TestRowExpressionTranslator.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.relational;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.expressions.translator.FunctionTranslator;
import com.facebook.presto.expressions.translator.RowExpressionTranslator;
import com.facebook.presto.expressions.translator.RowExpressionTreeTranslator;
import com.facebook.presto.expressions.translator.TranslatedExpression;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.function.FunctionMetadata;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.ScalarOperator;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.TestingRowExpressionTranslator;
import com.facebook.presto.sql.planner.TypeProvider;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.testng.annotations.Test;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.expressions.translator.FunctionTranslator.buildFunctionTranslator;
import static com.facebook.presto.expressions.translator.RowExpressionTreeTranslator.translateWith;
import static com.facebook.presto.expressions.translator.TranslatedExpression.untranslated;
import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
import static java.util.Objects.requireNonNull;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;
public class TestRowExpressionTranslator
{
private static final Metadata METADATA = MetadataManager.createTestMetadataManager();
private final FunctionAndTypeManager functionAndTypeManager;
private final TestingRowExpressionTranslator sqlToRowExpressionTranslator;
public TestRowExpressionTranslator()
{
this.functionAndTypeManager = METADATA.getFunctionAndTypeManager();
this.sqlToRowExpressionTranslator = new TestingRowExpressionTranslator(METADATA);
}
@Test
public void testEndToEndFunctionTranslation()
{
String untranslated = "LN(bitwise_and(1, col1))";
TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("col1", BIGINT));
CallExpression callExpression = (CallExpression) sqlToRowExpressionTranslator.translate(expression(untranslated), typeProvider);
TranslatedExpression translatedExpression = translateWith(
callExpression,
new TestFunctionTranslator(functionAndTypeManager, buildFunctionTranslator(ImmutableSet.of(TestFunctions.class))),
emptyMap());
assertTrue(translatedExpression.getTranslated().isPresent());
assertEquals(translatedExpression.getTranslated().get(), "LNof(1 BITWISE_AND col1)");
}
@Test
public void testEndToEndSpecialFormTranslation()
{
String untranslated = "col1 AND col2";
TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("col1", BOOLEAN, "col2", BOOLEAN));
RowExpression specialForm = sqlToRowExpressionTranslator.translate(expression(untranslated), typeProvider);
TranslatedExpression translatedExpression = translateWith(
specialForm,
new TestFunctionTranslator(functionAndTypeManager, buildFunctionTranslator(ImmutableSet.of(TestFunctions.class))),
emptyMap());
assertTrue(translatedExpression.getTranslated().isPresent());
assertEquals(translatedExpression.getTranslated().get(), "col1 TEST_AND col2");
}
@Test
public void testMissingFunctionTranslator()
{
String untranslated = "ABS(col1)";
TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("col1", DOUBLE));
RowExpression specialForm = sqlToRowExpressionTranslator.translate(expression(untranslated), typeProvider);
TranslatedExpression translatedExpression = translateWith(
specialForm,
new TestFunctionTranslator(functionAndTypeManager, buildFunctionTranslator(ImmutableSet.of(TestFunctions.class))),
emptyMap());
assertFalse(translatedExpression.getTranslated().isPresent());
}
@Test
public void testIncorrectFunctionSignatureInDefinition()
{
String untranslated = "CEIL(col1)";
TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("col1", DOUBLE));
RowExpression specialForm = sqlToRowExpressionTranslator.translate(expression(untranslated), typeProvider);
TranslatedExpression translatedExpression = translateWith(
specialForm,
new TestFunctionTranslator(functionAndTypeManager, buildFunctionTranslator(ImmutableSet.of(TestFunctions.class))),
emptyMap());
assertFalse(translatedExpression.getTranslated().isPresent());
}
@Test
public void testHiddenFunctionNot()
{
String untranslated = "NOT true";
RowExpression specialForm = sqlToRowExpressionTranslator.translate(expression(untranslated), TypeProvider.empty());
TranslatedExpression translatedExpression = translateWith(
specialForm,
new TestFunctionTranslator(functionAndTypeManager, buildFunctionTranslator(ImmutableSet.of(TestFunctions.class))),
emptyMap());
assertTrue(translatedExpression.getTranslated().isPresent());
assertEquals(translatedExpression.getTranslated().get(), "NOT_2 true");
}
@Test
public void testBasicOperator()
{
String untranslated = "col1 + col2";
TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("col1", BIGINT, "col2", BIGINT));
RowExpression specialForm = sqlToRowExpressionTranslator.translate(expression(untranslated), typeProvider);
TranslatedExpression translatedExpression = translateWith(
specialForm,
new TestFunctionTranslator(functionAndTypeManager, buildFunctionTranslator(ImmutableSet.of(TestFunctions.class))),
emptyMap());
assertTrue(translatedExpression.getTranslated().isPresent());
assertEquals(translatedExpression.getTranslated().get(), "col1 -|- col2");
}
@Test
public void testLessThanOperator()
{
String untranslated = "col1 < col2";
TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("col1", BIGINT, "col2", BIGINT));
RowExpression specialForm = sqlToRowExpressionTranslator.translate(expression(untranslated), typeProvider);
TranslatedExpression translatedExpression = translateWith(
specialForm,
new TestFunctionTranslator(functionAndTypeManager, buildFunctionTranslator(ImmutableSet.of(TestFunctions.class))),
emptyMap());
assertTrue(translatedExpression.getTranslated().isPresent());
assertEquals(translatedExpression.getTranslated().get(), "col1 LT col2");
}
@Test
public void testUntranslatableSpecialForm()
{
String untranslated = "col1 OR col2";
TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("col1", BOOLEAN, "col2", BOOLEAN));
RowExpression specialForm = sqlToRowExpressionTranslator.translate(expression(untranslated), typeProvider);
TranslatedExpression translatedExpression = translateWith(
specialForm,
new TestFunctionTranslator(functionAndTypeManager, buildFunctionTranslator(ImmutableSet.of(TestFunctions.class))),
emptyMap());
assertFalse(translatedExpression.getTranslated().isPresent());
}
private class TestFunctionTranslator
extends RowExpressionTranslator<String, Map<VariableReferenceExpression, ColumnHandle>>
{
private final FunctionAndTypeManager functionAndTypeManager;
private final FunctionTranslator<String> functionTranslator;
TestFunctionTranslator(FunctionAndTypeManager functionAndTypeManager, FunctionTranslator<String> functionTranslator)
{
this.functionTranslator = requireNonNull(functionTranslator);
this.functionAndTypeManager = requireNonNull(functionAndTypeManager);
}
@Override
public TranslatedExpression<String> translateConstant(ConstantExpression literal, Map<VariableReferenceExpression, ColumnHandle> context, RowExpressionTreeTranslator<String, Map<VariableReferenceExpression, ColumnHandle>> rowExpressionTreeTranslator)
{
return new TranslatedExpression<>(Optional.of(literal.toString()), literal, emptyList());
}
@Override
public TranslatedExpression<String> translateCall(CallExpression callExpression, Map<VariableReferenceExpression, ColumnHandle> context, RowExpressionTreeTranslator<String, Map<VariableReferenceExpression, ColumnHandle>> rowExpressionTreeTranslator)
{
List<TranslatedExpression<String>> translatedExpressions = callExpression.getArguments().stream()
.map(expression -> rowExpressionTreeTranslator.rewrite(expression, context))
.collect(Collectors.toList());
FunctionMetadata functionMetadata = functionAndTypeManager.getFunctionMetadata(callExpression.getFunctionHandle());
try {
return functionTranslator.translate(functionMetadata, callExpression, translatedExpressions);
}
catch (Throwable t) {
return untranslated(callExpression, translatedExpressions);
}
}
@Override
public TranslatedExpression<String> translateSpecialForm(SpecialFormExpression specialFormExpression, Map<VariableReferenceExpression, ColumnHandle> context, RowExpressionTreeTranslator<String, Map<VariableReferenceExpression, ColumnHandle>> rowExpressionTreeTranslator)
{
if (!specialFormExpression.getForm().equals(SpecialFormExpression.Form.AND)) {
return untranslated(specialFormExpression);
}
List<TranslatedExpression<String>> translatedExpressions = specialFormExpression.getArguments().stream()
.map(expression -> rowExpressionTreeTranslator.rewrite(expression, context))
.collect(Collectors.toList());
assertTrue(translatedExpressions.get(0).getTranslated().isPresent());
assertTrue(translatedExpressions.get(1).getTranslated().isPresent());
return new TranslatedExpression<>(
Optional.of(translatedExpressions.get(0).getTranslated().get() + " TEST_AND " + translatedExpressions.get(1).getTranslated().get()),
specialFormExpression,
translatedExpressions);
}
@Override
public TranslatedExpression<String> translateVariable(VariableReferenceExpression variable, Map<VariableReferenceExpression, ColumnHandle> context, RowExpressionTreeTranslator<String, Map<VariableReferenceExpression, ColumnHandle>> rowExpressionTreeTranslator)
{
return new TranslatedExpression<>(Optional.of(variable.getName()), variable, emptyList());
}
}
public static class TestFunctions
{
@ScalarFunction
@SqlType(StandardTypes.BIGINT)
public static String bitwiseAnd(@SqlType(StandardTypes.BIGINT) String left, @SqlType(StandardTypes.BIGINT) String right)
{
return left + " BITWISE_AND " + right;
}
@ScalarFunction("ln")
@SqlType(StandardTypes.DOUBLE)
public static String ln(@SqlType(StandardTypes.DOUBLE) String sql)
{
return "LNof(" + sql + ")";
}
@ScalarFunction("ceil")
@SqlType(StandardTypes.DOUBLE)
public static String ceil(@SqlType(StandardTypes.BOOLEAN) String sql)
{
return "CEILof(" + sql + ")";
}
@ScalarFunction("not")
@SqlType(StandardTypes.BOOLEAN)
public static String not(@SqlType(StandardTypes.BOOLEAN) String sql)
{
return "NOT_2 " + sql;
}
@ScalarOperator(OperatorType.ADD)
@SqlType(StandardTypes.BIGINT)
public static String plus(@SqlType(StandardTypes.BIGINT) String left, @SqlType(StandardTypes.BIGINT) String right)
{
return left + " -|- " + right;
}
@ScalarOperator(OperatorType.LESS_THAN)
@SqlType(StandardTypes.BOOLEAN)
public static String lessThan(@SqlType(StandardTypes.BIGINT) String left, @SqlType(StandardTypes.BIGINT) String right)
{
return left + " LT " + right;
}
}
}