TestConnectorOptimization.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.common.predicate.TupleDomain;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.StatsAndCosts;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.spi.ConnectorPlanOptimizer;
import com.facebook.presto.spi.ConnectorTableLayoutHandle;
import com.facebook.presto.spi.TableHandle;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.OutputNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.PlanVisitor;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.UnionNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.Plan;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.assertions.MatchResult;
import com.facebook.presto.sql.planner.assertions.Matcher;
import com.facebook.presto.sql.planner.assertions.PlanAssert;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.assertions.SymbolAliases;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.testng.annotations.Test;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT;
import static com.facebook.presto.expressions.LogicalRowExpressions.and;
import static com.facebook.presto.expressions.LogicalRowExpressions.or;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.lang.reflect.Modifier.isFinal;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.toMap;
import static org.testng.Assert.assertEquals;

public class TestConnectorOptimization
{
    private static final Metadata METADATA = MetadataManager.createTestMetadataManager();
    private static final PlanBuilder PLAN_BUILDER = new PlanBuilder(TEST_SESSION, new PlanNodeIdAllocator(), METADATA);

    @Test
    public void testSupportedPlanNodes()
    {
        @SuppressWarnings("unchecked")
        Set<Class<? extends PlanNode>> expected = Arrays.stream(PlanVisitor.class.getDeclaredMethods())
                .map(Method::getParameterTypes)
                .filter(parameterTypes -> parameterTypes.length > 0)
                .filter(parameterTypes -> PlanNode.class.isAssignableFrom(parameterTypes[0]))  // is accessible in SPI
                .filter(parameterTypes -> isFinal(parameterTypes[0].getModifiers()))  // is a final class
                .map(parameterTypes -> (Class<? extends PlanNode>) parameterTypes[0])
                .collect(toImmutableSet());

        assertEquals(ApplyConnectorOptimization.CONNECTOR_ACCESSIBLE_PLAN_NODES, expected);
    }

    @Test
    public void testEmptyOptimizers()
    {
        PlanNode plan = output(filter(tableScan("cat1", "a", "b"), TRUE_CONSTANT), "a");
        PlanNode actual = optimize(plan, ImmutableMap.of());
        assertEquals(actual, plan);

        actual = optimize(plan, ImmutableMap.of(new ConnectorId("cat2"), ImmutableSet.of(noop())));
        assertEquals(actual, plan);
    }

    @Test
    public void testMultipleConnectors()
    {
        PlanNode plan = output(
                union(
                        tableScan("cat1", "a", "b"),
                        tableScan("cat2", "a", "b"),
                        tableScan("cat3", "a", "b"),
                        tableScan("cat4", "a", "b"),
                        tableScan("cat2", "a", "b"),
                        tableScan("cat1", "a", "b"),
                        values("a", "b")),
                "a");

        PlanNode actual = optimize(plan, ImmutableMap.of());
        assertEquals(actual, plan);

        actual = optimize(plan, ImmutableMap.of(new ConnectorId("cat2"), ImmutableSet.of(noop())));
        assertEquals(actual, plan);
    }

    @Test
    public void testPlanUpdateWithComplexStructures()
    {
        PlanNode plan = output(
                union(
                        filter(tableScan("cat1", "a", "b"), TRUE_CONSTANT),
                        filter(tableScan("cat2", "a", "b"), TRUE_CONSTANT),
                        union(
                                filter(tableScan("cat3", "a", "b"), TRUE_CONSTANT),
                                union(
                                        filter(tableScan("cat4", "a", "b"), TRUE_CONSTANT),
                                        filter(tableScan("cat1", "a", "b"), TRUE_CONSTANT))),
                        filter(tableScan("cat2", "a", "b"), TRUE_CONSTANT),
                        union(filter(tableScan("cat1", "a", "b"), TRUE_CONSTANT))),
                "a");

        PlanNode actual = optimize(plan, ImmutableMap.of());
        assertEquals(actual, plan);

        // force updating every leaf node
        actual = optimize(
                plan,
                ImmutableMap.of(
                        new ConnectorId("cat1"), ImmutableSet.of(filterPushdown()),
                        new ConnectorId("cat2"), ImmutableSet.of(filterPushdown()),
                        new ConnectorId("cat3"), ImmutableSet.of(filterPushdown()),
                        new ConnectorId("cat4"), ImmutableSet.of(filterPushdown())));

        // assert all filters removed
        assertPlanMatch(
                actual,
                PlanMatchPattern.output(
                        PlanMatchPattern.union(
                                SimpleTableScanMatcher.tableScan("cat1", TRUE_CONSTANT),
                                SimpleTableScanMatcher.tableScan("cat2", TRUE_CONSTANT),
                                PlanMatchPattern.union(
                                        SimpleTableScanMatcher.tableScan("cat3", TRUE_CONSTANT),
                                        PlanMatchPattern.union(
                                                SimpleTableScanMatcher.tableScan("cat4", TRUE_CONSTANT),
                                                SimpleTableScanMatcher.tableScan("cat1", TRUE_CONSTANT))),
                                SimpleTableScanMatcher.tableScan("cat2", TRUE_CONSTANT),
                                PlanMatchPattern.union(
                                        SimpleTableScanMatcher.tableScan("cat1", TRUE_CONSTANT)))));
    }

    @Test
    public void testPushFilterToTableScan()
    {
        RowExpression expectedPredicate = and(newBigintVariable("a"), newBigintVariable("b"));
        PlanNode plan = output(
                filter(
                        tableScan("cat1", "a", "b"),
                        expectedPredicate),
                "a");
        PlanNode actual = optimize(plan, ImmutableMap.of(new ConnectorId("cat1"), ImmutableSet.of(filterPushdown())));

        // assert structure; FilterNode is removed
        assertPlanMatch(actual, PlanMatchPattern.output(SimpleTableScanMatcher.tableScan("cat1", expectedPredicate)));
    }

    @Test
    public void testAddFilterToTableScan()
    {
        RowExpression expectedPredicate = and(newBigintVariable("a"), newBigintVariable("b"));

        // (1) without filter node case
        PlanNode plan = output(tableScan("cat1", "a", "b"), "a");
        PlanNode actual = optimize(plan, ImmutableMap.of(new ConnectorId("cat1"), ImmutableSet.of(addFilterToTableScan(expectedPredicate))));

        // assert FilterNode is added
        assertPlanMatch(
                actual,
                PlanMatchPattern.output(
                        PlanMatchPattern.filter(
                                "a AND b",
                                SimpleTableScanMatcher.tableScan("cat1", "a", "b"))),
                TypeProvider.viewOf(ImmutableMap.of("a", BIGINT, "b", BIGINT)));

        // (2) with filter node case
        RowExpression existingPredicate = or(newBigintVariable("a"), newBigintVariable("b"));
        plan = output(
                filter(
                        tableScan("cat1", "a", "b"),
                        existingPredicate),
                "a");
        actual = optimize(plan, ImmutableMap.of(new ConnectorId("cat1"), ImmutableSet.of(addFilterToTableScan(expectedPredicate))));

        // assert filter gets added as a part of conjuncts
        assertPlanMatch(
                actual,
                PlanMatchPattern.output(
                        PlanMatchPattern.filter(
                                "(a OR b) AND (a AND b)",
                                SimpleTableScanMatcher.tableScan("cat1", "a", "b"))),
                TypeProvider.viewOf(ImmutableMap.of("a", BIGINT, "b", BIGINT)));
    }

    private TableScanNode tableScan(String connectorName, String... columnNames)
    {
        return PLAN_BUILDER.tableScan(
                connectorName,
                Arrays.stream(columnNames).map(TestConnectorOptimization::newBigintVariable).collect(toImmutableList()),
                Arrays.stream(columnNames).map(TestConnectorOptimization::newBigintVariable).collect(toMap(identity(), variable -> new ColumnHandle() {})));
    }

    private FilterNode filter(PlanNode source, RowExpression predicate)
    {
        return PLAN_BUILDER.filter(predicate, source);
    }

    private OutputNode output(PlanNode source, String... columnNames)
    {
        return PLAN_BUILDER.output(
                Arrays.stream(columnNames).collect(toImmutableList()),
                Arrays.stream(columnNames).map(TestConnectorOptimization::newBigintVariable).collect(toImmutableList()),
                source);
    }

    private UnionNode union(PlanNode... sources)
    {
        ImmutableListMultimap.Builder<VariableReferenceExpression, VariableReferenceExpression> outputsToInputs = ImmutableListMultimap.builder();
        for (PlanNode source : sources) {
            outputsToInputs.putAll(source.getOutputVariables().stream().collect(toMap(identity(), identity())).entrySet());
        }
        return PLAN_BUILDER.union(outputsToInputs.build(), Arrays.asList(sources));
    }

    private ValuesNode values(String... columnNames)
    {
        VariableReferenceExpression[] columns = new VariableReferenceExpression[columnNames.length];
        for (int i = 0; i < columnNames.length; i++) {
            columns[i] = newBigintVariable(columnNames[i]);
        }
        return PLAN_BUILDER.values(5, columns);
    }

    private static VariableReferenceExpression newBigintVariable(String name)
    {
        return new VariableReferenceExpression(Optional.empty(), name, BIGINT);
    }

    private static void assertPlanMatch(PlanNode actual, PlanMatchPattern expected)
    {
        assertPlanMatch(actual, expected, TypeProvider.empty());
    }

    private static void assertPlanMatch(PlanNode actual, PlanMatchPattern expected, TypeProvider typeProvider)
    {
        PlanAssert.assertPlan(
                TEST_SESSION,
                METADATA,
                (node, sourceStats, lookup, session, types) -> PlanNodeStatsEstimate.unknown(),
                new Plan(actual, typeProvider, StatsAndCosts.empty()),
                expected);
    }

    private static PlanNode optimize(PlanNode plan, Map<ConnectorId, Set<ConnectorPlanOptimizer>> optimizers)
    {
        ApplyConnectorOptimization optimizer = new ApplyConnectorOptimization(() -> optimizers);
        return optimizer.optimize(plan, TEST_SESSION, TypeProvider.empty(), new VariableAllocator(), new PlanNodeIdAllocator(), WarningCollector.NOOP).getPlanNode();
    }

    private static ConnectorPlanOptimizer filterPushdown()
    {
        return (maxSubplan, session, variableAllocator, idAllocator) -> maxSubplan.accept(new TestFilterPushdownVisitor(), null);
    }

    private static ConnectorPlanOptimizer addFilterToTableScan(RowExpression filter)
    {
        return (maxSubplan, session, variableAllocator, idAllocator) -> maxSubplan.accept(new TestAddFilterVisitor(filter, idAllocator), null);
    }

    private static ConnectorPlanOptimizer noop()
    {
        return (maxSubplan, session, variableAllocator, idAllocator) -> maxSubplan;
    }

    private static class TestPlanOptimizationVisitor
            extends PlanVisitor<PlanNode, Void>
    {
        @Override
        public PlanNode visitPlan(PlanNode node, Void context)
        {
            ImmutableList.Builder<PlanNode> children = ImmutableList.builder();
            for (PlanNode child : node.getSources()) {
                children.add(child.accept(this, null));
            }
            return node.replaceChildren(children.build());
        }
    }

    private static class TestFilterPushdownVisitor
            extends TestPlanOptimizationVisitor
    {
        @Override
        public PlanNode visitFilter(FilterNode node, Void context)
        {
            if (node.getSource() instanceof TableScanNode) {
                TableScanNode tableScanNode = (TableScanNode) node.getSource();
                TableHandle handle = tableScanNode.getTable();
                return new TableScanNode(
                        Optional.empty(),
                        tableScanNode.getId(),
                        new TableHandle(
                                handle.getConnectorId(),
                                handle.getConnectorHandle(),
                                handle.getTransaction(),
                                Optional.of(new TestConnectorTableLayoutHandle(node.getPredicate()))),
                        tableScanNode.getOutputVariables(),
                        tableScanNode.getAssignments(),
                        tableScanNode.getTableConstraints(),
                        TupleDomain.all(),
                        TupleDomain.all(), Optional.empty());
            }
            return node;
        }

        static class TestConnectorTableLayoutHandle
                implements ConnectorTableLayoutHandle
        {
            private final RowExpression predicate;

            TestConnectorTableLayoutHandle(RowExpression predicate)
            {
                this.predicate = predicate;
            }

            @Override
            public boolean equals(Object obj)
            {
                if (this == obj) {
                    return true;
                }

                if (!(obj instanceof TestConnectorTableLayoutHandle)) {
                    return false;
                }

                TestConnectorTableLayoutHandle other = (TestConnectorTableLayoutHandle) obj;
                return Objects.equals(predicate, other.predicate);
            }

            @Override
            public int hashCode()
            {
                return Objects.hashCode(predicate);
            }
        }
    }

    private static class TestAddFilterVisitor
            extends TestPlanOptimizationVisitor
    {
        private final RowExpression filter;
        private final PlanNodeIdAllocator idAllocator;

        TestAddFilterVisitor(RowExpression filter, PlanNodeIdAllocator idAllocator)
        {
            this.filter = filter;
            this.idAllocator = idAllocator;
        }

        @Override
        public PlanNode visitFilter(FilterNode node, Void context)
        {
            if (node.getSource() instanceof TableScanNode) {
                return new FilterNode(Optional.empty(), node.getId(), node.getSource(), and(node.getPredicate(), filter));
            }
            return node;
        }

        @Override
        public PlanNode visitTableScan(TableScanNode node, Void context)
        {
            return new FilterNode(Optional.empty(), idAllocator.getNextId(), node, filter);
        }
    }

    /**
     * A simplified table scan matcher for multiple-connector support.
     * The goal is to test plan structural matching rather than table scan details
     */
    private static final class SimpleTableScanMatcher
            implements Matcher
    {
        private final ConnectorId connectorId;
        private final Optional<ConnectorTableLayoutHandle> connectorTableLayoutHandle;
        private final String[] columns;

        public static PlanMatchPattern tableScan(String connectorName, RowExpression predicate, String... columnNames)
        {
            return node(TableScanNode.class)
                    .with(new SimpleTableScanMatcher(
                            new ConnectorId(connectorName),
                            Optional.ofNullable(predicate).map(TestFilterPushdownVisitor.TestConnectorTableLayoutHandle::new),
                            columnNames));
        }

        public static PlanMatchPattern tableScan(String connectorName, String... columnNames)
        {
            return tableScan(connectorName, null, columnNames);
        }

        private SimpleTableScanMatcher(
                ConnectorId connectorId,
                Optional<ConnectorTableLayoutHandle> connectorTableLayoutHandle,
                String... columns)
        {
            this.connectorId = connectorId;
            this.connectorTableLayoutHandle = connectorTableLayoutHandle;
            this.columns = columns;
        }

        @Override
        public boolean shapeMatches(PlanNode node)
        {
            return node instanceof TableScanNode;
        }

        @Override
        public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases)
        {
            checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName());

            TableScanNode tableScanNode = (TableScanNode) node;
            if (connectorId.equals(tableScanNode.getTable().getConnectorId()) &&
                    connectorTableLayoutHandle.equals(tableScanNode.getTable().getLayout())) {
                return MatchResult.match(SymbolAliases.builder().putAll(Arrays.stream(columns).collect(toMap(identity(), SymbolReference::new))).build());
            }

            return MatchResult.NO_MATCH;
        }

        @Override
        public String toString()
        {
            return toStringHelper(this)
                    .omitNullValues()
                    .add("connectorId", connectorId)
                    .add("connectorTableLayoutHandle", connectorTableLayoutHandle.orElse(null))
                    .toString();
        }
    }
}