CrossJoinWithOrFilterToInnerJoin.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.CastType;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.UnnestNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.PlannerUtils;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.SystemSessionProperties.isRewriteCrossJoinOrToInnerJoinEnabled;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.DateType.DATE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.expressions.LogicalRowExpressions.and;
import static com.facebook.presto.expressions.LogicalRowExpressions.extractConjuncts;
import static com.facebook.presto.expressions.LogicalRowExpressions.extractDisjuncts;
import static com.facebook.presto.matching.Capture.newCapture;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.SWITCH;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.WHEN;
import static com.facebook.presto.sql.planner.plan.Patterns.filter;
import static com.facebook.presto.sql.planner.plan.Patterns.join;
import static com.facebook.presto.sql.planner.plan.Patterns.source;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.coalesceNullToFalse;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.facebook.presto.sql.relational.Expressions.constantNull;
import static com.facebook.presto.sql.relational.Expressions.not;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.util.Objects.requireNonNull;

/**
 * Inner join with "or" inside join clause will be run as cross join with filter, which can degrade performance, especially when selectivity of join is low.
 * When the join condition has pattern of l_key1=r_key1 or l_key1=r_key2, we can rewrite it to a inner join. For example:
 * <pre>
 * - Filter l_key1=r_key1 or l_key1=r_key2
 *      - Cross join
 *          - scan l
 *          - scan r
 * </pre>
 * into:
 * <pre>
 *     - Project
 *          - Filter
 *              CASE field WHEN 1 l_key1 = r_key1 WHEN 2 NOT(coalesce(l_key1 = r_key1, false)) and l_key2 = r_key2 else NULL END
 *              - Inner Join
 *                  l_key = r_key and l_field = r_field
 *                  - Project
 *                      key1 := key1
 *                      key2 := key2
 *                      field := field
 *                      key := case field when 1 then key1 when 2 then key2 else null end
 *                      - Unnest
 *                          field <- unnest arr
 *                          - Project
 *                              key1 := key1
 *                              key2 := key2
 *                              arr := array[1, 2]
 *                              _ scan l
 *                                  key1, key2
 *                   - Project
 *                      key1 := key1
 *                      key2 := key2
 *                      field := field
 *                      key := case field when 1 then key1 when 2 then key2 else null end
 *                      - Unnest
 *                          field <- unnest arr
 *                          - Project
 *                              key1 := key1
 *                              key2 := key2
 *                              arr := array[1, 2]
 *                              _ scan r
 *                                  key1, key2
 * </pre>
 */
public class CrossJoinWithOrFilterToInnerJoin
        implements Rule<FilterNode>
{
    private static final List<Type> SUPPORTED_JOIN_KEY_TYPE = ImmutableList.of(BIGINT, INTEGER, VARCHAR, DATE);
    private static final Capture<JoinNode> CHILD = newCapture();

    private static final Pattern<FilterNode> PATTERN = filter()
            .with(source().matching(join().matching(x -> x.getType().equals(JoinType.INNER) && x.getCriteria().isEmpty()).capturedAs(CHILD)));

    private final FunctionAndTypeManager functionAndTypeManager;

    public CrossJoinWithOrFilterToInnerJoin(FunctionAndTypeManager functionAndTypeManager)
    {
        this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
    }

    // Valid only if it's an equal expression, and one argument from left of join and the other argument from right of join.
    private static boolean isValidExpression(RowExpression rowExpression, List<VariableReferenceExpression> leftInput, List<VariableReferenceExpression> rightInput)
    {
        if (!(rowExpression instanceof CallExpression) || !((CallExpression) rowExpression).getDisplayName().equals("EQUAL")) {
            return false;
        }
        CallExpression callExpression = (CallExpression) rowExpression;
        RowExpression argument0 = callExpression.getArguments().get(0);
        RowExpression argument1 = callExpression.getArguments().get(1);
        return SUPPORTED_JOIN_KEY_TYPE.containsAll(ImmutableList.of(argument0.getType(), argument1.getType()))
                && ((leftInput.contains(argument0) && rightInput.contains(argument1)) || (leftInput.contains(argument1) && rightInput.contains(argument0)));
    }

    public static RowExpression getCandidateOrExpression(RowExpression filterPredicate, List<VariableReferenceExpression> leftInput, List<VariableReferenceExpression> rightInput)
    {
        List<RowExpression> andConjuncts = extractConjuncts(filterPredicate);
        for (RowExpression conjunct : andConjuncts) {
            List<RowExpression> equalExpressionList = extractDisjuncts(conjunct);
            if (!equalExpressionList.isEmpty() && equalExpressionList.stream().allMatch(x -> isValidExpression(x, leftInput, rightInput))) {
                return conjunct;
            }
        }
        return null;
    }

    @Override
    public Pattern<FilterNode> getPattern()
    {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session)
    {
        return isRewriteCrossJoinOrToInnerJoinEnabled(session);
    }

    private RewrittenJoinInput rewriteJoinInput(List<VariableReferenceExpression> variablesInOrCondition, PlanNode joinInput, Type finalJoinKeyType, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator)
    {
        Map<VariableReferenceExpression, VariableReferenceExpression> castVariableMap = new HashMap<>();
        Map<VariableReferenceExpression, RowExpression> castExpressionMap = new HashMap<>();
        if (!variablesInOrCondition.stream().allMatch(x -> x.getType().equals(finalJoinKeyType))) {
            // cast all to varchar type
            for (int i = 0; i < variablesInOrCondition.size(); ++i) {
                CallExpression castExpression = call("CAST", functionAndTypeManager.lookupCast(CastType.CAST, variablesInOrCondition.get(i).getType(), VARCHAR), VARCHAR, variablesInOrCondition.get(i));
                VariableReferenceExpression castVariable = variableAllocator.newVariable(castExpression);
                castVariableMap.put(variablesInOrCondition.get(i), castVariable);
                castExpressionMap.put(castVariable, castExpression);
            }
        }

        ImmutableList.Builder<RowExpression> constantsArgument = ImmutableList.builder();
        for (int i = 0; i < variablesInOrCondition.size(); ++i) {
            constantsArgument.add(constant((long) i + 1, INTEGER));
        }
        CallExpression arrayConstruct = call(functionAndTypeManager, "array_constructor", new ArrayType(INTEGER), constantsArgument.build());

        VariableReferenceExpression arrayVariable = variableAllocator.newVariable(arrayConstruct);
        ImmutableMap.Builder<VariableReferenceExpression, RowExpression> projectAssignment = ImmutableMap.builder();
        PlanNode project = PlannerUtils.addProjections(joinInput, idAllocator, projectAssignment.put(arrayVariable, arrayConstruct).putAll(castExpressionMap).build());

        VariableReferenceExpression unnestVariable = variableAllocator.newVariable("field", INTEGER);
        UnnestNode unnest = new UnnestNode(joinInput.getSourceLocation(),
                idAllocator.getNextId(),
                project,
                project.getOutputVariables().stream().filter(x -> !x.equals(arrayVariable)).collect(toImmutableList()),
                ImmutableMap.of(arrayVariable, ImmutableList.of(unnestVariable)),
                Optional.empty());

        ImmutableList.Builder<RowExpression> whenExpression = ImmutableList.builder();
        whenExpression.add(unnestVariable);
        for (int i = 0; i < variablesInOrCondition.size(); ++i) {
            whenExpression.add(new SpecialFormExpression(WHEN, finalJoinKeyType, constant((long) i + 1, INTEGER), castVariableMap.isEmpty() ? variablesInOrCondition.get(i) : castVariableMap.get(variablesInOrCondition.get(i))));
        }
        whenExpression.add(constantNull(finalJoinKeyType));
        SpecialFormExpression joinKeyExpression = new SpecialFormExpression(SWITCH, finalJoinKeyType, whenExpression.build());
        VariableReferenceExpression newJoinVariable = variableAllocator.newVariable(joinKeyExpression);
        PlanNode rewrittenInput = PlannerUtils.addProjections(unnest, idAllocator, variableAllocator, ImmutableList.of(joinKeyExpression), ImmutableList.of(newJoinVariable));
        return new RewrittenJoinInput(rewrittenInput, unnestVariable, newJoinVariable);
    }

    private VariableReferenceExpression getVariableInEqualComparison(RowExpression rowExpression, List<VariableReferenceExpression> candidate)
    {
        checkArgument(rowExpression instanceof CallExpression && ((CallExpression) rowExpression).getDisplayName().equals("EQUAL"));
        CallExpression callExpression = (CallExpression) rowExpression;
        RowExpression argument0 = callExpression.getArguments().get(0);
        RowExpression argument1 = callExpression.getArguments().get(1);
        if (candidate.contains(argument0)) {
            return (VariableReferenceExpression) argument0;
        }
        else if (candidate.contains(argument1)) {
            return (VariableReferenceExpression) argument1;
        }
        checkState(false, "argument does not exist in candidate list");
        return null;
    }

    @Override
    public Result apply(FilterNode filterNode, Captures captures, Context context)
    {
        JoinNode joinNode = captures.get(CHILD);
        if (!(joinNode.getType().equals(JoinType.INNER) && joinNode.getCriteria().isEmpty())) {
            return Result.empty();
        }
        RowExpression candidateOrExpressions = getCandidateOrExpression(filterNode.getPredicate(), joinNode.getLeft().getOutputVariables(), joinNode.getRight().getOutputVariables());
        if (candidateOrExpressions == null) {
            return Result.empty();
        }
        List<RowExpression> andConjuncts = extractConjuncts(filterNode.getPredicate());
        List<RowExpression> leftAndConjuncts = andConjuncts.stream().filter(x -> !x.equals(candidateOrExpressions)).collect(toImmutableList());
        List<RowExpression> equalExpressionList = extractDisjuncts(candidateOrExpressions);
        List<VariableReferenceExpression> variablesUsedInOrComparisionFromLeft = equalExpressionList.stream().map(x -> getVariableInEqualComparison(x, joinNode.getLeft().getOutputVariables())).collect(toImmutableList());
        List<VariableReferenceExpression> variablesUsedInOrComparisionFromRight = equalExpressionList.stream().map(x -> getVariableInEqualComparison(x, joinNode.getRight().getOutputVariables())).collect(toImmutableList());

        if (variablesUsedInOrComparisionFromLeft.isEmpty() || variablesUsedInOrComparisionFromRight.isEmpty()) {
            return Result.empty();
        }
        // Apply optimization only when the variables in or condition is of type int/bigint/varchar/date types.
        if (variablesUsedInOrComparisionFromLeft.stream().anyMatch(x -> !SUPPORTED_JOIN_KEY_TYPE.contains(x.getType()))
                || variablesUsedInOrComparisionFromRight.stream().anyMatch(x -> !SUPPORTED_JOIN_KEY_TYPE.contains(x.getType()))) {
            return Result.empty();
        }

        // Check if all candidate variables are of the same type
        Type joinKeyType = VARCHAR;
        List<Type> leftOrPredicateTypes = variablesUsedInOrComparisionFromLeft.stream().map(x -> x.getType()).distinct().collect(toImmutableList());
        List<Type> rightOrPredicateTypes = variablesUsedInOrComparisionFromRight.stream().map(x -> x.getType()).distinct().collect(toImmutableList());
        if (leftOrPredicateTypes.size() == 1 && rightOrPredicateTypes.size() == 1 && leftOrPredicateTypes.get(0).equals(rightOrPredicateTypes.get(0))) {
            joinKeyType = leftOrPredicateTypes.get(0);
        }

        RewrittenJoinInput leftJoinInput = rewriteJoinInput(variablesUsedInOrComparisionFromLeft, joinNode.getLeft(), joinKeyType, context.getVariableAllocator(), context.getIdAllocator());
        RewrittenJoinInput rightJoinInput = rewriteJoinInput(variablesUsedInOrComparisionFromRight, joinNode.getRight(), joinKeyType, context.getVariableAllocator(), context.getIdAllocator());

        ImmutableList.Builder<VariableReferenceExpression> joinOutput = ImmutableList.builder();
        joinOutput.add(leftJoinInput.getJoinKey()).add(leftJoinInput.getUnnestIndex()).addAll(joinNode.getOutputVariables());
        JoinNode newJoinNode = new JoinNode(joinNode.getSourceLocation(),
                context.getIdAllocator().getNextId(),
                joinNode.getType(),
                leftJoinInput.getNode(),
                rightJoinInput.getNode(),
                ImmutableList.of(new EquiJoinClause(leftJoinInput.getJoinKey(), rightJoinInput.getJoinKey()),
                        new EquiJoinClause(leftJoinInput.getUnnestIndex(), rightJoinInput.getUnnestIndex())),
                joinOutput.build(),
                joinNode.getFilter(),
                Optional.empty(),
                Optional.empty(),
                joinNode.getDistributionType(),
                joinNode.getDynamicFilters());

        // Deduplicate the rows which matched multiple times
        ImmutableList.Builder<RowExpression> whenExpression = ImmutableList.builder();
        whenExpression.add(leftJoinInput.getUnnestIndex());
        for (int i = 0; i < equalExpressionList.size(); ++i) {
            ImmutableList.Builder<RowExpression> matchCondition = ImmutableList.builder();
            for (int j = 0; j < i; ++j) {
                matchCondition.add(not(functionAndTypeManager, coalesceNullToFalse(equalExpressionList.get(j))));
            }
            matchCondition.add(equalExpressionList.get(i));
            whenExpression.add(new SpecialFormExpression(WHEN, BOOLEAN, constant((long) i + 1, INTEGER), and(matchCondition.build())));
        }
        whenExpression.add(constantNull(BOOLEAN));
        SpecialFormExpression dedupFilter = new SpecialFormExpression(SWITCH, BOOLEAN, whenExpression.build());
        FilterNode newFilterNode = new FilterNode(joinNode.getSourceLocation(), context.getIdAllocator().getNextId(), newJoinNode, dedupFilter);
        if (!leftAndConjuncts.isEmpty()) {
            newFilterNode = new FilterNode(filterNode.getSourceLocation(), context.getIdAllocator().getNextId(), newFilterNode, and(leftAndConjuncts));
        }
        // So that the output of new node is exactly the same
        Assignments.Builder identity = Assignments.builder();
        identity.putAll(filterNode.getOutputVariables().stream().collect(toImmutableMap(x -> x, x -> x)));
        ProjectNode projectUnusedOutput = new ProjectNode(context.getIdAllocator().getNextId(), newFilterNode, identity.build());
        return Result.ofPlanNode(projectUnusedOutput);
    }

    private static class RewrittenJoinInput
    {
        private final PlanNode node;
        private final VariableReferenceExpression unnestIndex;
        private final VariableReferenceExpression joinKey;

        public RewrittenJoinInput(PlanNode node, VariableReferenceExpression unnestIndex, VariableReferenceExpression joinKey)
        {
            this.node = node;
            this.unnestIndex = unnestIndex;
            this.joinKey = joinKey;
        }

        public PlanNode getNode()
        {
            return node;
        }

        public VariableReferenceExpression getJoinKey()
        {
            return joinKey;
        }

        public VariableReferenceExpression getUnnestIndex()
        {
            return unnestIndex;
        }
    }
}