TestNullabilityAnalyzer.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.TestingRowExpressionTranslator;
import com.facebook.presto.sql.parser.ParsingOptions;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.iterative.rule.LambdaCaptureDesugaringRewriter;
import com.facebook.presto.sql.tree.Expression;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;
import java.util.Collection;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.RowType.field;
import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences;
import static org.testng.Assert.assertEquals;
public class TestNullabilityAnalyzer
{
private static final Metadata METADATA = MetadataManager.createTestMetadataManager();
private static final TestingRowExpressionTranslator TRANSLATOR = new TestingRowExpressionTranslator(METADATA);
private static final TypeProvider TYPES = TypeProvider.viewOf(new ImmutableMap.Builder<String, Type>()
.put("a", BIGINT)
.put("b", new ArrayType(BIGINT))
.put("c", RowType.from(ImmutableList.of(field("field_1", BIGINT))))
.build());
private static final NullabilityAnalyzer analyzer = new NullabilityAnalyzer(METADATA.getFunctionAndTypeManager());
@Test
void test()
{
assertNullability("TRY_CAST(JSON '123' AS VARCHAR)", true);
assertNullability("TRY_CAST(a AS VARCHAR)", true);
assertNullability("CAST(a AS VARCHAR)", true);
assertNullability("TRY_CAST('123' AS VARCHAR)", false);
assertNullability("CAST('123' AS VARCHAR)", false);
assertNullability("a = 1", false);
assertNullability("(a/9+1)*5-10 > 10", false);
assertNullability("1", false);
assertNullability("a", false);
assertNullability("TRY(a + 1)", true);
assertNullability("IF(a > 10, 1)", true);
assertNullability("a IN (1, NULL)", true);
assertNullability("CASE WHEN a> 10 THEN 1 END", true);
assertNullability("c.field_1", true);
assertNullability("b[0]", true);
assertNullability("a BETWEEN 1 AND 2", false);
// nested
assertNullability("1 = TRY(a + 1)", true);
}
private void assertNullability(String expression, boolean mayReturnNullForNotNullInput)
{
Expression rawExpression = rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(expression, new ParsingOptions()));
Expression desugaredExpression = new TestingDesugarExpressions(TYPES.allVariables()).rewrite(rawExpression);
RowExpression rowExpression = TRANSLATOR.translate(desugaredExpression, TYPES);
assertEquals(analyzer.mayReturnNullOnNonNullInput(rowExpression), mayReturnNullForNotNullInput);
}
private static class TestingDesugarExpressions
{
private final VariableAllocator variableAllocator;
public TestingDesugarExpressions(Collection<VariableReferenceExpression> variables)
{
this.variableAllocator = new VariableAllocator(variables);
}
public Expression rewrite(Expression expression)
{
expression = DesugarTryExpressionRewriter.rewrite(expression);
expression = LambdaCaptureDesugaringRewriter.rewrite(expression, variableAllocator);
return expression;
}
}
}