TranslateExpressionsUtil.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner;
import com.facebook.presto.Session;
import com.facebook.presto.common.type.FunctionType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.operator.aggregation.BuiltInAggregationFunctionImplementation;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.function.JavaAggregationFunctionImplementation;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.security.DenyAllAccessControl;
import com.facebook.presto.sql.analyzer.Analysis;
import com.facebook.presto.sql.analyzer.RelationId;
import com.facebook.presto.sql.analyzer.RelationType;
import com.facebook.presto.sql.analyzer.Scope;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.LambdaArgumentDeclaration;
import com.facebook.presto.sql.tree.LambdaExpression;
import com.facebook.presto.sql.tree.NodeRef;
import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static com.facebook.presto.spi.WarningCollector.NOOP;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.analyzeExpression;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Collections.emptyMap;
public class TranslateExpressionsUtil
{
private TranslateExpressionsUtil() {}
public static RowExpression toRowExpression(Expression expression, Metadata metadata, Session session, SqlParser sqlParser, VariableAllocator variableAllocator, Analysis analysis, SqlToRowExpressionTranslator.Context context)
{
Scope scope = Scope.builder().withRelationType(RelationId.anonymous(), new RelationType()).build();
analyzeExpression(session,
metadata,
new DenyAllAccessControl(),
sqlParser,
scope,
TypeProvider.viewOf(variableAllocator.getVariables()),
analysis,
expression,
WarningCollector.NOOP,
analysis.getTypes()).getExpressionTypes();
return toRowExpression(
expression,
metadata,
session,
analysis.getTypes(), // We need to pass all types when translating subqueries. TODO(pranjalssh): Add a proper test for complex queries which need this
context);
}
public static RowExpression toRowExpression(Expression expression, Metadata metadata, Session session, Map<NodeRef<Expression>, Type> types, SqlToRowExpressionTranslator.Context context)
{
return SqlToRowExpressionTranslator.translate(expression, types, ImmutableMap.of(), metadata.getFunctionAndTypeManager(), session, context);
}
public static Map<NodeRef<Expression>, Type> analyzeCallExpressionTypes(
FunctionHandle functionHandle,
List<Expression> arguments,
Metadata metadata,
SqlParser sqlParser,
Session session,
TypeProvider typeProvider)
{
List<LambdaExpression> lambdaExpressions = arguments.stream()
.filter(LambdaExpression.class::isInstance)
.map(LambdaExpression.class::cast)
.collect(toImmutableList());
ImmutableMap.Builder<NodeRef<Expression>, Type> builder = ImmutableMap.<NodeRef<Expression>, Type>builder();
if (!lambdaExpressions.isEmpty()) {
List<FunctionType> functionTypes = metadata.getFunctionAndTypeManager().getFunctionMetadata(functionHandle).getArgumentTypes().stream()
.filter(typeSignature -> typeSignature.getBase().equals(FunctionType.NAME))
.map(typeSignature -> (FunctionType) (metadata.getFunctionAndTypeManager().getType(typeSignature)))
.collect(toImmutableList());
JavaAggregationFunctionImplementation javaAggregateFunctionImplementation = metadata.getFunctionAndTypeManager().getJavaAggregateFunctionImplementation(functionHandle);
if (javaAggregateFunctionImplementation instanceof BuiltInAggregationFunctionImplementation) {
List<Class> lambdaInterfaces = ((BuiltInAggregationFunctionImplementation) javaAggregateFunctionImplementation).getLambdaInterfaces();
verify(lambdaExpressions.size() == functionTypes.size());
verify(lambdaExpressions.size() == lambdaInterfaces.size());
}
for (int i = 0; i < lambdaExpressions.size(); i++) {
LambdaExpression lambdaExpression = lambdaExpressions.get(i);
FunctionType functionType = functionTypes.get(i);
// To compile lambda, LambdaDefinitionExpression needs to be generated from LambdaExpression,
// which requires the types of all sub-expressions.
//
// In project and filter expression compilation, ExpressionAnalyzer.getExpressionTypesFromInput
// is used to generate the types of all sub-expressions. (see visitScanFilterAndProject and visitFilter)
//
// This does not work here since the function call representation in final aggregation node
// is currently a hack: it takes intermediate type as input, and may not be a valid
// function call in Presto.
//
// TODO: Once the final aggregation function call representation is fixed,
// the same mechanism in project and filter expression should be used here.
verify(lambdaExpression.getArguments().size() == functionType.getArgumentTypes().size());
Map<NodeRef<Expression>, Type> lambdaArgumentExpressionTypes = new HashMap<>();
Map<String, Type> lambdaArgumentSymbolTypes = new HashMap<>();
for (int j = 0; j < lambdaExpression.getArguments().size(); j++) {
LambdaArgumentDeclaration argument = lambdaExpression.getArguments().get(j);
Type type = functionType.getArgumentTypes().get(j);
lambdaArgumentExpressionTypes.put(NodeRef.of(argument), type);
lambdaArgumentSymbolTypes.put(argument.getName().getValue(), type);
}
// the lambda expression itself
builder.put(NodeRef.of(lambdaExpression), functionType)
// expressions from lambda arguments
.putAll(lambdaArgumentExpressionTypes)
// expressions from lambda body
.putAll(getExpressionTypes(
session,
metadata,
sqlParser,
TypeProvider.copyOf(lambdaArgumentSymbolTypes),
lambdaExpression.getBody(),
emptyMap(),
NOOP));
}
}
for (Expression argument : arguments) {
if (argument instanceof LambdaExpression) {
continue;
}
builder.putAll(analyze(argument, metadata, sqlParser, session, typeProvider));
}
return builder.build();
}
private static Map<NodeRef<Expression>, Type> analyze(Expression expression, Metadata metadata, SqlParser sqlParser, Session session, TypeProvider typeProvider)
{
return getExpressionTypes(
session,
metadata,
sqlParser,
typeProvider,
expression,
emptyMap(),
NOOP);
}
}