DetermineJoinDistributionType.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.cost.CostComparator;
import com.facebook.presto.cost.LocalCostEstimate;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.cost.TaskCountEstimator;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.CteConsumerNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.spi.statistics.HistoryBasedSourceInfo;
import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Ordering;
import io.airlift.units.DataSize;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

import static com.facebook.presto.SystemSessionProperties.confidenceBasedBroadcastEnabled;
import static com.facebook.presto.SystemSessionProperties.getJoinDistributionType;
import static com.facebook.presto.SystemSessionProperties.getJoinMaxBroadcastTableSize;
import static com.facebook.presto.SystemSessionProperties.isSizeBasedJoinDistributionTypeEnabled;
import static com.facebook.presto.SystemSessionProperties.isUseBroadcastJoinWhenBuildSizeSmallProbeSizeUnknownEnabled;
import static com.facebook.presto.SystemSessionProperties.treatLowConfidenceZeroEstimationAsUnknownEnabled;
import static com.facebook.presto.cost.CostCalculatorWithEstimatedExchanges.calculateJoinCostWithoutOutput;
import static com.facebook.presto.spi.plan.JoinDistributionType.PARTITIONED;
import static com.facebook.presto.spi.plan.JoinDistributionType.REPLICATED;
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.LOW;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.AUTOMATIC;
import static com.facebook.presto.sql.planner.iterative.ConfidenceBasedBroadcastUtil.confidenceBasedBroadcast;
import static com.facebook.presto.sql.planner.iterative.ConfidenceBasedBroadcastUtil.treatLowConfidenceZeroEstimationsAsUnknown;
import static com.facebook.presto.sql.planner.iterative.rule.JoinSwappingUtils.isBelowBroadcastLimit;
import static com.facebook.presto.sql.planner.iterative.rule.JoinSwappingUtils.isSmallerThanThreshold;
import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar;
import static com.facebook.presto.sql.planner.plan.Patterns.join;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static java.lang.Double.NaN;
import static java.util.Objects.requireNonNull;

public class DetermineJoinDistributionType
        implements Rule<JoinNode>
{
    private static final Pattern<JoinNode> PATTERN = join().matching(joinNode -> !joinNode.getDistributionType().isPresent());

    private final CostComparator costComparator;
    private final TaskCountEstimator taskCountEstimator;

    // records whether distribution decision was cost-based
    private String statsSource;

    public DetermineJoinDistributionType(CostComparator costComparator, TaskCountEstimator taskCountEstimator)
    {
        this.costComparator = requireNonNull(costComparator, "costComparator is null");
        this.taskCountEstimator = requireNonNull(taskCountEstimator, "taskCountEstimator is null");
    }

    @Override
    public boolean isCostBased(Session session)
    {
        return getJoinDistributionType(session) == AUTOMATIC;
    }

    @Override
    public String getStatsSource()
    {
        return statsSource;
    }

    @Override
    public Pattern<JoinNode> getPattern()
    {
        return PATTERN;
    }

    @Override
    public Result apply(JoinNode joinNode, Captures captures, Context context)
    {
        JoinDistributionType joinDistributionType = getJoinDistributionType(context.getSession());
        if (joinDistributionType == AUTOMATIC) {
            PlanNode resultNode = getCostBasedJoin(joinNode, context);
            statsSource = context.getStatsProvider().getStats(joinNode).getSourceInfo().getSourceInfoName();
            return Result.ofPlanNode(resultNode);
        }
        return Result.ofPlanNode(getSyntacticOrderJoin(joinNode, context, joinDistributionType));
    }

    public static boolean isBelowMaxBroadcastSize(JoinNode joinNode, Context context)
    {
        DataSize joinMaxBroadcastTableSize = getJoinMaxBroadcastTableSize(context.getSession());

        PlanNode buildSide = joinNode.getRight();
        PlanNodeStatsEstimate buildSideStatsEstimate = context.getStatsProvider().getStats(buildSide);

        if (treatLowConfidenceZeroEstimationAsUnknownEnabled(context.getSession()) && isLowConfidenceZero(buildSide, context)) {
            return false;
        }

        double buildSideSizeInBytes = buildSideStatsEstimate.getOutputSizeInBytes(buildSide);
        return buildSideSizeInBytes <= joinMaxBroadcastTableSize.toBytes()
                || (isSizeBasedJoinDistributionTypeEnabled(context.getSession())
                && getSourceTablesSizeInBytes(buildSide, context) <= joinMaxBroadcastTableSize.toBytes());
    }

    private PlanNode getCostBasedJoin(JoinNode joinNode, Context context)
    {
        List<PlanNodeWithCost> possibleJoinNodes = new ArrayList<>();

        addJoinsWithDifferentDistributions(joinNode, possibleJoinNodes, context);
        addJoinsWithDifferentDistributions(joinNode.flipChildren(), possibleJoinNodes, context);

        if (isBelowMaxBroadcastSize(joinNode, context) && isBelowMaxBroadcastSize(joinNode.flipChildren(), context) && !mustPartition(joinNode) && confidenceBasedBroadcastEnabled(context.getSession())) {
            Optional<JoinNode> result = confidenceBasedBroadcast(joinNode, context);
            if (result.isPresent()) {
                return result.get();
            }
        }

        boolean buildSideLowConfidenceZero = isLowConfidenceZero(joinNode.getRight(), context);
        boolean probeSideLowConfidenceZero = isLowConfidenceZero(joinNode.getLeft(), context);
        if ((buildSideLowConfidenceZero || probeSideLowConfidenceZero) && treatLowConfidenceZeroEstimationAsUnknownEnabled(context.getSession())) {
            Optional<JoinNode> result = treatLowConfidenceZeroEstimationsAsUnknown(probeSideLowConfidenceZero, buildSideLowConfidenceZero, joinNode, context);
            if (result.isPresent()) {
                return result.get();
            }
        }

        if (possibleJoinNodes.stream().anyMatch(result -> result.getCost().hasUnknownComponents()) || possibleJoinNodes.isEmpty()) {
            // TODO: currently this session parameter is added so as to roll out the plan change gradually, after proved to be a better choice, make it default and get rid of the session parameter here.
            if (isUseBroadcastJoinWhenBuildSizeSmallProbeSizeUnknownEnabled(context.getSession()) && possibleJoinNodes.stream().anyMatch(result -> ((JoinNode) result.getPlanNode()).getDistributionType().get().equals(REPLICATED))) {
                JoinNode broadcastJoin = (JoinNode) getOnlyElement(possibleJoinNodes.stream().filter(result -> ((JoinNode) result.getPlanNode()).getDistributionType().get().equals(REPLICATED)).map(x -> x.getPlanNode()).collect(toImmutableList()));
                if (context.getStatsProvider().getStats(broadcastJoin.getBuild()).getSourceInfo() instanceof HistoryBasedSourceInfo) {
                    return broadcastJoin;
                }
            }
            if (isSizeBasedJoinDistributionTypeEnabled(context.getSession())) {
                return getSizeBasedJoin(joinNode, context);
            }
            return getSyntacticOrderJoin(joinNode, context, AUTOMATIC);
        }

        // Using Ordering to facilitate rule determinism
        Ordering<PlanNodeWithCost> planNodeOrderings = costComparator.forSession(context.getSession()).onResultOf(PlanNodeWithCost::getCost);
        return planNodeOrderings.min(possibleJoinNodes).getPlanNode();
    }

    private JoinNode getSizeBasedJoin(JoinNode joinNode, Context context)
    {
        boolean isRightSideSmall = isBelowBroadcastLimit(joinNode.getRight(), context);
        if (isRightSideSmall && !mustPartition(joinNode)) {
            // choose right join side with small source tables as replicated build side
            return joinNode.withDistributionType(REPLICATED);
        }

        boolean isLeftSideSmall = isBelowBroadcastLimit(joinNode.getLeft(), context);
        JoinNode flippedJoin = joinNode.flipChildren();
        if (isLeftSideSmall && !mustPartition(flippedJoin)) {
            // choose join left side with small source tables as replicated build side
            return flippedJoin.withDistributionType(REPLICATED);
        }

        if (isRightSideSmall) {
            // right side is small enough, but must be partitioned
            return joinNode.withDistributionType(PARTITIONED);
        }

        if (isLeftSideSmall) {
            // left side is small enough, but must be partitioned
            return flippedJoin.withDistributionType(PARTITIONED);
        }

        // Flip join sides if one side is smaller than the other by more than SIZE_DIFFERENCE_THRESHOLD times.
        // We use 8x factor because getFirstKnownOutputSizeInBytes may not have accounted for the reduction in the size of
        // the output from a filter or aggregation due to lack of estimates.
        // We use getFirstKnownOutputSizeInBytes instead of getSourceTablesSizeInBytes to account for the reduction in
        // output size from the operators between the join and the table scan as much as possible when comparing the sizes of the join sides.

        // All the REPLICATED cases were handled in the code above, so now we only consider PARTITIONED cases here
        if (isSmallerThanThreshold(joinNode.getRight(), joinNode.getLeft(), context) && !mustReplicate(joinNode, context)) {
            return joinNode.withDistributionType(PARTITIONED);
        }

        if (isSmallerThanThreshold(joinNode.getLeft(), joinNode.getRight(), context) && !mustReplicate(flippedJoin, context)) {
            return flippedJoin.withDistributionType(PARTITIONED);
        }

        // neither side is small enough, choose syntactic join order
        return getSyntacticOrderJoin(joinNode, context, AUTOMATIC);
    }

    public static double getSourceTablesSizeInBytes(PlanNode node, Context context)
    {
        return getSourceTablesSizeInBytes(node, context.getLookup(), context.getStatsProvider());
    }

    @VisibleForTesting
    static double getSourceTablesSizeInBytes(PlanNode node, Lookup lookup, StatsProvider statsProvider)
    {
        boolean hasExpandingNodes = PlanNodeSearcher.searchFrom(node, lookup)
                .whereIsInstanceOfAny(JoinSwappingUtils.EXPANDING_NODE_CLASSES)
                .matches();
        if (hasExpandingNodes) {
            return NaN;
        }

        List<PlanNode> sourceNodes = PlanNodeSearcher.searchFrom(node, lookup)
                .whereIsInstanceOfAny(ImmutableList.of(TableScanNode.class, ValuesNode.class, RemoteSourceNode.class, CteConsumerNode.class))
                .findAll();

        return sourceNodes.stream()
                .mapToDouble(sourceNode -> statsProvider.getStats(sourceNode).getOutputSizeInBytes(sourceNode))
                .sum();
    }

    private void addJoinsWithDifferentDistributions(JoinNode joinNode, List<PlanNodeWithCost> possibleJoinNodes, Context context)
    {
        if (!mustPartition(joinNode) && isBelowMaxBroadcastSize(joinNode, context)) {
            possibleJoinNodes.add(getJoinNodeWithCost(context, joinNode.withDistributionType(REPLICATED)));
        }
        // don't consider partitioned inequality joins because they execute on a single node.
        if (!mustReplicate(joinNode, context) && !joinNode.getCriteria().isEmpty()) {
            possibleJoinNodes.add(getJoinNodeWithCost(context, joinNode.withDistributionType(PARTITIONED)));
        }
    }

    private JoinNode getSyntacticOrderJoin(JoinNode joinNode, Context context, JoinDistributionType joinDistributionType)
    {
        if (mustPartition(joinNode)) {
            return joinNode.withDistributionType(PARTITIONED);
        }
        if (mustReplicate(joinNode, context)) {
            return joinNode.withDistributionType(REPLICATED);
        }
        if (joinDistributionType.canPartition()) {
            return joinNode.withDistributionType(PARTITIONED);
        }
        return joinNode.withDistributionType(REPLICATED);
    }

    public static boolean mustPartition(JoinNode joinNode)
    {
        return joinNode.getType().mustPartition();
    }

    private static boolean mustReplicate(JoinNode joinNode, Context context)
    {
        if (joinNode.getType().mustReplicate(joinNode.getCriteria())) {
            return true;
        }
        return isAtMostScalar(joinNode.getRight(), context.getLookup());
    }

    private PlanNodeWithCost getJoinNodeWithCost(Context context, JoinNode possibleJoinNode)
    {
        StatsProvider stats = context.getStatsProvider();
        boolean replicated = possibleJoinNode.getDistributionType().get().equals(REPLICATED);
        /*
         *   HACK!
         *
         *   Currently cost model always has to compute the total cost of an operation.
         *   For JOIN the total cost consist of 4 parts:
         *     - Cost of exchanges that have to be introduced to execute a JOIN
         *     - Cost of building a hash table
         *     - Cost of probing a hash table
         *     - Cost of building an output for matched rows
         *
         *   When output size for a JOIN cannot be estimated the cost model returns
         *   UNKNOWN cost for the join.
         *
         *   However assuming the cost of JOIN output is always the same, we can still make
         *   cost based decisions based on the input cost for different types of JOINs.
         *
         *   Although the side flipping can be made purely based on stats (smaller side
         *   always goes to the right), determining JOIN type is not that simple. As when
         *   choosing REPLICATED over REPARTITIONED join the cost of exchanging and building
         *   the hash table scales with the number of nodes where the build side is replicated.
         *
         *   TODO Decision about the distribution should be based on LocalCostEstimate only when PlanCostEstimate cannot be calculated. Otherwise cost comparator cannot take query.max-memory into account.
         */
        int estimatedSourceDistributedTaskCount = taskCountEstimator.estimateSourceDistributedTaskCount();
        LocalCostEstimate cost = calculateJoinCostWithoutOutput(
                possibleJoinNode.getLeft(),
                possibleJoinNode.getRight(),
                stats,
                replicated,
                estimatedSourceDistributedTaskCount);
        return new PlanNodeWithCost(cost.toPlanCost(), possibleJoinNode);
    }

    private static boolean isLowConfidenceZero(PlanNode planNode, Context context)
    {
        PlanNodeStatsEstimate statsEstimate = context.getStatsProvider().getStats(planNode);
        return statsEstimate.confidenceLevel() == LOW && statsEstimate.getOutputRowCount() == 0;
    }
}