LeftJoinWithArrayContainsToEquiJoinCondition.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;
import com.facebook.presto.Session;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.UnnestNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.FeaturesConfig.LeftJoinArrayContainsToInnerJoinStrategy;
import com.facebook.presto.sql.planner.PlannerUtils;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.Optional;
import static com.facebook.presto.SystemSessionProperties.getLeftJoinArrayContainsToInnerJoinStrategy;
import static com.facebook.presto.expressions.LogicalRowExpressions.and;
import static com.facebook.presto.expressions.LogicalRowExpressions.extractConjuncts;
import static com.facebook.presto.sql.planner.VariablesExtractor.extractAll;
import static com.facebook.presto.sql.planner.plan.Patterns.join;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;
/**
* When the join condition of a left join has pattern of contains(array, element) where array, we can rewrite it as a equi join condition. For example:
* <pre>
* - Left Join
* empty join clause
* filter: contains(r_array, l_key)
* - scan l
* - scan r
* </pre>
* into:
* <pre>
* - Left Join
* l_key = field
* - scan l
* - Unnest
* field <- unnest distinct_array
* - project
* distinct_array := remove_nulls(array_distinct(r_array))
* - scan r
* r_array
* </pre>
*/
public class LeftJoinWithArrayContainsToEquiJoinCondition
implements Rule<JoinNode>
{
private static final Pattern<JoinNode> PATTERN = join().matching(x -> x.getType().equals(JoinType.LEFT) && x.getCriteria().isEmpty() && x.getFilter().isPresent());
private final FunctionAndTypeManager functionAndTypeManager;
private final RowExpressionDeterminismEvaluator determinismEvaluator;
private final FunctionResolution functionResolution;
public LeftJoinWithArrayContainsToEquiJoinCondition(FunctionAndTypeManager functionAndTypeManager)
{
this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
this.determinismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager);
this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
}
@Override
public Pattern<JoinNode> getPattern()
{
return PATTERN;
}
@Override
public boolean isEnabled(Session session)
{
// TODO: implement cost based with HBO
return getLeftJoinArrayContainsToInnerJoinStrategy(session).equals(LeftJoinArrayContainsToInnerJoinStrategy.ALWAYS_ENABLED);
}
@Override
public Result apply(JoinNode node, Captures captures, Context context)
{
RowExpression filterPredicate = node.getFilter().get();
List<VariableReferenceExpression> leftInput = node.getLeft().getOutputVariables();
List<VariableReferenceExpression> rightInput = node.getRight().getOutputVariables();
List<RowExpression> andConjuncts = extractConjuncts(filterPredicate);
Optional<RowExpression> arrayContains = andConjuncts.stream().filter(rowExpression -> isSupportedJoinCondition(rowExpression, leftInput, rightInput)).findFirst();
if (!arrayContains.isPresent()) {
return Result.empty();
}
List<RowExpression> remainingConjuncts = andConjuncts.stream().filter(rowExpression -> !rowExpression.equals(arrayContains.get())).collect(toImmutableList());
RowExpression array = ((CallExpression) arrayContains.get()).getArguments().get(0);
RowExpression element = ((CallExpression) arrayContains.get()).getArguments().get(1);
checkState(array.getType() instanceof ArrayType && ((ArrayType) array.getType()).getElementType().equals(element.getType()));
PlanNode newLeft = node.getLeft();
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> leftAssignment = ImmutableMap.builder();
VariableReferenceExpression elementVariable;
if (!(element instanceof VariableReferenceExpression)) {
elementVariable = context.getVariableAllocator().newVariable(element);
leftAssignment.put(elementVariable, element);
newLeft = PlannerUtils.addProjections(node.getLeft(), context.getIdAllocator(), leftAssignment.build());
}
else {
elementVariable = (VariableReferenceExpression) element;
}
CallExpression arrayDistinct = call(functionAndTypeManager, "array_distinct", new ArrayType(element.getType()), array);
CallExpression arrayFilterNull = call(functionAndTypeManager, "remove_nulls", arrayDistinct.getType(), ImmutableList.of(arrayDistinct));
VariableReferenceExpression arrayFilterNullVariable = context.getVariableAllocator().newVariable(arrayFilterNull);
PlanNode newRight = PlannerUtils.addProjections(node.getRight(), context.getIdAllocator(), ImmutableMap.of(arrayFilterNullVariable, arrayFilterNull));
VariableReferenceExpression unnestVariable = context.getVariableAllocator().newVariable("unnest", element.getType());
UnnestNode unnestNode = new UnnestNode(newRight.getSourceLocation(),
context.getIdAllocator().getNextId(),
newRight,
newRight.getOutputVariables(),
ImmutableMap.of(arrayFilterNullVariable, ImmutableList.of(unnestVariable)),
Optional.empty());
EquiJoinClause equiJoinClause = new EquiJoinClause(elementVariable, unnestVariable);
return Result.ofPlanNode(new JoinNode(node.getSourceLocation(),
context.getIdAllocator().getNextId(),
node.getType(),
newLeft,
unnestNode,
ImmutableList.of(equiJoinClause),
node.getOutputVariables(),
remainingConjuncts.isEmpty() ? Optional.empty() : Optional.of(and(remainingConjuncts)),
Optional.empty(),
Optional.empty(),
node.getDistributionType(),
node.getDynamicFilters()));
}
private boolean isSupportedJoinCondition(RowExpression rowExpression, List<VariableReferenceExpression> leftInput, List<VariableReferenceExpression> rightInput)
{
if (rowExpression instanceof CallExpression && functionResolution.isArrayContainsFunction(((CallExpression) rowExpression).getFunctionHandle())) {
RowExpression arrayExpression = ((CallExpression) rowExpression).getArguments().get(0);
RowExpression elementExpression = ((CallExpression) rowExpression).getArguments().get(1);
return determinismEvaluator.isDeterministic(arrayExpression) && rightInput.containsAll(extractAll(arrayExpression))
&& determinismEvaluator.isDeterministic(elementExpression) && leftInput.containsAll(extractAll(elementExpression));
}
return false;
}
}