TestMergePartialAggregationsWithFilter.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.optimizations;
import com.facebook.presto.Session;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.assertions.BasePlanTest;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.testng.annotations.Test;
import java.util.Optional;
import static com.facebook.presto.SystemSessionProperties.MERGE_AGGREGATIONS_WITH_AND_WITHOUT_FILTER;
import static com.facebook.presto.SystemSessionProperties.PARTIAL_AGGREGATION_STRATEGY;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.GroupingSetDescriptor;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.globalAggregation;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.groupingSet;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.singleGroupingSet;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.sort;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan;
import static com.facebook.presto.sql.tree.SortItem.NullOrdering.LAST;
import static com.facebook.presto.sql.tree.SortItem.Ordering.ASCENDING;
public class TestMergePartialAggregationsWithFilter
extends BasePlanTest
{
private Session enableOptimization()
{
return Session.builder(this.getQueryRunner().getDefaultSession())
.setSystemProperty(MERGE_AGGREGATIONS_WITH_AND_WITHOUT_FILTER, "true")
.setSystemProperty(PARTIAL_AGGREGATION_STRATEGY, "AUTOMATIC")
.build();
}
private Session disableOptimization()
{
return Session.builder(this.getQueryRunner().getDefaultSession())
.setSystemProperty(MERGE_AGGREGATIONS_WITH_AND_WITHOUT_FILTER, "false")
.setSystemProperty(PARTIAL_AGGREGATION_STRATEGY, "AUTOMATIC")
.build();
}
@Test
public void testOptimizationApplied()
{
assertPlan("SELECT partkey, sum(quantity), sum(quantity) filter (where orderkey > 0) from lineitem group by partkey",
enableOptimization(),
anyTree(
aggregation(
singleGroupingSet("partkey"),
ImmutableMap.of(Optional.of("finalSum"), functionCall("sum", ImmutableList.of("partialSum")), Optional.of("maskFinalSum"), functionCall("sum", ImmutableList.of("maskPartialSum"))),
ImmutableMap.of(),
Optional.empty(),
AggregationNode.Step.FINAL,
project(
ImmutableMap.of("maskPartialSum", expression("IF(expr, partialSum, null)")),
anyTree(
aggregation(
singleGroupingSet("partkey", "expr"),
ImmutableMap.of(Optional.of("partialSum"), functionCall("sum", ImmutableList.of("quantity"))),
ImmutableMap.of(),
Optional.empty(),
AggregationNode.Step.PARTIAL,
anyTree(
project(
ImmutableMap.of("expr", expression("orderkey > 0")),
tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "quantity", "quantity"))))))))),
false);
}
@Test
public void testOptimizationAppliedAllHasMask()
{
assertPlan("SELECT partkey, sum(quantity) filter (where orderkey > 10), sum(quantity) filter (where orderkey > 0) from lineitem group by partkey",
enableOptimization(),
anyTree(
aggregation(
singleGroupingSet("partkey"),
ImmutableMap.of(Optional.of("maskFinalSum"), functionCall("sum", ImmutableList.of("maskPartialSum")),
Optional.of("maskFinalSum2"), functionCall("sum", ImmutableList.of("maskPartialSum2"))),
ImmutableMap.of(),
Optional.empty(),
AggregationNode.Step.FINAL,
project(
ImmutableMap.of("maskPartialSum", expression("IF(expr, partialSum, null)"),
"maskPartialSum2", expression("IF(expr2, partialSum, null)")),
anyTree(
aggregation(
singleGroupingSet("partkey", "expr", "expr2"),
ImmutableMap.of(Optional.of("partialSum"), functionCall("sum", ImmutableList.of("quantity"))),
ImmutableMap.of(new Symbol("partialSum"), new Symbol("expr_or")),
Optional.empty(),
AggregationNode.Step.PARTIAL,
project(
ImmutableMap.of("expr_or", expression("expr or expr2")),
project(
ImmutableMap.of("expr", expression("orderkey > 0"), "expr2", expression("orderkey >10")),
tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "quantity", "quantity"))))))))),
false);
}
@Test
public void testOptimizationDisabled()
{
assertPlan("SELECT partkey, sum(quantity), sum(quantity) filter (where orderkey > 0) from lineitem group by partkey",
disableOptimization(),
anyTree(
aggregation(
singleGroupingSet("partkey"),
ImmutableMap.of(Optional.of("finalSum"), functionCall("sum", ImmutableList.of("partialSum")), Optional.of("maskFinalSum"), functionCall("sum", ImmutableList.of("maskPartialSum"))),
ImmutableMap.of(),
Optional.empty(),
AggregationNode.Step.FINAL,
anyTree(
aggregation(
singleGroupingSet("partkey"),
ImmutableMap.of(Optional.of("partialSum"), functionCall("sum", ImmutableList.of("quantity")), Optional.of("maskPartialSum"), functionCall("sum", ImmutableList.of("quantity"))),
ImmutableMap.of(new Symbol("maskPartialSum"), new Symbol("expr")),
Optional.empty(),
AggregationNode.Step.PARTIAL,
project(
ImmutableMap.of("expr", expression("orderkey > 0")),
tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "quantity", "quantity"))))))),
false);
}
@Test
public void testMultipleAggregations()
{
assertPlan("SELECT partkey, sum(quantity), sum(quantity) filter (where orderkey > 0), avg(quantity), avg(quantity) filter (where orderkey > 0) from lineitem group by partkey",
enableOptimization(),
anyTree(
aggregation(
singleGroupingSet("partkey"),
ImmutableMap.of(Optional.of("finalSum"), functionCall("sum", ImmutableList.of("partialSum")), Optional.of("maskFinalSum"), functionCall("sum", ImmutableList.of("maskPartialSum")),
Optional.of("finalAvg"), functionCall("avg", ImmutableList.of("partialAvg")), Optional.of("maskFinalAvg"), functionCall("avg", ImmutableList.of("maskPartialAvg"))),
ImmutableMap.of(),
Optional.empty(),
AggregationNode.Step.FINAL,
project(
ImmutableMap.of("maskPartialSum", expression("IF(expr, partialSum, null)"), "maskPartialAvg", expression("IF(expr, partialAvg, null)")),
anyTree(
aggregation(
singleGroupingSet("partkey", "expr"),
ImmutableMap.of(Optional.of("partialSum"), functionCall("sum", ImmutableList.of("quantity")), Optional.of("partialAvg"), functionCall("avg", ImmutableList.of("quantity"))),
ImmutableMap.of(),
Optional.empty(),
AggregationNode.Step.PARTIAL,
anyTree(
project(
ImmutableMap.of("expr", expression("orderkey > 0")),
tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "quantity", "quantity"))))))))),
false);
}
@Test
public void testAggregationsMultipleLevel()
{
assertPlan("select partkey, avg(sum), avg(sum) filter (where suppkey > 0), avg(filtersum) from (select partkey, suppkey, sum(quantity) sum, sum(quantity) filter (where orderkey > 0) filtersum from lineitem group by partkey, suppkey) t group by partkey",
enableOptimization(),
anyTree(
aggregation(
singleGroupingSet("partkey"),
ImmutableMap.of(Optional.of("finalAvg"), functionCall("avg", ImmutableList.of("partialAvg")), Optional.of("maskFinalAvg"), functionCall("avg", ImmutableList.of("maskPartialAvg")),
Optional.of("finalFilterAvg"), functionCall("avg", ImmutableList.of("partialFilterAvg"))),
ImmutableMap.of(),
Optional.empty(),
AggregationNode.Step.FINAL,
project(
ImmutableMap.of("maskPartialAvg", expression("IF(expr_2, partialAvg, null)")),
anyTree(
aggregation(
singleGroupingSet("partkey", "expr_2"),
ImmutableMap.of(Optional.of("partialAvg"), functionCall("avg", ImmutableList.of("finalSum")), Optional.of("partialFilterAvg"), functionCall("avg", ImmutableList.of("maskFinalSum"))),
ImmutableMap.of(),
Optional.empty(),
AggregationNode.Step.PARTIAL,
anyTree(
project(
ImmutableMap.of("expr_2", expression("suppkey > 0")),
aggregation(
singleGroupingSet("partkey", "suppkey"),
ImmutableMap.of(Optional.of("finalSum"), functionCall("sum", ImmutableList.of("partialSum")), Optional.of("maskFinalSum"), functionCall("sum", ImmutableList.of("maskPartialSum"))),
ImmutableMap.of(),
Optional.empty(),
AggregationNode.Step.FINAL,
project(
ImmutableMap.of("maskPartialSum", expression("IF(expr, partialSum, null)")),
anyTree(
aggregation(
singleGroupingSet("partkey", "suppkey", "expr"),
ImmutableMap.of(Optional.of("partialSum"), functionCall("sum", ImmutableList.of("quantity"))),
ImmutableMap.of(),
Optional.empty(),
AggregationNode.Step.PARTIAL,
anyTree(
project(
ImmutableMap.of("expr", expression("orderkey > 0")),
tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "quantity", "quantity", "suppkey", "suppkey"))))))))))))))),
false);
}
@Test
public void testAggregationsMultipleLevelAllAggWithMask()
{
assertPlan("select partkey, avg(sum) filter (where suppkey > 10), avg(sum) filter (where suppkey > 0), avg(filtersum) from (select partkey, suppkey, sum(quantity) filter (where orderkey > 10) sum, sum(quantity) filter (where orderkey > 0) filtersum from lineitem group by partkey, suppkey) t group by partkey",
enableOptimization(),
anyTree(
aggregation(
singleGroupingSet("partkey"),
ImmutableMap.of(Optional.of("finalAvg"), functionCall("avg", ImmutableList.of("maskPartialAvg_g10")), Optional.of("maskFinalAvg"), functionCall("avg", ImmutableList.of("maskPartialAvg")),
Optional.of("finalFilterAvg"), functionCall("avg", ImmutableList.of("partialFilterAvg"))),
ImmutableMap.of(),
Optional.empty(),
AggregationNode.Step.FINAL,
project(
ImmutableMap.of("maskPartialAvg", expression("IF(expr_2, partialAvg, null)"),
"maskPartialAvg_g10", expression("IF(expr_2_g10, partialAvg, null)")),
anyTree(
aggregation(
singleGroupingSet("partkey", "expr_2", "expr_2_g10"),
ImmutableMap.of(Optional.of("partialAvg"), functionCall("avg", ImmutableList.of("finalSum_g10")), Optional.of("partialFilterAvg"), functionCall("avg", ImmutableList.of("maskFinalSum"))),
ImmutableMap.of(new Symbol("partialAvg"), new Symbol("expr_2_or")),
Optional.empty(),
AggregationNode.Step.PARTIAL,
project(
ImmutableMap.of("expr_2_or", expression("expr_2 or expr_2_g10")),
project(
ImmutableMap.of("expr_2", expression("suppkey > 0"), "expr_2_g10", expression("suppkey > 10")),
aggregation(
singleGroupingSet("partkey", "suppkey"),
ImmutableMap.of(Optional.of("finalSum_g10"), functionCall("sum", ImmutableList.of("maskPartialSum_g10")), Optional.of("maskFinalSum"), functionCall("sum", ImmutableList.of("maskPartialSum"))),
ImmutableMap.of(),
Optional.empty(),
AggregationNode.Step.FINAL,
project(
ImmutableMap.of("maskPartialSum", expression("IF(expr, partialSum, null)"),
"maskPartialSum_g10", expression("IF(expr_g10, partialSum, null)")),
anyTree(
aggregation(
singleGroupingSet("partkey", "suppkey", "expr", "expr_g10"),
ImmutableMap.of(Optional.of("partialSum"), functionCall("sum", ImmutableList.of("quantity"))),
ImmutableMap.of(new Symbol("partialSum"), new Symbol("expr_or")),
Optional.empty(),
AggregationNode.Step.PARTIAL,
project(
ImmutableMap.of("expr_or", expression("expr or expr_g10")),
project(
ImmutableMap.of("expr", expression("orderkey > 0"), "expr_g10", expression("orderkey > 10")),
tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "quantity", "quantity", "suppkey", "suppkey"))))))))))))))),
false);
}
@Test
public void testGlobalOptimization()
{
assertPlan("SELECT sum(quantity), sum(quantity) filter (where orderkey > 0) from lineitem",
enableOptimization(),
anyTree(
aggregation(
globalAggregation(),
ImmutableMap.of(Optional.of("finalSum"), functionCall("sum", ImmutableList.of("partialSum")), Optional.of("maskFinalSum"), functionCall("sum", ImmutableList.of("maskPartialSum"))),
ImmutableMap.of(),
Optional.empty(),
AggregationNode.Step.FINAL,
anyTree(
aggregation(
globalAggregation(),
ImmutableMap.of(Optional.of("partialSum"), functionCall("sum", ImmutableList.of("quantity")), Optional.of("maskPartialSum"), functionCall("sum", ImmutableList.of("quantity"))),
ImmutableMap.of(new Symbol("maskPartialSum"), new Symbol("expr")),
Optional.empty(),
AggregationNode.Step.PARTIAL,
project(
ImmutableMap.of("expr", expression("orderkey > 0")),
tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "quantity", "quantity"))))))),
false);
}
@Test
public void testHasOrderBy()
{
assertPlan("select partkey, array_agg(suppkey order by suppkey), array_agg(suppkey order by suppkey) filter (where orderkey > 0) from lineitem group by partkey",
enableOptimization(),
anyTree(
aggregation(
singleGroupingSet("partkey"),
ImmutableMap.of(Optional.of("array_agg"), functionCall("array_agg", ImmutableList.of("suppkey"), ImmutableList.of(sort("suppkey", ASCENDING, LAST))),
Optional.of("array_agg_filter"), functionCall("array_agg", ImmutableList.of("suppkey"), ImmutableList.of(sort("suppkey", ASCENDING, LAST)))),
ImmutableMap.of(new Symbol("array_agg_filter"), new Symbol("expr")),
Optional.empty(),
AggregationNode.Step.SINGLE,
anyTree(
project(
ImmutableMap.of("expr", expression("orderkey > 0")),
tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "suppkey", "suppkey")))))),
false);
}
@Test
public void testGroupingSets()
{
assertPlan("SELECT partkey, sum(quantity), sum(quantity) filter (where orderkey > 0) from lineitem group by grouping sets((), (partkey))",
enableOptimization(),
anyTree(
aggregation(
new GroupingSetDescriptor(ImmutableList.of("partkey$gid", "groupid"), 2, ImmutableSet.of(0)),
ImmutableMap.of(Optional.of("finalSum"), functionCall("sum", ImmutableList.of("partialSum")), Optional.of("maskFinalSum"), functionCall("sum", ImmutableList.of("maskPartialSum"))),
ImmutableMap.of(),
Optional.of(new Symbol("groupid")),
AggregationNode.Step.FINAL,
project(
ImmutableMap.of("maskPartialSum", expression("IF(expr, partialSum, null)")),
anyTree(
aggregation(
new GroupingSetDescriptor(ImmutableList.of("partkey$gid", "groupid", "expr"), 2, ImmutableSet.of(0)),
ImmutableMap.of(Optional.of("partialSum"), functionCall("sum", ImmutableList.of("quantity"))),
ImmutableMap.of(),
Optional.of(new Symbol("groupid")),
AggregationNode.Step.PARTIAL,
anyTree(
groupingSet(
ImmutableList.of(ImmutableList.of(), ImmutableList.of("partkey")),
ImmutableMap.of("quantity", "quantity", "expr", "expr"),
"groupid",
ImmutableMap.of("partkey$gid", expression("partkey")),
project(
ImmutableMap.of("expr", expression("orderkey > 0")),
tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "quantity", "quantity")))))))))),
false);
}
@Test
public void testCalledOnNull()
{
assertPlan("SELECT partkey, count(*), count(*) filter (where orderkey > 0) from lineitem group by partkey",
enableOptimization(),
anyTree(
aggregation(
singleGroupingSet("partkey"),
ImmutableMap.of(Optional.of("finalCnt"), functionCall("count", ImmutableList.of("partialCnt")), Optional.of("maskFinalCnt"), functionCall("count", ImmutableList.of("maskPartialCnt"))),
ImmutableMap.of(),
Optional.empty(),
AggregationNode.Step.FINAL,
anyTree(
aggregation(
singleGroupingSet("partkey"),
ImmutableMap.of(Optional.of("partialCnt"), functionCall("count", ImmutableList.of()), Optional.of("maskPartialCnt"), functionCall("count", ImmutableList.of())),
ImmutableMap.of(new Symbol("maskPartialCnt"), new Symbol("expr")),
Optional.empty(),
AggregationNode.Step.PARTIAL,
project(
ImmutableMap.of("expr", expression("orderkey > 0")),
tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey"))))))),
false);
}
}