RewriteIfOverAggregation.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.expressions.DefaultRowExpressionTraversalVisitor;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.AggregationNode.Aggregation;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.RowExpressionVariableInliner;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;

import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import static com.facebook.presto.SystemSessionProperties.isOptimizeConditionalAggregationEnabled;
import static com.facebook.presto.expressions.LogicalRowExpressions.and;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF;
import static com.facebook.presto.sql.planner.VariablesExtractor.extractAll;
import static com.facebook.presto.sql.planner.VariablesExtractor.extractUnique;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.function.Function.identity;

/**
 * Rewrite IF(predicate, AGG(x)) to AGG(x) with Mask.
 * The plan will change
 * When the aggregation does not have mask
 * <p>
 * From:
 * <pre>
 * - Project (val <- IF(condition, agg))
 *   - Aggregation (agg <- AGG(x))
 * </pre>
 * To:
 * <pre>
 * - Project (val <- IF(condition, agg))
 *   - Aggregation (agg <- AGG(x), mask <- Mask)
 *     - Project (Mask <- condition)
 * </pre>
 * <p>
 * Or
 * When the aggregation already has mask
 * <p>
 * From:
 * <pre>
 * - Project (val <- IF(condition, agg))
 *   - Aggregation (agg <- AGG(x), mask)
 * </pre>
 * To:
 * <pre>
 * - Project (val <- IF(condition, agg))
 *   - Aggregation (agg <- AGG(x), mask <- Mask)
 *     - Project (Mask <- AND(mask, condition))
 * </pre>
 * <p>
 */
public class RewriteIfOverAggregation
        implements PlanOptimizer
{
    private final FunctionAndTypeManager functionAndTypeManager;
    private boolean isEnabledForTesting;

    public RewriteIfOverAggregation(FunctionAndTypeManager functionAndTypeManager)
    {
        this.functionAndTypeManager = functionAndTypeManager;
    }

    @Override
    public void setEnabledForTesting(boolean isSet)
    {
        isEnabledForTesting = isSet;
    }

    @Override
    public boolean isEnabled(Session session)
    {
        return isEnabledForTesting || isOptimizeConditionalAggregationEnabled(session);
    }

    @Override
    public PlanOptimizerResult optimize(PlanNode plan,
            Session session,
            TypeProvider types,
            VariableAllocator variableAllocator,
            PlanNodeIdAllocator idAllocator,
            WarningCollector warningCollector)
    {
        if (isEnabled(session)) {
            Rewriter rewriter = new Rewriter(variableAllocator, idAllocator, new RowExpressionDeterminismEvaluator(functionAndTypeManager));
            PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan, ImmutableMap.of());
            return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged());
        }
        return PlanOptimizerResult.optimizerResult(plan, false);
    }

    // Map<VariableReferenceExpression, RowExpression> stores the candidate IF expressions for rewrite.
    private static class Rewriter
            extends SimplePlanRewriter<Map<VariableReferenceExpression, RowExpression>>
    {
        private final VariableAllocator planVariableAllocator;
        private final PlanNodeIdAllocator planNodeIdAllocator;
        private final RowExpressionDeterminismEvaluator determinismEvaluator;
        private boolean planChanged;

        private Rewriter(VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, RowExpressionDeterminismEvaluator determinismEvaluator)
        {
            this.planVariableAllocator = variableAllocator;
            this.planNodeIdAllocator = idAllocator;
            this.determinismEvaluator = determinismEvaluator;
        }

        private static VariableReferenceExpression getTrueValueFromIf(RowExpression rowExpression)
        {
            checkState(rowExpression instanceof SpecialFormExpression
                    && ((SpecialFormExpression) rowExpression).getArguments().get(1) instanceof VariableReferenceExpression);
            return (VariableReferenceExpression) ((SpecialFormExpression) rowExpression).getArguments().get(1);
        }

        private static RowExpression inlineReferences(RowExpression expression, Assignments assignments)
        {
            return RowExpressionVariableInliner.inlineVariables(variable -> assignments.getMap().getOrDefault(variable, variable), expression);
        }

        public boolean isPlanChanged()
        {
            return planChanged;
        }

        @Override
        public PlanNode visitPlan(PlanNode node, RewriteContext<Map<VariableReferenceExpression, RowExpression>> context)
        {
            // This optimizer targets plan generated from IF(predicate, AGG(x)). The plan generated will be Project <- Aggregation, where project includes the IF expression.
            // Here we pass an empty map in context by default, so as to not pass the candidate IF expressions unless it's a project node.
            return context.defaultRewrite(node, ImmutableMap.of());
        }

        @Override
        public PlanNode visitProject(ProjectNode node, RewriteContext<Map<VariableReferenceExpression, RowExpression>> context)
        {
            // The rewrite will change the aggregation output, hence we only do the rewrite if the aggregation output is only used in the IF expression, i.e. variables which
            // occurs only once in the assignments.
            Set<VariableReferenceExpression> candidateVariables = node.getAssignments().getExpressions().stream()
                    .flatMap(expression -> extractAll(expression).stream())
                    .collect(Collectors.groupingBy(identity(), Collectors.counting())).entrySet().stream()
                    .filter(entry -> entry.getValue() == 1)
                    .map(Map.Entry::getKey)
                    .collect(toImmutableSet());

            ImmutableSet.Builder<RowExpression> candidateIfBuilder = ImmutableSet.builder();
            IfExpressionExtractor ifExpressionExtractor = new IfExpressionExtractor();
            // Collect candidate IF expressions in assignments
            node.getAssignments().getExpressions().forEach(expression -> expression.accept(ifExpressionExtractor, candidateIfBuilder));
            // The true value should be only used once in the assignments
            Map<VariableReferenceExpression, RowExpression> candidatesInAssignments = candidateIfBuilder.build().stream()
                    .filter(x -> candidateVariables.contains(getTrueValueFromIf(x)))
                    .collect(toImmutableMap(Rewriter::getTrueValueFromIf, identity()));

            // The true value used in the candidate passed from context should only be used once in the assignments, i.e. an identity assignment in this case.
            // Also inline the if expression so that it can be resolved when we push the condition to aggregation.
            Map<VariableReferenceExpression, RowExpression> candidatePassedFromContext = context.get().entrySet().stream()
                    .filter(x -> candidateVariables.contains(x.getKey()))
                    .collect(toImmutableMap(Map.Entry::getKey, x -> inlineReferences(x.getValue(), node.getAssignments())));

            ImmutableMap.Builder<VariableReferenceExpression, RowExpression> candidates = ImmutableMap.builder();
            candidates.putAll(candidatesInAssignments);
            candidates.putAll(candidatePassedFromContext);

            return context.defaultRewrite(node, candidates.build());
        }

        @Override
        public PlanNode visitAggregation(AggregationNode node, RewriteContext<Map<VariableReferenceExpression, RowExpression>> context)
        {
            Map<VariableReferenceExpression, RowExpression> candidate = context.get().entrySet().stream()
                    .filter(x -> node.getAggregations().containsKey(x.getKey()))
                    .filter(x -> extractUnique(x.getValue()).stream().filter(variable -> !variable.equals(x.getKey())).allMatch(node.getSource().getOutputVariables()::contains)) // only if expression can be resolved by aggregation inputs
                    .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
            if (candidate.isEmpty()) {
                return context.defaultRewrite(node, ImmutableMap.of());
            }

            Assignments.Builder sourceProjection = Assignments.builder();
            ImmutableMap.Builder<VariableReferenceExpression, Aggregation> newAggregations = ImmutableMap.builder();
            candidate.forEach((aggregationOutput, ifExpression) -> {
                checkState(ifExpression instanceof SpecialFormExpression && ((SpecialFormExpression) ifExpression).getForm().equals(IF));
                RowExpression condition = ((SpecialFormExpression) ifExpression).getArguments().get(0);
                Aggregation aggregation = node.getAggregations().get(aggregationOutput);
                RowExpression maskExpression = aggregation.getMask().isPresent() ? and(aggregation.getMask().get(), condition) : condition;
                VariableReferenceExpression maskVariable = planVariableAllocator.newVariable(maskExpression);
                Aggregation newAggregation = new Aggregation(
                        aggregation.getCall(),
                        aggregation.getFilter(),
                        aggregation.getOrderBy(),
                        aggregation.isDistinct(),
                        Optional.of(maskVariable));
                sourceProjection.put(maskVariable, maskExpression);
                newAggregations.put(aggregationOutput, newAggregation);
            });

            sourceProjection.putAll(node.getSource().getOutputVariables().stream().collect(toImmutableMap(identity(), identity())));
            newAggregations.putAll(
                    node.getAggregations().entrySet().stream().filter(x -> !candidate.containsKey(x.getKey())).collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)));

            planChanged = true;
            AggregationNode aggregationNode = new AggregationNode(
                    node.getSourceLocation(),
                    node.getId(),
                    new ProjectNode(planNodeIdAllocator.getNextId(), node.getSource(), sourceProjection.build()),
                    newAggregations.build(),
                    node.getGroupingSets(),
                    node.getPreGroupedVariables(),
                    node.getStep(),
                    node.getHashVariable(),
                    node.getGroupIdVariable(),
                    node.getAggregationId());

            return context.defaultRewrite(aggregationNode, ImmutableMap.of());
        }

        private boolean isCandidateIfExpression(RowExpression rowExpression)
        {
            return determinismEvaluator.isDeterministic(rowExpression) &&
                    rowExpression instanceof SpecialFormExpression && ((SpecialFormExpression) rowExpression).getForm().equals(IF)
                    && ((SpecialFormExpression) rowExpression).getArguments().get(1) instanceof VariableReferenceExpression
                    && (
                    ((SpecialFormExpression) rowExpression).getArguments().size() == 2
                            || ((SpecialFormExpression) rowExpression).getArguments().get(2) instanceof ConstantExpression
                            && ((ConstantExpression) ((SpecialFormExpression) rowExpression).getArguments().get(2)).isNull());
        }

        // Extract candidate IF expression from row expression
        private class IfExpressionExtractor
                extends DefaultRowExpressionTraversalVisitor<ImmutableSet.Builder<RowExpression>>
        {
            @Override
            public Void visitSpecialForm(SpecialFormExpression specialForm, ImmutableSet.Builder<RowExpression> context)
            {
                if (isCandidateIfExpression(specialForm)) {
                    context.add(specialForm);
                }
                return super.visitSpecialForm(specialForm, context);
            }
        }
    }
}