RewriteAggregationIfToFilter.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.expressions.DefaultRowExpressionTraversalVisitor;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.AggregationNode.Aggregation;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.FeaturesConfig.AggregationIfToFilterRewriteStrategy;
import com.facebook.presto.sql.planner.VariablesExtractor;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSortedSet;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;

import static com.facebook.presto.SystemSessionProperties.getAggregationIfToFilterRewriteStrategy;
import static com.facebook.presto.common.RuntimeMetricName.REWRITE_AGGREGATION_IF_TO_FILTER_APPLIED;
import static com.facebook.presto.common.RuntimeUnit.NONE;
import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT;
import static com.facebook.presto.expressions.LogicalRowExpressions.or;
import static com.facebook.presto.matching.Capture.newCapture;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.AggregationIfToFilterRewriteStrategy.FILTER_WITH_IF;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.AggregationIfToFilterRewriteStrategy.UNWRAP_IF;
import static com.facebook.presto.sql.planner.plan.Patterns.aggregation;
import static com.facebook.presto.sql.planner.plan.Patterns.project;
import static com.facebook.presto.sql.planner.plan.Patterns.source;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.ImmutableSortedMap.toImmutableSortedMap;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;

/**
 * A optimizer rule which rewrites
 * AGG(IF(condition, expr))
 * to
 * AGG(IF(condition, expr)) FILTER (WHERE condition).
 * if AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY is FILTER_WITH_IF,
 * or
 * AGG(expr) FILTER (WHERE condition).
 * if AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY is UNWRAP_IF.
 * <p>
 * The rewritten plan is more efficient because:
 * 1. The filter can be pushed down to the scan node.
 * 2. The rows not matching the condition are not aggregated.
 * <p>
 * Note that unwrapping the IF expression in the aggregate might cause issues if the true branch return errors for rows not matching the filters. For example:
 * 'IF(CARDINALITY(array) > 0, array[1]))'
 * Session property AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY and canUnwrapIf() control whether to enable IF unwrapping:
 * 1. If the strategy is FILTER_WITH_IF, then keep the IF expression.
 * 2. If the strategy is UNWRAP_IF_SAFE,  then unwrap the IF expression if it is safe to do so.
 * 3. If the strategy is UNWRAP_IF, then unwrap the IF expression after it passes the checks; note that this is an unsafe mode since the checks are not exhaustive.
 */
public class RewriteAggregationIfToFilter
        implements Rule<AggregationNode>
{
    private static final Capture<ProjectNode> CHILD = newCapture();

    private static final Pattern<AggregationNode> PATTERN = aggregation()
            .with(source().matching(project().capturedAs(CHILD)));

    private final FunctionAndTypeManager functionAndTypeManager;
    private final RowExpressionDeterminismEvaluator rowExpressionDeterminismEvaluator;
    private final StandardFunctionResolution standardFunctionResolution;

    public RewriteAggregationIfToFilter(FunctionAndTypeManager functionAndTypeManager)
    {
        this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionManager is null");
        this.rowExpressionDeterminismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager);
        this.standardFunctionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
    }

    @Override
    public boolean isEnabled(Session session)
    {
        return getAggregationIfToFilterRewriteStrategy(session).ordinal() > AggregationIfToFilterRewriteStrategy.DISABLED.ordinal();
    }

    @Override
    public Pattern<AggregationNode> getPattern()
    {
        return PATTERN;
    }

    @Override
    public Result apply(AggregationNode aggregationNode, Captures captures, Context context)
    {
        ProjectNode sourceProject = captures.get(CHILD);

        Set<Aggregation> aggregationsToRewrite = aggregationNode.getAggregations().values().stream()
                .filter(aggregation -> shouldRewriteAggregation(aggregation, sourceProject))
                .collect(toImmutableSet());
        if (aggregationsToRewrite.isEmpty()) {
            return Result.empty();
        }

        context.getSession().getRuntimeStats().addMetricValue(REWRITE_AGGREGATION_IF_TO_FILTER_APPLIED, NONE, 1);
        // Get the corresponding assignments in the input project.
        // The aggregationReferences only has the aggregations to rewrite, thus the sourceAssignments only has IF/CAST(IF) expressions with NULL false results.
        // Multiple aggregations may reference the same input. We use a map to dedup them based on the VariableReferenceExpression, so that we only do the rewrite once per input
        // IF expression.
        // The order of sourceAssignments determines the order of generating the new variables for the IF conditions and results. We use a sorted map to get a deterministic
        // order based on the name of the VariableReferenceExpressions.
        Map<VariableReferenceExpression, RowExpression> sourceAssignments = aggregationsToRewrite.stream()
                .map(aggregation -> (VariableReferenceExpression) aggregation.getArguments().get(0))
                .collect(toImmutableSortedMap(VariableReferenceExpression::compareTo, identity(), variable -> sourceProject.getAssignments().get(variable), (left, right) -> left));

        Assignments.Builder newAssignments = Assignments.builder();
        newAssignments.putAll(sourceProject.getAssignments());

        // Map from the aggregation reference to the IF condition reference which will be put in the mask.
        Map<VariableReferenceExpression, VariableReferenceExpression> aggregationReferenceToConditionReference = new HashMap<>();
        // Map from the aggregation reference to the IF result reference. This only contains the aggregates where the IF can be safely unwrapped.
        // E.g., SUM(IF(CARDINALITY(array) > 0, array[1])) will not be included in this map as array[1] can return errors if we unwrap the IF.
        Map<VariableReferenceExpression, VariableReferenceExpression> aggregationReferenceToIfResultReference = new HashMap<>();

        AggregationIfToFilterRewriteStrategy rewriteStrategy = getAggregationIfToFilterRewriteStrategy(context.getSession());
        for (Map.Entry<VariableReferenceExpression, RowExpression> entry : sourceAssignments.entrySet()) {
            VariableReferenceExpression outputVariable = entry.getKey();
            RowExpression rowExpression = entry.getValue();
            SpecialFormExpression ifExpression = (SpecialFormExpression) ((rowExpression instanceof CallExpression)
                    ? ((CallExpression) rowExpression).getArguments().get(0)
                    : rowExpression);

            RowExpression condition = ifExpression.getArguments().get(0);
            VariableReferenceExpression conditionReference = context.getVariableAllocator().newVariable(condition);
            newAssignments.put(conditionReference, condition);
            aggregationReferenceToConditionReference.put(outputVariable, conditionReference);

            if (canUnwrapIf(ifExpression, rewriteStrategy)) {
                RowExpression trueResult = ifExpression.getArguments().get(1);
                if (rowExpression instanceof CallExpression) {
                    // Wrap the result with CAST().
                    trueResult = new CallExpression(
                            ((CallExpression) rowExpression).getDisplayName(),
                            ((CallExpression) rowExpression).getFunctionHandle(),
                            rowExpression.getType(),
                            ImmutableList.of(trueResult));
                }
                VariableReferenceExpression ifResultReference = context.getVariableAllocator().newVariable(trueResult);
                newAssignments.put(ifResultReference, trueResult);
                aggregationReferenceToIfResultReference.put(outputVariable, ifResultReference);
            }
        }

        // Build new aggregations.
        ImmutableMap.Builder<VariableReferenceExpression, Aggregation> aggregations = ImmutableMap.builder();
        // Stores the masks used to build the filter predicates. Use set to dedup the predicates.
        ImmutableSortedSet.Builder<VariableReferenceExpression> masks = ImmutableSortedSet.naturalOrder();
        for (Map.Entry<VariableReferenceExpression, Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
            VariableReferenceExpression output = entry.getKey();
            Aggregation aggregation = entry.getValue();
            if (!aggregationsToRewrite.contains(aggregation)) {
                aggregations.put(output, aggregation);
                continue;
            }
            VariableReferenceExpression aggregationReference = (VariableReferenceExpression) aggregation.getArguments().get(0);
            CallExpression callExpression = aggregation.getCall();
            VariableReferenceExpression ifResultReference = aggregationReferenceToIfResultReference.get(aggregationReference);
            if (ifResultReference != null) {
                callExpression = new CallExpression(
                        callExpression.getSourceLocation(),
                        callExpression.getDisplayName(),
                        callExpression.getFunctionHandle(),
                        callExpression.getType(),
                        ImmutableList.of(ifResultReference));
            }
            VariableReferenceExpression mask = aggregationReferenceToConditionReference.get(aggregationReference);
            aggregations.put(output, new Aggregation(
                    callExpression,
                    Optional.empty(),
                    aggregation.getOrderBy(),
                    aggregation.isDistinct(),
                    Optional.of(aggregationReferenceToConditionReference.get(aggregationReference))));
            masks.add(mask);
        }

        RowExpression predicate = TRUE_CONSTANT;
        if (!aggregationNode.hasNonEmptyGroupingSet() && aggregationsToRewrite.size() == aggregationNode.getAggregations().size()) {
            // All aggregations are rewritten by this rule. We can add a filter with all the masks to make the query more efficient.
            predicate = or(masks.build());
        }
        return Result.ofPlanNode(
                new AggregationNode(
                        aggregationNode.getSourceLocation(),
                        context.getIdAllocator().getNextId(),
                        new FilterNode(
                                aggregationNode.getSourceLocation(),
                                context.getIdAllocator().getNextId(),
                                new ProjectNode(
                                        context.getIdAllocator().getNextId(),
                                        sourceProject.getSource(),
                                        newAssignments.build()),
                                predicate),
                        aggregations.build(),
                        aggregationNode.getGroupingSets(),
                        aggregationNode.getPreGroupedVariables(),
                        aggregationNode.getStep(),
                        aggregationNode.getHashVariable(),
                        aggregationNode.getGroupIdVariable(),
                        aggregationNode.getAggregationId()));
    }

    private boolean shouldRewriteAggregation(Aggregation aggregation, ProjectNode sourceProject)
    {
        if (functionAndTypeManager.getFunctionMetadata(aggregation.getFunctionHandle()).isCalledOnNullInput()) {
            // This rewrite will filter out the null values. It could change the behavior if the aggregation is also applied on NULLs.
            return false;
        }
        if (!(aggregation.getArguments().size() == 1 && aggregation.getArguments().get(0) instanceof VariableReferenceExpression)) {
            // Currently we only handle aggregation with a single VariableReferenceExpression. The detailed expressions are in a project node below this aggregation.
            return false;
        }
        if (aggregation.getFilter().isPresent() || aggregation.getMask().isPresent()) {
            // Do not rewrite the aggregation if it already has a filter or mask.
            return false;
        }
        RowExpression sourceExpression = sourceProject.getAssignments().get((VariableReferenceExpression) aggregation.getArguments().get(0));
        if (sourceExpression instanceof CallExpression) {
            CallExpression callExpression = (CallExpression) sourceExpression;
            if (callExpression.getArguments().size() == 1 && standardFunctionResolution.isCastFunction(callExpression.getFunctionHandle())) {
                // If the expression is CAST(), check the expression inside.
                sourceExpression = callExpression.getArguments().get(0);
            }
        }
        if (!(sourceExpression instanceof SpecialFormExpression) || !rowExpressionDeterminismEvaluator.isDeterministic(sourceExpression)) {
            return false;
        }
        SpecialFormExpression expression = (SpecialFormExpression) sourceExpression;
        // Only rewrite the aggregation if the else branch is not present or the else result is NULL.
        return expression.getForm() == IF && Expressions.isNull(expression.getArguments().get(2));
    }

    private boolean canUnwrapIf(SpecialFormExpression ifExpression, AggregationIfToFilterRewriteStrategy rewriteStrategy)
    {
        if (rewriteStrategy == FILTER_WITH_IF) {
            return false;
        }

        // Some use cases use IF expression to avoid returning errors when evaluating the true branch. For example, IF(CARDINALITY(array) > 0, array[1])).
        // We shouldn't unwrap the IF for those cases.
        // But if the condition expression doesn't reference any variables referenced in the true branch, unwrapping the if should not cause exceptions for the true branch.
        Set<VariableReferenceExpression> ifConditionReferences = VariablesExtractor.extractUnique(ifExpression.getArguments().get(0));
        Set<VariableReferenceExpression> ifResultReferences = VariablesExtractor.extractUnique(ifExpression.getArguments().get(1));
        if (ifConditionReferences.stream().noneMatch(ifResultReferences::contains)) {
            return true;
        }

        if (rewriteStrategy != UNWRAP_IF) {
            return false;
        }

        AtomicBoolean result = new AtomicBoolean(true);
        ifExpression.getArguments().get(1).accept(new DefaultRowExpressionTraversalVisitor<AtomicBoolean>()
        {
            @Override
            public Void visitLambda(LambdaDefinitionExpression lambda, AtomicBoolean result)
            {
                // Unwrapping the IF expression in the aggregate might cause issues if the true branch return errors for rows not matching the filters.
                // To be safe, we don't unwrap the IF expressions when the true branch has lambdas.
                result.set(false);
                return null;
            }

            @Override
            public Void visitCall(CallExpression call, AtomicBoolean result)
            {
                Optional<OperatorType> operatorType = functionAndTypeManager.getFunctionMetadata(call.getFunctionHandle()).getOperatorType();
                // Unwrapping the IF expression in the aggregate might cause issues if the true branch return errors for rows not matching the filters.
                // For example, array[1] could return out of bound error and a / b could return DIVISION_BY_ZERO error. So we doesn't unwrap the IF expression in these cases.
                if (operatorType.isPresent() && (operatorType.get() == OperatorType.DIVIDE || operatorType.get() == OperatorType.SUBSCRIPT)) {
                    result.set(false);
                    return null;
                }
                return super.visitCall(call, result);
            }
        }, result);
        return result.get();
    }
}