RemoveRedundantDistinctAggregation.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.InternalPlanVisitor;
import com.google.common.collect.ImmutableList;

import java.util.List;
import java.util.Set;

import static com.facebook.presto.SystemSessionProperties.isRemoveRedundantDistinctAggregationEnabled;
import static com.facebook.presto.spi.plan.AggregationNode.isDistinct;
import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.Objects.requireNonNull;

/**
 * Remove the redundant distinct if output is already distinct.
 * For example, for query select distinct k, sum(x) from table group by k
 * The plan will change
 * <p>
 * From:
 * <pre>
 * - Aggregation group by k, sum
 *   - Aggregation (sum <- AGG(x)) group by k
 * </pre>
 * To:
 * <pre>
 * - Aggregation (sum <- AGG(x)) group by k
 * </pre>
 * <p>
 */
public class RemoveRedundantDistinctAggregation
        implements PlanOptimizer
{
    private boolean isEnabledForTesting;

    @Override
    public void setEnabledForTesting(boolean isSet)
    {
        isEnabledForTesting = isSet;
    }

    @Override
    public boolean isEnabled(Session session)
    {
        return isEnabledForTesting || isRemoveRedundantDistinctAggregationEnabled(session);
    }

    @Override
    public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector)
    {
        if (isEnabled(session)) {
            Rewriter rewriter = new RemoveRedundantDistinctAggregation.Rewriter();
            PlanWithProperties result = rewriter.accept(plan);
            return PlanOptimizerResult.optimizerResult(result.getNode(), rewriter.isPlanChanged());
        }
        return PlanOptimizerResult.optimizerResult(plan, false);
    }

    private static class PlanWithProperties
    {
        private final PlanNode node;
        // Variables in each set combines to be distinct in the output of the plan node.
        private final List<Set<VariableReferenceExpression>> distinctVariableSet;

        public PlanWithProperties(PlanNode node, List<Set<VariableReferenceExpression>> distinctVariableSet)
        {
            this.node = requireNonNull(node, "node is null");
            this.distinctVariableSet = requireNonNull(distinctVariableSet, "StreamProperties is null");
        }

        public PlanNode getNode()
        {
            return node;
        }

        public List<Set<VariableReferenceExpression>> getProperties()
        {
            return distinctVariableSet;
        }
    }

    private static class Rewriter
            extends InternalPlanVisitor<PlanWithProperties, Void>
    {
        private boolean planChanged;

        public boolean isPlanChanged()
        {
            return planChanged;
        }

        @Override
        public PlanWithProperties visitPlan(PlanNode node, Void context)
        {
            // For nodes such as join, unnest etc. the distinct properties may be violated, hence pass empty list for these cases.
            return planAndRecplace(node, false);
        }

        @Override
        public PlanWithProperties visitAggregation(AggregationNode node, Void context)
        {
            PlanWithProperties child = accept(node.getSource());
            if (isDistinct(node) && child.getProperties().stream().anyMatch(node.getGroupingKeys()::containsAll)) {
                planChanged = true;
                return child;
            }
            ImmutableList.Builder<Set<VariableReferenceExpression>> properties = ImmutableList.builder();
            // Only do it for aggregations with one single grouping set
            if (node.getGroupingSetCount() == 1 && !node.getGroupingKeys().isEmpty()) {
                properties.add(node.getGroupingKeys().stream().collect(toImmutableSet()));
            }
            PlanNode newAggregation = node.replaceChildren(ImmutableList.of(child.getNode()));
            return new PlanWithProperties(newAggregation, properties.build());
        }

        @Override
        public PlanWithProperties visitProject(ProjectNode node, Void context)
        {
            return planAndRecplace(node, true);
        }

        private PlanWithProperties accept(PlanNode node)
        {
            PlanWithProperties result = node.accept(this, null);
            return new PlanWithProperties(
                    result.getNode().assignStatsEquivalentPlanNode(node.getStatsEquivalentPlanNode()),
                    result.getProperties());
        }

        private PlanWithProperties planAndRecplace(PlanNode node, boolean passProperties)
        {
            List<PlanWithProperties> children = node.getSources().stream().map(this::accept).collect(toImmutableList());
            PlanNode result = replaceChildren(node, children.stream().map(PlanWithProperties::getNode).collect(toImmutableList()));
            if (!passProperties) {
                return new PlanWithProperties(result, ImmutableList.of());
            }
            ImmutableList.Builder<Set<VariableReferenceExpression>> properties = ImmutableList.builder();
            children.stream().map(PlanWithProperties::getProperties).forEach(properties::addAll);
            return new PlanWithProperties(result, properties.build());
        }
    }
}