TestFunctionAndTypeManager.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.metadata;

import com.facebook.presto.Session;
import com.facebook.presto.common.QualifiedObjectName;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.operator.scalar.BuiltInScalarFunctionImplementation;
import com.facebook.presto.operator.scalar.CustomFunctions;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.function.Parameter;
import com.facebook.presto.spi.function.RoutineCharacteristics;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.Signature;
import com.facebook.presto.spi.function.SqlFunction;
import com.facebook.presto.spi.function.SqlFunctionId;
import com.facebook.presto.spi.function.SqlFunctionVisibility;
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeVariableConstraint;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.tree.QualifiedName;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.testng.annotations.Test;

import java.lang.invoke.MethodHandles;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.SystemSessionProperties.EXPERIMENTAL_FUNCTIONS_ENABLED;
import static com.facebook.presto.common.function.OperatorType.CAST;
import static com.facebook.presto.common.function.OperatorType.SATURATED_FLOOR_CAST;
import static com.facebook.presto.common.function.OperatorType.tryGetOperatorType;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.HyperLogLogType.HYPER_LOG_LOG;
import static com.facebook.presto.common.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE;
import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager;
import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.ArgumentProperty.valueTypeArgumentProperty;
import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.NullConvention.RETURN_NULL_ON_NULL;
import static com.facebook.presto.spi.function.FunctionKind.SCALAR;
import static com.facebook.presto.spi.function.FunctionVersion.notVersioned;
import static com.facebook.presto.spi.function.Signature.typeVariable;
import static com.facebook.presto.spi.function.SqlFunctionVisibility.PUBLIC;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypeSignatures;
import static com.facebook.presto.sql.planner.LiteralEncoder.getMagicLiteralFunctionSignature;
import static com.facebook.presto.sql.tree.ArithmeticBinaryExpression.Operator.ADD;
import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.GREATER_THAN;
import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Lists.transform;
import static java.lang.String.format;
import static java.util.Collections.nCopies;
import static java.util.stream.Collectors.toList;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;

public class TestFunctionAndTypeManager
{
    @Test
    public void testIdentityCast()
    {
        FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager();
        FunctionHandle exactOperator = functionAndTypeManager.lookupCast(CastType.CAST, HYPER_LOG_LOG, HYPER_LOG_LOG);
        assertEquals(exactOperator, new BuiltInFunctionHandle(new Signature(CAST.getFunctionName(), SCALAR, HYPER_LOG_LOG.getTypeSignature(), HYPER_LOG_LOG.getTypeSignature())));
    }

    @Test
    public void testExactMatchBeforeCoercion()
    {
        FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager();
        boolean foundOperator = false;
        for (SqlFunction function : functionAndTypeManager.listOperators()) {
            OperatorType operatorType = tryGetOperatorType(function.getSignature().getName()).get();
            if (operatorType == CAST || operatorType == SATURATED_FLOOR_CAST) {
                continue;
            }
            if (!function.getSignature().getTypeVariableConstraints().isEmpty()) {
                continue;
            }
            if (function.getSignature().getArgumentTypes().stream().anyMatch(TypeSignature::isCalculated)) {
                continue;
            }
            BuiltInFunctionHandle exactOperator = (BuiltInFunctionHandle) functionAndTypeManager.resolveOperator(operatorType, fromTypeSignatures(function.getSignature().getArgumentTypes()));
            assertEquals(exactOperator.getSignature(), function.getSignature());
            foundOperator = true;
        }
        assertTrue(foundOperator);
    }

    @Test
    public void testMagicLiteralFunction()
    {
        Signature signature = getMagicLiteralFunctionSignature(TIMESTAMP_WITH_TIME_ZONE);
        assertEquals(signature.getNameSuffix(), "$literal$timestamp with time zone");
        assertEquals(signature.getArgumentTypes(), ImmutableList.of(parseTypeSignature(StandardTypes.BIGINT)));
        assertEquals(signature.getReturnType().getBase(), StandardTypes.TIMESTAMP_WITH_TIME_ZONE);

        FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager();
        BuiltInFunctionHandle functionHandle = (BuiltInFunctionHandle) functionAndTypeManager.resolveFunction(
                Optional.empty(),
                TEST_SESSION.getTransactionId(),
                signature.getName(),
                fromTypeSignatures(signature.getArgumentTypes()));
        assertEquals(functionAndTypeManager.getFunctionMetadata(functionHandle).getArgumentTypes(), ImmutableList.of(parseTypeSignature(StandardTypes.BIGINT)));
        assertEquals(signature.getReturnType().getBase(), StandardTypes.TIMESTAMP_WITH_TIME_ZONE);
    }

    @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "\\QFunction already registered: presto.default.custom_add(bigint,bigint):bigint\\E")
    public void testDuplicateFunctions()
    {
        List<SqlFunction> functions = new FunctionListBuilder()
                .scalars(CustomFunctions.class)
                .getFunctions()
                .stream()
                .filter(input -> input.getSignature().getNameSuffix().equals("custom_add"))
                .collect(toImmutableList());

        FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager();
        functionAndTypeManager.registerBuiltInFunctions(functions);
        functionAndTypeManager.registerBuiltInFunctions(functions);
    }

    @Test(expectedExceptions = IllegalStateException.class, expectedExceptionsMessageRegExp = "'presto.default.sum' is both an aggregation and a scalar function")
    public void testConflictingScalarAggregation()
    {
        List<SqlFunction> functions = new FunctionListBuilder()
                .scalars(ScalarSum.class)
                .getFunctions();

        FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager();
        functionAndTypeManager.registerBuiltInFunctions(functions);
    }

    @Test
    public void testListingVisibilityBetaFunctionsDisabled()
    {
        FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager();
        List<SqlFunction> functions = functionAndTypeManager.listFunctions(TEST_SESSION, Optional.empty(), Optional.empty());
        List<String> names = transform(functions, input -> input.getSignature().getNameSuffix());

        assertTrue(names.contains("length"), "Expected function names " + names + " to contain 'length'");
        assertTrue(names.contains("stddev"), "Expected function names " + names + " to contain 'stddev'");
        assertTrue(names.contains("rank"), "Expected function names " + names + " to contain 'rank'");
        assertFalse(names.contains("quantiles_at_values"), "Expected function names " + names + " not to contain 'quantiles_at_values'");
        assertFalse(names.contains("like"), "Expected function names " + names + " not to contain 'like'");
        assertFalse(names.contains("sum_data_size_for_stats"), "Expected function names " + names + " not to contain 'sum_data_size_for_stats'");
        assertFalse(names.contains("max_data_size_for_stats"), "Expected function names " + names + " not to contain 'max_data_size_for_stats'");
    }

    @Test
    public void testListingVisibilityBetaFunctionsEnabled()
    {
        Session session = testSessionBuilder()
                .setCatalog("tpch")
                .setSchema(TINY_SCHEMA_NAME)
                .setSystemProperty(EXPERIMENTAL_FUNCTIONS_ENABLED, "true")
                .build();
        FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager();
        List<SqlFunction> functions = functionAndTypeManager.listFunctions(session, Optional.empty(), Optional.empty());
        List<String> names = transform(functions, input -> input.getSignature().getNameSuffix());

        assertTrue(names.contains("length"), "Expected function names " + names + " to contain 'length'");
        assertTrue(names.contains("stddev"), "Expected function names " + names + " to contain 'stddev'");
        assertTrue(names.contains("rank"), "Expected function names " + names + " to contain 'rank'");
        assertTrue(names.contains("tdigest_agg"), "Expected function names " + names + " to contain 'tdigest_agg'");
        assertTrue(names.contains("quantiles_at_values"), "Expected function names " + names + " to contain 'tdigest_agg'");
        assertFalse(names.contains("like"), "Expected function names " + names + " not to contain 'like'");
        assertFalse(names.contains("sum_data_size_for_stats"), "Expected function names " + names + " not to contain 'sum_data_size_for_stats'");
        assertFalse(names.contains("max_data_size_for_stats"), "Expected function names " + names + " not to contain 'max_data_size_for_stats'");
    }

    @Test
    public void testOperatorTypes()
    {
        FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager();
        FunctionResolution functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());

        assertTrue(functionAndTypeManager.getFunctionMetadata(functionResolution.arithmeticFunction(ADD, BIGINT, BIGINT)).getOperatorType().map(OperatorType::isArithmeticOperator).orElse(false));
        assertFalse(functionAndTypeManager.getFunctionMetadata(functionResolution.arithmeticFunction(ADD, BIGINT, BIGINT)).getOperatorType().map(OperatorType::isComparisonOperator).orElse(true));
        assertTrue(functionAndTypeManager.getFunctionMetadata(functionResolution.comparisonFunction(GREATER_THAN, BIGINT, BIGINT)).getOperatorType().map(OperatorType::isComparisonOperator).orElse(false));
        assertFalse(functionAndTypeManager.getFunctionMetadata(functionResolution.comparisonFunction(GREATER_THAN, BIGINT, BIGINT)).getOperatorType().map(OperatorType::isArithmeticOperator).orElse(true));
        assertFalse(functionAndTypeManager.getFunctionMetadata(functionResolution.notFunction()).getOperatorType().isPresent());
    }

    @Test
    public void testSessionFunctions()
    {
        FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager();

        SqlFunctionId bigintSignature = new SqlFunctionId(QualifiedObjectName.valueOf("presto.default.foo"), ImmutableList.of(parseTypeSignature("bigint")));
        SqlInvokedFunction bigintFunction = new SqlInvokedFunction(
                bigintSignature.getFunctionName(),
                ImmutableList.of(new Parameter("x", parseTypeSignature("bigint"))),
                parseTypeSignature("bigint"),
                "",
                RoutineCharacteristics.builder().build(),
                "",
                notVersioned());

        SqlFunctionId varcharSignature = new SqlFunctionId(QualifiedObjectName.valueOf("presto.default.foo"), ImmutableList.of(parseTypeSignature("varchar")));
        SqlInvokedFunction varcharFunction = new SqlInvokedFunction(
                bigintSignature.getFunctionName(),
                ImmutableList.of(new Parameter("x", parseTypeSignature("varchar"))),
                parseTypeSignature("varchar"),
                "",
                RoutineCharacteristics.builder().build(),
                "",
                notVersioned());

        Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions = ImmutableMap.of(bigintSignature, bigintFunction, varcharSignature, varcharFunction);

        assertEquals(
                functionAndTypeManager.resolveFunction(
                        Optional.of(sessionFunctions),
                        Optional.empty(),
                        bigintSignature.getFunctionName(),
                        ImmutableList.of(new TypeSignatureProvider(parseTypeSignature("bigint")))),
                new SessionFunctionHandle(bigintFunction));
        assertEquals(
                functionAndTypeManager.resolveFunction(
                        Optional.of(sessionFunctions),
                        Optional.empty(),
                        varcharSignature.getFunctionName(),
                        ImmutableList.of(new TypeSignatureProvider(parseTypeSignature("varchar")))),
                new SessionFunctionHandle(varcharFunction));
        assertEquals(
                functionAndTypeManager.resolveFunction(
                        Optional.of(sessionFunctions),
                        Optional.empty(),
                        bigintSignature.getFunctionName(),
                        ImmutableList.of(new TypeSignatureProvider(parseTypeSignature("int")))),
                new SessionFunctionHandle(bigintFunction));
    }

    @Test
    public void testResolveFunctionByExactMatch()
    {
        assertThatResolveFunction()
                .among(functionSignature("bigint", "bigint"))
                .forParameters("bigint", "bigint")
                .returns(functionSignature("bigint", "bigint"));
    }

    @Test
    public void testResolveTypeParametrizedFunction()
    {
        assertThatResolveFunction()
                .among(functionSignature(ImmutableList.of("T", "T"), "boolean", ImmutableList.of(typeVariable("T"))))
                .forParameters("bigint", "bigint")
                .returns(functionSignature("bigint", "bigint"));
    }

    @Test
    public void testResolveFunctionWithCoercion()
    {
        assertThatResolveFunction()
                .among(
                        functionSignature("decimal(p,s)", "double"),
                        functionSignature("decimal(p,s)", "decimal(p,s)"),
                        functionSignature("double", "double"))
                .forParameters("bigint", "bigint")
                .returns(functionSignature("decimal(19,0)", "decimal(19,0)"));
    }

    @Test
    public void testAmbiguousCallWithNoCoercion()
    {
        assertThatResolveFunction()
                .among(
                        functionSignature("decimal(p,s)", "decimal(p,s)"),
                        functionSignature(ImmutableList.of("T", "T"), "boolean", ImmutableList.of(typeVariable("T"))))
                .forParameters("decimal(3,1)", "decimal(3,1)")
                .returns(functionSignature("decimal(3,1)", "decimal(3,1)"));
    }

    @Test
    public void testAmbiguousCallWithCoercion()
    {
        assertThatResolveFunction()
                .among(
                        functionSignature("decimal(p,s)", "double"),
                        functionSignature("double", "decimal(p,s)"))
                .forParameters("bigint", "bigint")
                .failsWithMessage("Could not choose a best candidate operator. Explicit type casts must be added.");
    }

    @Test
    public void testResolveFunctionWithCoercionInTypes()
    {
        assertThatResolveFunction()
                .among(
                        functionSignature("array(decimal(p,s))", "array(double)"),
                        functionSignature("array(decimal(p,s))", "array(decimal(p,s))"),
                        functionSignature("array(double)", "array(double)"))
                .forParameters("array(bigint)", "array(bigint)")
                .returns(functionSignature("array(decimal(19,0))", "array(decimal(19,0))"));
    }

    @Test
    public void testResolveFunctionWithVariableArity()
    {
        assertThatResolveFunction()
                .among(
                        functionSignature("double", "double", "double"),
                        functionSignature("decimal(p,s)").setVariableArity(true))
                .forParameters("bigint", "bigint", "bigint")
                .returns(functionSignature("decimal(19,0)", "decimal(19,0)", "decimal(19,0)"));

        assertThatResolveFunction()
                .among(
                        functionSignature("double", "double", "double"),
                        functionSignature("bigint").setVariableArity(true))
                .forParameters("bigint", "bigint", "bigint")
                .returns(functionSignature("bigint", "bigint", "bigint"));
    }

    @Test
    public void testResolveFunctionWithVariadicBound()
    {
        assertThatResolveFunction()
                .among(
                        functionSignature("bigint", "bigint", "bigint"),
                        functionSignature(
                                ImmutableList.of("T1", "T2", "T3"),
                                "boolean",
                                ImmutableList.of(Signature.withVariadicBound("T1", "decimal"),
                                        Signature.withVariadicBound("T2", "decimal"),
                                        Signature.withVariadicBound("T3", "decimal"))))
                .forParameters("unknown", "bigint", "bigint")
                .returns(functionSignature("bigint", "bigint", "bigint"));
    }

    @Test
    public void testResolveFunctionForUnknown()
    {
        assertThatResolveFunction()
                .among(
                        functionSignature("bigint"))
                .forParameters("unknown")
                .returns(functionSignature("bigint"));

        // when coercion between the types exist, and the most specific function can be determined with the main algorithm
        assertThatResolveFunction()
                .among(
                        functionSignature("bigint"),
                        functionSignature("integer"))
                .forParameters("unknown")
                .returns(functionSignature("integer"));

        // function that requires only unknown coercion must be preferred
        assertThatResolveFunction()
                .among(
                        functionSignature("bigint", "bigint"),
                        functionSignature("integer", "integer"))
                .forParameters("unknown", "bigint")
                .returns(functionSignature("bigint", "bigint"));

        // when coercion between the types doesn't exist, but the return type is the same, so the random function must be chosen
        assertThatResolveFunction()
                .among(
                        functionSignature(ImmutableList.of("JoniRegExp"), "boolean"),
                        functionSignature(ImmutableList.of("integer"), "boolean"))
                .forParameters("unknown")
                // any function can be selected, but to make it deterministic we sort function signatures alphabetically
                .returns(functionSignature("integer"));

        // when the return type is different
        assertThatResolveFunction()
                .among(
                        functionSignature(ImmutableList.of("JoniRegExp"), "JoniRegExp"),
                        functionSignature(ImmutableList.of("integer"), "integer"))
                .forParameters("unknown")
                .failsWithMessage("Could not choose a best candidate operator. Explicit type casts must be added.");
    }

    private SignatureBuilder functionSignature(String... argumentTypes)
    {
        return functionSignature(ImmutableList.copyOf(argumentTypes), "boolean");
    }

    private static SignatureBuilder functionSignature(List<String> arguments, String returnType)
    {
        return functionSignature(arguments, returnType, ImmutableList.of());
    }

    private static SignatureBuilder functionSignature(List<String> arguments, String returnType, List<TypeVariableConstraint> typeVariableConstraints)
    {
        ImmutableSet<String> literalParameters = ImmutableSet.of("p", "s", "p1", "s1", "p2", "s2", "p3", "s3");
        List<TypeSignature> argumentSignatures = arguments.stream()
                .map((signature) -> parseTypeSignature(signature, literalParameters))
                .collect(toImmutableList());
        return new SignatureBuilder()
                .returnType(parseTypeSignature(returnType, literalParameters))
                .argumentTypes(argumentSignatures)
                .typeVariableConstraints(typeVariableConstraints)
                .kind(SCALAR);
    }

    private static ResolveFunctionAssertion assertThatResolveFunction()
    {
        return new ResolveFunctionAssertion();
    }

    private static class ResolveFunctionAssertion
    {
        private static final String TEST_FUNCTION_NAME = "TEST_FUNCTION_NAME";

        private List<SignatureBuilder> functionSignatures = ImmutableList.of();
        private List<TypeSignature> parameterTypes = ImmutableList.of();

        public ResolveFunctionAssertion among(SignatureBuilder... functionSignatures)
        {
            this.functionSignatures = ImmutableList.copyOf(functionSignatures);
            return this;
        }

        public ResolveFunctionAssertion forParameters(String... parameters)
        {
            this.parameterTypes = parseTypeSignatures(parameters);
            return this;
        }

        public ResolveFunctionAssertion returns(SignatureBuilder functionSignature)
        {
            FunctionHandle expectedFunction = new BuiltInFunctionHandle(functionSignature.name(TEST_FUNCTION_NAME).build());
            FunctionHandle actualFunction = resolveFunctionHandle();
            assertEquals(expectedFunction, actualFunction);
            return this;
        }

        public ResolveFunctionAssertion failsWithMessage(String... messages)
        {
            try {
                resolveFunctionHandle();
                fail("didn't fail as expected");
            }
            catch (RuntimeException e) {
                String actualMessage = e.getMessage();
                for (String expectedMessage : messages) {
                    if (!actualMessage.contains(expectedMessage)) {
                        fail(format("%s doesn't contain %s", actualMessage, expectedMessage));
                    }
                }
            }
            return this;
        }

        private FunctionHandle resolveFunctionHandle()
        {
            FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager();
            functionAndTypeManager.registerBuiltInFunctions(createFunctionsFromSignatures());
            return functionAndTypeManager.resolveFunction(
                    Optional.empty(),
                    TEST_SESSION.getTransactionId(),
                    functionAndTypeManager.getFunctionAndTypeResolver().qualifyObjectName(QualifiedName.of(TEST_FUNCTION_NAME)),
                    fromTypeSignatures(parameterTypes));
        }

        private List<BuiltInFunction> createFunctionsFromSignatures()
        {
            ImmutableList.Builder<BuiltInFunction> functions = ImmutableList.builder();
            for (SignatureBuilder functionSignature : functionSignatures) {
                Signature signature = functionSignature.name(TEST_FUNCTION_NAME).build();
                functions.add(new SqlScalarFunction(signature)
                {
                    @Override
                    public BuiltInScalarFunctionImplementation specialize(
                            BoundVariables boundVariables,
                            int arity,
                            FunctionAndTypeManager functionAndTypeManager)
                    {
                        return new BuiltInScalarFunctionImplementation(
                                false,
                                nCopies(arity, valueTypeArgumentProperty(RETURN_NULL_ON_NULL)),
                                MethodHandles.identity(Void.class));
                    }

                    @Override
                    public SqlFunctionVisibility getVisibility()
                    {
                        return PUBLIC;
                    }

                    @Override
                    public boolean isDeterministic()
                    {
                        return false;
                    }

                    @Override
                    public String getDescription()
                    {
                        return "testing function that does nothing";
                    }
                });
            }
            return functions.build();
        }

        private static List<TypeSignature> parseTypeSignatures(String... signatures)
        {
            return ImmutableList.copyOf(signatures)
                    .stream()
                    .map(TypeSignature::parseTypeSignature)
                    .collect(toList());
        }
    }

    public static final class ScalarSum
    {
        private ScalarSum() {}

        @ScalarFunction
        @SqlType(StandardTypes.BIGINT)
        public static long sum(@SqlType(StandardTypes.BIGINT) long a, @SqlType(StandardTypes.BIGINT) long b)
        {
            return a + b;
        }
    }
}