AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;
import com.facebook.presto.Session;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.TaskCountEstimator;
import com.facebook.presto.cost.VariableStatsEstimate;
import com.facebook.presto.execution.TaskManagerConfig;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Partitioning;
import com.facebook.presto.spi.plan.PartitioningScheme;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties;
import com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.relational.ProjectNodeUtils;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multiset;
import io.airlift.units.DataSize;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import static com.facebook.presto.SystemSessionProperties.getTaskConcurrency;
import static com.facebook.presto.SystemSessionProperties.isEnabledAddExchangeBelowGroupId;
import static com.facebook.presto.matching.Capture.newCapture;
import static com.facebook.presto.matching.Pattern.nonEmpty;
import static com.facebook.presto.matching.Pattern.typeOf;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
import static com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties.fixedParallelism;
import static com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.deriveProperties;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.partitionedExchange;
import static com.facebook.presto.sql.planner.plan.Patterns.Aggregation.groupingColumns;
import static com.facebook.presto.sql.planner.plan.Patterns.Aggregation.step;
import static com.facebook.presto.sql.planner.plan.Patterns.source;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMultiset.toImmutableMultiset;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.lang.Double.isNaN;
import static java.lang.Math.min;
import static java.util.Comparator.comparing;
import static java.util.Objects.requireNonNull;
/**
* Transforms
* <pre>
* - Exchange
* - [ Projection ]
* - Partial Aggregation
* - GroupId
* </pre>
* to
* <pre>
* - Exchange
* - [ Projection ]
* - Partial Aggregation
* - GroupId
* - LocalExchange
* - RemoteExchange
* </pre>
* <p>
* Rationale: GroupId increases number of rows (number of times equal to number of grouping sets) and then
* partial aggregation reduces number of rows. However, under certain conditions, exchanging the rows before
* GroupId (before multiplication) makes partial aggregation more effective, resulting in less data being
* exchanged afterwards.
*/
public class AddExchangesBelowPartialAggregationOverGroupIdRuleSet
{
private static final Capture<ProjectNode> PROJECTION = newCapture();
private static final Capture<AggregationNode> AGGREGATION = newCapture();
private static final Capture<GroupIdNode> GROUP_ID = newCapture();
private static final Capture<ExchangeNode> REMOTE_EXCHANGE = newCapture();
private static final Pattern<ExchangeNode> WITH_PROJECTION =
// If there was no exchange here, adding new exchanges could break property derivations logic of AddExchanges, AddLocalExchanges
typeOf(ExchangeNode.class)
.matching(e -> e.getScope().isRemote()).capturedAs(REMOTE_EXCHANGE)
.with(source().matching(
// PushPartialAggregationThroughExchange adds a projection. However, it can be removed if RemoveRedundantIdentityProjections is run in the mean-time.
typeOf(ProjectNode.class).matching(ProjectNodeUtils::isIdentity).capturedAs(PROJECTION)
.with(source().matching(
typeOf(AggregationNode.class).capturedAs(AGGREGATION)
.with(step().equalTo(AggregationNode.Step.PARTIAL))
.with(nonEmpty(groupingColumns()))
.with(source().matching(
typeOf(GroupIdNode.class).capturedAs(GROUP_ID)))))));
private static final Pattern<ExchangeNode> WITHOUT_PROJECTION =
// If there was no exchange here, adding new exchanges could break property derivations logic of AddExchanges, AddLocalExchanges
typeOf(ExchangeNode.class)
.matching(e -> e.getScope().isRemote()).capturedAs(REMOTE_EXCHANGE)
.with(source().matching(
typeOf(AggregationNode.class).capturedAs(AGGREGATION)
.with(step().equalTo(AggregationNode.Step.PARTIAL))
.with(nonEmpty(groupingColumns()))
.with(source().matching(
typeOf(GroupIdNode.class).capturedAs(GROUP_ID)))));
private static final double GROUPING_SETS_SYMBOL_REQUIRED_FREQUENCY = 0.5;
private static final double ANTI_SKEWNESS_MARGIN = 3;
private final TaskCountEstimator taskCountEstimator;
private final DataSize maxPartialAggregationMemoryUsage;
private final Metadata metadata;
private final boolean nativeExecution;
public AddExchangesBelowPartialAggregationOverGroupIdRuleSet(
TaskCountEstimator taskCountEstimator,
TaskManagerConfig taskManagerConfig,
Metadata metadata,
boolean nativeExecution)
{
this.taskCountEstimator = requireNonNull(taskCountEstimator, "taskCountEstimator is null");
this.maxPartialAggregationMemoryUsage = taskManagerConfig.getMaxPartialAggregationMemoryUsage();
this.metadata = metadata;
this.nativeExecution = nativeExecution;
}
public Set<Rule<?>> rules()
{
return ImmutableSet.of(
belowProjectionRule(),
belowExchangeRule());
}
@VisibleForTesting
AddExchangesBelowExchangePartialAggregationGroupId belowExchangeRule()
{
return new AddExchangesBelowExchangePartialAggregationGroupId();
}
@VisibleForTesting
AddExchangesBelowProjectionPartialAggregationGroupId belowProjectionRule()
{
return new AddExchangesBelowProjectionPartialAggregationGroupId();
}
@VisibleForTesting
class AddExchangesBelowProjectionPartialAggregationGroupId
extends BaseAddExchangesBelowExchangePartialAggregationGroupId
{
@Override
public Pattern<ExchangeNode> getPattern()
{
return WITH_PROJECTION;
}
@Override
public Result apply(ExchangeNode exchange, Captures captures, Context context)
{
ProjectNode project = captures.get(PROJECTION);
AggregationNode aggregation = captures.get(AGGREGATION);
GroupIdNode groupId = captures.get(GROUP_ID);
return transform(aggregation, groupId, context)
.map(newAggregation -> Result.ofPlanNode(
exchange.replaceChildren(ImmutableList.of(
project.replaceChildren(ImmutableList.of(
newAggregation))))))
.orElseGet(Result::empty);
}
}
@VisibleForTesting
class AddExchangesBelowExchangePartialAggregationGroupId
extends BaseAddExchangesBelowExchangePartialAggregationGroupId
{
@Override
public Pattern<ExchangeNode> getPattern()
{
return WITHOUT_PROJECTION;
}
@Override
public Result apply(ExchangeNode exchange, Captures captures, Context context)
{
AggregationNode aggregation = captures.get(AGGREGATION);
GroupIdNode groupId = captures.get(GROUP_ID);
return transform(aggregation, groupId, context)
.map(newAggregation -> {
PlanNode newExchange = exchange.replaceChildren(ImmutableList.of(newAggregation));
return Result.ofPlanNode(newExchange);
})
.orElseGet(Result::empty);
}
}
private abstract class BaseAddExchangesBelowExchangePartialAggregationGroupId
implements Rule<ExchangeNode>
{
@Override
public boolean isEnabled(Session session)
{
return isEnabledAddExchangeBelowGroupId(session);
}
protected Optional<PlanNode> transform(AggregationNode aggregation, GroupIdNode groupId, Context context)
{
Set<VariableReferenceExpression> groupingKeys = aggregation.getGroupingKeys().stream()
.filter(symbol -> !groupId.getGroupIdVariable().equals(symbol))
.collect(toImmutableSet());
Multiset<VariableReferenceExpression> groupingSetHistogram = groupId.getGroupingSets().stream()
.flatMap(Collection::stream)
.collect(toImmutableMultiset());
if (!Objects.equals(groupingSetHistogram.elementSet(), groupingKeys)) {
// TODO handle the case when some aggregation keys are pass-through in GroupId (e.g. common in all grouping sets)
// TODO handle the case when some grouping set symbols are not used in aggregation (possible?)
return Optional.empty();
}
double aggregationMemoryRequirements = estimateAggregationMemoryRequirements(groupingKeys, groupId, groupingSetHistogram, context);
if (isNaN(aggregationMemoryRequirements) || aggregationMemoryRequirements < maxPartialAggregationMemoryUsage.toBytes()) {
// Aggregation will be effective even without exchanges (or we have insufficient information).
return Optional.empty();
}
List<VariableReferenceExpression> desiredHashVariables = groupingSetHistogram.entrySet().stream()
// Take only frequently used symbols
.filter(entry -> entry.getCount() >= groupId.getGroupingSets().size() * GROUPING_SETS_SYMBOL_REQUIRED_FREQUENCY)
.map(Multiset.Entry::getElement)
// And only the symbols used in the aggregation (these are usually all symbols)
.peek(symbol -> verify(groupingKeys.contains(symbol), "%s not found in the grouping keys [%s]", symbol, groupingKeys))
// Transform to symbols before GroupId
.map(groupId.getGroupingColumns()::get)
.collect(toImmutableList());
// Use only the symbol with the highest cardinality (if we have statistics). This makes partial aggregation more efficient in case of
// low correlation between symbol that are in every grouping set vs additional symbols.
PlanNodeStatsEstimate sourceStats = context.getStatsProvider().getStats(groupId.getSource());
desiredHashVariables = desiredHashVariables.stream()
.filter(symbol -> !isNaN(sourceStats.getVariableStatistics(symbol).getDistinctValuesCount()))
.max(comparing(symbol -> sourceStats.getVariableStatistics(symbol).getDistinctValuesCount()))
.map(symbol -> (List<VariableReferenceExpression>) ImmutableList.of(symbol)).orElse(desiredHashVariables);
StreamPreferredProperties requiredProperties = fixedParallelism().withPartitioning(desiredHashVariables);
StreamProperties sourceProperties = derivePropertiesRecursively(groupId.getSource(), context);
if (requiredProperties.isSatisfiedBy(sourceProperties)) {
// Stream is already (locally) partitioned just as we want.
// In fact, there might be just a LocalExchange below and no Remote. For now, we give up in this situation anyway. To properly support such situation:
// 1. aggregation effectiveness estimation below need to consider the (helpful) fact that stream is already partitioned, so each operator will need less memory
// 2. if the local exchange becomes unnecessary (after we add a remove on top of it), it should be removed. What if the local exchange is somewhere further
// down the tree?
return Optional.empty();
}
double estimatedGroups = estimateGroupCount(desiredHashVariables, context.getStatsProvider().getStats(groupId.getSource()));
if (isNaN(estimatedGroups) || estimatedGroups * ANTI_SKEWNESS_MARGIN < maximalConcurrencyAfterRepartition(context)) {
// Desired hash symbols form too few groups. Hashing over them would harm concurrency.
// TODO instead of taking symbols with >GROUPING_SETS_SYMBOL_REQUIRED_FREQUENCY presence, we could take symbols from high freq to low until there are enough groups
return Optional.empty();
}
PlanNode source = groupId.getSource();
// Above we only checked the data is not yet locally partitioned and it could be already globally partitioned (but not locally). TODO avoid remote exchange in this case
// TODO If the aggregation memory requirements are only slightly above `maxPartialAggregationMemoryUsage`, adding only LocalExchange could be enough
source = partitionedExchange(
context.getIdAllocator().getNextId(),
REMOTE_STREAMING,
source,
new PartitioningScheme(
Partitioning.create(FIXED_HASH_DISTRIBUTION, desiredHashVariables),
source.getOutputVariables()));
source = partitionedExchange(
context.getIdAllocator().getNextId(),
LOCAL,
source,
new PartitioningScheme(
Partitioning.create(FIXED_HASH_DISTRIBUTION, desiredHashVariables),
source.getOutputVariables()));
PlanNode newGroupId = groupId.replaceChildren(ImmutableList.of(source));
PlanNode newAggregation = aggregation.replaceChildren(ImmutableList.of(newGroupId));
return Optional.of(newAggregation);
}
private int maximalConcurrencyAfterRepartition(Context context)
{
return getTaskConcurrency(context.getSession()) * taskCountEstimator.estimateHashedTaskCount(context.getSession());
}
private double estimateAggregationMemoryRequirements(Set<VariableReferenceExpression> groupingKeys,
GroupIdNode groupId,
Multiset<VariableReferenceExpression> groupingSetHistogram,
Context context)
{
checkArgument(Objects.equals(groupingSetHistogram.elementSet(), groupingKeys)); // Otherwise math below would be off-topic
PlanNodeStatsEstimate sourceStats = context.getStatsProvider().getStats(groupId.getSource());
double keysMemoryRequirements = 0;
for (List<VariableReferenceExpression> groupingSet : groupId.getGroupingSets()) {
List<VariableReferenceExpression> sourceVariables = groupingSet.stream()
.map(groupId.getGroupingColumns()::get)
.collect(toImmutableList());
double keyWidth = sourceStats.getOutputSizeForVariables(sourceVariables) / sourceStats.getOutputRowCount();
double keyNdv = min(estimateGroupCount(sourceVariables, sourceStats), sourceStats.getOutputRowCount());
keysMemoryRequirements += keyWidth * keyNdv;
}
// TODO consider also memory requirements for aggregation values
return keysMemoryRequirements;
}
private double estimateGroupCount(List<VariableReferenceExpression> variables, PlanNodeStatsEstimate statsEstimate)
{
return variables.stream()
.map(statsEstimate::getVariableStatistics)
.mapToDouble(this::ndvIncludingNull)
// This assumes no correlation, maximum number of aggregation keys
.reduce(1, (a, b) -> a * b);
}
private double ndvIncludingNull(VariableStatsEstimate variableStatsEstimate)
{
if (variableStatsEstimate.getNullsFraction() == 0.) {
return variableStatsEstimate.getDistinctValuesCount();
}
return variableStatsEstimate.getDistinctValuesCount() + 1;
}
private StreamProperties derivePropertiesRecursively(PlanNode node, Context context)
{
PlanNode resolvedPlanNode = context.getLookup().resolve(node);
List<StreamProperties> inputProperties = resolvedPlanNode.getSources().stream()
.map(source -> derivePropertiesRecursively(source, context))
.collect(toImmutableList());
return deriveProperties(resolvedPlanNode, inputProperties, metadata, context.getSession(), nativeExecution);
}
}
}