CombineApproxPercentileFunctions.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static com.facebook.presto.SystemSessionProperties.isCombineApproxPercentileEnabled;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.sql.planner.plan.Patterns.aggregation;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;

/**
 * For multiple approx_percentile() function calls on the same column with different percentile arguments, combine them to one call on an array of percentile arguments.
 * <p>
 * From:
 * <pre>
 * - Aggregation (approx_percentile(col, 0.2), approx_percentile(col, 0.8)
 * </pre>
 * To:
 * <pre>
 * - Project (approx_percentile_results[1], approx_percentile_results[2])
 *   - Aggregation (approx_percentile_results <- approx_percentile(col, array)
 *     - Project (col <- col, array <- [0.2, 0.8])
 * </pre>
 * <p>
 */

public class CombineApproxPercentileFunctions
        implements Rule<AggregationNode>
{
    private static final String APPROX_PERCENTILE = "approx_percentile";
    private static final String ARRAY_CONSTRUCTOR = "array_constructor";
    private static final String ELEMENT_AT = "element_at";
    // Limit specified in `ArrayConstructor` function.
    private static final int ARRAY_SIZE_LIMIT = 254;
    private final FunctionAndTypeManager functionAndTypeManager;

    public CombineApproxPercentileFunctions(FunctionAndTypeManager functionAndTypeManager)
    {
        this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
    }

    private static final Pattern<AggregationNode> PATTERN = aggregation()
            .matching(CombineApproxPercentileFunctions::hasMultipleApproxPercentile);

    private static boolean hasMultipleApproxPercentile(AggregationNode aggregation)
    {
        return aggregation.getAggregations().values().stream()
                .filter(agg -> agg.getCall().getDisplayName().equals(APPROX_PERCENTILE)).count() > 1;
    }

    @Override
    public Pattern<AggregationNode> getPattern()
    {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session)
    {
        return isCombineApproxPercentileEnabled(session);
    }

    // Get the position of the percentile argument in arguments
    private static int getPercentilePosition(FunctionHandle functionHandle)
    {
        // approx_percentile(x, percentage) -> arguments.size() == 2
        // approx_percentile(x, percentage, accuracy) -> arguments.size() == 3 && arguments.get(1) is Double
        // approx_percentile(x, w, percentage) -> arguments.size() == 3 && arguments.get(1) is BigInt
        // approx_percentile(x, w, percentage, accuracy) -> arguments.size() == 4
        List<TypeSignature> argumentTypes = functionHandle.getArgumentTypes();
        if (argumentTypes.size() == 2 || (argumentTypes.size() == 3 && argumentTypes.get(1).getBase().equals(StandardTypes.DOUBLE))) {
            return 1;
        }
        checkState(argumentTypes.size() == 4 || (argumentTypes.size() == 3 && argumentTypes.get(1).getBase().equals(StandardTypes.BIGINT)));
        return 2;
    }

    private static boolean aggregationCanMerge(AggregationNode.Aggregation aggregation1, AggregationNode.Aggregation aggregation2)
    {
        if (!aggregation1.getMask().equals(aggregation2.getMask())
                || !aggregation1.getOrderBy().equals(aggregation2.getOrderBy())
                || !aggregation1.getFilter().equals(aggregation2.getFilter())
                || aggregation1.isDistinct() != aggregation2.isDistinct()) {
            return false;
        }
        // Check call expression, the only difference should be percentile argument
        CallExpression expression1 = aggregation1.getCall();
        CallExpression expression2 = aggregation2.getCall();
        int percentilePosition = getPercentilePosition(expression1.getFunctionHandle());
        if (!expression1.getFunctionHandle().equals(expression2.getFunctionHandle()) || expression1.getArguments().size() != expression2.getArguments().size()) {
            return false;
        }
        List<RowExpression> arguments1 = expression1.getArguments();
        List<RowExpression> arguments2 = expression2.getArguments();
        for (int i = 0; i < arguments1.size(); ++i) {
            if (i != percentilePosition && !arguments1.get(i).equals(arguments2.get(i))) {
                return false;
            }
        }
        return true;
    }

    private static List<RowExpression> changePercentileArgument(List<RowExpression> arguments, RowExpression percentileArgument, int percentilePosition)
    {
        ImmutableList.Builder<RowExpression> newAggCallArguments = new ImmutableList.Builder<>();
        for (int i = 0; i < arguments.size(); ++i) {
            if (i == percentilePosition) {
                newAggCallArguments.add(percentileArgument);
            }
            else {
                newAggCallArguments.add(arguments.get(i));
            }
        }
        return newAggCallArguments.build();
    }

    // Split the aggregations in candidateAggregations into multiple lists, with each list containing aggregations which can be merged.
    private static List<List<AggregationNode.Aggregation>> createMergeableAggregations(List<AggregationNode.Aggregation> candidateAggregations)
    {
        ImmutableList.Builder<List<AggregationNode.Aggregation>> result = ImmutableList.builder();
        Set<AggregationNode.Aggregation> mergedAggregation = new HashSet<>();
        for (int i = 0; i < candidateAggregations.size(); ++i) {
            if (mergedAggregation.contains(candidateAggregations.get(i))) {
                continue;
            }
            ImmutableList.Builder<AggregationNode.Aggregation> aggregationCanBeMerged = ImmutableList.builder();
            mergedAggregation.add(candidateAggregations.get(i));
            aggregationCanBeMerged.add(candidateAggregations.get(i));
            for (int j = i + 1; j < candidateAggregations.size(); ++j) {
                if (mergedAggregation.contains(candidateAggregations.get(j))) {
                    continue;
                }
                if (aggregationCanMerge(candidateAggregations.get(i), candidateAggregations.get(j))) {
                    mergedAggregation.add(candidateAggregations.get(j));
                    aggregationCanBeMerged.add(candidateAggregations.get(j));
                }
            }
            result.add(aggregationCanBeMerged.build());
        }
        return result.build();
    }

    private CallExpression createArrayPercentile(List<AggregationNode.Aggregation> aggregations)
    {
        List<RowExpression> percentileArray = aggregations.stream().map(x -> x.getArguments().get(getPercentilePosition(x.getFunctionHandle()))).collect(Collectors.toList());

        return call(
                functionAndTypeManager,
                ARRAY_CONSTRUCTOR,
                new ArrayType(percentileArray.get(0).getType()),
                percentileArray);
    }

    private AggregationNode.Aggregation createArrayAggregation(List<AggregationNode.Aggregation> candidateList, VariableReferenceExpression arrayVariableReference)
    {
        AggregationNode.Aggregation aggregationBeforeMerge = candidateList.get(0);
        int percentilePosition = getPercentilePosition(aggregationBeforeMerge.getFunctionHandle());
        List<RowExpression> newAggCallArguments = changePercentileArgument(aggregationBeforeMerge.getCall().getArguments(), arrayVariableReference, percentilePosition);
        Type colType = aggregationBeforeMerge.getCall().getArguments().get(0).getType();
        CallExpression approxPercentileCall = call(
                functionAndTypeManager,
                APPROX_PERCENTILE,
                new ArrayType(colType),
                newAggCallArguments);

        return new AggregationNode.Aggregation(
                approxPercentileCall,
                aggregationBeforeMerge.getFilter(),
                aggregationBeforeMerge.getOrderBy(),
                aggregationBeforeMerge.isDistinct(),
                aggregationBeforeMerge.getMask());
    }

    @Override
    public Result apply(AggregationNode aggregationNode, Captures captures, Context context)
    {
        ImmutableMap.Builder<VariableReferenceExpression, AggregationNode.Aggregation> aggregations = ImmutableMap.builder();

        // Only approx_percentile which does not take array as percentile arguments
        List<AggregationNode.Aggregation> approxPercentile = aggregationNode.getAggregations().values().stream().filter(
                x -> x.getCall().getDisplayName().equals(APPROX_PERCENTILE) && !(x.getCall().getType() instanceof ArrayType)
        ).collect(Collectors.toList());

        // Remove aggregations which occurs more than once, as we assumes that there are no duplicates in later stage
        Map<AggregationNode.Aggregation, Long> aggregationOccurrences = approxPercentile.stream().collect(Collectors.groupingBy(identity(), Collectors.counting()));
        ImmutableList<AggregationNode.Aggregation> candidateApproxPercentile = approxPercentile.stream().filter(x -> aggregationOccurrences.get(x) == 1).collect(toImmutableList());

        // Group the aggregations on the same column and have the same function handle
        Map<RowExpression, Map<FunctionHandle, List<AggregationNode.Aggregation>>> sameColumnHandle =
                candidateApproxPercentile.stream().collect(Collectors.groupingBy(x -> x.getCall().getArguments().get(0), LinkedHashMap::new,
                        Collectors.groupingBy(x -> x.getFunctionHandle(), LinkedHashMap::new, Collectors.toList())));

        // Each list contains the aggregations which can be combined
        ImmutableList.Builder<List<AggregationNode.Aggregation>> candidateLists = ImmutableList.builder();
        sameColumnHandle.values().forEach(sameHandle -> {
            sameHandle.values().forEach(aggregationList -> {
                candidateLists.addAll(createMergeableAggregations(aggregationList));
            });
        });
        // ArrayConstructor does not support more than 254 elements
        List<List<AggregationNode.Aggregation>> candidateAggregationLists =
                candidateLists.build().stream().filter(x -> x.size() > 1 && x.size() < ARRAY_SIZE_LIMIT).collect(Collectors.toList());

        if (candidateAggregationLists.isEmpty()) {
            return Result.empty();
        }

        // Record which aggregations are combined
        Set<AggregationNode.Aggregation> combinedAggregations = candidateAggregationLists.stream().flatMap(List::stream).collect(Collectors.toSet());
        // Record mapping between aggregation and corresponding output variable reference
        Map<AggregationNode.Aggregation, VariableReferenceExpression> aggregationVariableMap = new HashMap<>();
        // Record which variable references are combined
        Set<VariableReferenceExpression> combinedVariableReference = new HashSet<>();
        aggregationNode.getAggregations().forEach((variable, aggregation) -> {
            if (combinedAggregations.contains(aggregation)) {
                aggregationVariableMap.put(aggregation, variable);
                combinedVariableReference.add(variable);
            }
        });

        // Build a project node as the source of the aggregation. This contains aggregation inputs and percentile arrays
        Assignments.Builder sourceProjectAssignments = Assignments.builder();
        // Build a project node as the output of the aggregation, do subscription for the combined approx_percentile, and keep the others
        Assignments.Builder outputProjectAssignments = Assignments.builder();
        for (List<AggregationNode.Aggregation> candidateList : candidateAggregationLists) {
            // Build array of percentile arguments
            RowExpression arrayExpression = createArrayPercentile(candidateList);
            VariableReferenceExpression arrayVariableReference = context.getVariableAllocator().newVariable(arrayExpression);
            sourceProjectAssignments.put(arrayVariableReference, arrayExpression);

            // Build aggregations taking percentile array as arguments
            AggregationNode.Aggregation newAggregation = createArrayAggregation(candidateList, arrayVariableReference);
            VariableReferenceExpression newVariableReference = context.getVariableAllocator().newVariable(newAggregation.getCall());
            aggregations.put(newVariableReference, newAggregation);

            // Build element_at expression
            Map<VariableReferenceExpression, RowExpression> elementAtMap =
                    IntStream.range(0, candidateList.size()).boxed().collect(ImmutableMap.toImmutableMap(
                            x -> aggregationVariableMap.get(candidateList.get(x)),
                            x -> call(
                                    functionAndTypeManager,
                                    ELEMENT_AT,
                                    candidateList.get(x).getArguments().get(0).getType(),
                                    ImmutableList.of(newVariableReference, constant((long) x + 1, BIGINT)))));
            outputProjectAssignments.putAll(elementAtMap);
        }

        // Add aggregations which are not combined to the new aggregation node.
        aggregationNode.getAggregations().forEach((key, value) -> {
            if (!combinedVariableReference.contains(key)) {
                aggregations.put(key, value);
            }
        });

        // Add output of the old aggregations which are not changed in the rewrite to the parent projection node.
        aggregationNode.getOutputVariables().forEach(variable -> {
            if (!combinedVariableReference.contains(variable)) {
                outputProjectAssignments.put(variable, variable);
            }
        });

        aggregationNode.getSource().getOutputVariables().forEach(variable -> sourceProjectAssignments.put(variable, variable));

        return Result.ofPlanNode(
                new ProjectNode(context.getIdAllocator().getNextId(),
                        new AggregationNode(
                                aggregationNode.getSourceLocation(),
                                context.getIdAllocator().getNextId(),
                                new ProjectNode(context.getIdAllocator().getNextId(),
                                        aggregationNode.getSource(), sourceProjectAssignments.build()),
                                aggregations.build(),
                                aggregationNode.getGroupingSets(),
                                ImmutableList.of(),
                                aggregationNode.getStep(),
                                aggregationNode.getHashVariable(),
                                aggregationNode.getGroupIdVariable(),
                                aggregationNode.getAggregationId()),
                        outputProjectAssignments.build()));
    }
}