RowExpressionCompiler.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.gen;
import com.facebook.presto.bytecode.Binding;
import com.facebook.presto.bytecode.BytecodeBlock;
import com.facebook.presto.bytecode.BytecodeNode;
import com.facebook.presto.bytecode.CallSiteBinder;
import com.facebook.presto.bytecode.ClassDefinition;
import com.facebook.presto.bytecode.Scope;
import com.facebook.presto.bytecode.Variable;
import com.facebook.presto.common.function.SqlFunctionProperties;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.function.FunctionMetadata;
import com.facebook.presto.spi.function.SqlFunctionId;
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.function.SqlInvokedScalarFunctionImplementation;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.InputReferenceExpression;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.RowExpressionVisitor;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.gen.LambdaBytecodeGenerator.CompiledLambda;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantTrue;
import static com.facebook.presto.bytecode.instruction.Constant.loadBoolean;
import static com.facebook.presto.bytecode.instruction.Constant.loadDouble;
import static com.facebook.presto.bytecode.instruction.Constant.loadFloat;
import static com.facebook.presto.bytecode.instruction.Constant.loadInt;
import static com.facebook.presto.bytecode.instruction.Constant.loadLong;
import static com.facebook.presto.bytecode.instruction.Constant.loadString;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.OR;
import static com.facebook.presto.sql.gen.BytecodeUtils.generateWrite;
import static com.facebook.presto.sql.gen.BytecodeUtils.loadConstant;
import static com.facebook.presto.sql.gen.LambdaBytecodeGenerator.generateLambda;
import static com.facebook.presto.sql.gen.LambdaBytecodeGenerator.generateMethodsForLambda;
import static com.facebook.presto.sql.relational.SqlFunctionUtils.getSqlFunctionRowExpression;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.lang.String.format;
public class RowExpressionCompiler
{
private final ClassDefinition classDefinition;
private final CallSiteBinder callSiteBinder;
private final CachedInstanceBinder cachedInstanceBinder;
private final RowExpressionVisitor<BytecodeNode, Scope> fieldReferenceCompiler;
private final Metadata metadata;
private final SqlFunctionProperties sqlFunctionProperties;
private final Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions;
private final Map<LambdaDefinitionExpression, CompiledLambda> compiledLambdaMap;
private final AtomicInteger lambdaCounter;
RowExpressionCompiler(
ClassDefinition classDefinition,
CallSiteBinder callSiteBinder,
CachedInstanceBinder cachedInstanceBinder,
RowExpressionVisitor<BytecodeNode, Scope> fieldReferenceCompiler,
Metadata metadata,
SqlFunctionProperties sqlFunctionProperties,
Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions,
Map<LambdaDefinitionExpression, CompiledLambda> compiledLambdaMap,
AtomicInteger lambdaCounter)
{
this.classDefinition = classDefinition;
this.callSiteBinder = callSiteBinder;
this.cachedInstanceBinder = cachedInstanceBinder;
this.fieldReferenceCompiler = fieldReferenceCompiler;
this.metadata = metadata;
this.sqlFunctionProperties = sqlFunctionProperties;
this.sessionFunctions = sessionFunctions;
this.compiledLambdaMap = new HashMap<>(compiledLambdaMap);
this.lambdaCounter = lambdaCounter;
}
public BytecodeNode compile(RowExpression rowExpression, Scope scope, Optional<Variable> outputBlockVariable)
{
return compile(rowExpression, scope, outputBlockVariable, Optional.empty());
}
// When outputBlockVariable is presented, the generated bytecode will write the evaluated value into the outputBlockVariable,
// otherwise the value will be left on stack.
public BytecodeNode compile(RowExpression rowExpression, Scope scope, Optional<Variable> outputBlockVariable, Optional<Class> lambdaInterface)
{
return rowExpression.accept(new Visitor(), new Context(scope, outputBlockVariable, lambdaInterface));
}
private class Visitor
implements RowExpressionVisitor<BytecodeNode, Context>
{
@Override
public BytecodeNode visitCall(CallExpression call, Context context)
{
FunctionAndTypeManager functionAndTypeManager = metadata.getFunctionAndTypeManager();
FunctionMetadata functionMetadata = functionAndTypeManager.getFunctionMetadata(call.getFunctionHandle());
BytecodeGeneratorContext generatorContext;
switch (functionMetadata.getImplementationType()) {
case JAVA:
// Pre-compile lambda bytecode and update compiled lambda map
compiledLambdaMap.putAll(generateMethodsForLambda(
classDefinition,
callSiteBinder,
cachedInstanceBinder,
call,
metadata,
sqlFunctionProperties,
sessionFunctions,
"sql" + call.hashCode(),
compiledLambdaMap.keySet(),
lambdaCounter));
generatorContext = new BytecodeGeneratorContext(
RowExpressionCompiler.this,
context.getScope(),
callSiteBinder,
cachedInstanceBinder,
functionAndTypeManager);
return (new FunctionCallCodeGenerator()).generateCall(call.getFunctionHandle(), generatorContext, call.getType(), call.getArguments(), context.getOutputBlockVariable());
case SQL:
SqlInvokedScalarFunctionImplementation functionImplementation = (SqlInvokedScalarFunctionImplementation) functionAndTypeManager.getScalarFunctionImplementation(call.getFunctionHandle());
RowExpression function = getSqlFunctionRowExpression(
functionMetadata,
functionImplementation,
metadata.getFunctionAndTypeManager(),
sqlFunctionProperties,
sessionFunctions,
call.getArguments());
// Pre-compile lambda bytecode and update compiled lambda map
compiledLambdaMap.putAll(generateMethodsForLambda(
classDefinition,
callSiteBinder,
cachedInstanceBinder,
function,
metadata,
sqlFunctionProperties,
sessionFunctions,
"sql",
compiledLambdaMap.keySet(),
lambdaCounter));
// generate bytecode for SQL function
RowExpressionCompiler newRowExpressionCompiler = new RowExpressionCompiler(
classDefinition,
callSiteBinder,
cachedInstanceBinder,
fieldReferenceCompiler,
metadata,
sqlFunctionProperties,
sessionFunctions,
compiledLambdaMap,
lambdaCounter);
// If called on null input, directly use the generated bytecode
if (functionMetadata.isCalledOnNullInput() || call.getArguments().isEmpty()) {
return newRowExpressionCompiler.compile(
function,
context.getScope(),
context.getOutputBlockVariable(),
context.getLambdaInterface());
}
// If returns null on null input, generate if(any input is null, null, generated bytecode)
generatorContext = new BytecodeGeneratorContext(
newRowExpressionCompiler,
context.getScope(),
callSiteBinder,
cachedInstanceBinder,
functionAndTypeManager);
return (new IfCodeGenerator()).generateExpression(
generatorContext,
call.getType(),
ImmutableList.of(
call.getArguments().stream()
.map(argument -> new SpecialFormExpression(IS_NULL, BOOLEAN, argument))
.reduce((a, b) -> new SpecialFormExpression(OR, BOOLEAN, a, b)).get(),
new ConstantExpression(null, call.getType()),
function),
context.getOutputBlockVariable());
default:
throw new IllegalArgumentException(format("Unsupported function implementation type: %s", functionMetadata.getImplementationType()));
}
}
@Override
public BytecodeNode visitConstant(ConstantExpression constant, Context context)
{
Object value = constant.getValue();
Class<?> javaType = constant.getType().getJavaType();
BytecodeBlock block = new BytecodeBlock();
if (value == null) {
block.comment("constant null")
.append(context.getScope().getVariable("wasNull").set(constantTrue()))
.pushJavaDefault(javaType);
}
else {
// use LDC for primitives (boolean, short, int, long, float, double)
block.comment("constant " + constant.getType().getTypeSignature());
if (javaType == boolean.class) {
block.append(loadBoolean((Boolean) value));
}
else if (javaType == byte.class || javaType == short.class || javaType == int.class) {
block.append(loadInt(((Number) value).intValue()));
}
else if (javaType == long.class) {
block.append(loadLong((Long) value));
}
else if (javaType == float.class) {
block.append(loadFloat((Float) value));
}
else if (javaType == double.class) {
block.append(loadDouble((Double) value));
}
else if (javaType == String.class) {
block.append(loadString((String) value));
}
else {
// bind constant object directly into the call-site using invoke dynamic
Binding binding = callSiteBinder.bind(value, constant.getType().getJavaType());
block = new BytecodeBlock()
.setDescription("constant " + constant.getType())
.comment(constant.toString())
.append(loadConstant(binding));
}
}
if (context.getOutputBlockVariable().isPresent()) {
block.append(generateWrite(
callSiteBinder,
context.getScope(),
context.getScope().getVariable("wasNull"),
constant.getType(),
context.getOutputBlockVariable().get()));
}
return block;
}
@Override
public BytecodeNode visitInputReference(InputReferenceExpression node, Context context)
{
BytecodeNode inputReferenceBytecode = fieldReferenceCompiler.visitInputReference(node, context.getScope());
if (!context.getOutputBlockVariable().isPresent()) {
return inputReferenceBytecode;
}
return new BytecodeBlock()
.append(inputReferenceBytecode)
.append(generateWrite(
callSiteBinder,
context.getScope(),
context.getScope().getVariable("wasNull"),
node.getType(),
context.getOutputBlockVariable().get()));
}
@Override
public BytecodeNode visitLambda(LambdaDefinitionExpression lambda, Context context)
{
checkArgument(!context.getOutputBlockVariable().isPresent(), "lambda definition expression does not support writing to block");
checkState(compiledLambdaMap.containsKey(lambda), "lambda expressions map does not contain this lambda definition");
if (!context.lambdaInterface.get().isAnnotationPresent(FunctionalInterface.class)) {
// lambdaInterface is checked to be annotated with FunctionalInterface when generating ScalarFunctionImplementation
throw new VerifyException("lambda should be generated as class annotated with FunctionalInterface");
}
BytecodeGeneratorContext generatorContext = new BytecodeGeneratorContext(
RowExpressionCompiler.this,
context.getScope(),
callSiteBinder,
cachedInstanceBinder,
metadata.getFunctionAndTypeManager());
return generateLambda(
generatorContext,
ImmutableList.of(),
compiledLambdaMap.get(lambda),
context.getLambdaInterface().get());
}
@Override
public BytecodeNode visitVariableReference(VariableReferenceExpression reference, Context context)
{
BytecodeNode variableReferenceByteCode = fieldReferenceCompiler.visitVariableReference(reference, context.getScope());
if (!context.getOutputBlockVariable().isPresent()) {
return variableReferenceByteCode;
}
return new BytecodeBlock()
.append(variableReferenceByteCode)
.append(generateWrite(
callSiteBinder,
context.getScope(),
context.getScope().getVariable("wasNull"),
reference.getType(),
context.getOutputBlockVariable().get()));
}
@Override
public BytecodeNode visitSpecialForm(SpecialFormExpression specialForm, Context context)
{
SpecialFormBytecodeGenerator generator;
switch (specialForm.getForm()) {
// lazy evaluation
case IF:
generator = new IfCodeGenerator();
break;
case NULL_IF:
generator = new NullIfCodeGenerator();
break;
case SWITCH:
// (SWITCH <expr> (WHEN <expr> <expr>) (WHEN <expr> <expr>) <expr>)
generator = new SwitchCodeGenerator();
break;
// functions that take null as input
case IS_NULL:
generator = new IsNullCodeGenerator();
break;
case COALESCE:
generator = new CoalesceCodeGenerator();
break;
// functions that require varargs and/or complex types (e.g., lists)
case IN:
generator = new InCodeGenerator(metadata.getFunctionAndTypeManager());
break;
// optimized implementations (shortcircuiting behavior)
case AND:
generator = new AndCodeGenerator();
break;
case OR:
generator = new OrCodeGenerator();
break;
case DEREFERENCE:
generator = new DereferenceCodeGenerator();
break;
case ROW_CONSTRUCTOR:
generator = new RowConstructorCodeGenerator();
break;
case BIND:
generator = new BindCodeGenerator(compiledLambdaMap, context.getLambdaInterface().get());
break;
default:
throw new IllegalStateException("Cannot compile special form: " + specialForm.getForm());
}
BytecodeGeneratorContext generatorContext = new BytecodeGeneratorContext(
RowExpressionCompiler.this,
context.getScope(),
callSiteBinder,
cachedInstanceBinder,
metadata.getFunctionAndTypeManager());
return generator.generateExpression(generatorContext, specialForm.getType(), specialForm.getArguments(), context.getOutputBlockVariable());
}
}
private static class Context
{
private final Scope scope;
private final Optional<Variable> outputBlockVariable;
private final Optional<Class> lambdaInterface;
public Context(Scope scope, Optional<Variable> outputBlockVariable, Optional<Class> lambdaInterface)
{
this.scope = scope;
this.outputBlockVariable = outputBlockVariable;
this.lambdaInterface = lambdaInterface;
}
public Scope getScope()
{
return scope;
}
public Optional<Variable> getOutputBlockVariable()
{
return outputBlockVariable;
}
public Optional<Class> getLambdaInterface()
{
return lambdaInterface;
}
}
}