TestPullUpExpressionInLambdaRules.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.MapType;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;

import java.lang.invoke.MethodHandle;

import static com.facebook.presto.SystemSessionProperties.PULL_EXPRESSION_FROM_LAMBDA_ENABLED;
import static com.facebook.presto.common.block.MethodHandleUtil.compose;
import static com.facebook.presto.common.block.MethodHandleUtil.nativeValueGetter;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;
import static com.facebook.presto.testing.TestingEnvironment.getOperatorMethodHandle;

public class TestPullUpExpressionInLambdaRules
        extends BaseRuleTest
{
    private static final MethodHandle KEY_NATIVE_EQUALS = getOperatorMethodHandle(OperatorType.EQUAL, BIGINT, BIGINT);
    private static final MethodHandle KEY_BLOCK_EQUALS = compose(KEY_NATIVE_EQUALS, nativeValueGetter(BIGINT), nativeValueGetter(BIGINT));
    private static final MethodHandle KEY_NATIVE_HASH_CODE = getOperatorMethodHandle(OperatorType.HASH_CODE, BIGINT);
    private static final MethodHandle KEY_BLOCK_HASH_CODE = compose(KEY_NATIVE_HASH_CODE, nativeValueGetter(BIGINT));

    @Test
    public void testProjection()
    {
        tester().assertThat(new PullUpExpressionInLambdaRules(getFunctionManager()).projectNodeRule())
                .setSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, "true")
                .on(p ->
                {
                    p.variable("idmap", new MapType(BIGINT, BIGINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE));
                    return p.project(
                            Assignments.builder().put(p.variable("expr"), p.rowExpression("map_filter(idmap, (k, v) -> array_position(array_sort(map_keys(idmap)), k) <= 200)")).build(),
                            p.values(p.variable("idmap", new MapType(BIGINT, BIGINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE))));
                })
                .matches(
                        project(
                                ImmutableMap.of("expr", expression("map_filter(idmap, (k, v) -> (array_position(array_sort, k)) <= (INTEGER'200'))")),
                                project(ImmutableMap.of("array_sort", expression("array_sort(map_keys(idmap))")),
                                        values("idmap"))));
    }

    @Test
    public void testFilter()
    {
        tester().assertThat(new PullUpExpressionInLambdaRules(getFunctionManager()).filterNodeRule())
                .setSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, "true")
                .on(p ->
                {
                    p.variable("idmap", new MapType(BIGINT, BIGINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE));
                    return p.filter(
                            p.rowExpression("cardinality(map_filter(idmap, (k, v) -> array_position(array_sort(map_keys(idmap)), k) <= 200)) > 0"),
                            p.values(p.variable("idmap", new MapType(BIGINT, BIGINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE))));
                })
                .matches(
                        project(
                                ImmutableMap.of("idmap", expression("idmap")),
                                filter(
                                        "(cardinality(map_filter(idmap, (k, v) -> (array_position(array_sort, k)) <= (INTEGER'200')))) > (INTEGER'0')",
                                        project(ImmutableMap.of("array_sort", expression("array_sort(map_keys(idmap))")),
                                                values("idmap")))));
    }

    @Test
    public void testNonDeterministicProjection()
    {
        tester().assertThat(new PullUpExpressionInLambdaRules(getFunctionManager()).projectNodeRule())
                .setSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, "true")
                .on(p ->
                {
                    p.variable("idmap", new MapType(BIGINT, BIGINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE));
                    return p.project(
                            Assignments.builder().put(p.variable("expr"), p.rowExpression("map_filter(idmap, (k, v) -> array_position(array[random()], k) <= 200)")).build(),
                            p.values(p.variable("idmap", new MapType(BIGINT, BIGINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE))));
                }).doesNotFire();
    }

    @Test
    public void testNonDeterministicFilter()
    {
        tester().assertThat(new PullUpExpressionInLambdaRules(getFunctionManager()).filterNodeRule())
                .setSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, "true")
                .on(p ->
                {
                    p.variable("idmap", new MapType(BIGINT, BIGINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE));
                    return p.filter(
                            p.rowExpression("cardinality(map_filter(idmap, (k, v) -> array_position(array_sort(array[random(), random()]), k) <= 200)) > 0"),
                            p.values(p.variable("idmap", new MapType(BIGINT, BIGINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE))));
                }).doesNotFire();
    }

    @Test
    public void testNoValidProjection()
    {
        tester().assertThat(new PullUpExpressionInLambdaRules(getFunctionManager()).projectNodeRule())
                .setSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, "true")
                .on(p ->
                {
                    p.variable("idmap", new MapType(BIGINT, BIGINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE));
                    return p.project(
                            Assignments.builder().put(p.variable("expr"), p.rowExpression("map_filter(idmap, (k, v) -> array_position(array_sort(array[v]), k) <= 200)")).build(),
                            p.values(p.variable("idmap", new MapType(BIGINT, BIGINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE))));
                })
                .doesNotFire();
    }

    @Test
    public void testNoValidFilter()
    {
        tester().assertThat(new PullUpExpressionInLambdaRules(getFunctionManager()).filterNodeRule())
                .setSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, "true")
                .on(p ->
                {
                    p.variable("idmap", new MapType(BIGINT, BIGINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE));
                    return p.filter(
                            p.rowExpression("cardinality(map_filter(idmap, (k, v) -> array_position(array_sort(array[v, k]), k) <= 200)) > 0"),
                            p.values(p.variable("idmap", new MapType(BIGINT, BIGINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE))));
                }).doesNotFire();
    }

    @Test
    public void testNestedLambdaInProjection()
    {
        tester().assertThat(new PullUpExpressionInLambdaRules(getFunctionManager()).projectNodeRule())
                .setSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, "true")
                .on(p ->
                {
                    p.variable("expr", new ArrayType(new ArrayType(BIGINT)));
                    p.variable("arr1", new ArrayType(BIGINT));
                    p.variable("arr2", new ArrayType(BIGINT));
                    return p.project(
                            Assignments.builder().put(p.variable("expr", new ArrayType(new ArrayType(BIGINT))), p.rowExpression("transform(arr1, x->transform(arr2, y->slice(arr2, 1, 10)))")).build(),
                            p.values(p.variable("arr1", new ArrayType(BIGINT)), p.variable("arr2", new ArrayType(BIGINT))));
                })
                .matches(
                        project(
                                ImmutableMap.of("expr", expression("transform(arr1, (x) -> transform(arr2, (y) -> slice))")),
                                project(

                                        ImmutableMap.of("slice", expression("slice(arr2, 1, 10)")),
                                        values("arr1", "arr2"))));
    }

    @Test
    public void testInvalidNestedLambdaInProjection()
    {
        tester().assertThat(new PullUpExpressionInLambdaRules(getFunctionManager()).projectNodeRule())
                .setSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, "true")
                .on(p ->
                {
                    p.variable("expr", new ArrayType(new ArrayType(BIGINT)));
                    p.variable("arr1", new ArrayType(BIGINT));
                    p.variable("arr2", new ArrayType(BIGINT));
                    return p.project(
                            Assignments.builder().put(p.variable("expr", new ArrayType(new ArrayType(BIGINT))), p.rowExpression("transform(arr1, x->transform(arr2, y->slice(arr2, 1, x)))")).build(),
                            p.values(p.variable("arr1", new ArrayType(BIGINT)), p.variable("arr2", new ArrayType(BIGINT))));
                }).doesNotFire();
    }

    @Test
    public void testSkipTryFunction()
    {
        tester().assertThat(new PullUpExpressionInLambdaRules(getFunctionManager()).projectNodeRule())
                .setSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, "true")
                .on(p ->
                {
                    p.variable("x");
                    return p.project(
                            Assignments.builder().put(p.variable("expr", VARCHAR), p.rowExpression("JSON_FORMAT(CAST(TRY(MAP(ARRAY[NULL], ARRAY[x])) AS JSON))")).build(),
                            p.values(p.variable("x")));
                }).doesNotFire();
    }

    @Test
    public void testSwitchWhenExpression()
    {
        tester().assertThat(new PullUpExpressionInLambdaRules(getFunctionManager()).projectNodeRule())
                .setSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, "true")
                .on(p ->
                {
                    p.variable("arr", new ArrayType(VARCHAR));
                    p.variable("arr2", new ArrayType(VARCHAR));
                    return p.project(
                            Assignments.builder().put(p.variable("expr", VARCHAR), p.rowExpression(
                                    "transform(arr, x -> concat(case when arr2 is null then '*' when contains(arr2, x) then '+' else ' ' end, x))")).build(),
                            p.values(p.variable("arr", new ArrayType(VARCHAR)), p.variable("arr2", new ArrayType(VARCHAR))));
                }).matches(
                        project(
                                ImmutableMap.of("expr", expression("transform(arr, x -> concat(case when expr_0 then '*' when contains(arr2, x) then '+' else ' ' end, x))")),
                                project(ImmutableMap.of("expr_0", expression("arr2 is null")),
                                        values("arr", "arr2"))));
    }

    // Candidate expression for extract is the second when expression, hence skip
    @Test
    public void testInvalidSwitchWhenExpression()
    {
        tester().assertThat(new PullUpExpressionInLambdaRules(getFunctionManager()).projectNodeRule())
                .setSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, "true")
                .on(p ->
                {
                    p.variable("arr", new ArrayType(VARCHAR));
                    p.variable("arr2", new ArrayType(VARCHAR));
                    return p.project(
                            Assignments.builder().put(p.variable("expr", VARCHAR), p.rowExpression(
                                    "transform(arr, x -> concat(case when contains(arr2, x) then '*' when arr2 is null then '+' else ' ' end, x))")).build(),
                            p.values(p.variable("arr", new ArrayType(VARCHAR)), p.variable("arr2", new ArrayType(VARCHAR))));
                }).doesNotFire();
    }

    @Test
    public void testCaseWhenExpression()
    {
        tester().assertThat(new PullUpExpressionInLambdaRules(getFunctionManager()).projectNodeRule())
                .setSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, "true")
                .on(p ->
                {
                    p.variable("arr", new ArrayType(VARCHAR));
                    p.variable("arr2", new ArrayType(VARCHAR));
                    p.variable("col1");
                    return p.project(
                            Assignments.builder().put(p.variable("expr", VARCHAR), p.rowExpression(
                                    "transform(arr, x -> concat(case (col1 > 2) when arr2 is null then '*' when contains(arr2, x) then '+' else ' ' end, x))")).build(),
                            p.values(p.variable("arr", new ArrayType(VARCHAR)), p.variable("arr2", new ArrayType(VARCHAR)), p.variable("col1")));
                }).matches(
                        project(
                                ImmutableMap.of("expr", expression("transform(arr, x -> concat(case expr_1 when expr_0 then '*' when contains(arr2, x) then '+' else ' ' end, x))")),
                                project(ImmutableMap.of("expr_0", expression("arr2 is null"), "expr_1", expression("col1>2")),
                                        values("arr", "arr2", "col1"))));
    }

    @Test
    public void testConditionalExpression()
    {
        tester().assertThat(new PullUpExpressionInLambdaRules(getFunctionManager()).projectNodeRule())
                .setSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, "true")
                .on(p ->
                {
                    p.variable("col1", new ArrayType(BOOLEAN));
                    p.variable("col2", new ArrayType(BIGINT));
                    return p.project(
                            Assignments.builder().put(p.variable("expr", VARCHAR), p.rowExpression(
                                    "transform(col1, x -> if(x, col2[2], 0))")).build(),
                            p.values(p.variable("col1", new ArrayType(BOOLEAN)), p.variable("col2", new ArrayType(BIGINT))));
                }).doesNotFire();
    }

    @Test
    public void testIfExpressionOnCondition()
    {
        tester().assertThat(new PullUpExpressionInLambdaRules(getFunctionManager()).projectNodeRule())
                .setSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, "true")
                .on(p ->
                {
                    p.variable("col1", new ArrayType(BOOLEAN));
                    p.variable("col2", new ArrayType(BIGINT));
                    p.variable("col3");
                    return p.project(
                            Assignments.builder().put(p.variable("expr", VARCHAR), p.rowExpression(
                                    "transform(col1, x -> if(col3 > 2, col2[2], 0))")).build(),
                            p.values(p.variable("col1", new ArrayType(BOOLEAN)), p.variable("col2", new ArrayType(BIGINT)), p.variable("col3")));
                }).matches(
                        project(
                                ImmutableMap.of("expr", expression("transform(col1, x -> if(greater_than, col2[2], 0))")),
                                project(ImmutableMap.of("greater_than", expression("col3>2")),
                                        values("col1", "col2", "col3"))));
    }

    @Test
    public void testIfExpressionOnValue()
    {
        tester().assertThat(new PullUpExpressionInLambdaRules(getFunctionManager()).projectNodeRule())
                .setSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, "true")
                .on(p ->
                {
                    p.variable("col1", new ArrayType(BOOLEAN));
                    p.variable("col2", new ArrayType(BIGINT));
                    p.variable("col3");
                    return p.project(
                            Assignments.builder().put(p.variable("expr", VARCHAR), p.rowExpression(
                                    "transform(col1, x -> if(x, col3 - 2, 0))")).build(),
                            p.values(p.variable("col1", new ArrayType(BOOLEAN)), p.variable("col2", new ArrayType(BIGINT)), p.variable("col3")));
                }).doesNotFire();
    }

    @Test
    public void testSubscriptExpression()
    {
        tester().assertThat(new PullUpExpressionInLambdaRules(getFunctionManager()).projectNodeRule())
                .setSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, "true")
                .on(p ->
                {
                    p.variable("col1", new ArrayType(BOOLEAN));
                    p.variable("col2", new ArrayType(BIGINT));
                    p.variable("col3");
                    return p.project(
                            Assignments.builder().put(p.variable("expr", VARCHAR), p.rowExpression(
                                    "transform(col1, x -> col2[2])")).build(),
                            p.values(p.variable("col1", new ArrayType(BOOLEAN)), p.variable("col2", new ArrayType(BIGINT)), p.variable("col3")));
                }).doesNotFire();
    }

    @Test
    public void testLikeExpression()
    {
        tester().assertThat(new PullUpExpressionInLambdaRules(getFunctionManager()).projectNodeRule())
                .setSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, "true")
                .on(p ->
                {
                    p.variable("expr", new ArrayType(BOOLEAN));
                    p.variable("col", VARCHAR);
                    p.variable("arr1", new ArrayType(VARCHAR));
                    return p.project(
                            Assignments.builder().put(p.variable("expr", new ArrayType(BOOLEAN)), p.rowExpression("transform(arr1, x-> x like concat(col, 'a'))")).build(),
                            p.values(p.variable("arr1", new ArrayType(VARCHAR)), p.variable("col", VARCHAR)));
                })
                .matches(
                        project(
                                ImmutableMap.of("expr", expression("transform(arr1, x -> x like concat_1)")),
                                project(

                                        ImmutableMap.of("concat_1", expression("concat(col, 'a')")),
                                        values("arr1", "col"))));
    }

    @Test
    public void testRegexpLikeExpression()
    {
        tester().assertThat(new PullUpExpressionInLambdaRules(getFunctionManager()).projectNodeRule())
                .setSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, "true")
                .on(p ->
                {
                    p.variable("expr", new ArrayType(BOOLEAN));
                    p.variable("col", VARCHAR);
                    p.variable("arr1", new ArrayType(VARCHAR));
                    return p.project(
                            Assignments.builder().put(p.variable("expr", new ArrayType(BOOLEAN)), p.rowExpression("transform(arr1, x-> regexp_like(x, concat(col, 'a')))")).build(),
                            p.values(p.variable("arr1", new ArrayType(VARCHAR)), p.variable("col", VARCHAR)));
                })
                .matches(
                        project(
                                ImmutableMap.of("expr", expression("transform(arr1, x -> regexp_like(x, concat_1))")),
                                project(

                                        ImmutableMap.of("concat_1", expression("concat(col, 'a')")),
                                        values("arr1", "col"))));
    }
}