NativeExecutionTypeRewrite.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.rewrite;
import com.facebook.airlift.log.Logger;
import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.UnknownTypeException;
import com.facebook.presto.common.type.BigintEnumType;
import com.facebook.presto.common.type.EnumType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeWithName;
import com.facebook.presto.common.type.VarcharEnumType;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.security.AccessControl;
import com.facebook.presto.sql.analyzer.FunctionAndTypeResolver;
import com.facebook.presto.sql.analyzer.QueryExplainer;
import com.facebook.presto.sql.analyzer.SemanticException;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.tree.ArrayConstructor;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.DereferenceExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.NodeRef;
import com.facebook.presto.sql.tree.Parameter;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.Statement;
import com.facebook.presto.sql.tree.StringLiteral;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import static com.facebook.presto.common.type.StandardTypes.BIGINT;
import static com.facebook.presto.common.type.StandardTypes.BIGINT_ENUM;
import static com.facebook.presto.common.type.StandardTypes.VARCHAR;
import static com.facebook.presto.common.type.StandardTypes.VARCHAR_ENUM;
import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TYPE_MISMATCH;
import static java.util.Objects.requireNonNull;
/**
* Queries can fail on native worker due to following missing support in Velox:<p>
* 1. Named types: Presto supports {@link TypeWithName} which Velox is not able to parse.<p>
* 2. {@link EnumType}: Velox does not support EnumTypes as well as its companion function {@code ENUM_KEY}.<p>
*
* This rewrite addresses the above issues by resolving the type or function in coordinator for native execution:<p>
* 1. Peel {@link TypeWithName} and only preserve the actual base type.<p>
* 2. Rewrite {@code CAST(col AS EnumType<T>)} -> {@code CAST(col AS <T>)}. <p> TODO: preserve the original type information for `typeof`. <p>
* 3. Since enum can be treated as a map, rewrite {@code ENUM_KEY(EnumType<T>)} -> {@code ELEMENT_AT(MAP(<T>, VARCHAR))}. <p>
*/
final class NativeExecutionTypeRewrite
implements StatementRewrite.Rewrite
{
private static final Logger LOG = Logger.get(ExpressionRewriter.class);
private static final String FUNCTION_ENUM_KEY = "enum_key";
private static final String FUNCTION_ELEMENT_AT = "element_at";
private static final String FUNCTION_MAP = "map";
@Override
public Statement rewrite(
Session session,
Metadata metadata,
SqlParser parser,
Optional<QueryExplainer> queryExplainer,
Statement node,
List<Expression> parameters,
Map<NodeRef<Parameter>, Expression> parameterLookup,
AccessControl accessControl,
WarningCollector warningCollector,
String query)
{
if (SystemSessionProperties.isNativeExecutionEnabled(session)
&& SystemSessionProperties.isNativeExecutionTypeRewriteEnabled(session)) {
return (Statement) new Rewriter(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()).process(node, null);
}
return node;
}
public static Expression rewriteEnumExpressions(Expression expression, FunctionAndTypeResolver functionAndTypeResolver)
{
return ExpressionTreeRewriter.rewriteWith(new EnumExpressionRewriter(functionAndTypeResolver), expression);
}
private static class EnumExpressionRewriter
extends ExpressionRewriter<Void>
{
private final FunctionAndTypeResolver functionAndTypeResolver;
public EnumExpressionRewriter(FunctionAndTypeResolver functionAndTypeResolver)
{
this.functionAndTypeResolver = functionAndTypeResolver;
}
private Expression convertEnumTypeToLiteral(DereferenceExpression key, Type type)
{
String enumKey = key.getField().getValue().toUpperCase();
if (type instanceof BigintEnumType) {
Map<String, Long> enumMap = ((EnumType) type).getEnumMap();
Long enumValue = enumMap.get(enumKey);
if (enumValue == null) {
throw new SemanticException(TYPE_MISMATCH, ".*'" + type.getDisplayName() + "." + key.getField().getValue() + "' cannot be resolved");
}
return new Cast(new LongLiteral(enumValue.toString()), BIGINT);
}
else if (type instanceof VarcharEnumType) {
Map<String, String> enumMap = ((EnumType) type).getEnumMap();
String enumValue = enumMap.get(enumKey);
if (enumValue == null) {
throw new SemanticException(TYPE_MISMATCH, ".*'" + type.getDisplayName() + "." + key.getField().getValue() + "' cannot be resolved");
}
return new StringLiteral(enumValue);
}
return key;
}
@Override
public Expression rewriteExpression(Expression expression, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
return treeRewriter.defaultRewrite(expression, null);
}
@Override
public Expression rewriteDereferenceExpression(DereferenceExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
try {
Type argumentType = functionAndTypeResolver.getType(parseTypeSignature(node.getBase().toString()));
if (argumentType instanceof TypeWithName) {
argumentType = ((TypeWithName) argumentType).getType();
if (argumentType instanceof EnumType) {
return convertEnumTypeToLiteral(node, argumentType);
}
}
}
catch (IllegalArgumentException | UnknownTypeException e) {
// Returns the original expression if rewrite fails.
LOG.warn(e.getMessage());
return node;
}
return node;
}
@Override
public Expression rewriteCast(Cast node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
// Rewrite any enum types to their values.
node = treeRewriter.defaultRewrite(node, null);
// Rewrite type to base type.
try {
Type type = functionAndTypeResolver.getType(parseTypeSignature(node.getType()));
if (type instanceof TypeWithName) {
// Peel user defined type name.
type = ((TypeWithName) type).getType();
switch (type.getTypeSignature().getBase()) {
case BIGINT_ENUM:
return new Cast(node.getLocation(), node.getExpression(), BIGINT, node.isSafe(), node.isTypeOnly());
case VARCHAR_ENUM:
return new Cast(node.getLocation(), node.getExpression(), VARCHAR, node.isSafe(), node.isTypeOnly());
default:
return new Cast(node.getLocation(), node.getExpression(), type.getTypeSignature().getBase(), node.isSafe(), node.isTypeOnly());
}
}
}
catch (IllegalArgumentException | UnknownTypeException e) {
throw new SemanticException(TYPE_MISMATCH, node, "Unknown type: " + node.getType());
}
return node;
}
private boolean isValidEnumKeyFunctionCall(FunctionCall node)
{
return node.getName().equals(QualifiedName.of(FUNCTION_ENUM_KEY))
&& node.getArguments().size() == 1;
}
private Expression convertEnumTypeToMapExpression(Type type)
{
ImmutableList.Builder<Expression> keys = ImmutableList.builder();
ImmutableList.Builder<Expression> values = ImmutableList.builder();
switch (type.getTypeSignature().getBase()) {
case BIGINT_ENUM:
for (Map.Entry<String, Long> entry : ((BigintEnumType) type).getEnumMap().entrySet()) {
keys.add(new LongLiteral(entry.getValue().toString()));
values.add(new StringLiteral(entry.getKey()));
}
break;
case VARCHAR_ENUM:
for (Map.Entry<String, String> entry : ((VarcharEnumType) type).getEnumMap().entrySet()) {
keys.add(new StringLiteral(entry.getValue()));
values.add(new StringLiteral(entry.getKey()));
}
break;
default:
throw new SemanticException(TYPE_MISMATCH, "Unknown type: " + type);
}
return new FunctionCall(QualifiedName.of(FUNCTION_MAP),
ImmutableList.of(
new ArrayConstructor(keys.build()),
new ArrayConstructor(values.build())));
}
@Override
public Expression rewriteFunctionCall(FunctionCall node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
QualifiedName functionName = node.getName();
List<Expression> arguments = node.getArguments();
if (isValidEnumKeyFunctionCall(node)) {
Expression argument = arguments.get(0);
functionName = QualifiedName.of(FUNCTION_ELEMENT_AT);
Type argumentType;
if (argument instanceof Cast) {
argumentType = functionAndTypeResolver.getType(parseTypeSignature(((Cast) argument).getType()));
}
else if (argument instanceof DereferenceExpression) {
argumentType = functionAndTypeResolver.getType(parseTypeSignature(((DereferenceExpression) argument).getBase().toString()));
}
else {
// ENUM_KEY is only supported with Cast or DereferenceExpression for now.
// Return node without rewriting.
return node;
}
if (argumentType instanceof TypeWithName) {
// Rewrite ENUM_KEY(EnumType<T>) -> ELEMENT_AT(MAP(<T>, VARCHAR))
argumentType = ((TypeWithName) argumentType).getType();
Expression enumMapExpression = convertEnumTypeToMapExpression(argumentType);
Expression enumValue = treeRewriter.rewrite(argument, null);
if (argumentType instanceof EnumType) {
arguments = ImmutableList.of(enumMapExpression, enumValue);
}
}
}
else {
node = treeRewriter.defaultRewrite(node, null);
arguments = node.getArguments();
}
return node.getLocation().isPresent()
? new FunctionCall(node.getLocation().get(), functionName, node.getWindow(), node.getFilter(), node.getOrderBy(), node.isDistinct(), node.isIgnoreNulls(), arguments)
: new FunctionCall(functionName, node.getWindow(), node.getFilter(), node.getOrderBy(), node.isDistinct(), node.isIgnoreNulls(), arguments);
}
}
private static final class Rewriter
extends DefaultTreeRewriter<Void>
{
private final FunctionAndTypeResolver functionAndTypeResolver;
public Rewriter(FunctionAndTypeResolver functionAndTypeResolver)
{
this.functionAndTypeResolver = requireNonNull(functionAndTypeResolver, "functionAndTypeResolver is null");
}
@Override
protected Node visitExpression(Expression node, Void context)
{
return rewriteEnumExpressions(node, this.functionAndTypeResolver);
}
@Override
protected Node visitFunctionCall(FunctionCall node, Void context)
{
return rewriteEnumExpressions(node, this.functionAndTypeResolver);
}
}
}