TestReplaceConstantVariableReferencesWithConstants.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.optimizations;
import com.facebook.presto.Session;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.assertions.BasePlanTest;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;
import static com.facebook.presto.SystemSessionProperties.REWRITE_EXPRESSION_WITH_CONSTANT_EXPRESSION;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.limit;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.sort;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.topN;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.union;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.unnest;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;
import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment;
import static com.facebook.presto.sql.tree.SortItem.NullOrdering.LAST;
import static com.facebook.presto.sql.tree.SortItem.Ordering.ASCENDING;
public class TestReplaceConstantVariableReferencesWithConstants
extends BasePlanTest
{
private Session enableOptimization()
{
return Session.builder(this.getQueryRunner().getDefaultSession())
.setSystemProperty(REWRITE_EXPRESSION_WITH_CONSTANT_EXPRESSION, "true")
.build();
}
@Test
public void testAggregation()
{
assertPlan("select orderkey, orderpriority, avg(totalprice) from orders where orderpriority='3-MEDIUM' group by 1, 2",
enableOptimization(),
output(
ImmutableList.of("orderkey", "expr_12", "avg"),
project(
ImmutableMap.of("expr_12", expression("'3-MEDIUM'")),
aggregation(
ImmutableMap.of("avg", functionCall("avg", ImmutableList.of("totalprice"))),
anyTree(
filter(
"orderpriority = '3-MEDIUM'",
tableScan("orders", ImmutableMap.of("orderkey", "orderkey", "orderpriority", "orderpriority", "totalprice", "totalprice"))))))));
}
@Test
public void testUnnest()
{
assertPlan("select orderkey, orderpriority, idx from orders cross join unnest(array[1, 2]) t(idx) where orderpriority='3-MEDIUM'",
enableOptimization(),
output(
ImmutableList.of("orderkey", "expr_9", "field"),
project(
ImmutableMap.of("expr_9", expression("'3-MEDIUM'")),
unnest(
ImmutableMap.of("expr", ImmutableList.of("field")),
project(
ImmutableMap.of("expr", expression("array[1, 2]")),
filter(
"orderpriority = '3-MEDIUM'",
tableScan("orders", ImmutableMap.of("orderkey", "orderkey", "orderpriority", "orderpriority"))))))));
}
@Test
public void testInnerJoin()
{
assertPlan("select o.orderkey, o.orderpriority, l.tax from lineitem l join orders o on o.orderkey = l.orderkey where o.orderpriority='3-MEDIUM'",
enableOptimization(),
output(
ImmutableList.of("orderkey_0", "expr_14", "tax"),
project(
ImmutableMap.of("expr_14", expression("'3-MEDIUM'")),
join(
JoinType.INNER,
ImmutableList.of(equiJoinClause("orderkey", "orderkey_0")),
anyTree(
tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "tax", "tax"))),
anyTree(
filter(
"orderpriority = '3-MEDIUM'",
tableScan("orders", ImmutableMap.of("orderkey_0", "orderkey", "orderpriority", "orderpriority"))))))));
}
@Test
public void testLeftJoinNotTrigger()
{
assertPlan("select o.orderkey, o.orderpriority, l.tax from lineitem l left join (select orderkey, orderpriority from orders where orderpriority='3-MEDIUM') o on o.orderkey = l.orderkey",
enableOptimization(),
output(
ImmutableList.of("orderkey_0", "expr_9", "tax"),
join(
JoinType.LEFT,
ImmutableList.of(equiJoinClause("orderkey", "orderkey_0")),
anyTree(
tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "tax", "tax"))),
anyTree(
project(
ImmutableMap.of("expr_9", expression("'3-MEDIUM'")),
filter(
"orderpriority = '3-MEDIUM'",
tableScan("orders", ImmutableMap.of("orderkey_0", "orderkey", "orderpriority", "orderpriority"))))))));
}
@Test
public void testLeftJoinTrigger()
{
assertPlan("select o.orderkey, l.linestatus, l.tax from (select tax, linestatus, orderkey from lineitem where linestatus ='O') l left join orders o on o.orderkey = l.orderkey",
enableOptimization(),
output(
ImmutableList.of("orderkey_0", "expr_26", "tax"),
project(
ImmutableMap.of("expr_26", expression("'O'")),
join(
JoinType.LEFT,
ImmutableList.of(equiJoinClause("orderkey", "orderkey_0")),
anyTree(
filter(
"linestatus = 'O'",
tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "tax", "tax", "linestatus", "linestatus")))),
anyTree(
tableScan("orders", ImmutableMap.of("orderkey_0", "orderkey")))))));
}
@Test
public void testSemiJoin()
{
assertPlan("select orderpriority, orderkey from orders where orderpriority='3-MEDIUM' and orderkey in (select orderkey from lineitem)",
enableOptimization(),
output(
ImmutableList.of("expr_15", "orderkey"),
project(
ImmutableMap.of("expr_15", expression("'3-MEDIUM'")),
filter(
"expr_8",
project(
semiJoin(
"orderkey",
"orderkey_1",
"expr_8",
project(
filter(
"orderpriority = '3-MEDIUM'",
tableScan("orders", ImmutableMap.of("orderkey", "orderkey", "orderpriority", "orderpriority")))),
anyTree(
project(
tableScan("lineitem", ImmutableMap.of("orderkey_1", "orderkey"))))))))));
}
@Test
public void testSimpleFilter()
{
assertPlan("select orderkey, orderpriority from orders where orderpriority='3-MEDIUM'",
enableOptimization(),
output(
ImmutableList.of("orderkey", "expr_6"),
project(
ImmutableMap.of("expr_6", expression("'3-MEDIUM'")),
filter(
"orderpriority = '3-MEDIUM'",
tableScan("orders", ImmutableMap.of("orderkey", "orderkey", "orderpriority", "orderpriority"))))));
}
@Test
public void testFilterOnSameVariable()
{
assertPlan("select orderkey, orderpriority from orders where orderpriority='3-MEDIUM' and orderpriority='5-HIGH'",
enableOptimization(),
output(
values("orderkey", "orderpriority")));
}
@Test
public void testJoinWithFilter()
{
assertPlan("with t1 as (select orderkey, orderstatus from orders where orderkey = 10) select l.orderkey, partkey, orderstatus from t1 join lineitem l on t1.orderkey = l.orderkey where partkey in (select suppkey from lineitem)",
enableOptimization(),
output(
ImmutableList.of("orderkey_9", "partkey", "orderstatus"),
project(
ImmutableMap.of("orderkey_9", expression("10")),
filter(
"expr_37",
project(
semiJoin("partkey", "suppkey_17", "expr_37",
project(
join(JoinType.INNER,
ImmutableList.of(),
anyTree(
filter("orderkey = 10",
tableScan("orders", ImmutableMap.of("orderstatus", "orderstatus", "orderkey", "orderkey")))),
anyTree(
filter(
"orderkey_9 = 10",
tableScan("lineitem", ImmutableMap.of("orderkey_9", "orderkey", "partkey", "partkey")))))),
anyTree(
tableScan("lineitem", ImmutableMap.of("suppkey_17", "suppkey")))))))));
}
@Test
public void testConstantRowExpression()
{
assertPlan("select orderkey+1 as nk from lineitem where orderkey=1",
enableOptimization(),
output(
project(
ImmutableMap.of("expr", expression("2")),
filter(
"orderkey=1",
tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey"))))));
}
@Test
public void testSort()
{
assertPlan("select orderkey, orderpriority from orders where orderpriority='3-MEDIUM' order by orderpriority",
enableOptimization(),
output(
project(
ImmutableMap.of("orderkey", expression("orderkey"), "orderpriority", expression("'3-MEDIUM'")),
filter(
"orderpriority='3-MEDIUM'",
tableScan("orders", ImmutableMap.of("orderpriority", "orderpriority", "orderkey", "orderkey"))))));
}
@Test
public void testSortOnMultipleKey()
{
ImmutableList<PlanMatchPattern.Ordering> orderBy = ImmutableList.of(sort("orderkey", ASCENDING, LAST));
assertPlan("select orderkey, orderpriority from orders where orderpriority='3-MEDIUM' order by orderpriority, orderkey",
enableOptimization(),
output(
project(
ImmutableMap.of("orderkey", expression("orderkey"), "orderpriority", expression("'3-MEDIUM'")),
anyTree(
sort(orderBy,
anyTree(
project(
ImmutableMap.of("orderkey", expression("orderkey")),
filter(
"orderpriority='3-MEDIUM'",
tableScan("orders", ImmutableMap.of("orderpriority", "orderpriority", "orderkey", "orderkey"))))))))));
}
@Test
public void testTopN()
{
assertPlan("select orderkey, orderpriority from orders where orderpriority='3-MEDIUM' order by orderpriority limit 10",
enableOptimization(),
output(
project(
ImmutableMap.of("orderkey", expression("orderkey"), "orderpriority", expression("'3-MEDIUM'")),
limit(
10,
anyTree(
filter(
"orderpriority='3-MEDIUM'",
tableScan("orders", ImmutableMap.of("orderpriority", "orderpriority", "orderkey", "orderkey"))))))));
}
@Test
public void testTopNSortOnMultipleKey()
{
ImmutableList<PlanMatchPattern.Ordering> orderBy = ImmutableList.of(sort("orderkey", ASCENDING, LAST));
assertPlan("select orderkey, orderpriority from orders where orderpriority='3-MEDIUM' order by orderpriority, orderkey limit 10",
enableOptimization(),
output(
project(
ImmutableMap.of("orderkey", expression("orderkey"), "orderpriority", expression("'3-MEDIUM'")),
topN(10, orderBy,
anyTree(
topN(
10,
orderBy,
project(
ImmutableMap.of("orderkey", expression("orderkey")),
filter(
"orderpriority='3-MEDIUM'",
tableScan("orders", ImmutableMap.of("orderpriority", "orderpriority", "orderkey", "orderkey"))))))))));
}
@Test
public void testUnionAllWithSameConstant()
{
assertPlan("select orderkey, price, count(*) from (select orderkey, extendedprice as price from lineitem where orderkey=5 union all select orderkey, totalprice as price from orders where orderkey=5) group by orderkey, price",
enableOptimization(),
output(
project(
ImmutableMap.of("orderkey_11", expression("5")),
aggregation(
ImmutableMap.of("count", functionCall("count", ImmutableList.of("count_29"))),
exchange(
anyTree(
aggregation(
ImmutableMap.of("count_29", functionCall("count", ImmutableList.of())),
project(
filter(
"orderkey = 5",
tableScan("lineitem", ImmutableMap.of("extendedprice", "extendedprice", "orderkey", "orderkey")))))),
anyTree(
aggregation(
ImmutableMap.of("count_29", functionCall("count", ImmutableList.of())),
project(
filter(
"orderkey_4 = 5",
tableScan("orders", ImmutableMap.of("orderkey_4", "orderkey", "totalprice", "totalprice")))))))))));
}
@Test
public void testUnionAllWithDifferentConstant()
{
assertPlan("select orderkey, price, count(*) from (select orderkey, extendedprice as price from lineitem where orderkey=5 union all select orderkey, totalprice as price from orders where orderkey=2) group by orderkey, price",
enableOptimization(),
output(
project(
ImmutableMap.of("orderkey_11", expression("orderkey_11")),
aggregation(
ImmutableMap.of("count", functionCall("count", ImmutableList.of("count_29"))),
exchange(
project(
project(
ImmutableMap.of("orderkey_11", expression("orderkey")),
aggregation(
ImmutableMap.of("count_29", functionCall("count", ImmutableList.of())),
project(
project(
ImmutableMap.of("orderkey", expression("5")),
filter(
"orderkey = 5",
tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "extendedprice", "extendedprice")))))))),
project(
project(
ImmutableMap.of("orderkey_11", expression("orderkey_4")),
aggregation(
ImmutableMap.of("count_29", functionCall("count", ImmutableList.of())),
project(
project(
ImmutableMap.of("orderkey_4", expression("2")),
filter(
"orderkey_4 = 2",
tableScan("orders", ImmutableMap.of("orderkey_4", "orderkey", "totalprice", "totalprice")))))))))))));
}
@Test
public void testUnionAllWithOneConstant()
{
assertPlan("select orderkey, price, count(*) from (select orderkey, extendedprice as price from lineitem where orderkey=5 union all select orderkey, totalprice as price from orders) group by orderkey, price",
enableOptimization(),
output(
project(
ImmutableMap.of("orderkey_11", expression("orderkey_11")),
aggregation(
ImmutableMap.of("count", functionCall("count", ImmutableList.of("count_29"))),
exchange(
project(
project(
ImmutableMap.of("orderkey_11", expression("orderkey")),
aggregation(
ImmutableMap.of("count_29", functionCall("count", ImmutableList.of())),
project(
project(
ImmutableMap.of("orderkey", expression("5")),
filter(
"orderkey = 5",
tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "extendedprice", "extendedprice")))))))),
project(
project(
ImmutableMap.of("orderkey_11", expression("orderkey_4")),
aggregation(
ImmutableMap.of("count_29", functionCall("count", ImmutableList.of())),
project(
tableScan("orders", ImmutableMap.of("orderkey_4", "orderkey", "totalprice", "totalprice")))))))))));
}
@Test
public void testExtractConstantFromFilter()
{
RuleTester tester = new RuleTester();
tester.assertThat(new ReplaceConstantVariableReferencesWithConstants(createTestFunctionAndTypeManager()))
.on(planBuilder ->
{
VariableReferenceExpression key1 = planBuilder.variable("key1", INTEGER);
VariableReferenceExpression key2 = planBuilder.variable("key2", INTEGER);
VariableReferenceExpression count = planBuilder.variable("cnt");
return planBuilder.aggregation(
aggregationBuilder -> aggregationBuilder
.source(
planBuilder.filter(
planBuilder.rowExpression("key1= 3"),
planBuilder.values(key1, key2)))
.singleGroupingSet(key1, key2)
.addAggregation(count, planBuilder.rowExpression("count()")));
})
.matches(
aggregation(
ImmutableMap.of("cnt", functionCall("count", ImmutableList.of())),
project(
ImmutableMap.of("key1", expression("3")),
filter(
"key1=3",
values("key1", "key2")))));
}
@Test
public void testJoinPlanChange()
{
RuleTester tester = new RuleTester();
tester.assertThat(new ReplaceConstantVariableReferencesWithConstants(createTestFunctionAndTypeManager()))
.on(planBuilder ->
{
VariableReferenceExpression key1 = planBuilder.variable("key1", INTEGER);
VariableReferenceExpression key2 = planBuilder.variable("key2", INTEGER);
return planBuilder.join(
JoinType.INNER,
planBuilder.filter(
planBuilder.rowExpression("key1= 3"),
planBuilder.values(key1)),
planBuilder.filter(
planBuilder.rowExpression("key2= 5"),
planBuilder.values(key2)));
})
.matches(
join(
project(
ImmutableMap.of("key1", expression("3")),
filter(
"key1=3",
values("key1"))),
project(
ImmutableMap.of("key2", expression("5")),
filter(
"key2=5",
values("key2")))));
}
@Test
public void testUnionPlanChange()
{
RuleTester tester = new RuleTester();
tester.assertThat(new ReplaceConstantVariableReferencesWithConstants(createTestFunctionAndTypeManager()))
.on(planBuilder ->
{
VariableReferenceExpression input1Source1 = planBuilder.variable("input1_source1", INTEGER);
VariableReferenceExpression input2Source1 = planBuilder.variable("input2_source1", INTEGER);
VariableReferenceExpression input1Source2 = planBuilder.variable("input1_source2", INTEGER);
VariableReferenceExpression input2Source2 = planBuilder.variable("input2_source2", INTEGER);
VariableReferenceExpression output1 = planBuilder.variable("output1", INTEGER);
VariableReferenceExpression output2 = planBuilder.variable("output2", INTEGER);
return
planBuilder.union(
ImmutableListMultimap.<VariableReferenceExpression, VariableReferenceExpression>builder().putAll(output1, input1Source1, input1Source2)
.putAll(output2, input2Source1, input2Source2).build(),
ImmutableList.of(
planBuilder.filter(
planBuilder.rowExpression("input1_source1 = 3"),
planBuilder.values(input1Source1, input2Source1)),
planBuilder.filter(
planBuilder.rowExpression("input1_source2 = 3"),
planBuilder.values(input1Source2, input2Source2))));
})
.matches(
union(
project(
ImmutableMap.of("input1_source1", expression("3"), "input2_source1", expression("input2_source1")),
filter(
"input1_source1=3",
values("input1_source1", "input2_source1"))),
project(
ImmutableMap.of("input1_source2", expression("3"), "input2_source2", expression("input2_source2")),
filter(
"input1_source2=3",
values("input1_source2", "input2_source2")))));
}
// Do not extract constant variable when having conflicting filters
@Test
public void testConflictFilter()
{
RuleTester tester = new RuleTester();
tester.assertThat(new ReplaceConstantVariableReferencesWithConstants(createTestFunctionAndTypeManager()))
.on(planBuilder ->
{
VariableReferenceExpression key1 = planBuilder.variable("key1", INTEGER);
VariableReferenceExpression key2 = planBuilder.variable("key2", INTEGER);
VariableReferenceExpression count = planBuilder.variable("cnt");
return planBuilder.aggregation(
aggregationBuilder -> aggregationBuilder
.source(
planBuilder.filter(
planBuilder.rowExpression("key1=2"),
planBuilder.filter(
planBuilder.rowExpression("key1= 3"),
planBuilder.values(key1, key2))))
.singleGroupingSet(key1, key2)
.addAggregation(count, planBuilder.rowExpression("count()")));
})
.matches(
aggregation(
ImmutableMap.of("cnt", functionCall("count", ImmutableList.of())),
filter(
"key1=2",
filter(
"key1=3",
values("key1", "key2")))));
}
@Test
public void testConflictFilterConjunct()
{
RuleTester tester = new RuleTester();
tester.assertThat(new ReplaceConstantVariableReferencesWithConstants(createTestFunctionAndTypeManager()))
.on(planBuilder ->
{
VariableReferenceExpression key1 = planBuilder.variable("key1", INTEGER);
VariableReferenceExpression key2 = planBuilder.variable("key2", INTEGER);
VariableReferenceExpression count = planBuilder.variable("cnt");
return planBuilder.aggregation(
aggregationBuilder -> aggregationBuilder
.source(
planBuilder.filter(
planBuilder.rowExpression("key1=2 and key2=5"),
planBuilder.filter(
planBuilder.rowExpression("key1= 3"),
planBuilder.values(key1, key2))))
.singleGroupingSet(key1, key2)
.addAggregation(count, planBuilder.rowExpression("count()")));
})
.matches(
aggregation(
ImmutableMap.of("cnt", functionCall("count", ImmutableList.of())),
filter(
"key1=2 and key2=5",
filter(
"key1=3",
values("key1", "key2")))));
}
@Test
public void testNonConflictFilter()
{
RuleTester tester = new RuleTester();
tester.assertThat(new ReplaceConstantVariableReferencesWithConstants(createTestFunctionAndTypeManager()))
.on(planBuilder ->
{
VariableReferenceExpression key1 = planBuilder.variable("key1", INTEGER);
VariableReferenceExpression key2 = planBuilder.variable("key2", INTEGER);
VariableReferenceExpression count = planBuilder.variable("cnt");
return planBuilder.aggregation(
aggregationBuilder -> aggregationBuilder
.source(
planBuilder.filter(
planBuilder.rowExpression("key1=3"),
planBuilder.filter(
planBuilder.rowExpression("key1= 3"),
planBuilder.values(key1, key2))))
.singleGroupingSet(key1, key2)
.addAggregation(count, planBuilder.rowExpression("count()")));
})
.matches(
aggregation(
ImmutableMap.of("cnt", functionCall("count", ImmutableList.of())),
project(
ImmutableMap.of("key1", expression("3")),
filter(
"3=3",
filter(
"key1=3",
values("key1", "key2"))))));
}
@Test
public void testNonConflictFilterConjunct()
{
RuleTester tester = new RuleTester();
tester.assertThat(new ReplaceConstantVariableReferencesWithConstants(createTestFunctionAndTypeManager()))
.on(planBuilder ->
{
VariableReferenceExpression key1 = planBuilder.variable("key1", INTEGER);
VariableReferenceExpression key2 = planBuilder.variable("key2", INTEGER);
VariableReferenceExpression count = planBuilder.variable("cnt");
return planBuilder.aggregation(
aggregationBuilder -> aggregationBuilder
.source(
planBuilder.filter(
planBuilder.rowExpression("key1=3 and key2=5"),
planBuilder.filter(
planBuilder.rowExpression("key1= 3"),
planBuilder.values(key1, key2))))
.singleGroupingSet(key1, key2)
.addAggregation(count, planBuilder.rowExpression("count()")));
})
.matches(
aggregation(
ImmutableMap.of("cnt", functionCall("count", ImmutableList.of())),
project(
ImmutableMap.of("key1", expression("3"), "key2", expression("5")),
filter(
"3=3 and key2=5",
filter(
"key1=3",
values("key1", "key2"))))));
}
@Test
public void testFilterOnExpression()
{
RuleTester tester = new RuleTester();
tester.assertThat(new ReplaceConstantVariableReferencesWithConstants(createTestFunctionAndTypeManager()))
.on(planBuilder ->
{
VariableReferenceExpression key1 = planBuilder.variable("key1", INTEGER);
VariableReferenceExpression key2 = planBuilder.variable("key2", INTEGER);
VariableReferenceExpression count = planBuilder.variable("cnt");
return planBuilder.aggregation(
aggregationBuilder -> aggregationBuilder
.source(
planBuilder.filter(
planBuilder.rowExpression("key2=key1+2"),
planBuilder.filter(
planBuilder.rowExpression("key1= 3"),
planBuilder.values(key1, key2))))
.singleGroupingSet(key1, key2)
.addAggregation(count, planBuilder.rowExpression("count()")));
})
.matches(
aggregation(
ImmutableMap.of("cnt", functionCall("count", ImmutableList.of())),
project(
ImmutableMap.of("key1", expression("3"), "key2", expression("key2")),
filter(
"key2=3+2",
filter(
"key1=3",
values("key1", "key2"))))));
}
@Test
public void testFilterAndProjectOnExpression()
{
RuleTester tester = new RuleTester();
tester.assertThat(new ReplaceConstantVariableReferencesWithConstants(createTestFunctionAndTypeManager()))
.on(planBuilder ->
{
VariableReferenceExpression key1 = planBuilder.variable("key1", INTEGER);
VariableReferenceExpression key2 = planBuilder.variable("key2", INTEGER);
VariableReferenceExpression count = planBuilder.variable("cnt");
return planBuilder.aggregation(
aggregationBuilder -> aggregationBuilder
.source(
planBuilder.project(
assignment(key2, planBuilder.rowExpression("key1+2"), key1, planBuilder.rowExpression("key1")),
planBuilder.filter(
planBuilder.rowExpression("key1= 3"),
planBuilder.values(key1, key2))))
.singleGroupingSet(key1, key2)
.addAggregation(count, planBuilder.rowExpression("count()")));
})
.matches(
aggregation(
ImmutableMap.of("cnt", functionCall("count", ImmutableList.of())),
project(
ImmutableMap.of("key1", expression("3"), "key2", expression("key2")),
project(
ImmutableMap.of("key2", expression("3+2"), "key1", expression("3")),
filter(
"key1=3",
values("key1", "key2"))))));
}
@Test
public void testConflictFilterAndProject()
{
RuleTester tester = new RuleTester();
tester.assertThat(new ReplaceConstantVariableReferencesWithConstants(createTestFunctionAndTypeManager()))
.on(planBuilder ->
{
VariableReferenceExpression key1 = planBuilder.variable("key1", INTEGER);
VariableReferenceExpression key2 = planBuilder.variable("key2", INTEGER);
VariableReferenceExpression count = planBuilder.variable("cnt");
return planBuilder.aggregation(
aggregationBuilder -> aggregationBuilder
.source(
planBuilder.filter(
planBuilder.rowExpression("key1= 3"),
planBuilder.project(
assignment(key1, planBuilder.rowExpression("2"), key2, planBuilder.rowExpression("key2")),
planBuilder.values(key1, key2))))
.singleGroupingSet(key1, key2)
.addAggregation(count, planBuilder.rowExpression("count()")));
})
.matches(
aggregation(
ImmutableMap.of("cnt", functionCall("count", ImmutableList.of())),
filter(
"key1=3",
project(
ImmutableMap.of("key1", expression("2")),
values("key1", "key2")))));
}
@Test
public void testExtractConstantFromProject()
{
RuleTester tester = new RuleTester();
tester.assertThat(new ReplaceConstantVariableReferencesWithConstants(createTestFunctionAndTypeManager()))
.on(planBuilder ->
{
VariableReferenceExpression key1 = planBuilder.variable("key1", INTEGER);
VariableReferenceExpression key2 = planBuilder.variable("key2", INTEGER);
VariableReferenceExpression count = planBuilder.variable("cnt");
return planBuilder.aggregation(
aggregationBuilder -> aggregationBuilder
.source(
planBuilder.filter(
planBuilder.rowExpression("key2 <> 3"),
planBuilder.project(
assignment(key1, planBuilder.rowExpression("3"), key2, planBuilder.rowExpression("key2")),
planBuilder.values(key1, key2))))
.singleGroupingSet(key1, key2)
.addAggregation(count, planBuilder.rowExpression("count()")));
})
.matches(
aggregation(
ImmutableMap.of("cnt", functionCall("count", ImmutableList.of())),
project(
ImmutableMap.of("key1", expression("3")),
filter(
"key2 <> 3",
project(
ImmutableMap.of("key1", expression("3"), "key2", expression("key2")),
values("key1", "key2"))))));
}
// If project get conflict constant, use the latest one
@Test
public void testConflictConstantFromProject()
{
RuleTester tester = new RuleTester();
tester.assertThat(new ReplaceConstantVariableReferencesWithConstants(createTestFunctionAndTypeManager()))
.on(planBuilder ->
{
VariableReferenceExpression key1 = planBuilder.variable("key1", INTEGER);
VariableReferenceExpression key2 = planBuilder.variable("key2", INTEGER);
VariableReferenceExpression count = planBuilder.variable("cnt");
return planBuilder.aggregation(
aggregationBuilder -> aggregationBuilder
.source(
planBuilder.filter(
planBuilder.rowExpression("key2 <> 3"),
planBuilder.project(
assignment(key1, planBuilder.rowExpression("5"), key2, planBuilder.rowExpression("key2")),
planBuilder.project(
assignment(key1, planBuilder.rowExpression("3"), key2, planBuilder.rowExpression("key2")),
planBuilder.values(key1, key2)))))
.singleGroupingSet(key1, key2)
.addAggregation(count, planBuilder.rowExpression("count()")));
})
.matches(
aggregation(
ImmutableMap.of("cnt", functionCall("count", ImmutableList.of())),
project(
ImmutableMap.of("key1", expression("5")),
filter(
"key2 <> 3",
project(
ImmutableMap.of("key1", expression("5"), "key2", expression("key2")),
project(
ImmutableMap.of("key1", expression("3"), "key2", expression("key2")),
values("key1", "key2")))))));
}
}