PolymorphicScalarFunction.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.metadata;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.PolymorphicScalarFunctionBuilder.MethodAndNativeContainerTypes;
import com.facebook.presto.metadata.PolymorphicScalarFunctionBuilder.MethodsGroup;
import com.facebook.presto.metadata.PolymorphicScalarFunctionBuilder.SpecializeContext;
import com.facebook.presto.operator.scalar.BuiltInScalarFunctionImplementation;
import com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice;
import com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.ArgumentProperty;
import com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.NullConvention;
import com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.ReturnPlaceConvention;
import com.facebook.presto.spi.function.Signature;
import com.facebook.presto.spi.function.SqlFunctionVisibility;
import com.facebook.presto.util.Reflection;
import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Primitives;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Optional;
import static com.facebook.presto.metadata.SignatureBinder.applyBoundVariables;
import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.NullConvention.BLOCK_AND_POSITION;
import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.NullConvention.USE_NULL_FLAG;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Collections.emptyList;
import static java.util.Objects.requireNonNull;
class PolymorphicScalarFunction
extends SqlScalarFunction
{
private final String description;
private final SqlFunctionVisibility visibility;
private final boolean deterministic;
private final boolean calledOnNullInput;
private final List<PolymorphicScalarFunctionChoice> choices;
PolymorphicScalarFunction(
Signature signature,
String description,
SqlFunctionVisibility visibility,
boolean deterministic,
boolean calledOnNullInput,
List<PolymorphicScalarFunctionChoice> choices)
{
super(signature);
this.description = description;
this.visibility = visibility;
this.deterministic = deterministic;
this.calledOnNullInput = calledOnNullInput;
this.choices = requireNonNull(choices, "choices is null");
}
@Override
public SqlFunctionVisibility getVisibility()
{
return visibility;
}
@Override
public boolean isDeterministic()
{
return deterministic;
}
@Override
public boolean isCalledOnNullInput()
{
return calledOnNullInput;
}
@Override
public String getDescription()
{
return description;
}
@Override
public BuiltInScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager)
{
ImmutableList.Builder<ScalarFunctionImplementationChoice> implementationChoices = ImmutableList.builder();
for (PolymorphicScalarFunctionChoice choice : choices) {
implementationChoices.add(getScalarFunctionImplementationChoice(boundVariables, functionAndTypeManager, choice));
}
return new BuiltInScalarFunctionImplementation(implementationChoices.build());
}
private ScalarFunctionImplementationChoice getScalarFunctionImplementationChoice(
BoundVariables boundVariables,
FunctionAndTypeManager functionAndTypeManager,
PolymorphicScalarFunctionChoice choice)
{
List<Type> resolvedParameterTypes = applyBoundVariables(functionAndTypeManager, getSignature().getArgumentTypes(), boundVariables);
Type resolvedReturnType = applyBoundVariables(functionAndTypeManager, getSignature().getReturnType(), boundVariables);
SpecializeContext context = new SpecializeContext(boundVariables, resolvedParameterTypes, resolvedReturnType, functionAndTypeManager);
Optional<MethodAndNativeContainerTypes> matchingMethod = Optional.empty();
Optional<MethodsGroup> matchingMethodsGroup = Optional.empty();
for (MethodsGroup candidateMethodsGroup : choice.getMethodsGroups()) {
for (MethodAndNativeContainerTypes candidateMethod : candidateMethodsGroup.getMethods()) {
if (matchesParameterAndReturnTypes(candidateMethod, resolvedParameterTypes, resolvedReturnType, choice.getArgumentProperties(), choice.isNullableResult())) {
if (matchingMethod.isPresent()) {
throw new IllegalStateException("two matching methods (" + matchingMethod.get().getMethod().getName() + " and " + candidateMethod.getMethod().getName() + ") for parameter types " + resolvedParameterTypes);
}
matchingMethod = Optional.of(candidateMethod);
matchingMethodsGroup = Optional.of(candidateMethodsGroup);
}
}
}
checkState(matchingMethod.isPresent(), "no matching method for parameter types %s", resolvedParameterTypes);
List<Object> extraParameters = computeExtraParameters(matchingMethodsGroup.get(), context);
MethodHandle methodHandle = applyExtraParameters(matchingMethod.get().getMethod(), extraParameters, choice.getArgumentProperties());
return new ScalarFunctionImplementationChoice(choice.isNullableResult(), choice.getArgumentProperties(), choice.getReturnPlaceConvention(), methodHandle, Optional.empty());
}
private static boolean matchesParameterAndReturnTypes(
MethodAndNativeContainerTypes methodAndNativeContainerTypes,
List<Type> resolvedTypes,
Type returnType,
List<ArgumentProperty> argumentProperties,
boolean nullableResult)
{
Method method = methodAndNativeContainerTypes.getMethod();
checkState(method.getParameterCount() >= resolvedTypes.size(),
"method %s has not enough arguments: %s (should have at least %s)", method.getName(), method.getParameterCount(), resolvedTypes.size());
Class<?>[] methodParameterJavaTypes = method.getParameterTypes();
for (int i = 0, methodParameterIndex = 0; i < resolvedTypes.size(); i++) {
NullConvention nullConvention = argumentProperties.get(i).getNullConvention();
Class<?> expectedType = null;
Class<?> actualType;
switch (nullConvention) {
case RETURN_NULL_ON_NULL:
case USE_NULL_FLAG:
expectedType = methodParameterJavaTypes[methodParameterIndex];
actualType = getNullAwareContainerType(resolvedTypes.get(i).getJavaType(), false);
break;
case USE_BOXED_TYPE:
expectedType = methodParameterJavaTypes[methodParameterIndex];
actualType = getNullAwareContainerType(resolvedTypes.get(i).getJavaType(), true);
break;
case BLOCK_AND_POSITION:
Optional<Class<?>> explicitNativeContainerTypes = methodAndNativeContainerTypes.getExplicitNativeContainerTypes().get(i);
if (explicitNativeContainerTypes.isPresent()) {
expectedType = explicitNativeContainerTypes.get();
}
actualType = getNullAwareContainerType(resolvedTypes.get(i).getJavaType(), false);
break;
default:
throw new UnsupportedOperationException("unknown NullConvention");
}
if (!actualType.equals(expectedType)) {
return false;
}
methodParameterIndex += nullConvention.getParameterCount();
}
return method.getReturnType().equals(getNullAwareContainerType(returnType.getJavaType(), nullableResult));
}
private static List<Object> computeExtraParameters(MethodsGroup methodsGroup, SpecializeContext context)
{
return methodsGroup.getExtraParametersFunction().map(function -> function.apply(context)).orElse(emptyList());
}
private static int getNullFlagsCount(List<ArgumentProperty> argumentProperties)
{
return (int) argumentProperties.stream()
.filter(argumentProperty -> argumentProperty.getNullConvention() == USE_NULL_FLAG)
.count();
}
private static int getBlockPositionCount(List<ArgumentProperty> argumentProperties)
{
return (int) argumentProperties.stream()
.filter(argumentProperty -> argumentProperty.getNullConvention() == BLOCK_AND_POSITION)
.count();
}
private MethodHandle applyExtraParameters(Method matchingMethod, List<Object> extraParameters, List<ArgumentProperty> argumentProperties)
{
Signature signature = getSignature();
int expectedArgumentsCount = signature.getArgumentTypes().size() + getNullFlagsCount(argumentProperties) + getBlockPositionCount(argumentProperties) + extraParameters.size();
int matchingMethodArgumentCount = matchingMethod.getParameterCount();
checkState(matchingMethodArgumentCount == expectedArgumentsCount,
"method %s has invalid number of arguments: %s (should have %s)", matchingMethod.getName(), matchingMethodArgumentCount, expectedArgumentsCount);
MethodHandle matchingMethodHandle = Reflection.methodHandle(matchingMethod);
matchingMethodHandle = MethodHandles.insertArguments(
matchingMethodHandle,
matchingMethodArgumentCount - extraParameters.size(),
extraParameters.toArray());
return matchingMethodHandle;
}
private static Class<?> getNullAwareContainerType(Class<?> clazz, boolean nullable)
{
if (nullable) {
return Primitives.wrap(clazz);
}
return clazz;
}
static final class PolymorphicScalarFunctionChoice
{
private final boolean nullableResult;
private final List<ArgumentProperty> argumentProperties;
private final ReturnPlaceConvention returnPlaceConvention;
private final List<MethodsGroup> methodsGroups;
PolymorphicScalarFunctionChoice(
boolean nullableResult,
List<ArgumentProperty> argumentProperties,
ReturnPlaceConvention returnPlaceConvention,
List<MethodsGroup> methodsGroups)
{
this.nullableResult = nullableResult;
this.argumentProperties = ImmutableList.copyOf(requireNonNull(argumentProperties, "argumentProperties is null"));
this.returnPlaceConvention = requireNonNull(returnPlaceConvention, "returnPlaceConvention is null");
this.methodsGroups = ImmutableList.copyOf(requireNonNull(methodsGroups, "methodsWithExtraParametersFunctions is null"));
}
boolean isNullableResult()
{
return nullableResult;
}
List<MethodsGroup> getMethodsGroups()
{
return methodsGroups;
}
List<ArgumentProperty> getArgumentProperties()
{
return argumentProperties;
}
ReturnPlaceConvention getReturnPlaceConvention()
{
return returnPlaceConvention;
}
}
}