LogicalCteOptimizer.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.eventlistener.CTEInformation;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.CteConsumerNode;
import com.facebook.presto.spi.plan.CteProducerNode;
import com.facebook.presto.spi.plan.CteReferenceNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.SemiJoinNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.sql.planner.SimplePlanVisitor;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
import com.facebook.presto.sql.planner.plan.SequenceNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.graph.Graph;
import com.google.common.graph.GraphBuilder;
import com.google.common.graph.MutableGraph;
import com.google.common.graph.MutableValueGraph;
import com.google.common.graph.Traverser;
import com.google.common.graph.ValueGraphBuilder;

import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.Stack;
import java.util.stream.Collectors;

import static com.facebook.presto.SystemSessionProperties.getCteHeuristicReplicationThreshold;
import static com.facebook.presto.SystemSessionProperties.getCteMaterializationStrategy;
import static com.facebook.presto.SystemSessionProperties.isCteMaterializationApplicable;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.CteMaterializationStrategy.ALL;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.CteMaterializationStrategy.HEURISTIC;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.CteMaterializationStrategy.HEURISTIC_COMPLEX_QUERIES_ONLY;
import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;

/*
 * Transformation of CTE Reference Nodes:
 * This process converts CTE reference nodes into corresponding CteProducers and Consumers.
 * Makes sure that execution deadlocks do not exist
 *
 * Example:
 * Before Transformation:
 *   JOIN
 *   |-- CTEReference(cte2)
 *   |   `-- TABLESCAN2
 *   `-- CTEReference(cte3)
 *       `-- TABLESCAN3
 *
 * After Transformation:
 *   SEQUENCE(cte1)
 *   |-- CTEProducer(cte2)
 *   |   `-- TABLESCAN2
 *   |-- CTEProducer(cte3)
 *   |   `-- TABLESCAN3
 *   `-- JOIN
 *       |-- CTEConsumer(cte2)
 *       `-- CTEConsumer(cte3)
 */
public class LogicalCteOptimizer
        implements PlanOptimizer
{
    private final Metadata metadata;

    public LogicalCteOptimizer(Metadata metadata)
    {
        this.metadata = metadata;
    }

    @Override
    public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector)
    {
        requireNonNull(plan, "plan is null");
        requireNonNull(session, "session is null");
        requireNonNull(variableAllocator, "variableAllocator is null");
        requireNonNull(idAllocator, "idAllocator is null");
        requireNonNull(warningCollector, "warningCollector is null");
        if (!isCteMaterializationApplicable(session)) {
            return PlanOptimizerResult.optimizerResult(plan, false);
        }
        CteEnumerator cteEnumerator = new CteEnumerator(idAllocator, variableAllocator);
        PlanNode rewrittenPlan = cteEnumerator.transformPersistentCtes(session, plan);
        return PlanOptimizerResult.optimizerResult(rewrittenPlan, cteEnumerator.isPlanRewritten());
    }

    public class CteEnumerator
    {
        private PlanNodeIdAllocator planNodeIdAllocator;
        private VariableAllocator variableAllocator;

        private boolean isPlanRewritten;

        public CteEnumerator(PlanNodeIdAllocator planNodeIdAllocator, VariableAllocator variableAllocator)
        {
            this.planNodeIdAllocator = requireNonNull(planNodeIdAllocator, "planNodeIdAllocator must not be null");
            this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator must not be null");
        }

        public PlanNode transformPersistentCtes(Session session, PlanNode root)
        {
            checkArgument(root.getSources().size() == 1, "expected newChildren to contain 1 node");
            LogicalCteOptimizerContext context = new LogicalCteOptimizerContext();
            determineMaterializationCandidatesAndUpdateContext(session, root, context);
            PlanNode transformedCte = SimplePlanRewriter.rewriteWith(new CteConsumerTransformer(session, planNodeIdAllocator, variableAllocator),
                    root, context);
            List<PlanNode> topologicalOrderedList = context.getTopologicalOrdering();
            if (topologicalOrderedList.isEmpty()) {
                isPlanRewritten = false;
                // Returning transformed Cte because cte reference nodes are cleared in the transformedCte regardless of materialization
                return transformedCte;
            }
            isPlanRewritten = true;
            SequenceNode sequenceNode = new SequenceNode(root.getSourceLocation(),
                    planNodeIdAllocator.getNextId(),
                    topologicalOrderedList,
                    transformedCte.getSources().get(0),
                    context.createIndexedGraphFromTopologicallySortedCteProducers(topologicalOrderedList));
            return root.replaceChildren(Arrays.asList(sequenceNode));
        }

        public boolean isPlanRewritten()
        {
            return isPlanRewritten;
        }

        private void determineMaterializationCandidatesAndUpdateContext(Session session, PlanNode root, LogicalCteOptimizerContext context)
        {
            if (shouldPerformHeuristicAnalysis(session)) {
                performHeuristicAnalysis(session, root, context);
            }
            else {
                markAllCtesForMaterialization(session, context);
            }
        }

        private boolean shouldPerformHeuristicAnalysis(Session session)
        {
            return !getCteMaterializationStrategy(session).equals(ALL);
        }

        private void performHeuristicAnalysis(Session session, PlanNode root, LogicalCteOptimizerContext context)
        {
            WeightedDependencyAnalyzer dependencyAnalyzer = new WeightedDependencyAnalyzer();
            ComplexCteAnalyzer complexCteAnalyzer = new ComplexCteAnalyzer(session);

            root.accept(dependencyAnalyzer, context);
            root.accept(complexCteAnalyzer, context);
            new HeuristicCteMaterializationDeterminer(session).determineHeuristicCandidates(context);
        }

        private void markAllCtesForMaterialization(Session session, LogicalCteOptimizerContext context)
        {
            session.getCteInformationCollector().getCTEInformationList().stream()
                    .map(CTEInformation::getCteId)
                    .forEach(context::addMaterializationCandidate);
        }
    }

    // Checks if the CTE has an underlying JoinNode.class, SemiJoinNode.class, AggregationNode.class
    // The presence of a complex node will mark the nearest parent CTE as complex
    public static class ComplexCteAnalyzer
            extends SimplePlanVisitor<LogicalCteOptimizerContext>
    {
        private final Session session;

        private static final List<Class<? extends PlanNode>> DATA_SOURCES_PLAN_NODES = ImmutableList.of(TableScanNode.class, RemoteSourceNode.class);

        public ComplexCteAnalyzer(Session session)
        {
            this.session = requireNonNull(session, "Session is null");
        }

        @Override
        public Void visitCteReference(CteReferenceNode node, LogicalCteOptimizerContext context)
        {
            context.pushActiveCte(node.getCteId());
            node.getSource().accept(this, context);
            context.popActiveCte();
            // Check for source nodes
            if (context.isComplexCte(node.getCteId()) &&
                    !PlanNodeSearcher.searchFrom(node)
                            .where(planNode -> DATA_SOURCES_PLAN_NODES.stream()
                                    .anyMatch(clazz -> clazz.isInstance(planNode)))
                            .matches()) {
                context.removeComplexCte(node.getCteId());
            }
            return null;
        }

        @Override
        public Void visitJoin(JoinNode node, LogicalCteOptimizerContext context)
        {
            Optional<String> parentCte = context.peekActiveCte();
            parentCte.ifPresent(context::addComplexCte);
            return super.visitJoin(node, context);
        }

        @Override
        public Void visitSemiJoin(SemiJoinNode node, LogicalCteOptimizerContext context)
        {
            Optional<String> parentCte = context.peekActiveCte();
            parentCte.ifPresent(context::addComplexCte);
            return super.visitSemiJoin(node, context);
        }

        @Override
        public Void visitAggregation(AggregationNode node, LogicalCteOptimizerContext context)
        {
            Optional<String> parentCte = context.peekActiveCte();
            parentCte.ifPresent(context::addComplexCte);
            return super.visitAggregation(node, context);
        }
    }

    /**
     * Analyzes the query plan to build a weighted dependency graph for the CTEs
     * The weight on each edge signifies the number of times the child CTE is referenced by the parent
     **/
    public static class WeightedDependencyAnalyzer
            extends SimplePlanVisitor<LogicalCteOptimizerContext>
    {
        private final Set<String> visited;

        public WeightedDependencyAnalyzer()
        {
            visited = new HashSet<>();
        }

        @Override
        public Void visitCteReference(CteReferenceNode node, LogicalCteOptimizerContext context)
        {
            if (visited.contains(node.getCteId())) {
                // already visited so skip traversal but add dependency
                context.addCteReferenceDependency(node.getCteId());
                return null;
            }
            visited.add(node.getCteId());
            context.addCteReferenceDependency(node.getCteId());
            context.pushActiveCte(node.getCteId());
            node.getSource().accept(this, context);
            context.popActiveCte();
            return null;
        }

        @Override
        public Void visitApply(ApplyNode node, LogicalCteOptimizerContext context)
        {
            node.getInput().accept(this, context);
            node.getSubquery().accept(this, context);
            return null;
        }
    }

    /**
     * Selects CTEs for materialization following a greedy heuristic approach.
     * The algorithm greedily prioritizes the earliest parent CTE that meets the heuristic criteria for materialization and then reduces the reference counts for its child CTEs,
     * assuming they are now accessed via the materialized parent.
     * The CTEs selected for materialization by this class adhere to the heuristic conditions, yet they might not represent the most optimal choices due to the nature of the heuristic decision-making process.
     * Example:
     * <p>
     * CTE_A
     * /    \
     * CTE_B  CTE_C
     * |      |
     * CTE_D  CTE_E
     * <p>
     * In this graph, if CTE_B and CTE_C are heavily referenced, the algorithm might choose to materialize these first, reducing the reference count for CTE_D and CTE_E respectively.
     * This means that subsequent decisions will consider the reduced reference count for CTE_D and CTE_E, potentially affecting whether they are materialized.
     */
    public static class HeuristicCteMaterializationDeterminer
    {
        private final Session session;

        public HeuristicCteMaterializationDeterminer(Session session)
        {
            this.session = session;
        }

        private void decrementCteReferenceCount(String cteId, int referencesToRemove)
        {
            HashMap<String, CTEInformation> cteInformationMap = session.getCteInformationCollector().getCteInformationMap();
            CTEInformation cteInfo = cteInformationMap.get(cteId);
            int newReferenceCount = cteInfo.getNumberOfReferences() - referencesToRemove;

            checkArgument(newReferenceCount >= 0, "CTE Reference count for cteId %s should be >= 0", cteId);
            cteInformationMap.put(cteId, new CTEInformation(cteInfo.getCteName(), cteInfo.getCteId(), newReferenceCount, cteInfo.getIsView(), cteInfo.isMaterialized()));
        }

        void rebaseReferences(MutableValueGraph<String, Integer> graph, String cteId, int currentMultiplier, int baseRemovalMultiplier, LogicalCteOptimizerContext context)
        {
            for (String childCte : graph.successors(cteId)) {
                if (!context.shouldCteBeMaterialized(childCte)) {
                    int edgeValue = graph.edgeValueOrDefault(cteId, childCte, 1);
                    int referencesToRemove = baseRemovalMultiplier * currentMultiplier * edgeValue;
                    decrementCteReferenceCount(childCte, referencesToRemove);
                    rebaseReferences(graph, childCte, currentMultiplier * edgeValue, baseRemovalMultiplier, context);
                }
            }
        }

        /**
         * Recursively adjusts the reference counts in the dependency graph due to the materialization of a CTE.
         * This adjustment accounts for the reduced need to recompute the CTEs that are directly or indirectly
         * referenced by the materialized CTE.
         * <p>
         * Example:
         * Let's say A is referenced 3 times in a query and A references B 3 times, and B references C 2 times.
         * The graph would be: Query - (3) - A -(3)-> B -(2)-> C
         * If A was materialized, we would need to adjust B and C's references because their computations are
         * effectively encapsulated by A's materialization.
         * Initial reference counts are 9 for B and 18 for C.
         * The decrement needed would be as follows:
         * - For B, the adjustment would be 2 (A's references) * 3 (times A references B) = 6
         * - For C, following B's adjustment, the adjustment would be 2 (A's references) * 3 (times A references B) * 2 (times B references C) = 12
         * Therefore, the new reference counts would be 3 for B (9 - 6) and 6 for C (18 - 12).
         */
        private void adjustChildReferenceCounts(String parentCteId, int parentReferences, LogicalCteOptimizerContext context)
        {
            int adjustmentFactor = parentReferences - 1;
            checkArgument(adjustmentFactor >= 0, "adjustment count cannot be negative");
            rebaseReferences(context.cteReferenceDependencyGraph, parentCteId, 1, adjustmentFactor, context);
        }

        public void determineHeuristicCandidates(LogicalCteOptimizerContext context)
        {
            MutableValueGraph<String, Integer> cteReferenceDependencyGraph = context.copyOfCteReferenceDependencyGraph();
            HashMap<String, CTEInformation> cteInformationMap = session.getCteInformationCollector().getCteInformationMap();

            // populate vertexes with indegree 0
            List<String> nodesWithInDegreeZero = cteReferenceDependencyGraph.nodes().stream()
                    .filter(node -> cteReferenceDependencyGraph.inDegree(node) == 0)
                    .collect(Collectors.toList());

            while (!nodesWithInDegreeZero.isEmpty()) {
                // traverse these edges and update
                nodesWithInDegreeZero.forEach(cteId -> {
                    CTEInformation cteInfo = cteInformationMap.get(cteId);
                    boolean isAboveThreshold = cteInfo.getNumberOfReferences() >= getCteHeuristicReplicationThreshold(session);
                    boolean isHeuristic = getCteMaterializationStrategy(session).equals(HEURISTIC);
                    boolean isHeuristicComplexOnly = getCteMaterializationStrategy(session).equals(HEURISTIC_COMPLEX_QUERIES_ONLY);
                    boolean isComplexCte = context.isComplexCte(cteInfo.getCteId());

                    if (isAboveThreshold && (isHeuristic || (isHeuristicComplexOnly && isComplexCte))) {
                        // should be materialized
                        context.candidatesForMaterilization.add(cteId);
                        // update child references
                        adjustChildReferenceCounts(cteId, cteInfo.getNumberOfReferences(), context);
                    }
                });

                // Remove these nodes from the graphs
                nodesWithInDegreeZero.forEach(cteReferenceDependencyGraph::removeNode);

                // Refresh the list of nodes with in-degree of zero
                nodesWithInDegreeZero = cteReferenceDependencyGraph.nodes().stream()
                        .filter(node -> cteReferenceDependencyGraph.inDegree(node) == 0)
                        .collect(Collectors.toList());
            }
        }
    }

    public static class CteConsumerTransformer
            extends SimplePlanRewriter<LogicalCteOptimizerContext>
    {
        private final PlanNodeIdAllocator idAllocator;
        private final VariableAllocator variableAllocator;
        private final Session session;

        public CteConsumerTransformer(Session session, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator)
        {
            this.idAllocator = requireNonNull(idAllocator, "idAllocator must not be null");
            this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator must not be null");
            this.session = requireNonNull(session, "Session is null");
        }

        public boolean shouldCteBeMaterialized(String cteId, LogicalCteOptimizerContext context)
        {
            CTEInformation cteInfo = session.getCteInformationCollector().getCteInformationMap().get(cteId);
            boolean shouldBeMaterialized = context.shouldCteBeMaterialized(cteId);
            session.getCteInformationCollector().getCteInformationMap().put(cteId,
                    new CTEInformation(cteInfo.getCteName(), cteInfo.getCteId(), cteInfo.getNumberOfReferences(), cteInfo.getIsView(), shouldBeMaterialized));
            return shouldBeMaterialized;
        }

        @Override
        public PlanNode visitCteReference(CteReferenceNode node, RewriteContext<LogicalCteOptimizerContext> context)
        {
            if (!shouldCteBeMaterialized(node.getCteId(), context.get())) {
                return context.rewrite(node.getSource(), context.get());
            }
            context.get().addMaterializedCteDependency(node.getCteId());
            context.get().pushActiveCte(node.getCteId());
            // So that dependent CTEs are processed properly
            PlanNode actualSource = context.rewrite(node.getSource(), context.get());
            context.get().popActiveCte();
            CteProducerNode cteProducerSource = new CteProducerNode(node.getSourceLocation(),
                    idAllocator.getNextId(),
                    actualSource,
                    node.getCteId(),
                    variableAllocator.newVariable("rows", BIGINT), node.getOutputVariables());
            context.get().addProducer(node.getCteId(), cteProducerSource);
            return new CteConsumerNode(node.getSourceLocation(), idAllocator.getNextId(), Optional.of(actualSource), actualSource.getOutputVariables(), node.getCteId(), actualSource);
        }

        @Override
        public PlanNode visitApply(ApplyNode node, RewriteContext<LogicalCteOptimizerContext> context)
        {
            return new ApplyNode(node.getSourceLocation(),
                    idAllocator.getNextId(),
                    context.rewrite(node.getInput(),
                            context.get()),
                    context.rewrite(node.getSubquery(),
                            context.get()),
                    node.getSubqueryAssignments(),
                    node.getCorrelation(),
                    node.getOriginSubqueryError(),
                    node.getMayParticipateInAntiJoin());
        }
    }

    public static class LogicalCteOptimizerContext
    {
        public Map<String, CteProducerNode> cteProducerMap;

        // a -> b indicates that b needs to be processed before a
        private MutableValueGraph<String, Integer> cteReferenceDependencyGraph;

        // a -> b indicates that a needs to be processed before b
        private MutableGraph<String> materializedCteDependencyGraph;

        private Stack<String> activeCteStack;

        private Set<String> complexCtes;

        private Set<String> candidatesForMaterilization;

        public LogicalCteOptimizerContext()
        {
            cteProducerMap = new HashMap<>();
            // The cte graph will never have cycles because sql won't allow it
            cteReferenceDependencyGraph = ValueGraphBuilder.directed().allowsSelfLoops(false).build();
            materializedCteDependencyGraph = GraphBuilder.directed().allowsSelfLoops(false).build();
            activeCteStack = new Stack<>();
            complexCtes = new HashSet<>();
            candidatesForMaterilization = new HashSet<>();
        }

        public Map<String, CteProducerNode> getCteProducerMap()
        {
            return ImmutableMap.copyOf(cteProducerMap);
        }

        public MutableValueGraph<String, Integer> copyOfCteReferenceDependencyGraph()
        {
            MutableValueGraph<String, Integer> graphCopy = ValueGraphBuilder.from(cteReferenceDependencyGraph).build();
            for (String node : cteReferenceDependencyGraph.nodes()) {
                graphCopy.addNode(node);
            }
            for (String node : cteReferenceDependencyGraph.nodes()) {
                cteReferenceDependencyGraph.successors(node).forEach(successor ->
                        graphCopy.putEdgeValue(node, successor, cteReferenceDependencyGraph.edgeValueOrDefault(node, successor, 0)));
            }
            return cteReferenceDependencyGraph;
        }

        public void addProducer(String cteId, CteProducerNode cteProducer)
        {
            cteProducerMap.putIfAbsent(cteId, cteProducer);
        }

        public void addMaterializationCandidate(String cteId)
        {
            this.candidatesForMaterilization.add(cteId);
        }

        public boolean shouldCteBeMaterialized(String cteId)
        {
            return this.candidatesForMaterilization.contains(cteId);
        }

        public void pushActiveCte(String cte)
        {
            this.activeCteStack.push(cte);
        }

        public String popActiveCte()
        {
            return this.activeCteStack.pop();
        }

        public Optional<String> peekActiveCte()
        {
            return (this.activeCteStack.isEmpty()) ? Optional.empty() : Optional.ofNullable(this.activeCteStack.peek());
        }

        public void addCteReferenceDependency(String currentCte)
        {
            cteReferenceDependencyGraph.addNode(currentCte);
            Optional<String> parentCte = peekActiveCte();
            parentCte.ifPresent(parent -> {
                if (cteReferenceDependencyGraph.hasEdgeConnecting(parent, currentCte)) {
                    // If the edge exists, increment its value
                    int existingWeight = cteReferenceDependencyGraph.edgeValueOrDefault(parent, currentCte, 0);
                    cteReferenceDependencyGraph.putEdgeValue(parent, currentCte, existingWeight + 1);
                }
                else {
                    // If the edge does not exist, create it with a value of 1
                    cteReferenceDependencyGraph.putEdgeValue(parent, currentCte, 1);
                }
            });
        }

        public void addMaterializedCteDependency(String currentCte)
        {
            materializedCteDependencyGraph.addNode(currentCte);
            Optional<String> parentCte = peekActiveCte();
            parentCte.ifPresent(s -> materializedCteDependencyGraph.putEdge(currentCte, s));
        }

        public void addComplexCte(String cteId)
        {
            complexCtes.add(cteId);
        }

        public void removeComplexCte(String cteId)
        {
            complexCtes.remove(cteId);
        }

        public boolean isComplexCte(String cteId)
        {
            return complexCtes.contains(cteId);
        }

        public List<PlanNode> getTopologicalOrdering()
        {
            ImmutableList.Builder<PlanNode> topSortedCteProducerListBuilder = ImmutableList.builder();
            Traverser.forGraph(materializedCteDependencyGraph).depthFirstPostOrder(materializedCteDependencyGraph.nodes())
                    .forEach(cteId -> topSortedCteProducerListBuilder.add(cteProducerMap.get(cteId)));
            return topSortedCteProducerListBuilder.build();
        }

        public Graph<Integer> createIndexedGraphFromTopologicallySortedCteProducers(List<PlanNode> topologicalSortedCteProducerList)
        {
            Map<String, Integer> cteIdToProducerIndexMap = new HashMap<>();
            MutableGraph<Integer> indexGraph = GraphBuilder
                    .directed()
                    .expectedNodeCount(topologicalSortedCteProducerList.size())
                    .build();
            for (int i = 0; i < topologicalSortedCteProducerList.size(); i++) {
                cteIdToProducerIndexMap.put(((CteProducerNode) topologicalSortedCteProducerList.get(i)).getCteId(), i);
                indexGraph.addNode(i);
            }

            // Populate the new graph with edges based on the index mapping
            for (String cteId : materializedCteDependencyGraph.nodes()) {
                materializedCteDependencyGraph.successors(cteId).forEach(successor ->
                        indexGraph.putEdge(cteIdToProducerIndexMap.get(cteId), cteIdToProducerIndexMap.get(successor)));
            }
            return indexGraph;
        }
    }
}