MapToMapCast.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.scalar;
import com.facebook.presto.annotation.UsedByGeneratedCode;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.function.SqlFunctionProperties;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeSignatureParameter;
import com.facebook.presto.metadata.BoundVariables;
import com.facebook.presto.metadata.CastType;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.SqlOperator;
import com.facebook.presto.operator.aggregation.TypedSet;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.JavaScalarFunctionImplementation;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import static com.facebook.presto.common.block.MethodHandleUtil.compose;
import static com.facebook.presto.common.block.MethodHandleUtil.nativeValueGetter;
import static com.facebook.presto.common.block.MethodHandleUtil.nativeValueWriter;
import static com.facebook.presto.common.function.OperatorType.CAST;
import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.ArgumentProperty.valueTypeArgumentProperty;
import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.NullConvention.RETURN_NULL_ON_NULL;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_CAST_ARGUMENT;
import static com.facebook.presto.spi.function.Signature.typeVariable;
import static com.facebook.presto.util.Failures.internalError;
import static com.facebook.presto.util.Reflection.methodHandle;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.primitives.Primitives.unwrap;
import static java.lang.invoke.MethodHandles.permuteArguments;
import static java.lang.invoke.MethodType.methodType;
public final class MapToMapCast
extends SqlOperator
{
public static final MapToMapCast MAP_TO_MAP_CAST = new MapToMapCast();
private static final MethodHandle METHOD_HANDLE = methodHandle(
MapToMapCast.class,
"mapCast",
MethodHandle.class,
MethodHandle.class,
Type.class,
SqlFunctionProperties.class,
Block.class);
private static final MethodHandle CHECK_LONG_IS_NOT_NULL = methodHandle(MapToMapCast.class, "checkLongIsNotNull", Long.class);
private static final MethodHandle CHECK_DOUBLE_IS_NOT_NULL = methodHandle(MapToMapCast.class, "checkDoubleIsNotNull", Double.class);
private static final MethodHandle CHECK_BOOLEAN_IS_NOT_NULL = methodHandle(MapToMapCast.class, "checkBooleanIsNotNull", Boolean.class);
private static final MethodHandle CHECK_SLICE_IS_NOT_NULL = methodHandle(MapToMapCast.class, "checkSliceIsNotNull", Slice.class);
private static final MethodHandle CHECK_BLOCK_IS_NOT_NULL = methodHandle(MapToMapCast.class, "checkBlockIsNotNull", Block.class);
public MapToMapCast()
{
super(CAST,
ImmutableList.of(typeVariable("FK"), typeVariable("FV"), typeVariable("TK"), typeVariable("TV")),
ImmutableList.of(),
parseTypeSignature("map(TK,TV)"),
ImmutableList.of(parseTypeSignature("map(FK,FV)")));
}
@Override
public BuiltInScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager)
{
checkArgument(arity == 1, "Expected arity to be 1");
Type fromKeyType = boundVariables.getTypeVariable("FK");
Type fromValueType = boundVariables.getTypeVariable("FV");
Type toKeyType = boundVariables.getTypeVariable("TK");
Type toValueType = boundVariables.getTypeVariable("TV");
Type toMapType = functionAndTypeManager.getParameterizedType(
"map",
ImmutableList.of(
TypeSignatureParameter.of(toKeyType.getTypeSignature()),
TypeSignatureParameter.of(toValueType.getTypeSignature())));
MethodHandle keyProcessor = (fromKeyType == toKeyType) ? null : buildProcessor(functionAndTypeManager, fromKeyType, toKeyType, true);
MethodHandle valueProcessor = (fromValueType == toValueType) ? null : buildProcessor(functionAndTypeManager, fromValueType, toValueType, false);
MethodHandle target = MethodHandles.insertArguments(METHOD_HANDLE, 0, keyProcessor, valueProcessor, toMapType);
return new BuiltInScalarFunctionImplementation(true, ImmutableList.of(valueTypeArgumentProperty(RETURN_NULL_ON_NULL)), target);
}
/**
* The signature of the returned MethodHandle is (Block fromMap, int position, SqlFunctionProperties properties, BlockBuilder mapBlockBuilder)void.
* The processor will get the value from fromMap, cast it and write to toBlock.
*/
private MethodHandle buildProcessor(FunctionAndTypeManager functionAndTypeManager, Type fromType, Type toType, boolean isKey)
{
MethodHandle getter = nativeValueGetter(fromType);
// Adapt cast that takes ([SqlFunctionProperties,] ?) to one that takes (?, SqlFunctionProperties), where ? is the return type of getter.
JavaScalarFunctionImplementation castImplementation = functionAndTypeManager.getJavaScalarFunctionImplementation(functionAndTypeManager.lookupCast(CastType.CAST, fromType, toType));
MethodHandle cast = castImplementation.getMethodHandle();
if (cast.type().parameterArray()[0] != SqlFunctionProperties.class) {
cast = MethodHandles.dropArguments(cast, 0, SqlFunctionProperties.class);
}
cast = permuteArguments(cast, methodType(cast.type().returnType(), cast.type().parameterArray()[1], cast.type().parameterArray()[0]), 1, 0);
MethodHandle target = compose(cast, getter);
// If the key cast function is nullable, check the result is not null.
if (isKey && castImplementation.isNullable()) {
target = compose(nullChecker(target.type().returnType()), target);
}
MethodHandle writer = nativeValueWriter(toType);
writer = permuteArguments(writer, methodType(void.class, writer.type().parameterArray()[1], writer.type().parameterArray()[0]), 1, 0);
return compose(writer, target.asType(methodType(unwrap(target.type().returnType()), target.type().parameterArray())));
}
/**
* Returns a null checker MethodHandle that only returns the value when it is not null.
* <p>
* The signature of the returned MethodHandle could be one of the following depending on javaType:
* <ul>
* <li>(Long value)long
* <li>(Double value)double
* <li>(Boolean value)boolean
* <li>(Slice value)Slice
* <li>(Block value)Block
* </ul>
*/
private MethodHandle nullChecker(Class<?> javaType)
{
if (javaType == Long.class) {
return CHECK_LONG_IS_NOT_NULL;
}
else if (javaType == Double.class) {
return CHECK_DOUBLE_IS_NOT_NULL;
}
else if (javaType == Boolean.class) {
return CHECK_BOOLEAN_IS_NOT_NULL;
}
else if (javaType == Slice.class) {
return CHECK_SLICE_IS_NOT_NULL;
}
else if (javaType == Block.class) {
return CHECK_BLOCK_IS_NOT_NULL;
}
else {
throw new IllegalArgumentException("Unknown java type " + javaType);
}
}
@UsedByGeneratedCode
public static long checkLongIsNotNull(Long value)
{
if (value == null) {
throw new PrestoException(INVALID_CAST_ARGUMENT, "map key is null");
}
return value;
}
@UsedByGeneratedCode
public static double checkDoubleIsNotNull(Double value)
{
if (value == null) {
throw new PrestoException(INVALID_CAST_ARGUMENT, "map key is null");
}
return value;
}
@UsedByGeneratedCode
public static boolean checkBooleanIsNotNull(Boolean value)
{
if (value == null) {
throw new PrestoException(INVALID_CAST_ARGUMENT, "map key is null");
}
return value;
}
@UsedByGeneratedCode
public static Slice checkSliceIsNotNull(Slice value)
{
if (value == null) {
throw new PrestoException(INVALID_CAST_ARGUMENT, "map key is null");
}
return value;
}
@UsedByGeneratedCode
public static Block checkBlockIsNotNull(Block value)
{
if (value == null) {
throw new PrestoException(INVALID_CAST_ARGUMENT, "map key is null");
}
return value;
}
@UsedByGeneratedCode
public static Block mapCast(
MethodHandle keyProcessFunction,
MethodHandle valueProcessFunction,
Type toMapType,
SqlFunctionProperties properties,
Block fromMap)
{
if (keyProcessFunction == null && valueProcessFunction == null) {
return fromMap;
}
Type toKeyType = toMapType.getTypeParameters().get(0);
Type toValueType = toMapType.getTypeParameters().get(1);
BlockBuilder mapBlockBuilder = toMapType.createBlockBuilder(null, fromMap.getPositionCount());
BlockBuilder blockBuilder = mapBlockBuilder.beginBlockEntry();
if (keyProcessFunction == null) {
// common case
for (int i = 0; i < fromMap.getPositionCount(); i += 2) {
toKeyType.appendTo(fromMap, i, blockBuilder);
if (fromMap.isNull(i + 1)) {
blockBuilder.appendNull();
continue;
}
try {
valueProcessFunction.invokeExact(fromMap, i + 1, properties, blockBuilder);
}
catch (Throwable t) {
throw internalError(t);
}
}
}
else {
TypedSet typedSet = new TypedSet(toKeyType, fromMap.getPositionCount() / 2, "map-to-map cast");
for (int i = 0; i < fromMap.getPositionCount(); i += 2) {
try {
keyProcessFunction.invokeExact(fromMap, i, properties, blockBuilder);
}
catch (Throwable t) {
throw internalError(t);
}
if (!typedSet.add(blockBuilder, i)) {
throw new PrestoException(INVALID_CAST_ARGUMENT, "duplicate keys");
}
if (fromMap.isNull(i + 1)) {
blockBuilder.appendNull();
continue;
}
if (valueProcessFunction != null) {
try {
valueProcessFunction.invokeExact(fromMap, i + 1, properties, blockBuilder);
}
catch (Throwable t) {
throw internalError(t);
}
}
else {
toValueType.appendTo(fromMap, i + 1, blockBuilder);
}
}
}
mapBlockBuilder.closeEntry();
return (Block) toMapType.getObject(mapBlockBuilder, mapBlockBuilder.getPositionCount() - 1);
}
}