JoinPrefilter.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.SemiJoinNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.IntStream;

import static com.facebook.presto.SystemSessionProperties.isJoinPrefilterEnabled;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet;
import static com.facebook.presto.spi.plan.JoinType.INNER;
import static com.facebook.presto.spi.plan.JoinType.LEFT;
import static com.facebook.presto.sql.planner.PlannerUtils.addProjections;
import static com.facebook.presto.sql.planner.PlannerUtils.clonePlanNode;
import static com.facebook.presto.sql.planner.PlannerUtils.isScanFilterProject;
import static com.facebook.presto.sql.planner.PlannerUtils.orNullHashCode;
import static com.facebook.presto.sql.planner.PlannerUtils.projectExpressions;
import static com.facebook.presto.sql.planner.PlannerUtils.restrictOutput;
import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.callOperator;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;

/**
 * This optimizer filter the right side of a join with the unique join keys on the left side of the join. When the join key is wide or
 * there are multiple join keys, we are to do filter on the hash instead of using the keys.
 * It will convert plan from
 * <pre>
 *     - InnerJoin
 *          leftKey = rightKey
 *          - scan l
 *          - scan r
 * </pre>
 * into
 * <pre>
 *     - InnerJoin
 *          leftKey = rightKey
 *          - scan l
 *          - semiJoin
 *              r.rightKey in l.leftKey
 *              - scan r
 *              - distinct aggregation
 *                  group by leftKey
 *                  - scan l
 * </pre>
 * And for join with varchar type
 * <pre>
 *     - InnerJoin
 *          leftKey (varchar) = rightKey (varchar)
 *          - scan l
 *          - scan r
 * </pre>
 * into
 * <pre>
 *     - InnerJoin
 *          leftKey (varchar) = rightKey (varchar)
 *          - scan l
 *          - semiJoin
 *              r.rightKeyHash in l.leftKeyHash
 *              - project
 *                  r.rightKeyHash = xx_hash64(r.rightKey)
 *                  - scan r
 *              - distinct aggregation
 *                  group by leftKeyHash
 *                  - project
 *                      l.leftKeyHash = xx_hash64(l.leftKey)
 *                      - scan l
 * </pre>
 * And for join with multiple keys
 * <pre>
 *     - InnerJoin
 *          leftKey1 = rightKey1 and leftKey2 = rightKey2
 *          - scan l
 *          - scan r
 * </pre>
 * into
 * <pre>
 *     - InnerJoin
 *          leftKey1 = rightKey1 and leftKey2 = rightKey2
 *          - scan l
 *          - semiJoin
 *              r.rightKeysHash in l.leftKeysHash
 *              - project
 *                  r.rightKeysHash = combine_hash(xx_hash64(rightKey1), xx_hash64(rightKey2))
 *                  - scan r
 *              - distinct aggregation
 *                  group by leftKeysHash
 *                  - project
 *                      l.leftKeysHash = combine_hash(xx_hash64(leftKey1), xx_hash64(leftKey2))
 *                      - scan l
 * </pre>
 */

public class JoinPrefilter
        implements PlanOptimizer
{
    private final Metadata metadata;
    private boolean isEnabledForTesting;

    public JoinPrefilter(Metadata metadata)
    {
        this.metadata = requireNonNull(metadata, "metadata is null");
    }

    @Override
    public void setEnabledForTesting(boolean isSet)
    {
        isEnabledForTesting = isSet;
    }

    @Override
    public boolean isEnabled(Session session)
    {
        return isEnabledForTesting || isJoinPrefilterEnabled(session);
    }

    @Override
    public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector)
    {
        if (isEnabled(session)) {
            Rewriter rewriter = new Rewriter(session, metadata, idAllocator, variableAllocator, metadata.getFunctionAndTypeManager());
            PlanNode rewritten = SimplePlanRewriter.rewriteWith(rewriter, plan, null);
            return PlanOptimizerResult.optimizerResult(rewritten, rewriter.isPlanChanged());
        }

        return PlanOptimizerResult.optimizerResult(plan, false);
    }

    private static class Rewriter
            extends SimplePlanRewriter<Void>
    {
        private final Session session;
        private final Metadata metadata;
        private final PlanNodeIdAllocator idAllocator;
        private final VariableAllocator variableAllocator;
        private final FunctionAndTypeManager functionAndTypeManager;
        private boolean planChanged;

        private Rewriter(Session session, Metadata metadata, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, FunctionAndTypeManager functionAndTypeManager)
        {
            this.session = requireNonNull(session, "session is null");
            this.metadata = requireNonNull(metadata, "functionAndTypeManager is null");
            this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
            this.variableAllocator = requireNonNull(variableAllocator, "idAllocator is null");
            this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
        }

        @Override
        public PlanNode visitJoin(JoinNode node, RewriteContext<Void> context)
        {
            PlanNode left = node.getLeft();
            PlanNode right = node.getRight();

            PlanNode rewrittenLeft = rewriteWith(this, left);
            PlanNode rewrittenRight = rewriteWith(this, right);
            List<EquiJoinClause> equiJoinClause = node.getCriteria();

            // We apply this for only left and inner join and the left side of the join is a simple scan
            if ((node.getType() == LEFT || node.getType() == INNER) && isScanFilterProject(rewrittenLeft) && !node.getCriteria().isEmpty()) {
                List<VariableReferenceExpression> leftKeyList = equiJoinClause.stream().map(EquiJoinClause::getLeft).collect(toImmutableList());
                List<VariableReferenceExpression> rightKeyList = equiJoinClause.stream().map(EquiJoinClause::getRight).collect(toImmutableList());
                checkState(IntStream.range(0, leftKeyList.size()).boxed().allMatch(i -> leftKeyList.get(i).getType().equals(rightKeyList.get(i).getType())));

                boolean hashJoinKey = leftKeyList.size() > 1 || (leftKeyList.get(0).getType().equals(VARCHAR) || leftKeyList.get(0).getType() instanceof VarcharType);

                // First create a SELECT DISTINCT leftKey FROM left
                Map<VariableReferenceExpression, VariableReferenceExpression> leftVarMap = new HashMap();
                PlanNode leftKeys = clonePlanNode(rewrittenLeft, session, metadata, idAllocator, leftKeyList, leftVarMap);
                ImmutableList.Builder<RowExpression> expressionsToProject = ImmutableList.builder();
                if (hashJoinKey) {
                    RowExpression hashExpression = getVariableHash(leftKeyList);
                    expressionsToProject.add(hashExpression);
                }
                else {
                    expressionsToProject.add(leftVarMap.get(leftKeyList.get(0)));
                }
                PlanNode projectNode = projectExpressions(leftKeys, idAllocator, variableAllocator, expressionsToProject.build(), ImmutableList.of());

                VariableReferenceExpression rightKeyToFilter = rightKeyList.get(0);
                if (hashJoinKey) {
                    RowExpression hashExpression = getVariableHash(rightKeyList);
                    rightKeyToFilter = variableAllocator.newVariable(hashExpression);
                    rewrittenRight = addProjections(rewrittenRight, idAllocator, ImmutableMap.of(rightKeyToFilter, hashExpression));
                }

                // DISTINCT on the leftkey or hash if wide column
                PlanNode filteringSource = new AggregationNode(
                        node.getLeft().getSourceLocation(),
                        idAllocator.getNextId(),
                        projectNode,
                        ImmutableMap.of(),
                        singleGroupingSet(projectNode.getOutputVariables()),
                        projectNode.getOutputVariables(),
                        AggregationNode.Step.SINGLE,
                        Optional.empty(),
                        Optional.empty(),
                        Optional.empty());

                // There should be only one output variable. Project that
                filteringSource = projectExpressions(filteringSource, idAllocator, variableAllocator, ImmutableList.of(filteringSource.getOutputVariables().get(0)), ImmutableList.of());

                // Now we add a semijoin as the right side
                VariableReferenceExpression semiJoinOutput = variableAllocator.newVariable("semiJoinOutput", BOOLEAN);
                SemiJoinNode semiJoinNode = new SemiJoinNode(
                        node.getRight().getSourceLocation(),
                        idAllocator.getNextId(),
                        node.getStatsEquivalentPlanNode(),
                        rewrittenRight,
                        filteringSource,
                        rightKeyToFilter,
                        filteringSource.getOutputVariables().get(0),
                        semiJoinOutput,
                        Optional.empty(),
                        Optional.empty(),
                        Optional.empty(),
                        ImmutableMap.of());

                rewrittenRight = new FilterNode(semiJoinNode.getSourceLocation(), idAllocator.getNextId(), semiJoinNode, semiJoinOutput);
                if (rewrittenRight.getOutputVariables().size() > node.getRight().getOutputVariables().size()) {
                    rewrittenRight = restrictOutput(rewrittenRight, idAllocator, node.getRight().getOutputVariables());
                }
            }

            if (rewrittenLeft != node.getLeft() || rewrittenRight != node.getRight()) {
                planChanged = true;
                return replaceChildren(node, ImmutableList.of(rewrittenLeft, rewrittenRight));
            }

            return node;
        }

        public boolean isPlanChanged()
        {
            return planChanged;
        }

        private RowExpression getVariableHash(List<VariableReferenceExpression> inputVariables)
        {
            List<CallExpression> hashExpressionList = inputVariables.stream().map(keyVariable ->
                    callOperator(functionAndTypeManager.getFunctionAndTypeResolver(), OperatorType.XX_HASH_64, BIGINT, keyVariable)).collect(toImmutableList());
            RowExpression hashExpression = hashExpressionList.get(0);
            if (hashExpressionList.size() > 1) {
                hashExpression = orNullHashCode(hashExpression);
                for (int i = 1; i < hashExpressionList.size(); ++i) {
                    hashExpression = call(functionAndTypeManager, "combine_hash", BIGINT, hashExpression, orNullHashCode(hashExpressionList.get(i)));
                }
            }
            return hashExpression;
        }
    }
}