TestPolymorphicScalarFunction.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.metadata;

import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.LongArrayBlock;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.operator.scalar.BuiltInScalarFunctionImplementation;
import com.facebook.presto.spi.function.Signature;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import org.testng.annotations.Test;

import java.util.Collections;
import java.util.Optional;

import static com.facebook.presto.common.function.OperatorType.ADD;
import static com.facebook.presto.common.function.OperatorType.IS_DISTINCT_FROM;
import static com.facebook.presto.common.type.Decimals.MAX_SHORT_PRECISION;
import static com.facebook.presto.common.type.StandardTypes.BOOLEAN;
import static com.facebook.presto.common.type.StandardTypes.VARCHAR;
import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager;
import static com.facebook.presto.metadata.TestPolymorphicScalarFunction.TestMethods.VARCHAR_TO_BIGINT_RETURN_VALUE;
import static com.facebook.presto.metadata.TestPolymorphicScalarFunction.TestMethods.VARCHAR_TO_VARCHAR_RETURN_VALUE;
import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.ArgumentProperty.valueTypeArgumentProperty;
import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.NullConvention.BLOCK_AND_POSITION;
import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.NullConvention.USE_NULL_FLAG;
import static com.facebook.presto.spi.function.FunctionKind.SCALAR;
import static com.facebook.presto.spi.function.Signature.comparableWithVariadicBound;
import static java.lang.Math.toIntExact;
import static java.util.Arrays.asList;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;

public class TestPolymorphicScalarFunction
{
    private static final FunctionAndTypeManager FUNCTION_AND_TYPE_MANAGER = createTestFunctionAndTypeManager();
    private static final Signature SIGNATURE = SignatureBuilder.builder()
            .name("foo")
            .kind(SCALAR)
            .returnType(parseTypeSignature(StandardTypes.BIGINT))
            .argumentTypes(parseTypeSignature("varchar(x)", ImmutableSet.of("x")))
            .build();
    private static final long INPUT_VARCHAR_LENGTH = 10;
    private static final String INPUT_VARCHAR_SIGNATURE = "varchar(" + INPUT_VARCHAR_LENGTH + ")";
    private static final TypeSignature INPUT_VARCHAR_TYPE = parseTypeSignature(INPUT_VARCHAR_SIGNATURE);
    private static final Slice INPUT_SLICE = Slices.allocate(toIntExact(INPUT_VARCHAR_LENGTH));
    private static final BoundVariables BOUND_VARIABLES = new BoundVariables(
            ImmutableMap.of("V", FUNCTION_AND_TYPE_MANAGER.getType(INPUT_VARCHAR_TYPE)),
            ImmutableMap.of("x", INPUT_VARCHAR_LENGTH));

    private static final TypeSignature DECIMAL_SIGNATURE = parseTypeSignature("decimal(a_precision, a_scale)", ImmutableSet.of("a_precision", "a_scale"));
    private static final BoundVariables LONG_DECIMAL_BOUND_VARIABLES = new BoundVariables(
            ImmutableMap.of(),
            ImmutableMap.of("a_precision", MAX_SHORT_PRECISION + 1L, "a_scale", 2L));
    private static final BoundVariables SHORT_DECIMAL_BOUND_VARIABLES = new BoundVariables(
            ImmutableMap.of(),
            ImmutableMap.of("a_precision", (long) MAX_SHORT_PRECISION, "a_scale", 2L));

    @Test
    public void testSelectsMultipleChoiceWithBlockPosition()
            throws Throwable
    {
        Signature signature = SignatureBuilder.builder()
                .kind(SCALAR)
                .operatorType(IS_DISTINCT_FROM)
                .argumentTypes(DECIMAL_SIGNATURE, DECIMAL_SIGNATURE)
                .returnType(parseTypeSignature(BOOLEAN))
                .build();

        SqlScalarFunction function = SqlScalarFunction.builder(TestMethods.class, IS_DISTINCT_FROM)
                .signature(signature)
                .deterministic(true)
                .choice(choice -> choice
                        .argumentProperties(
                                valueTypeArgumentProperty(USE_NULL_FLAG),
                                valueTypeArgumentProperty(USE_NULL_FLAG))
                        .implementation(methodsGroup -> methodsGroup
                                .methods("shortShort", "longLong")))
                .choice(choice -> choice
                        .argumentProperties(
                                valueTypeArgumentProperty(BLOCK_AND_POSITION),
                                valueTypeArgumentProperty(BLOCK_AND_POSITION))
                        .implementation(methodsGroup -> methodsGroup
                                .methodWithExplicitJavaTypes("blockPositionLongLong",
                                        asList(Optional.of(Slice.class), Optional.of(Slice.class)))
                                .methodWithExplicitJavaTypes("blockPositionShortShort",
                                        asList(Optional.of(long.class), Optional.of(long.class)))))
                .build();

        BuiltInScalarFunctionImplementation functionImplementation = function.specialize(SHORT_DECIMAL_BOUND_VARIABLES, 2, FUNCTION_AND_TYPE_MANAGER);

        assertEquals(functionImplementation.getAllChoices().size(), 2);
        assertEquals(functionImplementation.getAllChoices().get(0).getArgumentProperties(), Collections.nCopies(2, valueTypeArgumentProperty(USE_NULL_FLAG)));
        assertEquals(functionImplementation.getAllChoices().get(1).getArgumentProperties(), Collections.nCopies(2, valueTypeArgumentProperty(BLOCK_AND_POSITION)));
        Block block1 = new LongArrayBlock(0, Optional.empty(), new long[0]);
        Block block2 = new LongArrayBlock(0, Optional.empty(), new long[0]);
        assertFalse((boolean) functionImplementation.getAllChoices().get(1).getMethodHandle().invoke(block1, 0, block2, 0));
        functionImplementation = function.specialize(LONG_DECIMAL_BOUND_VARIABLES, 2, FUNCTION_AND_TYPE_MANAGER);
        assertTrue((boolean) functionImplementation.getAllChoices().get(1).getMethodHandle().invoke(block1, 0, block2, 0));
    }

    @Test
    public void testSelectsMethodBasedOnArgumentTypes()
            throws Throwable
    {
        SqlScalarFunction function = SqlScalarFunction.builder(TestMethods.class)
                .signature(SIGNATURE)
                .deterministic(true)
                .calledOnNullInput(false)
                .choice(choice -> choice
                        .implementation(methodsGroup -> methodsGroup.methods("bigintToBigintReturnExtraParameter"))
                        .implementation(methodsGroup -> methodsGroup
                                .methods("varcharToBigintReturnExtraParameter")
                                .withExtraParameters(context -> ImmutableList.of(context.getLiteral("x")))))
                .build();

        BuiltInScalarFunctionImplementation functionImplementation = function.specialize(BOUND_VARIABLES, 1, FUNCTION_AND_TYPE_MANAGER);
        assertEquals(functionImplementation.getMethodHandle().invoke(INPUT_SLICE), INPUT_VARCHAR_LENGTH);
    }

    @Test
    public void testSelectsMethodBasedOnReturnType()
            throws Throwable
    {
        SqlScalarFunction function = SqlScalarFunction.builder(TestMethods.class)
                .signature(SIGNATURE)
                .deterministic(true)
                .calledOnNullInput(false)
                .choice(choice -> choice
                        .implementation(methodsGroup -> methodsGroup.methods("varcharToVarcharCreateSliceWithExtraParameterLength"))
                        .implementation(methodsGroup -> methodsGroup
                                .methods("varcharToBigintReturnExtraParameter")
                                .withExtraParameters(context -> ImmutableList.of(42))))
                .build();

        BuiltInScalarFunctionImplementation functionImplementation = function.specialize(BOUND_VARIABLES, 1, FUNCTION_AND_TYPE_MANAGER);

        assertEquals(functionImplementation.getMethodHandle().invoke(INPUT_SLICE), VARCHAR_TO_BIGINT_RETURN_VALUE);
    }

    @Test
    public void testSameLiteralInArgumentsAndReturnValue()
            throws Throwable
    {
        Signature signature = SignatureBuilder.builder()
                .name("foo")
                .kind(SCALAR)
                .returnType(parseTypeSignature("varchar(x)", ImmutableSet.of("x")))
                .argumentTypes(parseTypeSignature("varchar(x)", ImmutableSet.of("x")))
                .build();

        SqlScalarFunction function = SqlScalarFunction.builder(TestMethods.class)
                .signature(signature)
                .deterministic(true)
                .calledOnNullInput(false)
                .choice(choice -> choice
                        .implementation(methodsGroup -> methodsGroup.methods("varcharToVarchar")))
                .build();

        BuiltInScalarFunctionImplementation functionImplementation = function.specialize(BOUND_VARIABLES, 1, FUNCTION_AND_TYPE_MANAGER);
        Slice slice = (Slice) functionImplementation.getMethodHandle().invoke(INPUT_SLICE);
        assertEquals(slice, VARCHAR_TO_VARCHAR_RETURN_VALUE);
    }

    @Test
    public void testTypeParameters()
            throws Throwable
    {
        Signature signature = SignatureBuilder.builder()
                .name("foo")
                .kind(SCALAR)
                .typeVariableConstraints(comparableWithVariadicBound("V", VARCHAR))
                .returnType(parseTypeSignature("V"))
                .argumentTypes(parseTypeSignature("V"))
                .build();

        SqlScalarFunction function = SqlScalarFunction.builder(TestMethods.class)
                .signature(signature)
                .deterministic(true)
                .calledOnNullInput(false)
                .choice(choice -> choice
                        .implementation(methodsGroup -> methodsGroup.methods("varcharToVarchar")))
                .build();

        BuiltInScalarFunctionImplementation functionImplementation = function.specialize(BOUND_VARIABLES, 1, FUNCTION_AND_TYPE_MANAGER);
        Slice slice = (Slice) functionImplementation.getMethodHandle().invoke(INPUT_SLICE);
        assertEquals(slice, VARCHAR_TO_VARCHAR_RETURN_VALUE);
    }

    @Test
    public void testSetsHiddenToTrueForOperators()
    {
        Signature signature = SignatureBuilder.builder()
                .operatorType(ADD)
                .kind(SCALAR)
                .returnType(parseTypeSignature("varchar(x)", ImmutableSet.of("x")))
                .argumentTypes(parseTypeSignature("varchar(x)", ImmutableSet.of("x")))
                .build();

        SqlScalarFunction function = SqlScalarFunction.builder(TestMethods.class, ADD)
                .signature(signature)
                .deterministic(true)
                .choice(choice -> choice
                        .implementation(methodsGroup -> methodsGroup.methods("varcharToVarchar")))
                .build();

        BuiltInScalarFunctionImplementation functionImplementation = function.specialize(BOUND_VARIABLES, 1, FUNCTION_AND_TYPE_MANAGER);
    }

    @Test(expectedExceptions = {IllegalStateException.class},
            expectedExceptionsMessageRegExp = "method foo was not found in class com.facebook.presto.metadata.TestPolymorphicScalarFunction\\$TestMethods")
    public void testFailIfNotAllMethodsPresent()
    {
        SqlScalarFunction.builder(TestMethods.class)
                .signature(SIGNATURE)
                .deterministic(true)
                .calledOnNullInput(false)
                .choice(choice -> choice
                        .implementation(methodsGroup -> methodsGroup.methods("bigintToBigintReturnExtraParameter"))
                        .implementation(methodsGroup -> methodsGroup.methods("foo")))
                .build();
    }

    @Test(expectedExceptions = {IllegalStateException.class},
            expectedExceptionsMessageRegExp = "methods must be selected first")
    public void testFailNoMethodsAreSelectedWhenExtraParametersFunctionIsSet()
    {
        SqlScalarFunction.builder(TestMethods.class)
                .signature(SIGNATURE)
                .deterministic(true)
                .calledOnNullInput(false)
                .choice(choice -> choice
                        .implementation(methodsGroup -> methodsGroup
                                .withExtraParameters(context -> ImmutableList.of(42))))
                .build();
    }

    @Test(expectedExceptions = {IllegalStateException.class},
            expectedExceptionsMessageRegExp = "two matching methods \\(varcharToBigintReturnFirstExtraParameter and varcharToBigintReturnExtraParameter\\) for parameter types \\[varchar\\(10\\)\\]")
    public void testFailIfTwoMethodsWithSameArguments()
    {
        SqlScalarFunction function = SqlScalarFunction.builder(TestMethods.class)
                .signature(SIGNATURE)
                .deterministic(true)
                .calledOnNullInput(false)
                .choice(choice -> choice
                        .implementation(methodsGroup -> methodsGroup.methods("varcharToBigintReturnFirstExtraParameter"))
                        .implementation(methodsGroup -> methodsGroup.methods("varcharToBigintReturnExtraParameter")))
                .build();

        function.specialize(BOUND_VARIABLES, 1, FUNCTION_AND_TYPE_MANAGER);
    }

    public static class TestMethods
    {
        static final Slice VARCHAR_TO_VARCHAR_RETURN_VALUE = Slices.utf8Slice("hello world");
        static final long VARCHAR_TO_BIGINT_RETURN_VALUE = 42L;

        public static Slice varcharToVarchar(Slice varchar)
        {
            return VARCHAR_TO_VARCHAR_RETURN_VALUE;
        }

        public static long varcharToBigint(Slice varchar)
        {
            return VARCHAR_TO_BIGINT_RETURN_VALUE;
        }

        public static long varcharToBigintReturnExtraParameter(Slice varchar, long extraParameter)
        {
            return extraParameter;
        }

        public static long bigintToBigintReturnExtraParameter(long bigint, int extraParameter)
        {
            return bigint;
        }

        public static long varcharToBigintReturnFirstExtraParameter(Slice varchar, long extraParameter1, int extraParameter2)
        {
            return extraParameter1;
        }

        public static Slice varcharToVarcharCreateSliceWithExtraParameterLength(Slice string, int extraParameter)
        {
            return Slices.allocate(extraParameter);
        }

        public static boolean blockPositionLongLong(Block left, int leftPosition, Block right, int rightPosition)
        {
            return true;
        }

        public static boolean blockPositionShortShort(Block left, int leftPosition, Block right, int rightPosition)
        {
            return false;
        }

        public static boolean shortShort(long left, boolean leftNull, long right, boolean rightNull)
        {
            return false;
        }

        public static boolean longLong(Slice left, boolean leftNull, Slice right, boolean rightNull)
        {
            return false;
        }
    }
}