MultipleDistinctAggregationToMarkDistinct.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.AggregationNode.Aggregation;
import com.facebook.presto.spi.plan.MarkDistinctNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.sql.planner.plan.Patterns.aggregation;
import static java.util.stream.Collectors.toSet;

/**
 * Implements distinct aggregations with different inputs by transforming plans of the following shape:
 * <pre>
 * - Aggregation
 *        GROUP BY (k)
 *        F1(DISTINCT a0, a1, ...)
 *        F2(DISTINCT b0, b1, ...)
 *        F3(c0, c1, ...)
 *     - X
 * </pre>
 * into
 * <pre>
 * - Aggregation
 *        GROUP BY (k)
 *        F1(a0, a1, ...) mask ($0)
 *        F2(b0, b1, ...) mask ($1)
 *        F3(c0, c1, ...)
 *     - MarkDistinct (k, a0, a1, ...) -> $0
 *          - MarkDistinct (k, b0, b1, ...) -> $1
 *              - X
 * </pre>
 */
public class MultipleDistinctAggregationToMarkDistinct
        implements Rule<AggregationNode>
{
    private static final Pattern<AggregationNode> PATTERN = aggregation()
            .matching(
                    Predicates.and(
                            MultipleDistinctAggregationToMarkDistinct::hasNoDistinctWithFilterOrMask,
                            Predicates.or(
                                    MultipleDistinctAggregationToMarkDistinct::hasMultipleDistincts,
                                    MultipleDistinctAggregationToMarkDistinct::hasMixedDistinctAndNonDistincts)));

    private static boolean hasNoDistinctWithFilterOrMask(AggregationNode aggregation)
    {
        return aggregation.getAggregations()
                .values().stream()
                .noneMatch(e -> e.isDistinct() && (e.getFilter().isPresent() || e.getMask().isPresent()));
    }

    private static boolean hasMultipleDistincts(AggregationNode aggregation)
    {
        return aggregation.getAggregations()
                .values().stream()
                .filter(e -> e.isDistinct())
                .map(Aggregation::getArguments)
                .map(HashSet::new)
                .distinct()
                .count() > 1;
    }

    private static boolean hasMixedDistinctAndNonDistincts(AggregationNode aggregation)
    {
        long distincts = aggregation.getAggregations()
                .values().stream()
                .filter(Aggregation::isDistinct)
                .count();

        return distincts > 0 && distincts < aggregation.getAggregations().size();
    }

    @Override
    public Pattern<AggregationNode> getPattern()
    {
        return PATTERN;
    }

    @Override
    public Result apply(AggregationNode parent, Captures captures, Context context)
    {
        if (!SystemSessionProperties.useMarkDistinct(context.getSession())) {
            return Result.empty();
        }

        // the distinct marker for the given set of input columns
        Map<Set<VariableReferenceExpression>, VariableReferenceExpression> markers = new HashMap<>();

        Map<VariableReferenceExpression, Aggregation> newAggregations = new HashMap<>();
        PlanNode subPlan = parent.getSource();

        for (Map.Entry<VariableReferenceExpression, Aggregation> entry : parent.getAggregations().entrySet()) {
            Aggregation aggregation = entry.getValue();

            if (aggregation.isDistinct() && !aggregation.getFilter().isPresent() && !aggregation.getMask().isPresent()) {
                Set<VariableReferenceExpression> inputs = aggregation.getArguments().stream()
                        .map(VariableReferenceExpression.class::cast)
                        .collect(toSet());

                VariableReferenceExpression marker = markers.get(inputs);
                if (marker == null) {
                    marker = context.getVariableAllocator().newVariable(Iterables.getLast(inputs).getName(), BOOLEAN, "distinct");
                    markers.put(inputs, marker);

                    ImmutableSet.Builder<VariableReferenceExpression> distinctVariables = ImmutableSet.<VariableReferenceExpression>builder()
                            .addAll(parent.getGroupingKeys())
                            .addAll(inputs);
                    parent.getGroupIdVariable().ifPresent(distinctVariables::add);

                    subPlan = new MarkDistinctNode(
                            subPlan.getSourceLocation(),
                            context.getIdAllocator().getNextId(),
                            subPlan,
                            marker,
                            ImmutableList.copyOf(distinctVariables.build()),
                            Optional.empty());
                }

                // remove the distinct flag and set the distinct marker
                newAggregations.put(entry.getKey(),
                        new Aggregation(
                                aggregation.getCall(),
                                aggregation.getFilter(),
                                aggregation.getOrderBy(),
                                false,
                                Optional.of(marker)));
            }
            else {
                newAggregations.put(entry.getKey(), aggregation);
            }
        }

        return Result.ofPlanNode(
                new AggregationNode(
                        parent.getSourceLocation(),
                        parent.getId(),
                        subPlan,
                        newAggregations,
                        parent.getGroupingSets(),
                        ImmutableList.of(),
                        parent.getStep(),
                        parent.getHashVariable(),
                        parent.getGroupIdVariable(),
                        parent.getAggregationId()));
    }
}