FunctionResource.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.server;
import com.facebook.airlift.json.JsonCodec;
import com.facebook.presto.PrestoMediaTypes;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.PageBuilder;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.block.BlockEncodingManager;
import com.facebook.presto.common.type.Decimals;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.functionNamespace.JsonBasedUdfFunctionMetadata;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.operator.scalar.BuiltInScalarFunctionImplementation;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.function.FunctionKind;
import com.facebook.presto.spi.function.RoutineCharacteristics;
import com.facebook.presto.spi.function.SqlFunction;
import com.facebook.presto.spi.function.SqlFunctionId;
import com.facebook.presto.spi.function.SqlFunctionVisibility;
import com.facebook.presto.spi.page.PagesSerde;
import com.facebook.presto.spi.page.SerializedPage;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import io.airlift.slice.BasicSliceInput;
import io.airlift.slice.DynamicSliceOutput;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import javax.inject.Inject;
import javax.ws.rs.Consumes;
import javax.ws.rs.GET;
import javax.ws.rs.HEAD;
import javax.ws.rs.POST;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.io.UnsupportedEncodingException;
import java.lang.invoke.MethodHandle;
import java.math.BigDecimal;
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_ARGUMENTS;
import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.DETERMINISTIC;
import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.NOT_DETERMINISTIC;
import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.JAVA;
import static com.facebook.presto.spi.function.RoutineCharacteristics.NullCallClause.CALLED_ON_NULL_INPUT;
import static com.facebook.presto.spi.function.RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT;
import static com.facebook.presto.spi.page.PagesSerdeUtil.readSerializedPage;
import static com.facebook.presto.spi.page.PagesSerdeUtil.writeSerializedPage;
import static io.airlift.slice.Slices.wrappedBuffer;
import static java.lang.Double.longBitsToDouble;
@Path("/v1/functions")
public class FunctionResource
{
private final FunctionAndTypeManager manager;
private final JsonCodec<Map<String, List<JsonBasedUdfFunctionMetadata>>> jsonCodec;
private final PagesSerde pagesSerde;
private String etag = "\"etag\"";
@Inject
public FunctionResource(FunctionAndTypeManager manager, JsonCodec<Map<String, List<JsonBasedUdfFunctionMetadata>>> jsonCodec)
{
this.manager = manager;
this.jsonCodec = jsonCodec;
this.pagesSerde = createPagesSerde();
}
@GET
@Produces(MediaType.APPLICATION_JSON)
public String getFunctions()
{
Map<String, List<JsonBasedUdfFunctionMetadata>> udfMap = new HashMap<>();
Collection<SqlFunction> builtInFunctions = manager.listBuiltInFunctions();
for (SqlFunction function : builtInFunctions) {
if (function.getSignature().getKind() != FunctionKind.SCALAR) {
continue;
}
if (function.getVisibility() == SqlFunctionVisibility.HIDDEN) {
continue;
}
JsonBasedUdfFunctionMetadata metadata = sqlFunctionToMetadata(function);
String functionName = function.getSignature().getName().getObjectName();
List<JsonBasedUdfFunctionMetadata> functionList = new ArrayList<>();
if (udfMap.containsKey(functionName)) {
functionList = udfMap.get(functionName);
}
functionList.add(metadata);
udfMap.put(functionName, functionList);
}
return jsonCodec.toJson(udfMap);
}
private static JsonBasedUdfFunctionMetadata sqlFunctionToMetadata(SqlFunction function)
{
return new JsonBasedUdfFunctionMetadata(
function.getDescription() != null ? function.getDescription() : "",
function.getSignature().getKind(),
function.getSignature().getReturnType(),
function.getSignature().getArgumentTypes(),
function.getSignature().getName().getSchemaName(),
function.getSignature().isVariableArity(),
new RoutineCharacteristics(
JAVA,
function.isDeterministic() ? DETERMINISTIC : NOT_DETERMINISTIC,
function.isCalledOnNullInput() ? CALLED_ON_NULL_INPUT : RETURNS_NULL_ON_NULL_INPUT),
Optional.empty(),
Optional.of(
new SqlFunctionId(
function.getSignature().getName(),
function.getSignature().getArgumentTypes())),
Optional.of("1"),
Optional.of(function.getSignature().getTypeVariableConstraints()),
Optional.ofNullable(function.getSignature().getLongVariableConstraints()),
Optional.empty());
}
@GET
@Path("/{schema}")
@Produces(MediaType.APPLICATION_JSON)
public String getFunctionsBySchema(@PathParam("schema") String schema)
{
Map<String, List<JsonBasedUdfFunctionMetadata>> udfMap = new HashMap<>();
Collection<SqlFunction> builtInFunctions = manager.listBuiltInFunctions();
for (SqlFunction function : builtInFunctions) {
if (!function.getSignature().getName().getSchemaName().equals(schema)) {
continue;
}
if (function.getSignature().getKind() != FunctionKind.SCALAR) {
continue;
}
if (function.getVisibility() == SqlFunctionVisibility.HIDDEN) {
continue;
}
String functionName = function.getSignature().getName().getObjectName();
List<JsonBasedUdfFunctionMetadata> functionList = new ArrayList<>();
if (udfMap.containsKey(functionName)) {
functionList = udfMap.get(functionName);
}
JsonBasedUdfFunctionMetadata metadata = sqlFunctionToMetadata(function);
functionList.add(metadata);
udfMap.put(functionName, functionList);
}
return jsonCodec.toJson(udfMap);
}
@GET
@Path("/{schema}/{functionName}")
@Produces(MediaType.APPLICATION_JSON)
public String getFunctionsBySchemaAndName(@PathParam("schema") String schema, @PathParam("functionName") String functionName)
{
Map<String, List<JsonBasedUdfFunctionMetadata>> udfMap = new HashMap<>();
Collection<SqlFunction> functionList = manager.listBuiltInFunctions();
List<JsonBasedUdfFunctionMetadata> filteredList = new ArrayList<>();
for (SqlFunction function : functionList) {
if (function.getSignature().getKind() != FunctionKind.SCALAR) {
continue;
}
if (function.getVisibility() == SqlFunctionVisibility.HIDDEN) {
continue;
}
if (function.getSignature().getName().getSchemaName().equals(schema) &&
function.getSignature().getName().getObjectName().equals(functionName)) {
filteredList.add(sqlFunctionToMetadata(function));
}
}
if (!filteredList.isEmpty()) {
udfMap.put(functionName, filteredList);
}
return jsonCodec.toJson(udfMap);
}
@POST
@Path("/{schema}/{functionName}/{functionId}/{version}")
@Consumes(PrestoMediaTypes.PRESTO_PAGES)
@Produces(PrestoMediaTypes.PRESTO_PAGES)
public byte[] execute(
@PathParam("schema") String schema,
@PathParam("functionName") String functionName,
@PathParam("functionId") String functionId,
@PathParam("version") String version,
byte[] serializedPageByteArray)
{
Slice slice = wrappedBuffer(serializedPageByteArray);
SerializedPage serializedPage = readSerializedPage(new BasicSliceInput(slice));
Page inputPage = pagesSerde.deserialize(serializedPage);
List<TypeSignatureProvider> argumentTypeSignatures = extractArgumentTypeSignatures(functionId);
Type[] types = new Type[argumentTypeSignatures.size()];
Block[] blocks = new Block[argumentTypeSignatures.size()];
for (int i = 0; i < argumentTypeSignatures.size(); i++) {
types[i] = manager.getType(argumentTypeSignatures.get(i).getTypeSignature());
blocks[i] = inputPage.getBlock(i);
}
FunctionHandle functionHandle = manager.lookupFunction(functionName, argumentTypeSignatures);
BuiltInScalarFunctionImplementation functionImplementation = (BuiltInScalarFunctionImplementation) manager.getJavaScalarFunctionImplementation(functionHandle);
int positionCount = inputPage.getPositionCount();
int channelCount = inputPage.getChannelCount();
Type returnType = manager.getType(manager.getFunctionMetadata(functionHandle).getReturnType());
PageBuilder pageBuilder = new PageBuilder(Collections.singletonList(returnType));
for (int position = 0; position < positionCount; position++) {
Object[] inputValues = new Object[blocks.length];
for (int i = 0; i < blocks.length; i++) {
if (blocks[i].isNull(position)) {
inputValues[i] = null;
}
else {
inputValues[i] = deserializeBlock(types[i], blocks[i].getRegion(position, 1));
}
}
Object result = executeFunction(functionImplementation, inputValues);
pageBuilder.declarePosition();
BlockBuilder output = pageBuilder.getBlockBuilder(0);
createResultBlock(output, returnType, result);
}
Page outputPage = pageBuilder.build();
DynamicSliceOutput sliceOutput = new DynamicSliceOutput((int) outputPage.getRetainedSizeInBytes());
writeSerializedPage(sliceOutput, pagesSerde.serialize(outputPage));
return sliceOutput.slice().byteArray();
}
public static List<TypeSignatureProvider> extractArgumentTypeSignatures(String encodedFunctionId)
{
String functionId;
try {
functionId = URLDecoder.decode(encodedFunctionId, StandardCharsets.UTF_8.toString());
}
catch (UnsupportedEncodingException e) {
throw new PrestoException(INVALID_ARGUMENTS, "Invalid functionId !");
}
SqlFunctionId sqlFunctionId = SqlFunctionId.parseSqlFunctionId(functionId);
return sqlFunctionId.getArgumentTypes().stream()
.map(TypeSignatureProvider::new)
.collect(Collectors.toList());
}
private Object deserializeBlock(Type type, Block block)
{
if (block.isNull(0)) {
return null;
}
switch (type.getTypeSignature().getBase()) {
case "boolean":
return block.getByte(0) != 0;
case "integer":
case "bigint":
case "smallint":
case "tinyint":
case "real":
case "interval day to second":
case "interval year to month":
case "timestamp":
case "time":
return block.toLong(0);
case "double":
return longBitsToDouble(block.toLong(0));
case "varchar":
case "varbinary":
return block.getSlice(0, 0, block.getSliceLength(0));
default:
throw new IllegalArgumentException("Unsupported type for deserialization: " + type);
}
}
private Object executeFunction(BuiltInScalarFunctionImplementation functionImplementation, Object[] arguments)
{
MethodHandle methodHandle = functionImplementation.getMethodHandle();
try {
return methodHandle.invokeWithArguments(arguments);
}
catch (Throwable throwable) {
throw new RuntimeException("Error during function execution", throwable);
}
}
private static PagesSerde createPagesSerde()
{
return new PagesSerde(new BlockEncodingManager(),
Optional.empty(),
Optional.empty(),
Optional.empty());
}
private void createResultBlock(BlockBuilder output, Type type, Object result)
{
switch (type.getTypeSignature().getBase()) {
case "integer":
case "bigint":
case "smallint":
case "tinyint":
case "real":
case "interval day to second":
case "interval year to month":
case "timestamp":
case "time":
type.writeLong(output, (Long) result);
break;
case "double":
type.writeDouble(output, (Double) result);
break;
case "boolean":
type.writeBoolean(output, (Boolean) result);
break;
case "varchar":
case "char":
case "json":
case "varbinary":
if (result instanceof Slice) {
type.writeSlice(output, (Slice) result);
}
else if (result instanceof byte[]) {
type.writeSlice(output, Slices.wrappedBuffer((byte[]) result));
}
else {
type.writeSlice(output, Slices.utf8Slice(result.toString()));
}
break;
case "decimal":
BigDecimal bd = (BigDecimal) result;
type.writeSlice(output, Decimals.encodeScaledValue(bd));
break;
case "object":
case "array":
case "row":
case "map":
type.writeObject(output, result);
break;
}
}
@HEAD
public Response getFunctionHeaders()
{
return Response.ok()
.header("ETag", etag)
.build();
}
}