HashGenerationOptimizer.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.optimizations;
import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.DistinctLimitNode;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.MarkDistinctNode;
import com.facebook.presto.spi.plan.MergeJoinNode;
import com.facebook.presto.spi.plan.PartitioningScheme;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.SemiJoinNode;
import com.facebook.presto.spi.plan.SpatialJoinNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.UnionNode;
import com.facebook.presto.spi.plan.UnnestNode;
import com.facebook.presto.spi.plan.WindowNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.planner.plan.IndexJoinNode;
import com.facebook.presto.sql.planner.plan.InternalPlanVisitor;
import com.facebook.presto.sql.planner.plan.LateralJoinNode;
import com.facebook.presto.sql.planner.plan.RowNumberNode;
import com.facebook.presto.sql.planner.plan.SequenceNode;
import com.facebook.presto.sql.planner.plan.TopNRowNumberNode;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Lists;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import static com.facebook.presto.SystemSessionProperties.skipHashGenerationForJoinWithTableScanInput;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.spi.plan.JoinType.INNER;
import static com.facebook.presto.spi.plan.JoinType.LEFT;
import static com.facebook.presto.spi.plan.JoinType.RIGHT;
import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL;
import static com.facebook.presto.spi.plan.ProjectNode.Locality.REMOTE;
import static com.facebook.presto.sql.planner.PlannerUtils.HASH_CODE;
import static com.facebook.presto.sql.planner.PlannerUtils.INITIAL_HASH_VALUE;
import static com.facebook.presto.sql.planner.PlannerUtils.orNullHashCode;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
import static com.facebook.presto.sql.planner.optimizations.SetOperationNodeUtils.fromListMultimap;
import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Stream.concat;
public class HashGenerationOptimizer
implements PlanOptimizer
{
private final FunctionAndTypeManager functionAndTypeManager;
private boolean isEnabledForTesting;
public HashGenerationOptimizer(FunctionAndTypeManager functionAndTypeManager)
{
this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionManager is null");
}
@Override
public void setEnabledForTesting(boolean isSet)
{
isEnabledForTesting = isSet;
}
@Override
public boolean isEnabled(Session session)
{
return isEnabledForTesting || SystemSessionProperties.isOptimizeHashGenerationEnabled(session);
}
@Override
public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector)
{
requireNonNull(plan, "plan is null");
requireNonNull(session, "session is null");
requireNonNull(types, "types is null");
requireNonNull(variableAllocator, "variableAllocator is null");
requireNonNull(idAllocator, "idAllocator is null");
if (isEnabled(session)) {
PlanWithProperties result = new Rewriter(idAllocator, variableAllocator, functionAndTypeManager, session).accept(plan, new HashComputationSet());
return PlanOptimizerResult.optimizerResult(result.getNode(), true);
}
return PlanOptimizerResult.optimizerResult(plan, false);
}
private static class Rewriter
extends InternalPlanVisitor<PlanWithProperties, HashComputationSet>
{
private final PlanNodeIdAllocator idAllocator;
private final VariableAllocator variableAllocator;
private final FunctionAndTypeManager functionAndTypeManager;
private final Session session;
private Rewriter(PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, FunctionAndTypeManager functionAndTypeManager, Session session)
{
this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator is null");
this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionManager is null");
this.session = requireNonNull(session, "session is null");
}
@Override
public PlanWithProperties visitPlan(PlanNode node, HashComputationSet parentPreference)
{
return planSimpleNodeWithProperties(node, parentPreference);
}
@Override
public PlanWithProperties visitEnforceSingleRow(EnforceSingleRowNode node, HashComputationSet parentPreference)
{
// this plan node can only have a single input variable, so do not add extra hash variables
return planSimpleNodeWithProperties(node, new HashComputationSet(), true);
}
@Override
public PlanWithProperties visitApply(ApplyNode node, HashComputationSet context)
{
// Apply node is not supported by execution, so do not rewrite it
// that way query will fail in sanity checkers
return new PlanWithProperties(node, ImmutableMap.of());
}
public PlanWithProperties visitSequence(SequenceNode node, HashComputationSet context)
{
List<PlanNode> cteProducers = node.getCteProducers().stream()
.map(c ->
planAndEnforce(c, new HashComputationSet(), true, new HashComputationSet()).getNode())
.collect(ImmutableList.toImmutableList());
PlanWithProperties primarySource = plan(node.getPrimarySource(), context);
return new PlanWithProperties(
replaceChildren(node, ImmutableList.<PlanNode>builder()
.addAll(cteProducers)
.add(primarySource.getNode())
.build()),
primarySource.getHashVariables());
}
@Override
public PlanWithProperties visitLateralJoin(LateralJoinNode node, HashComputationSet context)
{
// Lateral join node is not supported by execution, so do not rewrite it
// that way query will fail in sanity checkers
return new PlanWithProperties(node, ImmutableMap.of());
}
@Override
public PlanWithProperties visitAggregation(AggregationNode node, HashComputationSet parentPreference)
{
Optional<HashComputation> groupByHash = Optional.empty();
List<VariableReferenceExpression> groupingKeys = node.getGroupingKeys();
if (!node.isStreamable() && !node.isSegmentedAggregationEligible() && !canSkipHashGeneration(node.getGroupingKeys())) {
// todo: for segmented aggregation, add optimizations for the fields that need to compute hash
groupByHash = computeHash(groupingKeys, functionAndTypeManager);
}
// aggregation does not pass through preferred hash variables
HashComputationSet requiredHashes = new HashComputationSet(groupByHash);
PlanWithProperties child = planAndEnforce(node.getSource(), requiredHashes, false, requiredHashes);
Optional<VariableReferenceExpression> hashVariable = groupByHash.map(child::getRequiredHashVariable);
return new PlanWithProperties(
new AggregationNode(
node.getSourceLocation(),
node.getId(),
child.getNode(),
node.getAggregations(),
node.getGroupingSets(),
node.getPreGroupedVariables(),
node.getStep(),
hashVariable,
node.getGroupIdVariable(),
node.getAggregationId()),
hashVariable.isPresent() ? ImmutableMap.of(groupByHash.get(), hashVariable.get()) : ImmutableMap.of());
}
private boolean canSkipHashGeneration(List<VariableReferenceExpression> partitionVariables)
{
// HACK: bigint grouped aggregation has special operators that do not use precomputed hash, so we can skip hash generation
return partitionVariables.isEmpty() || (partitionVariables.size() == 1 && Iterables.getOnlyElement(partitionVariables).getType().equals(BIGINT));
}
@Override
public PlanWithProperties visitGroupId(GroupIdNode node, HashComputationSet parentPreference)
{
// remove any hash variables not exported by the source of this node
return planSimpleNodeWithProperties(node, parentPreference.pruneVariables(node.getSource().getOutputVariables()));
}
@Override
public PlanWithProperties visitDistinctLimit(DistinctLimitNode node, HashComputationSet parentPreference)
{
// skip hash variable generation for single bigint
if (canSkipHashGeneration(node.getDistinctVariables())) {
return planSimpleNodeWithProperties(node, parentPreference);
}
Optional<HashComputation> hashComputation = computeHash(node.getDistinctVariables(), functionAndTypeManager);
PlanWithProperties child = planAndEnforce(
node.getSource(),
new HashComputationSet(hashComputation),
false,
parentPreference.withHashComputation(node, hashComputation));
VariableReferenceExpression hashVariable = child.getRequiredHashVariable(hashComputation.get());
// TODO: we need to reason about how pre-computed hashes from child relate to distinct variables. We should be able to include any precomputed hash
// that's functionally dependent on the distinct field in the set of distinct fields of the new node to be able to propagate it downstream.
// Currently, such precomputed hashes will be dropped by this operation.
return new PlanWithProperties(
new DistinctLimitNode(node.getSourceLocation(), node.getId(), child.getNode(), node.getLimit(), node.isPartial(), node.getDistinctVariables(), Optional.of(hashVariable), node.getTimeoutMillis()),
ImmutableMap.of(hashComputation.get(), hashVariable));
}
@Override
public PlanWithProperties visitMarkDistinct(MarkDistinctNode node, HashComputationSet parentPreference)
{
// skip hash variable generation for single bigint
if (canSkipHashGeneration(node.getDistinctVariables())) {
return planSimpleNodeWithProperties(node, parentPreference);
}
Optional<HashComputation> hashComputation = computeHash(node.getDistinctVariables(), functionAndTypeManager);
PlanWithProperties child = planAndEnforce(
node.getSource(),
new HashComputationSet(hashComputation),
false,
parentPreference.withHashComputation(node, hashComputation));
VariableReferenceExpression hashVariable = child.getRequiredHashVariable(hashComputation.get());
return new PlanWithProperties(
new MarkDistinctNode(node.getSourceLocation(), node.getId(), child.getNode(), node.getMarkerVariable(), node.getDistinctVariables(), Optional.of(hashVariable)),
child.getHashVariables());
}
@Override
public PlanWithProperties visitRowNumber(RowNumberNode node, HashComputationSet parentPreference)
{
if (node.getPartitionBy().isEmpty()) {
return planSimpleNodeWithProperties(node, parentPreference);
}
Optional<HashComputation> hashComputation = computeHash(node.getPartitionBy(), functionAndTypeManager);
PlanWithProperties child = planAndEnforce(
node.getSource(),
new HashComputationSet(hashComputation),
false,
parentPreference.withHashComputation(node, hashComputation));
VariableReferenceExpression hashVariable = child.getRequiredHashVariable(hashComputation.get());
return new PlanWithProperties(
new RowNumberNode(
node.getSourceLocation(),
node.getId(),
child.getNode(),
node.getPartitionBy(),
node.getRowNumberVariable(),
node.getMaxRowCountPerPartition(),
node.isPartial(),
Optional.of(hashVariable)),
child.getHashVariables());
}
@Override
public PlanWithProperties visitTopNRowNumber(TopNRowNumberNode node, HashComputationSet parentPreference)
{
if (node.getPartitionBy().isEmpty()) {
return planSimpleNodeWithProperties(node, parentPreference);
}
Optional<HashComputation> hashComputation = computeHash(node.getPartitionBy(), functionAndTypeManager);
PlanWithProperties child = planAndEnforce(
node.getSource(),
new HashComputationSet(hashComputation),
false,
parentPreference.withHashComputation(node, hashComputation));
VariableReferenceExpression hashVariable = child.getRequiredHashVariable(hashComputation.get());
return new PlanWithProperties(
new TopNRowNumberNode(
node.getSourceLocation(),
node.getId(),
child.getNode(),
node.getSpecification(),
node.getRowNumberVariable(),
node.getMaxRowCountPerPartition(),
node.isPartial(),
Optional.of(hashVariable)),
child.getHashVariables());
}
private boolean skipHashComputeForJoinInput(PlanNode node, Optional<HashComputation> hashComputation, HashComputationSet parentPreference)
{
return node instanceof TableScanNode && hashComputation.isPresent() && hashComputation.get().isSingleBigIntVariable() && !parentPreference.getHashes().contains(hashComputation.get());
}
@Override
public PlanWithProperties visitJoin(JoinNode node, HashComputationSet parentPreference)
{
List<EquiJoinClause> clauses = node.getCriteria();
if (clauses.isEmpty()) {
// join does not pass through preferred hash variables since they take more memory and since
// the join node filters, may take more compute
PlanWithProperties left = planAndEnforce(node.getLeft(), new HashComputationSet(), true, new HashComputationSet());
PlanWithProperties right = planAndEnforce(node.getRight(), new HashComputationSet(), true, new HashComputationSet());
checkState(left.getHashVariables().isEmpty() && right.getHashVariables().isEmpty());
return new PlanWithProperties(
replaceChildren(node, ImmutableList.of(left.getNode(), right.getNode())),
ImmutableMap.of());
}
// join does not pass through preferred hash variables since they take more memory and since
// the join node filters, may take more compute
Optional<HashComputation> leftHashComputation = computeHash(Lists.transform(clauses, EquiJoinClause::getLeft), functionAndTypeManager);
if (skipHashGenerationForJoinWithTableScanInput(session) && skipHashComputeForJoinInput(node.getLeft(), leftHashComputation, parentPreference)) {
leftHashComputation = Optional.empty();
}
PlanWithProperties left = planAndEnforce(node.getLeft(), new HashComputationSet(leftHashComputation), true, new HashComputationSet(leftHashComputation));
Optional<VariableReferenceExpression> leftHashVariable = leftHashComputation.isPresent() ? Optional.of(left.getRequiredHashVariable(leftHashComputation.get())) : Optional.empty();
Optional<HashComputation> rightHashComputation = computeHash(Lists.transform(clauses, EquiJoinClause::getRight), functionAndTypeManager);
if (skipHashGenerationForJoinWithTableScanInput(session) && skipHashComputeForJoinInput(node.getRight(), rightHashComputation, parentPreference)) {
rightHashComputation = Optional.empty();
}
// drop undesired hash variables from build to save memory
PlanWithProperties right = planAndEnforce(node.getRight(), new HashComputationSet(rightHashComputation), true, new HashComputationSet(rightHashComputation));
Optional<VariableReferenceExpression> rightHashVariable = rightHashComputation.isPresent() ? Optional.of(right.getRequiredHashVariable(rightHashComputation.get())) : Optional.empty();
// build map of all hash variables
// NOTE: Full outer join doesn't use hash variables
Map<HashComputation, VariableReferenceExpression> allHashVariables = new HashMap<>();
if (node.getType() == INNER || node.getType() == LEFT) {
allHashVariables.putAll(left.getHashVariables());
}
if (node.getType() == INNER || node.getType() == RIGHT) {
allHashVariables.putAll(right.getHashVariables());
}
return buildJoinNodeWithPreferredHashes(node, left, right, allHashVariables, parentPreference, leftHashVariable, rightHashVariable);
}
private PlanWithProperties buildJoinNodeWithPreferredHashes(
JoinNode node,
PlanWithProperties left,
PlanWithProperties right,
Map<HashComputation, VariableReferenceExpression> allHashVariables,
HashComputationSet parentPreference,
Optional<VariableReferenceExpression> leftHashVariable,
Optional<VariableReferenceExpression> rightHashVariable)
{
// retain only hash variables preferred by parent nodes
Map<HashComputation, VariableReferenceExpression> hashVariablesWithParentPreferences =
allHashVariables.entrySet()
.stream()
.filter(entry -> parentPreference.getHashes().contains(entry.getKey()))
.collect(toImmutableMap(Entry::getKey, Entry::getValue));
List<VariableReferenceExpression> outputVariables = concat(left.getNode().getOutputVariables().stream(), right.getNode().getOutputVariables().stream())
.filter(variable -> node.getOutputVariables().contains(variable) ||
hashVariablesWithParentPreferences.values().contains(variable))
.collect(toImmutableList());
return new PlanWithProperties(
new JoinNode(
node.getSourceLocation(),
node.getId(),
node.getType(),
left.getNode(),
right.getNode(),
node.getCriteria(),
outputVariables,
node.getFilter(),
leftHashVariable,
rightHashVariable,
node.getDistributionType(),
node.getDynamicFilters()),
hashVariablesWithParentPreferences);
}
@Override
public PlanWithProperties visitSemiJoin(SemiJoinNode node, HashComputationSet parentPreference)
{
Optional<HashComputation> sourceHashComputation = computeHash(ImmutableList.of(node.getSourceJoinVariable()), functionAndTypeManager);
PlanWithProperties source = planAndEnforce(
node.getSource(),
new HashComputationSet(sourceHashComputation),
true,
new HashComputationSet(sourceHashComputation));
VariableReferenceExpression sourceHashVariable = source.getRequiredHashVariable(sourceHashComputation.get());
Optional<HashComputation> filterHashComputation = computeHash(ImmutableList.of(node.getFilteringSourceJoinVariable()), functionAndTypeManager);
HashComputationSet requiredHashes = new HashComputationSet(filterHashComputation);
PlanWithProperties filteringSource = planAndEnforce(node.getFilteringSource(), requiredHashes, true, requiredHashes);
VariableReferenceExpression filteringSourceHashVariable = filteringSource.getRequiredHashVariable(filterHashComputation.get());
return new PlanWithProperties(
new SemiJoinNode(
node.getSourceLocation(),
node.getId(),
source.getNode(),
filteringSource.getNode(),
node.getSourceJoinVariable(),
node.getFilteringSourceJoinVariable(),
node.getSemiJoinOutput(),
Optional.of(sourceHashVariable),
Optional.of(filteringSourceHashVariable),
node.getDistributionType(),
node.getDynamicFilters()),
source.getHashVariables());
}
@Override
public PlanWithProperties visitSpatialJoin(SpatialJoinNode node, HashComputationSet parentPreference)
{
PlanWithProperties left = planAndEnforce(node.getLeft(), new HashComputationSet(), true, new HashComputationSet());
PlanWithProperties right = planAndEnforce(node.getRight(), new HashComputationSet(), true, new HashComputationSet());
verify(left.getHashVariables().isEmpty(), "probe side of the spatial join should not include hash variables");
verify(right.getHashVariables().isEmpty(), "build side of the spatial join should not include hash variables");
return new PlanWithProperties(
replaceChildren(node, ImmutableList.of(left.getNode(), right.getNode())),
ImmutableMap.of());
}
@Override
public PlanWithProperties visitIndexJoin(IndexJoinNode node, HashComputationSet parentPreference)
{
List<IndexJoinNode.EquiJoinClause> clauses = node.getCriteria();
// join does not pass through preferred hash variables since they take more memory and since
// the join node filters, may take more compute
Optional<HashComputation> probeHashComputation = computeHash(Lists.transform(clauses, IndexJoinNode.EquiJoinClause::getProbe), functionAndTypeManager);
PlanWithProperties probe = planAndEnforce(
node.getProbeSource(),
new HashComputationSet(probeHashComputation),
true,
new HashComputationSet(probeHashComputation));
VariableReferenceExpression probeHashVariable = probe.getRequiredHashVariable(probeHashComputation.get());
Optional<HashComputation> indexHashComputation = computeHash(Lists.transform(clauses, IndexJoinNode.EquiJoinClause::getIndex), functionAndTypeManager);
HashComputationSet requiredHashes = new HashComputationSet(indexHashComputation);
PlanWithProperties index = planAndEnforce(node.getIndexSource(), requiredHashes, true, requiredHashes);
VariableReferenceExpression indexHashVariable = index.getRequiredHashVariable(indexHashComputation.get());
// build map of all hash variables
Map<HashComputation, VariableReferenceExpression> allHashVariables = new HashMap<>();
if (node.getType() == JoinType.INNER) {
allHashVariables.putAll(probe.getHashVariables());
}
allHashVariables.putAll(index.getHashVariables());
return new PlanWithProperties(
new IndexJoinNode(
node.getSourceLocation(),
node.getId(),
node.getType(),
probe.getNode(),
index.getNode(),
node.getCriteria(),
node.getFilter(),
Optional.of(probeHashVariable),
Optional.of(indexHashVariable)),
allHashVariables);
}
@Override
public PlanWithProperties visitMergeJoin(MergeJoinNode node, HashComputationSet parentPreference)
{
PlanWithProperties left = planAndEnforce(node.getLeft(), new HashComputationSet(), true, new HashComputationSet());
PlanWithProperties right = planAndEnforce(node.getRight(), new HashComputationSet(), true, new HashComputationSet());
verify(left.getHashVariables().isEmpty(), "left side of the merge join should not include hash variables");
verify(right.getHashVariables().isEmpty(), "right side of the merge join should not include hash variables");
return new PlanWithProperties(
replaceChildren(node, ImmutableList.of(left.getNode(), right.getNode())),
ImmutableMap.of());
}
@Override
public PlanWithProperties visitWindow(WindowNode node, HashComputationSet parentPreference)
{
if (node.getPartitionBy().isEmpty()) {
return planSimpleNodeWithProperties(node, parentPreference, true);
}
Optional<HashComputation> hashComputation = computeHash(node.getPartitionBy(), functionAndTypeManager);
PlanWithProperties child = planAndEnforce(
node.getSource(),
new HashComputationSet(hashComputation),
true,
parentPreference.withHashComputation(node, hashComputation));
VariableReferenceExpression hashSymbol = child.getRequiredHashVariable(hashComputation.get());
return new PlanWithProperties(
new WindowNode(
node.getSourceLocation(),
node.getId(),
child.getNode(),
node.getSpecification(),
node.getWindowFunctions(),
Optional.of(hashSymbol),
node.getPrePartitionedInputs(),
node.getPreSortedOrderPrefix()),
child.getHashVariables());
}
@Override
public PlanWithProperties visitExchange(ExchangeNode node, HashComputationSet parentPreference)
{
// remove any hash variables not exported by this node
HashComputationSet preference = parentPreference.pruneVariables(node.getOutputVariables());
// Currently, precomputed hash values are only supported for system hash distributions without constants
Optional<HashComputation> partitionVariables = Optional.empty();
PartitioningScheme partitioningScheme = node.getPartitioningScheme();
if (partitioningScheme.getPartitioning().getHandle().equals(FIXED_HASH_DISTRIBUTION) &&
partitioningScheme.getPartitioning().getArguments().stream().allMatch(VariableReferenceExpression.class::isInstance)) {
// add precomputed hash for exchange
partitionVariables = computeHash(
partitioningScheme.getPartitioning().getArguments().stream()
.map(VariableReferenceExpression.class::cast)
.collect(toImmutableList()),
functionAndTypeManager);
preference = preference.withHashComputation(partitionVariables);
}
// establish fixed ordering for hash variables
List<HashComputation> hashVariableOrder = ImmutableList.copyOf(preference.getHashes());
Map<HashComputation, VariableReferenceExpression> newHashVariables = new HashMap<>();
for (HashComputation preferredHashVariable : hashVariableOrder) {
newHashVariables.put(preferredHashVariable, variableAllocator.newHashVariable());
}
// rewrite partition function to include new variables (and precomputed hash
partitioningScheme = new PartitioningScheme(
partitioningScheme.getPartitioning(),
ImmutableList.<VariableReferenceExpression>builder()
.addAll(partitioningScheme.getOutputLayout())
.addAll(hashVariableOrder.stream()
.map(newHashVariables::get)
.collect(toImmutableList()))
.build(),
partitionVariables.map(newHashVariables::get),
partitioningScheme.isReplicateNullsAndAny(),
partitioningScheme.isScaleWriters(),
partitioningScheme.getEncoding(),
partitioningScheme.getBucketToPartition());
// add hash variables to sources
ImmutableList.Builder<List<VariableReferenceExpression>> newInputs = ImmutableList.builder();
ImmutableList.Builder<PlanNode> newSources = ImmutableList.builder();
for (int sourceId = 0; sourceId < node.getSources().size(); sourceId++) {
PlanNode source = node.getSources().get(sourceId);
List<VariableReferenceExpression> inputVariables = node.getInputs().get(sourceId);
Map<VariableReferenceExpression, VariableReferenceExpression> outputToInputMap = new HashMap<>();
for (int variableId = 0; variableId < inputVariables.size(); variableId++) {
outputToInputMap.put(node.getOutputVariables().get(variableId), inputVariables.get(variableId));
}
Function<VariableReferenceExpression, Optional<VariableReferenceExpression>> outputToInputTranslator = variable -> Optional.of(outputToInputMap.get(variable));
HashComputationSet sourceContext = preference.translate(outputToInputTranslator);
PlanWithProperties child = planAndEnforce(source, sourceContext, true, sourceContext);
newSources.add(child.getNode());
// add hash variables to inputs in the required order
ImmutableList.Builder<VariableReferenceExpression> newInputVariables = ImmutableList.builder();
newInputVariables.addAll(inputVariables);
for (HashComputation preferredHashSymbol : hashVariableOrder) {
HashComputation hashComputation = preferredHashSymbol.translate(outputToInputTranslator).get();
newInputVariables.add(child.getRequiredHashVariable(hashComputation));
}
newInputs.add(newInputVariables.build());
}
return new PlanWithProperties(
new ExchangeNode(
node.getSourceLocation(),
node.getId(),
node.getType(),
node.getScope(),
partitioningScheme,
newSources.build(),
newInputs.build(),
node.isEnsureSourceOrdering(),
node.getOrderingScheme()),
newHashVariables);
}
@Override
public PlanWithProperties visitUnion(UnionNode node, HashComputationSet parentPreference)
{
// remove any hash variables not exported by this node
HashComputationSet preference = parentPreference.pruneVariables(node.getOutputVariables());
// create new hash variables
Map<HashComputation, VariableReferenceExpression> newHashVariables = new HashMap<>();
for (HashComputation preferredHashSymbol : preference.getHashes()) {
newHashVariables.put(preferredHashSymbol, variableAllocator.newHashVariable());
}
// add hash variables to sources
ImmutableListMultimap.Builder<VariableReferenceExpression, VariableReferenceExpression> newVariableMapping = ImmutableListMultimap.builder();
node.getVariableMapping().forEach(newVariableMapping::putAll);
ImmutableList.Builder<PlanNode> newSources = ImmutableList.builder();
for (int sourceId = 0; sourceId < node.getSources().size(); sourceId++) {
// translate preference to input variables
Map<VariableReferenceExpression, VariableReferenceExpression> outputToInputMap = new HashMap<>();
for (VariableReferenceExpression outputVariables : node.getOutputVariables()) {
outputToInputMap.put(outputVariables, node.getVariableMapping().get(outputVariables).get(sourceId));
}
Function<VariableReferenceExpression, Optional<VariableReferenceExpression>> outputToInputTranslator = variable -> Optional.of(outputToInputMap.get(variable));
HashComputationSet sourcePreference = preference.translate(outputToInputTranslator);
PlanWithProperties child = planAndEnforce(node.getSources().get(sourceId), sourcePreference, true, sourcePreference);
newSources.add(child.getNode());
// add hash variables to inputs
for (Entry<HashComputation, VariableReferenceExpression> entry : newHashVariables.entrySet()) {
HashComputation hashComputation = entry.getKey().translate(outputToInputTranslator).get();
newVariableMapping.put(entry.getValue(), child.getRequiredHashVariable(hashComputation));
}
}
ListMultimap<VariableReferenceExpression, VariableReferenceExpression> outputsToInputs = newVariableMapping.build();
return new PlanWithProperties(
new UnionNode(
node.getSourceLocation(),
node.getId(),
newSources.build(),
ImmutableList.copyOf(outputsToInputs.keySet()),
fromListMultimap(outputsToInputs)),
newHashVariables);
}
@Override
public PlanWithProperties visitProject(ProjectNode node, HashComputationSet parentPreference)
{
Map<VariableReferenceExpression, VariableReferenceExpression> outputToInputMapping = computeIdentityTranslations(node.getAssignments().getMap());
HashComputationSet sourceContext = parentPreference.translate(variable -> Optional.ofNullable(outputToInputMapping.get(variable)));
PlanWithProperties child = plan(node.getSource(), sourceContext);
// create a new project node with all assignments from the original node
Assignments.Builder newAssignments = Assignments.builder();
newAssignments.putAll(node.getAssignments());
// and all hash variables that could be translated to the source variables
Map<VariableReferenceExpression, RowExpression> hashAssignments = new HashMap<>();
Map<HashComputation, VariableReferenceExpression> allHashVariables = new HashMap<>();
for (HashComputation hashComputation : sourceContext.getHashes()) {
VariableReferenceExpression hashVariable = child.getHashVariables().get(hashComputation);
RowExpression hashExpression;
if (hashVariable == null) {
hashVariable = variableAllocator.newHashVariable();
hashExpression = hashComputation.getHashExpression();
}
else {
hashExpression = hashVariable;
}
hashAssignments.put(hashVariable, hashExpression);
allHashVariables.put(hashComputation, hashVariable);
}
if (node.getLocality().equals(REMOTE) && !hashAssignments.isEmpty()) {
// if the ProjectNode is remote, created a local projection with identity projection and hash
Assignments.Builder localProjectionAssignments = Assignments.builder();
child.getNode().getOutputVariables().forEach(variable -> localProjectionAssignments.put(variable, variable));
localProjectionAssignments.putAll(hashAssignments);
ProjectNode localProjectNode = new ProjectNode(child.getNode().getSourceLocation(), idAllocator.getNextId(), child.getNode(), localProjectionAssignments.build(), LOCAL);
// add identity projection for hash variable to remote projection
hashAssignments.keySet().forEach(variable -> newAssignments.put(variable, variable));
return new PlanWithProperties(new ProjectNode(localProjectNode.getSourceLocation(), idAllocator.getNextId(), localProjectNode, newAssignments.build(), REMOTE), allHashVariables);
}
newAssignments.putAll(hashAssignments);
return new PlanWithProperties(new ProjectNode(node.getSourceLocation(), node.getId(), child.getNode(), newAssignments.build(), node.getLocality()), allHashVariables);
}
@Override
public PlanWithProperties visitUnnest(UnnestNode node, HashComputationSet parentPreference)
{
PlanWithProperties child = plan(node.getSource(), parentPreference.pruneVariables(node.getSource().getOutputVariables()));
// only pass through hash variables requested by the parent
Map<HashComputation, VariableReferenceExpression> hashVariables = new HashMap<>(child.getHashVariables());
hashVariables.keySet().retainAll(parentPreference.getHashes());
return new PlanWithProperties(
new UnnestNode(
node.getSourceLocation(),
node.getId(),
child.getNode(),
ImmutableList.<VariableReferenceExpression>builder()
.addAll(node.getReplicateVariables())
.addAll(hashVariables.values())
.build(),
node.getUnnestVariables(),
node.getOrdinalityVariable()),
hashVariables);
}
private PlanWithProperties planSimpleNodeWithProperties(PlanNode node, HashComputationSet preferredHashes)
{
return planSimpleNodeWithProperties(node, preferredHashes, true);
}
private PlanWithProperties planSimpleNodeWithProperties(
PlanNode node,
HashComputationSet preferredHashes,
boolean alwaysPruneExtraHashVariables)
{
if (node.getSources().isEmpty()) {
return new PlanWithProperties(node, ImmutableMap.of());
}
// There is not requirement to produce hash variables and only preference for variables
PlanWithProperties source = planAndEnforce(Iterables.getOnlyElement(node.getSources()), new HashComputationSet(), alwaysPruneExtraHashVariables, preferredHashes);
PlanNode result = replaceChildren(node, ImmutableList.of(source.getNode()));
// return only hash variables that are passed through the new node
Map<HashComputation, VariableReferenceExpression> hashVariables = new HashMap<>(source.getHashVariables());
hashVariables.values().retainAll(result.getOutputVariables());
return new PlanWithProperties(result, hashVariables);
}
private PlanWithProperties planAndEnforce(
PlanNode node,
HashComputationSet requiredHashes,
boolean pruneExtraHashVariables,
HashComputationSet preferredHashes)
{
PlanWithProperties result = plan(node, preferredHashes);
boolean preferenceSatisfied;
if (pruneExtraHashVariables) {
// Make sure that
// (1) result has all required hashes
// (2) any extra hashes are preferred hashes (e.g. no pruning is needed)
Set<HashComputation> resultHashes = result.getHashVariables().keySet();
Set<HashComputation> requiredAndPreferredHashes = ImmutableSet.<HashComputation>builder()
.addAll(requiredHashes.getHashes())
.addAll(preferredHashes.getHashes())
.build();
preferenceSatisfied = resultHashes.containsAll(requiredHashes.getHashes()) &&
requiredAndPreferredHashes.containsAll(resultHashes);
}
else {
preferenceSatisfied = result.getHashVariables().keySet().containsAll(requiredHashes.getHashes());
}
if (preferenceSatisfied) {
return result;
}
return enforce(result, requiredHashes);
}
private PlanWithProperties enforce(PlanWithProperties planWithProperties, HashComputationSet requiredHashes)
{
Assignments.Builder assignments = Assignments.builder();
Map<HashComputation, VariableReferenceExpression> outputHashVariables = new HashMap<>();
// copy through all variables from child, except for hash variables not needed by the parent
Map<VariableReferenceExpression, HashComputation> resultHashVariables = planWithProperties.getHashVariables().inverse();
for (VariableReferenceExpression variable : planWithProperties.getNode().getOutputVariables()) {
HashComputation partitionVariables = resultHashVariables.get(variable);
if (partitionVariables == null || requiredHashes.getHashes().contains(partitionVariables)) {
assignments.put(variable, variable);
if (partitionVariables != null) {
outputHashVariables.put(partitionVariables, planWithProperties.getHashVariables().get(partitionVariables));
}
}
}
// add new projections for hash variables needed by the parent
for (HashComputation hashComputation : requiredHashes.getHashes()) {
if (!planWithProperties.getHashVariables().containsKey(hashComputation)) {
RowExpression hashExpression = hashComputation.getHashExpression();
VariableReferenceExpression hashVariable = variableAllocator.newHashVariable();
assignments.put(hashVariable, hashExpression);
outputHashVariables.put(hashComputation, hashVariable);
}
}
ProjectNode projectNode = new ProjectNode(planWithProperties.node.getSourceLocation(), idAllocator.getNextId(), planWithProperties.node.getStatsEquivalentPlanNode(), planWithProperties.getNode(), assignments.build(), LOCAL);
return new PlanWithProperties(projectNode, outputHashVariables);
}
private PlanWithProperties plan(PlanNode node, HashComputationSet parentPreference)
{
PlanWithProperties result = accept(node, parentPreference);
checkState(
result.getNode().getOutputVariables().containsAll(result.getHashVariables().values()),
"Node %s declares hash variables not in the output",
result.getNode().getClass().getSimpleName());
return new PlanWithProperties(result.getNode().assignStatsEquivalentPlanNode(node.getStatsEquivalentPlanNode()), result.getHashVariables());
}
private PlanWithProperties accept(PlanNode node, HashComputationSet context)
{
PlanWithProperties result = node.accept(this, context);
return new PlanWithProperties(
result.getNode().assignStatsEquivalentPlanNode(node.getStatsEquivalentPlanNode()),
result.getHashVariables());
}
}
private static class HashComputationSet
{
private final Set<HashComputation> hashes;
public HashComputationSet()
{
hashes = ImmutableSet.of();
}
public HashComputationSet(Optional<HashComputation> hash)
{
requireNonNull(hash, "hash is null");
if (hash.isPresent()) {
this.hashes = ImmutableSet.of(hash.get());
}
else {
this.hashes = ImmutableSet.of();
}
}
private HashComputationSet(Iterable<HashComputation> hashes)
{
requireNonNull(hashes, "hashes is null");
this.hashes = ImmutableSet.copyOf(hashes);
}
public Set<HashComputation> getHashes()
{
return hashes;
}
public HashComputationSet pruneVariables(List<VariableReferenceExpression> variables)
{
Set<VariableReferenceExpression> uniqueVariables = ImmutableSet.copyOf(variables);
return new HashComputationSet(hashes.stream()
.filter(hash -> hash.canComputeWith(uniqueVariables))
.collect(toImmutableSet()));
}
public HashComputationSet translate(Function<VariableReferenceExpression, Optional<VariableReferenceExpression>> translator)
{
Set<HashComputation> newHashes = hashes.stream()
.map(hash -> hash.translate(translator))
.filter(Optional::isPresent)
.map(Optional::get)
.collect(toImmutableSet());
return new HashComputationSet(newHashes);
}
public HashComputationSet withHashComputation(PlanNode node, Optional<HashComputation> hashComputation)
{
return pruneVariables(node.getOutputVariables()).withHashComputation(hashComputation);
}
public HashComputationSet withHashComputation(Optional<HashComputation> hashComputation)
{
if (!hashComputation.isPresent()) {
return this;
}
return new HashComputationSet(ImmutableSet.<HashComputation>builder()
.addAll(hashes)
.add(hashComputation.get())
.build());
}
}
private static Optional<HashComputation> computeHash(Iterable<VariableReferenceExpression> fields, FunctionAndTypeManager functionAndTypeManager)
{
requireNonNull(fields, "fields is null");
List<VariableReferenceExpression> variables = ImmutableList.copyOf(fields);
if (variables.isEmpty()) {
return Optional.empty();
}
return Optional.of(new HashComputation(fields, functionAndTypeManager));
}
private static class HashComputation
{
private final List<VariableReferenceExpression> fields;
private final FunctionAndTypeManager functionAndTypeManager;
private HashComputation(Iterable<VariableReferenceExpression> fields, FunctionAndTypeManager functionAndTypeManager)
{
requireNonNull(fields, "fields is null");
this.fields = ImmutableList.copyOf(fields);
this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionManager is null");
checkArgument(!this.fields.isEmpty(), "fields can not be empty");
}
public Optional<HashComputation> translate(Function<VariableReferenceExpression, Optional<VariableReferenceExpression>> translator)
{
ImmutableList.Builder<VariableReferenceExpression> newVariables = ImmutableList.builder();
for (VariableReferenceExpression field : fields) {
Optional<VariableReferenceExpression> newVariable = translator.apply(field);
if (!newVariable.isPresent()) {
return Optional.empty();
}
newVariables.add(newVariable.get());
}
return computeHash(newVariables.build(), functionAndTypeManager);
}
public boolean canComputeWith(Set<VariableReferenceExpression> availableFields)
{
return availableFields.containsAll(fields);
}
public boolean isSingleBigIntVariable()
{
return fields.size() == 1 && Iterables.getOnlyElement(fields).getType().equals(BIGINT);
}
private RowExpression getHashExpression()
{
RowExpression hashExpression = constant(INITIAL_HASH_VALUE, BIGINT);
for (VariableReferenceExpression field : fields) {
hashExpression = getHashFunctionCall(hashExpression, field);
}
return hashExpression;
}
private RowExpression getHashFunctionCall(RowExpression previousHashValue, VariableReferenceExpression variable)
{
CallExpression functionCall = call(functionAndTypeManager, HASH_CODE, BIGINT, variable);
return call(functionAndTypeManager, "combine_hash", BIGINT, previousHashValue, orNullHashCode(functionCall));
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
HashComputation that = (HashComputation) o;
return Objects.equals(fields, that.fields);
}
@Override
public int hashCode()
{
return Objects.hash(fields);
}
@Override
public String toString()
{
return toStringHelper(this)
.add("fields", fields)
.toString();
}
}
private static class PlanWithProperties
{
private final PlanNode node;
private final BiMap<HashComputation, VariableReferenceExpression> hashVariables;
public PlanWithProperties(PlanNode node, Map<HashComputation, VariableReferenceExpression> hashVariables)
{
this.node = requireNonNull(node, "node is null");
this.hashVariables = ImmutableBiMap.copyOf(requireNonNull(hashVariables, "hashVariables is null"));
}
public PlanNode getNode()
{
return node;
}
public BiMap<HashComputation, VariableReferenceExpression> getHashVariables()
{
return hashVariables;
}
public VariableReferenceExpression getRequiredHashVariable(HashComputation hash)
{
VariableReferenceExpression hashVariable = hashVariables.get(hash);
requireNonNull(hashVariable, () -> "No hash variable generated for " + hash);
return hashVariable;
}
}
private static Map<VariableReferenceExpression, VariableReferenceExpression> computeIdentityTranslations(Map<VariableReferenceExpression, RowExpression> assignments)
{
Map<VariableReferenceExpression, VariableReferenceExpression> outputToInput = new HashMap<>();
for (Map.Entry<VariableReferenceExpression, RowExpression> assignment : assignments.entrySet()) {
if (assignment.getValue() instanceof VariableReferenceExpression) {
outputToInput.put(assignment.getKey(), (VariableReferenceExpression) assignment.getValue());
}
}
return outputToInput;
}
}