TestPruneWindowColumns.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;
import com.facebook.presto.common.block.SortOrder;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.plan.DataOrganizationSpecification;
import com.facebook.presto.spi.plan.Ordering;
import com.facebook.presto.spi.plan.OrderingScheme;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.WindowNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.testng.annotations.Test;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager;
import static com.facebook.presto.spi.plan.WindowNode.Frame.BoundType.CURRENT_ROW;
import static com.facebook.presto.spi.plan.WindowNode.Frame.BoundType.FOLLOWING;
import static com.facebook.presto.spi.plan.WindowNode.Frame.BoundType.PRECEDING;
import static com.facebook.presto.spi.plan.WindowNode.Frame.BoundType.UNBOUNDED_PRECEDING;
import static com.facebook.presto.spi.plan.WindowNode.Frame.WindowType.RANGE;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.window;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.windowFrame;
import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignments;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.google.common.base.Predicates.alwaysTrue;
import static com.google.common.collect.ImmutableList.toImmutableList;
public class TestPruneWindowColumns
extends BaseRuleTest
{
private static final String FUNCTION_NAME = "min";
private static final FunctionHandle FUNCTION_HANDLE = createTestMetadataManager().getFunctionAndTypeManager().lookupFunction(FUNCTION_NAME, fromTypes(BIGINT));
private static final List<String> inputSymbolNameList =
ImmutableList.of("orderKey", "partitionKey", "hash", "startValue1", "startValue2", "startValue3", "endValue1", "endValue2", "endValue3", "sortKeyForStartComparison3",
"sortKeyForEndComparison3", "input1", "input2", "input3", "unused");
private static final Set<String> inputSymbolNameSet = ImmutableSet.copyOf(inputSymbolNameList);
private static final ExpectedValueProvider<WindowNode.Frame> frameProvider1 = windowFrame(
RANGE,
UNBOUNDED_PRECEDING,
Optional.of("startValue1"),
CURRENT_ROW,
Optional.of("endValue1"),
Optional.of("orderKey"));
private static final ExpectedValueProvider<WindowNode.Frame> frameProvider2 = windowFrame(
RANGE,
UNBOUNDED_PRECEDING,
Optional.of("startValue2"),
CURRENT_ROW,
Optional.of("endValue2"),
Optional.of("orderKey"));
private static final ExpectedValueProvider<WindowNode.Frame> frameProvider3 = windowFrame(
RANGE,
PRECEDING,
Optional.of("startValue3"),
Optional.of(BIGINT),
Optional.of("sortKeyForStartComparison3"),
Optional.of(BIGINT),
FOLLOWING,
Optional.of("endValue3"),
Optional.of(BIGINT),
Optional.of("sortKeyForEndComparison3"),
Optional.of(BIGINT));
@Test
public void testWindowNotNeeded()
{
tester().assertThat(new PruneWindowColumns())
.on(p -> buildProjectedWindow(p, symbol -> inputSymbolNameSet.contains(symbol.getName()), alwaysTrue()))
.matches(
strictProject(
Maps.asMap(inputSymbolNameSet, PlanMatchPattern::expression),
values(inputSymbolNameList)));
}
@Test
public void testOneFunctionNotNeeded()
{
tester().assertThat(new PruneWindowColumns())
.on(p -> buildProjectedWindow(p,
symbol -> symbol.getName().equals("output2") || symbol.getName().equals("output3") || symbol.getName().equals("unused"),
alwaysTrue()))
.matches(
strictProject(
ImmutableMap.of(
"output2", expression("output2"),
"output3", expression("output3"),
"unused", expression("unused")),
window(windowBuilder -> windowBuilder
.prePartitionedInputs(ImmutableSet.of())
.specification(
ImmutableList.of("partitionKey"),
ImmutableList.of("orderKey"),
ImmutableMap.of("orderKey", SortOrder.ASC_NULLS_FIRST))
.preSortedOrderPrefix(0)
.addFunction(
"output2",
functionCall("min", ImmutableList.of("input2")),
FUNCTION_HANDLE,
frameProvider2)
.addFunction(
"output3",
functionCall("min", ImmutableList.of("input3")),
FUNCTION_HANDLE,
frameProvider3)
.hashSymbol("hash"),
strictProject(
Maps.asMap(
Sets.difference(inputSymbolNameSet, ImmutableSet.of("input1", "startValue1", "endValue1")),
PlanMatchPattern::expression),
values(inputSymbolNameList)))));
}
@Test
public void testTwoFunctionsNotNeeded()
{
tester().assertThat(new PruneWindowColumns())
.on(p -> buildProjectedWindow(p,
symbol -> symbol.getName().equals("output3") || symbol.getName().equals("unused"),
alwaysTrue()))
.matches(
strictProject(
ImmutableMap.of(
"output3", expression("output3"),
"unused", expression("unused")),
window(windowBuilder -> windowBuilder
.prePartitionedInputs(ImmutableSet.of())
.specification(
ImmutableList.of("partitionKey"),
ImmutableList.of("orderKey"),
ImmutableMap.of("orderKey", SortOrder.ASC_NULLS_FIRST))
.preSortedOrderPrefix(0)
.addFunction(
"output3",
functionCall("min", ImmutableList.of("input3")),
FUNCTION_HANDLE,
frameProvider3)
.hashSymbol("hash"),
strictProject(
Maps.asMap(
Sets.difference(inputSymbolNameSet, ImmutableSet.of("input1", "startValue1", "endValue1", "input2", "startValue2", "endValue2")),
PlanMatchPattern::expression),
values(inputSymbolNameList)))));
}
@Test
public void testAllColumnsNeeded()
{
tester().assertThat(new PruneWindowColumns())
.on(p -> buildProjectedWindow(p, alwaysTrue(), alwaysTrue()))
.doesNotFire();
}
@Test
public void testUsedInputsNotNeeded()
{
// If the WindowNode needs all its inputs, we can't discard them from its child.
tester().assertThat(new PruneWindowColumns())
.on(p -> buildProjectedWindow(
p,
// only the window function outputs
symbol -> !inputSymbolNameSet.contains(symbol.getName()),
// only the used input symbols
symbol -> !symbol.getName().equals("unused")))
.doesNotFire();
}
@Test
public void testUnusedInputNotNeeded()
{
tester().assertThat(new PruneWindowColumns())
.on(p -> buildProjectedWindow(
p,
// only the window function outputs
symbol -> !inputSymbolNameSet.contains(symbol.getName()),
alwaysTrue()))
.matches(
strictProject(
ImmutableMap.of(
"output1", expression("output1"),
"output2", expression("output2"),
"output3", expression("output3")),
window(windowBuilder -> windowBuilder
.prePartitionedInputs(ImmutableSet.of())
.specification(
ImmutableList.of("partitionKey"),
ImmutableList.of("orderKey"),
ImmutableMap.of("orderKey", SortOrder.ASC_NULLS_FIRST))
.preSortedOrderPrefix(0)
.addFunction(
"output1",
functionCall("min", ImmutableList.of("input1")),
FUNCTION_HANDLE,
frameProvider1)
.addFunction(
"output2",
functionCall("min", ImmutableList.of("input2")),
FUNCTION_HANDLE,
frameProvider2)
.addFunction(
"output3",
functionCall("min", ImmutableList.of("input3")),
FUNCTION_HANDLE,
frameProvider3)
.hashSymbol("hash"),
strictProject(
Maps.asMap(
Sets.filter(inputSymbolNameSet, symbolName -> !symbolName.equals("unused")),
PlanMatchPattern::expression),
values(inputSymbolNameList)))));
}
private static PlanNode buildProjectedWindow(
PlanBuilder p,
Predicate<VariableReferenceExpression> projectionFilter,
Predicate<VariableReferenceExpression> sourceFilter)
{
VariableReferenceExpression orderKey = p.variable("orderKey");
VariableReferenceExpression partitionKey = p.variable("partitionKey");
VariableReferenceExpression hash = p.variable("hash");
VariableReferenceExpression startValue1 = p.variable("startValue1");
VariableReferenceExpression startValue2 = p.variable("startValue2");
VariableReferenceExpression startValue3 = p.variable("startValue3");
VariableReferenceExpression sortKeyForStartComparison3 = p.variable("sortKeyForStartComparison3");
VariableReferenceExpression endValue1 = p.variable("endValue1");
VariableReferenceExpression endValue2 = p.variable("endValue2");
VariableReferenceExpression endValue3 = p.variable("endValue3");
VariableReferenceExpression sortKeyForEndComparison3 = p.variable("sortKeyForEndComparison3");
VariableReferenceExpression input1 = p.variable("input1");
VariableReferenceExpression input2 = p.variable("input2");
VariableReferenceExpression input3 = p.variable("input3");
VariableReferenceExpression unused = p.variable("unused");
VariableReferenceExpression output1 = p.variable("output1");
VariableReferenceExpression output2 = p.variable("output2");
VariableReferenceExpression output3 = p.variable("output3");
List<VariableReferenceExpression> inputs = ImmutableList.of(orderKey, partitionKey, hash, startValue1, startValue2, startValue3, endValue1, endValue2, endValue3,
sortKeyForStartComparison3, sortKeyForEndComparison3, input1, input2, input3, unused);
List<VariableReferenceExpression> outputs = ImmutableList.<VariableReferenceExpression>builder().addAll(inputs).add(output1, output2, output3).build();
List<VariableReferenceExpression> filteredInputs = inputs.stream().filter(sourceFilter).collect(toImmutableList());
return p.project(
identityAssignments(
outputs.stream()
.filter(projectionFilter)
.collect(toImmutableList())),
p.window(
new DataOrganizationSpecification(
ImmutableList.of(partitionKey),
Optional.of(new OrderingScheme(
ImmutableList.of(new Ordering(orderKey, SortOrder.ASC_NULLS_FIRST))))),
ImmutableMap.of(
output1,
new WindowNode.Function(
call(FUNCTION_NAME, FUNCTION_HANDLE, BIGINT, input1),
new WindowNode.Frame(
RANGE,
UNBOUNDED_PRECEDING,
Optional.of(startValue1),
Optional.of(orderKey),
CURRENT_ROW,
Optional.of(endValue1),
Optional.of(orderKey),
Optional.of(new SymbolReference(startValue1.getName())).map(Expression::toString),
Optional.of(new SymbolReference(endValue2.getName())).map(Expression::toString)),
false),
output2,
new WindowNode.Function(
call(FUNCTION_NAME, FUNCTION_HANDLE, BIGINT, input2),
new WindowNode.Frame(
RANGE,
UNBOUNDED_PRECEDING,
Optional.of(startValue2),
Optional.of(orderKey),
CURRENT_ROW,
Optional.of(endValue2),
Optional.of(orderKey),
Optional.of(new SymbolReference(startValue2.getName())).map(Expression::toString),
Optional.of(new SymbolReference(endValue2.getName())).map(Expression::toString)),
false),
output3,
new WindowNode.Function(
call(FUNCTION_NAME, FUNCTION_HANDLE, BIGINT, input3),
new WindowNode.Frame(
RANGE,
PRECEDING,
Optional.of(startValue3),
Optional.of(sortKeyForStartComparison3),
FOLLOWING,
Optional.of(endValue3),
Optional.of(sortKeyForEndComparison3),
Optional.of(new SymbolReference(startValue3.getName())).map(Expression::toString),
Optional.of(new SymbolReference(endValue3.getName())).map(Expression::toString)),
false)),
hash,
p.values(
filteredInputs,
ImmutableList.of())));
}
}