TestIterativePlanFragmenter.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.spark.planner;

import com.facebook.presto.Session;
import com.facebook.presto.common.predicate.TupleDomain;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.cost.StatsAndCosts;
import com.facebook.presto.dispatcher.NoOpQueryManager;
import com.facebook.presto.execution.NodeTaskMap;
import com.facebook.presto.execution.QueryManagerConfig;
import com.facebook.presto.execution.scheduler.LegacyNetworkTopology;
import com.facebook.presto.execution.scheduler.NodeScheduler;
import com.facebook.presto.execution.scheduler.NodeSchedulerConfig;
import com.facebook.presto.execution.scheduler.nodeSelection.NodeSelectionStats;
import com.facebook.presto.execution.scheduler.nodeSelection.SimpleTtlNodeSelectorConfig;
import com.facebook.presto.metadata.CatalogManager;
import com.facebook.presto.metadata.InMemoryNodeManager;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spark.planner.IterativePlanFragmenter.PlanAndFragments;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.spi.TableHandle;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.JoinDistributionType;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.PartitioningHandle;
import com.facebook.presto.spi.plan.PartitioningScheme;
import com.facebook.presto.spi.plan.PlanFragmentId;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.SimplePlanFragment;
import com.facebook.presto.spi.plan.StageExecutionDescriptor;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.spi.security.AllowAllAccessControl;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.planner.NodePartitioningManager;
import com.facebook.presto.sql.planner.PartitioningProviderManager;
import com.facebook.presto.sql.planner.Plan;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.PlanFragmenter;
import com.facebook.presto.sql.planner.SubPlan;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.JsonCodecSimplePlanFragmentSerde;
import com.facebook.presto.sql.planner.sanity.PlanChecker;
import com.facebook.presto.sql.planner.sanity.PlanCheckerProviderManager;
import com.facebook.presto.sql.planner.sanity.PlanCheckerProviderManagerConfig;
import com.facebook.presto.tpch.TpchColumnHandle;
import com.facebook.presto.tpch.TpchTableHandle;
import com.facebook.presto.tpch.TpchTableLayoutHandle;
import com.facebook.presto.tpch.TpchTransactionHandle;
import com.facebook.presto.transaction.TransactionManager;
import com.facebook.presto.ttl.nodettlfetchermanagers.ThrowingNodeTtlFetcherManager;
import com.facebook.presto.util.FinalizerService;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;

import static com.facebook.airlift.json.JsonCodec.jsonCodec;
import static com.facebook.presto.SystemSessionProperties.FORCE_SINGLE_NODE_OUTPUT;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager;
import static com.facebook.presto.spark.planner.TestIterativePlanFragmenter.CanonicalTestFragment.toCanonicalTestFragment;
import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet;
import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment;
import static com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils.count;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.systemPartitionedExchange;
import static com.facebook.presto.sql.relational.Expressions.variable;
import static com.facebook.presto.testing.TestingSession.createBogusTestingCatalog;
import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
import static com.facebook.presto.transaction.InMemoryTransactionManager.createTestTransactionManager;
import static com.facebook.presto.transaction.TransactionBuilder.transaction;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.Objects.requireNonNull;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;

public class TestIterativePlanFragmenter
{
    private PlanFragmenter planFragmenter;
    private Session session;
    private MetadataManager metadata;
    private TransactionManager transactionManager;
    private FinalizerService finalizerService;
    private NodeScheduler nodeScheduler;
    private NodePartitioningManager nodePartitioningManager;
    private PlanCheckerProviderManager planCheckerProviderManager;

    @BeforeClass
    public void setUp()
    {
        session = testSessionBuilder()
                .setCatalog("tpch")
                .setSystemProperty(FORCE_SINGLE_NODE_OUTPUT, "false")
                .build();

        CatalogManager catalogManager = new CatalogManager();
        catalogManager.registerCatalog(createBogusTestingCatalog("tpch"));
        transactionManager = createTestTransactionManager(catalogManager);
        metadata = createTestMetadataManager(transactionManager);

        finalizerService = new FinalizerService();
        finalizerService.start();
        nodeScheduler = new NodeScheduler(
                new LegacyNetworkTopology(),
                new InMemoryNodeManager(),
                new NodeSelectionStats(),
                new NodeSchedulerConfig().setIncludeCoordinator(true),
                new NodeTaskMap(finalizerService),
                new ThrowingNodeTtlFetcherManager(),
                new NoOpQueryManager(),
                new SimpleTtlNodeSelectorConfig());
        PartitioningProviderManager partitioningProviderManager = new PartitioningProviderManager();
        nodePartitioningManager = new NodePartitioningManager(nodeScheduler, partitioningProviderManager, new NodeSelectionStats());
        planCheckerProviderManager = new PlanCheckerProviderManager(new JsonCodecSimplePlanFragmentSerde(jsonCodec(SimplePlanFragment.class)), new PlanCheckerProviderManagerConfig());
        planFragmenter = new PlanFragmenter(metadata, nodePartitioningManager, new QueryManagerConfig(), new FeaturesConfig(), planCheckerProviderManager);
    }

    @AfterClass(alwaysRun = true)
    public void tearDown()
    {
        planFragmenter = null;
        session = null;
        transactionManager = null;
        metadata = null;
        finalizerService.destroy();
        finalizerService = null;
        nodeScheduler.stop();
        nodeScheduler = null;
        nodePartitioningManager = null;
    }

    @Test
    public void testIterativePlanFragmenter()
    {
        TableScanNode ts1 = tableScan("ts1", "orderkey");
        TableScanNode ts2 = tableScan("ts2", "orderkey_0");
        PlanNode p1 = project("p1", ts1, variable("orderkey_1", BIGINT), variable("orderkey", BIGINT));
        ExchangeNode remoteExchange1 = systemPartitionedExchange(
                new PlanNodeId("re1"),
                REMOTE_STREAMING,
                p1,
                ImmutableList.of(new VariableReferenceExpression(Optional.empty(), "orderkey_1", BIGINT)),
                Optional.empty());
        ExchangeNode remoteExchange2 = systemPartitionedExchange(
                new PlanNodeId("re2"),
                REMOTE_STREAMING,
                ts2,
                ImmutableList.of(new VariableReferenceExpression(Optional.empty(), "orderkey_0", BIGINT)),
                Optional.empty());
        ExchangeNode localExchange = systemPartitionedExchange(
                new PlanNodeId("le"),
                LOCAL,
                remoteExchange2,
                ImmutableList.of(new VariableReferenceExpression(Optional.empty(), "orderkey_0", BIGINT)),
                Optional.empty());

        JoinNode join = join("join",
                remoteExchange1,
                localExchange,
                JoinDistributionType.PARTITIONED,
                "orderkey_1",
                "orderkey_0");
        Map<String, Type> types = ImmutableMap.of(
                "orderkey", BIGINT,
                "orderkey_1", BIGINT,
                "orderkey_0", BIGINT);
        TypeProvider typeProvider = TypeProvider.copyOf(types);
        Plan plan = new Plan(join, typeProvider, StatsAndCosts.empty());

        SubPlan fullFragmentedPlan = getFullFragmentedPlan(plan);

        inTransaction(session -> runTestIterativePlanFragmenter(join, plan, fullFragmentedPlan, session));
    }

    private Void runTestIterativePlanFragmenter(PlanNode node, Plan plan, SubPlan fullFragmentedPlan, Session session)
    {
        TestingFragmentTracker testingFragmentTracker = new TestingFragmentTracker();
        IterativePlanFragmenter iterativePlanFragmenter = new IterativePlanFragmenter(
                plan,
                testingFragmentTracker::isFragmentFinished,
                metadata,
                new PlanChecker(new FeaturesConfig(), planCheckerProviderManager),
                new PlanNodeIdAllocator(),
                nodePartitioningManager,
                new QueryManagerConfig(),
                session,
                WarningCollector.NOOP,
                false);

        PlanAndFragments nextPlanAndFragments = getNextPlanAndFragments(iterativePlanFragmenter, node);
        assertTrue(nextPlanAndFragments.getRemainingPlan().isPresent());
        assertEquals(nextPlanAndFragments.getReadyFragments().size(), 2);

        // nothing new is ready for execution, you are returned the same plan you sent in
        // and no fragments.
        PlanAndFragments previousPlanAndFragments = nextPlanAndFragments;
        nextPlanAndFragments = getNextPlanAndFragments(iterativePlanFragmenter, previousPlanAndFragments.getRemainingPlan().get());
        assertTrue(nextPlanAndFragments.getReadyFragments().isEmpty());
        assertTrue(nextPlanAndFragments.getRemainingPlan().isPresent());
        assertEquals(previousPlanAndFragments.getRemainingPlan().get(), nextPlanAndFragments.getRemainingPlan().get());

        // finish one fragment
        // still nothing is ready for execution as the join stage has two dependencies
        previousPlanAndFragments = nextPlanAndFragments;
        testingFragmentTracker.addFinishedFragment(new PlanFragmentId(1));

        nextPlanAndFragments = getNextPlanAndFragments(iterativePlanFragmenter, previousPlanAndFragments.getRemainingPlan().get());
        assertEquals(previousPlanAndFragments, nextPlanAndFragments);

        testingFragmentTracker.addFinishedFragment(new PlanFragmentId(2));
        previousPlanAndFragments = nextPlanAndFragments;
        nextPlanAndFragments = getNextPlanAndFragments(iterativePlanFragmenter, previousPlanAndFragments.getRemainingPlan().get());

        // when the root fragment is ready to execute, there should be no remaining plan left
        assertFalse(nextPlanAndFragments.getRemainingPlan().isPresent());
        assertEquals(nextPlanAndFragments.getReadyFragments().size(), 1);

        assertSubPlansEquivalent(nextPlanAndFragments.getReadyFragments().get(0), fullFragmentedPlan);
        return null;
    }

    private void assertSubPlansEquivalent(SubPlan subPlan1, SubPlan subPlan2)
    {
        assertEquals(toCanonicalTestFragment(subPlan1.getFragment()), toCanonicalTestFragment(subPlan2.getFragment()));
        Set<CanonicalTestFragment> subPlan1Children = subPlan1.getChildren().stream()
                .map(child -> toCanonicalTestFragment(child.getFragment()))
                .collect(toImmutableSet());
        Set<CanonicalTestFragment> subPlan2Children = subPlan2.getChildren().stream()
                .map(child -> toCanonicalTestFragment(child.getFragment()))
                .collect(toImmutableSet());

        assertEquals(subPlan1Children, subPlan2Children);
    }

    private TableScanNode tableScan(String id, String... symbols)
    {
        List<VariableReferenceExpression> variables = Arrays.stream(symbols)
                .map(symbol -> new VariableReferenceExpression(Optional.empty(), symbol, BIGINT))
                .collect(toImmutableList());
        return tableScan(id, variables);
    }

    private TableScanNode tableScan(String id, List<VariableReferenceExpression> variables)
    {
        ImmutableMap.Builder<VariableReferenceExpression, ColumnHandle> assignments = ImmutableMap.builder();

        for (VariableReferenceExpression variable : variables) {
            assignments.put(variable, new TpchColumnHandle("orderkey", BIGINT));
        }

        TpchTableHandle tableHandle = new TpchTableHandle("orders", 1.0);
        return new TableScanNode(
                Optional.empty(),
                new PlanNodeId(id),
                new TableHandle(
                        new ConnectorId("tpch"),
                        tableHandle,
                        TpchTransactionHandle.INSTANCE,
                        Optional.of(new TpchTableLayoutHandle(tableHandle, TupleDomain.all()))),
                variables,
                assignments.build(),
                TupleDomain.all(),
                TupleDomain.all(), Optional.empty());
    }

    private PlanNode project(String id, PlanNode source, VariableReferenceExpression variable, RowExpression expression)
    {
        return new ProjectNode(
                new PlanNodeId(id),
                source,
                assignment(variable, expression));
    }

    private AggregationNode aggregation(String id, PlanNode source)
    {
        AggregationNode.Aggregation aggregation = count(metadata.getFunctionAndTypeManager());

        return new AggregationNode(
                Optional.empty(),
                new PlanNodeId(id),
                source,
                ImmutableMap.of(new VariableReferenceExpression(Optional.empty(), "count", BIGINT), aggregation),
                singleGroupingSet(source.getOutputVariables()),
                ImmutableList.of(),
                AggregationNode.Step.FINAL,
                Optional.empty(),
                Optional.empty(),
                Optional.empty());
    }

    /**
     * EquiJoinClause is created from symbols in form of:
     * symbol[0] = symbol[1] AND symbol[2] = symbol[3] AND ...
     */
    private JoinNode join(String planNodeId, PlanNode left, PlanNode right, JoinDistributionType distributionType, String... symbols)
    {
        checkArgument(symbols.length % 2 == 0);
        ImmutableList.Builder<EquiJoinClause> criteria = ImmutableList.builder();

        for (int i = 0; i < symbols.length; i += 2) {
            criteria.add(new EquiJoinClause(new VariableReferenceExpression(Optional.empty(), symbols[i], BIGINT), new VariableReferenceExpression(Optional.empty(), symbols[i + 1], BIGINT)));
        }

        return new JoinNode(
                Optional.empty(),
                new PlanNodeId(planNodeId),
                JoinType.INNER,
                left,
                right,
                criteria.build(),
                ImmutableList.<VariableReferenceExpression>builder()
                        .addAll(left.getOutputVariables())
                        .addAll(right.getOutputVariables())
                        .build(),
                Optional.empty(),
                Optional.empty(),
                Optional.empty(),
                Optional.of(distributionType),
                ImmutableMap.of());
    }

    private SubPlan getFullFragmentedPlan(Plan plan)
    {
        return inTransaction(session -> planFragmenter.createSubPlans(session, plan, false, new PlanNodeIdAllocator(), WarningCollector.NOOP));
    }

    private PlanAndFragments getNextPlanAndFragments(IterativePlanFragmenter iterativePlanFragmenter, PlanNode node)
    {
        return iterativePlanFragmenter.createReadySubPlans(node);
    }

    private <T> T inTransaction(Function<Session, T> transactionSessionConsumer)
    {
        return transaction(transactionManager, new AllowAllAccessControl())
                .singleStatement()
                .execute(session, session -> {
                    // metadata.getCatalogHandle() registers the catalog for the transaction
                    session.getCatalog().ifPresent(catalog -> metadata.getCatalogHandle(session, catalog));
                    return transactionSessionConsumer.apply(session);
                });
    }

    private static class TestingFragmentTracker
    {
        private final Set<PlanFragmentId> finishedFragments = new HashSet<>();

        public void addFinishedFragment(PlanFragmentId id)
        {
            finishedFragments.add(id);
        }

        public boolean isFragmentFinished(PlanFragmentId id)
        {
            return finishedFragments.contains(id);
        }
    }

    static class CanonicalTestFragment
    {
        // it's tricky to compare plans between fragments
        // because the remotes sources will be different
        // just make sure they have the same root node
        // for sanity checking
        private final Class<PlanNode> clazz;
        private final Set<VariableReferenceExpression> variables;
        private final PartitioningHandle partitioning;
        private final List<PlanNodeId> tableScanSchedulingOrder;
        private final List<Type> types;

        // can't compare the remoteSourceNodes themselves
        // because fragment numbering can differ,
        // so just ensure that there are the same number
        private final int numberOfRemoteSourceNodes;
        private final PartitioningScheme partitioningScheme;
        private final StageExecutionDescriptor stageExecutionDescriptor;
        private final boolean outputTableWriterFragment;
        private final Optional<StatsAndCosts> statsAndCosts;

        public CanonicalTestFragment(
                Class<PlanNode> clazz,
                Set<VariableReferenceExpression> variables,
                PartitioningHandle partitioning,
                List<PlanNodeId> tableScanSchedulingOrder,
                List<Type> types,
                int numberOfRemoteSourceNodes,
                PartitioningScheme partitioningScheme,
                StageExecutionDescriptor stageExecutionDescriptor,
                boolean outputTableWriterFragment,
                Optional<StatsAndCosts> statsAndCosts)
        {
            this.clazz = requireNonNull(clazz, "clazz is null");
            this.variables = ImmutableSet.copyOf(requireNonNull(variables, "variables is null"));
            this.partitioning = requireNonNull(partitioning, "partitioning is null");
            this.tableScanSchedulingOrder = ImmutableList.copyOf(requireNonNull(tableScanSchedulingOrder, "tableScanSchedulingOrder is null"));
            this.types = ImmutableList.copyOf(requireNonNull(types, "types is null"));
            this.numberOfRemoteSourceNodes = numberOfRemoteSourceNodes;
            this.partitioningScheme = requireNonNull(partitioningScheme, "partitioningScheme is null");
            this.stageExecutionDescriptor = requireNonNull(stageExecutionDescriptor, "stageExecutionDescriptor is null");
            this.outputTableWriterFragment = outputTableWriterFragment;
            this.statsAndCosts = requireNonNull(statsAndCosts, "statsAndCosts is null");
        }

        public static CanonicalTestFragment toCanonicalTestFragment(PlanFragment planFragment)
        {
            return new CanonicalTestFragment(
                    (Class<PlanNode>) planFragment.getRoot().getClass(),
                    planFragment.getVariables(),
                    planFragment.getPartitioning(),
                    planFragment.getTableScanSchedulingOrder(),
                    planFragment.getTypes(),
                    planFragment.getRemoteSourceNodes().size(),
                    planFragment.getPartitioningScheme(),
                    planFragment.getStageExecutionDescriptor(),
                    planFragment.isOutputTableWriterFragment(),
                    planFragment.getStatsAndCosts());
        }

        @Override
        public boolean equals(Object o)
        {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            CanonicalTestFragment that = (CanonicalTestFragment) o;
            return numberOfRemoteSourceNodes == that.numberOfRemoteSourceNodes &&
                    outputTableWriterFragment == that.outputTableWriterFragment &&
                    clazz.equals(that.clazz) &&
                    variables.equals(that.variables) &&
                    partitioning.equals(that.partitioning) &&
                    tableScanSchedulingOrder.equals(that.tableScanSchedulingOrder) &&
                    types.equals(that.types) &&
                    partitioningScheme.equals(that.partitioningScheme) &&
                    stageExecutionDescriptor.equals(that.stageExecutionDescriptor) &&
                    statsAndCosts.equals(that.statsAndCosts);
        }

        @Override
        public int hashCode()
        {
            return Objects.hash(
                    clazz,
                    variables,
                    partitioning,
                    tableScanSchedulingOrder,
                    types,
                    numberOfRemoteSourceNodes,
                    partitioningScheme,
                    stageExecutionDescriptor,
                    outputTableWriterFragment,
                    statsAndCosts);
        }

        @Override
        public String toString()
        {
            return toStringHelper(this)
                    .add("clazz", clazz)
                    .add("variables", variables)
                    .add("partitioning", partitioning)
                    .add("tableScanSchedulingOrder", tableScanSchedulingOrder)
                    .add("types", types)
                    .add("numberOfRemoteSourceNodes", numberOfRemoteSourceNodes)
                    .add("partitioningScheme", partitioningScheme)
                    .add("stageExecutionDescriptor", stageExecutionDescriptor)
                    .add("outputTableWriterFragment", outputTableWriterFragment)
                    .add("statsAndCosts", statsAndCosts)
                    .toString();
        }
    }
}