MapTransformValueFunction.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.scalar;
import com.facebook.presto.bytecode.BytecodeBlock;
import com.facebook.presto.bytecode.BytecodeNode;
import com.facebook.presto.bytecode.CallSiteBinder;
import com.facebook.presto.bytecode.ClassDefinition;
import com.facebook.presto.bytecode.FieldDefinition;
import com.facebook.presto.bytecode.MethodDefinition;
import com.facebook.presto.bytecode.Parameter;
import com.facebook.presto.bytecode.Scope;
import com.facebook.presto.bytecode.Variable;
import com.facebook.presto.bytecode.control.ForLoop;
import com.facebook.presto.bytecode.control.IfStatement;
import com.facebook.presto.bytecode.control.TryCatch;
import com.facebook.presto.common.QualifiedObjectName;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.block.BlockBuilderStatus;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeSignatureParameter;
import com.facebook.presto.metadata.BoundVariables;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.SqlScalarFunction;
import com.facebook.presto.spi.ErrorCodeSupplier;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor;
import com.facebook.presto.spi.function.FunctionKind;
import com.facebook.presto.spi.function.LambdaArgumentDescriptor;
import com.facebook.presto.spi.function.LambdaDescriptor;
import com.facebook.presto.spi.function.Signature;
import com.facebook.presto.spi.function.SqlFunctionVisibility;
import com.facebook.presto.sql.gen.SqlTypeBytecodeExpression;
import com.facebook.presto.sql.gen.lambda.BinaryFunctionInterface;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.primitives.Primitives;
import java.lang.invoke.MethodHandle;
import java.util.Optional;
import static com.facebook.presto.bytecode.Access.FINAL;
import static com.facebook.presto.bytecode.Access.PRIVATE;
import static com.facebook.presto.bytecode.Access.PUBLIC;
import static com.facebook.presto.bytecode.Access.STATIC;
import static com.facebook.presto.bytecode.Access.a;
import static com.facebook.presto.bytecode.Parameter.arg;
import static com.facebook.presto.bytecode.ParameterizedType.type;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.add;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantInt;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantNull;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantString;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.equal;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.getStatic;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.invokeStatic;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.lessThan;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newInstance;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.setStatic;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.subtract;
import static com.facebook.presto.bytecode.instruction.VariableInstruction.incrementVariable;
import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.common.type.UnknownType.UNKNOWN;
import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE;
import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.ArgumentProperty.functionTypeArgumentProperty;
import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.ArgumentProperty.valueTypeArgumentProperty;
import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.NullConvention.RETURN_NULL_ON_NULL;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.facebook.presto.spi.function.Signature.typeVariable;
import static com.facebook.presto.sql.gen.SqlTypeBytecodeExpression.constantType;
import static com.facebook.presto.util.CompilerUtils.defineClass;
import static com.facebook.presto.util.CompilerUtils.makeClassName;
import static com.facebook.presto.util.Reflection.methodHandle;
public final class MapTransformValueFunction
extends SqlScalarFunction
{
public static final MapTransformValueFunction MAP_TRANSFORM_VALUE_FUNCTION = new MapTransformValueFunction();
private final ComplexTypeFunctionDescriptor descriptor;
private MapTransformValueFunction()
{
super(new Signature(
QualifiedObjectName.valueOf(JAVA_BUILTIN_NAMESPACE, "transform_values"),
FunctionKind.SCALAR,
ImmutableList.of(typeVariable("K"), typeVariable("V1"), typeVariable("V2")),
ImmutableList.of(),
parseTypeSignature("map(K,V2)"),
ImmutableList.of(parseTypeSignature("map(K,V1)"), parseTypeSignature("function(K,V1,V2)")),
false));
descriptor = new ComplexTypeFunctionDescriptor(
true,
ImmutableList.of(new LambdaDescriptor(1, ImmutableMap.of(1, new LambdaArgumentDescriptor(0, ComplexTypeFunctionDescriptor::prependAllSubscripts)))),
Optional.of(ImmutableSet.of(0)),
Optional.of(ComplexTypeFunctionDescriptor::clearRequiredSubfields),
getSignature());
}
@Override
public final SqlFunctionVisibility getVisibility()
{
return SqlFunctionVisibility.PUBLIC;
}
@Override
public boolean isDeterministic()
{
return false;
}
@Override
public String getDescription()
{
return "apply lambda to each entry of the map and transform the value";
}
@Override
public ComplexTypeFunctionDescriptor getComplexTypeFunctionDescriptor()
{
return descriptor;
}
@Override
public BuiltInScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager)
{
Type keyType = boundVariables.getTypeVariable("K");
Type valueType = boundVariables.getTypeVariable("V1");
Type transformedValueType = boundVariables.getTypeVariable("V2");
Type resultMapType = functionAndTypeManager.getParameterizedType(StandardTypes.MAP, ImmutableList.of(
TypeSignatureParameter.of(keyType.getTypeSignature()),
TypeSignatureParameter.of(transformedValueType.getTypeSignature())));
return new BuiltInScalarFunctionImplementation(
false,
ImmutableList.of(
valueTypeArgumentProperty(RETURN_NULL_ON_NULL),
functionTypeArgumentProperty(BinaryFunctionInterface.class)),
generateTransform(keyType, valueType, transformedValueType, resultMapType));
}
private static MethodHandle generateTransform(Type keyType, Type valueType, Type transformedValueType, Type resultMapType)
{
CallSiteBinder binder = new CallSiteBinder();
Class<?> keyJavaType = Primitives.wrap(keyType.getJavaType());
Class<?> valueJavaType = Primitives.wrap(valueType.getJavaType());
Class<?> transformedValueJavaType = Primitives.wrap(transformedValueType.getJavaType());
ClassDefinition definition = new ClassDefinition(
a(PUBLIC, FINAL),
makeClassName("MapTransformValue"),
type(Object.class));
definition.declareDefaultConstructor(a(PRIVATE));
// Add a root blockbuilder
FieldDefinition rootBlockBuilder = definition.declareField(a(PRIVATE, STATIC, FINAL), "BLOCK_BUILDER", BlockBuilder.class);
definition.getClassInitializer().getBody().append(setStatic(rootBlockBuilder, constantType(binder, resultMapType).invoke("createBlockBuilder", BlockBuilder.class, constantNull(BlockBuilderStatus.class), constantInt(10))));
// define transform method
Parameter block = arg("block", Block.class);
Parameter function = arg("function", BinaryFunctionInterface.class);
MethodDefinition method = definition.declareMethod(
a(PUBLIC, STATIC),
"transform",
type(Block.class),
ImmutableList.of(block, function));
BytecodeBlock body = method.getBody();
Scope scope = method.getScope();
Variable positionCount = scope.declareVariable(int.class, "positionCount");
Variable position = scope.declareVariable(int.class, "position");
Variable mapBlockBuilder = scope.declareVariable(BlockBuilder.class, "mapBlockBuilder");
Variable blockBuilder = scope.declareVariable(BlockBuilder.class, "blockBuilder");
Variable keyElement = scope.declareVariable(keyJavaType, "keyElement");
Variable valueElement = scope.declareVariable(valueJavaType, "valueElement");
Variable transformedValueElement = scope.declareVariable(transformedValueJavaType, "transformedValueElement");
// invoke block.getPositionCount()
body.append(positionCount.set(block.invoke("getPositionCount", int.class)));
// prepare the single map block builder
body.append(mapBlockBuilder.set(getStatic(rootBlockBuilder).invoke("newBlockBuilderLike", BlockBuilder.class, constantNull(BlockBuilderStatus.class), positionCount)));
body.append(blockBuilder.set(mapBlockBuilder.invoke("beginBlockEntry", BlockBuilder.class)));
// throw null key exception block
BytecodeNode throwNullKeyException = new BytecodeBlock()
.append(newInstance(
PrestoException.class,
getStatic(INVALID_FUNCTION_ARGUMENT.getDeclaringClass(), "INVALID_FUNCTION_ARGUMENT").cast(ErrorCodeSupplier.class),
constantString("map key cannot be null")))
.throwObject();
SqlTypeBytecodeExpression keySqlType = constantType(binder, keyType);
BytecodeNode loadKeyElement;
if (!keyType.equals(UNKNOWN)) {
loadKeyElement = new BytecodeBlock().append(keyElement.set(keySqlType.getValue(block, position).cast(keyJavaType)));
}
else {
// make sure invokeExact will not take uninitialized keys during compile time
// but if we reach this point during runtime, it is an exception
// also close the block builder before throwing as we may be in a TRY() call
// so that subsequent calls do not find it in an inconsistent state
loadKeyElement = new BytecodeBlock()
.append(mapBlockBuilder.invoke("closeEntry", BlockBuilder.class).pop())
.append(keyElement.set(constantNull(keyJavaType)))
.append(throwNullKeyException);
}
SqlTypeBytecodeExpression valueSqlType = constantType(binder, valueType);
BytecodeNode loadValueElement;
if (!valueType.equals(UNKNOWN)) {
loadValueElement = new IfStatement()
.condition(block.invoke("isNull", boolean.class, add(position, constantInt(1))))
.ifTrue(valueElement.set(constantNull(valueJavaType)))
.ifFalse(valueElement.set(valueSqlType.getValue(block, add(position, constantInt(1))).cast(valueJavaType)));
}
else {
loadValueElement = new BytecodeBlock().append(valueElement.set(constantNull(valueJavaType)));
}
BytecodeNode writeTransformedValueElement;
if (!transformedValueType.equals(UNKNOWN)) {
writeTransformedValueElement = new IfStatement()
.condition(equal(transformedValueElement, constantNull(transformedValueJavaType)))
.ifTrue(blockBuilder.invoke("appendNull", BlockBuilder.class).pop())
.ifFalse(constantType(binder, transformedValueType).writeValue(blockBuilder, transformedValueElement.cast(transformedValueType.getJavaType())));
}
else {
writeTransformedValueElement = new BytecodeBlock().append(blockBuilder.invoke("appendNull", BlockBuilder.class).pop());
}
Variable transformationException = scope.declareVariable(Throwable.class, "transformationException");
body.append(new ForLoop()
.initialize(position.set(constantInt(0)))
.condition(lessThan(position, positionCount))
.update(incrementVariable(position, (byte) 2))
.body(new BytecodeBlock()
.append(loadKeyElement)
.append(loadValueElement)
.append(
new TryCatch(
"Close builder before throwing to avoid subsequent calls finding it in an inconsistent state if we are in in a TRY() call.",
transformedValueElement.set(function.invoke("apply", Object.class, keyElement.cast(Object.class), valueElement.cast(Object.class))
.cast(transformedValueJavaType)),
new BytecodeBlock()
.append(mapBlockBuilder.invoke("closeEntry", BlockBuilder.class).pop())
.putVariable(transformationException)
.append(invokeStatic(Throwables.class, "throwIfUnchecked", void.class, transformationException))
.append(newInstance(RuntimeException.class, transformationException))
.throwObject(),
type(Throwable.class)))
.append(keySqlType.invoke("appendTo", void.class, block, position, blockBuilder))
.append(writeTransformedValueElement)));
body.append(mapBlockBuilder
.invoke("closeEntry", BlockBuilder.class)
.pop());
body.append(constantType(binder, resultMapType)
.invoke(
"getObject",
Object.class,
mapBlockBuilder.cast(Block.class),
subtract(mapBlockBuilder.invoke("getPositionCount", int.class), constantInt(1)))
.ret());
Class<?> generatedClass = defineClass(definition, Object.class, binder.getBindings(), MapTransformValueFunction.class.getClassLoader());
return methodHandle(generatedClass, "transform", Block.class, BinaryFunctionInterface.class);
}
}