TestPruneOrderByInAggregation.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Ordering;
import com.facebook.presto.spi.plan.OrderingScheme;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.facebook.presto.sql.tree.SortItem;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;

import java.util.List;
import java.util.Optional;

import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST;
import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.singleGroupingSet;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.sort;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;

public class TestPruneOrderByInAggregation
        extends BaseRuleTest
{
    private static final FunctionAndTypeManager FUNCTION_MANAGER = MetadataManager.createTestMetadataManager().getFunctionAndTypeManager();

    @Test
    public void testBasics()
    {
        tester().assertThat(new PruneOrderByInAggregation(FUNCTION_MANAGER))
                .on(this::buildAggregation)
                .matches(
                        aggregation(
                                singleGroupingSet("key"),
                                ImmutableMap.of(
                                        Optional.of("avg"), functionCall("avg", ImmutableList.of("input")),
                                        Optional.of("array_agg"), functionCall(
                                                "array_agg",
                                                ImmutableList.of("input"),
                                                ImmutableList.of(sort("input", SortItem.Ordering.ASCENDING, SortItem.NullOrdering.UNDEFINED)))),
                                ImmutableMap.of(
                                        new Symbol("avg"), new Symbol("mask"),
                                        new Symbol("array_agg"), new Symbol("mask")),
                                Optional.empty(),
                                SINGLE,
                                values("input", "key", "keyHash", "mask")));
    }

    private AggregationNode buildAggregation(PlanBuilder planBuilder)
    {
        VariableReferenceExpression avg = planBuilder.variable("avg");
        VariableReferenceExpression arrayAgg = planBuilder.variable("array_agg");
        VariableReferenceExpression input = planBuilder.variable("input");
        VariableReferenceExpression key = planBuilder.variable("key");
        VariableReferenceExpression keyHash = planBuilder.variable("keyHash");
        VariableReferenceExpression mask = planBuilder.variable("mask");
        List<VariableReferenceExpression> sourceVariables = ImmutableList.of(input, key, keyHash, mask);
        OrderingScheme orderingScheme = new OrderingScheme(ImmutableList.of(new Ordering(input, ASC_NULLS_LAST)));

        return planBuilder.aggregation(aggregationBuilder -> aggregationBuilder
                .singleGroupingSet(key)
                .addAggregation(avg, planBuilder.rowExpression("avg(input order by input)"), Optional.empty(), Optional.of(orderingScheme), false, Optional.of(mask))
                .addAggregation(arrayAgg, planBuilder.rowExpression("array_agg(input order by input)"), Optional.empty(), Optional.of(orderingScheme), false, Optional.of(mask))
                .hashVariable(keyHash)
                .source(planBuilder.values(sourceVariables, ImmutableList.of())));
    }
}