TestSimplifyCardinalityMap.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.sql.TestingRowExpressionTranslator;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;
import java.util.Map;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager;
import static com.facebook.presto.sql.planner.PlannerUtils.createMapType;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;
import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment;
public class TestSimplifyCardinalityMap
extends BaseRuleTest
{
private final TestingRowExpressionTranslator testSqlToRowExpressionTranslator = new TestingRowExpressionTranslator();
@Test
public void testRewriteMapValuesCardinality()
{
assertRewritten("cardinality(map_values(m))", "cardinality(m)");
}
@Test
public void testRewriteMapValuesMixedCasesCardinality()
{
assertRewritten("CaRDinality(map_keys(m))", "cardinaLITY(m)");
}
@Test
public void testNoRewriteMapValuesCardinality()
{
assertRewriteDoesNotFire("cardinality(map(ARRAY[1,3], ARRAY[2,4]))");
}
@Test
public void testNestedRewriteMapValuesCardinality()
{
assertRewritten(
"cardinality(map(ARRAY[cardinality(map_values(m_1))], ARRAY[cardinality(map_values(m_2))]))",
"cardinality(map(ARRAY[cardinality(m_1)], ARRAY[cardinality(m_2)]))");
}
@Test
public void testNestedRewriteMapKeysCardinality()
{
assertRewritten(
"cardinality(map(ARRAY[cardinality(map_keys(m_1)),3], ARRAY[2,cardinality(map_keys(m_2))]))",
"cardinality(map(ARRAY[cardinality(m_1),3], ARRAY[2,cardinality(m_2)]))");
assertRewritten(
"cast(cardinality(map_keys(m)) as varchar)",
"cast(cardinality(m) as varchar)");
}
@Test
public void testAnotherNestedRewriteMapValuesCardinality()
{
assertRewritten(
"cardinality(map(ARRAY[cardinality(map_values(map(ARRAY[1,3], ARRAY[2,4]))),3], ARRAY[2,cardinality(map_values(m_2))]))",
"cardinality(map(ARRAY[cardinality(map(ARRAY[1,3], ARRAY[2,4])),3], ARRAY[2,cardinality(m_2)]))");
}
private void assertRewriteDoesNotFire(String expression)
{
RowExpression inputExpression = testSqlToRowExpressionTranslator.translate(expression, ImmutableMap.of());
tester().assertThat(new SimplifyCardinalityMap(createTestFunctionAndTypeManager()).projectRowExpressionRewriteRule())
.on(p -> p.project(assignment(p.variable("x"), inputExpression), p.values()))
.doesNotFire();
}
private void assertRewritten(String inputExpressionStr,
String expectedExpressionStr)
{
Type mapType = createMapType(getFunctionManager(), BIGINT, BIGINT);
Map<String, Type> types = ImmutableMap.of("m", mapType, "m_1", mapType, "m_2", mapType);
RowExpression inputExpression = testSqlToRowExpressionTranslator.translate(inputExpressionStr, types);
tester().assertThat(new SimplifyCardinalityMap(createTestFunctionAndTypeManager()).projectRowExpressionRewriteRule())
.on(p -> p.project(assignment(p.variable("x"), inputExpression), p.values(p.variable("m", mapType), p.variable("m_1", mapType), p.variable("m_2", mapType))))
.matches(project(ImmutableMap.of("x", expression(expectedExpressionStr)), values("m", "m_1", "m_2")));
}
}