PredicatePushDown.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.optimizations;
import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.expressions.DynamicFilters;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.expressions.RowExpressionNodeInliner;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinDistributionType;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.MarkDistinctNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.SemiJoinNode;
import com.facebook.presto.spi.plan.SortNode;
import com.facebook.presto.spi.plan.SpatialJoinNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.UnionNode;
import com.facebook.presto.spi.plan.UnnestNode;
import com.facebook.presto.spi.plan.WindowNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.ExpressionOptimizer;
import com.facebook.presto.spi.relation.ExpressionOptimizerProvider;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.EffectivePredicateExtractor;
import com.facebook.presto.sql.planner.EqualityInference;
import com.facebook.presto.sql.planner.InequalityInference;
import com.facebook.presto.sql.planner.RowExpressionVariableInliner;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.VariablesExtractor;
import com.facebook.presto.sql.planner.plan.AssignUniqueId;
import com.facebook.presto.sql.planner.plan.AssignmentUtils;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.planner.plan.SampleNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.facebook.presto.sql.relational.RowExpressionDomainTranslator;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.airlift.slice.Slices;
import java.util.ArrayList;
import java.util.Collection;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import static com.facebook.presto.SystemSessionProperties.shouldGenerateDomainFilters;
import static com.facebook.presto.SystemSessionProperties.shouldInferInequalityPredicates;
import static com.facebook.presto.common.function.OperatorType.BETWEEN;
import static com.facebook.presto.common.function.OperatorType.EQUAL;
import static com.facebook.presto.common.function.OperatorType.GREATER_THAN_OR_EQUAL;
import static com.facebook.presto.common.function.OperatorType.IS_DISTINCT_FROM;
import static com.facebook.presto.common.function.OperatorType.LESS_THAN_OR_EQUAL;
import static com.facebook.presto.common.function.OperatorType.NOT_EQUAL;
import static com.facebook.presto.common.function.OperatorType.negate;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.expressions.LogicalRowExpressions.FALSE_CONSTANT;
import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT;
import static com.facebook.presto.expressions.LogicalRowExpressions.extractConjuncts;
import static com.facebook.presto.spi.plan.JoinDistributionType.PARTITIONED;
import static com.facebook.presto.spi.plan.JoinDistributionType.REPLICATED;
import static com.facebook.presto.spi.plan.JoinType.FULL;
import static com.facebook.presto.spi.plan.JoinType.INNER;
import static com.facebook.presto.spi.plan.JoinType.LEFT;
import static com.facebook.presto.spi.plan.JoinType.RIGHT;
import static com.facebook.presto.spi.plan.ProjectNode.Locality;
import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL;
import static com.facebook.presto.spi.plan.ProjectNode.Locality.REMOTE;
import static com.facebook.presto.spi.plan.ProjectNode.Locality.UNKNOWN;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.sql.planner.VariablesExtractor.extractUnique;
import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignments;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.facebook.presto.sql.relational.Expressions.constantNull;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Predicates.in;
import static com.google.common.base.Predicates.not;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.Iterables.filter;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;
public class PredicatePushDown
implements PlanOptimizer
{
private final Metadata metadata;
private final EffectivePredicateExtractor effectivePredicateExtractor;
private final SqlParser sqlParser;
private final RowExpressionDomainTranslator rowExpressionDomainTranslator;
private final boolean nativeExecution;
private final ExpressionOptimizerProvider expressionOptimizerProvider;
public PredicatePushDown(Metadata metadata, SqlParser sqlParser, ExpressionOptimizerProvider expressionOptimizerProvider, boolean nativeExecution)
{
this.metadata = requireNonNull(metadata, "metadata is null");
rowExpressionDomainTranslator = new RowExpressionDomainTranslator(metadata);
this.effectivePredicateExtractor = new EffectivePredicateExtractor(rowExpressionDomainTranslator, metadata.getFunctionAndTypeManager());
this.sqlParser = requireNonNull(sqlParser, "sqlParser is null");
this.expressionOptimizerProvider = requireNonNull(expressionOptimizerProvider, "expressionOptimizerProvider is null");
this.nativeExecution = nativeExecution;
}
@Override
public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector)
{
requireNonNull(plan, "plan is null");
requireNonNull(session, "session is null");
requireNonNull(types, "types is null");
requireNonNull(idAllocator, "idAllocator is null");
Rewriter rewriter = new Rewriter(variableAllocator, idAllocator, metadata, effectivePredicateExtractor, rowExpressionDomainTranslator, expressionOptimizerProvider, sqlParser, session, nativeExecution);
PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan, TRUE_CONSTANT);
return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged());
}
public static RowExpression createDynamicFilterExpression(String id, RowExpression input, FunctionAndTypeManager functionAndTypeManager)
{
return createDynamicFilterExpression(id, input, functionAndTypeManager, EQUAL.name());
}
private static RowExpression createDynamicFilterExpression(
String id,
RowExpression input,
FunctionAndTypeManager functionAndTypeManager,
String operator)
{
return call(
functionAndTypeManager,
DynamicFilters.DynamicFilterPlaceholderFunction.NAME,
BooleanType.BOOLEAN,
ImmutableList.of(
input,
new ConstantExpression(input.getSourceLocation(), Slices.utf8Slice(operator), VarcharType.VARCHAR),
new ConstantExpression(input.getSourceLocation(), Slices.utf8Slice(id), VarcharType.VARCHAR)));
}
private static class Rewriter
extends SimplePlanRewriter<RowExpression>
{
private final VariableAllocator variableAllocator;
private final PlanNodeIdAllocator idAllocator;
private final Metadata metadata;
private final EffectivePredicateExtractor effectivePredicateExtractor;
private final RowExpressionDomainTranslator rowExpressionDomainTranslator;
private final ExpressionOptimizerProvider expressionOptimizerProvider;
private final Session session;
private final boolean nativeExecution;
private final ExpressionEquivalence expressionEquivalence;
private final RowExpressionDeterminismEvaluator determinismEvaluator;
private final LogicalRowExpressions logicalRowExpressions;
private final FunctionAndTypeManager functionAndTypeManager;
private final ExternalCallExpressionChecker externalCallExpressionChecker;
private boolean planChanged;
private Rewriter(
VariableAllocator variableAllocator,
PlanNodeIdAllocator idAllocator,
Metadata metadata,
EffectivePredicateExtractor effectivePredicateExtractor,
RowExpressionDomainTranslator rowExpressionDomainTranslator,
ExpressionOptimizerProvider expressionOptimizerProvider,
SqlParser sqlParser,
Session session,
boolean nativeExecution)
{
this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator is null");
this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
this.metadata = requireNonNull(metadata, "metadata is null");
this.effectivePredicateExtractor = requireNonNull(effectivePredicateExtractor, "effectivePredicateExtractor is null");
this.rowExpressionDomainTranslator = rowExpressionDomainTranslator;
this.expressionOptimizerProvider = requireNonNull(expressionOptimizerProvider, "expressionOptimizerProvider is null");
this.session = requireNonNull(session, "session is null");
this.nativeExecution = nativeExecution;
this.expressionEquivalence = new ExpressionEquivalence(metadata, sqlParser);
this.determinismEvaluator = new RowExpressionDeterminismEvaluator(metadata);
this.logicalRowExpressions = new LogicalRowExpressions(determinismEvaluator, new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()), metadata.getFunctionAndTypeManager());
this.functionAndTypeManager = metadata.getFunctionAndTypeManager();
this.externalCallExpressionChecker = new ExternalCallExpressionChecker(functionAndTypeManager);
}
public boolean isPlanChanged()
{
return planChanged;
}
@Override
public PlanNode visitPlan(PlanNode node, RewriteContext<RowExpression> context)
{
PlanNode rewrittenNode = context.defaultRewrite(node, TRUE_CONSTANT);
if (!context.get().equals(TRUE_CONSTANT)) {
// Drop in a FilterNode b/c we cannot push our predicate down any further
planChanged = true;
rewrittenNode = new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), rewrittenNode, context.get());
}
return rewrittenNode;
}
@Override
public PlanNode visitExchange(ExchangeNode node, RewriteContext<RowExpression> context)
{
boolean modified = false;
ImmutableList.Builder<PlanNode> builder = ImmutableList.builder();
for (int i = 0; i < node.getSources().size(); i++) {
Map<VariableReferenceExpression, VariableReferenceExpression> outputsToInputs = new HashMap<>();
for (int index = 0; index < node.getInputs().get(i).size(); index++) {
outputsToInputs.put(
node.getOutputVariables().get(index),
node.getInputs().get(i).get(index));
}
RowExpression sourcePredicate = RowExpressionVariableInliner.inlineVariables(outputsToInputs, context.get());
PlanNode source = node.getSources().get(i);
PlanNode rewrittenSource = context.rewrite(source, sourcePredicate);
if (rewrittenSource != source) {
modified = true;
}
builder.add(rewrittenSource);
}
if (modified) {
planChanged = true;
return new ExchangeNode(
node.getSourceLocation(),
node.getId(),
node.getType(),
node.getScope(),
node.getPartitioningScheme(),
builder.build(),
node.getInputs(),
node.isEnsureSourceOrdering(),
node.getOrderingScheme());
}
return node;
}
@Override
public PlanNode visitWindow(WindowNode node, RewriteContext<RowExpression> context)
{
// TODO: This could be broader. We can push down conjuncts if they are constant for all rows in a window partition.
// The simplest way to guarantee this is if the conjuncts are deterministic functions of the partitioning variables.
// This can leave out cases where they're both functions of some set of common expressions and the partitioning
// function is injective, but that's a rare case. The majority of window nodes are expected to be partitioned by
// pre-projected variables.
Predicate<RowExpression> isSupported = conjunct ->
determinismEvaluator.isDeterministic(conjunct) &&
extractUnique(conjunct).stream().allMatch(node.getPartitionBy()::contains);
Map<Boolean, List<RowExpression>> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(isSupported));
PlanNode rewrittenNode = context.defaultRewrite(node, logicalRowExpressions.combineConjuncts(conjuncts.get(true)));
if (!conjuncts.get(false).isEmpty()) {
planChanged = true;
rewrittenNode = new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), rewrittenNode, logicalRowExpressions.combineConjuncts(conjuncts.get(false)));
}
return rewrittenNode;
}
@Override
public PlanNode visitProject(ProjectNode node, RewriteContext<RowExpression> context)
{
Set<VariableReferenceExpression> deterministicVariables = node.getAssignments().entrySet().stream()
.filter(entry -> determinismEvaluator.isDeterministic(entry.getValue()))
.map(Map.Entry::getKey)
.collect(Collectors.toSet());
Predicate<RowExpression> deterministic = conjunct -> deterministicVariables.containsAll(extractUnique(conjunct));
Map<Boolean, List<RowExpression>> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(deterministic));
// Push down conjuncts from the inherited predicate that only depend on deterministic assignments with
// certain limitations.
List<RowExpression> deterministicConjuncts = conjuncts.get(true);
// We partition the expressions in the deterministicConjuncts into two lists, and only inline the
// expressions that are in the inlining targets list.
Map<Boolean, List<RowExpression>> inlineConjuncts = deterministicConjuncts.stream()
.collect(Collectors.partitioningBy(expression -> isInliningCandidate(expression, node)));
List<RowExpression> inlinedDeterministicConjuncts = inlineConjuncts.get(true).stream()
.map(entry -> RowExpressionVariableInliner.inlineVariables(node.getAssignments().getMap(), entry))
.collect(Collectors.toList());
PlanNode rewrittenNode = context.defaultRewrite(node, logicalRowExpressions.combineConjuncts(inlinedDeterministicConjuncts));
// All deterministic conjuncts that contains non-inlining targets, and non-deterministic conjuncts,
// if any, will be in the filter node.
List<RowExpression> nonInliningConjuncts = inlineConjuncts.get(false);
nonInliningConjuncts.addAll(conjuncts.get(false));
if (!nonInliningConjuncts.isEmpty()) {
planChanged = true;
rewrittenNode = new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), rewrittenNode, logicalRowExpressions.combineConjuncts(nonInliningConjuncts));
}
return rewrittenNode;
}
private boolean isInliningCandidate(RowExpression expression, ProjectNode node)
{
// candidate symbols for inlining are
// 1. references to simple constants
// 2. references to complex expressions that appear only once
// which come from the node, as opposed to an enclosing scope,
// and the expression does not contain remote functions.
Set<VariableReferenceExpression> childOutputSet = ImmutableSet.copyOf(node.getOutputVariables());
Map<VariableReferenceExpression, Long> dependencies = VariablesExtractor.extractAll(expression).stream()
.filter(childOutputSet::contains)
.collect(Collectors.groupingBy(identity(), Collectors.counting()));
return dependencies.entrySet().stream()
.allMatch(entry -> (entry.getValue() == 1 && !node.getAssignments().get(entry.getKey()).accept(new ExternalCallExpressionChecker(functionAndTypeManager), null)) ||
node.getAssignments().get(entry.getKey()) instanceof ConstantExpression);
}
@Override
public PlanNode visitGroupId(GroupIdNode node, RewriteContext<RowExpression> context)
{
Map<VariableReferenceExpression, VariableReferenceExpression> commonGroupingVariableMapping = node.getGroupingColumns().entrySet().stream()
.filter(entry -> node.getCommonGroupingColumns().contains(entry.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
Predicate<RowExpression> pushdownEligiblePredicate = conjunct -> extractUnique(conjunct).stream()
.allMatch(commonGroupingVariableMapping.keySet()::contains);
Map<Boolean, List<RowExpression>> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(pushdownEligiblePredicate));
// Push down conjuncts from the inherited predicate that apply to common grouping symbols
PlanNode rewrittenNode = context.defaultRewrite(node, RowExpressionVariableInliner.inlineVariables(commonGroupingVariableMapping, logicalRowExpressions.combineConjuncts(conjuncts.get(true))));
// All other conjuncts, if any, will be in the filter node.
if (!conjuncts.get(false).isEmpty()) {
planChanged = true;
rewrittenNode = new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), rewrittenNode, logicalRowExpressions.combineConjuncts(conjuncts.get(false)));
}
return rewrittenNode;
}
@Override
public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext<RowExpression> context)
{
Set<VariableReferenceExpression> pushDownableVariables = ImmutableSet.copyOf(node.getDistinctVariables());
Map<Boolean, List<RowExpression>> conjuncts = extractConjuncts(context.get()).stream()
.collect(Collectors.partitioningBy(conjunct -> pushDownableVariables.containsAll(extractUnique(conjunct))));
PlanNode rewrittenNode = context.defaultRewrite(node, logicalRowExpressions.combineConjuncts(conjuncts.get(true)));
if (!conjuncts.get(false).isEmpty()) {
planChanged = true;
rewrittenNode = new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), rewrittenNode, logicalRowExpressions.combineConjuncts(conjuncts.get(false)));
}
return rewrittenNode;
}
@Override
public PlanNode visitSort(SortNode node, RewriteContext<RowExpression> context)
{
return context.defaultRewrite(node, context.get());
}
@Override
public PlanNode visitUnion(UnionNode node, RewriteContext<RowExpression> context)
{
boolean modified = false;
ImmutableList.Builder<PlanNode> builder = ImmutableList.builder();
for (int i = 0; i < node.getSources().size(); i++) {
RowExpression sourcePredicate = RowExpressionVariableInliner.inlineVariables(node.sourceVariableMap(i), context.get());
PlanNode source = node.getSources().get(i);
PlanNode rewrittenSource = context.rewrite(source, sourcePredicate);
if (rewrittenSource != source) {
modified = true;
}
builder.add(rewrittenSource);
}
if (modified) {
planChanged = true;
return new UnionNode(node.getSourceLocation(), node.getId(), builder.build(), node.getOutputVariables(), node.getVariableMapping());
}
return node;
}
@Deprecated
@Override
public PlanNode visitFilter(FilterNode node, RewriteContext<RowExpression> context)
{
PlanNode rewrittenPlan = context.rewrite(node.getSource(), logicalRowExpressions.combineConjuncts(node.getPredicate(), context.get()));
if (!(rewrittenPlan instanceof FilterNode)) {
planChanged = true;
return rewrittenPlan;
}
FilterNode rewrittenFilterNode = (FilterNode) rewrittenPlan;
if (!areExpressionsEquivalent(rewrittenFilterNode.getPredicate(), node.getPredicate())
|| node.getSource() != rewrittenFilterNode.getSource()) {
planChanged = true;
return rewrittenPlan;
}
return node;
}
@Override
public PlanNode visitJoin(JoinNode node, RewriteContext<RowExpression> context)
{
RowExpression inheritedPredicate = context.get();
// See if we can rewrite outer joins in terms of a plain inner join
node = tryNormalizeToOuterToInnerJoin(node, inheritedPredicate);
RowExpression leftEffectivePredicate = effectivePredicateExtractor.extract(node.getLeft());
RowExpression rightEffectivePredicate = effectivePredicateExtractor.extract(node.getRight());
RowExpression joinPredicate = extractJoinPredicate(node);
RowExpression leftPredicate;
RowExpression rightPredicate;
RowExpression postJoinPredicate;
RowExpression newJoinPredicate;
switch (node.getType()) {
case INNER:
InnerJoinPushDownResult innerJoinPushDownResult = processInnerJoin(inheritedPredicate,
leftEffectivePredicate,
rightEffectivePredicate,
joinPredicate,
node.getLeft().getOutputVariables(),
shouldInferInequalityPredicates(session));
leftPredicate = innerJoinPushDownResult.getLeftPredicate();
rightPredicate = innerJoinPushDownResult.getRightPredicate();
postJoinPredicate = innerJoinPushDownResult.getPostJoinPredicate();
newJoinPredicate = innerJoinPushDownResult.getJoinPredicate();
break;
case LEFT:
OuterJoinPushDownResult leftOuterJoinPushDownResult = processLimitedOuterJoin(inheritedPredicate,
leftEffectivePredicate,
rightEffectivePredicate,
joinPredicate,
node.getLeft().getOutputVariables(),
shouldInferInequalityPredicates(session));
leftPredicate = leftOuterJoinPushDownResult.getOuterJoinPredicate();
rightPredicate = leftOuterJoinPushDownResult.getInnerJoinPredicate();
postJoinPredicate = leftOuterJoinPushDownResult.getPostJoinPredicate();
newJoinPredicate = leftOuterJoinPushDownResult.getJoinPredicate();
break;
case RIGHT:
OuterJoinPushDownResult rightOuterJoinPushDownResult = processLimitedOuterJoin(inheritedPredicate,
rightEffectivePredicate,
leftEffectivePredicate,
joinPredicate,
node.getRight().getOutputVariables(),
shouldInferInequalityPredicates(session));
leftPredicate = rightOuterJoinPushDownResult.getInnerJoinPredicate();
rightPredicate = rightOuterJoinPushDownResult.getOuterJoinPredicate();
postJoinPredicate = rightOuterJoinPushDownResult.getPostJoinPredicate();
newJoinPredicate = rightOuterJoinPushDownResult.getJoinPredicate();
break;
case FULL:
leftPredicate = TRUE_CONSTANT;
rightPredicate = TRUE_CONSTANT;
postJoinPredicate = inheritedPredicate;
newJoinPredicate = joinPredicate;
break;
default:
throw new UnsupportedOperationException("Unsupported join type: " + node.getType());
}
newJoinPredicate = simplifyExpression(newJoinPredicate);
// TODO: find a better way to directly optimize FALSE LITERAL in join predicate
if (newJoinPredicate.equals(FALSE_CONSTANT)) {
newJoinPredicate = buildEqualsExpression(functionAndTypeManager, constant(0L, BIGINT), constant(1L, BIGINT));
}
// Create identity projections for all existing symbols
Assignments.Builder leftProjections = Assignments.builder()
.putAll(identityAssignments(node.getLeft().getOutputVariables()));
Assignments.Builder rightProjections = Assignments.builder()
.putAll(identityAssignments(node.getRight().getOutputVariables()));
Locality leftLocality = LOCAL;
Locality rightLocality = LOCAL;
// Create new projections for the new join clauses
List<EquiJoinClause> equiJoinClauses = new ArrayList<>();
ImmutableList.Builder<RowExpression> joinFilterBuilder = ImmutableList.builder();
for (RowExpression conjunct : extractConjuncts(newJoinPredicate)) {
if (joinEqualityExpression(node.getLeft().getOutputVariables()).test(conjunct)) {
boolean alignedComparison = Iterables.all(extractUnique(getLeft(conjunct)), in(node.getLeft().getOutputVariables()));
RowExpression leftExpression = (alignedComparison) ? getLeft(conjunct) : getRight(conjunct);
RowExpression rightExpression = (alignedComparison) ? getRight(conjunct) : getLeft(conjunct);
VariableReferenceExpression leftVariable = variableForExpression(leftExpression);
if (!node.getLeft().getOutputVariables().contains(leftVariable)) {
leftProjections.put(leftVariable, leftExpression);
if (leftExpression.accept(externalCallExpressionChecker, null)) {
leftLocality = REMOTE;
}
}
VariableReferenceExpression rightVariable = variableForExpression(rightExpression);
if (!node.getRight().getOutputVariables().contains(rightVariable)) {
rightProjections.put(rightVariable, rightExpression);
if (rightExpression.accept(externalCallExpressionChecker, null)) {
rightLocality = REMOTE;
}
}
equiJoinClauses.add(new EquiJoinClause(leftVariable, rightVariable));
}
else {
joinFilterBuilder.add(conjunct);
}
}
PlanNode leftSource;
PlanNode rightSource;
List<RowExpression> joinFilter = joinFilterBuilder.build();
boolean dynamicFilterEnabled = isEnableDynamicFiltering();
Map<String, VariableReferenceExpression> dynamicFilters = ImmutableMap.of();
if (dynamicFilterEnabled) {
DynamicFiltersResult dynamicFiltersResult = createDynamicFilters(node, equiJoinClauses, joinFilter, idAllocator, metadata.getFunctionAndTypeManager());
dynamicFilters = dynamicFiltersResult.getDynamicFilters();
leftPredicate = logicalRowExpressions.combineConjuncts(leftPredicate, logicalRowExpressions.combineConjuncts(dynamicFiltersResult.getPredicates()));
}
boolean equiJoinClausesUnmodified = ImmutableSet.copyOf(equiJoinClauses).equals(ImmutableSet.copyOf(node.getCriteria()));
if (dynamicFilterEnabled && !equiJoinClausesUnmodified) {
leftSource = context.rewrite(wrapInProjectIfNeeded(node.getLeft(), leftProjections.build()), leftPredicate);
rightSource = context.rewrite(wrapInProjectIfNeeded(node.getRight(), rightProjections.build()), rightPredicate);
}
else {
leftSource = context.rewrite(node.getLeft(), leftPredicate);
rightSource = context.rewrite(node.getRight(), rightPredicate);
}
Optional<RowExpression> newJoinFilter = Optional.of(logicalRowExpressions.combineConjuncts(joinFilter));
if (newJoinFilter.get() == TRUE_CONSTANT) {
newJoinFilter = Optional.empty();
}
if (node.getType() == INNER && newJoinFilter.isPresent() && equiJoinClauses.isEmpty()) {
// if we do not have any equi conjunct we do not pushdown non-equality condition into
// inner join, so we plan execution as nested-loops-join followed by filter instead
// hash join.
// todo: remove the code when we have support for filter function in nested loop join
postJoinPredicate = logicalRowExpressions.combineConjuncts(postJoinPredicate, newJoinFilter.get());
newJoinFilter = Optional.empty();
}
boolean filtersEquivalent =
newJoinFilter.isPresent() == node.getFilter().isPresent() &&
(!newJoinFilter.isPresent() || areExpressionsEquivalent(newJoinFilter.get(), node.getFilter().get()));
PlanNode output = node;
if (leftSource != node.getLeft() ||
rightSource != node.getRight() ||
!filtersEquivalent ||
(dynamicFilterEnabled && !dynamicFilters.equals(node.getDynamicFilters())) ||
!equiJoinClausesUnmodified) {
leftSource = wrapInProjectIfNeeded(leftSource, leftProjections.build(), leftLocality);
rightSource = wrapInProjectIfNeeded(rightSource, rightProjections.build(), rightLocality);
checkState(ImmutableSet.<VariableReferenceExpression>builder()
.addAll(leftSource.getOutputVariables())
.addAll(rightSource.getOutputVariables())
.build().containsAll(node.getOutputVariables()),
"JoinNode predicate pushdown incorrect : Left and right source are not producing original JoinNode output variables");
// if the distribution type is already set, make sure that changes from PredicatePushDown
// don't make the join node invalid.
Optional<JoinDistributionType> distributionType = node.getDistributionType();
if (node.getDistributionType().isPresent()) {
if (node.getType().mustPartition()) {
distributionType = Optional.of(PARTITIONED);
}
if (node.getType().mustReplicate(equiJoinClauses)) {
distributionType = Optional.of(REPLICATED);
}
}
List<VariableReferenceExpression> newJoinOutputVariables = node.getOutputVariables();
// If, the new Join node is a cross-join OR
// we have a post join predicate that refers to variables that were not already referenced by the JoinNode
if ((node.getType() == INNER && equiJoinClauses.isEmpty() && !newJoinFilter.isPresent())
|| (!ImmutableSet.copyOf(newJoinOutputVariables).containsAll(extractUnique(postJoinPredicate)))) {
// Set the new output variables to be left + right output variables
newJoinOutputVariables = ImmutableList.<VariableReferenceExpression>builder()
.addAll(leftSource.getOutputVariables())
.addAll(rightSource.getOutputVariables())
.build();
}
planChanged = true;
output = new JoinNode(
node.getSourceLocation(),
node.getId(),
node.getType(),
leftSource,
rightSource,
equiJoinClauses,
newJoinOutputVariables,
newJoinFilter,
node.getLeftHashVariable(),
node.getRightHashVariable(),
distributionType,
dynamicFilters);
}
if (!postJoinPredicate.equals(TRUE_CONSTANT)) {
planChanged = true;
output = new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), output, postJoinPredicate);
}
if (!node.getOutputVariables().equals(output.getOutputVariables())) {
planChanged = true;
output = new ProjectNode(node.getSourceLocation(), idAllocator.getNextId(), output, identityAssignments(node.getOutputVariables()), LOCAL);
}
return output;
}
private PlanNode wrapInProjectIfNeeded(PlanNode childNode, Assignments assignments)
{
return wrapInProjectIfNeeded(childNode, assignments, UNKNOWN);
}
private PlanNode wrapInProjectIfNeeded(PlanNode childNode, Assignments assignments, Locality locality)
{
if ((childNode instanceof ProjectNode || childNode instanceof JoinNode)
&& AssignmentUtils.isIdentity(assignments)) {
// By wrapping an identity Project over a child node of type :
// ProjectNode - we are adding no value
// JoinNode - we are preventing this JoinNode from participating in join re-ordering
// So we return the child node as is, without an identity project
return childNode;
}
return new ProjectNode(childNode.getSourceLocation(), idAllocator.getNextId(), childNode, assignments, locality);
}
private static DynamicFiltersResult createDynamicFilters(
JoinNode node,
List<EquiJoinClause> equiJoinClauses,
List<RowExpression> joinFilter,
PlanNodeIdAllocator idAllocator,
FunctionAndTypeManager functionAndTypeManager)
{
Map<String, VariableReferenceExpression> dynamicFilters = ImmutableMap.of();
List<RowExpression> predicates = ImmutableList.of();
if (node.getType() == INNER || node.getType() == RIGHT) {
List<CallExpression> clauses = getDynamicFilterClauses(node, equiJoinClauses, joinFilter, functionAndTypeManager);
List<VariableReferenceExpression> buildSymbols = clauses.stream()
.map(expression -> (VariableReferenceExpression) expression.getArguments().get(1))
.collect(Collectors.toList());
BiMap<VariableReferenceExpression, String> buildSymbolToIdMap = HashBiMap.create(node.getDynamicFilters()).inverse();
for (VariableReferenceExpression buildSymbol : buildSymbols) {
buildSymbolToIdMap.put(buildSymbol, idAllocator.getNextId().toString());
}
ImmutableList.Builder<RowExpression> predicatesBuilder = ImmutableList.builder();
for (CallExpression expression : clauses) {
RowExpression probeExpression = expression.getArguments().get(0);
VariableReferenceExpression buildSymbol = (VariableReferenceExpression) expression.getArguments().get(1);
String id = buildSymbolToIdMap.get(buildSymbol);
RowExpression predicate = createDynamicFilterExpression(id, probeExpression, functionAndTypeManager, expression.getDisplayName());
predicatesBuilder.add(predicate);
}
dynamicFilters = buildSymbolToIdMap.inverse();
predicates = predicatesBuilder.build();
}
return new DynamicFiltersResult(dynamicFilters, predicates);
}
private static List<CallExpression> getDynamicFilterClauses(
JoinNode node,
List<EquiJoinClause> equiJoinClauses,
List<RowExpression> joinFilter,
FunctionAndTypeManager functionAndTypeManager)
{
// New equiJoinClauses could potentially not contain symbols used in current dynamic filters.
// Since we use PredicatePushdown to push dynamic filters themselves,
// instead of separate ApplyDynamicFilters rule we derive dynamic filters within PredicatePushdown itself.
// Even if equiJoinClauses.equals(node.getCriteria), current dynamic filters may not match equiJoinClauses
ImmutableList.Builder<CallExpression> clausesBuilder = ImmutableList.builder();
for (EquiJoinClause clause : equiJoinClauses) {
VariableReferenceExpression probeSymbol = clause.getLeft();
VariableReferenceExpression buildSymbol = clause.getRight();
clausesBuilder.add(call(
EQUAL.name(),
functionAndTypeManager.resolveOperator(EQUAL, fromTypes(probeSymbol.getType(), buildSymbol.getType())),
BOOLEAN,
probeSymbol,
buildSymbol));
}
for (RowExpression filter : joinFilter) {
if ((filter instanceof CallExpression)) {
CallExpression call = (CallExpression) filter;
List<RowExpression> arguments = call.getArguments();
// TODO: support for complex inequalities, e.g. left < right + 10, NOT, LIKE
if (arguments.size() == 1) {
continue;
}
if (arguments.size() == 3) {
// try convert BETWEEN into GREATER_THAN_OR_EQUAL and LESS_THAN_OR_EQUAL
String function = call.getDisplayName();
if (function.equals(BETWEEN.name()) && arguments.get(0) instanceof VariableReferenceExpression) {
if (arguments.get(1) instanceof VariableReferenceExpression) {
CallExpression callExpression = call(
GREATER_THAN_OR_EQUAL.name(),
functionAndTypeManager.resolveOperator(GREATER_THAN_OR_EQUAL, fromTypes(arguments.get(0).getType(), arguments.get(1).getType())),
BOOLEAN,
arguments.get(0),
arguments.get(1));
Optional<CallExpression> comparisonExpression = getDynamicFilterComparison(node, callExpression, functionAndTypeManager);
if (comparisonExpression.isPresent()) {
clausesBuilder.add(comparisonExpression.get());
}
}
if (arguments.get(2) instanceof VariableReferenceExpression) {
CallExpression callExpression = call(
LESS_THAN_OR_EQUAL.name(),
functionAndTypeManager.resolveOperator(LESS_THAN_OR_EQUAL, fromTypes(arguments.get(0).getType(), arguments.get(2).getType())),
BOOLEAN,
arguments.get(0),
arguments.get(2));
Optional<CallExpression> comparisonExpression = getDynamicFilterComparison(node, callExpression, functionAndTypeManager);
if (comparisonExpression.isPresent()) {
clausesBuilder.add(comparisonExpression.get());
}
}
}
continue;
}
checkArgument(arguments.size() == 2, "invalid arguments count: %s", arguments.size());
Optional<CallExpression> comparisonExpression = getDynamicFilterComparison(node, call, functionAndTypeManager);
if (comparisonExpression.isPresent()) {
clausesBuilder.add(comparisonExpression.get());
}
}
}
return clausesBuilder.build();
}
private static Optional<CallExpression> getDynamicFilterComparison(
JoinNode node,
CallExpression call,
FunctionAndTypeManager functionAndTypeManager)
{
Optional<OperatorType> operatorType = functionAndTypeManager.getFunctionMetadata(call.getFunctionHandle()).getOperatorType();
if (!operatorType.isPresent()) {
return Optional.empty();
}
OperatorType operator = operatorType.get();
List<RowExpression> arguments = call.getArguments();
RowExpression left = arguments.get(0);
RowExpression right = arguments.get(1);
// supported comparison for dynamic filtering: EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL
if (!operator.isComparisonOperator()) {
return Optional.empty();
}
if (operator == NOT_EQUAL || operator == IS_DISTINCT_FROM) {
return Optional.empty();
}
// supported expression for dynamic filtering:
// either 1. left child contains left variables and right child contains right variables
// or, 2. left child contains right variables and right child contains left variables
Set<VariableReferenceExpression> leftUniqueOutputs = extractUnique(left);
Set<VariableReferenceExpression> rightUniqueOutputs = extractUnique(right);
boolean leftChildContainsLeftVariables = node.getLeft().getOutputVariables().containsAll(leftUniqueOutputs);
boolean rightChildContainsRightVariables = node.getRight().getOutputVariables().containsAll(rightUniqueOutputs);
boolean leftChildContainsRightVariables = node.getLeft().getOutputVariables().containsAll(rightUniqueOutputs);
boolean rightChildContainsLeftVariables = node.getRight().getOutputVariables().containsAll(leftUniqueOutputs);
if (!((leftChildContainsLeftVariables && rightChildContainsRightVariables) || (leftChildContainsRightVariables && rightChildContainsLeftVariables))) {
return Optional.empty();
}
boolean shouldFlip = false;
if (leftChildContainsRightVariables && rightChildContainsLeftVariables) {
shouldFlip = true;
}
if (shouldFlip) {
operator = negate(operator);
left = arguments.get(1);
right = arguments.get(0);
}
if (!(right instanceof VariableReferenceExpression)) {
return Optional.empty();
}
return Optional.of(call(
operator.name(),
functionAndTypeManager.resolveOperator(operator, fromTypes(left.getType(), right.getType())),
BOOLEAN,
left,
right));
}
private static DynamicFiltersResult createDynamicFilters(
VariableReferenceExpression probeVariable,
VariableReferenceExpression buildVariable,
PlanNodeIdAllocator idAllocator,
FunctionAndTypeManager functionAndTypeManager)
{
ImmutableMap.Builder<String, VariableReferenceExpression> dynamicFiltersBuilder = ImmutableMap.builder();
ImmutableList.Builder<RowExpression> predicatesBuilder = ImmutableList.builder();
String id = idAllocator.getNextId().toString();
predicatesBuilder.add(createDynamicFilterExpression(id, probeVariable, functionAndTypeManager));
dynamicFiltersBuilder.put(id, buildVariable);
return new DynamicFiltersResult(dynamicFiltersBuilder.build(), predicatesBuilder.build());
}
private static class DynamicFiltersResult
{
private final Map<String, VariableReferenceExpression> dynamicFilters;
private final List<RowExpression> predicates;
public DynamicFiltersResult(Map<String, VariableReferenceExpression> dynamicFilters, List<RowExpression> predicates)
{
this.dynamicFilters = dynamicFilters;
this.predicates = predicates;
}
public Map<String, VariableReferenceExpression> getDynamicFilters()
{
return dynamicFilters;
}
public List<RowExpression> getPredicates()
{
return predicates;
}
}
private static RowExpression getLeft(RowExpression expression)
{
checkArgument(expression instanceof CallExpression && ((CallExpression) expression).getArguments().size() == 2, "must be binary call expression");
return ((CallExpression) expression).getArguments().get(0);
}
private static RowExpression getRight(RowExpression expression)
{
checkArgument(expression instanceof CallExpression && ((CallExpression) expression).getArguments().size() == 2, "must be binary call expression");
return ((CallExpression) expression).getArguments().get(1);
}
@Override
public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext<RowExpression> context)
{
RowExpression inheritedPredicate = context.get();
// See if we can rewrite left join in terms of a plain inner join
if (node.getType() == SpatialJoinNode.Type.LEFT && canConvertOuterToInner(node.getRight().getOutputVariables(), inheritedPredicate)) {
planChanged = true;
node = new SpatialJoinNode(
node.getSourceLocation(),
node.getId(),
SpatialJoinNode.Type.INNER,
node.getLeft(),
node.getRight(),
node.getOutputVariables(),
node.getFilter(),
node.getLeftPartitionVariable(),
node.getRightPartitionVariable(),
node.getKdbTree());
}
RowExpression leftEffectivePredicate = effectivePredicateExtractor.extract(node.getLeft());
RowExpression rightEffectivePredicate = effectivePredicateExtractor.extract(node.getRight());
RowExpression joinPredicate = node.getFilter();
RowExpression leftPredicate;
RowExpression rightPredicate;
RowExpression postJoinPredicate;
RowExpression newJoinPredicate;
switch (node.getType()) {
case INNER:
InnerJoinPushDownResult innerJoinPushDownResult = processInnerJoin(
inheritedPredicate,
leftEffectivePredicate,
rightEffectivePredicate,
joinPredicate,
node.getLeft().getOutputVariables(),
shouldInferInequalityPredicates(session));
leftPredicate = innerJoinPushDownResult.getLeftPredicate();
rightPredicate = innerJoinPushDownResult.getRightPredicate();
postJoinPredicate = innerJoinPushDownResult.getPostJoinPredicate();
newJoinPredicate = innerJoinPushDownResult.getJoinPredicate();
break;
case LEFT:
OuterJoinPushDownResult leftOuterJoinPushDownResult = processLimitedOuterJoin(
inheritedPredicate,
leftEffectivePredicate,
rightEffectivePredicate,
joinPredicate,
node.getLeft().getOutputVariables(),
shouldInferInequalityPredicates(session));
leftPredicate = leftOuterJoinPushDownResult.getOuterJoinPredicate();
rightPredicate = leftOuterJoinPushDownResult.getInnerJoinPredicate();
postJoinPredicate = leftOuterJoinPushDownResult.getPostJoinPredicate();
newJoinPredicate = leftOuterJoinPushDownResult.getJoinPredicate();
break;
default:
throw new IllegalArgumentException("Unsupported spatial join type: " + node.getType());
}
newJoinPredicate = simplifyExpression(newJoinPredicate);
verify(!newJoinPredicate.equals(FALSE_CONSTANT), "Spatial join predicate is missing");
PlanNode leftSource = context.rewrite(node.getLeft(), leftPredicate);
PlanNode rightSource = context.rewrite(node.getRight(), rightPredicate);
PlanNode output = node;
if (leftSource != node.getLeft() ||
rightSource != node.getRight() ||
!areExpressionsEquivalent(newJoinPredicate, joinPredicate)) {
// Create identity projections for all existing symbols
Assignments.Builder leftProjections = Assignments.builder()
.putAll(identityAssignments(node.getLeft().getOutputVariables()));
Assignments.Builder rightProjections = Assignments.builder()
.putAll(identityAssignments(node.getRight().getOutputVariables()));
leftSource = new ProjectNode(node.getSourceLocation(), idAllocator.getNextId(), leftSource, leftProjections.build(), LOCAL);
rightSource = new ProjectNode(node.getSourceLocation(), idAllocator.getNextId(), rightSource, rightProjections.build(), LOCAL);
planChanged = true;
output = new SpatialJoinNode(
node.getSourceLocation(),
node.getId(),
node.getType(),
leftSource,
rightSource,
node.getOutputVariables(),
newJoinPredicate,
node.getLeftPartitionVariable(),
node.getRightPartitionVariable(),
node.getKdbTree());
}
if (!postJoinPredicate.equals(TRUE_CONSTANT)) {
planChanged = true;
output = new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), output, postJoinPredicate);
}
return output;
}
private VariableReferenceExpression variableForExpression(RowExpression expression)
{
if (expression instanceof VariableReferenceExpression) {
return (VariableReferenceExpression) expression;
}
return variableAllocator.newVariable(expression);
}
private OuterJoinPushDownResult processLimitedOuterJoin(RowExpression inheritedPredicate,
RowExpression outerEffectivePredicate,
RowExpression innerEffectivePredicate,
RowExpression joinPredicate,
Collection<VariableReferenceExpression> outerVariables,
boolean inferInequalityPredicates)
{
checkArgument(Iterables.all(extractUnique(outerEffectivePredicate), in(outerVariables)), "outerEffectivePredicate must only contain variables from outerVariables");
checkArgument(Iterables.all(extractUnique(innerEffectivePredicate), not(in(outerVariables))), "innerEffectivePredicate must not contain variables from outerVariables");
ImmutableList.Builder<RowExpression> outerPushdownConjuncts = ImmutableList.builder();
ImmutableList.Builder<RowExpression> innerPushdownConjuncts = ImmutableList.builder();
ImmutableList.Builder<RowExpression> postJoinConjuncts = ImmutableList.builder();
ImmutableList.Builder<RowExpression> joinConjuncts = ImmutableList.builder();
// Strip out non-deterministic conjuncts
postJoinConjuncts.addAll(filter(extractConjuncts(inheritedPredicate), not(determinismEvaluator::isDeterministic)));
inheritedPredicate = logicalRowExpressions.filterDeterministicConjuncts(inheritedPredicate);
outerEffectivePredicate = logicalRowExpressions.filterDeterministicConjuncts(outerEffectivePredicate);
innerEffectivePredicate = logicalRowExpressions.filterDeterministicConjuncts(innerEffectivePredicate);
joinConjuncts.addAll(filter(extractConjuncts(joinPredicate), not(determinismEvaluator::isDeterministic)));
joinPredicate = logicalRowExpressions.filterDeterministicConjuncts(joinPredicate);
// Generate equality inferences
EqualityInference inheritedInference = createEqualityInference(inheritedPredicate);
EqualityInference outerInference = createEqualityInference(inheritedPredicate, outerEffectivePredicate);
EqualityInference.EqualityPartition equalityPartition = inheritedInference.generateEqualitiesPartitionedBy(in(outerVariables));
RowExpression outerOnlyInheritedEqualities = logicalRowExpressions.combineConjuncts(equalityPartition.getScopeEqualities());
EqualityInference potentialNullSymbolInference = createEqualityInference(outerOnlyInheritedEqualities, outerEffectivePredicate, innerEffectivePredicate, joinPredicate);
// Generate inequality inferences
if (inferInequalityPredicates) {
InequalityInference inequalityInference = new InequalityInference.Builder(functionAndTypeManager, expressionEquivalence, Optional.of(outerVariables))
.addInequalityInferences(joinPredicate, inheritedPredicate)
.build();
innerPushdownConjuncts.addAll(inequalityInference.inferInequalities());
}
// See if we can push inherited predicates down
for (RowExpression conjunct : nonInferableConjuncts(inheritedPredicate)) {
RowExpression outerRewritten = outerInference.rewriteExpression(conjunct, in(outerVariables));
if (outerRewritten != null) {
outerPushdownConjuncts.add(outerRewritten);
// A conjunct can only be pushed down into an inner side if it can be rewritten in terms of the outer side
RowExpression innerRewritten = potentialNullSymbolInference.rewriteExpression(outerRewritten, not(in(outerVariables)));
if (innerRewritten != null) {
innerPushdownConjuncts.add(innerRewritten);
}
}
else {
postJoinConjuncts.add(conjunct);
}
}
if (shouldGenerateDomainFilters(session)) {
// Extract domains for each of the variables from the inherited predicate
// See related comment on #processInnerJoin
rowExpressionDomainTranslator.fromPredicate(session.toConnectorSession(), inheritedPredicate)
.getTupleDomain()
.getDomains()
.ifPresent(map -> map.forEach((variable, domain) -> {
// For outer-side, inferred domains can be pushed down as-is
if (outerVariables.contains(variable)) {
outerPushdownConjuncts.add(rowExpressionDomainTranslator.toPredicate(domain, variable));
}
// For inner-side, only domains that don't include NULL can be pushed down
else if (!domain.isNullAllowed()) {
innerPushdownConjuncts.add(rowExpressionDomainTranslator.toPredicate(domain, variable));
}
}));
}
// Add the equalities from the inferences back in
outerPushdownConjuncts.addAll(equalityPartition.getScopeEqualities());
postJoinConjuncts.addAll(equalityPartition.getScopeComplementEqualities());
postJoinConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities());
// See if we can push down any outer effective predicates to the inner side
for (RowExpression conjunct : nonInferableConjuncts(outerEffectivePredicate)) {
RowExpression rewritten = potentialNullSymbolInference.rewriteExpression(conjunct, not(in(outerVariables)));
if (rewritten != null) {
innerPushdownConjuncts.add(rewritten);
}
}
// See if we can push down join predicates to the inner side
for (RowExpression conjunct : nonInferableConjuncts(joinPredicate)) {
RowExpression innerRewritten = potentialNullSymbolInference.rewriteExpression(conjunct, not(in(outerVariables)));
if (innerRewritten != null) {
innerPushdownConjuncts.add(innerRewritten);
}
else {
joinConjuncts.add(conjunct);
}
}
// Push outer and join equalities into the inner side. For example:
// SELECT * FROM nation LEFT OUTER JOIN region ON nation.regionkey = region.regionkey and nation.name = region.name WHERE nation.name = 'blah'
EqualityInference potentialNullSymbolInferenceWithoutInnerInferred = createEqualityInference(outerOnlyInheritedEqualities, outerEffectivePredicate, joinPredicate);
innerPushdownConjuncts.addAll(potentialNullSymbolInferenceWithoutInnerInferred.generateEqualitiesPartitionedBy(not(in(outerVariables))).getScopeEqualities());
// TODO: we can further improve simplifying the equalities by considering other relationships from the outer side
EqualityInference.EqualityPartition joinEqualityPartition = createEqualityInference(joinPredicate).generateEqualitiesPartitionedBy(not(in(outerVariables)));
innerPushdownConjuncts.addAll(joinEqualityPartition.getScopeEqualities());
joinConjuncts.addAll(joinEqualityPartition.getScopeComplementEqualities())
.addAll(joinEqualityPartition.getScopeStraddlingEqualities());
return new OuterJoinPushDownResult(logicalRowExpressions.combineConjuncts(outerPushdownConjuncts.build()),
logicalRowExpressions.combineConjuncts(innerPushdownConjuncts.build()),
logicalRowExpressions.combineConjuncts(joinConjuncts.build()),
logicalRowExpressions.combineConjuncts(postJoinConjuncts.build()));
}
private static class OuterJoinPushDownResult
{
private final RowExpression outerJoinPredicate;
private final RowExpression innerJoinPredicate;
private final RowExpression joinPredicate;
private final RowExpression postJoinPredicate;
private OuterJoinPushDownResult(RowExpression outerJoinPredicate, RowExpression innerJoinPredicate, RowExpression joinPredicate, RowExpression postJoinPredicate)
{
this.outerJoinPredicate = outerJoinPredicate;
this.innerJoinPredicate = innerJoinPredicate;
this.joinPredicate = joinPredicate;
this.postJoinPredicate = postJoinPredicate;
}
private RowExpression getOuterJoinPredicate()
{
return outerJoinPredicate;
}
private RowExpression getInnerJoinPredicate()
{
return innerJoinPredicate;
}
public RowExpression getJoinPredicate()
{
return joinPredicate;
}
private RowExpression getPostJoinPredicate()
{
return postJoinPredicate;
}
}
private InnerJoinPushDownResult processInnerJoin(
RowExpression inheritedPredicate,
RowExpression leftEffectivePredicate,
RowExpression rightEffectivePredicate,
RowExpression joinPredicate,
Collection<VariableReferenceExpression> leftVariables,
boolean inferInequalityPredicates)
{
checkArgument(Iterables.all(extractUnique(leftEffectivePredicate), in(leftVariables)), "leftEffectivePredicate must only contain variables from leftVariables");
checkArgument(Iterables.all(extractUnique(rightEffectivePredicate), not(in(leftVariables))), "rightEffectivePredicate must not contain variables from leftVariables");
ImmutableList.Builder<RowExpression> leftPushDownConjuncts = ImmutableList.builder();
ImmutableList.Builder<RowExpression> rightPushDownConjuncts = ImmutableList.builder();
ImmutableList.Builder<RowExpression> joinConjuncts = ImmutableList.builder();
// Strip out non-deterministic conjuncts
joinConjuncts.addAll(filter(extractConjuncts(inheritedPredicate), not(determinismEvaluator::isDeterministic)));
inheritedPredicate = logicalRowExpressions.filterDeterministicConjuncts(inheritedPredicate);
joinConjuncts.addAll(filter(extractConjuncts(joinPredicate), not(determinismEvaluator::isDeterministic)));
joinPredicate = logicalRowExpressions.filterDeterministicConjuncts(joinPredicate);
leftEffectivePredicate = logicalRowExpressions.filterDeterministicConjuncts(leftEffectivePredicate);
rightEffectivePredicate = logicalRowExpressions.filterDeterministicConjuncts(rightEffectivePredicate);
// Generate inequality inferences
if (inferInequalityPredicates) {
InequalityInference inequalityInference = new InequalityInference.Builder(functionAndTypeManager, expressionEquivalence, Optional.empty())
.addInequalityInferences(joinPredicate, inheritedPredicate)
.build();
joinConjuncts.addAll(inequalityInference.inferInequalities());
}
// Generate equality inferences
EqualityInference allInference = new EqualityInference.Builder(functionAndTypeManager)
.addEqualityInference(inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, joinPredicate)
.build();
EqualityInference allInferenceWithoutLeftInferred = new EqualityInference.Builder(functionAndTypeManager)
.addEqualityInference(inheritedPredicate, rightEffectivePredicate, joinPredicate)
.build();
EqualityInference allInferenceWithoutRightInferred = new EqualityInference.Builder(functionAndTypeManager)
.addEqualityInference(inheritedPredicate, leftEffectivePredicate, joinPredicate)
.build();
// Sort through conjuncts in inheritedPredicate that were not used for inference
for (RowExpression conjunct : new EqualityInference.Builder(functionAndTypeManager).nonInferableConjuncts(inheritedPredicate)) {
RowExpression leftRewrittenConjunct = allInference.rewriteExpression(conjunct, in(leftVariables));
if (leftRewrittenConjunct != null) {
leftPushDownConjuncts.add(leftRewrittenConjunct);
}
RowExpression rightRewrittenConjunct = allInference.rewriteExpression(conjunct, not(in(leftVariables)));
if (rightRewrittenConjunct != null) {
rightPushDownConjuncts.add(rightRewrittenConjunct);
}
// Drop predicate after join only if unable to push down to either side
if (leftRewrittenConjunct == null && rightRewrittenConjunct == null) {
joinConjuncts.add(conjunct);
}
}
// See if we can push the right effective predicate to the left side
for (RowExpression conjunct : new EqualityInference.Builder(functionAndTypeManager).nonInferableConjuncts(rightEffectivePredicate)) {
RowExpression rewritten = allInference.rewriteExpression(conjunct, in(leftVariables));
if (rewritten != null) {
leftPushDownConjuncts.add(rewritten);
}
}
// See if we can push the left effective predicate to the right side
for (RowExpression conjunct : new EqualityInference.Builder(functionAndTypeManager).nonInferableConjuncts(leftEffectivePredicate)) {
RowExpression rewritten = allInference.rewriteExpression(conjunct, not(in(leftVariables)));
if (rewritten != null) {
rightPushDownConjuncts.add(rewritten);
}
}
// See if we can push any parts of the join predicates to either side
for (RowExpression conjunct : new EqualityInference.Builder(functionAndTypeManager).nonInferableConjuncts(joinPredicate)) {
RowExpression leftRewritten = allInference.rewriteExpression(conjunct, in(leftVariables));
if (leftRewritten != null) {
leftPushDownConjuncts.add(leftRewritten);
}
RowExpression rightRewritten = allInference.rewriteExpression(conjunct, not(in(leftVariables)));
if (rightRewritten != null) {
rightPushDownConjuncts.add(rightRewritten);
}
if (leftRewritten == null && rightRewritten == null) {
joinConjuncts.add(conjunct);
}
}
if (shouldGenerateDomainFilters(session)) {
// Extract domains for each of the variables from the inherited predicate
// and translate them back to predicates that can be added to the sources
// These prove to be helpful to extract predicates on columns that cannot be converted to CNF cleanly
// E.g. for [(Left=4 AND Right=20) or (Left=5 AND Right=21) or (Left=6 AND Right=22)]
// We would extract the TupleDomain = [ Left IN (4,5,6), Right IN (20,21,22)], then convert them into predicates on 'Left' & 'Right'
// Note that, we can end up adding logically equivalent duplicate conjuncts for some variables
// These are usually eliminated during later stages of predicate simplification. However, if some remain
// these redundant predicates do end up increasing plan node cost (with no impact to correctness)
// TODO : Move this filter addition as a cost-based rule if/when we implement a true CBO
rowExpressionDomainTranslator.fromPredicate(session.toConnectorSession(), inheritedPredicate)
.getTupleDomain()
.getDomains()
.ifPresent(map -> map.forEach((variable, domain) -> {
if (leftVariables.contains(variable)) {
leftPushDownConjuncts.add(rowExpressionDomainTranslator.toPredicate(domain, variable));
}
else {
rightPushDownConjuncts.add(rowExpressionDomainTranslator.toPredicate(domain, variable));
}
}));
}
// Add equalities from the inference back in
leftPushDownConjuncts.addAll(allInferenceWithoutLeftInferred.generateEqualitiesPartitionedBy(in(leftVariables)).getScopeEqualities());
rightPushDownConjuncts.addAll(allInferenceWithoutRightInferred.generateEqualitiesPartitionedBy(not(in(leftVariables))).getScopeEqualities());
joinConjuncts.addAll(allInference.generateEqualitiesPartitionedBy(in(leftVariables)::apply).getScopeStraddlingEqualities()); // scope straddling equalities get dropped in as part of the join predicate
return new Rewriter.InnerJoinPushDownResult(
expressionOptimizerProvider,
logicalRowExpressions.combineConjuncts(leftPushDownConjuncts.build()),
logicalRowExpressions.combineConjuncts(rightPushDownConjuncts.build()),
logicalRowExpressions.combineConjuncts(joinConjuncts.build()), TRUE_CONSTANT);
}
private static class InnerJoinPushDownResult
{
private final RowExpression leftPredicate;
private final RowExpression rightPredicate;
private final RowExpression joinPredicate;
private final RowExpression postJoinPredicate;
private final ExpressionOptimizerProvider expressionOptimizerProvider;
private InnerJoinPushDownResult(
ExpressionOptimizerProvider expressionOptimizerProvider,
RowExpression leftPredicate,
RowExpression rightPredicate,
RowExpression joinPredicate,
RowExpression postJoinPredicate)
{
this.expressionOptimizerProvider = requireNonNull(expressionOptimizerProvider, "expressionOptimizerProvider is null");
this.leftPredicate = requireNonNull(leftPredicate, "leftPredicate is null");
this.rightPredicate = requireNonNull(rightPredicate, "rightPredicate is null");
this.joinPredicate = requireNonNull(joinPredicate, "joinPredicate is null");
this.postJoinPredicate = requireNonNull(postJoinPredicate, "postJoinPredicate is null");
}
private RowExpression getLeftPredicate()
{
return leftPredicate;
}
private RowExpression getRightPredicate()
{
return rightPredicate;
}
private RowExpression getJoinPredicate()
{
return joinPredicate;
}
private RowExpression getPostJoinPredicate()
{
return postJoinPredicate;
}
}
private RowExpression extractJoinPredicate(JoinNode joinNode)
{
ImmutableList.Builder<RowExpression> builder = ImmutableList.builder();
for (EquiJoinClause equiJoinClause : joinNode.getCriteria()) {
builder.add(toRowExpression(equiJoinClause));
}
joinNode.getFilter().ifPresent(builder::add);
return logicalRowExpressions.combineConjuncts(builder.build());
}
private RowExpression toRowExpression(EquiJoinClause equiJoinClause)
{
return buildEqualsExpression(functionAndTypeManager, equiJoinClause.getLeft(), equiJoinClause.getRight());
}
private JoinNode tryNormalizeToOuterToInnerJoin(JoinNode node, RowExpression inheritedPredicate)
{
checkArgument(EnumSet.of(INNER, RIGHT, LEFT, FULL).contains(node.getType()), "Unsupported join type: %s", node.getType());
if (node.getType() == JoinType.INNER) {
return node;
}
if (node.getType() == JoinType.FULL) {
boolean canConvertToLeftJoin = canConvertOuterToInner(node.getLeft().getOutputVariables(), inheritedPredicate);
boolean canConvertToRightJoin = canConvertOuterToInner(node.getRight().getOutputVariables(), inheritedPredicate);
if (!canConvertToLeftJoin && !canConvertToRightJoin) {
return node;
}
if (canConvertToLeftJoin && canConvertToRightJoin) {
return new JoinNode(
node.getSourceLocation(),
node.getId(),
INNER,
node.getLeft(),
node.getRight(),
node.getCriteria(),
node.getOutputVariables(),
node.getFilter(),
node.getLeftHashVariable(),
node.getRightHashVariable(),
node.getDistributionType(),
node.getDynamicFilters());
}
else {
return new JoinNode(
node.getSourceLocation(),
node.getId(),
canConvertToLeftJoin ? LEFT : RIGHT,
node.getLeft(),
node.getRight(),
node.getCriteria(),
node.getOutputVariables(),
node.getFilter(),
node.getLeftHashVariable(),
node.getRightHashVariable(),
node.getDistributionType(),
node.getDynamicFilters());
}
}
if (node.getType() == JoinType.LEFT && !canConvertOuterToInner(node.getRight().getOutputVariables(), inheritedPredicate) ||
node.getType() == JoinType.RIGHT && !canConvertOuterToInner(node.getLeft().getOutputVariables(), inheritedPredicate)) {
return node;
}
return new JoinNode(
node.getSourceLocation(),
node.getId(),
JoinType.INNER,
node.getLeft(),
node.getRight(),
node.getCriteria(),
node.getOutputVariables(),
node.getFilter(),
node.getLeftHashVariable(),
node.getRightHashVariable(),
node.getDistributionType(),
node.getDynamicFilters());
}
private boolean canConvertOuterToInner(List<VariableReferenceExpression> innerVariablesForOuterJoin, RowExpression inheritedPredicate)
{
Set<VariableReferenceExpression> innerVariables = ImmutableSet.copyOf(innerVariablesForOuterJoin);
for (RowExpression conjunct : extractConjuncts(inheritedPredicate)) {
if (determinismEvaluator.isDeterministic(conjunct)) {
// Ignore a conjunct for this test if we can not deterministically get responses from it
RowExpression response = nullInputEvaluator(innerVariables, conjunct);
if (response == null || Expressions.isNull(response) || FALSE_CONSTANT.equals(response)) {
// If there is a single conjunct that returns FALSE or NULL given all NULL inputs for the inner side symbols of an outer join
// then this conjunct removes all effects of the outer join, and effectively turns this into an equivalent of an inner join.
// So, let's just rewrite this join as an INNER join
return true;
}
}
}
return false;
}
// Temporary implementation for joins because the SimplifyExpressions optimizers can not run properly on join clauses
private RowExpression simplifyExpression(RowExpression expression)
{
return expressionOptimizerProvider.getExpressionOptimizer(session.toConnectorSession()).optimize(expression, ExpressionOptimizer.Level.SERIALIZABLE, session.toConnectorSession());
}
private boolean areExpressionsEquivalent(RowExpression leftExpression, RowExpression rightExpression)
{
return expressionEquivalence.areExpressionsEquivalent(simplifyExpression(leftExpression), simplifyExpression(rightExpression));
}
/**
* Evaluates an expression's response to binding the specified input symbols to NULL
*/
private RowExpression nullInputEvaluator(final Collection<VariableReferenceExpression> nullSymbols, RowExpression expression)
{
expression = RowExpressionNodeInliner.replaceExpression(expression, nullSymbols.stream()
.collect(Collectors.toMap(identity(), variable -> constantNull(variable.getSourceLocation(), variable.getType()))));
return expressionOptimizerProvider.getExpressionOptimizer(session.toConnectorSession()).optimize(expression, ExpressionOptimizer.Level.OPTIMIZED, session.toConnectorSession());
}
private Predicate<RowExpression> joinEqualityExpression(final Collection<VariableReferenceExpression> leftVariables)
{
return expression -> {
// At this point in time, our join predicates need to be deterministic
if (determinismEvaluator.isDeterministic(expression) && isOperation(expression, EQUAL)) {
Set<VariableReferenceExpression> variables1 = extractUnique(getLeft(expression));
Set<VariableReferenceExpression> variables2 = extractUnique(getRight(expression));
if (variables1.isEmpty() || variables2.isEmpty()) {
return false;
}
return (Iterables.all(variables1, in(leftVariables)) && Iterables.all(variables2, not(in(leftVariables)))) ||
(Iterables.all(variables2, in(leftVariables)) && Iterables.all(variables1, not(in(leftVariables))));
}
return false;
};
}
private boolean isOperation(RowExpression expression, OperatorType type)
{
if (expression instanceof CallExpression) {
Optional<OperatorType> operatorType = functionAndTypeManager.getFunctionMetadata(((CallExpression) expression).getFunctionHandle()).getOperatorType();
if (operatorType.isPresent()) {
return operatorType.get().equals(type);
}
}
return false;
}
@Override
public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext<RowExpression> context)
{
Set<RowExpression> inheritedConjuncts = ImmutableSet.copyOf(extractConjuncts(context.get()));
if (inheritedConjuncts.contains(node.getSemiJoinOutput()) ||
inheritedConjuncts.contains(logicalRowExpressions.equalsCallExpression(node.getSemiJoinOutput(), TRUE_CONSTANT)) ||
inheritedConjuncts.contains(logicalRowExpressions.equalsCallExpression(TRUE_CONSTANT, node.getSemiJoinOutput()))) {
return visitFilteringSemiJoin(node, context);
}
return visitNonFilteringSemiJoin(node, context);
}
private PlanNode visitNonFilteringSemiJoin(SemiJoinNode node, RewriteContext<RowExpression> context)
{
RowExpression inheritedPredicate = context.get();
List<RowExpression> sourceConjuncts = new ArrayList<>();
List<RowExpression> postJoinConjuncts = new ArrayList<>();
// TODO: see if there are predicates that can be inferred from the semi join output
PlanNode rewrittenFilteringSource = context.defaultRewrite(node.getFilteringSource(), TRUE_CONSTANT);
// Push inheritedPredicates down to the source if they don't involve the semi join output
EqualityInference inheritedInference = new EqualityInference.Builder(functionAndTypeManager)
.addEqualityInference(inheritedPredicate)
.build();
for (RowExpression conjunct : new EqualityInference.Builder(functionAndTypeManager).nonInferableConjuncts(inheritedPredicate)) {
RowExpression rewrittenConjunct = inheritedInference.rewriteExpressionAllowNonDeterministic(conjunct, in(node.getSource().getOutputVariables()));
// Since each source row is reflected exactly once in the output, ok to push non-deterministic predicates down
if (rewrittenConjunct != null) {
sourceConjuncts.add(rewrittenConjunct);
}
else {
postJoinConjuncts.add(conjunct);
}
}
// Add the inherited equality predicates back in
EqualityInference.EqualityPartition equalityPartition = inheritedInference.generateEqualitiesPartitionedBy(in(node.getSource()
.getOutputVariables())::apply);
sourceConjuncts.addAll(equalityPartition.getScopeEqualities());
postJoinConjuncts.addAll(equalityPartition.getScopeComplementEqualities());
postJoinConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities());
PlanNode rewrittenSource = context.rewrite(node.getSource(), logicalRowExpressions.combineConjuncts(sourceConjuncts));
PlanNode output = node;
if (rewrittenSource != node.getSource() || rewrittenFilteringSource != node.getFilteringSource()) {
planChanged = true;
output = new SemiJoinNode(node.getSourceLocation(), node.getId(), rewrittenSource, rewrittenFilteringSource, node.getSourceJoinVariable(), node.getFilteringSourceJoinVariable(), node.getSemiJoinOutput(), node.getSourceHashVariable(), node.getFilteringSourceHashVariable(), node.getDistributionType(), node.getDynamicFilters());
}
if (!postJoinConjuncts.isEmpty()) {
planChanged = true;
output = new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), output, logicalRowExpressions.combineConjuncts(postJoinConjuncts));
}
return output;
}
private boolean isEnableDynamicFiltering()
{
return !nativeExecution && SystemSessionProperties.isEnableDynamicFiltering(session);
}
private PlanNode visitFilteringSemiJoin(SemiJoinNode node, RewriteContext<RowExpression> context)
{
List<RowExpression> postJoinConjuncts = new ArrayList<>();
List<RowExpression> sourceConjuncts = new ArrayList<>();
List<RowExpression> filteringSourceConjuncts = new ArrayList<>();
// Remove any conjuncts which involve the semi join output from the passed in predicate, these cannot be pushed down or rewritten
ImmutableList.Builder<RowExpression> predicateOnSources = ImmutableList.builder();
LogicalRowExpressions.extractConjuncts(context.get()).forEach(conjunct -> {
if (extractUnique(conjunct).contains(node.getSemiJoinOutput())) {
postJoinConjuncts.add(conjunct);
}
else {
predicateOnSources.add(conjunct);
}
});
RowExpression inheritedPredicate = logicalRowExpressions.combineConjuncts(predicateOnSources.build());
RowExpression deterministicInheritedPredicate = logicalRowExpressions.filterDeterministicConjuncts(inheritedPredicate);
RowExpression sourceEffectivePredicate = logicalRowExpressions.filterDeterministicConjuncts(effectivePredicateExtractor.extract(node.getSource()));
RowExpression filteringSourceEffectivePredicate = logicalRowExpressions.filterDeterministicConjuncts(effectivePredicateExtractor.extract(node.getFilteringSource()));
RowExpression joinExpression = buildEqualsExpression(functionAndTypeManager, node.getSourceJoinVariable(), node.getFilteringSourceJoinVariable());
List<VariableReferenceExpression> sourceVariables = node.getSource().getOutputVariables();
List<VariableReferenceExpression> filteringSourceVariables = node.getFilteringSource().getOutputVariables();
// Generate equality inferences
EqualityInference allInference = createEqualityInference(deterministicInheritedPredicate, sourceEffectivePredicate, filteringSourceEffectivePredicate, joinExpression);
EqualityInference allInferenceWithoutSourceInferred = createEqualityInference(deterministicInheritedPredicate, filteringSourceEffectivePredicate, joinExpression);
EqualityInference allInferenceWithoutFilteringSourceInferred = createEqualityInference(deterministicInheritedPredicate, sourceEffectivePredicate, joinExpression);
// Push inherited Predicates down to the source if possible
for (RowExpression conjunct : nonInferableConjuncts(inheritedPredicate)) {
RowExpression rewrittenConjunct = allInference.rewriteExpressionAllowNonDeterministic(conjunct, in(sourceVariables));
// Since each source row is reflected exactly once in the output, ok to push non-deterministic predicates down
if (rewrittenConjunct != null) {
sourceConjuncts.add(rewrittenConjunct);
}
else {
postJoinConjuncts.add(conjunct);
}
}
// Push inherited Predicates down to the filtering source if possible
for (RowExpression conjunct : nonInferableConjuncts(deterministicInheritedPredicate)) {
RowExpression rewrittenConjunct = allInference.rewriteExpression(conjunct, in(filteringSourceVariables));
// We cannot push non-deterministic predicates to filtering side. Each filtering side row have to be
// logically reevaluated for each source row.
if (rewrittenConjunct != null) {
filteringSourceConjuncts.add(rewrittenConjunct);
}
}
// move effective predicate conjuncts source <-> filter
// See if we can push the filtering source effective predicate to the source side
for (RowExpression conjunct : nonInferableConjuncts(filteringSourceEffectivePredicate)) {
RowExpression rewritten = allInference.rewriteExpression(conjunct, in(sourceVariables));
if (rewritten != null) {
sourceConjuncts.add(rewritten);
}
}
// See if we can push the source effective predicate to the filtering source side
for (RowExpression conjunct : nonInferableConjuncts(sourceEffectivePredicate)) {
RowExpression rewritten = allInference.rewriteExpression(conjunct, in(filteringSourceVariables));
if (rewritten != null) {
filteringSourceConjuncts.add(rewritten);
}
}
// Add equalities from the inference back in
sourceConjuncts.addAll(allInferenceWithoutSourceInferred.generateEqualitiesPartitionedBy(in(sourceVariables)).getScopeEqualities());
filteringSourceConjuncts.addAll(allInferenceWithoutFilteringSourceInferred.generateEqualitiesPartitionedBy(in(filteringSourceVariables)).getScopeEqualities());
PlanNode rewrittenSource = context.rewrite(node.getSource(), logicalRowExpressions.combineConjuncts(sourceConjuncts));
PlanNode rewrittenFilteringSource = context.rewrite(node.getFilteringSource(), logicalRowExpressions.combineConjuncts(filteringSourceConjuncts));
Map<String, VariableReferenceExpression> dynamicFilters = ImmutableMap.of();
if (isEnableDynamicFiltering()) {
DynamicFiltersResult dynamicFiltersResult = createDynamicFilters(node.getSourceJoinVariable(), node.getFilteringSourceJoinVariable(), idAllocator, metadata.getFunctionAndTypeManager());
dynamicFilters = dynamicFiltersResult.getDynamicFilters();
// add filter node on top of probe
rewrittenSource = new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), rewrittenSource, logicalRowExpressions.combineConjuncts(dynamicFiltersResult.getPredicates()));
}
PlanNode output = node;
if (rewrittenSource != node.getSource() || rewrittenFilteringSource != node.getFilteringSource() || !dynamicFilters.isEmpty()) {
planChanged = true;
output = new SemiJoinNode(
node.getSourceLocation(),
node.getId(),
rewrittenSource,
rewrittenFilteringSource,
node.getSourceJoinVariable(),
node.getFilteringSourceJoinVariable(),
node.getSemiJoinOutput(),
node.getSourceHashVariable(),
node.getFilteringSourceHashVariable(),
node.getDistributionType(),
dynamicFilters);
}
if (!postJoinConjuncts.isEmpty()) {
planChanged = true;
output = new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), output, logicalRowExpressions.combineConjuncts(postJoinConjuncts));
}
return output;
}
private Iterable<RowExpression> nonInferableConjuncts(RowExpression inheritedPredicate)
{
return new EqualityInference.Builder(functionAndTypeManager)
.nonInferableConjuncts(inheritedPredicate);
}
private EqualityInference createEqualityInference(RowExpression... expressions)
{
return new EqualityInference.Builder(functionAndTypeManager)
.addEqualityInference(expressions)
.build();
}
@Override
public PlanNode visitAggregation(AggregationNode node, RewriteContext<RowExpression> context)
{
if (node.hasEmptyGroupingSet()) {
// TODO: in case of grouping sets, we should be able to push the filters over grouping keys below the aggregation
// and also preserve the filter above the aggregation if it has an empty grouping set
return visitPlan(node, context);
}
RowExpression inheritedPredicate = context.get();
EqualityInference equalityInference = createEqualityInference(inheritedPredicate);
List<RowExpression> pushdownConjuncts = new ArrayList<>();
List<RowExpression> postAggregationConjuncts = new ArrayList<>();
List<VariableReferenceExpression> groupingKeyVariables = node.getGroupingKeys();
// Strip out non-deterministic conjuncts
postAggregationConjuncts.addAll(ImmutableList.copyOf(filter(extractConjuncts(inheritedPredicate), not(determinismEvaluator::isDeterministic))));
inheritedPredicate = logicalRowExpressions.filterDeterministicConjuncts(inheritedPredicate);
// Sort non-equality predicates by those that can be pushed down and those that cannot
for (RowExpression conjunct : nonInferableConjuncts(inheritedPredicate)) {
if (node.getGroupIdVariable().isPresent() && extractUnique(conjunct).contains(node.getGroupIdVariable().get())) {
// aggregation operator synthesizes outputs for group ids corresponding to the global grouping set (i.e., ()), so we
// need to preserve any predicates that evaluate the group id to run after the aggregation
// TODO: we should be able to infer if conditions on grouping() correspond to global grouping sets to determine whether
// we need to do this for each specific case
postAggregationConjuncts.add(conjunct);
continue;
}
RowExpression rewrittenConjunct = equalityInference.rewriteExpression(conjunct, in(groupingKeyVariables));
if (rewrittenConjunct != null) {
pushdownConjuncts.add(rewrittenConjunct);
}
else {
postAggregationConjuncts.add(conjunct);
}
}
// Add the equality predicates back in
EqualityInference.EqualityPartition equalityPartition = equalityInference.generateEqualitiesPartitionedBy(in(groupingKeyVariables)::apply);
pushdownConjuncts.addAll(equalityPartition.getScopeEqualities());
postAggregationConjuncts.addAll(equalityPartition.getScopeComplementEqualities());
postAggregationConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities());
PlanNode rewrittenSource = context.rewrite(node.getSource(), logicalRowExpressions.combineConjuncts(pushdownConjuncts));
PlanNode output = node;
if (rewrittenSource != node.getSource()) {
planChanged = true;
output = new AggregationNode(
node.getSourceLocation(),
node.getId(),
rewrittenSource,
node.getAggregations(),
node.getGroupingSets(),
ImmutableList.of(),
node.getStep(),
node.getHashVariable(),
node.getGroupIdVariable(),
node.getAggregationId());
}
if (!postAggregationConjuncts.isEmpty()) {
planChanged = true;
output = new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), output, logicalRowExpressions.combineConjuncts(postAggregationConjuncts));
}
return output;
}
@Override
public PlanNode visitUnnest(UnnestNode node, RewriteContext<RowExpression> context)
{
RowExpression inheritedPredicate = context.get();
EqualityInference equalityInference = createEqualityInference(inheritedPredicate);
List<RowExpression> pushdownConjuncts = new ArrayList<>();
List<RowExpression> postUnnestConjuncts = new ArrayList<>();
// Strip out non-deterministic conjuncts
postUnnestConjuncts.addAll(ImmutableList.copyOf(filter(extractConjuncts(inheritedPredicate), not(determinismEvaluator::isDeterministic))));
inheritedPredicate = logicalRowExpressions.filterDeterministicConjuncts(inheritedPredicate);
// Sort non-equality predicates by those that can be pushed down and those that cannot
for (RowExpression conjunct : nonInferableConjuncts(inheritedPredicate)) {
RowExpression rewrittenConjunct = equalityInference.rewriteExpression(conjunct, in(node.getReplicateVariables()));
if (rewrittenConjunct != null) {
pushdownConjuncts.add(rewrittenConjunct);
}
else {
postUnnestConjuncts.add(conjunct);
}
}
// Add the equality predicates back in
EqualityInference.EqualityPartition equalityPartition = equalityInference.generateEqualitiesPartitionedBy(in(node.getReplicateVariables())::apply);
pushdownConjuncts.addAll(equalityPartition.getScopeEqualities());
postUnnestConjuncts.addAll(equalityPartition.getScopeComplementEqualities());
postUnnestConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities());
PlanNode rewrittenSource = context.rewrite(node.getSource(), logicalRowExpressions.combineConjuncts(pushdownConjuncts));
PlanNode output = node;
if (rewrittenSource != node.getSource()) {
planChanged = true;
output = new UnnestNode(node.getSourceLocation(), node.getId(), rewrittenSource, node.getReplicateVariables(), node.getUnnestVariables(), node.getOrdinalityVariable());
}
if (!postUnnestConjuncts.isEmpty()) {
planChanged = true;
output = new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), output, logicalRowExpressions.combineConjuncts(postUnnestConjuncts));
}
return output;
}
@Override
public PlanNode visitSample(SampleNode node, RewriteContext<RowExpression> context)
{
return context.defaultRewrite(node, context.get());
}
@Override
public PlanNode visitTableScan(TableScanNode node, RewriteContext<RowExpression> context)
{
RowExpression predicate = simplifyExpression(context.get());
if (!TRUE_CONSTANT.equals(predicate)) {
planChanged = true;
return new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), node, predicate);
}
return node;
}
@Override
public PlanNode visitAssignUniqueId(AssignUniqueId node, RewriteContext<RowExpression> context)
{
Set<VariableReferenceExpression> predicateVariables = extractUnique(context.get());
checkState(!predicateVariables.contains(node.getIdVariable()), "UniqueId in predicate is not yet supported");
return context.defaultRewrite(node, context.get());
}
private static CallExpression buildEqualsExpression(FunctionAndTypeManager functionAndTypeManager, RowExpression left, RowExpression right)
{
return call(
EQUAL.getFunctionName().getObjectName(),
functionAndTypeManager.resolveOperator(EQUAL, fromTypes(left.getType(), right.getType())),
BOOLEAN,
left,
right);
}
}
}