PullUpExpressionInLambdaRules.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;
import com.facebook.presto.Session;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.InputReferenceExpression;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.RowExpressionVisitor;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.PlannerUtils;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.airlift.joni.Regex;
import io.airlift.slice.Slice;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.IntStream;
import static com.facebook.presto.SystemSessionProperties.isPullExpressionFromLambdaEnabled;
import static com.facebook.presto.sql.planner.VariablesExtractor.extractAll;
import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignments;
import static com.facebook.presto.sql.planner.plan.Patterns.filter;
import static com.facebook.presto.sql.planner.plan.Patterns.project;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;
/**
* If there are expressions in the body of a lambda function, which does not refer to the arguments of the lambda function, it can be
* evaluated outside of the lambda function, hence avoid evaluating multiple times inside lambda body. An example of the optimization is:
* Before:
* <pre>
* - Project
* expr := filter(array, x -> x > id1+id2)
* - TableScan
* array: array(bigint)
* id1: bigint
* id2: bigint
* </pre>
* After:
* <pre>
* - Project
* expr: filter(array, x -> x > sum)
* - Project:
* sum := id1+id2
* - TableScan
* array: array(bigint)
* id1: bigint
* id2: bigint
* </pre>
*/
public class PullUpExpressionInLambdaRules
{
private final RowExpressionDeterminismEvaluator determinismEvaluator;
private final FunctionResolution functionResolution;
public PullUpExpressionInLambdaRules(FunctionAndTypeManager functionAndTypeManager)
{
requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
this.determinismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager);
this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
}
private static Set<RowExpression> getCandidateRowExpression(RowExpressionDeterminismEvaluator determinismEvaluator, FunctionResolution functionResolution, List<VariableReferenceExpression> inputVariables,
RowExpression rowExpression)
{
ImmutableSet.Builder<RowExpression> candidateBuilder = ImmutableSet.builder();
ValidExpressionExtractor validCallExpressionExtractor = new ValidExpressionExtractor(determinismEvaluator, functionResolution, inputVariables, candidateBuilder);
rowExpression.accept(validCallExpressionExtractor, false);
// If row expression has no variable reference, i.e. is constant, do not pull out
return candidateBuilder.build().stream().filter(x -> !extractAll(x).isEmpty()).collect(toImmutableSet());
}
public boolean isRuleEnabled(Session session)
{
return isPullExpressionFromLambdaEnabled(session);
}
public Set<Rule<?>> rules()
{
return ImmutableSet.of(
filterNodeRule(),
projectNodeRule());
}
public Rule<FilterNode> filterNodeRule()
{
return new PullUpExpressionInLambdaFilterNodeRule();
}
public Rule<ProjectNode> projectNodeRule()
{
return new PullUpExpressionInLambdaProjectNodeRule();
}
private final class PullUpExpressionInLambdaProjectNodeRule
implements Rule<ProjectNode>
{
@Override
public boolean isEnabled(Session session)
{
return isRuleEnabled(session);
}
@Override
public Pattern<ProjectNode> getPattern()
{
return project();
}
@Override
public Result apply(ProjectNode node, Captures captures, Context context)
{
List<VariableReferenceExpression> inputVariables = node.getSource().getOutputVariables();
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> pulledExpressionMapBuilder = ImmutableMap.builder();
Assignments.Builder newProjectWithLambda = Assignments.builder();
for (Map.Entry<VariableReferenceExpression, RowExpression> entry : node.getAssignments().getMap().entrySet()) {
RowExpression rowExpression = entry.getValue();
Set<RowExpression> candidates = getCandidateRowExpression(determinismEvaluator, functionResolution, inputVariables, rowExpression);
if (candidates.isEmpty()) {
newProjectWithLambda.put(entry.getKey(), entry.getValue());
}
else {
Map<RowExpression, VariableReferenceExpression> mapping = candidates.stream().collect(toImmutableMap(identity(), x -> context.getVariableAllocator().newVariable(x)));
pulledExpressionMapBuilder.putAll(mapping.entrySet().stream().collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey)));
RowExpression rewrittenExpression = rowExpression.accept(new ExpressionRewriter(mapping), null);
newProjectWithLambda.put(entry.getKey(), rewrittenExpression);
}
}
Map<VariableReferenceExpression, RowExpression> pulledExpressionMap = pulledExpressionMapBuilder.build();
if (pulledExpressionMap.isEmpty()) {
return Result.empty();
}
PlanNode planNode = PlannerUtils.addProjections(node.getSource(), context.getIdAllocator(), pulledExpressionMap);
return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), planNode, newProjectWithLambda.build()));
}
}
private final class PullUpExpressionInLambdaFilterNodeRule
implements Rule<FilterNode>
{
@Override
public boolean isEnabled(Session session)
{
return isRuleEnabled(session);
}
@Override
public Pattern<FilterNode> getPattern()
{
return filter();
}
@Override
public Result apply(FilterNode filterNode, Captures captures, Context context)
{
RowExpression predicate = filterNode.getPredicate();
List<VariableReferenceExpression> inputVariables = filterNode.getSource().getOutputVariables();
Set<RowExpression> candidates = getCandidateRowExpression(determinismEvaluator, functionResolution, inputVariables, predicate);
if (candidates.isEmpty()) {
return Result.empty();
}
Map<RowExpression, VariableReferenceExpression> mapping = candidates.stream().collect(toImmutableMap(identity(), x -> context.getVariableAllocator().newVariable(x)));
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> pulledExpressionMapBuilder = ImmutableMap.builder();
pulledExpressionMapBuilder.putAll(mapping.entrySet().stream().collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey)));
RowExpression rewrittenExpression = predicate.accept(new ExpressionRewriter(mapping), null);
PlanNode planNode = PlannerUtils.addProjections(filterNode.getSource(), context.getIdAllocator(), pulledExpressionMapBuilder.build());
return Result.ofPlanNode(
new ProjectNode(
context.getIdAllocator().getNextId(),
new FilterNode(filterNode.getSourceLocation(), context.getIdAllocator().getNextId(), planNode, rewrittenExpression),
identityAssignments(filterNode.getOutputVariables())));
}
}
private static class ValidExpressionExtractor
implements RowExpressionVisitor<Boolean, Boolean>
{
// Bind expression will complicate the lambda expression, we apply this optimization before DesugarLambdaRule. And if there are bind expression, skip
private static final List<SpecialFormExpression.Form> UNSUPPORTED_TYPES = ImmutableList.of(SpecialFormExpression.Form.BIND);
private static final List<Class<?>> SUPPORTED_JAVA_TYPES = ImmutableList.of(boolean.class, long.class, double.class, Slice.class, Block.class);
private final RowExpressionDeterminismEvaluator determinismEvaluator;
private final FunctionResolution functionResolution;
private final List<VariableReferenceExpression> inputVariables;
private final ImmutableSet.Builder<RowExpression> candidates;
public ValidExpressionExtractor(RowExpressionDeterminismEvaluator determinismEvaluator,
FunctionResolution functionResolution,
List<VariableReferenceExpression> inputVariables,
ImmutableSet.Builder<RowExpression> candidates)
{
this.determinismEvaluator = requireNonNull(determinismEvaluator, "determinismEvaluator is null");
this.functionResolution = requireNonNull(functionResolution, "functionResolution is null");
this.inputVariables = requireNonNull(inputVariables, "inputVariables is null");
this.candidates = requireNonNull(candidates, "candidates is null");
}
@Override
public Boolean visitCall(CallExpression call, Boolean context)
{
// Skip try function as pulling out function within try function can throw exception.
// Skip subscript function as it can throw exception when pull out
if (functionResolution.isTryFunction(call.getFunctionHandle()) || functionResolution.isSubscriptFunction(call.getFunctionHandle())) {
return false;
}
Map<RowExpression, Boolean> validRowExpressionMap = call.getArguments().stream().distinct().collect(toImmutableMap(identity(), x -> x.accept(this, context)));
if (context.equals(Boolean.TRUE)) {
boolean allArgumentsValid = validRowExpressionMap.values().stream().allMatch(x -> x.equals(Boolean.TRUE));
if (!allArgumentsValid) {
candidates.addAll(validRowExpressionMap.entrySet().stream()
.filter(x -> x.getValue().equals(Boolean.TRUE))
.map(Map.Entry::getKey)
.map(x -> getArgumentForRegexTypeExpression(x))
.filter(ValidExpressionExtractor::isSupportedExpression)
.collect(toImmutableList()));
}
return allArgumentsValid && determinismEvaluator.isDeterministic(call);
}
return false;
}
// For the conditional expressions, not all arguments will be evaluated, we only try to extract from the arguments which will always be executed
private static List<RowExpression> getValidArguments(SpecialFormExpression specialForm)
{
List<RowExpression> validArgument;
SpecialFormExpression.Form form = specialForm.getForm();
if (form.equals(SpecialFormExpression.Form.IF) || form.equals(SpecialFormExpression.Form.COALESCE) || form.equals(SpecialFormExpression.Form.WHEN)) {
validArgument = ImmutableList.of(specialForm.getArguments().get(0));
}
else if (form.equals(SpecialFormExpression.Form.SWITCH)) {
validArgument = ImmutableList.of(specialForm.getArguments().get(0), specialForm.getArguments().get(1));
}
else {
validArgument = specialForm.getArguments();
}
return validArgument;
}
// When expression cannot be pulled out, hence if we get a when expression, try to pull out its argument instead
private static RowExpression getArgumentOfWhen(RowExpression expression)
{
if (expression instanceof SpecialFormExpression && ((SpecialFormExpression) expression).getForm().equals(SpecialFormExpression.Form.WHEN)) {
return getArgumentOfWhen(((SpecialFormExpression) expression).getArguments().get(0));
}
return expression;
}
// If the input is a CAST expression to cast to JoniRegexType or LikePatternType (underlying Java type is Regex.class) or is a like_pattern function, return the argument
// Still return even if it's not a cast/like_pattern expression, as these types will be filtered by the isSupportedExpression later
private RowExpression getArgumentForRegexTypeExpression(RowExpression rowExpression)
{
if (rowExpression.getType().getJavaType() == Regex.class && rowExpression instanceof CallExpression
&& (functionResolution.isCastFunction(((CallExpression) rowExpression).getFunctionHandle())
|| functionResolution.isLikePatternFunction(((CallExpression) rowExpression).getFunctionHandle()))) {
CallExpression castExpression = (CallExpression) rowExpression;
return getArgumentForRegexTypeExpression(castExpression.getArguments().get(0));
}
return rowExpression;
}
@Override
public Boolean visitSpecialForm(SpecialFormExpression specialForm, Boolean context)
{
if (UNSUPPORTED_TYPES.contains(specialForm.getForm())) {
return false;
}
List<RowExpression> validArguments = getValidArguments(specialForm);
Map<RowExpression, Boolean> validRowExpressionMap = specialForm.getArguments().stream().distinct().collect(toImmutableMap(identity(), x -> validArguments.contains(x) ? x.accept(this, context) : false));
if (context.equals(Boolean.TRUE)) {
boolean allArgumentsValid = validRowExpressionMap.values().stream().allMatch(x -> x.equals(Boolean.TRUE));
if (!allArgumentsValid) {
candidates.addAll(validRowExpressionMap.entrySet().stream()
.filter(x -> x.getValue().equals(Boolean.TRUE))
.map(Map.Entry::getKey)
.map(ValidExpressionExtractor::getArgumentOfWhen)
.filter(ValidExpressionExtractor::isSupportedExpression)
.collect(toImmutableList()));
}
return allArgumentsValid && determinismEvaluator.isDeterministic(specialForm);
}
return false;
}
@Override
public Boolean visitLambda(LambdaDefinitionExpression lambda, Boolean context)
{
if (lambda.getBody().accept(this, true) && isSupportedExpression(lambda.getBody())) {
candidates.add(lambda.getBody());
}
// For simplicity, we do not pull out lambda expressions
return false;
}
@Override
public Boolean visitVariableReference(VariableReferenceExpression reference, Boolean context)
{
return inputVariables.contains(reference);
}
@Override
public Boolean visitConstant(ConstantExpression literal, Boolean context)
{
return true;
}
@Override
public Boolean visitInputReference(InputReferenceExpression reference, Boolean context)
{
return false;
}
// WHEN expression should only exist within SWITCH expression, and will throw exception in RowExpressionInterpreter, also no byte code generator for standalone WHEN expression
// Pull out LikePatternType and JoniRegexpType out can lead to byte code generation failure because of the underlying Regex type.
private static boolean isSupportedExpression(RowExpression expression)
{
return (expression instanceof CallExpression || (expression instanceof SpecialFormExpression && !((SpecialFormExpression) expression).getForm().equals(SpecialFormExpression.Form.WHEN)))
&& SUPPORTED_JAVA_TYPES.contains(expression.getType().getJavaType());
}
}
private static class ExpressionRewriter
implements RowExpressionVisitor<RowExpression, Void>
{
private final Map<RowExpression, VariableReferenceExpression> expressionMap;
public ExpressionRewriter(Map<RowExpression, VariableReferenceExpression> expressionMap)
{
this.expressionMap = ImmutableMap.copyOf(expressionMap);
}
@Override
public RowExpression visitCall(CallExpression call, Void context)
{
List<RowExpression> rewrittenArguments = call.getArguments().stream().map(argument -> argument.accept(this, null)).collect(toImmutableList());
RowExpression rewritten = new CallExpression(
call.getSourceLocation(),
call.getDisplayName(),
call.getFunctionHandle(),
call.getType(),
rewrittenArguments);
if (expressionMap.containsKey(rewritten)) {
return expressionMap.get(rewritten);
}
if (rowExpressionsNotChanged(call.getArguments(), rewrittenArguments)) {
return call;
}
return rewritten;
}
@Override
public RowExpression visitInputReference(InputReferenceExpression reference, Void context)
{
return reference;
}
@Override
public RowExpression visitConstant(ConstantExpression literal, Void context)
{
return literal;
}
@Override
public RowExpression visitLambda(LambdaDefinitionExpression lambda, Void context)
{
return new LambdaDefinitionExpression(lambda.getSourceLocation(), lambda.getArgumentTypes(), lambda.getArguments(), lambda.getBody().accept(this, context));
}
@Override
public RowExpression visitVariableReference(VariableReferenceExpression reference, Void context)
{
return reference;
}
@Override
public RowExpression visitSpecialForm(SpecialFormExpression specialForm, Void context)
{
List<RowExpression> rewrittenArguments = specialForm.getArguments().stream().map(argument -> argument.accept(this, null)).collect(toImmutableList());
SpecialFormExpression rewritten = new SpecialFormExpression(
specialForm.getForm(),
specialForm.getType(),
rewrittenArguments);
if (expressionMap.containsKey(rewritten)) {
return expressionMap.get(rewritten);
}
if (rowExpressionsNotChanged(specialForm.getArguments(), rewrittenArguments)) {
return specialForm;
}
return rewritten;
}
private boolean rowExpressionsNotChanged(List<RowExpression> original, List<RowExpression> rewritten)
{
checkArgument(original.size() == rewritten.size());
return IntStream.range(0, original.size()).boxed().allMatch(idx -> original.get(idx).equals(rewritten.get(idx)));
}
}
}