CrossJoinWithArrayContainsToInnerJoin.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;
import com.facebook.presto.Session;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.UnnestNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.PlannerUtils;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.Optional;
import static com.facebook.presto.SystemSessionProperties.isRewriteCrossJoinArrayContainsToInnerJoinEnabled;
import static com.facebook.presto.expressions.LogicalRowExpressions.and;
import static com.facebook.presto.expressions.LogicalRowExpressions.extractConjuncts;
import static com.facebook.presto.matching.Capture.newCapture;
import static com.facebook.presto.sql.planner.PlannerUtils.isSupportedArrayContainsFilter;
import static com.facebook.presto.sql.planner.plan.Patterns.filter;
import static com.facebook.presto.sql.planner.plan.Patterns.join;
import static com.facebook.presto.sql.planner.plan.Patterns.source;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;
/**
* Inner join with contains function inside join clause will be run as cross join with filter.
* When the join condition has pattern of contains(array, element), we can rewrite it to a inner join. For example:
* <pre>
* - Filter contains(l_array_key, r_key)
* - Cross join
* - scan l
* - scan r
* </pre>
* into:
* <pre>
* - Join
* field = r_key
* - Unnest
* field <- unnest distinct_array
* - project
* distinct_array := array_distinct(l_array_key)
* - scan l
* l_array_key
* - scan r
* r_key
* </pre>
*/
public class CrossJoinWithArrayContainsToInnerJoin
implements Rule<FilterNode>
{
private static final Capture<JoinNode> CHILD = newCapture();
private static final Pattern<FilterNode> PATTERN = filter()
.with(source().matching(join().matching(x -> x.getType().equals(JoinType.INNER) && x.getCriteria().isEmpty()).capturedAs(CHILD)));
private final FunctionAndTypeManager functionAndTypeManager;
public CrossJoinWithArrayContainsToInnerJoin(FunctionAndTypeManager functionAndTypeManager)
{
this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
}
public static RowExpression getCandidateArrayContainsExpression(FunctionResolution functionResolution, RowExpression filterPredicate, List<VariableReferenceExpression> leftInput, List<VariableReferenceExpression> rightInput)
{
List<RowExpression> andConjuncts = extractConjuncts(filterPredicate);
for (RowExpression conjunct : andConjuncts) {
if (isSupportedArrayContainsFilter(functionResolution, conjunct, leftInput, rightInput)) {
return conjunct;
}
}
return null;
}
@Override
public Pattern<FilterNode> getPattern()
{
return PATTERN;
}
@Override
public boolean isEnabled(Session session)
{
return isRewriteCrossJoinArrayContainsToInnerJoinEnabled(session);
}
@Override
public Result apply(FilterNode node, Captures captures, Context context)
{
JoinNode joinNode = captures.get(CHILD);
if (!(joinNode.getType().equals(JoinType.INNER) && joinNode.getCriteria().isEmpty())) {
return Result.empty();
}
List<VariableReferenceExpression> leftInput = joinNode.getLeft().getOutputVariables();
List<VariableReferenceExpression> rightInput = joinNode.getRight().getOutputVariables();
RowExpression filterExpression = node.getPredicate();
FunctionResolution functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
RowExpression arrayContainsExpression = getCandidateArrayContainsExpression(functionResolution, filterExpression, leftInput, rightInput);
if (arrayContainsExpression == null) {
return Result.empty();
}
List<RowExpression> andConjuncts = extractConjuncts(filterExpression);
List<RowExpression> remainingConjuncts = andConjuncts.stream().filter(x -> !x.equals(arrayContainsExpression)).collect(toImmutableList());
RowExpression array = ((CallExpression) arrayContainsExpression).getArguments().get(0);
RowExpression element = ((CallExpression) arrayContainsExpression).getArguments().get(1);
checkState(element instanceof VariableReferenceExpression, "Argument to CONTAINS is not a column");
checkState(array instanceof VariableReferenceExpression, "Argument to CONTAINS is not a column");
VariableReferenceExpression elementVar = (VariableReferenceExpression) element;
boolean arrayAtLeftInput = leftInput.contains(array);
PlanNode inputWithArray = arrayAtLeftInput ? joinNode.getLeft() : joinNode.getRight();
CallExpression arrayDistinct = call(functionAndTypeManager, "array_distinct", new ArrayType(element.getType()), array);
VariableReferenceExpression arrayDistinctVariable = context.getVariableAllocator().newVariable(arrayDistinct);
PlanNode project = PlannerUtils.addProjections(inputWithArray, context.getIdAllocator(), ImmutableMap.of(arrayDistinctVariable, arrayDistinct));
VariableReferenceExpression unnestVariable = context.getVariableAllocator().newVariable("field", element.getType());
UnnestNode unnest = new UnnestNode(inputWithArray.getSourceLocation(),
context.getIdAllocator().getNextId(),
project,
project.getOutputVariables(),
ImmutableMap.of(arrayDistinctVariable, ImmutableList.of(unnestVariable)),
Optional.empty());
EquiJoinClause equiJoinClause = arrayAtLeftInput ? new EquiJoinClause(unnestVariable, elementVar) : new EquiJoinClause(elementVar, unnestVariable);
JoinNode newJoinNode = new JoinNode(joinNode.getSourceLocation(),
context.getIdAllocator().getNextId(),
joinNode.getType(),
arrayAtLeftInput ? unnest : joinNode.getLeft(),
arrayAtLeftInput ? joinNode.getRight() : unnest,
ImmutableList.of(equiJoinClause),
joinNode.getOutputVariables(),
joinNode.getFilter(),
Optional.empty(),
Optional.empty(),
joinNode.getDistributionType(),
joinNode.getDynamicFilters());
if (!remainingConjuncts.isEmpty()) {
return Result.ofPlanNode(new FilterNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), newJoinNode, and(remainingConjuncts)));
}
return Result.ofPlanNode(newJoinNode);
}
}