ReorderJoins.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;
import com.facebook.airlift.log.Logger;
import com.facebook.presto.Session;
import com.facebook.presto.cost.CostComparator;
import com.facebook.presto.cost.CostProvider;
import com.facebook.presto.cost.PlanCostEstimate;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.DeterminismEvaluator;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType;
import com.facebook.presto.sql.planner.EqualityInference;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.MultiJoinNode;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Predicate;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Ordering;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import static com.facebook.presto.SystemSessionProperties.confidenceBasedBroadcastEnabled;
import static com.facebook.presto.SystemSessionProperties.getJoinDistributionType;
import static com.facebook.presto.SystemSessionProperties.getJoinReorderingStrategy;
import static com.facebook.presto.SystemSessionProperties.getMaxReorderedJoins;
import static com.facebook.presto.SystemSessionProperties.shouldHandleComplexEquiJoins;
import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT;
import static com.facebook.presto.expressions.LogicalRowExpressions.and;
import static com.facebook.presto.expressions.RowExpressionNodeInliner.replaceExpression;
import static com.facebook.presto.spi.plan.JoinDistributionType.PARTITIONED;
import static com.facebook.presto.spi.plan.JoinDistributionType.REPLICATED;
import static com.facebook.presto.spi.plan.JoinType.INNER;
import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.AUTOMATIC;
import static com.facebook.presto.sql.planner.EqualityInference.createEqualityInference;
import static com.facebook.presto.sql.planner.PlannerUtils.addProjections;
import static com.facebook.presto.sql.planner.VariablesExtractor.extractUnique;
import static com.facebook.presto.sql.planner.iterative.ConfidenceBasedBroadcastUtil.confidenceBasedBroadcast;
import static com.facebook.presto.sql.planner.iterative.rule.DetermineJoinDistributionType.isBelowMaxBroadcastSize;
import static com.facebook.presto.sql.planner.iterative.rule.DetermineJoinDistributionType.mustPartition;
import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult.INFINITE_COST_RESULT;
import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult.UNKNOWN_COST_RESULT;
import static com.facebook.presto.sql.planner.optimizations.JoinNodeUtils.toRowExpression;
import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar;
import static com.facebook.presto.sql.planner.plan.AssignmentUtils.getNonIdentityAssignments;
import static com.facebook.presto.sql.planner.plan.Patterns.join;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Predicates.in;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.collect.Sets.powerSet;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toCollection;
public class ReorderJoins
implements Rule<JoinNode>
{
private static final Logger log = Logger.get(ReorderJoins.class);
// We check that join distribution type is absent because we only want
// to do this transformation once (reordered joins will have distribution type already set).
private final Pattern<JoinNode> joinNodePattern;
private final CostComparator costComparator;
private final Metadata metadata;
private final FunctionResolution functionResolution;
private final DeterminismEvaluator determinismEvaluator;
private String statsSource;
public ReorderJoins(CostComparator costComparator, Metadata metadata)
{
this.costComparator = requireNonNull(costComparator, "costComparator is null");
this.metadata = requireNonNull(metadata, "metadata is null");
this.functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver());
this.determinismEvaluator = new RowExpressionDeterminismEvaluator(metadata.getFunctionAndTypeManager());
this.joinNodePattern = join().matching(
joinNode -> !joinNode.getDistributionType().isPresent()
&& joinNode.getType() == INNER
&& determinismEvaluator.isDeterministic(joinNode.getFilter().orElse(TRUE_CONSTANT)));
}
@Override
public Pattern<JoinNode> getPattern()
{
return joinNodePattern;
}
@Override
public boolean isEnabled(Session session)
{
return getJoinReorderingStrategy(session) == AUTOMATIC;
}
@Override
public boolean isCostBased(Session session)
{
// when enabled, join order is always cost-based
return isEnabled(session);
}
public String getStatsSource()
{
return statsSource;
}
@Override
public Result apply(JoinNode joinNode, Captures captures, Context context)
{
MultiJoinNode multiJoinNode = toMultiJoinNode(joinNode, context.getLookup(), getMaxReorderedJoins(context.getSession()), shouldHandleComplexEquiJoins(context.getSession()),
functionResolution, determinismEvaluator);
JoinEnumerator joinEnumerator = new JoinEnumerator(
costComparator,
multiJoinNode.getFilter(),
context,
determinismEvaluator,
functionResolution,
metadata);
JoinEnumerationResult result = joinEnumerator.chooseJoinOrder(multiJoinNode.getSources(), multiJoinNode.getOutputVariables());
if (!result.getPlanNode().isPresent()) {
return Result.empty();
}
statsSource = context.getStatsProvider().getStats(joinNode).getSourceInfo().getSourceInfoName();
PlanNode transformedPlan = result.getPlanNode().get();
if (!multiJoinNode.getAssignments().isEmpty()) {
transformedPlan = new ProjectNode(
transformedPlan.getSourceLocation(),
context.getIdAllocator().getNextId(),
transformedPlan,
multiJoinNode.getAssignments(),
LOCAL);
}
return Result.ofPlanNode(transformedPlan);
}
@VisibleForTesting
static class JoinEnumerator
{
private final Session session;
private final CostProvider costProvider;
// Using Ordering to facilitate rule determinism
private final Ordering<JoinEnumerationResult> resultComparator;
private final PlanNodeIdAllocator idAllocator;
private final Metadata metadata;
private final RowExpression allFilter;
private final EqualityInference allFilterInference;
private final LogicalRowExpressions logicalRowExpressions;
private final Lookup lookup;
private final Context context;
private final Map<Set<PlanNode>, JoinEnumerationResult> memo = new HashMap<>();
private final FunctionResolution functionResolution;
@VisibleForTesting
JoinEnumerator(CostComparator costComparator, RowExpression filter, Context context, DeterminismEvaluator determinismEvaluator, FunctionResolution functionResolution, Metadata metadata)
{
this.context = requireNonNull(context);
this.session = requireNonNull(context.getSession(), "session is null");
this.costProvider = requireNonNull(context.getCostProvider(), "costProvider is null");
this.resultComparator = costComparator.forSession(session).onResultOf(result -> result.cost);
this.idAllocator = requireNonNull(context.getIdAllocator(), "idAllocator is null");
this.allFilter = requireNonNull(filter, "filter is null");
this.lookup = requireNonNull(context.getLookup(), "lookup is null");
this.metadata = requireNonNull(metadata, "metadata is null");
this.allFilterInference = createEqualityInference(metadata, filter);
this.logicalRowExpressions = new LogicalRowExpressions(determinismEvaluator, functionResolution, metadata.getFunctionAndTypeManager());
this.functionResolution = functionResolution;
}
private JoinEnumerationResult chooseJoinOrder(LinkedHashSet<PlanNode> sources, List<VariableReferenceExpression> outputVariables)
{
context.checkTimeoutNotExhausted();
Set<PlanNode> multiJoinKey = ImmutableSet.copyOf(sources);
JoinEnumerationResult bestResult = memo.get(multiJoinKey);
if (bestResult == null) {
checkState(sources.size() > 1, "sources size is less than or equal to one");
ImmutableList.Builder<JoinEnumerationResult> resultBuilder = ImmutableList.builder();
Set<Set<Integer>> partitions = generatePartitions(sources.size());
for (Set<Integer> partition : partitions) {
JoinEnumerationResult result = createJoinAccordingToPartitioning(sources, outputVariables, partition);
if (result.equals(UNKNOWN_COST_RESULT)) {
memo.put(multiJoinKey, result);
return result;
}
if (!result.equals(INFINITE_COST_RESULT)) {
resultBuilder.add(result);
}
}
List<JoinEnumerationResult> results = resultBuilder.build();
if (results.isEmpty()) {
memo.put(multiJoinKey, INFINITE_COST_RESULT);
return INFINITE_COST_RESULT;
}
bestResult = resultComparator.min(results);
memo.put(multiJoinKey, bestResult);
}
bestResult.planNode.ifPresent((planNode) -> log.debug("Least cost join was: %s", planNode));
return bestResult;
}
/**
* This method generates all the ways of dividing totalNodes into two sets
* each containing at least one node. It will generate one set for each
* possible partitioning. The other partition is implied in the absent values.
* In order not to generate the inverse of any set, we always include the 0th
* node in our sets.
*
* @return A set of sets each of which defines a partitioning of totalNodes
*/
@VisibleForTesting
static Set<Set<Integer>> generatePartitions(int totalNodes)
{
checkArgument(totalNodes > 1, "totalNodes must be greater than 1");
Set<Integer> numbers = IntStream.range(0, totalNodes)
.boxed()
.collect(toImmutableSet());
return powerSet(numbers).stream()
.filter(subSet -> subSet.contains(0))
.filter(subSet -> subSet.size() < numbers.size())
.collect(toImmutableSet());
}
@VisibleForTesting
JoinEnumerationResult createJoinAccordingToPartitioning(LinkedHashSet<PlanNode> sources, List<VariableReferenceExpression> outputVariables, Set<Integer> partitioning)
{
List<PlanNode> sourceList = ImmutableList.copyOf(sources);
LinkedHashSet<PlanNode> leftSources = partitioning.stream()
.map(sourceList::get)
.collect(toCollection(LinkedHashSet::new));
LinkedHashSet<PlanNode> rightSources = sources.stream()
.filter(source -> !leftSources.contains(source))
.collect(toCollection(LinkedHashSet::new));
return createJoin(leftSources, rightSources, outputVariables);
}
private JoinEnumerationResult createJoin(LinkedHashSet<PlanNode> leftSources, LinkedHashSet<PlanNode> rightSources, List<VariableReferenceExpression> outputVariables)
{
HashSet<VariableReferenceExpression> leftVariables = leftSources.stream()
.flatMap(node -> node.getOutputVariables().stream())
.collect(toCollection(HashSet::new));
HashSet<VariableReferenceExpression> rightVariables = rightSources.stream()
.flatMap(node -> node.getOutputVariables().stream())
.collect(toCollection(HashSet::new));
List<RowExpression> joinPredicates = getJoinPredicates(leftVariables, rightVariables);
VariableAllocator variableAllocator = context.getVariableAllocator();
JoinCondition joinConditions = extractJoinConditions(joinPredicates, leftVariables, rightVariables, variableAllocator);
List<EquiJoinClause> joinClauses = joinConditions.getJoinClauses();
List<RowExpression> joinFilters = joinConditions.getJoinFilters();
//Update the left & right variable sets with any new variables generated
leftVariables.addAll(joinConditions.getNewLeftAssignments().keySet());
rightVariables.addAll(joinConditions.getNewRightAssignments().keySet());
if (joinClauses.isEmpty()) {
return INFINITE_COST_RESULT;
}
Set<VariableReferenceExpression> requiredJoinVariables = ImmutableSet.<VariableReferenceExpression>builder()
.addAll(outputVariables)
.addAll(extractUnique(joinPredicates))
.build();
JoinEnumerationResult leftResult = getJoinSource(
leftSources,
requiredJoinVariables.stream()
.filter(leftVariables::contains)
.collect(toImmutableList()));
if (leftResult.equals(UNKNOWN_COST_RESULT)) {
return UNKNOWN_COST_RESULT;
}
if (leftResult.equals(INFINITE_COST_RESULT)) {
return INFINITE_COST_RESULT;
}
PlanNode left = leftResult.planNode.orElseThrow(() -> new VerifyException("Plan node is not present"));
if (!joinConditions.getNewLeftAssignments().isEmpty()) {
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> assignments = ImmutableMap.builder();
left.getOutputVariables().forEach(outputVariable -> assignments.put(outputVariable, outputVariable));
assignments.putAll(joinConditions.getNewLeftAssignments());
left = addProjections(left, idAllocator, assignments.build());
}
JoinEnumerationResult rightResult = getJoinSource(
rightSources,
requiredJoinVariables.stream()
.filter(rightVariables::contains)
.collect(toImmutableList()));
if (rightResult.equals(UNKNOWN_COST_RESULT)) {
return UNKNOWN_COST_RESULT;
}
if (rightResult.equals(INFINITE_COST_RESULT)) {
return INFINITE_COST_RESULT;
}
PlanNode right = rightResult.planNode.orElseThrow(() -> new VerifyException("Plan node is not present"));
if (!joinConditions.getNewRightAssignments().isEmpty()) {
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> assignments = ImmutableMap.builder();
right.getOutputVariables().forEach(outputVariable -> assignments.put(outputVariable, outputVariable));
assignments.putAll(joinConditions.getNewRightAssignments());
right = addProjections(right, idAllocator, assignments.build());
}
// sort output variables so that the left input variables are first
List<VariableReferenceExpression> sortedOutputVariables = Stream.concat(left.getOutputVariables().stream(), right.getOutputVariables().stream())
.filter(outputVariables::contains)
.collect(toImmutableList());
return setJoinNodeProperties(new JoinNode(
left.getSourceLocation(),
idAllocator.getNextId(),
INNER,
left,
right,
joinClauses,
sortedOutputVariables,
joinFilters.isEmpty() ? Optional.empty() : Optional.of(and(joinFilters)),
Optional.empty(),
Optional.empty(),
Optional.empty(),
ImmutableMap.of()));
}
private List<RowExpression> getJoinPredicates(Set<VariableReferenceExpression> leftVariables, Set<VariableReferenceExpression> rightVariables)
{
ImmutableList.Builder<RowExpression> joinPredicatesBuilder = ImmutableList.builder();
// This takes all conjuncts that were part of allFilters that
// could not be used for equality inference.
// If they use both the left and right variables, we add them to the list of joinPredicates
EqualityInference.Builder builder = new EqualityInference.Builder(metadata);
StreamSupport.stream(builder.nonInferableConjuncts(allFilter).spliterator(), false)
.map(conjunct -> allFilterInference.rewriteExpression(
conjunct,
variable -> leftVariables.contains(variable) || rightVariables.contains(variable)))
.filter(Objects::nonNull)
// filter expressions that contain only left or right variables
.filter(conjunct -> allFilterInference.rewriteExpression(conjunct, leftVariables::contains) == null)
.filter(conjunct -> allFilterInference.rewriteExpression(conjunct, rightVariables::contains) == null)
.forEach(joinPredicatesBuilder::add);
// create equality inference on available variables
// TODO: make generateEqualitiesPartitionedBy take left and right scope
List<RowExpression> joinEqualities = allFilterInference.generateEqualitiesPartitionedBy(
variable -> leftVariables.contains(variable) || rightVariables.contains(variable)).getScopeEqualities();
EqualityInference joinInference = createEqualityInference(metadata, joinEqualities.toArray(new RowExpression[0]));
joinPredicatesBuilder.addAll(joinInference.generateEqualitiesPartitionedBy(in(leftVariables)).getScopeStraddlingEqualities());
return joinPredicatesBuilder.build();
}
private JoinEnumerationResult getJoinSource(LinkedHashSet<PlanNode> nodes, List<VariableReferenceExpression> outputVariables)
{
if (nodes.size() == 1) {
PlanNode planNode = getOnlyElement(nodes);
ImmutableList.Builder<RowExpression> predicates = ImmutableList.builder();
predicates.addAll(allFilterInference.generateEqualitiesPartitionedBy(outputVariables::contains).getScopeEqualities());
EqualityInference.Builder builder = new EqualityInference.Builder(metadata);
StreamSupport.stream(builder.nonInferableConjuncts(allFilter).spliterator(), false)
.map(conjunct -> allFilterInference.rewriteExpression(conjunct, outputVariables::contains))
.filter(Objects::nonNull)
.forEach(predicates::add);
RowExpression filter = logicalRowExpressions.combineConjuncts(predicates.build());
if (!TRUE_CONSTANT.equals(filter)) {
planNode = new FilterNode(planNode.getSourceLocation(), idAllocator.getNextId(), planNode, filter);
}
return createJoinEnumerationResult(planNode);
}
return chooseJoinOrder(nodes, outputVariables);
}
@VisibleForTesting
JoinCondition extractJoinConditions(List<RowExpression> joinPredicates,
Set<VariableReferenceExpression> leftVariables,
Set<VariableReferenceExpression> rightVariables,
VariableAllocator variableAllocator)
{
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> newLeftAssignments = ImmutableMap.builder();
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> newRightAssignments = ImmutableMap.builder();
ImmutableList.Builder<EquiJoinClause> joinClauses = ImmutableList.builder();
ImmutableList.Builder<RowExpression> joinFilters = ImmutableList.builder();
for (RowExpression predicate : joinPredicates) {
if (predicate instanceof CallExpression
&& functionResolution.isEqualFunction(((CallExpression) predicate).getFunctionHandle())
&& ((CallExpression) predicate).getArguments().size() == 2) {
RowExpression argument0 = ((CallExpression) predicate).getArguments().get(0);
RowExpression argument1 = ((CallExpression) predicate).getArguments().get(1);
// First check if arguments refer to different sides of join
Set<VariableReferenceExpression> argument0Vars = extractUnique(argument0);
Set<VariableReferenceExpression> argument1Vars = extractUnique(argument1);
if (!((leftVariables.containsAll(argument0Vars) && rightVariables.containsAll(argument1Vars))
|| (rightVariables.containsAll(argument0Vars) && leftVariables.containsAll(argument1Vars)))) {
// Arguments have a mix of join sides, use this predicate as a filter
joinFilters.add(predicate);
continue;
}
// Next, check to see if first argument refers to left side and second argument to the right side
// If not, swap the arguments
if (leftVariables.containsAll(argument1Vars)) {
RowExpression temp = argument1;
argument1 = argument0;
argument0 = temp;
}
// Next, check if we need to create new assignments for complex equi-join clauses
// E.g. leftVar = ADD(rightVar1, rightVar2)
if (!(argument0 instanceof VariableReferenceExpression)) {
VariableReferenceExpression newLeft = variableAllocator.newVariable(argument0);
newLeftAssignments.put(newLeft, argument0);
argument0 = newLeft;
}
if (!(argument1 instanceof VariableReferenceExpression)) {
VariableReferenceExpression newRight = variableAllocator.newVariable(argument1);
newRightAssignments.put(newRight, argument1);
argument1 = newRight;
}
joinClauses.add(new EquiJoinClause((VariableReferenceExpression) argument0, (VariableReferenceExpression) argument1));
}
else {
joinFilters.add(predicate);
}
}
return new JoinCondition(joinClauses.build(), joinFilters.build(), newLeftAssignments.build(), newRightAssignments.build());
}
@VisibleForTesting
static class JoinCondition
{
List<EquiJoinClause> joinClauses;
List<RowExpression> joinFilters;
Map<VariableReferenceExpression, RowExpression> newLeftAssignments;
Map<VariableReferenceExpression, RowExpression> newRightAssignments;
public JoinCondition(List<EquiJoinClause> joinClauses, List<RowExpression> joinFilters,
Map<VariableReferenceExpression, RowExpression> left, Map<VariableReferenceExpression, RowExpression> right)
{
this.joinClauses = joinClauses;
this.joinFilters = joinFilters;
this.newLeftAssignments = left;
this.newRightAssignments = right;
}
public List<EquiJoinClause> getJoinClauses()
{
return joinClauses;
}
public List<RowExpression> getJoinFilters()
{
return joinFilters;
}
public Map<VariableReferenceExpression, RowExpression> getNewLeftAssignments()
{
return newLeftAssignments;
}
public Map<VariableReferenceExpression, RowExpression> getNewRightAssignments()
{
return newRightAssignments;
}
}
private JoinEnumerationResult setJoinNodeProperties(JoinNode joinNode)
{
// TODO avoid stat (but not cost) recalculation for all considered (distribution,flip) pairs, since resulting relation is the same in all case
if (isAtMostScalar(joinNode.getRight(), lookup)) {
return createJoinEnumerationResult(joinNode.withDistributionType(REPLICATED));
}
if (isAtMostScalar(joinNode.getLeft(), lookup)) {
return createJoinEnumerationResult(joinNode.flipChildren().withDistributionType(REPLICATED));
}
if (isBelowMaxBroadcastSize(joinNode, context) && isBelowMaxBroadcastSize(joinNode.flipChildren(), context) && !mustPartition(joinNode) && confidenceBasedBroadcastEnabled(context.getSession())) {
Optional<JoinNode> result = confidenceBasedBroadcast(joinNode, context);
if (result.isPresent()) {
return createJoinEnumerationResult(result.get());
}
}
List<JoinEnumerationResult> possibleJoinNodes = getPossibleJoinNodes(joinNode, getJoinDistributionType(session));
verify(!possibleJoinNodes.isEmpty(), "possibleJoinNodes is empty");
if (possibleJoinNodes.stream().anyMatch(UNKNOWN_COST_RESULT::equals)) {
return UNKNOWN_COST_RESULT;
}
return resultComparator.min(possibleJoinNodes);
}
private List<JoinEnumerationResult> getPossibleJoinNodes(JoinNode joinNode, JoinDistributionType distributionType)
{
checkArgument(joinNode.getType() == INNER, "unexpected join node type: %s", joinNode.getType());
if (joinNode.isCrossJoin()) {
return getPossibleJoinNodes(joinNode, REPLICATED);
}
switch (distributionType) {
case PARTITIONED:
return getPossibleJoinNodes(joinNode, PARTITIONED);
case BROADCAST:
return getPossibleJoinNodes(joinNode, REPLICATED);
case AUTOMATIC:
ImmutableList.Builder<JoinEnumerationResult> result = ImmutableList.builder();
result.addAll(getPossibleJoinNodes(joinNode, PARTITIONED));
result.addAll(getPossibleJoinNodes(joinNode, REPLICATED, node -> isBelowMaxBroadcastSize(node, context)));
return result.build();
default:
throw new IllegalArgumentException("unexpected join distribution type: " + distributionType);
}
}
private List<JoinEnumerationResult> getPossibleJoinNodes(JoinNode joinNode, com.facebook.presto.spi.plan.JoinDistributionType distributionType)
{
return getPossibleJoinNodes(joinNode, distributionType, (node) -> true);
}
private List<JoinEnumerationResult> getPossibleJoinNodes(JoinNode joinNode, com.facebook.presto.spi.plan.JoinDistributionType distributionType, Predicate<JoinNode> isAllowed)
{
List<JoinNode> nodes = ImmutableList.of(
joinNode.withDistributionType(distributionType),
joinNode.flipChildren().withDistributionType(distributionType));
return nodes.stream().filter(isAllowed).map(this::createJoinEnumerationResult).collect(toImmutableList());
}
private JoinEnumerationResult createJoinEnumerationResult(PlanNode planNode)
{
return JoinEnumerationResult.createJoinEnumerationResult(Optional.of(planNode), costProvider.getCost(planNode));
}
}
public static MultiJoinNode toMultiJoinNode(JoinNode joinNode, Lookup lookup, int joinLimit, boolean handleComplexEquiJoins, FunctionResolution functionResolution, DeterminismEvaluator determinismEvaluator)
{
// the number of sources is the number of joins + 1
return new JoinNodeFlattener(joinNode, lookup, joinLimit + 1, handleComplexEquiJoins, functionResolution, determinismEvaluator).toMultiJoinNode();
}
@VisibleForTesting
private static class JoinNodeFlattener
{
private final LinkedHashSet<PlanNode> sources = new LinkedHashSet<>();
private final Assignments intermediateAssignments;
private final boolean handleComplexEquiJoins;
private List<RowExpression> filters = new ArrayList<>();
private final List<VariableReferenceExpression> outputVariables;
private final FunctionResolution functionResolution;
private final DeterminismEvaluator determinismEvaluator;
private final Lookup lookup;
JoinNodeFlattener(JoinNode node, Lookup lookup, int sourceLimit, boolean handleComplexEquiJoins, FunctionResolution functionResolution,
DeterminismEvaluator determinismEvaluator)
{
requireNonNull(node, "node is null");
checkState(node.getType() == INNER, "join type must be INNER");
this.outputVariables = node.getOutputVariables();
this.lookup = requireNonNull(lookup, "lookup is null");
this.functionResolution = requireNonNull(functionResolution, "functionResolution is null");
this.determinismEvaluator = requireNonNull(determinismEvaluator, "determinismEvaluator is null");
this.handleComplexEquiJoins = handleComplexEquiJoins;
Map<VariableReferenceExpression, RowExpression> intermediateAssignments = new HashMap<>();
flattenNode(node, sourceLimit, intermediateAssignments);
// We resolve the intermediate assignments to only inputs of the flattened join node
ImmutableSet<VariableReferenceExpression> inputVariables = sources.stream().flatMap(s -> s.getOutputVariables().stream()).collect(toImmutableSet());
this.intermediateAssignments = resolveAssignments(intermediateAssignments, inputVariables);
rewriteFilterWithInlinedAssignments(this.intermediateAssignments);
}
private Assignments resolveAssignments(Map<VariableReferenceExpression, RowExpression> assignments, Set<VariableReferenceExpression> availableVariables)
{
HashSet<VariableReferenceExpression> resolvedVariables = new HashSet<>();
ImmutableList.copyOf(assignments.keySet()).forEach(variable -> resolveVariable(variable, resolvedVariables, assignments, availableVariables));
return Assignments.builder().putAll(assignments).build();
}
private void resolveVariable(VariableReferenceExpression variable, HashSet<VariableReferenceExpression> resolvedVariables, Map<VariableReferenceExpression,
RowExpression> assignments, Set<VariableReferenceExpression> availableVariables)
{
RowExpression expression = assignments.get(variable);
Sets.SetView<VariableReferenceExpression> variablesToResolve = Sets.difference(Sets.difference(extractUnique(expression), availableVariables), resolvedVariables);
// Recursively resolve any unresolved variables
variablesToResolve.forEach(variableToResolve -> resolveVariable(variableToResolve, resolvedVariables, assignments, availableVariables));
// Modify the assignment for the variable : Replace it with the now resolved constituent variables
assignments.put(variable, replaceExpression(expression, assignments));
// Mark this variable as resolved
resolvedVariables.add(variable);
}
private void rewriteFilterWithInlinedAssignments(Assignments assignments)
{
ImmutableList.Builder<RowExpression> modifiedFilters = ImmutableList.builder();
filters.forEach(filter -> modifiedFilters.add(replaceExpression(filter, assignments.getMap())));
filters = modifiedFilters.build();
}
private void flattenNode(PlanNode node, int limit, Map<VariableReferenceExpression, RowExpression> assignmentsBuilder)
{
PlanNode resolved = lookup.resolve(node);
if (resolved instanceof ProjectNode) {
ProjectNode projectNode = (ProjectNode) resolved;
// A ProjectNode could be 'hiding' a join source by building an assignment of a complex equi-join criteria like `left.key = right1.key1 + right1.key2`
// We open up the join space by tracking the assignments from this Project node; these will be inlined into the overall filters once we finish
// traversing the join graph
// We only do this if the ProjectNode assignments are deterministic
if (handleComplexEquiJoins && lookup.resolve(projectNode.getSource()) instanceof JoinNode &&
projectNode.getAssignments().getExpressions().stream().allMatch(determinismEvaluator::isDeterministic)) {
//We keep track of only the non-identity assignments since these are the ones that will be inlined into the overall filters
assignmentsBuilder.putAll(getNonIdentityAssignments(projectNode.getAssignments()));
flattenNode(projectNode.getSource(), limit, assignmentsBuilder);
}
else {
sources.add(node);
}
return;
}
// (limit - 2) because you need to account for adding left and right side
if (!(resolved instanceof JoinNode) || (sources.size() > (limit - 2))) {
sources.add(node);
return;
}
JoinNode joinNode = (JoinNode) resolved;
if (joinNode.getType() != INNER || !determinismEvaluator.isDeterministic(joinNode.getFilter().orElse(TRUE_CONSTANT)) || joinNode.getDistributionType().isPresent()) {
sources.add(node);
return;
}
// we set the left limit to limit - 1 to account for the node on the right
flattenNode(joinNode.getLeft(), limit - 1, assignmentsBuilder);
flattenNode(joinNode.getRight(), limit, assignmentsBuilder);
joinNode.getCriteria().stream()
.map(criteria -> toRowExpression(criteria, functionResolution))
.forEach(filters::add);
joinNode.getFilter().ifPresent(filters::add);
}
MultiJoinNode toMultiJoinNode()
{
ImmutableSet<VariableReferenceExpression> inputVariables = sources.stream().flatMap(source -> source.getOutputVariables().stream()).collect(toImmutableSet());
// We could have some output variables that were possibly generated from intermediate assignments
// For each of these variables, use the intermediate assignments to replace this variable with the set of input variables it uses
// Additionally, we build an overall set of assignments for the reordered Join node - this is used to add a wrapper Project over the updated output variables
// We do this to satisfy the invariant that the rewritten Join node must produce the same output variables as the input Join node
ImmutableSet.Builder<VariableReferenceExpression> updatedOutputVariables = ImmutableSet.builder();
Assignments.Builder overallAssignments = Assignments.builder();
boolean nonIdentityAssignmentsFound = false;
for (VariableReferenceExpression outputVariable : outputVariables) {
if (inputVariables.contains(outputVariable)) {
overallAssignments.put(outputVariable, outputVariable);
updatedOutputVariables.add(outputVariable);
continue;
}
checkState(intermediateAssignments.getMap().containsKey(outputVariable),
"Output variable [%s] not found in input variables or in intermediate assignments", outputVariable);
nonIdentityAssignmentsFound = true;
overallAssignments.put(outputVariable, intermediateAssignments.get(outputVariable));
updatedOutputVariables.addAll(extractUnique(intermediateAssignments.get(outputVariable)));
}
return new MultiJoinNode(sources,
and(filters),
updatedOutputVariables.build().asList(),
nonIdentityAssignmentsFound ? overallAssignments.build() : Assignments.of(), false, Optional.empty());
}
}
@VisibleForTesting
static class JoinEnumerationResult
{
public static final JoinEnumerationResult UNKNOWN_COST_RESULT = new JoinEnumerationResult(Optional.empty(), PlanCostEstimate.unknown());
public static final JoinEnumerationResult INFINITE_COST_RESULT = new JoinEnumerationResult(Optional.empty(), PlanCostEstimate.infinite());
private final Optional<PlanNode> planNode;
private final PlanCostEstimate cost;
private JoinEnumerationResult(Optional<PlanNode> planNode, PlanCostEstimate cost)
{
this.planNode = requireNonNull(planNode, "planNode is null");
this.cost = requireNonNull(cost, "cost is null");
checkArgument((cost.hasUnknownComponents() || cost.equals(PlanCostEstimate.infinite())) && !planNode.isPresent()
|| (!cost.hasUnknownComponents() || !cost.equals(PlanCostEstimate.infinite())) && planNode.isPresent(),
"planNode should be present if and only if cost is known");
}
public Optional<PlanNode> getPlanNode()
{
return planNode;
}
public PlanCostEstimate getCost()
{
return cost;
}
static JoinEnumerationResult createJoinEnumerationResult(Optional<PlanNode> planNode, PlanCostEstimate cost)
{
if (cost.hasUnknownComponents()) {
return UNKNOWN_COST_RESULT;
}
if (cost.equals(PlanCostEstimate.infinite())) {
return INFINITE_COST_RESULT;
}
return new JoinEnumerationResult(planNode, cost);
}
}
}