PushAggregationThroughOuterJoin.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;
import com.facebook.presto.Session;
import com.facebook.presto.common.QualifiedObjectName;
import com.facebook.presto.common.block.SortOrder;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.Ordering;
import com.facebook.presto.spi.plan.OrderingScheme;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import static com.facebook.presto.SystemSessionProperties.shouldPushAggregationThroughJoin;
import static com.facebook.presto.SystemSessionProperties.useDefaultsForCorrelatedAggregationPushdownThroughOuterJoins;
import static com.facebook.presto.matching.Capture.newCapture;
import static com.facebook.presto.spi.plan.AggregationNode.globalAggregation;
import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet;
import static com.facebook.presto.sql.planner.PlannerUtils.coalesce;
import static com.facebook.presto.sql.planner.RowExpressionVariableInliner.inlineVariables;
import static com.facebook.presto.sql.planner.optimizations.DistinctOutputQueryUtil.isDistinct;
import static com.facebook.presto.sql.planner.plan.Patterns.aggregation;
import static com.facebook.presto.sql.planner.plan.Patterns.join;
import static com.facebook.presto.sql.planner.plan.Patterns.source;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.facebook.presto.sql.relational.Expressions.constantNull;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;
/**
* This optimizer pushes aggregations below outer joins when: the aggregation
* is on top of the outer join, it groups by all columns in the outer table, and
* the outer rows are guaranteed to be distinct.
* <p>
* When the aggregation is pushed down, we still need to perform aggregations
* on the null values that come out of the absent values in an outer
* join. We add a cross join with a row of aggregations on null literals,
* and coalesce the aggregation that results from the left outer join with
* the result of the aggregation over nulls.
* <p>
* Example:
* <pre>
* - Filter ("nationkey" > "avg")
* - Aggregate(Group by: all columns from the left table, aggregation:
* avg("n2.nationkey"))
* - LeftJoin("regionkey" = "regionkey")
* - AssignUniqueId (nation)
* - Tablescan (nation)
* - Tablescan (nation)
* </pre>
* </p>
* Is rewritten to:
* <pre>
* - Filter ("nationkey" > "avg")
* - project(regionkey, coalesce("avg", "avg_over_null")
* - CrossJoin
* - LeftJoin("regionkey" = "regionkey")
* - AssignUniqueId (nation)
* - Tablescan (nation)
* - Aggregate(Group by: regionkey, aggregation:
* avg(nationkey))
* - Tablescan (nation)
* - Aggregate
* avg(null_literal)
* - Values (null_literal)
* </pre>
*/
public class PushAggregationThroughOuterJoin
implements Rule<AggregationNode>
{
private static final Capture<JoinNode> JOIN = newCapture();
private static final Pattern<AggregationNode> PATTERN = aggregation()
.with(source().matching(join().capturedAs(JOIN)));
private final FunctionAndTypeManager functionAndTypeManager;
public PushAggregationThroughOuterJoin(FunctionAndTypeManager functionAndTypeManager)
{
this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionManager is null");
}
@Override
public Pattern<AggregationNode> getPattern()
{
return PATTERN;
}
@Override
public boolean isEnabled(Session session)
{
return shouldPushAggregationThroughJoin(session);
}
@Override
public Result apply(AggregationNode aggregation, Captures captures, Context context)
{
JoinNode join = captures.get(JOIN);
if (join.getFilter().isPresent()
|| !(join.getType() == JoinType.LEFT || join.getType() == JoinType.RIGHT)
|| !groupsOnAllColumns(aggregation, getOuterTable(join).getOutputVariables())
|| !isDistinct(context.getLookup().resolve(getOuterTable(join)), context.getLookup()::resolve)) {
return Result.empty();
}
List<VariableReferenceExpression> groupingKeys = join.getCriteria().stream()
.map(join.getType() == JoinType.RIGHT ? EquiJoinClause::getLeft : EquiJoinClause::getRight)
.collect(toImmutableList());
AggregationNode rewrittenAggregation = new AggregationNode(
aggregation.getSourceLocation(),
aggregation.getId(),
getInnerTable(join),
aggregation.getAggregations(),
singleGroupingSet(groupingKeys),
ImmutableList.of(),
aggregation.getStep(),
aggregation.getHashVariable(),
aggregation.getGroupIdVariable(),
aggregation.getAggregationId());
JoinNode rewrittenJoin;
if (join.getType() == JoinType.LEFT) {
rewrittenJoin = new JoinNode(
join.getSourceLocation(),
join.getId(),
join.getType(),
join.getLeft(),
rewrittenAggregation,
join.getCriteria(),
ImmutableList.<VariableReferenceExpression>builder()
.addAll(join.getLeft().getOutputVariables())
.addAll(rewrittenAggregation.getAggregations().keySet())
.build(),
join.getFilter(),
join.getLeftHashVariable(),
join.getRightHashVariable(),
join.getDistributionType(),
join.getDynamicFilters());
}
else {
rewrittenJoin = new JoinNode(
join.getSourceLocation(),
join.getId(),
join.getType(),
rewrittenAggregation,
join.getRight(),
join.getCriteria(),
ImmutableList.<VariableReferenceExpression>builder()
.addAll(rewrittenAggregation.getAggregations().keySet())
.addAll(join.getRight().getOutputVariables())
.build(),
join.getFilter(),
join.getLeftHashVariable(),
join.getRightHashVariable(),
join.getDistributionType(),
join.getDynamicFilters());
}
Optional<PlanNode> resultNode = coalesceWithNullAggregation(rewrittenAggregation, rewrittenJoin, context.getVariableAllocator(), context.getIdAllocator(), context.getLookup(), useDefaultsForCorrelatedAggregationPushdownThroughOuterJoins(context.getSession()));
if (!resultNode.isPresent()) {
return Result.empty();
}
return Result.ofPlanNode(resultNode.get());
}
private static PlanNode getInnerTable(JoinNode join)
{
checkState(join.getType() == JoinType.LEFT || join.getType() == JoinType.RIGHT, "expected LEFT or RIGHT JOIN");
PlanNode innerNode;
if (join.getType().equals(JoinType.LEFT)) {
innerNode = join.getRight();
}
else {
innerNode = join.getLeft();
}
return innerNode;
}
private static PlanNode getOuterTable(JoinNode join)
{
checkState(join.getType() == JoinType.LEFT || join.getType() == JoinType.RIGHT, "expected LEFT or RIGHT JOIN");
PlanNode outerNode;
if (join.getType().equals(JoinType.LEFT)) {
outerNode = join.getLeft();
}
else {
outerNode = join.getRight();
}
return outerNode;
}
private static boolean groupsOnAllColumns(AggregationNode node, List<VariableReferenceExpression> columns)
{
return new HashSet<>(node.getGroupingKeys()).equals(new HashSet<>(columns));
}
// When the aggregation is done after the join, there will be a null value that gets aggregated over
// where rows did not exist in the inner table. For some aggregate functions, such as count, the result
// of an aggregation over a single null row is one or zero rather than null. In order to ensure correct results,
// we add a coalesce function with the output of the new outer join and the aggregation performed over a single
// null row.
private Optional<PlanNode> coalesceWithNullAggregation(AggregationNode aggregationNode, PlanNode outerJoin, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, Lookup lookup, boolean useDefaultsForCorrelatedAggregations)
{
// Create an aggregation node over a row of nulls.
Optional<MappedAggregationInfo> aggregationOverNullInfoResultNode = createAggregationOverNull(
aggregationNode,
variableAllocator,
idAllocator,
lookup);
if (!aggregationOverNullInfoResultNode.isPresent()) {
return Optional.empty();
}
MappedAggregationInfo aggregationOverNullInfo = aggregationOverNullInfoResultNode.get();
AggregationNode aggregationOverNull = aggregationOverNullInfo.getAggregation();
Map<VariableReferenceExpression, VariableReferenceExpression> sourceAggregationToOverNullMapping = aggregationOverNullInfo.getVariableMapping();
FunctionResolution functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
Map<VariableReferenceExpression, RowExpression> literalMap = new HashMap<>();
if (useDefaultsForCorrelatedAggregations) {
for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> aggregation : aggregationNode.getAggregations().entrySet()) {
FunctionHandle functionHandle = aggregation.getValue().getFunctionHandle();
Optional<RowExpression> defaultLiteral = Optional.empty();
if (functionResolution.isCountFunction(functionHandle) && !aggregation.getValue().getArguments().isEmpty()) { // Can also include count_if
defaultLiteral = Optional.of(constant(Long.valueOf(0), aggregation.getKey().getType()));
}
else if (!functionAndTypeManager.getFunctionMetadata(functionHandle).isCalledOnNullInput()) {
defaultLiteral = Optional.of(constantNull(aggregation.getKey().getType()));
}
if (defaultLiteral.isPresent()) {
literalMap.put(aggregation.getKey(), defaultLiteral.get());
}
}
}
PlanNode finalJoinNode = outerJoin;
if (literalMap.size() < aggregationNode.getAggregations().size()) {
// Do a cross join with the aggregation over null
finalJoinNode = new JoinNode(
outerJoin.getSourceLocation(),
idAllocator.getNextId(),
JoinType.INNER,
outerJoin,
aggregationOverNull,
ImmutableList.of(),
ImmutableList.<VariableReferenceExpression>builder()
.addAll(outerJoin.getOutputVariables())
.addAll(aggregationOverNull.getOutputVariables())
.build(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
ImmutableMap.of());
}
// Add coalesce expressions for all aggregation functions
Assignments.Builder assignmentsBuilder = Assignments.builder();
for (VariableReferenceExpression variable : outerJoin.getOutputVariables()) {
if (aggregationNode.getAggregations().keySet().contains(variable)) {
RowExpression coalesceArgument = literalMap.containsKey(variable) ? literalMap.get(variable) : sourceAggregationToOverNullMapping.get(variable);
assignmentsBuilder.put(variable, coalesce(ImmutableList.of(variable, coalesceArgument)));
}
else {
assignmentsBuilder.put(variable, variable);
}
}
return Optional.of(new ProjectNode(idAllocator.getNextId(), finalJoinNode, assignmentsBuilder.build()));
}
private Optional<MappedAggregationInfo> createAggregationOverNull(AggregationNode referenceAggregation, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, Lookup lookup)
{
// Create a values node that consists of a single row of nulls.
// Map the output symbols from the referenceAggregation's source
// to symbol references for the new values node.
ImmutableList.Builder<VariableReferenceExpression> nullVariables = ImmutableList.builder();
ImmutableList.Builder<RowExpression> nullLiterals = ImmutableList.builder();
ImmutableMap.Builder<VariableReferenceExpression, VariableReferenceExpression> sourcesVariableMappingBuilder = ImmutableMap.builder();
for (VariableReferenceExpression sourceVariable : referenceAggregation.getSource().getOutputVariables()) {
RowExpression nullLiteral = constantNull(sourceVariable.getSourceLocation(), sourceVariable.getType());
nullLiterals.add(nullLiteral);
VariableReferenceExpression nullVariable = variableAllocator.newVariable(nullLiteral);
nullVariables.add(nullVariable);
// TODO The type should be from sourceVariable.getType
sourcesVariableMappingBuilder.put(sourceVariable, nullVariable);
}
ValuesNode nullRow = new ValuesNode(
referenceAggregation.getSourceLocation(),
idAllocator.getNextId(),
nullVariables.build(),
ImmutableList.of(nullLiterals.build()),
Optional.empty());
Map<VariableReferenceExpression, VariableReferenceExpression> sourcesVariableMapping = sourcesVariableMappingBuilder.build();
// For each aggregation function in the reference node, create a corresponding aggregation function
// that points to the nullRow. Map the symbols from the aggregations in referenceAggregation to the
// symbols in these new aggregations.
ImmutableMap.Builder<VariableReferenceExpression, VariableReferenceExpression> aggregationsVariableMappingBuilder = ImmutableMap.builder();
ImmutableMap.Builder<VariableReferenceExpression, AggregationNode.Aggregation> aggregationsOverNullBuilder = ImmutableMap.builder();
for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : referenceAggregation.getAggregations().entrySet()) {
VariableReferenceExpression aggregationVariable = entry.getKey();
AggregationNode.Aggregation aggregation = entry.getValue();
if (!isUsingVariables(aggregation, sourcesVariableMapping.keySet())) {
return Optional.empty();
}
AggregationNode.Aggregation overNullAggregation = new AggregationNode.Aggregation(
new CallExpression(
aggregation.getCall().getSourceLocation(),
aggregation.getCall().getDisplayName(),
aggregation.getCall().getFunctionHandle(),
aggregation.getCall().getType(),
aggregation.getArguments()
.stream()
.map(argument -> inlineVariables(sourcesVariableMapping, argument))
.collect(toImmutableList())),
aggregation.getFilter().map(filter -> inlineVariables(sourcesVariableMapping, filter)),
aggregation.getOrderBy().map(orderBy -> inlineOrderByVariables(sourcesVariableMapping, orderBy)),
aggregation.isDistinct(),
aggregation.getMask().map(x -> new VariableReferenceExpression(sourcesVariableMapping.get(x).getSourceLocation(), sourcesVariableMapping.get(x).getName(), x.getType())));
QualifiedObjectName functionName = functionAndTypeManager.getFunctionMetadata(overNullAggregation.getFunctionHandle()).getName();
VariableReferenceExpression overNull = variableAllocator.newVariable(aggregation.getCall().getSourceLocation(), functionName.getObjectName(), aggregationVariable.getType());
aggregationsOverNullBuilder.put(overNull, overNullAggregation);
aggregationsVariableMappingBuilder.put(aggregationVariable, overNull);
}
Map<VariableReferenceExpression, VariableReferenceExpression> aggregationsSymbolMapping = aggregationsVariableMappingBuilder.build();
// create an aggregation node whose source is the null row.
AggregationNode aggregationOverNullRow = new AggregationNode(
referenceAggregation.getSourceLocation(),
idAllocator.getNextId(),
nullRow,
aggregationsOverNullBuilder.build(),
globalAggregation(),
ImmutableList.of(),
AggregationNode.Step.SINGLE,
Optional.empty(),
Optional.empty(),
Optional.empty());
return Optional.of(new MappedAggregationInfo(aggregationOverNullRow, aggregationsSymbolMapping));
}
private static OrderingScheme inlineOrderByVariables(Map<VariableReferenceExpression, VariableReferenceExpression> variableMapping, OrderingScheme orderingScheme)
{
// This is a logic expanded from ExpressionTreeRewriter::rewriteSortItems
ImmutableList.Builder<VariableReferenceExpression> orderBy = ImmutableList.builder();
ImmutableMap.Builder<VariableReferenceExpression, SortOrder> ordering = new ImmutableMap.Builder<>();
for (VariableReferenceExpression variable : orderingScheme.getOrderByVariables()) {
VariableReferenceExpression translated = variableMapping.get(variable);
orderBy.add(translated);
ordering.put(translated, orderingScheme.getOrdering(variable));
}
ImmutableMap<VariableReferenceExpression, SortOrder> orderingMap = ordering.build();
return new OrderingScheme(orderBy.build().stream().map(variable -> new Ordering(variable, orderingMap.get(variable))).collect(toImmutableList()));
}
private static boolean isUsingVariables(AggregationNode.Aggregation aggregation, Set<VariableReferenceExpression> sourceVariables)
{
Set<VariableReferenceExpression> inputVariables = new HashSet<>();
for (RowExpression argument : aggregation.getArguments()) {
if (argument instanceof VariableReferenceExpression) {
inputVariables.add((VariableReferenceExpression) argument);
}
}
return sourceVariables.stream()
.anyMatch(inputVariables::contains);
}
private static class MappedAggregationInfo
{
private final AggregationNode aggregationNode;
private final Map<VariableReferenceExpression, VariableReferenceExpression> variableMapping;
public MappedAggregationInfo(AggregationNode aggregationNode, Map<VariableReferenceExpression, VariableReferenceExpression> variableMapping)
{
this.aggregationNode = aggregationNode;
this.variableMapping = variableMapping;
}
public Map<VariableReferenceExpression, VariableReferenceExpression> getVariableMapping()
{
return variableMapping;
}
public AggregationNode getAggregation()
{
return aggregationNode;
}
}
}