TestRewriteCaseExpressionPredicate.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.sql.TestingRowExpressionTranslator;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;

import java.util.Map;

import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_CASE_EXPRESSION_PREDICATE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;

public class TestRewriteCaseExpressionPredicate
        extends BaseRuleTest
{
    private static final MetadataManager METADATA = createTestMetadataManager();
    private static final Map<String, Type> TYPE_MAP = ImmutableMap.of("col1", INTEGER, "col2", INTEGER, "col3", VARCHAR);

    private final TestingRowExpressionTranslator testSqlToRowExpressionTranslator = new TestingRowExpressionTranslator();

    @Test
    public void testRewriterDoesNotFireOnPredicateWithoutCaseExpression()
    {
        assertRewriteDoesNotFire("col1 > 1");
    }

    @Test
    public void testRewriterDoesNotFireOnPredicateWithoutComparisonFunction()
    {
        assertRewriteDoesNotFire("(case when col1=1 then 'case1' when col2=2 then 'case2' else 'default' end)");
    }

    @Test
    public void testRewriterDoesNotFireOnPredicateWithFunctionCallOnComparisonValue()
    {
        assertRewriteDoesNotFire("(case when col1=1 then 'case1' when col2=2 then 'case2' else 'default' end) = upper('case1')");
        assertRewriteDoesNotFire("(case when col1=1 then 10 when col2=2 then 20 else 30 end) = ceil(col1)");
    }

    @Test
    public void testRewriterDoesNotFireOnInvalidSearchCaseExpression()
    {
        // All LHS expressions are not the same
        assertRewriteDoesNotFire("(case when col1=1 then 'case1' when col2=2 then 'case2' else 'default' end) = 'case1'");
        assertRewriteDoesNotFire("(case when col1=1 then 'case1' when ceil(col1)=2 then 'case2' else 'default' end) = 'case1'");

        // All expressions are not equals function
        assertRewriteDoesNotFire("(case when col1>1 then 1 when col1>2 then 2 else 3 end) > 2");
        assertRewriteDoesNotFire("(case when col1<1 then 1 when col1<2 then 2 else 3 end) < 2");

        // All RHS expressions are not Constant Expression
        assertRewriteDoesNotFire("(case when col1=1 then 'case1' when col1=ceil(1) then 'case2' else 'default' end) = 'case1'");

        // All RHS expressions are not unique
        assertRewriteDoesNotFire("(case when col1=1 then 'case1' when col1=1 then 'case2' else 'default' end) = 'case1'");
    }

    @Test
    public void testSimpleCaseExpressionRewrite()
    {
        assertRewrittenExpression(
                "(case col1 when 1 then 'case1' when 2 then 'case2' else 'default' end) = 'case1'",
                "('case1' = 'case1' AND col1 = 1) OR ('case2' = 'case1' AND col1 = 2) OR ('default' = 'case1' AND (NOT(col1 = 1) AND NOT(col1 = 2)))");

        assertRewrittenExpression(
                "(case col1 when 1 then 'case1' when 2 then 'case2' else 'default' end) = 'case2'",
                "('case1' = 'case2' AND col1 = 1) OR ('case2' = 'case2' AND col1 = 2) OR ('default' = 'case2' AND (NOT(col1 = 1) AND NOT(col1 = 2)))");

        assertRewrittenExpression(
                "(case col1 when 1 then 'case1' when 2 then 'case2' else 'default' end) = 'default'",
                "('case1' = 'default' AND col1 = 1) OR ('case2' = 'default' AND col1 = 2) OR ('default' = 'default' AND (NOT(col1 = 1) AND NOT(col1 = 2)))");
    }

    @Test
    public void testSearchedCaseExpressionRewrite()
    {
        assertRewrittenExpression(
                "(case when col1=1 then 'case1' when col1=2 then 'case2' else 'default' end) = 'case1'",
                "('case1' = 'case1' AND col1 = 1) OR ('case2' = 'case1' AND col1 = 2) OR ('default' = 'case1' AND (NOT(col1 = 1) AND NOT(col1 = 2)))");

        assertRewrittenExpression(
                "(case when lower(col3)='a' then 'case1' when lower(col3)='b' then 'case2' else 'default' end) = 'case1'",
                "('case1' = 'case1' AND lower(col3) = 'a') OR ('case2' = 'case1' AND lower(col3) = 'b') OR ('default' = 'case1' AND (NOT(lower(col3) = 'a') AND NOT(lower(col3) = 'b')))");

        assertRewrittenExpression(
                "(case when ceil(col1)=1 then 'case1' when ceil(col1)=2 then 'case2' else 'default' end) = 'default'",
                "('case1' = 'default' AND ceil(col1) = 1) OR ('case2' = 'default' AND ceil(col1) = 2) OR ('default' = 'default' AND (NOT(ceil(col1) = 1) AND NOT(ceil(col1) = 2)))");
    }

    @Test
    public void testRewriterOnCaseExpressionInRightSideOfComparisonFunction()
    {
        assertRewrittenExpression(
                "(case col1 when 1 then 10 when 2 then 20 else 30 end) > 20",
                "(10 > 20 AND col1 = 1) OR (20 > 20 AND col1 = 2) OR (30 > 20 AND (NOT(col1 = 1) AND NOT(col1 = 2)))");

        assertRewrittenExpression(
                "25 < (case col1 when 1 then 10 when 2 then 20 else 30 end)",
                "(25 < 10 AND col1 = 1) OR (25 < 20 AND col1 = 2) OR (25 < 30 AND (NOT(col1 = 1) AND NOT(col1 = 2)))");
    }

    @Test
    public void testRewriterWhenMoreThanOneConditionMatches()
    {
        assertRewrittenExpression(
                "(case col1 when 1 then 'case' when 2 then 'case' else 'default' end) = 'case'",
                "('case' = 'case' AND col1 = 1) OR ('case' = 'case' AND col1 = 2) OR ('default' = 'case' AND (NOT(col1 = 1) AND NOT(col1 = 2)))");

        assertRewrittenExpression(
                "(case col1 when 1 then concat('default', 'AndCase1') when 2 then 'case2' else 'defaultAndCase1' end) = 'defaultAndCase1'",
                "(concat('default', 'AndCase1') = 'defaultAndCase1' AND col1 = 1) OR ('case2' = 'defaultAndCase1' AND col1 = 2) OR ('defaultAndCase1' = 'defaultAndCase1' AND (NOT(col1 = 1) AND NOT(col1 = 2)))");

        assertRewrittenExpression(
                "(case col3 when 'data1' then 'case1' when 'data2' then 'case2' else col3 end) = 'case1'",
                "('case1' = 'case1' AND col3 = 'data1') OR ('case2' = 'case1' AND col3 = 'data2') OR (col3 = 'case1' AND (NOT(col3 = 'data1') AND NOT(col3 = 'data2')))");
    }

    @Test
    public void testRewriterOnCaseExpressionWithoutElseClause()
    {
        assertRewrittenExpression(
                "(case col1 when 1 then 'case1' when 2 then 'case2' end) = 'case1'",
                "('case1' = 'case1' AND col1 = 1) OR ('case2' = 'case1' AND col1 = 2) OR (null = 'case1' AND (NOT(col1 = 1) AND NOT(col1 = 2)))");

        assertRewrittenExpression(
                "(case col1 when 1 then 'case1' when 2 then 'case2' end) = 'case3'",
                "('case1' = 'case3' AND col1 = 1) OR ('case2' = 'case3' AND col1 = 2) OR (null = 'case3' AND (NOT(col1 = 1) AND NOT(col1 = 2)))");

        assertRewrittenExpression(
                "(case col1 when 1 then 'case1' when 2 then 'case2' end) = 'case2'",
                "('case1' = 'case2' AND col1 = 1) OR ('case2' = 'case2' AND col1 = 2) OR (null = 'case2' AND (NOT(col1 = 1) AND NOT(col1 = 2)))");
    }

    @Test
    public void testRewriterOnCaseExpressionWithCastFunction()
    {
        // When left hand and right hand side of the expression are of different types, RowExpressionInterpreter identifies the common super type and adds a CAST function
        assertRewrittenExpression(
                "cast((case col1 when 1 then 'case11' when 2 then 'case2' else 'def' end) as VARCHAR(6)) = 'case11'",
                "(cast('case11' as VARCHAR(6)) = 'case11' AND col1 = 1) OR (cast('case2' as VARCHAR(6)) = 'case11' AND col1 = 2) OR (cast('def' as VARCHAR(6)) = 'case11' AND (NOT(col1 = 1) AND NOT(col1 = 2)))");

        assertRewrittenExpression(
                "(case col1 when 1 then 'case1' when 2 then 'case2' else 'default' end) = cast('case1' AS VARCHAR)",
                "('case1' = cast('case1' AS VARCHAR) AND col1 = 1) OR ('case2' = cast('case1' AS VARCHAR) AND col1 = 2) OR ('default' = cast('case1' AS VARCHAR) AND (NOT(col1 = 1) AND NOT(col1 = 2)))");

        assertRewrittenExpression(
                "(case when col1=cast('1' as INTEGER) then 'case1' when col1=cast('2' as INTEGER) then 'case2' else 'default' end) = 'case1'",
                "('case1' = 'case1' AND col1 = cast('1' as INTEGER)) OR ('case2' = 'case1' AND col1 = cast('2' as INTEGER)) OR ('default' = 'case1' AND (NOT(col1 = cast('1' as INTEGER)) AND NOT(col1 = cast('2' as INTEGER))))");
    }

    @Test
    public void testIfSubExpressionsAreRewritten()
    {
        assertRewrittenExpression(
                "((case col1 when 1 then 'a' else 'b' end) = 'a') = true",
                "(('a' = 'a' AND col1 = 1) OR ('b' = 'a' AND NOT(col1=1))) = true");
    }

    private void assertRewriteDoesNotFire(String expression)
    {
        tester().assertThat(new RewriteCaseExpressionPredicate(METADATA.getFunctionAndTypeManager()).filterRowExpressionRewriteRule())
                .setSystemProperty(OPTIMIZE_CASE_EXPRESSION_PREDICATE, "true")
                .on(p -> p.filter(testSqlToRowExpressionTranslator.translate(expression, TYPE_MAP), p.values()))
                .doesNotFire();
    }

    private void assertRewrittenExpression(String inputExpressionStr,
                                           String expectedExpressionStr)
    {
        RowExpression inputExpression = testSqlToRowExpressionTranslator.translate(inputExpressionStr, TYPE_MAP);

        tester().assertThat(new RewriteCaseExpressionPredicate(METADATA.getFunctionAndTypeManager()).filterRowExpressionRewriteRule())
                .setSystemProperty(OPTIMIZE_CASE_EXPRESSION_PREDICATE, "true")
                .on(p -> p.filter(inputExpression, p.values(p.variable("col1"), p.variable("col2"), p.variable("col3"))))
                .matches(filter(expectedExpressionStr, values("col1", "col2", "col3")));
    }
}