TestAddNotNullFiltersToJoinNode.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.TestingRowExpressionTranslator;
import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinNotNullInferenceStrategy;
import com.facebook.presto.sql.planner.assertions.BasePlanTest;
import com.facebook.presto.sql.planner.iterative.rule.AddNotNullFiltersToJoinNode.ExtractInferredNotNullVariablesVisitor;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import static com.facebook.presto.SystemSessionProperties.JOINS_NOT_NULL_INFERENCE_STRATEGY;
import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_NULLS_IN_JOINS;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager;
import static com.facebook.presto.spi.plan.JoinType.INNER;
import static com.facebook.presto.spi.plan.JoinType.LEFT;
import static com.facebook.presto.spi.plan.JoinType.RIGHT;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinNotNullInferenceStrategy.INFER_FROM_STANDARD_OPERATORS;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinNotNullInferenceStrategy.USE_FUNCTION_METADATA;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan;
import static com.facebook.presto.sql.relational.Expressions.variable;
import static com.facebook.presto.testing.assertions.Assert.assertEquals;
public class TestAddNotNullFiltersToJoinNode
extends BasePlanTest
{
private static final Map<String, Type> testVariableTypeMap;
private final TestingRowExpressionTranslator rowExpressionTranslator;
private final FunctionAndTypeManager functionAndTypeManager;
public TestAddNotNullFiltersToJoinNode()
{
super(ImmutableMap.of(OPTIMIZE_NULLS_IN_JOINS, Boolean.toString(false),
JOINS_NOT_NULL_INFERENCE_STRATEGY, USE_FUNCTION_METADATA.toString()));
functionAndTypeManager = createTestFunctionAndTypeManager();
Metadata metadata = MetadataManager.createTestMetadataManager();
rowExpressionTranslator = new TestingRowExpressionTranslator(metadata);
}
@DataProvider
public static Object[][] getExistingNotNullVarsTestCases()
{
return new Object[][] {
{"a IS NOT NULL AND b IS NOT NULL", new String[] {"a", "b"}},
{"a > 10 AND b IS NOT NULL AND c is NOT NULL", new String[] {"b", "c"}},
{"a is NULL AND b IS NOT NULL", new String[] {"b"}},
{"NOT(a is NULL)", new String[] {"a"}},
{"NOT(a is NULL OR b is NULL)", new String[] {}},
{"a is NOT NULL OR b is NOT NULL", new String[] {}}
};
}
@DataProvider
public static Object[][] standardOperatorTestCases()
{
return new Object[][] {
{"a + b > 10", new String[] {"a", "b"}},
{"a != 10 - b", new String[] {"a", "b"}},
{"a > b", new String[] {"a", "b"}},
{"a + NULL > b", new String[] {"a", "b"}},
// We can infer NOT NULL predicates on arguments of an AND expression
{"a > b and c = d", new String[] {"a", "b", "c", "d"}},
{"a IS NULL and b > c", new String[] {"b", "c"}},
// We cannot infer NOT NULL predicates on arguments of an OR expression
{"a > b OR c = d", new String[] {}},
// COALESCE can operate on NULL arguments, so cant infer predicates on its arguments
{"COALESCE(a,b)", new String[] {}},
// IN can operate on NULL arguments, so cant infer predicates on its arguments
{"a IN (b,10,NULL)", new String[] {}},
// arr[3] = 10 translates to EQUAL(SUBSCRIPT(arr, 3), 10). SUBSCRIPT is a standard Operator, so we can infer that 'arr' is NOT NULL
{"arr[3] = 10", new String[] {"arr"}},
// c_struct.a = 10 translates to EQUAL(DEREFERENCE(c_struct, 0), 10).
// We chose to not make any inferences for DEREFERENCE clauses, hence we don't add any NOT NULL clauses for 'c_struct' or 'c_struct.a'
{"c_struct.a = 10", new String[] {}},
// NULLs are only inferred from standard operators
{"NOT (b + 10 > c)", new String[] {}},
{"abs(b + c) > 10", new String[] {}},
{"random(b) = ceil(c)", new String[] {}},
{"d > e and abs(b + c) > 10", new String[] {"d", "e"}},
};
}
@DataProvider
public static Object[][] nonStandardOperatorTestCases()
{
return new Object[][] {
{"NOT (b + 10 > c)", new String[] {"b", "c"}},
{"abs(b + c) > 10", new String[] {"b", "c"}},
{"random(b) = ceil(c)", new String[] {"b", "c"}},
};
}
@Test
public void testNotNullPredicatesAddedForSingleEquiJoinClause()
{
String query = "select 1 from lineitem l join orders o on l.orderkey = o.orderkey";
assertPlan(query,
anyTree(
join(INNER,
ImmutableList.of(equiJoinClause("LINE_ORDER_KEY", "ORDERS_ORDER_KEY")),
anyTree(
filter("LINE_ORDER_KEY IS NOT NULL",
tableScan("lineitem", ImmutableMap.of("LINE_ORDER_KEY", "orderkey")))),
anyTree(
filter("ORDERS_ORDER_KEY IS NOT NULL",
tableScan("orders", ImmutableMap.of("ORDERS_ORDER_KEY", "orderkey")))))));
}
@Test
public void testNotNullPredicatesAddedForCrossJoinReducedToInnerJoin()
{
String query = "select 1 from lineitem l, orders o where l.orderkey = o.orderkey";
assertPlan(query,
anyTree(
join(INNER,
ImmutableList.of(equiJoinClause("LINE_ORDER_KEY", "ORDERS_ORDER_KEY")),
anyTree(
filter("LINE_ORDER_KEY IS NOT NULL",
tableScan("lineitem", ImmutableMap.of("LINE_ORDER_KEY", "orderkey")))),
anyTree(
filter("ORDERS_ORDER_KEY IS NOT NULL",
tableScan("orders", ImmutableMap.of("ORDERS_ORDER_KEY", "orderkey")))))));
query = "select 1 from lineitem l join orders o on l.orderkey = o.orderkey, customer c where c.custkey = o.custkey";
assertPlan(query,
anyTree(
join(INNER,
ImmutableList.of(equiJoinClause("LINE_ORDER_KEY", "ORDERS_ORDER_KEY")),
anyTree(
filter("LINE_ORDER_KEY IS NOT NULL",
tableScan("lineitem", ImmutableMap.of("LINE_ORDER_KEY", "orderkey")))),
anyTree(join(INNER,
ImmutableList.of(equiJoinClause("ORDERS_CUSTOMER_KEY", "CUSTOMER_CUSTOMER_KEY")),
anyTree(
filter("ORDERS_CUSTOMER_KEY IS NOT NULL AND ORDERS_ORDER_KEY IS NOT NULL",
tableScan("orders", ImmutableMap.of("ORDERS_ORDER_KEY", "orderkey",
"ORDERS_CUSTOMER_KEY", "custkey")))),
anyTree(
filter("CUSTOMER_CUSTOMER_KEY IS NOT NULL",
tableScan("customer", ImmutableMap.of("CUSTOMER_CUSTOMER_KEY", "custkey")))))))));
}
@Test
public void testMultipleNotNullsAddedForMultipleEquiJoinClause()
{
String query = "select 1 from lineitem l join orders o on l.orderkey = o.orderkey and l.partkey = o.custkey";
assertPlan(query,
anyTree(
join(INNER,
ImmutableList.of(equiJoinClause("LINE_ORDER_KEY", "ORDERS_ORDER_KEY"),
equiJoinClause("partkey", "custkey")),
anyTree(
filter("LINE_ORDER_KEY IS NOT NULL AND partkey IS NOT NULL",
tableScan("lineitem",
ImmutableMap.of(
"LINE_ORDER_KEY", "orderkey",
"partkey", "partkey")))),
anyTree(
filter("ORDERS_ORDER_KEY IS NOT NULL AND custkey IS NOT NULL",
tableScan("orders",
ImmutableMap.of(
"ORDERS_ORDER_KEY", "orderkey",
"custkey", "custkey")))))));
}
@Test
public void testNotNullInferredForJoinFilter()
{
String query = "select 1 from lineitem l join orders o on l.orderkey = o.orderkey and partkey + custkey > 10";
assertPlan(query,
anyTree(
join(INNER,
// Only single equi join clause in this case
ImmutableList.of(equiJoinClause("LINE_ORDER_KEY", "ORDERS_ORDER_KEY")),
// Extra join filter is passed unchanged
Optional.of("partkey + custkey > 10"),
// We can infer NOT NULL filters on partkey and custkey since the ADD function cannot operate on NULL arguments
anyTree(
filter("LINE_ORDER_KEY IS NOT NULL AND partkey IS NOT NULL",
tableScan("lineitem",
ImmutableMap.of(
"LINE_ORDER_KEY", "orderkey",
"partkey", "partkey")))),
anyTree(
filter("ORDERS_ORDER_KEY IS NOT NULL AND custkey IS NOT NULL",
tableScan("orders",
ImmutableMap.of(
"ORDERS_ORDER_KEY", "orderkey",
"custkey", "custkey")))))));
}
@Test
public void testNotNullPredicatesAddedOnlyForInnerSideTablesVariableReferences()
{
String query = "select 1 from lineitem l left join orders o on l.orderkey = o.orderkey and partkey - custkey > 10";
assertPlan(query,
anyTree(
join(LEFT,
ImmutableList.of(equiJoinClause("LINE_ORDER_KEY", "ORDERS_ORDER_KEY")),
Optional.of("partkey - custkey > 10"),
anyTree(
tableScan("lineitem", ImmutableMap.of(
"LINE_ORDER_KEY", "orderkey",
"partkey", "partkey"))),
anyTree(
filter("ORDERS_ORDER_KEY IS NOT NULL and custkey IS NOT NULL",
tableScan("orders",
ImmutableMap.of(
"ORDERS_ORDER_KEY", "orderkey",
"custkey", "custkey")))))));
query = "select 1 from lineitem l right join orders o on l.orderkey = o.orderkey and custkey > partkey";
assertPlan(query,
anyTree(
join(RIGHT,
ImmutableList.of(equiJoinClause("LINE_ORDER_KEY", "ORDERS_ORDER_KEY")),
Optional.of("custkey > partkey"),
anyTree(
filter("LINE_ORDER_KEY IS NOT NULL and partkey IS NOT NULL",
tableScan("lineitem", ImmutableMap.of(
"LINE_ORDER_KEY", "orderkey",
"partkey", "partkey")))),
anyTree(
tableScan("orders",
ImmutableMap.of(
"ORDERS_ORDER_KEY", "orderkey",
"custkey", "custkey"))))));
}
@Test(dataProvider = "standardOperatorTestCases")
public void testNotNullInferenceForInferFromStandardOperatorsStrategy(String filterSql, String[] expectedInferredNotNullVariables)
{
assertInferredNotNullVariableRefsListMatch(INFER_FROM_STANDARD_OPERATORS, filterSql, buildVariableReferencesList(expectedInferredNotNullVariables));
}
@Test(dataProvider = "nonStandardOperatorTestCases")
public void testNotNullInferenceForUseFunctionMetadataStrategy(String filterSql, String[] expectedInferredNotNullVariables)
{
assertInferredNotNullVariableRefsListMatch(USE_FUNCTION_METADATA, filterSql, buildVariableReferencesList(expectedInferredNotNullVariables));
}
@Test(dataProvider = "getExistingNotNullVarsTestCases")
public void testGetExistingNotNullVars(String filterSql, String[] expectedNotNullVars)
{
Set<VariableReferenceExpression> actual = new AddNotNullFiltersToJoinNode(functionAndTypeManager).getExistingNotNullVariables(
Optional.of(rowExpressionTranslator.translate(filterSql, testVariableTypeMap)));
assertEquals(actual, buildVariableReferencesList(expectedNotNullVars));
}
private List<VariableReferenceExpression> buildVariableReferencesList(String... var)
{
return Arrays.stream(var).map(x -> variable(x, testVariableTypeMap.get(x))).collect(Collectors.toList());
}
private void assertInferredNotNullVariableRefsListMatch(JoinNotNullInferenceStrategy notNullInferenceStrategy,
String filterSql, List<VariableReferenceExpression> expectedInferredNotNullVariables)
{
ExtractInferredNotNullVariablesVisitor visitor = new ExtractInferredNotNullVariablesVisitor(functionAndTypeManager, notNullInferenceStrategy);
RowExpression rowExpression = rowExpressionTranslator.translate(filterSql, testVariableTypeMap);
ImmutableSet.Builder<VariableReferenceExpression> builder = ImmutableSet.builder();
rowExpression.accept(visitor, builder);
assertEquals(builder.build(), expectedInferredNotNullVariables);
}
static {
ImmutableMap.Builder<String, Type> builder = ImmutableMap.builder();
builder.put("a", BIGINT);
builder.put("b", BIGINT);
builder.put("c", BIGINT);
builder.put("d", BIGINT);
builder.put("e", BIGINT);
builder.put("arr", new ArrayType(BIGINT));
builder.put("c_struct", RowType.from(ImmutableList.of(RowType.field("a", BIGINT))));
testVariableTypeMap = builder.build();
}
}