TestLogicalAddExchangesBelowPartialAggregationOverGroupIdRuleSet.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner;
import com.facebook.presto.Session;
import com.facebook.presto.execution.TaskManagerConfig;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.sql.planner.assertions.BasePlanTest;
import com.facebook.presto.sql.planner.assertions.PlanAssert;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.testing.LocalQueryRunner;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.airlift.units.DataSize;
import org.testng.annotations.Test;
import static com.facebook.presto.SystemSessionProperties.ADD_EXCHANGE_BELOW_PARTIAL_AGGREGATION_OVER_GROUP_ID;
import static com.facebook.presto.SystemSessionProperties.MERGE_AGGREGATIONS_WITH_AND_WITHOUT_FILTER;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION;
import static io.airlift.units.DataSize.Unit.KILOBYTE;
import static io.airlift.units.DataSize.Unit.MEGABYTE;
public class TestLogicalAddExchangesBelowPartialAggregationOverGroupIdRuleSet
extends BasePlanTest
{
public TestLogicalAddExchangesBelowPartialAggregationOverGroupIdRuleSet()
{
super(TestLogicalAddExchangesBelowPartialAggregationOverGroupIdRuleSet::setup);
}
private static LocalQueryRunner setup()
{
// We set available max-partial-aggregation-memory to a low value to allow the rule to trigger for the TPCH tiny scale factor
TaskManagerConfig taskManagerConfig = new TaskManagerConfig().setMaxPartialAggregationMemoryUsage(DataSize.succinctDataSize(1, KILOBYTE));
return createQueryRunner(ImmutableMap.of(ADD_EXCHANGE_BELOW_PARTIAL_AGGREGATION_OVER_GROUP_ID, "true"), taskManagerConfig);
}
@Test
public void testRollup()
{
assertDistributedPlan("SELECT orderkey, suppkey, partkey, sum(quantity) from lineitem GROUP BY ROLLUP(orderkey, suppkey, partkey)",
anyTree(node(GroupIdNode.class,
// Since 'orderkey' will be the variable with the highest frequency, we repartition on it
anyTree(exchange(LOCAL, REPARTITION, ImmutableList.of(), ImmutableSet.of("orderkey"),
exchange(REMOTE_STREAMING, REPARTITION, ImmutableList.of(), ImmutableSet.of("orderkey"),
anyTree(tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey")))))))));
}
@Test
public void testNegativeCases()
{
// MERGE_AGGREGATIONS_WITH_AND_WITHOUT_FILTER adds a Project for an 'expr' that is pass-through through the GroupIdNode node
// The Rule does not apply when such a variable is used in an Aggregation but not in the GroupId grouping set
Session enableMergeAggregationWithAndWithoutFilter = Session.builder(getQueryRunner().getDefaultSession())
.setSystemProperty(MERGE_AGGREGATIONS_WITH_AND_WITHOUT_FILTER, "true")
.build();
String sql = "select partkey, sum(quantity), sum(quantity) filter (where discount > 0.1) from lineitem group by grouping sets((), (partkey))";
assertDistributedPlan(sql, enableMergeAggregationWithAndWithoutFilter,
anyTree(node(GroupIdNode.class,
project(ImmutableMap.of("partkey", expression("partkey"), "quantity", expression("quantity"), "expr", expression("discount > DOUBLE'0.1'")),
tableScan("lineitem",
ImmutableMap.of("partkey", "partkey", "quantity", "quantity", "discount", "discount"))))));
// Rule does not apply when aggregation will be effective due to a sufficiently high max-partial-aggregation-memory
TaskManagerConfig taskManagerConfig = new TaskManagerConfig().setMaxPartialAggregationMemoryUsage(DataSize.succinctDataSize(1, MEGABYTE));
try (LocalQueryRunner queryRunner = createQueryRunner(ImmutableMap.of(ADD_EXCHANGE_BELOW_PARTIAL_AGGREGATION_OVER_GROUP_ID, "true"), taskManagerConfig)) {
queryRunner.inTransaction(queryRunner.getDefaultSession(), transactionSession -> {
Plan plan = queryRunner.createPlan(transactionSession,
"SELECT orderkey, suppkey, partkey, sum(quantity) from lineitem GROUP BY ROLLUP(orderkey, suppkey, partkey)",
WarningCollector.NOOP);
PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getStatsCalculator(), plan,
anyTree(node(GroupIdNode.class,
tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey")))));
return null;
});
}
}
}