MergePartialAggregationsWithFilter.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.PartitioningScheme;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import static com.facebook.presto.SystemSessionProperties.isMergeAggregationsWithAndWithoutFilter;
import static com.facebook.presto.expressions.LogicalRowExpressions.or;
import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static com.facebook.presto.spi.plan.AggregationNode.Step.FINAL;
import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL;
import static com.facebook.presto.sql.planner.PlannerUtils.addProjections;
import static com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils.removeFilterAndMask;
import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignments;
import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren;
import static com.facebook.presto.sql.relational.Expressions.constantNull;
import static com.facebook.presto.sql.relational.Expressions.specialForm;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.util.Objects.requireNonNull;

/**
 * Merge partial aggregations which have mask with the partial aggregation without mask when all other fields are the same
 *
 * <pre>
 *     - Aggregation (Final)
 *          sum_1 := sum(partial_sum_1)
 *          sum_2 := sum(partial_sum_2)
 *          group_by_key [gb]
 *          - Exchange
 *              - Aggregation (Partial)
 *                  partial_sum_1 := sum(a)
 *                  partial_sum_2 := sum(a) mask m
 *                  group_by_key [gb]
 * </pre>
 * into
 * <pre>
 *     - Aggregation (Final)
 *          sum_1 := sum(partial_sum_1)
 *          sum_2 := sum(partial_sum_2)
 *          group_by_key [gb]
 *          - Project
 *              partial_sum_2 := IF(m, partial_sum_1, null)
 *              - Exchange
 *                  - Aggregation (Partial)
 *                      partial_sum_1 := sum(a)
 *                      group_by_key [gb, m]
 * </pre>
 */
public class MergePartialAggregationsWithFilter
        implements PlanOptimizer
{
    private final FunctionAndTypeManager functionAndTypeManager;
    private boolean isEnabledForTesting;

    public MergePartialAggregationsWithFilter(FunctionAndTypeManager functionAndTypeManager)
    {
        this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
    }

    @Override
    public void setEnabledForTesting(boolean isSet)
    {
        isEnabledForTesting = isSet;
    }

    @Override
    public boolean isEnabled(Session session)
    {
        return isEnabledForTesting || isMergeAggregationsWithAndWithoutFilter(session);
    }

    @Override
    public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector)
    {
        if (isEnabled(session)) {
            Rewriter rewriter = new Rewriter(session, variableAllocator, idAllocator, functionAndTypeManager);
            PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan, new Context());
            return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged());
        }

        return PlanOptimizerResult.optimizerResult(plan, false);
    }

    private static class Context
    {
        private final Map<VariableReferenceExpression, VariableReferenceExpression> partialResultToMask;
        private final Map<VariableReferenceExpression, VariableReferenceExpression> partialOutputMapping;
        private final List<VariableReferenceExpression> newAggregationOutput;

        public Context()
        {
            partialResultToMask = new HashMap<>();
            partialOutputMapping = new HashMap<>();
            newAggregationOutput = new LinkedList<>();
        }

        public boolean isEmpty()
        {
            return partialOutputMapping.isEmpty();
        }

        public void clear()
        {
            partialResultToMask.clear();
            partialOutputMapping.clear();
            newAggregationOutput.clear();
        }

        public Map<VariableReferenceExpression, VariableReferenceExpression> getPartialOutputMapping()
        {
            return partialOutputMapping;
        }

        public Map<VariableReferenceExpression, VariableReferenceExpression> getPartialResultToMask()
        {
            return partialResultToMask;
        }

        public List<VariableReferenceExpression> getNewAggregationOutput()
        {
            return newAggregationOutput;
        }
    }

    private static class Rewriter
            extends SimplePlanRewriter<Context>
    {
        private final Session session;
        private final VariableAllocator variableAllocator;
        private final PlanNodeIdAllocator planNodeIdAllocator;
        private final FunctionAndTypeManager functionAndTypeManager;
        private boolean planChanged;

        public Rewriter(Session session, VariableAllocator variableAllocator, PlanNodeIdAllocator planNodeIdAllocator, FunctionAndTypeManager functionAndTypeManager)
        {
            this.session = requireNonNull(session, "session is null");
            this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator is null");
            this.planNodeIdAllocator = requireNonNull(planNodeIdAllocator, "planNodeIdAllocator is null");
            this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
        }

        public static RowExpression ifThenElse(RowExpression... arguments)
        {
            return specialForm(SpecialFormExpression.Form.IF, arguments[1].getType(), arguments);
        }

        public boolean isPlanChanged()
        {
            return planChanged;
        }

        @Override
        public PlanNode visitPlan(PlanNode node, RewriteContext<Context> context)
        {
            List<PlanNode> children = node.getSources().stream()
                    .map(child -> context.rewrite(child, context.get()))
                    .collect(toImmutableList());
            if (!context.get().isEmpty()) {
                throw new PrestoException(GENERIC_INTERNAL_ERROR, "Unexpected plan node between partial and final aggregation");
            }
            return replaceChildren(node, children);
        }

        @Override
        public PlanNode visitAggregation(AggregationNode node, RewriteContext<Context> context)
        {
            PlanNode rewrittenSource = context.rewrite(node.getSource(), context.get());
            // Before optimization, for aggregations with filter, input rows will be skipped if mask is false. However, after optimization, the partial
            // aggregation output is projected to be NULL if mask is false. We need to have the function to not calledOnNullInput to ensure correctness.
            // Applying optimizations on global aggregations will lead to exception at
            // https://github.com/prestodb/presto/blob/dfbf21744ccd900d1a650571ffc35915db9b9f59/presto-main/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java#L627
            boolean canOptimize = !node.getGroupingKeys().isEmpty() && node.getAggregations().values().stream()
                    .map(x -> functionAndTypeManager.getFunctionMetadata(x.getFunctionHandle())).noneMatch(x -> x.isCalledOnNullInput());
            if (canOptimize) {
                checkState(node.getAggregations().values().stream().noneMatch(x -> x.getFilter().isPresent()), "All aggregation filters should already be rewritten to mask before this optimization");
                if (node.getStep().equals(PARTIAL)) {
                    planChanged = true;
                    return createPartialAggregationNode(node, rewrittenSource, context);
                }
                else if (node.getStep().equals(FINAL)) {
                    planChanged = true;
                    return createFinalAggregationNode(node, rewrittenSource, context);
                }
            }
            return node.replaceChildren(ImmutableList.of(rewrittenSource));
        }

        private AggregationNode createPartialAggregationNode(AggregationNode node, PlanNode rewrittenSource, RewriteContext<Context> context)
        {
            checkState(context.get().isEmpty(), "There should be no partial aggregation left unmerged for a partial aggregation node");

            Map<AggregationNode.Aggregation, VariableReferenceExpression> aggregationsWithoutMaskToOutput = node.getAggregations().entrySet().stream()
                    .filter(x -> !x.getValue().getMask().isPresent())
                    .collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey, (a, b) -> a));
            Map<AggregationNode.Aggregation, VariableReferenceExpression> aggregationsToMergeOutput = node.getAggregations().entrySet().stream()
                    .filter(x -> x.getValue().getMask().isPresent() && aggregationsWithoutMaskToOutput.containsKey(removeFilterAndMask(x.getValue())))
                    .collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey));

            ImmutableMap.Builder<AggregationNode.Aggregation, VariableReferenceExpression> partialAggregationToOutputBuilder = ImmutableMap.builder();
            partialAggregationToOutputBuilder.putAll(aggregationsToMergeOutput.keySet().stream().collect(toImmutableMap(Function.identity(), x -> aggregationsWithoutMaskToOutput.get(removeFilterAndMask(x)))));

            List<List<AggregationNode.Aggregation>> candidateAggregationsWithMaskNotMatched = node.getAggregations().entrySet().stream().map(Map.Entry::getValue)
                    .filter(x -> x.getMask().isPresent() && !aggregationsToMergeOutput.containsKey(x))
                    .collect(Collectors.groupingBy(AggregationNodeUtils::removeFilterAndMask)).values()
                    .stream().filter(x -> x.size() > 1).collect(toImmutableList());

            Map<AggregationNode.Aggregation, VariableReferenceExpression> aggregationsWithMaskToMerge = node.getAggregations().entrySet().stream()
                    .filter(x -> aggregationsToMergeOutput.containsKey(x.getValue()) || candidateAggregationsWithMaskNotMatched.stream().anyMatch(aggregations -> aggregations.contains(x.getValue())))
                    .collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey));
            ImmutableMap.Builder<VariableReferenceExpression, RowExpression> newMaskAssignmentsBuilder = ImmutableMap.builder();
            ImmutableMap.Builder<VariableReferenceExpression, AggregationNode.Aggregation> aggregationsAddedBuilder = ImmutableMap.builder();
            List<AggregationNode.Aggregation> newAggregationAdded = candidateAggregationsWithMaskNotMatched.stream()
                    .map(aggregations ->
                    {
                        List<VariableReferenceExpression> maskVariables = aggregations.stream().map(x -> x.getMask().get()).collect(toImmutableList());
                        RowExpression orMaskVariables = or(maskVariables);
                        VariableReferenceExpression newMaskVariable = variableAllocator.newVariable(orMaskVariables);
                        newMaskAssignmentsBuilder.put(newMaskVariable, orMaskVariables);
                        AggregationNode.Aggregation newAggregation = new AggregationNode.Aggregation(
                                aggregations.get(0).getCall(),
                                Optional.empty(),
                                aggregations.get(0).getOrderBy(),
                                aggregations.get(0).isDistinct(),
                                Optional.of(newMaskVariable));
                        VariableReferenceExpression newAggregationVariable = variableAllocator.newVariable(newAggregation.getCall());
                        aggregationsAddedBuilder.put(newAggregationVariable, newAggregation);
                        aggregations.forEach(x -> partialAggregationToOutputBuilder.put(x, newAggregationVariable));
                        return newAggregation;
                    })
                    .collect(toImmutableList());
            Map<VariableReferenceExpression, RowExpression> newMaskAssignments = newMaskAssignmentsBuilder.build();
            Map<VariableReferenceExpression, AggregationNode.Aggregation> aggregationsAdded = aggregationsAddedBuilder.build();
            Map<AggregationNode.Aggregation, VariableReferenceExpression> partialAggregationToOutput = partialAggregationToOutputBuilder.build();

            Map<AggregationNode.Aggregation, VariableReferenceExpression> aggregationsToMergeOutputCombined =
                    node.getAggregations().entrySet().stream()
                            .filter(x -> x.getValue().getMask().isPresent() && aggregationsToMergeOutput.containsKey(x.getValue()) || candidateAggregationsWithMaskNotMatched.stream().anyMatch(aggregations -> aggregations.contains(x.getValue())))
                            .collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey));

            context.get().getNewAggregationOutput().addAll(aggregationsAdded.keySet());
            context.get().getPartialResultToMask().putAll(aggregationsWithMaskToMerge.entrySet().stream()
                    .collect(toImmutableMap(Map.Entry::getValue, x -> x.getKey().getMask().get())));
            context.get().getPartialOutputMapping().putAll(aggregationsWithMaskToMerge.entrySet().stream()
                    .collect(toImmutableMap(Map.Entry::getValue, x -> partialAggregationToOutput.get(x.getKey()))));

            Set<VariableReferenceExpression> maskVariables = new HashSet<>(context.get().getPartialResultToMask().values());
            if (maskVariables.isEmpty()) {
                return (AggregationNode) node.replaceChildren(ImmutableList.of(rewrittenSource));
            }

            ImmutableList.Builder<VariableReferenceExpression> groupingVariables = ImmutableList.builder();
            AggregationNode.GroupingSetDescriptor groupingSetDescriptor = node.getGroupingSets();
            groupingVariables.addAll(groupingSetDescriptor.getGroupingKeys());
            groupingVariables.addAll(maskVariables);
            AggregationNode.GroupingSetDescriptor partialGroupingSetDescriptor = new AggregationNode.GroupingSetDescriptor(
                    groupingVariables.build(), groupingSetDescriptor.getGroupingSetCount(), groupingSetDescriptor.getGlobalGroupingSets());

            Set<VariableReferenceExpression> partialResultToMerge = new HashSet<>(aggregationsToMergeOutputCombined.values());
            Map<VariableReferenceExpression, AggregationNode.Aggregation> aggregationsRemained = node.getAggregations().entrySet().stream()
                    .filter(x -> !partialResultToMerge.contains(x.getKey())).collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
            Map<VariableReferenceExpression, AggregationNode.Aggregation> newAggregations = ImmutableMap.<VariableReferenceExpression, AggregationNode.Aggregation>builder()
                    .putAll(aggregationsRemained).putAll(aggregationsAdded).build();

            PlanNode newChild = rewrittenSource;
            if (!newMaskAssignments.isEmpty()) {
                newChild = addProjections(newChild, planNodeIdAllocator, newMaskAssignments);
            }

            return new AggregationNode(
                    node.getSourceLocation(),
                    node.getId(),
                    newChild,
                    newAggregations,
                    partialGroupingSetDescriptor,
                    node.getPreGroupedVariables(),
                    PARTIAL,
                    node.getHashVariable(),
                    node.getGroupIdVariable(),
                    node.getAggregationId());
        }

        private AggregationNode createFinalAggregationNode(AggregationNode node, PlanNode rewrittenSource, RewriteContext<Context> context)
        {
            if (context.get().isEmpty()) {
                return (AggregationNode) node.replaceChildren(ImmutableList.of(rewrittenSource));
            }
            List<VariableReferenceExpression> intermediateVariables = node.getAggregations().values().stream()
                    .map(x -> (VariableReferenceExpression) x.getArguments().get(0)).collect(toImmutableList());
            checkState(intermediateVariables.containsAll(context.get().partialResultToMask.keySet()));

            ImmutableList.Builder<RowExpression> projectionsFromPartialAgg = ImmutableList.builder();
            ImmutableList.Builder<VariableReferenceExpression> variablesForPartialAggResult = ImmutableList.builder();
            ImmutableMap.Builder<VariableReferenceExpression, AggregationNode.Aggregation> newFinalAggregationMap = ImmutableMap.builder();
            for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
                AggregationNode.Aggregation aggregation = entry.getValue();
                checkState(aggregation.getArguments().size() > 0 && aggregation.getArguments().get(0) instanceof VariableReferenceExpression);
                VariableReferenceExpression partialInput = (VariableReferenceExpression) aggregation.getArguments().get(0);
                if (!context.get().partialResultToMask.containsKey(partialInput)) {
                    newFinalAggregationMap.put(entry.getKey(), entry.getValue());
                    continue;
                }
                VariableReferenceExpression maskVariable = context.get().getPartialResultToMask().get(partialInput);
                VariableReferenceExpression toMergePartialInput = context.get().getPartialOutputMapping().get(partialInput);
                RowExpression conditionalResult = ifThenElse(maskVariable, toMergePartialInput, constantNull(toMergePartialInput.getType()));
                projectionsFromPartialAgg.add(conditionalResult);
                VariableReferenceExpression maskedPartialResult = variableAllocator.newVariable(toMergePartialInput);
                variablesForPartialAggResult.add(maskedPartialResult);

                CallExpression originalExpression = aggregation.getCall();
                CallExpression newExpression = new CallExpression(originalExpression.getSourceLocation(),
                        originalExpression.getDisplayName(),
                        originalExpression.getFunctionHandle(),
                        originalExpression.getType(),
                        ImmutableList.<RowExpression>builder()
                                .add(maskedPartialResult)
                                .addAll(originalExpression.getArguments().subList(1, originalExpression.getArguments().size()))
                                .build());

                AggregationNode.Aggregation newFinalAggregation = new AggregationNode.Aggregation(
                        newExpression,
                        aggregation.getFilter(),
                        aggregation.getOrderBy(),
                        aggregation.isDistinct(),
                        aggregation.getMask());
                newFinalAggregationMap.put(entry.getKey(), newFinalAggregation);
            }

            PlanNode projectNode = addProjections(rewrittenSource, planNodeIdAllocator, variableAllocator, projectionsFromPartialAgg.build(), variablesForPartialAggResult.build());
            context.get().clear();
            return new AggregationNode(
                    node.getSourceLocation(),
                    node.getId(),
                    projectNode,
                    newFinalAggregationMap.build(),
                    node.getGroupingSets(),
                    node.getPreGroupedVariables(),
                    node.getStep(),
                    node.getHashVariable(),
                    node.getGroupIdVariable(),
                    node.getAggregationId());
        }

        @Override
        public PlanNode visitProject(ProjectNode node, RewriteContext<Context> context)
        {
            PlanNode rewrittenSource = context.rewrite(node.getSource(), context.get());
            if (!context.get().isEmpty()) {
                Assignments.Builder assignments = Assignments.builder();
                Map<VariableReferenceExpression, RowExpression> excludeMergedAssignments = node.getAssignments().getMap().entrySet().stream()
                        .filter(x -> !(x.getValue() instanceof VariableReferenceExpression && context.get().getPartialOutputMapping().containsKey(x.getValue())))
                        .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
                assignments.putAll(excludeMergedAssignments);
                assignments.putAll(identityAssignments(context.get().getPartialResultToMask().values()));
                assignments.putAll(identityAssignments(context.get().getNewAggregationOutput()));
                return new ProjectNode(
                        node.getSourceLocation(),
                        node.getId(),
                        rewrittenSource,
                        assignments.build(),
                        node.getLocality());
            }
            return node.replaceChildren(ImmutableList.of(rewrittenSource));
        }

        @Override
        public PlanNode visitExchange(ExchangeNode node, RewriteContext<Context> context)
        {
            ImmutableList.Builder rewriteChildren = ImmutableList.builder();
            for (PlanNode child : node.getSources()) {
                context.get().clear();
                rewriteChildren.add(context.rewrite(child, context.get()));
            }
            List<PlanNode> children = rewriteChildren.build();
            if (!context.get().isEmpty()) {
                PartitioningScheme partitioning = new PartitioningScheme(
                        node.getPartitioningScheme().getPartitioning(),
                        children.get(children.size() - 1).getOutputVariables(),
                        node.getPartitioningScheme().getHashColumn(),
                        node.getPartitioningScheme().isReplicateNullsAndAny(),
                        node.getPartitioningScheme().isScaleWriters(),
                        node.getPartitioningScheme().getEncoding(),
                        node.getPartitioningScheme().getBucketToPartition());

                return new ExchangeNode(
                        node.getSourceLocation(),
                        node.getId(),
                        node.getType(),
                        node.getScope(),
                        partitioning,
                        children,
                        children.stream().map(x -> x.getOutputVariables()).collect(toImmutableList()),
                        node.isEnsureSourceOrdering(),
                        node.getOrderingScheme());
            }
            return node.replaceChildren(children);
        }
    }
}