TestInlineSqlFunctions.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.common.CatalogSchemaName;
import com.facebook.presto.common.QualifiedObjectName;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.IntegerType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.functionNamespace.SqlInvokedFunctionNamespaceManagerConfig;
import com.facebook.presto.functionNamespace.execution.NoopSqlFunctionExecutor;
import com.facebook.presto.functionNamespace.execution.SqlFunctionExecutors;
import com.facebook.presto.functionNamespace.testing.InMemoryFunctionNamespaceManager;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.function.FunctionImplementationType;
import com.facebook.presto.spi.function.Parameter;
import com.facebook.presto.spi.function.RoutineCharacteristics;
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.sql.TestingRowExpressionTranslator;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.facebook.presto.sql.planner.iterative.rule.test.RuleAssert;
import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.util.Map;

import static com.facebook.presto.common.type.StandardTypes.INTEGER;
import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager;
import static com.facebook.presto.spi.function.FunctionImplementationType.THRIFT;
import static com.facebook.presto.spi.function.FunctionVersion.notVersioned;
import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.DETERMINISTIC;
import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.SQL;
import static com.facebook.presto.spi.function.RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;
import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment;

public class TestInlineSqlFunctions
        extends BaseRuleTest
{
    private static final RoutineCharacteristics.Language JAVA = new RoutineCharacteristics.Language("java");
    private static final SqlInvokedFunction SQL_FUNCTION_SQUARE = new SqlInvokedFunction(
            QualifiedObjectName.valueOf(new CatalogSchemaName("unittest", "memory"), "square"),
            ImmutableList.of(new Parameter("x", parseTypeSignature(INTEGER))),
            parseTypeSignature(INTEGER),
            "square",
            RoutineCharacteristics.builder()
                    .setDeterminism(DETERMINISTIC)
                    .setNullCallClause(RETURNS_NULL_ON_NULL_INPUT)
                    .build(),
            "RETURN x * x",
            notVersioned());

    private static final SqlInvokedFunction THRIFT_FUNCTION_FOO = new SqlInvokedFunction(
            QualifiedObjectName.valueOf(new CatalogSchemaName("unittest", "memory"), "foo"),
            ImmutableList.of(new Parameter("x", parseTypeSignature(INTEGER))),
            parseTypeSignature(INTEGER),
            "thrift function foo",
            RoutineCharacteristics.builder()
                    .setLanguage(JAVA)
                    .setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT)
                    .build(),
            "",
            notVersioned());

    private static final SqlInvokedFunction SQL_FUNCTION_ADD_1_TO_INT_ARRAY = new SqlInvokedFunction(
            QualifiedObjectName.valueOf(new CatalogSchemaName("unittest", "memory"), "add_1_int"),
            ImmutableList.of(new Parameter("x", parseTypeSignature("array(int)"))),
            parseTypeSignature("array(int)"),
            "add 1 to all elements of array",
            RoutineCharacteristics.builder()
                    .setDeterminism(DETERMINISTIC)
                    .setNullCallClause(RETURNS_NULL_ON_NULL_INPUT)
                    .build(),
            "RETURN transform(x, x -> x + 1)",
            notVersioned());

    private static final SqlInvokedFunction SQL_FUNCTION_ADD_1_TO_BIGINT_ARRAY = new SqlInvokedFunction(
            QualifiedObjectName.valueOf(new CatalogSchemaName("unittest", "memory"), "add_1_bigint"),
            ImmutableList.of(new Parameter("x", parseTypeSignature("array(bigint)"))),
            parseTypeSignature("array(bigint)"),
            "add 1 to all elements of array",
            RoutineCharacteristics.builder()
                    .setDeterminism(DETERMINISTIC)
                    .setNullCallClause(RETURNS_NULL_ON_NULL_INPUT)
                    .build(),
            "RETURN transform(x, x -> x + 1)",
            notVersioned());

    @BeforeClass
    public final void setup()
    {
        tester = new RuleTester();
        FunctionAndTypeManager functionAndTypeManager = tester.getMetadata().getFunctionAndTypeManager();
        functionAndTypeManager.addFunctionNamespace(
                "unittest",
                new InMemoryFunctionNamespaceManager(
                        "unittest",
                        new SqlFunctionExecutors(
                                ImmutableMap.of(
                                        SQL, FunctionImplementationType.SQL,
                                        JAVA, THRIFT),
                                new NoopSqlFunctionExecutor()),
                        new SqlInvokedFunctionNamespaceManagerConfig().setSupportedFunctionLanguages("sql,java")));
        functionAndTypeManager.createFunction(SQL_FUNCTION_SQUARE, true);
        functionAndTypeManager.createFunction(THRIFT_FUNCTION_FOO, true);
        functionAndTypeManager.createFunction(SQL_FUNCTION_ADD_1_TO_INT_ARRAY, true);
        functionAndTypeManager.createFunction(SQL_FUNCTION_ADD_1_TO_BIGINT_ARRAY, true);
    }

    @Test
    public void testInlineFunction()
    {
        assertInlined("unittest.memory.square(x)", "x * x", "x", IntegerType.INTEGER);
    }

    @Test
    public void testInlineFunctionInsideFunction()
    {
        assertInlined("abs(unittest.memory.square(x))", "abs(x * x)", "x", IntegerType.INTEGER);
    }

    @Test
    public void testInlineFunctionContainingLambda()
    {
        assertInlined(
                "unittest.memory.add_1_int(x)",
                "transform(x, \"x$lambda\" -> \"x$lambda\" + 1)",
                "x",
                new ArrayType(IntegerType.INTEGER));
    }

    @Test
    public void testInlineSqlFunctionCoercesConstantWithCast()
    {
        assertInlined(
                "unittest.memory.add_1_bigint(x)",
                "transform(x, \"x$lambda\" -> \"x$lambda\" + CAST(1 AS bigint))",
                "x",
                new ArrayType(BigintType.BIGINT));
    }

    @Test
    public void testNoInlineThriftFunction()
    {
        assertNotInlined("unittest.memory.foo(x)", "x", IntegerType.INTEGER);
    }

    @Test
    public void testNoInlineIntoPlanWhenInlineIsDisabled()
    {
        assertNotInlined("unittest.memory.square(x)",
                ImmutableMap.of("inline_sql_functions", "false"),
                "x",
                IntegerType.INTEGER);
    }

    protected void assertInlined(String inputExpressionStr, String expectedExpressionStr, String variable, Type type)
    {
        RowExpression inputExpression = new TestingRowExpressionTranslator(tester.getMetadata()).translate(inputExpressionStr, ImmutableMap.of(variable, type));

        tester().assertThat(new InlineSqlFunctions(tester.getMetadata()).projectRowExpressionRewriteRule())
                .on(p -> p.project(assignment(p.variable("var"), inputExpression), p.values(p.variable(variable, type))))
                .matches(project(ImmutableMap.of("var", expression(expectedExpressionStr)), values(variable)));
    }

    private void assertNotInlined(String expression, String variable, Type type)
    {
        assertNotInlined(expression, ImmutableMap.of(), variable, type);
    }

    private void assertNotInlined(String expression, Map<String, String> sessionValues, String variable, Type type)
    {
        RowExpression inputExpression = new TestingRowExpressionTranslator(tester.getMetadata()).translate(expression, ImmutableMap.of(variable, type));
        RuleAssert ruleAssert = tester.assertThat(new SimplifyCardinalityMap(createTestFunctionAndTypeManager()).projectRowExpressionRewriteRule());
        sessionValues.forEach((k, v) -> ruleAssert.setSystemProperty(k, v));
        ruleAssert
                .on(p -> p.project(assignment(p.variable("var"), inputExpression), p.values(p.variable(variable, type))))
                .doesNotFire();
    }
}