TestCommonSubExpressionRewriter.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.gen;

import com.facebook.presto.Session;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.NodeRef;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;

import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.collectCSEByLevel;
import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.getExpressionsPartitionedByCSE;
import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.rewriteExpressionWithCSE;
import static org.testng.Assert.assertEquals;

public class TestCommonSubExpressionRewriter
{
    private static final Session SESSION = TEST_SESSION;
    private static final Metadata METADATA = MetadataManager.createTestMetadataManager();
    private static final TypeProvider TYPES = TypeProvider.viewOf(
            ImmutableMap.<String, Type>builder()
                    .put("x", BIGINT)
                    .put("y", BIGINT)
                    .put("z", BIGINT)
                    .put("add$cse", BIGINT)
                    .put("multiply$cse", BIGINT)
                    .put("add$cse_0", BIGINT)
                    .put("expr$cse", BIGINT).build());

    @Test
    void testGetExpressionsWithCSE()
    {
        List<RowExpression> expressions = ImmutableList.of(rowExpression("x + y"), rowExpression("(x + y) * 2"), rowExpression("x + 2"), rowExpression("y * (x + 2)"), rowExpression("x * y"));
        Map<List<RowExpression>, Boolean> expressionsWithCSE = getExpressionsPartitionedByCSE(expressions, 3);
        assertEquals(
                expressionsWithCSE,
                ImmutableMap.of(
                        ImmutableList.of(rowExpression("x + y"), rowExpression("(x + y) * 2")), true,
                        ImmutableList.of(rowExpression("x + 2"), rowExpression("y * (x + 2)")), true,
                        ImmutableList.of(rowExpression("x * y")), false));
        expressions = ImmutableList.of(rowExpression("x + y"), rowExpression("x * 2"), rowExpression("x + y + x * 2"), rowExpression("y * 2"), rowExpression("x + y * 2"));
        expressionsWithCSE = getExpressionsPartitionedByCSE(expressions, 3);
        assertEquals(
                expressionsWithCSE,
                ImmutableMap.of(
                        ImmutableList.of(rowExpression("x + y"), rowExpression("x + y + x * 2"), rowExpression("x * 2")), true,
                        ImmutableList.of(rowExpression("y * 2"), rowExpression("x + y * 2")), true));

        expressionsWithCSE = getExpressionsPartitionedByCSE(expressions, 2);
        assertEquals(
                expressionsWithCSE,
                ImmutableMap.of(
                        ImmutableList.of(rowExpression("x + y"), rowExpression("x + y + x * 2")), true,
                        ImmutableList.of(rowExpression("y * 2"), rowExpression("x + y * 2")), true,
                        ImmutableList.of(rowExpression("x * 2")), true));
    }

    @Test
    void testCollectCSEByLevel()
    {
        List<RowExpression> expressions = ImmutableList.of(rowExpression("x * 2 + y + z"), rowExpression("(x * 2 + y + 1) * 2"), rowExpression("(x * 2)  + (x * 2 + y + z)"));
        Map<Integer, Map<RowExpression, VariableReferenceExpression>> cseByLevel = collectCSEByLevel(expressions);
        assertEquals(cseByLevel, ImmutableMap.of(
                3, ImmutableMap.of(rowExpression("\"add$cse\" + z"), rowExpression("\"add$cse_0\"")),
                2, ImmutableMap.of(rowExpression("\"multiply$cse\" + y"), rowExpression("\"add$cse\"")),
                1, ImmutableMap.of(rowExpression("x * 2"), rowExpression("\"multiply$cse\""))));
    }

    @Test
    void testCollectCSEByLevelCaseStatement()
    {
        List<RowExpression> expressions = ImmutableList.of(rowExpression("1 + CASE WHEN x = 1 THEN y + z WHEN x = 2 THEN z * 2 END"), rowExpression("2 + CASE WHEN x = 1 THEN y + z WHEN x = 2 THEN z * 2 END"));
        Map<Integer, Map<RowExpression, VariableReferenceExpression>> cseByLevel = collectCSEByLevel(expressions);
        assertEquals(cseByLevel, ImmutableMap.of(3, ImmutableMap.of(rowExpression("CASE WHEN x = 1 THEN y + z WHEN x = 2 THEN z * 2 END"), rowExpression("\"expr$cse\""))));
    }

    @Test
    void testNoRedundantCSE()
    {
        List<RowExpression> expressions = ImmutableList.of(rowExpression("x * 2 + y + z"), rowExpression("(x * 2 + y + z) * 2"), rowExpression("x * 2"));
        Map<Integer, Map<RowExpression, VariableReferenceExpression>> cseByLevel = collectCSEByLevel(expressions);
        // x * 2 + y is redundant thus should not appear in results
        assertEquals(cseByLevel, ImmutableMap.of(
                3, ImmutableMap.of(rowExpression("\"multiply$cse\" + y + z"), rowExpression("\"add$cse\"")),
                1, ImmutableMap.of(rowExpression("x * 2"), rowExpression("\"multiply$cse\""))));
    }

    @Test
    void testRewriteExpressionWithCSE()
    {
        assertEquals(
                rewriteExpressionWithCSE(
                        rowExpression("(x * y + z) * (y + z) + (x * y)"),
                        ImmutableMap.of(
                                rowExpression("x * y"), variable("multiply$cse"),
                                rowExpression("y + z"), variable("add$cse"),
                                rowExpression("\"multiply$cse\" + z"), variable("add$cse_0"))),
                rowExpression("\"add$cse_0\" * \"add$cse\" + \"multiply$cse\""));
    }

    private VariableReferenceExpression variable(String variable)
    {
        return new VariableReferenceExpression(Optional.empty(), variable, TYPES.allTypes().get(variable));
    }

    private RowExpression rowExpression(String sql)
    {
        Expression expression = rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(sql));
        Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypes(
                SESSION,
                METADATA,
                new SqlParser(),
                TYPES,
                expression,
                ImmutableMap.of(),
                WarningCollector.NOOP);
        return SqlToRowExpressionTranslator.translate(
                expression,
                expressionTypes,
                ImmutableMap.of(),
                METADATA.getFunctionAndTypeManager(),
                SESSION);
    }
}