JoinSwappingUtils.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;
import com.facebook.presto.Session;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.UnnestNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.GroupReference;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties;
import com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.airlift.units.DataSize;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;
import static com.facebook.presto.SystemSessionProperties.getJoinMaxBroadcastTableSize;
import static com.facebook.presto.SystemSessionProperties.getTaskConcurrency;
import static com.facebook.presto.SystemSessionProperties.isJoinSpillingEnabled;
import static com.facebook.presto.SystemSessionProperties.isSpillEnabled;
import static com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties.defaultParallelism;
import static com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties.exactlyPartitionedOn;
import static com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties.fixedParallelism;
import static com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties.singleStream;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.gatheringExchange;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.systemPartitionedExchange;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.Double.NaN;
import static java.lang.Double.isNaN;
public class JoinSwappingUtils
{
static final List<Class<? extends PlanNode>> EXPANDING_NODE_CLASSES = ImmutableList.of(JoinNode.class, UnnestNode.class);
private static final double SIZE_DIFFERENCE_THRESHOLD = 8;
private JoinSwappingUtils() {}
public static Optional<JoinNode> createRuntimeSwappedJoinNode(
JoinNode joinNode,
Metadata metadata,
Lookup lookup,
Session session,
PlanNodeIdAllocator idAllocator,
boolean nativeExecution)
{
JoinNode swapped = joinNode.flipChildren();
PlanNode newLeft = swapped.getLeft();
Optional<VariableReferenceExpression> leftHashVariable = swapped.getLeftHashVariable();
// Remove unnecessary LocalExchange in the current probe side. If the immediate left child (new probe side) of the join node
// is a localExchange, there are two cases: an Exchange introduced by the current probe side (previous build side); or it is a UnionNode.
// If the exchangeNode has more than 1 sources, it corresponds to the second case, otherwise it corresponds to the first case and could be safe to remove
PlanNode resolvedSwappedLeft = lookup.resolve(newLeft);
if (resolvedSwappedLeft instanceof ExchangeNode && resolvedSwappedLeft.getSources().size() == 1) {
// Ensure the new probe after skipping the local exchange will satisfy the required probe side property
if (checkProbeSidePropertySatisfied(resolvedSwappedLeft.getSources().get(0), metadata, lookup, session, nativeExecution)) {
newLeft = resolvedSwappedLeft.getSources().get(0);
// The HashGenerationOptimizer will generate hashVariables and append to the output layout of the nodes following the same order. Therefore,
// we use the index of the old hashVariable in the ExchangeNode output layout to retrieve the hashVariable from the new left node, and feed
// it as the leftHashVariable of the swapped join node.
if (swapped.getLeftHashVariable().isPresent()) {
int hashVariableIndex = resolvedSwappedLeft.getOutputVariables().indexOf(swapped.getLeftHashVariable().get());
leftHashVariable = Optional.of(resolvedSwappedLeft.getSources().get(0).getOutputVariables().get(hashVariableIndex));
// When join output layout contains new left side's hashVariable (e.g., a nested join in a single stage, the inner join's output layout possibly
// carry the join hashVariable from its new probe), after removing the local exchange at the new probe, the output variables of the join node will
// also change, which has to be broadcast upwards (rewriting plan nodes) until the point where this hashVariable is no longer the output.
// This is against typical iterativeOptimizer behavior and given this case is rare, just abort the swapping for this scenario.
if (swapped.getOutputVariables().contains(swapped.getLeftHashVariable().get())) {
return Optional.empty();
}
}
}
}
// Add additional localExchange if the new build side does not satisfy the partitioning conditions.
List<VariableReferenceExpression> buildJoinVariables = swapped.getCriteria().stream()
.map(EquiJoinClause::getRight)
.collect(toImmutableList());
PlanNode newRight = swapped.getRight();
if (!checkBuildSidePropertySatisfied(swapped.getRight(), buildJoinVariables, metadata, lookup, session, nativeExecution)) {
if (getTaskConcurrency(session) > 1) {
newRight = systemPartitionedExchange(
idAllocator.getNextId(),
LOCAL,
swapped.getRight(),
buildJoinVariables,
swapped.getRightHashVariable());
}
else {
newRight = gatheringExchange(idAllocator.getNextId(), LOCAL, swapped.getRight());
}
}
JoinNode newJoinNode = new JoinNode(
swapped.getSourceLocation(),
swapped.getId(),
swapped.getType(),
newLeft,
newRight,
swapped.getCriteria(),
swapped.getOutputVariables(),
swapped.getFilter(),
leftHashVariable,
swapped.getRightHashVariable(),
swapped.getDistributionType(),
swapped.getDynamicFilters());
return Optional.of(newJoinNode);
}
// Check if the new probe side after removing unnecessary local exchange is valid.
public static boolean checkProbeSidePropertySatisfied(PlanNode node, Metadata metadata, Lookup lookup, Session session, boolean nativeExecution)
{
StreamPreferredProperties requiredProbeProperty;
if (isSpillEnabled(session) && isJoinSpillingEnabled(session)) {
requiredProbeProperty = fixedParallelism();
}
else {
requiredProbeProperty = defaultParallelism(session);
}
StreamPropertyDerivations.StreamProperties nodeProperty = derivePropertiesRecursively(node, metadata, lookup, session, nativeExecution);
return requiredProbeProperty.isSatisfiedBy(nodeProperty);
}
// Check if the property of a planNode satisfies the requirements for directly feeding as the build side of a JoinNode.
private static boolean checkBuildSidePropertySatisfied(
PlanNode node,
List<VariableReferenceExpression> partitioningColumns,
Metadata metadata,
Lookup lookup,
Session session,
boolean nativeExecution)
{
StreamPreferredProperties requiredBuildProperty;
if (getTaskConcurrency(session) > 1) {
requiredBuildProperty = exactlyPartitionedOn(partitioningColumns);
}
else {
requiredBuildProperty = singleStream();
}
StreamPropertyDerivations.StreamProperties nodeProperty = derivePropertiesRecursively(node, metadata, lookup, session, nativeExecution);
return requiredBuildProperty.isSatisfiedBy(nodeProperty);
}
private static StreamPropertyDerivations.StreamProperties derivePropertiesRecursively(
PlanNode node,
Metadata metadata,
Lookup lookup,
Session session,
boolean nativeExecution)
{
PlanNode actual = lookup.resolve(node);
List<StreamPropertyDerivations.StreamProperties> inputProperties = actual.getSources().stream()
.map(source -> derivePropertiesRecursively(source, metadata, lookup, session, nativeExecution))
.collect(toImmutableList());
return StreamPropertyDerivations.deriveProperties(actual, inputProperties, metadata, session, nativeExecution);
}
public static boolean isBelowBroadcastLimit(PlanNode planNode, Rule.Context context)
{
DataSize joinMaxBroadcastTableSize = getJoinMaxBroadcastTableSize(context.getSession());
return DetermineJoinDistributionType.getSourceTablesSizeInBytes(planNode, context) <= joinMaxBroadcastTableSize.toBytes();
}
public static boolean isSmallerThanThreshold(PlanNode planNodeA, PlanNode planNodeB, Rule.Context context)
{
double aOutputSize = getFirstKnownOutputSizeInBytes(planNodeA, context);
double bOutputSize = getFirstKnownOutputSizeInBytes(planNodeB, context);
return aOutputSize * SIZE_DIFFERENCE_THRESHOLD < bOutputSize;
}
private static double getFirstKnownOutputSizeInBytes(PlanNode node, Rule.Context context)
{
return getFirstKnownOutputSizeInBytes(node, context.getLookup(), context.getStatsProvider());
}
/**
* Recursively looks for the first source node with a known estimate and uses that to return an approximate output size.
* Returns NaN if an un-estimated expanding node (Join or Unnest) is encountered.
* The amount of reduction in size from un-estimated non-expanding nodes (e.g. an un-estimated filter or aggregation)
* is not accounted here. We make use of the first available estimate and make decision about flipping join sides only if
* we find a large difference in output size of both sides.
*/
@VisibleForTesting
public static double getFirstKnownOutputSizeInBytes(PlanNode node, Lookup lookup, StatsProvider statsProvider)
{
return Stream.of(node)
.flatMap(planNode -> {
if (planNode instanceof GroupReference) {
return lookup.resolveGroup(node);
}
return Stream.of(planNode);
})
.mapToDouble(resolvedNode -> {
double outputSizeInBytes = statsProvider.getStats(resolvedNode).getOutputSizeInBytes(resolvedNode);
if (!isNaN(outputSizeInBytes)) {
return outputSizeInBytes;
}
if (EXPANDING_NODE_CLASSES.stream().anyMatch(clazz -> clazz.isInstance(resolvedNode))) {
return NaN;
}
List<PlanNode> sourceNodes = resolvedNode.getSources();
if (sourceNodes.isEmpty()) {
return NaN;
}
double sourcesOutputSizeInBytes = 0;
for (PlanNode sourceNode : sourceNodes) {
double firstKnownOutputSizeInBytes = getFirstKnownOutputSizeInBytes(sourceNode, lookup, statsProvider);
if (isNaN(firstKnownOutputSizeInBytes)) {
return NaN;
}
sourcesOutputSizeInBytes += firstKnownOutputSizeInBytes;
}
return sourcesOutputSizeInBytes;
})
.sum();
}
}