TestPruneAggregationSourceColumns.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;
import java.util.List;
import java.util.Optional;
import java.util.function.Predicate;
import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.singleGroupingSet;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;
import static com.google.common.base.Predicates.alwaysTrue;
import static com.google.common.collect.ImmutableList.toImmutableList;
public class TestPruneAggregationSourceColumns
extends BaseRuleTest
{
@Test
public void testNotAllInputsReferenced()
{
tester().assertThat(new PruneAggregationSourceColumns())
.on(p -> buildAggregation(p, alwaysTrue()))
.matches(
aggregation(
singleGroupingSet("key"),
ImmutableMap.of(
Optional.of("avg"),
functionCall("avg", ImmutableList.of("input"))),
ImmutableMap.of(new Symbol("avg"), new Symbol("mask")),
Optional.empty(),
SINGLE,
strictProject(
ImmutableMap.of(
"input", expression("input"),
"key", expression("key"),
"keyHash", expression("keyHash"),
"mask", expression("mask")),
values("input", "key", "keyHash", "mask", "unused"))));
}
@Test
public void testAllInputsReferenced()
{
tester().assertThat(new PruneAggregationSourceColumns())
.on(p -> buildAggregation(p, variable -> !variable.getName().equals("unused")))
.doesNotFire();
}
private AggregationNode buildAggregation(PlanBuilder planBuilder, Predicate<VariableReferenceExpression> sourceVariableFilter)
{
VariableReferenceExpression avg = planBuilder.variable("avg");
VariableReferenceExpression input = planBuilder.variable("input");
VariableReferenceExpression key = planBuilder.variable("key");
VariableReferenceExpression keyHash = planBuilder.variable("keyHash");
VariableReferenceExpression mask = planBuilder.variable("mask");
VariableReferenceExpression unused = planBuilder.variable("unused");
List<VariableReferenceExpression> filteredSourceVariables = ImmutableList.of(input, key, keyHash, mask, unused).stream()
.filter(sourceVariableFilter)
.collect(toImmutableList());
return planBuilder.aggregation(aggregationBuilder -> aggregationBuilder
.singleGroupingSet(key)
.addAggregation(avg, planBuilder.rowExpression("avg(input)"), Optional.empty(), Optional.empty(), false, Optional.of(mask))
.hashVariable(keyHash)
.source(
planBuilder.values(
filteredSourceVariables,
ImmutableList.of())));
}
}