AddIntermediateAggregations.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;
import com.facebook.presto.Session;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.AggregationNode.Aggregation;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.Map;
import java.util.Optional;
import static com.facebook.presto.SystemSessionProperties.getTaskConcurrency;
import static com.facebook.presto.SystemSessionProperties.isEnableIntermediateAggregations;
import static com.facebook.presto.matching.Pattern.empty;
import static com.facebook.presto.spi.plan.AggregationNode.Step.FINAL;
import static com.facebook.presto.spi.plan.AggregationNode.Step.INTERMEDIATE;
import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL;
import static com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils.extractAggregationUniqueVariables;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.gatheringExchange;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.roundRobinExchange;
import static com.facebook.presto.sql.planner.plan.Patterns.Aggregation.groupingColumns;
import static com.facebook.presto.sql.planner.plan.Patterns.Aggregation.step;
import static com.facebook.presto.sql.planner.plan.Patterns.aggregation;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.Iterables.getOnlyElement;
/**
* Adds INTERMEDIATE aggregations between an un-grouped FINAL aggregation and its preceding
* PARTIAL aggregation.
* <p>
* From:
* <pre>
* - Aggregation (FINAL)
* - RemoteExchange (GATHER)
* - Aggregation (PARTIAL)
* </pre>
* To:
* <pre>
* - Aggregation (FINAL)
* - LocalExchange (GATHER)
* - Aggregation (INTERMEDIATE)
* - LocalExchange (ARBITRARY)
* - RemoteExchange (GATHER)
* - Aggregation (INTERMEDIATE)
* - LocalExchange (GATHER)
* - Aggregation (PARTIAL)
* </pre>
* <p>
*/
public class AddIntermediateAggregations
implements Rule<AggregationNode>
{
private static final Pattern<AggregationNode> PATTERN = aggregation()
// Only consider FINAL un-grouped aggregations
.with(step().equalTo(FINAL))
.with(empty(groupingColumns()))
// Only consider aggregations without ORDER BY clause
.matching(node -> !node.hasOrderings());
@Override
public Pattern<AggregationNode> getPattern()
{
return PATTERN;
}
@Override
public boolean isEnabled(Session session)
{
return isEnableIntermediateAggregations(session);
}
@Override
public Result apply(AggregationNode aggregation, Captures captures, Context context)
{
Lookup lookup = context.getLookup();
PlanNodeIdAllocator idAllocator = context.getIdAllocator();
Session session = context.getSession();
TypeProvider types = TypeProvider.viewOf(context.getVariableAllocator().getVariables());
Optional<PlanNode> rewrittenSource = recurseToPartial(lookup.resolve(aggregation.getSource()), lookup, idAllocator, types);
if (!rewrittenSource.isPresent()) {
return Result.empty();
}
PlanNode source = rewrittenSource.get();
if (getTaskConcurrency(session) > 1) {
Map<VariableReferenceExpression, Aggregation> variableToAggregations = inputsAsOutputs(aggregation.getAggregations(), types);
if (variableToAggregations.isEmpty()) {
return Result.empty();
}
source = roundRobinExchange(idAllocator.getNextId(), LOCAL, source);
source = new AggregationNode(
aggregation.getSourceLocation(),
idAllocator.getNextId(),
source,
variableToAggregations,
aggregation.getGroupingSets(),
aggregation.getPreGroupedVariables(),
INTERMEDIATE,
aggregation.getHashVariable(),
aggregation.getGroupIdVariable(),
aggregation.getAggregationId());
source = gatheringExchange(idAllocator.getNextId(), LOCAL, source);
}
return Result.ofPlanNode(aggregation.replaceChildren(ImmutableList.of(source)));
}
/**
* Recurse through a series of preceding ExchangeNodes and ProjectNodes to find the preceding PARTIAL aggregation
*/
private Optional<PlanNode> recurseToPartial(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, TypeProvider types)
{
if (node instanceof AggregationNode && ((AggregationNode) node).getStep() == PARTIAL) {
return Optional.of(addGatheringIntermediate((AggregationNode) node, idAllocator, types));
}
if (!(node instanceof ExchangeNode) && !(node instanceof ProjectNode)) {
return Optional.empty();
}
ImmutableList.Builder<PlanNode> builder = ImmutableList.builder();
for (PlanNode source : node.getSources()) {
Optional<PlanNode> planNode = recurseToPartial(lookup.resolve(source), lookup, idAllocator, types);
if (!planNode.isPresent()) {
return Optional.empty();
}
builder.add(planNode.get());
}
return Optional.of(node.replaceChildren(builder.build()));
}
private PlanNode addGatheringIntermediate(AggregationNode aggregation, PlanNodeIdAllocator idAllocator, TypeProvider types)
{
verify(aggregation.getGroupingKeys().isEmpty(), "Should be an un-grouped aggregation");
ExchangeNode gatheringExchange = gatheringExchange(idAllocator.getNextId(), LOCAL, aggregation);
return new AggregationNode(
aggregation.getSourceLocation(),
idAllocator.getNextId(),
gatheringExchange,
outputsAsInputs(aggregation.getAggregations()),
aggregation.getGroupingSets(),
aggregation.getPreGroupedVariables(),
INTERMEDIATE,
aggregation.getHashVariable(),
aggregation.getGroupIdVariable(),
aggregation.getAggregationId());
}
/**
* Rewrite assignments so that inputs are in terms of the output symbols.
* <p>
* Example:
* 'a' := sum('b') => 'a' := sum('a')
* 'a' := count(*) => 'a' := count('a')
*/
private static Map<VariableReferenceExpression, Aggregation> outputsAsInputs(Map<VariableReferenceExpression, Aggregation> assignments)
{
ImmutableMap.Builder<VariableReferenceExpression, Aggregation> builder = ImmutableMap.builder();
for (Map.Entry<VariableReferenceExpression, Aggregation> entry : assignments.entrySet()) {
VariableReferenceExpression output = entry.getKey();
Aggregation aggregation = entry.getValue();
checkState(!aggregation.getOrderBy().isPresent(), "Intermediate aggregation does not support ORDER BY");
appendAggregation(builder, aggregation, output, aggregation.getCall().getType());
}
return builder.build();
}
/**
* Rewrite assignments so that outputs are in terms of the input symbols.
* This operation only reliably applies to aggregation steps that take partial inputs (e.g. INTERMEDIATE and split FINALs),
* which are guaranteed to have exactly one input and one output.
* <p>
* Example:
* 'a' := sum('b') => 'b' := sum('b')
*/
private static Map<VariableReferenceExpression, Aggregation> inputsAsOutputs(Map<VariableReferenceExpression, Aggregation> assignments, TypeProvider types)
{
ImmutableMap.Builder<VariableReferenceExpression, Aggregation> builder = ImmutableMap.builder();
for (Map.Entry<VariableReferenceExpression, Aggregation> entry : assignments.entrySet()) {
// Should only have one input symbol
Aggregation aggregation = entry.getValue();
if (!(aggregation.getArguments().size() == 1 && !aggregation.getOrderBy().isPresent() && !aggregation.getFilter().isPresent())) {
return ImmutableMap.of();
}
VariableReferenceExpression input = getOnlyElement(extractAggregationUniqueVariables(entry.getValue()));
// Return type of intermediate aggregation is the same as the input type.
RowExpression argumentExpr = aggregation.getCall().getArguments().get(0);
Type returnType = argumentExpr.getType();
appendAggregation(builder, aggregation, input, returnType);
}
return builder.build();
}
/**
* Helper function to add an aggregation to the aggregation map builder.
*/
private static void appendAggregation(ImmutableMap.Builder<VariableReferenceExpression, Aggregation> builder, Aggregation aggregation, VariableReferenceExpression varRef, Type returnType)
{
builder.put(
varRef,
new Aggregation(
new CallExpression(
aggregation.getCall().getSourceLocation(),
aggregation.getCall().getDisplayName(),
aggregation.getCall().getFunctionHandle(),
returnType,
ImmutableList.of(varRef)),
Optional.empty(),
Optional.empty(),
false,
Optional.empty())); // No mask for INTERMEDIATE
}
}