TestRewriteIfOverAggregation.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.assertions.BasePlanTest;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.testng.annotations.Test;

import java.util.Optional;

import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_CONDITIONAL_AGGREGATION_ENABLED;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.groupingSet;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan;

public class TestRewriteIfOverAggregation
        extends BasePlanTest
{
    private Session enableOptimization()
    {
        return Session.builder(this.getQueryRunner().getDefaultSession())
                .setSystemProperty(OPTIMIZE_CONDITIONAL_AGGREGATION_ENABLED, "true")
                .build();
    }

    @Test
    public void testConditionOnGrouping()
    {
        assertPlan("SELECT orderstatus, shippriority, IF(GROUPING(orderstatus, shippriority) = 0, sum(totalprice)) "
                        + "FROM orders GROUP BY GROUPING SETS ((orderstatus), (orderstatus, shippriority))",
                enableOptimization(),
                anyTree(
                        aggregation(
                                new PlanMatchPattern.GroupingSetDescriptor(ImmutableList.of("orderstatus$gid", "shippriority$gid", "groupid"), 2, ImmutableSet.of()),
                                ImmutableMap.of(Optional.of("pricesum"), functionCall("sum", ImmutableList.of("totalprice"))),
                                ImmutableMap.of(new Symbol("pricesum"), new Symbol("mask")),
                                Optional.of(new Symbol("groupid")),
                                AggregationNode.Step.PARTIAL,
                                project(
                                        ImmutableMap.of("mask", expression("array[1, 0][groupid+1]=0")),
                                        groupingSet(
                                                ImmutableList.of(ImmutableList.of("orderstatus"), ImmutableList.of("orderstatus", "shippriority")),
                                                ImmutableMap.of("totalprice", "totalprice"),
                                                "groupid",
                                                ImmutableMap.of("orderstatus$gid", expression("orderstatus"), "shippriority$gid", expression("shippriority")),
                                                tableScan("orders", ImmutableMap.of("totalprice", "totalprice", "orderstatus", "orderstatus", "shippriority", "shippriority")))))));
    }

    // Should not be rewritten
    @Test
    public void testConditionOnAggregation()
    {
        assertPlan("select orderpriority, if(count(1)>3000, avg(totalprice)) from orders group by orderpriority ",
                enableOptimization(),
                anyTree(
                        project(
                                ImmutableMap.of("ifexp", expression("if(count > 3000, avg, null)")),
                                aggregation(
                                        ImmutableMap.of("avg", functionCall("avg", ImmutableList.of("partial_avg")), "count", functionCall("count", ImmutableList.of("partial_count"))),
                                        exchange(
                                                aggregation(
                                                        ImmutableMap.of("partial_avg", functionCall("avg", ImmutableList.of("totalprice")), "partial_count", functionCall("count", ImmutableList.of())),
                                                        project(
                                                                ImmutableMap.of("totalprice", expression("totalprice"), "orderpriority", expression("orderpriority")),
                                                                tableScan("orders", ImmutableMap.of("totalprice", "totalprice", "orderpriority", "orderpriority")))))))));
    }

    @Test
    public void testMultipleArgumentsAggregation()
    {
        assertPlan("SELECT orderstatus, shippriority, IF(GROUPING(orderstatus, shippriority) = 0, max_by(shippriority, totalprice)) "
                        + "FROM orders GROUP BY GROUPING SETS ((orderstatus), (orderstatus, shippriority))",
                enableOptimization(),
                anyTree(
                        aggregation(
                                new PlanMatchPattern.GroupingSetDescriptor(ImmutableList.of("orderstatus$gid", "shippriority$gid", "groupid"), 2, ImmutableSet.of()),
                                ImmutableMap.of(Optional.of("result"), functionCall("max_by", ImmutableList.of("shippriority", "totalprice"))),
                                ImmutableMap.of(new Symbol("result"), new Symbol("mask")),
                                Optional.of(new Symbol("groupid")),
                                AggregationNode.Step.PARTIAL,
                                project(
                                        ImmutableMap.of("mask", expression("array[1, 0][groupid+1]=0")),
                                        groupingSet(
                                                ImmutableList.of(ImmutableList.of("orderstatus"), ImmutableList.of("orderstatus", "shippriority")),
                                                ImmutableMap.of("totalprice", "totalprice", "shippriority", "shippriority"),
                                                "groupid",
                                                ImmutableMap.of("orderstatus$gid", expression("orderstatus"), "shippriority$gid", expression("shippriority")),
                                                tableScan("orders", ImmutableMap.of("totalprice", "totalprice", "orderstatus", "orderstatus", "shippriority", "shippriority")))))));
    }
}