EffectivePredicateExtractor.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.OperatorNotFoundException;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.DistinctLimitNode;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.LimitNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.SemiJoinNode;
import com.facebook.presto.spi.plan.SortNode;
import com.facebook.presto.spi.plan.SpatialJoinNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.TopNNode;
import com.facebook.presto.spi.plan.UnionNode;
import com.facebook.presto.spi.plan.WindowNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.plan.AssignUniqueId;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.InternalPlanVisitor;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.facebook.presto.sql.relational.RowExpressionDomainTranslator;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import static com.facebook.presto.common.function.OperatorType.EQUAL;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.expressions.LogicalRowExpressions.FALSE_CONSTANT;
import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT;
import static com.facebook.presto.expressions.LogicalRowExpressions.extractConjuncts;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.sql.planner.optimizations.SetOperationNodeUtils.outputMap;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.specialForm;
import static com.google.common.base.Predicates.in;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;
public class EffectivePredicateExtractor
{
private final RowExpressionDomainTranslator domainTranslator;
private final FunctionAndTypeManager functionAndTypeManager;
public EffectivePredicateExtractor(RowExpressionDomainTranslator domainTranslator, FunctionAndTypeManager functionAndTypeManager)
{
this.domainTranslator = requireNonNull(domainTranslator, "domainTranslator is null");
this.functionAndTypeManager = functionAndTypeManager;
}
public RowExpression extract(PlanNode node)
{
return node.accept(new Visitor(domainTranslator, functionAndTypeManager), null);
}
private static class Visitor
extends InternalPlanVisitor<RowExpression, Void>
{
private final RowExpressionDomainTranslator domainTranslator;
private final LogicalRowExpressions logicalRowExpressions;
private final RowExpressionDeterminismEvaluator determinismEvaluator;
private final FunctionAndTypeManager functionManger;
public Visitor(RowExpressionDomainTranslator domainTranslator, FunctionAndTypeManager functionAndTypeManager)
{
this.domainTranslator = requireNonNull(domainTranslator, "domainTranslator is null");
this.functionManger = requireNonNull(functionAndTypeManager);
this.determinismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager);
this.logicalRowExpressions = new LogicalRowExpressions(determinismEvaluator, new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver()), functionAndTypeManager);
}
@Override
public RowExpression visitPlan(PlanNode node, Void context)
{
return TRUE_CONSTANT;
}
@Override
public RowExpression visitAggregation(AggregationNode node, Void context)
{
// GROUP BY () always produces a group, regardless of whether there's any
// input (unlike the case where there are group by keys, which produce
// no output if there's no input).
// Therefore, we can't say anything about the effective predicate of the
// output of such an aggregation.
if (node.getGroupingKeys().isEmpty()) {
return TRUE_CONSTANT;
}
RowExpression underlyingPredicate = node.getSource().accept(this, context);
return pullExpressionThroughVariables(underlyingPredicate, node.getGroupingKeys());
}
@Override
public RowExpression visitFilter(FilterNode node, Void context)
{
RowExpression underlyingPredicate = node.getSource().accept(this, context);
RowExpression predicate = node.getPredicate();
// Remove non-deterministic conjuncts
predicate = logicalRowExpressions.filterDeterministicConjuncts(predicate);
return logicalRowExpressions.combineConjuncts(predicate, underlyingPredicate);
}
@Override
public RowExpression visitExchange(ExchangeNode node, Void context)
{
return deriveCommonPredicates(node, source -> {
Map<VariableReferenceExpression, VariableReferenceExpression> mappings = new HashMap<>();
for (int i = 0; i < node.getInputs().get(source).size(); i++) {
mappings.put(
node.getOutputVariables().get(i),
node.getInputs().get(source).get(i));
}
return mappings.entrySet();
});
}
@Override
public RowExpression visitEnforceSingleRow(EnforceSingleRowNode node, Void context)
{
if (node.getSource() instanceof ProjectNode) {
return node.getSource().accept(this, context);
}
return TRUE_CONSTANT;
}
@Override
public RowExpression visitProject(ProjectNode node, Void context)
{
// TODO: add simple algebraic solver for projection translation (right now only considers identity projections)
RowExpression underlyingPredicate = node.getSource().accept(this, context);
List<RowExpression> projectionEqualities = node.getAssignments().getMap().entrySet().stream()
.filter(this::notIdentityAssignment)
.filter(this::canCompareEquity)
.map(this::toEquality)
.collect(toImmutableList());
return pullExpressionThroughVariables(logicalRowExpressions.combineConjuncts(
ImmutableList.<RowExpression>builder()
.addAll(projectionEqualities)
.add(underlyingPredicate)
.build()),
node.getOutputVariables());
}
@Override
public RowExpression visitTopN(TopNNode node, Void context)
{
return node.getSource().accept(this, context);
}
@Override
public RowExpression visitLimit(LimitNode node, Void context)
{
return node.getSource().accept(this, context);
}
@Override
public RowExpression visitAssignUniqueId(AssignUniqueId node, Void context)
{
return node.getSource().accept(this, context);
}
@Override
public RowExpression visitDistinctLimit(DistinctLimitNode node, Void context)
{
return node.getSource().accept(this, context);
}
@Override
public RowExpression visitTableScan(TableScanNode node, Void context)
{
Map<ColumnHandle, VariableReferenceExpression> assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse();
return domainTranslator.toPredicate(node.getCurrentConstraint().simplify().transform(column -> assignments.containsKey(column) ? assignments.get(column) : null));
}
@Override
public RowExpression visitSort(SortNode node, Void context)
{
return node.getSource().accept(this, context);
}
@Override
public RowExpression visitWindow(WindowNode node, Void context)
{
return node.getSource().accept(this, context);
}
@Override
public RowExpression visitUnion(UnionNode node, Void context)
{
return deriveCommonPredicates(node, source -> outputMap(node, source).entries());
}
@Override
public RowExpression visitJoin(JoinNode node, Void context)
{
RowExpression leftPredicate = node.getLeft().accept(this, context);
RowExpression rightPredicate = node.getRight().accept(this, context);
List<RowExpression> joinConjuncts = node.getCriteria().stream()
.map(this::toRowExpression)
.collect(toImmutableList());
switch (node.getType()) {
case INNER:
return pullExpressionThroughVariables(logicalRowExpressions.combineConjuncts(ImmutableList.<RowExpression>builder()
.add(leftPredicate)
.add(rightPredicate)
.add(logicalRowExpressions.combineConjuncts(joinConjuncts))
.add(node.getFilter().orElse(TRUE_CONSTANT))
.build()), node.getOutputVariables());
case LEFT:
return logicalRowExpressions.combineConjuncts(ImmutableList.<RowExpression>builder()
.add(pullExpressionThroughVariables(leftPredicate, node.getOutputVariables()))
.addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getOutputVariables(), node.getRight().getOutputVariables()::contains))
.addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputVariables(), node.getRight().getOutputVariables()::contains))
.build());
case RIGHT:
return logicalRowExpressions.combineConjuncts(ImmutableList.<RowExpression>builder()
.add(pullExpressionThroughVariables(rightPredicate, node.getOutputVariables()))
.addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(leftPredicate), node.getOutputVariables(), node.getLeft().getOutputVariables()::contains))
.addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputVariables(), node.getLeft().getOutputVariables()::contains))
.build());
case FULL:
return logicalRowExpressions.combineConjuncts(ImmutableList.<RowExpression>builder()
.addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(leftPredicate), node.getOutputVariables(), node.getLeft().getOutputVariables()::contains))
.addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getOutputVariables(), node.getRight().getOutputVariables()::contains))
.addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputVariables(), node.getLeft().getOutputVariables()::contains, node.getRight().getOutputVariables()::contains))
.build());
default:
throw new UnsupportedOperationException("Unknown join type: " + node.getType());
}
}
private Iterable<RowExpression> pullNullableConjunctsThroughOuterJoin(List<RowExpression> conjuncts, Collection<VariableReferenceExpression> outputVariables, Predicate<VariableReferenceExpression>... nullVariableScopes)
{
// Conjuncts without any symbol dependencies cannot be applied to the effective predicate (e.g. FALSE literal)
return conjuncts.stream()
.map(expression -> pullExpressionThroughVariables(expression, outputVariables))
.map(expression -> VariablesExtractor.extractAll(expression).isEmpty() ? TRUE_CONSTANT : expression)
.map(expressionOrNullVariables(nullVariableScopes))
.collect(toImmutableList());
}
public Function<RowExpression, RowExpression> expressionOrNullVariables(final Predicate<VariableReferenceExpression>... nullVariableScopes)
{
return expression -> {
ImmutableList.Builder<RowExpression> resultDisjunct = ImmutableList.builder();
resultDisjunct.add(expression);
for (Predicate<VariableReferenceExpression> nullVariableScope : nullVariableScopes) {
List<VariableReferenceExpression> variables = VariablesExtractor.extractUnique(expression).stream()
.filter(nullVariableScope)
.collect(toImmutableList());
if (Iterables.isEmpty(variables)) {
continue;
}
ImmutableList.Builder<RowExpression> nullConjuncts = ImmutableList.builder();
for (VariableReferenceExpression variable : variables) {
nullConjuncts.add(specialForm(IS_NULL, BOOLEAN, variable));
}
resultDisjunct.add(LogicalRowExpressions.and(nullConjuncts.build()));
}
return LogicalRowExpressions.or(resultDisjunct.build());
};
}
@Override
public RowExpression visitSemiJoin(SemiJoinNode node, Void context)
{
// Filtering source does not change the effective predicate over the output symbols
return node.getSource().accept(this, context);
}
@Override
public RowExpression visitSpatialJoin(SpatialJoinNode node, Void context)
{
RowExpression leftPredicate = node.getLeft().accept(this, context);
RowExpression rightPredicate = node.getRight().accept(this, context);
switch (node.getType()) {
case INNER:
return logicalRowExpressions.combineConjuncts(ImmutableList.<RowExpression>builder()
.add(pullExpressionThroughVariables(leftPredicate, node.getOutputVariables()))
.add(pullExpressionThroughVariables(rightPredicate, node.getOutputVariables()))
.build());
case LEFT:
return logicalRowExpressions.combineConjuncts(ImmutableList.<RowExpression>builder()
.add(pullExpressionThroughVariables(leftPredicate, node.getOutputVariables()))
.addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getOutputVariables(), node.getRight().getOutputVariables()::contains))
.build());
default:
throw new IllegalArgumentException("Unsupported spatial join type: " + node.getType());
}
}
private RowExpression toRowExpression(EquiJoinClause equiJoinClause)
{
return buildEqualsExpression(functionManger, equiJoinClause.getLeft(), equiJoinClause.getRight());
}
private RowExpression deriveCommonPredicates(PlanNode node, Function<Integer, Collection<Map.Entry<VariableReferenceExpression, VariableReferenceExpression>>> mapping)
{
// Find the predicates that can be pulled up from each source
List<Set<RowExpression>> sourceOutputConjuncts = new ArrayList<>();
for (int i = 0; i < node.getSources().size(); i++) {
RowExpression underlyingPredicate = node.getSources().get(i).accept(this, null);
List<RowExpression> equalities = mapping.apply(i).stream()
.filter(this::notIdentityAssignment)
.filter(this::canCompareEquity)
.map(this::toEquality)
.collect(toImmutableList());
sourceOutputConjuncts.add(ImmutableSet.copyOf(extractConjuncts(pullExpressionThroughVariables(logicalRowExpressions.combineConjuncts(
ImmutableList.<RowExpression>builder()
.addAll(equalities)
.add(underlyingPredicate)
.build()),
node.getOutputVariables()))));
}
// Find the intersection of predicates across all sources
// TODO: use a more precise way to determine overlapping conjuncts (e.g. commutative predicates)
Iterator<Set<RowExpression>> iterator = sourceOutputConjuncts.iterator();
Set<RowExpression> potentialOutputConjuncts = iterator.next();
while (iterator.hasNext()) {
potentialOutputConjuncts = Sets.intersection(potentialOutputConjuncts, iterator.next());
}
return logicalRowExpressions.combineConjuncts(potentialOutputConjuncts);
}
private boolean notIdentityAssignment(Map.Entry<VariableReferenceExpression, ? extends RowExpression> entry)
{
return !entry.getKey().equals(entry.getValue());
}
private boolean canCompareEquity(Map.Entry<VariableReferenceExpression, ? extends RowExpression> entry)
{
try {
functionManger.resolveOperator(EQUAL, fromTypes(entry.getKey().getType(), entry.getValue().getType()));
return true;
}
catch (OperatorNotFoundException e) {
return false;
}
}
private RowExpression toEquality(Map.Entry<VariableReferenceExpression, ? extends RowExpression> entry)
{
return buildEqualsExpression(functionManger, entry.getKey(), entry.getValue());
}
private static CallExpression buildEqualsExpression(FunctionAndTypeManager functionAndTypeManager, RowExpression left, RowExpression right)
{
return call(
left.getSourceLocation(),
EQUAL.getFunctionName().getObjectName(),
functionAndTypeManager.resolveOperator(EQUAL, fromTypes(left.getType(), right.getType())),
BOOLEAN,
left,
right);
}
private RowExpression pullExpressionThroughVariables(RowExpression expression, Collection<VariableReferenceExpression> variables)
{
EqualityInference equalityInference = new EqualityInference.Builder(functionManger)
.addEqualityInference(expression)
.build();
ImmutableList.Builder<RowExpression> effectiveConjuncts = ImmutableList.builder();
for (RowExpression conjunct : new EqualityInference.Builder(functionManger).nonInferableConjuncts(expression)) {
if (determinismEvaluator.isDeterministic(conjunct)) {
RowExpression rewritten = equalityInference.rewriteExpression(conjunct, in(variables));
if (rewritten != null && (hasVariableReferences(rewritten) || rewritten.equals(FALSE_CONSTANT))) {
effectiveConjuncts.add(rewritten);
}
// If equality inference has reduced the predicate to an expression referring to only constants, it does not make sense to pull this predicate up
}
}
effectiveConjuncts.addAll(equalityInference.generateEqualitiesPartitionedBy(in(variables)).getScopeEqualities());
return logicalRowExpressions.combineConjuncts(effectiveConjuncts.build());
}
private static boolean hasVariableReferences(RowExpression rowExpression)
{
return !VariablesExtractor.extractUnique(rowExpression).isEmpty();
}
}
}