SqlFunctionUtils.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.relational;
import com.facebook.presto.Session;
import com.facebook.presto.common.function.SqlFunctionProperties;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.expressions.RowExpressionRewriter;
import com.facebook.presto.expressions.RowExpressionTreeRewriter;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.function.FunctionMetadata;
import com.facebook.presto.spi.function.SqlFunctionId;
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.function.SqlInvokedScalarFunctionImplementation;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.ExpressionAnalysis;
import com.facebook.presto.sql.analyzer.FunctionAndTypeResolver;
import com.facebook.presto.sql.parser.ParsingOptions;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.PlannerUtils;
import com.facebook.presto.sql.planner.iterative.rule.LambdaCaptureDesugaringRewriter;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.Identifier;
import com.facebook.presto.sql.tree.LambdaArgumentDeclaration;
import com.facebook.presto.sql.tree.LambdaExpression;
import com.facebook.presto.sql.tree.NodeRef;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import static com.facebook.presto.spi.function.FunctionImplementationType.SQL;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.analyzeSqlFunctionExpression;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.createSymbolReference;
import static com.facebook.presto.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL;
import static com.facebook.presto.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DOUBLE;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.lang.String.format;
import static java.util.function.Function.identity;
public final class SqlFunctionUtils
{
private SqlFunctionUtils() {}
public static Expression getSqlFunctionExpression(
FunctionMetadata functionMetadata,
SqlInvokedScalarFunctionImplementation implementation,
FunctionAndTypeResolver functionAndTypeResolver,
VariableAllocator variableAllocator,
SqlFunctionProperties sqlFunctionProperties,
List<Expression> arguments)
{
Map<String, VariableReferenceExpression> argumentVariables = allocateFunctionArgumentVariables(functionMetadata, functionAndTypeResolver, variableAllocator);
Expression expression = getSqlFunctionImplementationExpression(functionMetadata, implementation, functionAndTypeResolver, variableAllocator, sqlFunctionProperties, argumentVariables);
return SqlFunctionArgumentBinder.bindFunctionArguments(
expression,
functionMetadata.getArgumentNames().get(),
arguments,
argumentVariables);
}
public static RowExpression getSqlFunctionRowExpression(
FunctionMetadata functionMetadata,
SqlInvokedScalarFunctionImplementation implementation,
FunctionAndTypeManager functionAndTypeManager,
SqlFunctionProperties sqlFunctionProperties,
Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions,
List<RowExpression> arguments)
{
VariableAllocator variableAllocator = new VariableAllocator();
Map<String, VariableReferenceExpression> argumentVariables = allocateFunctionArgumentVariables(functionMetadata, functionAndTypeManager.getFunctionAndTypeResolver(), variableAllocator);
Expression expression = getSqlFunctionImplementationExpression(functionMetadata, implementation, functionAndTypeManager.getFunctionAndTypeResolver(), variableAllocator, sqlFunctionProperties, argumentVariables);
// Translate to row expression
return SqlFunctionArgumentBinder.bindFunctionArguments(
SqlToRowExpressionTranslator.translate(
expression,
analyzeSqlFunctionExpression(
functionAndTypeManager.getFunctionAndTypeResolver(),
sqlFunctionProperties,
expression,
argumentVariables.values().stream()
.collect(toImmutableMap(VariableReferenceExpression::getName, VariableReferenceExpression::getType))).getExpressionTypes(),
ImmutableMap.of(),
functionAndTypeManager,
Optional.empty(),
Optional.empty(),
sqlFunctionProperties,
sessionFunctions,
new SqlToRowExpressionTranslator.Context()),
functionMetadata.getArgumentNames().get(),
arguments,
argumentVariables);
}
public static RowExpression sqlFunctionToRowExpression(String functionBody,
Set<VariableReferenceExpression> variables,
FunctionAndTypeManager functionAndTypeManager,
Session session,
Map<String, String> columnNameToInputVariableNameMap)
{
Expression expression = parseSqlFunctionExpression(
new SqlInvokedScalarFunctionImplementation(functionBody),
session.getSqlFunctionProperties());
// Translate the parameter name in functionBody to input variable name
expression = ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Map<String, String>>()
{
@Override
public Expression rewriteIdentifier(Identifier node, Map<String, String> context, ExpressionTreeRewriter<Map<String, String>> treeRewriter)
{
String name = node.getValueLowerCase();
if (context.containsKey(name)) {
return new Identifier(context.get(name));
}
return node;
}
}, expression, columnNameToInputVariableNameMap);
return SqlToRowExpressionTranslator.translate(
expression,
analyzeSqlFunctionExpression(
functionAndTypeManager.getFunctionAndTypeResolver(),
session.getSqlFunctionProperties(),
expression,
variables.stream()
.collect(toImmutableMap(
VariableReferenceExpression::getName,
VariableReferenceExpression::getType)))
.getExpressionTypes(),
ImmutableMap.of(),
functionAndTypeManager,
Optional.empty(),
Optional.empty(),
session.getSqlFunctionProperties(),
session.getSessionFunctions(),
new SqlToRowExpressionTranslator.Context());
}
private static Expression getSqlFunctionImplementationExpression(
FunctionMetadata functionMetadata,
SqlInvokedScalarFunctionImplementation implementation,
FunctionAndTypeResolver functionAndTypeResolver,
VariableAllocator variableAllocator,
SqlFunctionProperties sqlFunctionProperties,
Map<String, VariableReferenceExpression> argumentVariables)
{
checkArgument(functionMetadata.getImplementationType().equals(SQL), format("Expect SQL function, get %s", functionMetadata.getImplementationType()));
checkArgument(functionMetadata.getArgumentNames().isPresent(), "ArgumentNames is missing");
Expression expression = normalizeParameters(functionMetadata.getArgumentNames().get(), parseSqlFunctionExpression(implementation, sqlFunctionProperties));
ExpressionAnalysis functionAnalysis = analyzeSqlFunctionExpression(
functionAndTypeResolver,
sqlFunctionProperties,
expression,
argumentVariables.entrySet().stream()
.collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().getType())));
expression = coerceIfNecessary(expression, functionAnalysis);
return rewriteLambdaExpression(expression, argumentVariables, functionAnalysis, variableAllocator);
}
private static Expression normalizeParameters(List<String> argumentNames, Expression sqlFunction)
{
return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Map<String, String>>()
{
@Override
public Expression rewriteIdentifier(Identifier node, Map<String, String> context, ExpressionTreeRewriter<Map<String, String>> treeRewriter)
{
String name = node.getValueLowerCase();
if (context.containsKey(name)) {
return new Identifier(context.get(name));
}
return node;
}
}, sqlFunction, argumentNames.stream().collect(toImmutableMap(String::toLowerCase, identity())));
}
private static Expression parseSqlFunctionExpression(SqlInvokedScalarFunctionImplementation functionImplementation, SqlFunctionProperties sqlFunctionProperties)
{
ParsingOptions parsingOptions = ParsingOptions.builder()
.setDecimalLiteralTreatment(sqlFunctionProperties.isParseDecimalLiteralAsDouble() ? AS_DOUBLE : AS_DECIMAL)
.build();
// TODO: Use injector-created SqlParser, which could potentially be different from the adhoc SqlParser.
return new SqlParser().createReturn(functionImplementation.getImplementation(), parsingOptions).getExpression();
}
private static Map<String, VariableReferenceExpression> allocateFunctionArgumentVariables(FunctionMetadata functionMetadata, FunctionAndTypeResolver functionAndTypeResolver, VariableAllocator variableAllocator)
{
List<String> argumentNames = functionMetadata.getArgumentNames().get();
List<Type> argumentTypes = functionMetadata.getArgumentTypes().stream().map(functionAndTypeResolver::getType).collect(toImmutableList());
checkState(argumentNames.size() == argumentTypes.size(), format("Expect argumentNames (size %d) and argumentTypes (size %d) to be of the same size", argumentNames.size(), argumentTypes.size()));
ImmutableMap.Builder<String, VariableReferenceExpression> builder = ImmutableMap.builder();
for (int i = 0; i < argumentNames.size(); i++) {
builder.put(argumentNames.get(i), variableAllocator.newVariable(argumentNames.get(i), argumentTypes.get(i)));
}
return builder.build();
}
private static Expression rewriteLambdaExpression(Expression sqlFunction, Map<String, VariableReferenceExpression> arguments, ExpressionAnalysis functionAnalysis, VariableAllocator variableAllocator)
{
Map<NodeRef<Identifier>, LambdaArgumentDeclaration> lambdaArgumentReferences = functionAnalysis.getLambdaArgumentReferences();
Map<NodeRef<Expression>, Type> expressionTypes = functionAnalysis.getExpressionTypes();
// Rewrite reference to LambdaArgumentDeclaration
Map<NodeRef<LambdaArgumentDeclaration>, VariableReferenceExpression> variables = expressionTypes.entrySet().stream()
.filter(entry -> entry.getKey().getNode() instanceof LambdaArgumentDeclaration)
.distinct()
.collect(toImmutableMap(entry -> NodeRef.of((LambdaArgumentDeclaration) entry.getKey().getNode()), entry -> PlannerUtils.newVariable(variableAllocator, ((LambdaArgumentDeclaration) entry.getKey().getNode()).getName(), entry.getValue(), "lambda")));
Expression rewritten = ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Map<NodeRef<Identifier>, LambdaArgumentDeclaration>>()
{
@Override
public Expression rewriteLambdaExpression(LambdaExpression node, Map<NodeRef<Identifier>, LambdaArgumentDeclaration> context, ExpressionTreeRewriter<Map<NodeRef<Identifier>, LambdaArgumentDeclaration>> treeRewriter)
{
return new LambdaExpression(
node.getArguments().stream()
.map(argument -> new LambdaArgumentDeclaration(new Identifier(variables.get(NodeRef.of(argument)).getName())))
.collect(toImmutableList()),
treeRewriter.rewrite(node.getBody(), context));
}
@Override
public Expression rewriteIdentifier(Identifier node, Map<NodeRef<Identifier>, LambdaArgumentDeclaration> context, ExpressionTreeRewriter<Map<NodeRef<Identifier>, LambdaArgumentDeclaration>> treeRewriter)
{
NodeRef<Identifier> ref = NodeRef.of(node);
if (context.containsKey(ref)) {
return createSymbolReference(variables.get(NodeRef.of(context.get(ref))));
}
return node;
}
}, sqlFunction, lambdaArgumentReferences);
// Rewrite function input referenced in lambda
rewritten = ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Map<String, VariableReferenceExpression>>()
{
@Override
public Expression rewriteIdentifier(Identifier node, Map<String, VariableReferenceExpression> context, ExpressionTreeRewriter<Map<String, VariableReferenceExpression>> treeRewriter)
{
if (context.containsKey(node.getValue())) {
return createSymbolReference(context.get(node.getValue()));
}
return node;
}
}, rewritten, arguments);
// Desugar lambda capture
return LambdaCaptureDesugaringRewriter.rewrite(rewritten, variableAllocator);
}
private static Expression coerceIfNecessary(Expression sqlFunction, ExpressionAnalysis analysis)
{
return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<ExpressionAnalysis>()
{
@Override
public Expression rewriteExpression(Expression expression, ExpressionAnalysis context, ExpressionTreeRewriter<ExpressionAnalysis> treeRewriter)
{
Expression rewritten = treeRewriter.defaultRewrite(expression, null);
Type coercion = analysis.getCoercion(expression);
if (coercion != null) {
return new Cast(
rewritten,
coercion.getTypeSignature().toString(),
false,
analysis.isTypeOnlyCoercion(expression));
}
return rewritten;
}
}, sqlFunction, analysis);
}
private static final class SqlFunctionArgumentBinder
{
private SqlFunctionArgumentBinder() {}
public static Expression bindFunctionArguments(Expression function, List<String> argumentNames, List<Expression> argumentValues, Map<String, VariableReferenceExpression> argumentVariables)
{
checkArgument(argumentNames.size() == argumentValues.size(), format("Expect same size for argumentNames (%d) and argumentValues (%d)", argumentNames.size(), argumentValues.size()));
ImmutableMap.Builder<String, Expression> argumentBindings = ImmutableMap.builder();
for (int i = 0; i < argumentNames.size(); i++) {
String argumentName = argumentNames.get(i);
argumentBindings.put(argumentVariables.get(argumentName).getName(), argumentValues.get(i));
}
return ExpressionTreeRewriter.rewriteWith(new ExpressionFunctionVisitor(), function, argumentBindings.build());
}
public static RowExpression bindFunctionArguments(RowExpression function, List<String> argumentNames, List<RowExpression> argumentValues, Map<String, VariableReferenceExpression> argumentVariables)
{
checkArgument(argumentNames.size() == argumentValues.size(), format("Expect same size for argumentNames (%d) and argumentValues (%d)", argumentNames.size(), argumentValues.size()));
ImmutableMap.Builder<String, RowExpression> argumentBindings = ImmutableMap.builder();
for (int i = 0; i < argumentNames.size(); i++) {
String argumentName = argumentNames.get(i);
argumentBindings.put(argumentVariables.get(argumentName).getName(), argumentValues.get(i));
}
return RowExpressionTreeRewriter.rewriteWith(new RowExpressionRewriter<Map<String, RowExpression>>()
{
@Override
public RowExpression rewriteVariableReference(VariableReferenceExpression variable, Map<String, RowExpression> context, RowExpressionTreeRewriter<Map<String, RowExpression>> treeRewriter)
{
if (context.containsKey(variable.getName())) {
return context.get(variable.getName());
}
return variable;
}
}, function, argumentBindings.build());
}
private static class ExpressionFunctionVisitor
extends ExpressionRewriter<Map<String, Expression>>
{
@Override
public Expression rewriteSymbolReference(SymbolReference node, Map<String, Expression> context, ExpressionTreeRewriter<Map<String, Expression>> treeRewriter)
{
if (context.containsKey(node.getName())) {
return context.get(node.getName());
}
return node;
}
}
}
}