AggregationNodeUtils.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.cost.CachingStatsProvider;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.StatsCalculator;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.VariablesExtractor;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.LOW;
import static com.google.common.collect.ImmutableList.toImmutableList;

public class AggregationNodeUtils
{
    private AggregationNodeUtils() {}

    public static AggregationNode.Aggregation count(FunctionAndTypeManager functionAndTypeManager)
    {
        return new AggregationNode.Aggregation(
                new CallExpression("count",
                        new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver()).countFunction(),
                        BIGINT,
                        ImmutableList.of()),
                Optional.empty(),
                Optional.empty(),
                false,
                Optional.empty());
    }

    public static Set<VariableReferenceExpression> extractAggregationUniqueVariables(AggregationNode.Aggregation aggregation)
    {
        // types will be no longer needed once everything is RowExpression.
        ImmutableSet.Builder<VariableReferenceExpression> builder = ImmutableSet.builder();
        aggregation.getArguments().forEach(argument -> builder.addAll(extractAll(argument)));
        aggregation.getFilter().ifPresent(filter -> builder.addAll(extractAll(filter)));
        aggregation.getOrderBy().ifPresent(orderingScheme -> builder.addAll(orderingScheme.getOrderByVariables()));
        return builder.build();
    }

    private static List<VariableReferenceExpression> extractAll(RowExpression expression)
    {
        return VariablesExtractor.extractAll(expression)
                .stream()
                .collect(toImmutableList());
    }

    public static boolean isAllLowCardinalityGroupByKeys(AggregationNode aggregationNode, TableScanNode scanNode, Session session, StatsCalculator statsCalculator, TypeProvider types, long count)
    {
        List<VariableReferenceExpression> groupbyKeys = aggregationNode.getGroupingSets().getGroupingKeys().stream().collect(Collectors.toList());
        StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, types);
        PlanNodeStatsEstimate estimate = statsProvider.getStats(scanNode);
        if (estimate.confidenceLevel() == LOW) {
            // For safety, we assume they are low card if not confident
            // TODO(kaikalur) : maybe return low card only for partition keys if/when we can detect that
            return true;
        }

        return groupbyKeys.stream().noneMatch(x -> estimate.getVariableStatistics(x).getDistinctValuesCount() >= count);
    }

    public static AggregationNode.Aggregation removeFilterAndMask(AggregationNode.Aggregation aggregation)
    {
        Optional<RowExpression> filter = aggregation.getFilter();
        Optional<VariableReferenceExpression> mask = aggregation.getMask();

        if (filter.isPresent() || mask.isPresent()) {
            return new AggregationNode.Aggregation(
                    aggregation.getCall(),
                    Optional.empty(),
                    aggregation.getOrderBy(),
                    aggregation.isDistinct(),
                    Optional.empty());
        }

        return aggregation;
    }
}