NativeFunctionNamespaceManager.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sidecar.functionNamespace;
import com.facebook.airlift.log.Logger;
import com.facebook.presto.common.CatalogSchemaName;
import com.facebook.presto.common.QualifiedObjectName;
import com.facebook.presto.common.type.NamedTypeSignature;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeManager;
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.common.type.TypeSignatureParameter;
import com.facebook.presto.common.type.UserDefinedType;
import com.facebook.presto.functionNamespace.AbstractSqlInvokedFunctionNamespaceManager;
import com.facebook.presto.functionNamespace.JsonBasedUdfFunctionMetadata;
import com.facebook.presto.functionNamespace.ServingCatalog;
import com.facebook.presto.functionNamespace.SqlInvokedFunctionNamespaceManagerConfig;
import com.facebook.presto.functionNamespace.UdfFunctionSignatureMap;
import com.facebook.presto.functionNamespace.execution.SqlFunctionExecutors;
import com.facebook.presto.spi.NodeManager;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.AggregationFunctionImplementation;
import com.facebook.presto.spi.function.AggregationFunctionMetadata;
import com.facebook.presto.spi.function.AlterRoutineCharacteristics;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.function.FunctionMetadata;
import com.facebook.presto.spi.function.FunctionMetadataManager;
import com.facebook.presto.spi.function.FunctionNamespaceTransactionHandle;
import com.facebook.presto.spi.function.LongVariableConstraint;
import com.facebook.presto.spi.function.Parameter;
import com.facebook.presto.spi.function.ScalarFunctionImplementation;
import com.facebook.presto.spi.function.Signature;
import com.facebook.presto.spi.function.SqlFunction;
import com.facebook.presto.spi.function.SqlFunctionHandle;
import com.facebook.presto.spi.function.SqlFunctionId;
import com.facebook.presto.spi.function.SqlFunctionSupplier;
import com.facebook.presto.spi.function.SqlInvokedAggregationFunctionImplementation;
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.function.TypeVariableConstraint;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Suppliers;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.UncheckedExecutionException;
import javax.inject.Inject;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;
import static com.facebook.presto.common.type.TypeSignatureUtils.resolveIntermediateType;
import static com.facebook.presto.spi.StandardErrorCode.DUPLICATE_FUNCTION_ERROR;
import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static com.facebook.presto.spi.StandardErrorCode.GENERIC_USER_ERROR;
import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED;
import static com.facebook.presto.spi.function.FunctionVersion.notVersioned;
import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.CPP;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.MoreCollectors.onlyElement;
import static java.lang.String.format;
import static java.util.Collections.unmodifiableMap;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.HOURS;
public class NativeFunctionNamespaceManager
extends AbstractSqlInvokedFunctionNamespaceManager
{
private static final Logger log = Logger.get(NativeFunctionNamespaceManager.class);
private final Map<QualifiedObjectName, UserDefinedType> userDefinedTypes = new ConcurrentHashMap<>();
private final Map<SqlFunctionHandle, AggregationFunctionImplementation> aggregationImplementationByHandle = new ConcurrentHashMap<>();
private final FunctionDefinitionProvider functionDefinitionProvider;
private final NodeManager nodeManager;
private final Map<SqlFunctionId, SqlInvokedFunction> functions = new ConcurrentHashMap<>();
private final Supplier<Map<SqlFunctionId, SqlInvokedFunction>> memoizedFunctionsSupplier;
private final FunctionMetadataManager functionMetadataManager;
private final LoadingCache<Signature, SqlFunctionSupplier> specializedFunctionKeyCache;
@Inject
public NativeFunctionNamespaceManager(
@ServingCatalog String catalogName,
SqlFunctionExecutors sqlFunctionExecutors,
SqlInvokedFunctionNamespaceManagerConfig config,
FunctionDefinitionProvider functionDefinitionProvider,
NodeManager nodeManager,
FunctionMetadataManager functionMetadataManager)
{
super(catalogName, sqlFunctionExecutors, config);
this.functionDefinitionProvider = requireNonNull(functionDefinitionProvider, "functionDefinitionProvider is null");
this.nodeManager = requireNonNull(nodeManager, "nodeManager is null");
this.memoizedFunctionsSupplier = Suppliers.memoize(this::bootstrapNamespace);
this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null");
this.specializedFunctionKeyCache = CacheBuilder.newBuilder()
.maximumSize(1000)
.expireAfterWrite(1, HOURS)
.build(CacheLoader.from(this::doGetSpecializedFunctionKey));
}
private SqlFunctionSupplier doGetSpecializedFunctionKey(Signature signature)
{
return functionMetadataManager.getSpecializedFunctionKey(signature);
}
private synchronized Map<SqlFunctionId, SqlInvokedFunction> bootstrapNamespace()
{
functions.clear();
UdfFunctionSignatureMap nativeFunctionSignatureMap = functionDefinitionProvider.getUdfDefinition(nodeManager);
if (nativeFunctionSignatureMap == null || nativeFunctionSignatureMap.isEmpty()) {
return ImmutableMap.of();
}
populateNamespaceManager(nativeFunctionSignatureMap);
checkArgument(!functions.isEmpty(), "functions map is empty !");
return unmodifiableMap(functions);
}
private synchronized void populateNamespaceManager(UdfFunctionSignatureMap udfFunctionSignatureMap)
{
Map<String, List<JsonBasedUdfFunctionMetadata>> udfSignatureMap = udfFunctionSignatureMap.getUDFSignatureMap();
udfSignatureMap.forEach((name, metaInfoList) -> {
List<SqlInvokedFunction> functions = metaInfoList.stream().map(metaInfo -> createSqlInvokedFunction(name, metaInfo)).collect(toImmutableList());
functions.forEach(this::createFunction);
});
}
@Override
public final AggregationFunctionImplementation getAggregateFunctionImplementation(FunctionHandle functionHandle, TypeManager typeManager)
{
checkCatalog(functionHandle);
checkArgument(functionHandle instanceof SqlFunctionHandle, "Unsupported FunctionHandle type '%s'", functionHandle.getClass().getSimpleName());
SqlFunctionHandle sqlFunctionHandle = (SqlFunctionHandle) functionHandle;
if (aggregationImplementationByHandle.containsKey(sqlFunctionHandle)) {
return aggregationImplementationByHandle.get(sqlFunctionHandle);
}
if (functionHandle instanceof NativeFunctionHandle) {
return processNativeFunctionHandle((NativeFunctionHandle) sqlFunctionHandle, typeManager);
}
else {
return processSqlFunctionHandle(sqlFunctionHandle, typeManager);
}
}
private AggregationFunctionImplementation processNativeFunctionHandle(NativeFunctionHandle nativeFunctionHandle, TypeManager typeManager)
{
Signature signature = nativeFunctionHandle.getSignature();
SqlFunction function = getSqlFunctionFromSignature(signature);
SqlInvokedFunction sqlFunction = (SqlInvokedFunction) function;
checkArgument(
sqlFunction.getAggregationMetadata().isPresent(),
"Need aggregationMetadata to get aggregation function implementation");
AggregationFunctionMetadata aggregationMetadata = sqlFunction.getAggregationMetadata().get();
TypeSignature intermediateType = aggregationMetadata.getIntermediateType();
TypeSignature resolvedIntermediateType = resolveIntermediateType(
intermediateType, sqlFunction.getFunctionId().getArgumentTypes(), signature.getArgumentTypes());
List<Type> parameters = signature.getArgumentTypes().stream().map(
(typeManager::getType)).collect(toImmutableList());
aggregationImplementationByHandle.put(
nativeFunctionHandle,
new SqlInvokedAggregationFunctionImplementation(
typeManager.getType(resolvedIntermediateType),
typeManager.getType(signature.getReturnType()),
aggregationMetadata.isOrderSensitive(),
parameters));
return aggregationImplementationByHandle.get(nativeFunctionHandle);
}
private AggregationFunctionImplementation processSqlFunctionHandle(SqlFunctionHandle sqlFunctionHandle, TypeManager typeManager)
{
SqlFunctionId functionId = sqlFunctionHandle.getFunctionId();
if (!memoizedFunctionsSupplier.get().containsKey(functionId)) {
throw new PrestoException(GENERIC_USER_ERROR, format("Function '%s' is missing from cache", functionId.getId()));
}
aggregationImplementationByHandle.put(
sqlFunctionHandle,
sqlInvokedFunctionToAggregationImplementation(memoizedFunctionsSupplier.get().get(functionId), typeManager));
return aggregationImplementationByHandle.get(sqlFunctionHandle);
}
protected synchronized SqlInvokedFunction createSqlInvokedFunction(String functionName, JsonBasedUdfFunctionMetadata jsonBasedUdfFunctionMetaData)
{
checkState(jsonBasedUdfFunctionMetaData.getRoutineCharacteristics().getLanguage().equals(CPP), "NativeFunctionNamespaceManager only supports CPP UDF");
QualifiedObjectName qualifiedFunctionName = QualifiedObjectName.valueOf(new CatalogSchemaName(getCatalogName(), jsonBasedUdfFunctionMetaData.getSchema()), functionName);
List<String> parameterNameList = jsonBasedUdfFunctionMetaData.getParamNames();
List<TypeSignature> parameterTypeList = convertApplicableTypeToVariable(jsonBasedUdfFunctionMetaData.getParamTypes());
List<TypeVariableConstraint> typeVariableConstraintsList = jsonBasedUdfFunctionMetaData.getTypeVariableConstraints().isPresent() ?
jsonBasedUdfFunctionMetaData.getTypeVariableConstraints().get() : Collections.emptyList();
List<LongVariableConstraint> longVariableConstraintList = jsonBasedUdfFunctionMetaData.getLongVariableConstraints().isPresent() ?
jsonBasedUdfFunctionMetaData.getLongVariableConstraints().get() : Collections.emptyList();
TypeSignature outputType = convertApplicableTypeToVariable(jsonBasedUdfFunctionMetaData.getOutputType());
ImmutableList.Builder<Parameter> parameterBuilder = ImmutableList.builder();
for (int i = 0; i < parameterNameList.size(); i++) {
parameterBuilder.add(new Parameter(parameterNameList.get(i), parameterTypeList.get(i)));
}
Optional<AggregationFunctionMetadata> aggregationFunctionMetadata =
jsonBasedUdfFunctionMetaData.getAggregateMetadata()
.map(metadata -> new AggregationFunctionMetadata(
convertApplicableTypeToVariable(metadata.getIntermediateType()),
metadata.isOrderSensitive()));
return new SqlInvokedFunction(
qualifiedFunctionName,
parameterBuilder.build(),
typeVariableConstraintsList,
longVariableConstraintList,
outputType,
jsonBasedUdfFunctionMetaData.getDocString(),
jsonBasedUdfFunctionMetaData.getRoutineCharacteristics(),
"",
jsonBasedUdfFunctionMetaData.getVariableArity(),
notVersioned(),
jsonBasedUdfFunctionMetaData.getFunctionKind(),
aggregationFunctionMetadata);
}
@Override
protected Collection<SqlInvokedFunction> fetchFunctionsDirect(QualifiedObjectName functionName)
{
return memoizedFunctionsSupplier.get().values().stream()
.filter(function -> function.getSignature().getName().equals(functionName))
.collect(toImmutableList());
}
@Override
protected UserDefinedType fetchUserDefinedTypeDirect(QualifiedObjectName typeName)
{
return userDefinedTypes.get(typeName);
}
@Override
protected FunctionMetadata fetchFunctionMetadataDirect(SqlFunctionHandle functionHandle)
{
if (functionHandle instanceof NativeFunctionHandle) {
return getMetadataFromNativeFunctionHandle(functionHandle);
}
return fetchFunctionsDirect(functionHandle.getFunctionId().getFunctionName()).stream()
.filter(function -> function.getRequiredFunctionHandle().equals(functionHandle))
.map(this::sqlInvokedFunctionToMetadata).collect(onlyElement());
}
@Override
protected ScalarFunctionImplementation fetchFunctionImplementationDirect(SqlFunctionHandle functionHandle)
{
return fetchFunctionsDirect(functionHandle.getFunctionId().getFunctionName()).stream()
.filter(function -> function.getRequiredFunctionHandle().equals(functionHandle))
.map(this::sqlInvokedFunctionToImplementation)
.collect(onlyElement());
}
@Override
public synchronized void createFunction(SqlInvokedFunction function, boolean replace)
{
throw new PrestoException(NOT_SUPPORTED, "Create Function is not supported in NativeFunctionNamespaceManager");
}
@Override
public void alterFunction(QualifiedObjectName functionName, Optional<List<TypeSignature>> parameterTypes, AlterRoutineCharacteristics alterRoutineCharacteristics)
{
throw new PrestoException(NOT_SUPPORTED, "Alter Function is not supported in NativeFunctionNamespaceManager");
}
@Override
public void dropFunction(QualifiedObjectName functionName, Optional<List<TypeSignature>> parameterTypes, boolean exists)
{
throw new PrestoException(NOT_SUPPORTED, "Drop Function is not supported in NativeFunctionNamespaceManager");
}
@Override
public Collection<SqlInvokedFunction> listFunctions(Optional<String> likePattern, Optional<String> escape)
{
return memoizedFunctionsSupplier.get().values();
}
@Override
public void addUserDefinedType(UserDefinedType userDefinedType)
{
QualifiedObjectName name = userDefinedType.getUserDefinedTypeName();
checkArgument(
!userDefinedTypes.containsKey(name),
"Parametric type %s already registered",
name);
userDefinedTypes.put(name, userDefinedType);
}
@Override
public final FunctionHandle getFunctionHandle(Optional<? extends FunctionNamespaceTransactionHandle> transactionHandle, Signature signature)
{
FunctionHandle functionHandle = super.getFunctionHandle(transactionHandle, signature);
// only handle generic variadic signatures here , for normal signature we use the AbstractSqlInvokedFunctionNamespaceManager function handle.
if (functionHandle == null) {
return new NativeFunctionHandle(signature);
}
return functionHandle;
}
// Todo: Improve the handling of parameter type differentiation in native execution.
// HACK: Currently, we lack support for correctly identifying the parameterKind, specifically between TYPE and VARIABLE,
// in native execution. The following utility functions help bridge this gap by parsing the type signature and verifying whether its base
// and parameters are of a supported type. The valid types list are non - parametric types that Presto supports.
public static TypeSignature convertApplicableTypeToVariable(TypeSignature typeSignature)
{
List<TypeSignature> typeSignaturesList = convertApplicableTypeToVariable(ImmutableList.of(typeSignature));
checkArgument(!typeSignaturesList.isEmpty(), "Type signature list is empty for : " + typeSignature);
return typeSignaturesList.get(0);
}
public static List<TypeSignature> convertApplicableTypeToVariable(List<TypeSignature> typeSignatures)
{
List<TypeSignature> newTypeSignaturesList = new ArrayList<>();
for (TypeSignature typeSignature : typeSignatures) {
if (!typeSignature.getParameters().isEmpty()) {
TypeSignature newTypeSignature =
new TypeSignature(
typeSignature.getBase(),
getTypeSignatureParameters(
typeSignature,
typeSignature.getParameters()));
newTypeSignaturesList.add(newTypeSignature);
}
else {
newTypeSignaturesList.add(typeSignature);
}
}
return newTypeSignaturesList;
}
@VisibleForTesting
public FunctionDefinitionProvider getFunctionDefinitionProvider()
{
return functionDefinitionProvider;
}
private static List<TypeSignatureParameter> getTypeSignatureParameters(
TypeSignature typeSignature,
List<TypeSignatureParameter> typeSignatureParameterList)
{
List<TypeSignatureParameter> newParameterTypeList = new ArrayList<>();
for (TypeSignatureParameter parameter : typeSignatureParameterList) {
if (parameter.isLongLiteral()) {
newParameterTypeList.add(parameter);
continue;
}
boolean isNamedTypeSignature = parameter.isNamedTypeSignature();
TypeSignature parameterTypeSignature;
// If it's a named type signatures only in the case of row signature types.
if (isNamedTypeSignature) {
parameterTypeSignature = parameter.getNamedTypeSignature().getTypeSignature();
}
else {
parameterTypeSignature = parameter.getTypeSignature();
}
if (parameterTypeSignature.getParameters().isEmpty()) {
boolean changeTypeToVariable = isDecimalTypeBase(typeSignature.getBase());
if (changeTypeToVariable) {
newParameterTypeList.add(
TypeSignatureParameter.of(parameterTypeSignature.getBase()));
}
else {
if (isNamedTypeSignature) {
newParameterTypeList.add(TypeSignatureParameter.of(parameter.getNamedTypeSignature()));
}
else {
newParameterTypeList.add(TypeSignatureParameter.of(parameterTypeSignature));
}
}
}
else {
TypeSignature newTypeSignature =
new TypeSignature(
parameterTypeSignature.getBase(),
getTypeSignatureParameters(
parameterTypeSignature.getStandardTypeSignature(),
parameterTypeSignature.getParameters()));
if (isNamedTypeSignature) {
newParameterTypeList.add(
TypeSignatureParameter.of(
new NamedTypeSignature(
Optional.empty(),
newTypeSignature)));
}
else {
newParameterTypeList.add(TypeSignatureParameter.of(newTypeSignature));
}
}
}
return newParameterTypeList;
}
private static boolean isDecimalTypeBase(String typeBase)
{
return typeBase.equals(StandardTypes.DECIMAL);
}
// Hack ends here
private synchronized void createFunction(SqlInvokedFunction function)
{
checkFunctionLanguageSupported(function);
SqlFunctionId functionId = function.getFunctionId();
if (functions.containsKey(function.getFunctionId())) {
throw new PrestoException(DUPLICATE_FUNCTION_ERROR, format("Function '%s' already exists", functionId.getId()));
}
functions.put(functionId, function.withVersion("1"));
}
private SqlFunction getSqlFunctionFromSignature(Signature signature)
{
try {
return specializedFunctionKeyCache.getUnchecked(signature).getFunction();
}
catch (UncheckedExecutionException e) {
throw convertToPrestoException(e, format("Error getting FunctionMetadata for signature: %s", signature));
}
}
private FunctionMetadata getMetadataFromNativeFunctionHandle(SqlFunctionHandle functionHandle)
{
NativeFunctionHandle nativeFunctionHandle = (NativeFunctionHandle) functionHandle;
Signature signature = nativeFunctionHandle.getSignature();
SqlFunction function = getSqlFunctionFromSignature(signature);
SqlInvokedFunction sqlFunction = (SqlInvokedFunction) function;
return new FunctionMetadata(
signature.getName(),
signature.getArgumentTypes(),
sqlFunction.getParameters().stream()
.map(Parameter::getName)
.collect(toImmutableList()),
signature.getReturnType(),
function.getSignature().getKind(),
sqlFunction.getRoutineCharacteristics().getLanguage(),
getFunctionImplementationType(sqlFunction),
function.isDeterministic(),
function.isCalledOnNullInput(),
sqlFunction.getVersion(),
function.getComplexTypeFunctionDescriptor());
}
private static PrestoException convertToPrestoException(UncheckedExecutionException exception, String failureMessage)
{
Throwable cause = exception.getCause();
if (cause instanceof PrestoException) {
return (PrestoException) cause;
}
return new PrestoException(GENERIC_INTERNAL_ERROR, failureMessage, cause);
}
}