TestJdbcComputePushdown.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.plugin.jdbc.optimization;

import com.facebook.presto.Session;
import com.facebook.presto.common.predicate.TupleDomain;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.StatsAndCosts;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.plugin.jdbc.JdbcColumnHandle;
import com.facebook.presto.plugin.jdbc.JdbcTableHandle;
import com.facebook.presto.plugin.jdbc.JdbcTableLayoutHandle;
import com.facebook.presto.plugin.jdbc.JdbcTypeHandle;
import com.facebook.presto.plugin.jdbc.optimization.function.OperatorTranslators;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.SchemaTableName;
import com.facebook.presto.spi.TableHandle;
import com.facebook.presto.spi.connector.ConnectorTransactionHandle;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.DeterminismEvaluator;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.TestingRowExpressionTranslator;
import com.facebook.presto.sql.planner.Plan;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.assertions.MatchResult;
import com.facebook.presto.sql.planner.assertions.Matcher;
import com.facebook.presto.sql.planner.assertions.PlanAssert;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.assertions.SymbolAliases;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.facebook.presto.sql.relational.RowExpressionOptimizer;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.testing.TestingConnectorSession;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.testng.annotations.Test;

import java.sql.Types;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node;
import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.toMap;
import static org.testng.Assert.assertTrue;

public class TestJdbcComputePushdown
{
    private static final Metadata METADATA = MetadataManager.createTestMetadataManager();
    private static final PlanBuilder PLAN_BUILDER = new PlanBuilder(TEST_SESSION, new PlanNodeIdAllocator(), METADATA);
    private static final PlanNodeIdAllocator ID_ALLOCATOR = new PlanNodeIdAllocator();
    private static final String CATALOG_NAME = "Jdbc";
    private static final String CONNECTOR_ID = new ConnectorId(CATALOG_NAME).toString();

    private final TestingRowExpressionTranslator sqlToRowExpressionTranslator;

    private final JdbcComputePushdown jdbcComputePushdown;

    public TestJdbcComputePushdown()
    {
        this.sqlToRowExpressionTranslator = new TestingRowExpressionTranslator(METADATA);
        FunctionAndTypeManager functionAndTypeManager = METADATA.getFunctionAndTypeManager();
        StandardFunctionResolution functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
        DeterminismEvaluator determinismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager);

        this.jdbcComputePushdown = new JdbcComputePushdown(
                functionAndTypeManager,
                functionResolution,
                determinismEvaluator,
                (ConnectorSession session) -> new RowExpressionOptimizer(METADATA),
                "'",
                getFunctionTranslators());
    }

    @Test
    public void testJdbcComputePushdownAll()
    {
        String table = "test_table";
        String schema = "test_schema";

        String expression = "(c1 + c2) - c2";
        TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("c1", BIGINT, "c2", BIGINT));
        RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider);
        Set<ColumnHandle> columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::integerJdbcColumnHandle).collect(Collectors.toSet());
        PlanNode original = filter(jdbcTableScan(schema, table, BIGINT, "c1", "c2"), rowExpression);

        JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table);
        ConnectorSession session = new TestingConnectorSession(ImmutableList.of());
        JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle(session.getSqlFunctionProperties(), jdbcTableHandle, TupleDomain.none(), Optional.of(new JdbcExpression("(('c1' + 'c2') - 'c2')")));

        PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR);
        assertPlanMatch(
                actual,
                PlanMatchPattern.filter(expression, JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns)));
    }

    @Test
    public void testJdbcComputePushdownBooleanOperations()
    {
        String table = "test_table";
        String schema = "test_schema";

        String expression = "(((c1 + c2) - c2 <> c2) OR c2 = c1) AND c1 <> c2";
        TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("c1", BIGINT, "c2", BIGINT));
        RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider);
        Set<ColumnHandle> columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::integerJdbcColumnHandle).collect(Collectors.toSet());
        PlanNode original = filter(jdbcTableScan(schema, table, BIGINT, "c1", "c2"), rowExpression);

        JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table);
        ConnectorSession session = new TestingConnectorSession(ImmutableList.of());
        JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle(
                session.getSqlFunctionProperties(),
                jdbcTableHandle,
                TupleDomain.none(),
                Optional.of(new JdbcExpression("((((((('c1' + 'c2') - 'c2') <> 'c2')) OR (('c2' = 'c1')))) AND (('c1' <> 'c2')))")));

        PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR);
        assertPlanMatch(
                actual,
                PlanMatchPattern.filter(expression, JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns)));
    }

    @Test
    public void testJdbcComputePushdownUnsupported()
    {
        String table = "test_table";
        String schema = "test_schema";

        String expression = "(c1 + c2) > c2";
        TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("c1", BIGINT, "c2", BIGINT));
        RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider);
        Set<ColumnHandle> columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::integerJdbcColumnHandle).collect(Collectors.toSet());
        PlanNode original = filter(jdbcTableScan(schema, table, BIGINT, "c1", "c2"), rowExpression);

        JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table);
        ConnectorSession session = new TestingConnectorSession(ImmutableList.of());
        // Test should expect an empty entry for translatedSql since > is an unsupported function currently in the optimizer
        JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle(session.getSqlFunctionProperties(), jdbcTableHandle, TupleDomain.none(), Optional.empty());

        PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR);
        assertPlanMatch(actual, PlanMatchPattern.filter(
                expression,
                JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns)));
    }

    @Test
    public void testJdbcComputePushdownWithConstants()
    {
        String table = "test_table";
        String schema = "test_schema";

        String expression = "(c1 + c2) = 3";
        TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("c1", BIGINT, "c2", BIGINT));
        RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider);
        Set<ColumnHandle> columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::integerJdbcColumnHandle).collect(Collectors.toSet());
        PlanNode original = filter(jdbcTableScan(schema, table, BIGINT, "c1", "c2"), rowExpression);

        JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table);
        ConnectorSession session = new TestingConnectorSession(ImmutableList.of());
        JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle(
                session.getSqlFunctionProperties(),
                jdbcTableHandle,
                TupleDomain.none(),
                Optional.of(new JdbcExpression("(('c1' + 'c2') = ?)", ImmutableList.of(new ConstantExpression(Long.valueOf(3), INTEGER)))));

        PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR);
        assertPlanMatch(actual, PlanMatchPattern.filter(
                expression,
                JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns)));
    }

    @Test
    public void testJdbcComputePushdownNotOperator()
    {
        String table = "test_table";
        String schema = "test_schema";

        String expression = "c1 AND NOT(c2)";
        TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("c1", BOOLEAN, "c2", BOOLEAN));
        RowExpression rowExpression = sqlToRowExpressionTranslator.translateAndOptimize(expression(expression), typeProvider);
        PlanNode original = filter(jdbcTableScan(schema, table, BOOLEAN, "c1", "c2"), rowExpression);

        Set<ColumnHandle> columns = Stream.of("c1", "c2").map(TestJdbcComputePushdown::booleanJdbcColumnHandle).collect(Collectors.toSet());
        JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table);
        ConnectorSession session = new TestingConnectorSession(ImmutableList.of());
        JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle(
                session.getSqlFunctionProperties(),
                jdbcTableHandle,
                TupleDomain.none(),
                Optional.of(new JdbcExpression("(('c1') AND ((NOT('c2'))))")));

        PlanNode actual = this.jdbcComputePushdown.optimize(original, session, null, ID_ALLOCATOR);
        assertPlanMatch(actual, PlanMatchPattern.filter(
                expression,
                JdbcTableScanMatcher.jdbcTableScanPattern(jdbcTableLayoutHandle, columns)));
    }

    private Set<Class<?>> getFunctionTranslators()
    {
        return ImmutableSet.of(OperatorTranslators.class);
    }

    private static VariableReferenceExpression newVariable(String name, Type type)
    {
        return new VariableReferenceExpression(Optional.empty(), name, type);
    }

    private static JdbcColumnHandle integerJdbcColumnHandle(String name)
    {
        return new JdbcColumnHandle(CONNECTOR_ID, name, new JdbcTypeHandle(Types.BIGINT, "integer", 10, 0), BIGINT, false, Optional.empty());
    }

    private static JdbcColumnHandle booleanJdbcColumnHandle(String name)
    {
        return new JdbcColumnHandle(CONNECTOR_ID, name, new JdbcTypeHandle(Types.BOOLEAN, "boolean", 1, 0), BOOLEAN, false, Optional.empty());
    }

    private static JdbcColumnHandle getColumnHandleForVariable(String name, Type type)
    {
        return type.equals(BOOLEAN) ? booleanJdbcColumnHandle(name) : integerJdbcColumnHandle(name);
    }

    private static void assertPlanMatch(PlanNode actual, PlanMatchPattern expected)
    {
        assertPlanMatch(actual, expected, TypeProvider.empty());
    }

    private static void assertPlanMatch(PlanNode actual, PlanMatchPattern expected, TypeProvider typeProvider)
    {
        // always check the actual plan node has a filter node to prevent accidentally missing predicates
        Queue<PlanNode> nodes = new LinkedList<>(ImmutableSet.of(actual));
        boolean hasFilterNode = false;
        while (!nodes.isEmpty()) {
            PlanNode node = nodes.poll();
            nodes.addAll(node.getSources());

            // we always enforce the original filter to be present
            // the following assertPlan will guarantee the completeness of the filter
            if (node instanceof FilterNode && node.getSources().get(0) instanceof TableScanNode) {
                hasFilterNode = true;
                break;
            }
        }
        assertTrue(hasFilterNode, "filter is missing from the pushdown plan");

        PlanAssert.assertPlan(
                TEST_SESSION,
                METADATA,
                (node, sourceStats, lookup, session, types) -> PlanNodeStatsEstimate.unknown(),
                new Plan(actual, typeProvider, StatsAndCosts.empty()),
                expected);
    }

    private TableScanNode jdbcTableScan(String schema, String table, Type type, String... columnNames)
    {
        JdbcTableHandle jdbcTableHandle = new JdbcTableHandle(CONNECTOR_ID, new SchemaTableName(schema, table), CATALOG_NAME, schema, table);
        JdbcTableLayoutHandle jdbcTableLayoutHandle = new JdbcTableLayoutHandle(TEST_SESSION.getSqlFunctionProperties(), jdbcTableHandle, TupleDomain.none(), Optional.empty());
        TableHandle tableHandle = new TableHandle(new ConnectorId(CATALOG_NAME), jdbcTableHandle, new ConnectorTransactionHandle() {}, Optional.of(jdbcTableLayoutHandle));

        return PLAN_BUILDER.tableScan(
                tableHandle,
                Arrays.stream(columnNames).map(column -> newVariable(column, type)).collect(toImmutableList()),
                Arrays.stream(columnNames)
                        .map(column -> newVariable(column, type))
                        .collect(toMap(identity(), entry -> getColumnHandleForVariable(entry.getName(), type))));
    }

    private FilterNode filter(PlanNode source, RowExpression predicate)
    {
        return PLAN_BUILDER.filter(predicate, source);
    }

    private static final class JdbcTableScanMatcher
            implements Matcher
    {
        private final JdbcTableLayoutHandle jdbcTableLayoutHandle;
        private final Set<ColumnHandle> columns;

        static PlanMatchPattern jdbcTableScanPattern(JdbcTableLayoutHandle jdbcTableLayoutHandle, Set<ColumnHandle> columns)
        {
            return node(TableScanNode.class).with(new JdbcTableScanMatcher(jdbcTableLayoutHandle, columns));
        }

        private JdbcTableScanMatcher(JdbcTableLayoutHandle jdbcTableLayoutHandle, Set<ColumnHandle> columns)
        {
            this.jdbcTableLayoutHandle = jdbcTableLayoutHandle;
            this.columns = columns;
        }

        @Override
        public boolean shapeMatches(PlanNode node)
        {
            return node instanceof TableScanNode;
        }

        @Override
        public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases)
        {
            checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName());

            TableScanNode tableScanNode = (TableScanNode) node;
            JdbcTableLayoutHandle layoutHandle = (JdbcTableLayoutHandle) tableScanNode.getTable().getLayout().get();
            if (jdbcTableLayoutHandle.getTable().equals(layoutHandle.getTable())
                    && jdbcTableLayoutHandle.getTupleDomain().equals(layoutHandle.getTupleDomain())
                    && ((!jdbcTableLayoutHandle.getAdditionalPredicate().isPresent() && !layoutHandle.getAdditionalPredicate().isPresent())
                        || jdbcTableLayoutHandle.getAdditionalPredicate().get().getExpression().equals(layoutHandle.getAdditionalPredicate().get().getExpression()))) {
                return MatchResult.match(
                        SymbolAliases.builder().putAll(
                        columns.stream()
                                .map(column -> ((JdbcColumnHandle) column).getColumnName())
                                .collect(toMap(identity(), SymbolReference::new)))
                        .build());
            }

            return MatchResult.NO_MATCH;
        }
    }
}