ShardJoins.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.cost.StatsCalculator;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.UnnestNode;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.PlannerUtils;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;

import static com.facebook.presto.SystemSessionProperties.getJoinShardCount;
import static com.facebook.presto.SystemSessionProperties.getShardedJoinStrategy;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.spi.plan.JoinType.FULL;
import static com.facebook.presto.spi.plan.JoinType.RIGHT;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.ShardedJoinStrategy.ALWAYS;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.ShardedJoinStrategy.COST_BASED;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.ShardedJoinStrategy.DISABLED;
import static com.facebook.presto.sql.planner.PlannerUtils.isBroadcastJoin;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;

/**
 * Shard joins to eliminate skew:
 * <p>
 * Transform
 * <pre>
 * - Join
 *      S.key = T.key
 *      - S
 *      - T
 * </pre>
 * to
 * <pre>
 * - Join
 *  *    S.key = T.key and leftShard = rightShard
 *  *      - Project(leftShard:=random(NumShards))
 *             - S
 *  *      - Unnest(rightShard, seq)
 *             Project(seq:=sequence(0, NumShards - 1))
 *                - T
 * </pre>
 */

public class ShardJoins
        implements PlanOptimizer
{
    private final Metadata metadata;
    private final FunctionAndTypeManager functionAndTypeManager;
    private final StatsCalculator statsCalculator;
    private boolean isEnabledForTesting;

    public ShardJoins(Metadata metadata, FunctionAndTypeManager functionAndTypeManager, StatsCalculator statsCalculator)
    {
        this.metadata = requireNonNull(metadata, "metadata is null");
        this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
        this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null");
    }

    @Override
    public void setEnabledForTesting(boolean isSet)
    {
        isEnabledForTesting = isSet;
    }

    @Override
    public boolean isEnabled(Session session)
    {
        return isEnabledForTesting || !getShardedJoinStrategy(session).equals(DISABLED);
    }

    @Override
    public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector)
    {
        if (isEnabled(session)) {
            Rewriter rewriter = new Rewriter(session, metadata, functionAndTypeManager, idAllocator, variableAllocator);
            PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan, new HashSet<>());
            return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged());
        }

        return PlanOptimizerResult.optimizerResult(plan, false);
    }

    private static class Rewriter
            extends SimplePlanRewriter<Set<VariableReferenceExpression>>
    {
        private final Session session;
        private final Metadata metadata;
        private final FunctionAndTypeManager functionAndTypeManager;
        private final PlanNodeIdAllocator planNodeIdAllocator;
        private final VariableAllocator planVariableAllocator;
        private boolean planChanged;

        private Rewriter(Session session, Metadata metadata,
                FunctionAndTypeManager functionAndTypeManager, PlanNodeIdAllocator planNodeIdAllocator, VariableAllocator planVariableAllocator)
        {
            this.session = requireNonNull(session, "session is null");
            this.metadata = requireNonNull(metadata, "metadata is null");
            this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
            this.planNodeIdAllocator = requireNonNull(planNodeIdAllocator, "planNodeIdAllocator is null");
            this.planVariableAllocator = requireNonNull(planVariableAllocator, "planVariableAllocator is null");
        }

        public boolean isPlanChanged()
        {
            return planChanged;
        }

        @Override
        public PlanNode visitJoin(JoinNode joinNode, RewriteContext<Set<VariableReferenceExpression>> context)
        {
            if (isApplicable(joinNode)) {
                long numShards = getNumberOfShards();
                RowExpression randomNumber = call(
                        functionAndTypeManager,
                        "random",
                        BIGINT,
                        constant(numShards, BIGINT));
                VariableReferenceExpression leftShardVariable = planVariableAllocator.newVariable("shard", BIGINT);
                VariableReferenceExpression rightShardVariable = planVariableAllocator.newVariable("shard", BIGINT);

                PlanNode newLeftChild = PlannerUtils.addProjections(joinNode.getLeft(), planNodeIdAllocator, planVariableAllocator, ImmutableList.of(randomNumber), ImmutableList.of(leftShardVariable));

                PlanNode newRightChild = shardInput(numShards, joinNode.getRight(), rightShardVariable);
                EquiJoinClause shardEquality = new EquiJoinClause(leftShardVariable, rightShardVariable);
                List<EquiJoinClause> joinCriteria = new ArrayList<>();
                joinCriteria.addAll(joinNode.getCriteria());
                joinCriteria.add(shardEquality);
                PlanNode result = new JoinNode(
                        joinNode.getSourceLocation(),
                        joinNode.getId(),
                        joinNode.getStatsEquivalentPlanNode(),
                        joinNode.getType(),
                        newLeftChild,
                        newRightChild,
                        joinCriteria,
                        joinNode.getOutputVariables(),
                        joinNode.getFilter(),
                        joinNode.getLeftHashVariable(),
                        joinNode.getRightHashVariable(),
                        joinNode.getDistributionType(),
                        joinNode.getDynamicFilters());

                planChanged = true;
                return context.defaultRewrite(result);
            }

            return context.defaultRewrite(joinNode);
        }

        private boolean isApplicable(JoinNode joinNode)
        {
            return joinNode.getType() != FULL && joinNode.getType() != RIGHT && !isBroadcastJoin(joinNode) &&
                    (getShardedJoinStrategy(session).equals(ALWAYS) ||
                            getShardedJoinStrategy(session).equals(COST_BASED) && shouldShardJoin(joinNode));
        }

        private boolean shouldShardJoin(JoinNode joinNode)
        {
            // TODO: implement based on HBO stats
            return false;
        }

        private PlanNode shardInput(long numShards, PlanNode source, VariableReferenceExpression shardVariable)
        {
            checkState(numShards > 1);

            RowExpression sequenceExpression = call(
                    functionAndTypeManager,
                    "sequence",
                    new ArrayType(BIGINT),
                    constant((long) 0, BIGINT),
                    constant((long) numShards - 1, BIGINT));

            VariableReferenceExpression sequenceVariable = planVariableAllocator.newVariable(sequenceExpression);
            PlanNode projectSequence = PlannerUtils.addProjections(source, planNodeIdAllocator, planVariableAllocator, ImmutableList.of(sequenceExpression), ImmutableList.of(sequenceVariable));
            UnnestNode unnest = new UnnestNode(source.getSourceLocation(),
                    planNodeIdAllocator.getNextId(),
                    projectSequence,
                    projectSequence.getOutputVariables(),
                    ImmutableMap.of(sequenceVariable, ImmutableList.of(shardVariable)),
                    Optional.empty());
            return unnest;
        }

        private int getNumberOfShards()
        {
            // TODO: compute number of shards based on stats
            return getJoinShardCount(session);
        }
    }
}