ApplyConnectorOptimization.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.spi.ConnectorPlanOptimizer;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.CteConsumerNode;
import com.facebook.presto.spi.plan.CteProducerNode;
import com.facebook.presto.spi.plan.CteReferenceNode;
import com.facebook.presto.spi.plan.DeleteNode;
import com.facebook.presto.spi.plan.DistinctLimitNode;
import com.facebook.presto.spi.plan.ExceptNode;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.IndexSourceNode;
import com.facebook.presto.spi.plan.IntersectNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.LimitNode;
import com.facebook.presto.spi.plan.MarkDistinctNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.SemiJoinNode;
import com.facebook.presto.spi.plan.SortNode;
import com.facebook.presto.spi.plan.TableFinishNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.TableWriterNode;
import com.facebook.presto.spi.plan.TopNNode;
import com.facebook.presto.spi.plan.UnionNode;
import com.facebook.presto.spi.plan.UnnestNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.sql.planner.TypeProvider;
import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;

import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Map;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;

import static com.facebook.presto.SystemSessionProperties.isIncludeValuesNodeInConnectorOptimizer;
import static com.facebook.presto.common.RuntimeUnit.NANO;
import static com.facebook.presto.sql.OptimizerRuntimeTrackUtil.getOptimizerNameForLog;
import static com.facebook.presto.sql.OptimizerRuntimeTrackUtil.trackOptimizerRuntime;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.Objects.requireNonNull;

public class ApplyConnectorOptimization
        implements PlanOptimizer
{
    static final Set<Class<? extends PlanNode>> CONNECTOR_ACCESSIBLE_PLAN_NODES = ImmutableSet.of(
            CteProducerNode.class,
            CteConsumerNode.class,
            CteReferenceNode.class,
            DistinctLimitNode.class,
            FilterNode.class,
            TableScanNode.class,
            IndexSourceNode.class,
            LimitNode.class,
            SortNode.class,
            TopNNode.class,
            ValuesNode.class,
            ProjectNode.class,
            AggregationNode.class,
            MarkDistinctNode.class,
            UnionNode.class,
            IntersectNode.class,
            ExceptNode.class,
            SemiJoinNode.class,
            JoinNode.class,
            UnnestNode.class,
            TableWriterNode.class,
            TableFinishNode.class,
            DeleteNode.class);

    // for a leaf node that does not belong to any connector (e.g., ValuesNode)
    private static final ConnectorId EMPTY_CONNECTOR_ID = new ConnectorId("$internal$" + ApplyConnectorOptimization.class + "_CONNECTOR");

    private final Supplier<Map<ConnectorId, Set<ConnectorPlanOptimizer>>> connectorOptimizersSupplier;

    public ApplyConnectorOptimization(Supplier<Map<ConnectorId, Set<ConnectorPlanOptimizer>>> connectorOptimizersSupplier)
    {
        this.connectorOptimizersSupplier = requireNonNull(connectorOptimizersSupplier, "connectorOptimizersSupplier is null");
    }

    @Override
    public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector)
    {
        requireNonNull(plan, "plan is null");
        requireNonNull(session, "session is null");
        requireNonNull(types, "types is null");
        requireNonNull(variableAllocator, "variableAllocator is null");
        requireNonNull(idAllocator, "idAllocator is null");

        boolean enableVerboseRuntimeStats = SystemSessionProperties.isVerboseRuntimeStatsEnabled(session);
        Map<ConnectorId, Set<ConnectorPlanOptimizer>> connectorOptimizers = connectorOptimizersSupplier.get();
        if (connectorOptimizers.isEmpty()) {
            return PlanOptimizerResult.optimizerResult(plan, false);
        }

        // retrieve all the connectors
        ImmutableSet.Builder<ConnectorId> connectorIds = ImmutableSet.builder();
        getAllConnectorIds(plan, connectorIds);

        // for each connector, retrieve the set of subplans to optimize
        // TODO: what if a new connector is added by an existing one
        // There are cases (e.g., query federation) where a connector C1 needs to
        // create a UNION_ALL to federate data sources from both C1 and C2 (regardless of the classloader issue).
        // For such case, it is dangerous to re-calculate the "max closure" given the fixpoint property will be broken.
        // In order to preserve the fixpoint, we will "pretend" the newly added C2 table scan is part of C1's job to maintain.
        for (ConnectorId connectorId : connectorIds.build()) {
            Set<ConnectorPlanOptimizer> optimizers = connectorOptimizers.get(connectorId);
            if (optimizers == null) {
                continue;
            }

            ImmutableMap.Builder<PlanNode, ConnectorPlanNodeContext> contextMapBuilder = ImmutableMap.builder();
            buildConnectorPlanNodeContext(plan, null, contextMapBuilder);
            Map<PlanNode, ConnectorPlanNodeContext> contextMap = contextMapBuilder.build();

            // keep track of changed nodes; the keys are original nodes and the values are the new nodes
            Map<PlanNode, PlanNode> updates = new HashMap<>();

            // process connector optimizers
            for (PlanNode node : contextMap.keySet()) {
                // For a subtree with root `node` to be a max closure, the following conditions must hold:
                //    * The subtree with root `node` is a closure.
                //    * `node` has no parent, or the subtree with root as `node`'s parent is not a closure.
                ConnectorPlanNodeContext context = contextMap.get(node);
                if (!context.isClosure(connectorId, session) ||
                        !context.getParent().isPresent() ||
                        contextMap.get(context.getParent().get()).isClosure(connectorId, session)) {
                    continue;
                }

                PlanNode newNode = node;

                // the returned node is still a max closure (only if there is no new connector added, which does happen but ignored here)
                for (ConnectorPlanOptimizer optimizer : optimizers) {
                    long start = System.nanoTime();
                    newNode = optimizer.optimize(newNode, session.toConnectorSession(connectorId), variableAllocator, idAllocator);
                    if (enableVerboseRuntimeStats || trackOptimizerRuntime(session, optimizer)) {
                        session.getRuntimeStats().addMetricValue(String.format("optimizer%sTimeNanos", getOptimizerNameForLog(optimizer)), NANO, System.nanoTime() - start);
                    }
                }

                if (node != newNode) {
                    // the optimizer has allocated a new PlanNode
                    checkState(
                            containsAll(ImmutableSet.copyOf(newNode.getOutputVariables()), node.getOutputVariables()),
                            "the connector optimizer from %s returns a node that does not cover all output before optimization",
                            connectorId);

                    updates.put(node, newNode);
                }
            }
            // up to this point, we have a set of updated nodes; need to recursively update their parents

            // alter the plan with a bottom-up approach (but does not have to be strict bottom-up to guarantee the correctness of the algorithm)
            // use "original nodes" to keep track of the plan structure and "updates" to keep track of the new nodes
            Queue<PlanNode> originalNodes = new LinkedList<>(updates.keySet());
            while (!originalNodes.isEmpty()) {
                PlanNode originalNode = originalNodes.poll();

                if (!contextMap.get(originalNode).getParent().isPresent()) {
                    // originalNode must be the root; update the plan
                    plan = updates.get(originalNode);
                    continue;
                }

                PlanNode originalParent = contextMap.get(originalNode).getParent().get();

                // need to create a new parent given the child has changed; the new parent needs to point to the new child.
                // if a node has been updated, it will occur in `updates`; otherwise, just use the original node
                ImmutableList.Builder<PlanNode> newChildren = ImmutableList.builder();
                originalParent.getSources().forEach(child -> newChildren.add(updates.getOrDefault(child, child)));
                PlanNode newParent = originalParent.replaceChildren(newChildren.build());

                // mark the new parent as updated
                updates.put(originalParent, newParent);

                // enqueue the parent node in order to recursively update its ancestors
                originalNodes.add(originalParent);
            }
        }

        return PlanOptimizerResult.optimizerResult(plan, true);
    }

    private static void getAllConnectorIds(PlanNode node, ImmutableSet.Builder<ConnectorId> builder)
    {
        if (node.getSources().isEmpty()) {
            if (node instanceof TableScanNode) {
                builder.add(((TableScanNode) node).getTable().getConnectorId());
            }
            else if (node instanceof IndexSourceNode) {
                builder.add(((IndexSourceNode) node).getTableHandle().getConnectorId());
            }
            else {
                builder.add(EMPTY_CONNECTOR_ID);
            }
            return;
        }

        for (PlanNode child : node.getSources()) {
            getAllConnectorIds(child, builder);
        }
    }

    private static ConnectorPlanNodeContext buildConnectorPlanNodeContext(
            PlanNode node,
            PlanNode parent,
            ImmutableMap.Builder<PlanNode, ConnectorPlanNodeContext> contextBuilder)
    {
        Set<ConnectorId> connectorIds;
        Set<Class<? extends PlanNode>> planNodeTypes;

        if (node.getSources().isEmpty()) {
            if (node instanceof TableScanNode) {
                connectorIds = ImmutableSet.of(((TableScanNode) node).getTable().getConnectorId());
                planNodeTypes = ImmutableSet.of(TableScanNode.class);
            }
            else if (node instanceof IndexSourceNode) {
                connectorIds = ImmutableSet.of(((IndexSourceNode) node).getTableHandle().getConnectorId());
                planNodeTypes = ImmutableSet.of(IndexSourceNode.class);
            }
            else {
                connectorIds = ImmutableSet.of(EMPTY_CONNECTOR_ID);
                planNodeTypes = ImmutableSet.of(node.getClass());
            }
        }
        else {
            connectorIds = new HashSet<>();
            planNodeTypes = new HashSet<>();

            for (PlanNode child : node.getSources()) {
                ConnectorPlanNodeContext childContext = buildConnectorPlanNodeContext(child, node, contextBuilder);
                connectorIds.addAll(childContext.getReachableConnectors());
                planNodeTypes.addAll(childContext.getReachablePlanNodeTypes());
            }
            planNodeTypes.add(node.getClass());
        }

        ConnectorPlanNodeContext connectorPlanNodeContext = new ConnectorPlanNodeContext(
                parent,
                connectorIds,
                planNodeTypes);

        contextBuilder.put(node, connectorPlanNodeContext);
        return connectorPlanNodeContext;
    }

    /**
     * Extra information needed for a plan node
     */
    private static final class ConnectorPlanNodeContext
    {
        private final PlanNode parent;
        private final Set<ConnectorId> reachableConnectors;
        private final Set<Class<? extends PlanNode>> reachablePlanNodeTypes;

        ConnectorPlanNodeContext(PlanNode parent, Set<ConnectorId> reachableConnectors, Set<Class<? extends PlanNode>> reachablePlanNodeTypes)
        {
            this.parent = parent;
            this.reachableConnectors = requireNonNull(reachableConnectors, "reachableConnectors is null");
            this.reachablePlanNodeTypes = requireNonNull(reachablePlanNodeTypes, "reachablePlanNodeTypes is null");
            checkArgument(!reachableConnectors.isEmpty(), "encountered a PlanNode that reaches no connector");
            checkArgument(!reachablePlanNodeTypes.isEmpty(), "encountered a PlanNode that reaches no plan node");
        }

        Optional<PlanNode> getParent()
        {
            return Optional.ofNullable(parent);
        }

        public Set<ConnectorId> getReachableConnectors()
        {
            return reachableConnectors;
        }

        public Set<Class<? extends PlanNode>> getReachablePlanNodeTypes()
        {
            return reachablePlanNodeTypes;
        }

        boolean isClosure(ConnectorId connectorId, Session session)
        {
            // check if all children can reach the only connector
            boolean includeValuesNode = isIncludeValuesNodeInConnectorOptimizer(session);
            Set<ConnectorId> connectorIds = includeValuesNode ? reachableConnectors.stream().filter(x -> !x.equals(EMPTY_CONNECTOR_ID)).collect(toImmutableSet()) : reachableConnectors;
            if (connectorIds.size() != 1 || !connectorIds.contains(connectorId)) {
                return false;
            }

            // check if all children are accessible by connectors
            return containsAll(CONNECTOR_ACCESSIBLE_PLAN_NODES, reachablePlanNodeTypes);
        }
    }

    private static <T> boolean containsAll(Set<T> container, Collection<T> test)
    {
        for (T element : test) {
            if (!container.contains(element)) {
                return false;
            }
        }
        return true;
    }
}