SqlToRowExpressionTranslator.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.relational;
import com.facebook.presto.Session;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.function.SqlFunctionProperties;
import com.facebook.presto.common.transaction.TransactionId;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.CharType;
import com.facebook.presto.common.type.DecimalParseResult;
import com.facebook.presto.common.type.Decimals;
import com.facebook.presto.common.type.DistinctType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.RowType.Field;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeWithName;
import com.facebook.presto.common.type.UnknownType;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.function.SqlFunctionId;
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.ExistsExpression;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.QuantifiedComparisonExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression.Form;
import com.facebook.presto.spi.relation.UnresolvedSymbolExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.FunctionAndTypeResolver;
import com.facebook.presto.sql.analyzer.SemanticErrorCode;
import com.facebook.presto.sql.analyzer.SemanticException;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.sql.tree.ArithmeticBinaryExpression;
import com.facebook.presto.sql.tree.ArithmeticUnaryExpression;
import com.facebook.presto.sql.tree.ArrayConstructor;
import com.facebook.presto.sql.tree.AstVisitor;
import com.facebook.presto.sql.tree.AtTimeZone;
import com.facebook.presto.sql.tree.BetweenPredicate;
import com.facebook.presto.sql.tree.BinaryLiteral;
import com.facebook.presto.sql.tree.BindExpression;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.CharLiteral;
import com.facebook.presto.sql.tree.CoalesceExpression;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.CurrentTime;
import com.facebook.presto.sql.tree.CurrentUser;
import com.facebook.presto.sql.tree.DecimalLiteral;
import com.facebook.presto.sql.tree.DereferenceExpression;
import com.facebook.presto.sql.tree.DoubleLiteral;
import com.facebook.presto.sql.tree.EnumLiteral;
import com.facebook.presto.sql.tree.ExistsPredicate;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.Extract;
import com.facebook.presto.sql.tree.FieldReference;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.GenericLiteral;
import com.facebook.presto.sql.tree.Identifier;
import com.facebook.presto.sql.tree.IfExpression;
import com.facebook.presto.sql.tree.InListExpression;
import com.facebook.presto.sql.tree.InPredicate;
import com.facebook.presto.sql.tree.IntervalLiteral;
import com.facebook.presto.sql.tree.IsNotNullPredicate;
import com.facebook.presto.sql.tree.IsNullPredicate;
import com.facebook.presto.sql.tree.LambdaArgumentDeclaration;
import com.facebook.presto.sql.tree.LambdaExpression;
import com.facebook.presto.sql.tree.LikePredicate;
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.NodeRef;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.NullIfExpression;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.Row;
import com.facebook.presto.sql.tree.SearchedCaseExpression;
import com.facebook.presto.sql.tree.SimpleCaseExpression;
import com.facebook.presto.sql.tree.StringLiteral;
import com.facebook.presto.sql.tree.SubscriptExpression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.sql.tree.TimeLiteral;
import com.facebook.presto.sql.tree.TimestampLiteral;
import com.facebook.presto.sql.tree.TryExpression;
import com.facebook.presto.sql.tree.WhenClause;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.regex.Pattern;
import static com.facebook.presto.common.function.OperatorType.BETWEEN;
import static com.facebook.presto.common.function.OperatorType.EQUAL;
import static com.facebook.presto.common.function.OperatorType.NEGATION;
import static com.facebook.presto.common.function.OperatorType.SUBSCRIPT;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.CharType.createCharType;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.JsonType.JSON;
import static com.facebook.presto.common.type.SmallintType.SMALLINT;
import static com.facebook.presto.common.type.TimeType.TIME;
import static com.facebook.presto.common.type.TimeWithTimeZoneType.TIME_WITH_TIME_ZONE;
import static com.facebook.presto.common.type.TimestampType.TIMESTAMP;
import static com.facebook.presto.common.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE;
import static com.facebook.presto.common.type.TinyintType.TINYINT;
import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.common.type.TypeUtils.isEnumType;
import static com.facebook.presto.common.type.VarbinaryType.VARBINARY;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.common.type.VarcharType.createVarcharType;
import static com.facebook.presto.metadata.CastType.CAST;
import static com.facebook.presto.metadata.CastType.TRY_CAST;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.AND;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.BIND;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.DEREFERENCE;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IN;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.NULL_IF;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.OR;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.ROW_CONSTRUCTOR;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.SWITCH;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.WHEN;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.getSourceLocation;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.resolveEnumLiteral;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TYPE_MISMATCH;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.facebook.presto.sql.relational.Expressions.constantNull;
import static com.facebook.presto.sql.relational.Expressions.field;
import static com.facebook.presto.sql.relational.Expressions.inSubquery;
import static com.facebook.presto.sql.relational.Expressions.quantifiedComparison;
import static com.facebook.presto.sql.relational.Expressions.specialForm;
import static com.facebook.presto.sql.tree.DereferenceExpression.getQualifiedName;
import static com.facebook.presto.type.LikePatternType.LIKE_PATTERN;
import static com.facebook.presto.util.DateTimeUtils.parseDayTimeInterval;
import static com.facebook.presto.util.DateTimeUtils.parseTimeWithTimeZone;
import static com.facebook.presto.util.DateTimeUtils.parseTimeWithoutTimeZone;
import static com.facebook.presto.util.DateTimeUtils.parseTimestampLiteral;
import static com.facebook.presto.util.DateTimeUtils.parseYearMonthInterval;
import static com.facebook.presto.util.LegacyRowFieldOrdinalAccessUtil.parseAnonymousRowFieldOrdinalAccess;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.slice.SliceUtf8.countCodePoints;
import static io.airlift.slice.Slices.utf8Slice;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
public final class SqlToRowExpressionTranslator
{
private static final Pattern LIKE_PREFIX_MATCH_PATTERN = Pattern.compile("^[^%_]*%$");
private static final Pattern LIKE_SUFFIX_MATCH_PATTERN = Pattern.compile("^%[^%_]*$");
private static final Pattern LIKE_SIMPLE_EXISTS_PATTERN = Pattern.compile("^%[^%_]*%$");
private SqlToRowExpressionTranslator() {}
public static RowExpression translate(
Expression expression,
Map<NodeRef<Expression>, Type> types,
Map<VariableReferenceExpression, Integer> layout,
FunctionAndTypeManager functionAndTypeManager,
Session session)
{
return translate(
expression,
types,
layout,
functionAndTypeManager,
session,
new Context());
}
public static RowExpression translate(
Expression expression,
Map<NodeRef<Expression>, Type> types,
Map<VariableReferenceExpression, Integer> layout,
FunctionAndTypeManager functionAndTypeManager,
Session session,
Context context)
{
return translate(
expression,
types,
layout,
functionAndTypeManager,
Optional.of(session.getUser()),
session.getTransactionId(),
session.getSqlFunctionProperties(),
session.getSessionFunctions(),
context);
}
public static RowExpression translate(
Expression expression,
Map<NodeRef<Expression>, Type> types,
Map<VariableReferenceExpression, Integer> layout,
FunctionAndTypeManager functionAndTypeManager,
Optional<String> user,
Optional<TransactionId> transactionId,
SqlFunctionProperties sqlFunctionProperties,
Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions,
Context context)
{
Visitor visitor = new Visitor(
types,
layout,
functionAndTypeManager,
user,
transactionId,
sqlFunctionProperties,
sessionFunctions);
RowExpression result = visitor.process(expression, context);
requireNonNull(result, "translated expression is null");
return result;
}
public static class Context
{
private final Map<Expression, RowExpression> rowExpressionMap = new IdentityHashMap<>();
private final Map<RowExpression, Expression> expressionMap = new IdentityHashMap<>();
public Context() {}
public Map<Expression, RowExpression> getRowExpressionMap()
{
return rowExpressionMap;
}
public Map<RowExpression, Expression> getExpressionMap()
{
return expressionMap;
}
public void put(Expression expression, RowExpression rowExpression)
{
rowExpressionMap.put(expression, rowExpression);
expressionMap.put(rowExpression, expression);
}
}
private static class Visitor
extends AstVisitor<RowExpression, Context>
{
private final Map<NodeRef<Expression>, Type> types;
private final Map<VariableReferenceExpression, Integer> layout;
private final FunctionAndTypeManager functionAndTypeManager;
private final FunctionAndTypeResolver functionAndTypeResolver;
private final Optional<String> user;
private final Optional<TransactionId> transactionId;
private final SqlFunctionProperties sqlFunctionProperties;
private final Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions;
private final FunctionResolution functionResolution;
private Visitor(
Map<NodeRef<Expression>, Type> types,
Map<VariableReferenceExpression, Integer> layout,
FunctionAndTypeManager functionAndTypeManager,
Optional<String> user,
Optional<TransactionId> transactionId,
SqlFunctionProperties sqlFunctionProperties,
Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions)
{
this.types = requireNonNull(types, "types is null");
this.layout = requireNonNull(layout);
this.functionAndTypeManager = requireNonNull(functionAndTypeManager);
this.functionAndTypeResolver = functionAndTypeManager.getFunctionAndTypeResolver();
this.user = requireNonNull(user);
this.transactionId = requireNonNull(transactionId);
this.sqlFunctionProperties = requireNonNull(sqlFunctionProperties);
this.functionResolution = new FunctionResolution(functionAndTypeResolver);
this.sessionFunctions = requireNonNull(sessionFunctions);
}
private Type getType(Expression node)
{
return types.get(NodeRef.of(node));
}
@Override
public RowExpression process(Node node, Context context)
{
if (!(node instanceof Expression)) {
throw new UnsupportedOperationException("not yet implemented: expression translator for " + node.getClass().getName());
}
Expression expression = (Expression) node;
if (context.getRowExpressionMap().containsKey(expression)) {
return context.getRowExpressionMap().get(expression);
}
RowExpression rowExpression = super.process(expression, context);
context.put(expression, rowExpression);
return rowExpression;
}
@Override
protected RowExpression visitExpression(Expression node, Context context)
{
throw new UnsupportedOperationException("not yet implemented: expression translator for " + node.getClass().getName());
}
@Override
protected RowExpression visitIdentifier(Identifier node, Context context)
{
// identifier should never be reachable with the exception of lambda within VALUES (#9711)
return new VariableReferenceExpression(getSourceLocation(node), node.getValue(), getType(node));
}
@Override
protected RowExpression visitFieldReference(FieldReference node, Context context)
{
return field(getSourceLocation(node), node.getFieldIndex(), getType(node));
}
@Override
protected RowExpression visitNullLiteral(NullLiteral node, Context context)
{
return constantNull(getSourceLocation(node), UnknownType.UNKNOWN);
}
@Override
protected RowExpression visitBooleanLiteral(BooleanLiteral node, Context context)
{
return constant(node.getValue(), BOOLEAN);
}
@Override
protected RowExpression visitLongLiteral(LongLiteral node, Context context)
{
if (node.getValue() >= Integer.MIN_VALUE && node.getValue() <= Integer.MAX_VALUE) {
return constant(node.getValue(), INTEGER);
}
return constant(node.getValue(), BIGINT);
}
@Override
protected RowExpression visitDoubleLiteral(DoubleLiteral node, Context context)
{
return constant(node.getValue(), functionAndTypeManager.getType(DOUBLE.getTypeSignature()));
}
@Override
protected RowExpression visitDecimalLiteral(DecimalLiteral node, Context context)
{
DecimalParseResult parseResult = Decimals.parse(node.getValue());
return constant(parseResult.getObject(), parseResult.getType());
}
@Override
protected RowExpression visitStringLiteral(StringLiteral node, Context context)
{
return constant(node.getSlice(), createVarcharType(countCodePoints(node.getSlice())));
}
@Override
protected RowExpression visitCharLiteral(CharLiteral node, Context context)
{
return constant(node.getSlice(), createCharType(node.getValue().length()));
}
@Override
protected RowExpression visitBinaryLiteral(BinaryLiteral node, Context context)
{
return constant(node.getValue(), VARBINARY);
}
@Override
protected RowExpression visitEnumLiteral(EnumLiteral node, Context context)
{
Type type;
try {
type = functionAndTypeResolver.getType(parseTypeSignature(node.getType()));
}
catch (IllegalArgumentException e) {
throw new IllegalArgumentException("Unsupported type: " + node.getType());
}
return constant(node.getValue(), type);
}
@Override
protected RowExpression visitGenericLiteral(GenericLiteral node, Context context)
{
Type type;
try {
type = functionAndTypeResolver.getType(parseTypeSignature(node.getType()));
}
catch (IllegalArgumentException e) {
throw new IllegalArgumentException("Unsupported type: " + node.getType());
}
try {
if (TINYINT.equals(type)) {
return constant((long) Byte.parseByte(node.getValue()), TINYINT);
}
else if (SMALLINT.equals(type)) {
return constant((long) Short.parseShort(node.getValue()), SMALLINT);
}
else if (INTEGER.equals(type)) {
return constant((long) Integer.parseInt(node.getValue()), INTEGER);
}
else if (BIGINT.equals(type)) {
return constant(Long.parseLong(node.getValue()), BIGINT);
}
}
catch (NumberFormatException e) {
throw new SemanticException(SemanticErrorCode.INVALID_LITERAL, node, format("Invalid formatted generic %s literal: %s", type, node));
}
if (JSON.equals(type)) {
return call(
getSourceLocation(node),
"json_parse",
functionAndTypeResolver.lookupFunction("json_parse", fromTypes(VARCHAR)),
getType(node),
constant(utf8Slice(node.getValue()), VARCHAR));
}
return call(
getSourceLocation(node),
CAST.name(),
functionAndTypeResolver.lookupCast("CAST", VARCHAR, getType(node)),
getType(node),
constant(utf8Slice(node.getValue()), VARCHAR));
}
@Override
protected RowExpression visitTimeLiteral(TimeLiteral node, Context context)
{
long value;
if (getType(node).equals(TIME_WITH_TIME_ZONE)) {
value = parseTimeWithTimeZone(node.getValue());
}
else {
if (sqlFunctionProperties.isLegacyTimestamp()) {
// parse in time zone of client
value = parseTimeWithoutTimeZone(sqlFunctionProperties.getTimeZoneKey(), node.getValue());
}
else {
value = parseTimeWithoutTimeZone(node.getValue());
}
}
return constant(value, getType(node));
}
@Override
protected RowExpression visitTimestampLiteral(TimestampLiteral node, Context context)
{
long value;
if (sqlFunctionProperties.isLegacyTimestamp()) {
value = parseTimestampLiteral(sqlFunctionProperties.getTimeZoneKey(), node.getValue());
}
else {
value = parseTimestampLiteral(node.getValue());
}
return constant(value, getType(node));
}
@Override
protected RowExpression visitIntervalLiteral(IntervalLiteral node, Context context)
{
long value;
if (node.isYearToMonth()) {
value = node.getSign().multiplier() * parseYearMonthInterval(node.getValue(), node.getStartField(), node.getEndField());
}
else {
value = node.getSign().multiplier() * parseDayTimeInterval(node.getValue(), node.getStartField(), node.getEndField());
}
return constant(value, getType(node));
}
@Override
protected RowExpression visitComparisonExpression(ComparisonExpression node, Context context)
{
RowExpression left = process(node.getLeft(), context);
RowExpression right = process(node.getRight(), context);
return call(
getSourceLocation(node),
node.getOperator().name(),
functionResolution.comparisonFunction(node.getOperator(), left.getType(), right.getType()),
BOOLEAN,
left,
right);
}
@Override
protected RowExpression visitFunctionCall(FunctionCall node, Context context)
{
List<RowExpression> arguments = node.getArguments().stream()
.map(value -> process(value, context))
.collect(toImmutableList());
List<TypeSignatureProvider> argumentTypes = arguments.stream()
.map(RowExpression::getType)
.map(Type::getTypeSignature)
.map(TypeSignatureProvider::new)
.collect(toImmutableList());
return call(node.getName().toString(),
functionAndTypeResolver.resolveFunction(
Optional.of(sessionFunctions),
transactionId,
functionAndTypeResolver.qualifyObjectName(node.getName()),
argumentTypes),
getType(node),
arguments);
}
@Override
protected RowExpression visitSymbolReference(SymbolReference node, Context context)
{
VariableReferenceExpression variable = new VariableReferenceExpression(getSourceLocation(node), node.getName(), getType(node));
Integer channel = layout.get(variable);
if (channel != null) {
return field(variable.getSourceLocation(), channel, variable.getType());
}
return variable;
}
@Override
protected RowExpression visitLambdaExpression(LambdaExpression node, Context context)
{
RowExpression body = process(node.getBody(), context);
Type type = getType(node);
List<Type> typeParameters = type.getTypeParameters();
List<Type> argumentTypes = typeParameters.subList(0, typeParameters.size() - 1);
List<String> argumentNames = node.getArguments().stream()
.map(LambdaArgumentDeclaration::getName)
.map(Identifier::getValue)
.collect(toImmutableList());
return new LambdaDefinitionExpression(getSourceLocation(node), argumentTypes, argumentNames, body);
}
@Override
protected RowExpression visitBindExpression(BindExpression node, Context context)
{
ImmutableList.Builder<Type> valueTypesBuilder = ImmutableList.builder();
ImmutableList.Builder<RowExpression> argumentsBuilder = ImmutableList.builder();
for (Expression value : node.getValues()) {
RowExpression valueRowExpression = process(value, context);
valueTypesBuilder.add(valueRowExpression.getType());
argumentsBuilder.add(valueRowExpression);
}
RowExpression function = process(node.getFunction(), context);
argumentsBuilder.add(function);
return specialForm(BIND, getType(node), argumentsBuilder.build());
}
@Override
protected RowExpression visitArithmeticBinary(ArithmeticBinaryExpression node, Context context)
{
RowExpression left = process(node.getLeft(), context);
RowExpression right = process(node.getRight(), context);
return call(
getSourceLocation(node),
node.getOperator().name(),
functionResolution.arithmeticFunction(node.getOperator(), left.getType(), right.getType()),
getType(node),
left,
right);
}
@Override
protected RowExpression visitArithmeticUnary(ArithmeticUnaryExpression node, Context context)
{
RowExpression expression = process(node.getValue(), context);
switch (node.getSign()) {
case PLUS:
return expression;
case MINUS:
return call(
getSourceLocation(node),
NEGATION.name(),
functionAndTypeResolver.resolveOperator(NEGATION, fromTypes(expression.getType())),
getType(node),
expression);
}
throw new UnsupportedOperationException("Unsupported unary operator: " + node.getSign());
}
@Override
protected RowExpression visitLogicalBinaryExpression(LogicalBinaryExpression node, Context context)
{
Form form;
switch (node.getOperator()) {
case AND:
form = AND;
break;
case OR:
form = OR;
break;
default:
throw new IllegalStateException("Unknown logical operator: " + node.getOperator());
}
return specialForm(getSourceLocation(node), form, BOOLEAN, process(node.getLeft(), context), process(node.getRight(), context));
}
@Override
protected RowExpression visitCast(Cast node, Context context)
{
RowExpression value = process(node.getExpression(), context);
if (node.isSafe()) {
return call(getSourceLocation(node), TRY_CAST.name(), functionAndTypeResolver.lookupCast("TRY_CAST", value.getType(), getType(node)), getType(node), value);
}
return call(getSourceLocation(node), CAST.name(), functionAndTypeResolver.lookupCast("CAST", value.getType(), getType(node)), getType(node), value);
}
@Override
protected RowExpression visitCoalesceExpression(CoalesceExpression node, Context context)
{
List<RowExpression> arguments = node.getOperands().stream()
.map(value -> process(value, context))
.collect(toImmutableList());
return specialForm(COALESCE, getType(node), arguments);
}
@Override
protected RowExpression visitSimpleCaseExpression(SimpleCaseExpression node, Context context)
{
return buildSwitch(process(node.getOperand(), context), node.getWhenClauses(), node.getDefaultValue(), getType(node), context);
}
@Override
protected RowExpression visitSearchedCaseExpression(SearchedCaseExpression node, Context context)
{
// We rewrite this as - CASE true WHEN p1 THEN v1 WHEN p2 THEN v2 .. ELSE v END
return buildSwitch(new ConstantExpression(getSourceLocation(node), true, BOOLEAN), node.getWhenClauses(), node.getDefaultValue(), getType(node), context);
}
private RowExpression buildSwitch(RowExpression operand, List<WhenClause> whenClauses, Optional<Expression> defaultValue, Type returnType, Context context)
{
ImmutableList.Builder<RowExpression> arguments = ImmutableList.builder();
arguments.add(operand);
for (WhenClause clause : whenClauses) {
arguments.add(specialForm(
getSourceLocation(clause),
WHEN,
getType(clause.getResult()),
process(clause.getOperand(), context),
process(clause.getResult(), context)));
}
arguments.add(defaultValue
.map((value) -> process(value, context))
.orElseGet(() -> constantNull(operand.getSourceLocation(), returnType)));
return specialForm(SWITCH, returnType, arguments.build());
}
@Override
protected RowExpression visitDereferenceExpression(DereferenceExpression node, Context context)
{
Type returnType = getType(node);
Type baseType = getType(node.getBase());
if (baseType == null) {
return new UnresolvedSymbolExpression(getSourceLocation(node), returnType, getQualifiedName(node).getParts());
}
if (isEnumType(baseType) && isEnumType(returnType)) {
return constant(resolveEnumLiteral(node, baseType), returnType);
}
if (baseType instanceof TypeWithName) {
baseType = ((TypeWithName) baseType).getType();
}
if (baseType instanceof DistinctType) {
baseType = ((DistinctType) baseType).getBaseType();
}
RowType rowType = (RowType) baseType;
String fieldName = node.getField().getValue();
List<Field> fields = rowType.getFields();
int index = -1;
for (int i = 0; i < fields.size(); i++) {
Field field = fields.get(i);
if (field.getName().isPresent() && field.getName().get().equalsIgnoreCase(fieldName)) {
checkArgument(index < 0, "Ambiguous field %s in type %s", field, rowType.getDisplayName());
index = i;
}
}
if (sqlFunctionProperties.isLegacyRowFieldOrdinalAccessEnabled() && index < 0) {
OptionalInt rowIndex = parseAnonymousRowFieldOrdinalAccess(fieldName, fields);
if (rowIndex.isPresent()) {
index = rowIndex.getAsInt();
}
}
checkState(index >= 0, "could not find field name: %s", node.getField());
return specialForm(getSourceLocation(node.getBase()), DEREFERENCE, returnType, process(node.getBase(), context), constant((long) index, INTEGER));
}
@Override
protected RowExpression visitIfExpression(IfExpression node, Context context)
{
ImmutableList.Builder<RowExpression> arguments = ImmutableList.builder();
arguments.add(process(node.getCondition(), context))
.add(process(node.getTrueValue(), context));
if (node.getFalseValue().isPresent()) {
arguments.add(process(node.getFalseValue().get(), context));
}
else {
arguments.add(constantNull(getSourceLocation(node), getType(node)));
}
return specialForm(IF, getType(node), arguments.build());
}
@Override
protected RowExpression visitTryExpression(TryExpression node, Context context)
{
RowExpression body = process(node.getInnerExpression(), context);
return call(
functionAndTypeResolver,
"$internal$try",
getType(node),
new LambdaDefinitionExpression(
getSourceLocation(node),
ImmutableList.of(),
ImmutableList.of(),
body));
}
private RowExpression buildEquals(RowExpression lhs, RowExpression rhs)
{
return call(
EQUAL.getOperator(),
functionResolution.comparisonFunction(ComparisonExpression.Operator.EQUAL, lhs.getType(), rhs.getType()),
BOOLEAN,
lhs,
rhs);
}
@Override
protected RowExpression visitExists(ExistsPredicate existsPredicate, Context context)
{
RowExpression subquery = process(existsPredicate.getSubquery(), context);
return new ExistsExpression(subquery.getSourceLocation(), subquery);
}
@Override
protected RowExpression visitQuantifiedComparisonExpression(com.facebook.presto.sql.tree.QuantifiedComparisonExpression expression, Context context)
{
return quantifiedComparison(
OperatorType.valueOf(expression.getOperator().name()),
QuantifiedComparisonExpression.Quantifier.valueOf(expression.getQuantifier().name()),
process(expression.getValue(), context),
process(expression.getSubquery(), context));
}
@Override
protected RowExpression visitInPredicate(InPredicate node, Context context)
{
ImmutableList.Builder<RowExpression> arguments = ImmutableList.builder();
RowExpression value = process(node.getValue(), context);
if (!(node.getValueList() instanceof InListExpression)) {
RowExpression subquery = process(node.getValueList(), context);
checkArgument(value instanceof VariableReferenceExpression, "Unexpected expression: %s", value);
checkArgument(subquery instanceof VariableReferenceExpression, "Unexpected expression: %s", subquery);
return inSubquery((VariableReferenceExpression) value, (VariableReferenceExpression) subquery);
}
InListExpression values = (InListExpression) node.getValueList();
if (values.getValues().size() == 1) {
return buildEquals(value, process(values.getValues().get(0), context));
}
arguments.add(value);
for (Expression inValue : values.getValues()) {
arguments.add(process(inValue, context));
}
return specialForm(IN, BOOLEAN, arguments.build());
}
@Override
protected RowExpression visitIsNotNullPredicate(IsNotNullPredicate node, Context context)
{
RowExpression expression = process(node.getValue(), context);
return call(
getSourceLocation(node),
"not",
functionResolution.notFunction(),
BOOLEAN,
specialForm(IS_NULL, BOOLEAN, ImmutableList.of(expression)));
}
@Override
protected RowExpression visitIsNullPredicate(IsNullPredicate node, Context context)
{
RowExpression expression = process(node.getValue(), context);
return specialForm(getSourceLocation(node), IS_NULL, BOOLEAN, expression);
}
@Override
protected RowExpression visitNotExpression(NotExpression node, Context context)
{
return call(getSourceLocation(node), "not", functionResolution.notFunction(), BOOLEAN, process(node.getValue(), context));
}
@Override
protected RowExpression visitNullIfExpression(NullIfExpression node, Context context)
{
RowExpression first = process(node.getFirst(), context);
RowExpression second = process(node.getSecond(), context);
Type returnType = getType(node);
if (!functionAndTypeManager.nullIfSpecialFormEnabled()) {
// If the first type is unknown, as per presto's NULL_IF semantics we should not infer the type using second argument.
// Always return a null with unknown type.
if (first.getType().equals(UnknownType.UNKNOWN)) {
return constantNull(UnknownType.UNKNOWN);
}
RowExpression firstArgWithoutCast = first;
if (!second.getType().equals(first.getType())) {
Optional<Type> commonType = functionAndTypeResolver.getCommonSuperType(first.getType(), second.getType());
if (!commonType.isPresent()) {
throw new SemanticException(TYPE_MISMATCH, node, "Types are not comparable with NULLIF: %s vs %s", first.getType(), second.getType());
}
// cast(first as <common type>)
if (!first.getType().equals(commonType.get())) {
first = call(
getSourceLocation(node),
CAST.name(),
functionAndTypeResolver.lookupCast(CAST.name(), first.getType(), commonType.get()),
commonType.get(), first);
}
// cast(second as <common type>)
if (!second.getType().equals(commonType.get())) {
second = call(
getSourceLocation(node),
CAST.name(),
functionAndTypeResolver.lookupCast(CAST.name(), second.getType(), commonType.get()),
commonType.get(), second);
}
}
FunctionHandle equalsFunctionHandle = functionAndTypeResolver.resolveOperator(EQUAL, fromTypes(first.getType(), second.getType()));
// equal(cast(first as <common type>), cast(second as <common type>))
RowExpression equal = call(EQUAL.name(), equalsFunctionHandle, BOOLEAN, first, second);
// if (equal(cast(first as <common type>), cast(second as <common type>)), cast(null as firstType), first)
return specialForm(IF, returnType, equal, constantNull(returnType), firstArgWithoutCast);
}
return specialForm(getSourceLocation(node), NULL_IF, returnType, first, second);
}
@Override
protected RowExpression visitBetweenPredicate(BetweenPredicate node, Context context)
{
RowExpression value = process(node.getValue(), context);
RowExpression min = process(node.getMin(), context);
RowExpression max = process(node.getMax(), context);
return call(
getSourceLocation(node),
BETWEEN.name(),
functionAndTypeResolver.resolveOperator(BETWEEN, fromTypes(value.getType(), min.getType(), max.getType())),
BOOLEAN,
value,
min,
max);
}
@Override
protected RowExpression visitLikePredicate(LikePredicate node, Context context)
{
RowExpression value = process(node.getValue(), context);
RowExpression pattern = process(node.getPattern(), context);
if (node.getEscape().isPresent()) {
RowExpression escape = process(node.getEscape().get(), context);
if (!functionResolution.supportsLikePatternFunction()) {
return call(value.getSourceLocation(), "LIKE", functionResolution.likeVarcharVarcharVarcharFunction(), BOOLEAN, value, pattern, escape);
}
return likeFunctionCall(value, call(getSourceLocation(node), "LIKE_PATTERN", functionResolution.likePatternFunction(), LIKE_PATTERN, pattern, escape));
}
RowExpression prefixOrSuffixMatch = generateLikePrefixOrSuffixMatch(value, pattern);
if (prefixOrSuffixMatch != null) {
return prefixOrSuffixMatch;
}
if (!functionResolution.supportsLikePatternFunction()) {
return likeFunctionCall(value, pattern);
}
return likeFunctionCall(value, call(getSourceLocation(node), CAST.name(), functionAndTypeResolver.lookupCast("CAST", VARCHAR, LIKE_PATTERN), LIKE_PATTERN, pattern));
}
private RowExpression generateLikePrefixOrSuffixMatch(RowExpression value, RowExpression pattern)
{
if (value.getType() instanceof VarcharType && pattern instanceof ConstantExpression) {
Object constObject = ((ConstantExpression) pattern).getValue();
if (constObject instanceof Slice) {
Slice slice = (Slice) constObject;
String patternString = slice.toStringUtf8();
int matchCharacterLength = patternString.length();
int matchBytesLength = slice.length();
if (matchCharacterLength > 1 && !patternString.contains("_")) {
if (LIKE_PREFIX_MATCH_PATTERN.matcher(patternString).matches()) {
// prefix match
// x LIKE 'some string%' is same as SUBSTR(x, 1, length('some string')) = 'some string', trialing .* won't matter
return buildEquals(
call(functionAndTypeManager, "SUBSTR", VARCHAR, value, constant(1L, BIGINT), constant((long) matchCharacterLength - 1, BIGINT)),
constant(slice.slice(0, matchBytesLength - 1), VARCHAR));
}
else if (LIKE_SUFFIX_MATCH_PATTERN.matcher(patternString).matches()) {
// suffix match
// x LIKE '%some string' is same as SUBSTR(x, 'some string', -length('some string')) = 'some stirng'
return buildEquals(
call(functionAndTypeManager, "SUBSTR", VARCHAR, value, constant(-(long) (matchCharacterLength - 1), BIGINT)),
constant(slice.slice(1, matchBytesLength - 1), VARCHAR));
}
else if (LIKE_SIMPLE_EXISTS_PATTERN.matcher(patternString).matches()) {
// pattern should just exist in the string ignoring leading and trailing stuff
// x LIKE '%some string%' is same as CARDINALITY(SPLIT(x, 'some string', 2)) = 2
// Split is most efficient as it uses string.indexOf java builtin so little memory/cpu overhead
return buildEquals(
call(functionAndTypeManager, "CARDINALITY", BIGINT, call(functionAndTypeManager, "SPLIT", new ArrayType(VARCHAR), value, constant(slice.slice(1, matchBytesLength - 2), VARCHAR), constant(2L, BIGINT))),
constant(2L, BIGINT));
}
}
}
}
return null;
}
private RowExpression likeFunctionCall(RowExpression value, RowExpression pattern)
{
if (value.getType() instanceof VarcharType) {
if (!functionResolution.supportsLikePatternFunction()) {
return call(value.getSourceLocation(), "LIKE", functionResolution.likeVarcharVarcharFunction(), BOOLEAN, value, pattern);
}
return call(value.getSourceLocation(), "LIKE", functionResolution.likeVarcharFunction(), BOOLEAN, value, pattern);
}
checkState(value.getType() instanceof CharType, "LIKE value type is neither VARCHAR or CHAR");
return call(value.getSourceLocation(), "LIKE", functionResolution.likeCharFunction(value.getType()), BOOLEAN, value, pattern);
}
@Override
protected RowExpression visitSubscriptExpression(SubscriptExpression node, Context context)
{
RowExpression base = process(node.getBase(), context);
RowExpression index = process(node.getIndex(), context);
// this block will handle row subscript, converts the ROW_CONSTRUCTOR with subscript to a DEREFERENCE expression
if (base.getType() instanceof RowType) {
checkState(index instanceof ConstantExpression, "Subscript expression on ROW requires a ConstantExpression");
ConstantExpression position = (ConstantExpression) index;
checkState(position.getValue() instanceof Long, "ConstantExpression should contain a valid integer index into the row");
Long offset = (Long) position.getValue();
checkState(
offset >= 1 && offset <= base.getType().getTypeParameters().size(),
"Subscript index out of bounds %s: should be >= 1 and <= %s",
offset,
base.getType().getTypeParameters().size());
return specialForm(getSourceLocation(node), DEREFERENCE, getType(node), base, Expressions.constant(offset - 1, INTEGER));
}
return call(
getSourceLocation(node),
SUBSCRIPT.name(),
functionAndTypeResolver.resolveOperator(SUBSCRIPT, fromTypes(base.getType(), index.getType())),
getType(node),
base,
index);
}
@Override
protected RowExpression visitArrayConstructor(ArrayConstructor node, Context context)
{
List<RowExpression> arguments = node.getValues().stream()
.map(value -> process(value, context))
.collect(toImmutableList());
List<Type> argumentTypes = arguments.stream()
.map(RowExpression::getType)
.collect(toImmutableList());
return call("ARRAY", functionResolution.arrayConstructor(argumentTypes), getType(node), arguments);
}
@Override
protected RowExpression visitRow(Row node, Context context)
{
List<RowExpression> arguments = node.getItems().stream()
.map(value -> process(value, context))
.collect(toImmutableList());
Type returnType = getType(node);
return specialForm(ROW_CONSTRUCTOR, returnType, arguments);
}
@Override
protected RowExpression visitCurrentTime(CurrentTime node, Context context)
{
if (node.getPrecision() != null) {
throw new UnsupportedOperationException("not yet implemented: non-default precision");
}
switch (node.getFunction()) {
case DATE:
return call(functionAndTypeResolver, "current_date", getType(node));
case TIME:
return call(functionAndTypeResolver, "current_time", getType(node));
case LOCALTIME:
return call(functionAndTypeResolver, "localtime", getType(node));
case TIMESTAMP:
return call(functionAndTypeResolver, "current_timestamp", getType(node));
case LOCALTIMESTAMP:
return call(functionAndTypeResolver, "localtimestamp", getType(node));
default:
throw new UnsupportedOperationException("not yet implemented: " + node.getFunction());
}
}
@Override
protected RowExpression visitExtract(Extract node, Context context)
{
RowExpression value = process(node.getExpression(), context);
switch (node.getField()) {
case YEAR:
return call(functionAndTypeResolver, "year", getType(node), value);
case QUARTER:
return call(functionAndTypeResolver, "quarter", getType(node), value);
case MONTH:
return call(functionAndTypeResolver, "month", getType(node), value);
case WEEK:
return call(functionAndTypeResolver, "week", getType(node), value);
case DAY:
case DAY_OF_MONTH:
return call(functionAndTypeResolver, "day", getType(node), value);
case DAY_OF_WEEK:
case DOW:
return call(functionAndTypeResolver, "day_of_week", getType(node), value);
case DAY_OF_YEAR:
case DOY:
return call(functionAndTypeResolver, "day_of_year", getType(node), value);
case YEAR_OF_WEEK:
case YOW:
return call(functionAndTypeResolver, "year_of_week", getType(node), value);
case HOUR:
return call(functionAndTypeResolver, "hour", getType(node), value);
case MINUTE:
return call(functionAndTypeResolver, "minute", getType(node), value);
case SECOND:
return call(functionAndTypeResolver, "second", getType(node), value);
case TIMEZONE_MINUTE:
return call(functionAndTypeResolver, "timezone_minute", getType(node), value);
case TIMEZONE_HOUR:
return call(functionAndTypeResolver, "timezone_hour", getType(node), value);
}
throw new UnsupportedOperationException("not yet implemented: " + node.getField());
}
@Override
protected RowExpression visitAtTimeZone(AtTimeZone node, Context context)
{
RowExpression value = process(node.getValue(), context);
RowExpression timeZone = process(node.getTimeZone(), context);
Type valueType = value.getType();
if (valueType.equals(TIME)) {
value = call(
getSourceLocation(node),
CAST.name(),
functionAndTypeResolver.lookupCast("CAST", valueType, TIME_WITH_TIME_ZONE),
TIME_WITH_TIME_ZONE,
value);
}
else if (valueType.equals(TIMESTAMP)) {
value = call(
getSourceLocation(node),
CAST.name(),
functionAndTypeResolver.lookupCast("CAST", valueType, TIMESTAMP_WITH_TIME_ZONE),
TIMESTAMP_WITH_TIME_ZONE,
value);
}
return call(functionAndTypeResolver, "at_timezone", getType(node), value, timeZone);
}
@Override
protected RowExpression visitCurrentUser(CurrentUser node, Context context)
{
return call(functionAndTypeResolver, "$current_user", getType(node));
}
}
}