TestPushPartialAggregationThroughExchange.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;
import com.facebook.presto.cost.PartialAggregationStatsEstimate;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.VariableStatsEstimate;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;
import static com.facebook.presto.SystemSessionProperties.PARTIAL_AGGREGATION_STRATEGY;
import static com.facebook.presto.SystemSessionProperties.USE_PARTIAL_AGGREGATION_HISTORY;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL;
import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE;
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.FACT;
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.LOW;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;
import static com.facebook.presto.sql.relational.Expressions.variable;
public class TestPushPartialAggregationThroughExchange
extends BaseRuleTest
{
@Test
public void testPartialAggregationAdded()
{
tester().assertThat(new PushPartialAggregationThroughExchange(getFunctionManager(), false))
.setSystemProperty(PARTIAL_AGGREGATION_STRATEGY, "AUTOMATIC")
.on(p -> {
VariableReferenceExpression a = p.variable("a");
return p.aggregation(ab -> ab
.source(
p.exchange(e -> e
.addSource(p.values(a))
.addInputsSet(a)
.singleDistributionPartitioningScheme(a)))
.addAggregation(p.variable("SUM", DOUBLE), p.rowExpression("SUM(a)"))
.globalGrouping()
.step(PARTIAL));
})
.matches(exchange(
project(
aggregation(
ImmutableMap.of("SUM", functionCall("sum", ImmutableList.of("a"))),
PARTIAL,
values("a")))));
}
@Test
public void testNoPartialAggregationWhenDisabled()
{
tester().assertThat(new PushPartialAggregationThroughExchange(getFunctionManager(), false))
.setSystemProperty(PARTIAL_AGGREGATION_STRATEGY, "NEVER")
.on(p -> {
VariableReferenceExpression a = p.variable("a");
return p.aggregation(ab -> ab
.source(
p.exchange(e -> e
.addSource(p.values(a))
.addInputsSet(a)
.singleDistributionPartitioningScheme(a)))
.addAggregation(p.variable("SUM", DOUBLE), p.rowExpression("SUM(a)"))
.globalGrouping()
.step(PARTIAL));
})
.doesNotFire();
}
@Test
public void testNoPartialAggregationWhenReductionBelowThreshold()
{
tester().assertThat(new PushPartialAggregationThroughExchange(getFunctionManager(), false))
.setSystemProperty(PARTIAL_AGGREGATION_STRATEGY, "AUTOMATIC")
.on(p -> {
VariableReferenceExpression a = p.variable("a", DOUBLE);
VariableReferenceExpression b = p.variable("b", DOUBLE);
return p.aggregation(ab -> ab
.source(
p.exchange(e -> e
.addSource(p.values(new PlanNodeId("values"), a, b))
.addInputsSet(a, b)
.singleDistributionPartitioningScheme(a, b)))
.addAggregation(p.variable("SUM", DOUBLE), p.rowExpression("SUM(a)"))
.singleGroupingSet(b)
.step(SINGLE));
})
.overrideStats("values", PlanNodeStatsEstimate.builder()
.setOutputRowCount(1000)
.addVariableStatistics(variable("b", DOUBLE), new VariableStatsEstimate(0, 100, 0, 8, 800))
.setConfidence(FACT)
.build())
.doesNotFire();
}
@Test
public void testNoPartialAggregationWhenReductionBelowThresholdUsingPartialAggregationStats()
{
tester().assertThat(new PushPartialAggregationThroughExchange(getFunctionManager(), false))
.setSystemProperty(PARTIAL_AGGREGATION_STRATEGY, "AUTOMATIC")
.setSystemProperty(USE_PARTIAL_AGGREGATION_HISTORY, "true")
.on(p -> constructAggregation(p))
.overrideStats("aggregation", PlanNodeStatsEstimate.builder()
.addVariableStatistics(variable("b", DOUBLE), new VariableStatsEstimate(0, 100, 0, 8, 800))
.setConfidence(FACT)
.setPartialAggregationStatsEstimate(new PartialAggregationStatsEstimate(1000, 800, 10, 10))
.build())
.doesNotFire();
}
@Test
public void testNoPartialAggregationWhenReductionAboveThresholdUsingPartialAggregationStats()
{
// when use_partial_aggregation_history=true, we use row count reduction (instead of bytes) to decide if partial aggregation is useful
tester().assertThat(new PushPartialAggregationThroughExchange(getFunctionManager(), false))
.setSystemProperty(PARTIAL_AGGREGATION_STRATEGY, "AUTOMATIC")
.setSystemProperty(USE_PARTIAL_AGGREGATION_HISTORY, "true")
.on(p -> constructAggregation(p))
.overrideStats("aggregation", PlanNodeStatsEstimate.builder()
.addVariableStatistics(variable("b", DOUBLE), new VariableStatsEstimate(0, 100, 0, 8, 800))
.setConfidence(FACT)
.setPartialAggregationStatsEstimate(new PartialAggregationStatsEstimate(1000, 300, 10, 10))
.build())
.doesNotFire();
}
@Test
public void testNoPartialAggregationWhenRowReductionBelowThreshold()
{
tester().assertThat(new PushPartialAggregationThroughExchange(getFunctionManager(), false))
.setSystemProperty(PARTIAL_AGGREGATION_STRATEGY, "AUTOMATIC")
.setSystemProperty(USE_PARTIAL_AGGREGATION_HISTORY, "true")
.on(p -> constructAggregation(p))
.overrideStats("aggregation", PlanNodeStatsEstimate.builder()
.addVariableStatistics(variable("b", DOUBLE), new VariableStatsEstimate(0, 100, 0, 8, 800))
.setConfidence(FACT)
.setPartialAggregationStatsEstimate(new PartialAggregationStatsEstimate(0, 300, 10, 8))
.build())
.doesNotFire();
}
@Test
public void testPartialAggregationWhenRowReductionAboveThreshold()
{
tester().assertThat(new PushPartialAggregationThroughExchange(getFunctionManager(), false))
.setSystemProperty(PARTIAL_AGGREGATION_STRATEGY, "AUTOMATIC")
.setSystemProperty(USE_PARTIAL_AGGREGATION_HISTORY, "true")
.on(p -> constructAggregation(p))
.overrideStats("aggregation", PlanNodeStatsEstimate.builder()
.addVariableStatistics(variable("b", DOUBLE), new VariableStatsEstimate(0, 100, 0, 8, 800))
.setConfidence(FACT)
.setPartialAggregationStatsEstimate(new PartialAggregationStatsEstimate(0, 300, 10, 1))
.build())
.matches(aggregation(ImmutableMap.of("sum", functionCall("sum", ImmutableList.of("sum0"))),
aggregation(
ImmutableMap.of("sum0", functionCall("sum", ImmutableList.of("a"))),
exchange(
values("a", "b")))));
}
@Test
public void testPartialAggregationEnabledWhenNotConfident()
{
tester().assertThat(new PushPartialAggregationThroughExchange(getFunctionManager(), false))
.setSystemProperty(PARTIAL_AGGREGATION_STRATEGY, "AUTOMATIC")
.on(p -> {
VariableReferenceExpression a = p.variable("a", DOUBLE);
VariableReferenceExpression b = p.variable("b", DOUBLE);
return p.aggregation(ab -> ab
.source(
p.exchange(e -> e
.addSource(p.values(new PlanNodeId("values"), a, b))
.addInputsSet(a, b)
.singleDistributionPartitioningScheme(a, b)))
.addAggregation(p.variable("SUM", DOUBLE), p.rowExpression("SUM(a)"))
.singleGroupingSet(b)
.step(PARTIAL));
})
.overrideStats("values", PlanNodeStatsEstimate.builder()
.setOutputRowCount(1000)
.addVariableStatistics(variable("b", DOUBLE), new VariableStatsEstimate(0, 100, 0, 8, 800))
.setConfidence(LOW)
.build())
.matches(exchange(
project(
aggregation(
ImmutableMap.of("SUM", functionCall("sum", ImmutableList.of("a"))),
PARTIAL,
values("a", "b")))));
}
private static AggregationNode constructAggregation(PlanBuilder p)
{
VariableReferenceExpression a = p.variable("a", DOUBLE);
VariableReferenceExpression b = p.variable("b", DOUBLE);
return p.aggregation(ab -> ab
.source(
p.exchange(e -> e
.addSource(p.values(new PlanNodeId("values"), a, b))
.addInputsSet(a, b)
.singleDistributionPartitioningScheme(
ImmutableList.of(a, b))))
.addAggregation(p.variable("sum", DOUBLE), p.rowExpression("sum(a)"))
.singleGroupingSet(b)
.setPlanNodeId(new PlanNodeId("aggregation")));
}
}