PushDownFilterExpressionEvaluationThroughCrossJoin.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.PlannerUtils;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static com.facebook.presto.SystemSessionProperties.getPushdownFilterExpressionEvaluationThroughCrossJoinStrategy;
import static com.facebook.presto.expressions.LogicalRowExpressions.extractConjuncts;
import static com.facebook.presto.expressions.LogicalRowExpressions.extractDisjuncts;
import static com.facebook.presto.matching.Capture.newCapture;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.PushDownFilterThroughCrossJoinStrategy.DISABLED;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.PushDownFilterThroughCrossJoinStrategy.REWRITTEN_TO_INNER_JOIN;
import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.rewriteExpressionWithCSE;
import static com.facebook.presto.sql.planner.VariablesExtractor.extractAll;
import static com.facebook.presto.sql.planner.iterative.rule.CrossJoinWithArrayContainsToInnerJoin.getCandidateArrayContainsExpression;
import static com.facebook.presto.sql.planner.iterative.rule.CrossJoinWithArrayNotContainsToAntiJoin.getCandidateArrayNotContainsExpression;
import static com.facebook.presto.sql.planner.iterative.rule.CrossJoinWithOrFilterToInnerJoin.getCandidateOrExpression;
import static com.facebook.presto.sql.planner.plan.Patterns.filter;
import static com.facebook.presto.sql.planner.plan.Patterns.join;
import static com.facebook.presto.sql.planner.plan.Patterns.source;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;
import static java.util.stream.Stream.concat;

/**
 * Output of cross join is larger than input, push down expression evaluation can save calculation cost.
 * <pre>
 *     - Filter l_key1 = cardinality(r_key1)
 *          - Cross Join
 *              - scan l
 *              - scan r
 * </pre>
 * to
 * <pre>
 *     - Filter l_key1 = card
 *          - Cross Join
 *              - scan l
 *              - project
 *                  card := cardinality(r_key1)
 *                  - scan r
 * </pre>
 */
public class PushDownFilterExpressionEvaluationThroughCrossJoin
        implements Rule<FilterNode>
{
    private static final Capture<JoinNode> CHILD = newCapture();

    private static final Pattern<FilterNode> PATTERN = filter()
            .with(source().matching(join().matching(x -> x.getCriteria().isEmpty() && x.getType().equals(JoinType.INNER)).capturedAs(CHILD)));

    private final FunctionAndTypeManager functionAndTypeManager;
    private final RowExpressionDeterminismEvaluator determinismEvaluator;

    public PushDownFilterExpressionEvaluationThroughCrossJoin(FunctionAndTypeManager functionAndTypeManager)
    {
        this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
        this.determinismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager);
    }

    private static boolean canRewriteToInnerJoin(FunctionResolution functionResolution, RowExpression filter, List<VariableReferenceExpression> left, List<VariableReferenceExpression> right)
    {
        return getCandidateOrExpression(filter, left, right) != null
                || getCandidateArrayContainsExpression(functionResolution, filter, left, right) != null
                || getCandidateArrayNotContainsExpression(functionResolution, filter, left, right) != null;
    }

    @Override
    public Pattern<FilterNode> getPattern()
    {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session)
    {
        return !getPushdownFilterExpressionEvaluationThroughCrossJoinStrategy(session).equals(DISABLED);
    }

    @Override
    public Result apply(FilterNode filterNode, Captures captures, Context context)
    {
        JoinNode joinNode = captures.get(CHILD);
        FunctionResolution functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
        List<Set<RowExpression>> rowExpressionToProject = getRowExpressions(functionResolution, filterNode.getPredicate(), joinNode.getLeft().getOutputVariables(), joinNode.getRight().getOutputVariables());
        if (rowExpressionToProject.stream().allMatch(x -> x.isEmpty())) {
            return Result.empty();
        }
        Map<RowExpression, VariableReferenceExpression> rewrittenExpressionMap = concat(rowExpressionToProject.get(0).stream(), rowExpressionToProject.get(1).stream())
                .collect(toImmutableMap(identity(), x -> context.getVariableAllocator().newVariable(x)));
        RowExpression rewrittenFilter = rewriteExpressionWithCSE(filterNode.getPredicate(), rewrittenExpressionMap);

        Map<VariableReferenceExpression, RowExpression> leftAssignment = rowExpressionToProject.get(0).stream().collect(toImmutableMap(x -> rewrittenExpressionMap.get(x), identity()));
        Map<VariableReferenceExpression, RowExpression> rightAssignment = rowExpressionToProject.get(1).stream().collect(toImmutableMap(x -> rewrittenExpressionMap.get(x), identity()));

        PlanNode leftInput = joinNode.getLeft();
        if (!leftAssignment.isEmpty()) {
            leftInput = PlannerUtils.addProjections(joinNode.getLeft(), context.getIdAllocator(), leftAssignment);
        }
        PlanNode rightInput = joinNode.getRight();
        if (!rightAssignment.isEmpty()) {
            rightInput = PlannerUtils.addProjections(joinNode.getRight(), context.getIdAllocator(), rightAssignment);
        }

        // Only enable if the cross join can be rewritten to inner join after the rewrite
        if (getPushdownFilterExpressionEvaluationThroughCrossJoinStrategy(context.getSession()).equals(REWRITTEN_TO_INNER_JOIN)
                && !canRewriteToInnerJoin(functionResolution, rewrittenFilter, leftInput.getOutputVariables(), rightInput.getOutputVariables())) {
            return Result.empty();
        }

        Assignments.Builder identity = Assignments.builder();
        identity.putAll(filterNode.getOutputVariables().stream().collect(toImmutableMap(identity(), identity())));
        return Result.ofPlanNode(
                new ProjectNode(
                        context.getIdAllocator().getNextId(),
                        new FilterNode(
                                filterNode.getSourceLocation(),
                                context.getIdAllocator().getNextId(),
                                new JoinNode(
                                        joinNode.getSourceLocation(),
                                        context.getIdAllocator().getNextId(),
                                        joinNode.getType(),
                                        leftInput,
                                        rightInput,
                                        joinNode.getCriteria(),
                                        ImmutableList.<VariableReferenceExpression>builder()
                                                .addAll(leftInput.getOutputVariables())
                                                .addAll(rightInput.getOutputVariables())
                                                .build(),
                                        joinNode.getFilter(),
                                        joinNode.getLeftHashVariable(),
                                        joinNode.getRightHashVariable(),
                                        joinNode.getDistributionType(),
                                        joinNode.getDynamicFilters()),
                                rewrittenFilter),
                        identity.build()));
    }

    // TODO: this function only works for filter in form of or condition and array contains function etc. make it generic to work for all RowExpressions
    private List<Set<RowExpression>> getRowExpressions(FunctionResolution functionResolution, RowExpression filterPredicate, List<VariableReferenceExpression> left, List<VariableReferenceExpression> right)
    {
        List<Set<RowExpression>> candidateFromOrCondition = getRowExpressionsFromOrCondition(filterPredicate, left, right);
        List<Set<RowExpression>> candidateFromArrayContains = getRowExpressionsFromArrayContains(functionResolution, filterPredicate, left, right);
        List<Set<RowExpression>> candidateFromArrayNotContains = getRowExpressionsFromArrayNotContains(functionResolution, filterPredicate, left, right);

        ImmutableSet.Builder<RowExpression> leftCandidate = ImmutableSet.builder();
        leftCandidate.addAll(candidateFromOrCondition.get(0));
        leftCandidate.addAll(candidateFromArrayContains.get(0));
        leftCandidate.addAll(candidateFromArrayNotContains.get(0));

        ImmutableSet.Builder<RowExpression> rightCandidate = ImmutableSet.builder();
        rightCandidate.addAll(candidateFromOrCondition.get(1));
        rightCandidate.addAll(candidateFromArrayContains.get(1));
        rightCandidate.addAll(candidateFromArrayNotContains.get(1));

        return ImmutableList.of(leftCandidate.build(), rightCandidate.build());
    }

    private List<Set<RowExpression>> getRowExpressionsFromOrCondition(RowExpression filterPredicate, List<VariableReferenceExpression> left, List<VariableReferenceExpression> right)
    {
        Set<RowExpression> leftRowExpression = new HashSet<>();
        Set<RowExpression> rightRowExpression = new HashSet<>();
        for (RowExpression conjunct : extractConjuncts(filterPredicate)) {
            for (RowExpression disjunct : extractDisjuncts(conjunct)) {
                if (disjunct instanceof CallExpression && ((CallExpression) disjunct).getDisplayName().equals("EQUAL")) {
                    CallExpression callExpression = (CallExpression) disjunct;
                    addCandidateExpression(callExpression.getArguments().get(0), left, right, leftRowExpression, rightRowExpression);
                    addCandidateExpression(callExpression.getArguments().get(1), left, right, leftRowExpression, rightRowExpression);
                }
            }
        }
        return ImmutableList.of(leftRowExpression, rightRowExpression);
    }

    private List<Set<RowExpression>> getRowExpressionsFromArrayContains(FunctionResolution functionResolution, RowExpression filterPredicate, List<VariableReferenceExpression> left, List<VariableReferenceExpression> right)
    {
        Set<RowExpression> leftRowExpression = new HashSet<>();
        Set<RowExpression> rightRowExpression = new HashSet<>();
        if (filterPredicate instanceof CallExpression && functionResolution.isArrayContainsFunction(((CallExpression) filterPredicate).getFunctionHandle())) {
            CallExpression callExpression = (CallExpression) filterPredicate;
            addCandidateExpression(callExpression.getArguments().get(0), left, right, leftRowExpression, rightRowExpression);
            addCandidateExpression(callExpression.getArguments().get(1), left, right, leftRowExpression, rightRowExpression);
        }
        return ImmutableList.of(leftRowExpression, rightRowExpression);
    }

    private List<Set<RowExpression>> getRowExpressionsFromArrayNotContains(FunctionResolution functionResolution, RowExpression filterPredicate, List<VariableReferenceExpression> left, List<VariableReferenceExpression> right)
    {
        Set<RowExpression> leftRowExpression = new HashSet<>();
        Set<RowExpression> rightRowExpression = new HashSet<>();
        if (PlannerUtils.isNegationExpression(functionResolution, filterPredicate)) {
            RowExpression argument = filterPredicate.getChildren().get(0);
            if (argument instanceof CallExpression && functionResolution.isArrayContainsFunction(((CallExpression) argument).getFunctionHandle())) {
                CallExpression callExpression = (CallExpression) argument;
                addCandidateExpression(callExpression.getArguments().get(0), left, right, leftRowExpression, rightRowExpression);
                addCandidateExpression(callExpression.getArguments().get(1), left, right, leftRowExpression, rightRowExpression);
            }
        }
        return ImmutableList.of(leftRowExpression, rightRowExpression);
    }

    private void addCandidateExpression(RowExpression candidate, List<VariableReferenceExpression> left, List<VariableReferenceExpression> right, Set<RowExpression> leftRowExpression, Set<RowExpression> rightRowExpression)
    {
        List<VariableReferenceExpression> variablesInExpression = extractAll(candidate);
        if (!variablesInExpression.isEmpty() && determinismEvaluator.isDeterministic(candidate) && !(candidate instanceof VariableReferenceExpression)) {
            if (left.containsAll(variablesInExpression)) {
                leftRowExpression.add(candidate);
            }
            else if (right.containsAll(variablesInExpression)) {
                rightRowExpression.add(candidate);
            }
        }
    }
}