SqlQueryScheduler.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.execution.scheduler;

import com.facebook.airlift.concurrent.SetThreadName;
import com.facebook.airlift.log.Logger;
import com.facebook.airlift.stats.TimeStat;
import com.facebook.presto.Session;
import com.facebook.presto.cost.StatsAndCosts;
import com.facebook.presto.execution.BasicStageExecutionStats;
import com.facebook.presto.execution.LocationFactory;
import com.facebook.presto.execution.PartialResultQueryManager;
import com.facebook.presto.execution.QueryState;
import com.facebook.presto.execution.QueryStateMachine;
import com.facebook.presto.execution.RemoteTask;
import com.facebook.presto.execution.RemoteTaskFactory;
import com.facebook.presto.execution.SqlStageExecution;
import com.facebook.presto.execution.StageExecutionInfo;
import com.facebook.presto.execution.StageExecutionState;
import com.facebook.presto.execution.StageId;
import com.facebook.presto.execution.StageInfo;
import com.facebook.presto.execution.TaskId;
import com.facebook.presto.execution.buffer.OutputBuffers;
import com.facebook.presto.execution.buffer.OutputBuffers.OutputBufferId;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.PlanFragmentId;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.SplitSourceFactory;
import com.facebook.presto.sql.planner.SubPlan;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.sanity.PlanChecker;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.util.concurrent.ListenableFuture;
import com.sun.management.ThreadMXBean;
import io.airlift.units.Duration;
import org.apache.http.client.utils.URIBuilder;

import java.lang.management.ManagementFactory;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;

import static com.facebook.airlift.concurrent.MoreFutures.tryGetFutureValue;
import static com.facebook.airlift.concurrent.MoreFutures.whenAnyComplete;
import static com.facebook.presto.SystemSessionProperties.getMaxConcurrentMaterializations;
import static com.facebook.presto.SystemSessionProperties.getPartialResultsCompletionRatioThreshold;
import static com.facebook.presto.SystemSessionProperties.getPartialResultsMaxExecutionTimeMultiplier;
import static com.facebook.presto.SystemSessionProperties.isEnhancedCTESchedulingEnabled;
import static com.facebook.presto.SystemSessionProperties.isPartialResultsEnabled;
import static com.facebook.presto.SystemSessionProperties.isRuntimeOptimizerEnabled;
import static com.facebook.presto.execution.BasicStageExecutionStats.aggregateBasicStageStats;
import static com.facebook.presto.execution.StageExecutionState.ABORTED;
import static com.facebook.presto.execution.StageExecutionState.CANCELED;
import static com.facebook.presto.execution.StageExecutionState.FAILED;
import static com.facebook.presto.execution.StageExecutionState.FINISHED;
import static com.facebook.presto.execution.StageExecutionState.PLANNED;
import static com.facebook.presto.execution.StageExecutionState.RUNNING;
import static com.facebook.presto.execution.StageExecutionState.SCHEDULED;
import static com.facebook.presto.execution.buffer.OutputBuffers.BROADCAST_PARTITION_ID;
import static com.facebook.presto.execution.buffer.OutputBuffers.createDiscardingOutputBuffers;
import static com.facebook.presto.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers;
import static com.facebook.presto.execution.scheduler.StreamingPlanSection.extractStreamingSections;
import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static com.facebook.presto.sql.planner.PlanFragmenterUtils.ROOT_FRAGMENT_ID;
import static com.facebook.presto.sql.planner.SchedulingOrderVisitor.scheduleOrder;
import static com.facebook.presto.sql.planner.planPrinter.PlanPrinter.jsonFragmentPlan;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.collect.Streams.stream;
import static com.google.common.graph.Traverser.forTree;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
import static java.util.function.Function.identity;

public class SqlQueryScheduler
        implements SqlQuerySchedulerInterface
{
    private static final Logger log = Logger.get(SqlQueryScheduler.class);

    private static final ThreadMXBean THREAD_MX_BEAN = (ThreadMXBean) ManagementFactory.getThreadMXBean();

    private final LocationFactory locationFactory;
    private final ExecutionPolicy executionPolicy;

    private final SplitSchedulerStats schedulerStats;

    private final QueryStateMachine queryStateMachine;
    private final AtomicReference<SubPlan> plan = new AtomicReference<>();
    private final StreamingPlanSection sectionedPlan;
    private final StageId rootStageId;
    private final boolean summarizeTaskInfo;
    private final int maxConcurrentMaterializations;

    // The following fields are required by adaptive optimization in runtime.
    private final Session session;
    private final FunctionAndTypeManager functionAndTypeManager;
    private final List<PlanOptimizer> runtimePlanOptimizers;
    private final WarningCollector warningCollector;
    private final PlanNodeIdAllocator idAllocator;
    private final VariableAllocator variableAllocator;
    private final SectionExecutionFactory sectionExecutionFactory;
    private final RemoteTaskFactory remoteTaskFactory;
    private final SplitSourceFactory splitSourceFactory;
    private final Set<StageId> runtimeOptimizedStages = Collections.synchronizedSet(new HashSet<>());
    private final PlanChecker planChecker;
    private final Metadata metadata;

    private final Map<StageId, StageExecutionAndScheduler> stageExecutions = new ConcurrentHashMap<>();
    private final ExecutorService executor;
    private final AtomicBoolean started = new AtomicBoolean();
    private final AtomicBoolean scheduling = new AtomicBoolean();

    private final PartialResultQueryTaskTracker partialResultQueryTaskTracker;
    private final CTEMaterializationTracker cteMaterializationTracker = new CTEMaterializationTracker();

    public static SqlQueryScheduler createSqlQueryScheduler(
            LocationFactory locationFactory,
            ExecutionPolicy executionPolicy,
            ExecutorService queryExecutor,
            SplitSchedulerStats schedulerStats,
            SectionExecutionFactory sectionExecutionFactory,
            RemoteTaskFactory remoteTaskFactory,
            SplitSourceFactory splitSourceFactory,
            Session session,
            FunctionAndTypeManager functionAndTypeManager,
            QueryStateMachine queryStateMachine,
            SubPlan plan,
            OutputBuffers rootOutputBuffers,
            boolean summarizeTaskInfo,
            List<PlanOptimizer> runtimePlanOptimizers,
            WarningCollector warningCollector,
            PlanNodeIdAllocator idAllocator,
            VariableAllocator variableAllocator,
            PlanChecker planChecker,
            Metadata metadata,
            SqlParser sqlParser,
            PartialResultQueryManager partialResultQueryManager)
    {
        SqlQueryScheduler sqlQueryScheduler = new SqlQueryScheduler(
                locationFactory,
                executionPolicy,
                queryExecutor,
                schedulerStats,
                sectionExecutionFactory,
                remoteTaskFactory,
                splitSourceFactory,
                session,
                functionAndTypeManager,
                queryStateMachine,
                plan,
                summarizeTaskInfo,
                rootOutputBuffers,
                runtimePlanOptimizers,
                warningCollector,
                idAllocator,
                variableAllocator,
                planChecker,
                metadata,
                sqlParser,
                partialResultQueryManager);
        sqlQueryScheduler.initialize();
        return sqlQueryScheduler;
    }

    private SqlQueryScheduler(
            LocationFactory locationFactory,
            ExecutionPolicy executionPolicy,
            ExecutorService queryExecutor,
            SplitSchedulerStats schedulerStats,
            SectionExecutionFactory sectionExecutionFactory,
            RemoteTaskFactory remoteTaskFactory,
            SplitSourceFactory splitSourceFactory,
            Session session,
            FunctionAndTypeManager functionAndTypeManager,
            QueryStateMachine queryStateMachine,
            SubPlan plan,
            boolean summarizeTaskInfo,
            OutputBuffers rootOutputBuffers,
            List<PlanOptimizer> runtimePlanOptimizers,
            WarningCollector warningCollector,
            PlanNodeIdAllocator idAllocator,
            VariableAllocator variableAllocator,
            PlanChecker planChecker,
            Metadata metadata,
            SqlParser sqlParser,
            PartialResultQueryManager partialResultQueryManager)
    {
        this.locationFactory = requireNonNull(locationFactory, "locationFactory is null");
        this.executionPolicy = requireNonNull(executionPolicy, "schedulerPolicyFactory is null");
        this.executor = queryExecutor;
        this.schedulerStats = requireNonNull(schedulerStats, "schedulerStats is null");
        this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null");
        this.plan.compareAndSet(null, requireNonNull(plan, "plan is null"));
        this.session = requireNonNull(session, "session is null");
        this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionManager is null");
        this.runtimePlanOptimizers = requireNonNull(runtimePlanOptimizers, "runtimePlanOptimizers is null");
        this.warningCollector = requireNonNull(warningCollector, "warningCollector is null");
        this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
        this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator is null");
        this.planChecker = requireNonNull(planChecker, "planChecker is null");
        this.metadata = requireNonNull(metadata, "metadata is null");
        this.sectionExecutionFactory = requireNonNull(sectionExecutionFactory, "sectionExecutionFactory is null");
        this.remoteTaskFactory = requireNonNull(remoteTaskFactory, "remoteTaskFactory is null");
        this.splitSourceFactory = requireNonNull(splitSourceFactory, "splitSourceFactory is null");
        this.sectionedPlan = extractStreamingSections(plan);
        this.summarizeTaskInfo = summarizeTaskInfo;

        OutputBufferId rootBufferId = getOnlyElement(rootOutputBuffers.getBuffers().keySet());
        List<StageExecutionAndScheduler> stageExecutions = createStageExecutions(
                sectionExecutionFactory,
                (fragmentId, tasks, noMoreExchangeLocations) -> updateQueryOutputLocations(queryStateMachine, rootBufferId, tasks, noMoreExchangeLocations),
                sectionedPlan,
                Optional.of(new int[1]),
                rootOutputBuffers,
                remoteTaskFactory,
                splitSourceFactory,
                session);

        this.rootStageId = Iterables.getLast(stageExecutions).getStageExecution().getStageExecutionId().getStageId();

        stageExecutions.stream()
                .forEach(execution -> this.stageExecutions.put(execution.getStageExecution().getStageExecutionId().getStageId(), execution));

        this.maxConcurrentMaterializations = getMaxConcurrentMaterializations(session);
        this.partialResultQueryTaskTracker = new PartialResultQueryTaskTracker(partialResultQueryManager, getPartialResultsCompletionRatioThreshold(session), getPartialResultsMaxExecutionTimeMultiplier(session), warningCollector);
    }

    // this is a separate method to ensure that the `this` reference is not leaked during construction
    private void initialize()
    {
        SqlStageExecution rootStage = stageExecutions.get(rootStageId).getStageExecution();
        rootStage.addStateChangeListener(state -> {
            if (state == FINISHED) {
                queryStateMachine.transitionToFinishing();
            }
            else if (state == CANCELED) {
                // output stage was canceled
                queryStateMachine.transitionToCanceled();
            }
        });

        for (StageExecutionAndScheduler stageExecutionInfo : stageExecutions.values()) {
            SqlStageExecution stageExecution = stageExecutionInfo.getStageExecution();
            // Add a listener for state changes
            if (stageExecution.isCTETableFinishStage()) {
                stageExecution.addStateChangeListener(state -> {
                    if (state == StageExecutionState.FINISHED) {
                        String cteName = stageExecution.getCTEWriterId();
                        log.debug("CTE write completed for: " + cteName);
                        // Notify the materialization tracker
                        cteMaterializationTracker.markCTEAsMaterialized(cteName);
                    }
                });
            }
            stageExecution.addStateChangeListener(state -> {
                if (queryStateMachine.isDone()) {
                    return;
                }
                if (state == FAILED) {
                    queryStateMachine.transitionToFailed(stageExecution.getStageExecutionInfo().getFailureCause().get().toException());
                }
                else if (state == ABORTED) {
                    // this should never happen, since abort can only be triggered in query clean up after the query is finished
                    queryStateMachine.transitionToFailed(new PrestoException(GENERIC_INTERNAL_ERROR, "Query stage was aborted"));
                }
                else if (state == FINISHED) {
                    // checks if there's any new sections available for execution and starts the scheduling if any
                    startScheduling();
                }
                else if (queryStateMachine.getQueryState() == QueryState.STARTING) {
                    // if the stage has at least one task, we are running
                    if (stageExecution.hasTasks()) {
                        queryStateMachine.transitionToRunning();
                    }
                }
            });
            stageExecution.addFinalStageInfoListener(status -> queryStateMachine.updateQueryInfo(Optional.of(getStageInfo())));
        }

        // when query is done or any time a stage completes, attempt to transition query to "final query info ready"
        queryStateMachine.addStateChangeListener(newState -> {
            if (newState.isDone()) {
                queryStateMachine.updateQueryInfo(Optional.of(getStageInfo()));
            }
        });
    }

    private static void updateQueryOutputLocations(QueryStateMachine queryStateMachine, OutputBufferId rootBufferId, Set<RemoteTask> tasks, boolean noMoreExchangeLocations)
    {
        Map<URI, TaskId> bufferLocations = tasks.stream()
                .collect(toImmutableMap(
                        task -> getBufferLocation(task, rootBufferId),
                        RemoteTask::getTaskId));
        queryStateMachine.updateOutputLocations(bufferLocations, noMoreExchangeLocations);
    }

    private static URI getBufferLocation(RemoteTask remoteTask, OutputBufferId rootBufferId)
    {
        URI location = remoteTask.getTaskStatus().getSelf();
        try {
            URIBuilder builder = new URIBuilder(location);
            List<String> segments = builder.getPathSegments();
            segments.add("results");
            segments.add(rootBufferId.toString());
            builder.setPathSegments(segments);
            return builder.build();
        }
        catch (URISyntaxException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * returns a List of SqlStageExecutionInfos in a postorder representation of the tree
     */
    private List<StageExecutionAndScheduler> createStageExecutions(
            SectionExecutionFactory sectionExecutionFactory,
            ExchangeLocationsConsumer locationsConsumer,
            StreamingPlanSection section,
            Optional<int[]> bucketToPartition,
            OutputBuffers outputBuffers,
            RemoteTaskFactory remoteTaskFactory,
            SplitSourceFactory splitSourceFactory,
            Session session)
    {
        ImmutableList.Builder<StageExecutionAndScheduler> stages = ImmutableList.builder();

        for (StreamingPlanSection childSection : section.getChildren()) {
            ExchangeLocationsConsumer childLocationsConsumer = (fragmentId, tasks, noMoreExchangeLocations) -> {};
            stages.addAll(createStageExecutions(
                    sectionExecutionFactory,
                    childLocationsConsumer,
                    childSection,
                    Optional.empty(),
                    createDiscardingOutputBuffers(),
                    remoteTaskFactory,
                    splitSourceFactory,
                    session));
        }
        List<StageExecutionAndScheduler> sectionStages =
                sectionExecutionFactory.createSectionExecutions(
                        session,
                        section,
                        locationsConsumer,
                        bucketToPartition,
                        outputBuffers,
                        summarizeTaskInfo,
                        remoteTaskFactory,
                        splitSourceFactory,
                        0,
                        cteMaterializationTracker).getSectionStages();
        stages.addAll(sectionStages);

        return stages.build();
    }

    public void start()
    {
        if (started.compareAndSet(false, true)) {
            startScheduling();
        }
    }

    private void startScheduling()
    {
        requireNonNull(stageExecutions);
        // still scheduling the previous batch of stages
        if (scheduling.get()) {
            return;
        }
        executor.submit(this::schedule);
    }

    private void schedule()
    {
        if (!scheduling.compareAndSet(false, true)) {
            // still scheduling the previous batch of stages
            return;
        }

        List<StageExecutionAndScheduler> scheduledStageExecutions = new ArrayList<>();

        try (SetThreadName ignored = new SetThreadName("Query-%s", queryStateMachine.getQueryId())) {
            Set<StageId> completedStages = new HashSet<>();

            List<ExecutionSchedule> sectionExecutionSchedules = new LinkedList<>();

            while (!Thread.currentThread().isInterrupted()) {
                // remove finished section
                sectionExecutionSchedules.removeIf(ExecutionSchedule::isFinished);

                // try to pull more section that are ready to be run
                List<StreamingPlanSection> sectionsReadyForExecution = getSectionsReadyForExecution();

                // all finished
                if (sectionsReadyForExecution.isEmpty() && sectionExecutionSchedules.isEmpty()) {
                    break;
                }

                List<List<StageExecutionAndScheduler>> sectionStageExecutions = getStageExecutions(sectionsReadyForExecution);
                sectionStageExecutions.forEach(scheduledStageExecutions::addAll);
                sectionStageExecutions.stream()
                        .map(executionInfos -> executionInfos.stream()
                                .collect(toImmutableList()))
                        .map(stages -> executionPolicy.createExecutionSchedule(session, stages))
                        .forEach(sectionExecutionSchedules::add);

                while (sectionExecutionSchedules.stream().noneMatch(ExecutionSchedule::isFinished)) {
                    List<ListenableFuture<?>> blockedStages = new ArrayList<>();

                    List<StageExecutionAndScheduler> executionsToSchedule = sectionExecutionSchedules.stream()
                            .flatMap(schedule -> schedule.getStagesToSchedule().stream())
                            .collect(toImmutableList());

                    for (StageExecutionAndScheduler stageExecutionAndScheduler : executionsToSchedule) {
                        long startCpuNanos = THREAD_MX_BEAN.getCurrentThreadCpuTime();
                        long startWallNanos = System.nanoTime();

                        SqlStageExecution stageExecution = stageExecutionAndScheduler.getStageExecution();
                        stageExecution.beginScheduling();

                        // perform some scheduling work
                        ScheduleResult result = stageExecutionAndScheduler.getStageScheduler()
                                .schedule();

                        // Track leaf tasks if partial results are enabled
                        if (isPartialResultsEnabled(session) && stageExecutionAndScheduler.getStageExecution().getFragment().isLeaf()) {
                            for (RemoteTask task : result.getNewTasks()) {
                                partialResultQueryTaskTracker.trackTask(task);
                                task.addFinalTaskInfoListener(partialResultQueryTaskTracker::recordTaskFinish);
                            }
                        }

                        // modify parent and children based on the results of the scheduling
                        if (result.isFinished()) {
                            stageExecution.schedulingComplete();
                        }
                        else if (!result.getBlocked().isDone()) {
                            blockedStages.add(result.getBlocked());
                        }
                        stageExecutionAndScheduler.getStageLinkage()
                                .processScheduleResults(stageExecution.getState(), result.getNewTasks());
                        schedulerStats.getSplitsScheduledPerIteration().add(result.getSplitsScheduled());
                        if (result.getBlockedReason().isPresent()) {
                            ScheduleResult.BlockedReason blockedReason = result.getBlockedReason().get();
                            switch (blockedReason) {
                                case WRITER_SCALING:
                                    break;
                                case WAITING_FOR_CTE_MATERIALIZATION:
                                    schedulerStats.getWaitingForCTEMaterialization().update(1);
                                    break;
                                case WAITING_FOR_SOURCE:
                                    schedulerStats.getWaitingForSource().update(1);
                                    break;
                                case SPLIT_QUEUES_FULL:
                                    schedulerStats.getSplitQueuesFull().update(1);
                                    break;
                                case MIXED_SPLIT_QUEUES_FULL_AND_WAITING_FOR_SOURCE:
                                    schedulerStats.getMixedSplitQueuesFullAndWaitingForSource().update(1);
                                    break;
                                case NO_ACTIVE_DRIVER_GROUP:
                                    schedulerStats.getNoActiveDriverGroup().update(1);
                                    break;
                                default:
                                    throw new UnsupportedOperationException("Unknown blocked reason: " + blockedReason);
                            }
                            if (!result.getBlocked().isDone()) {
                                long startBlockedNanos = System.nanoTime();
                                result.getBlocked().addListener(
                                        () -> stageExecution.recordSchedulerBlockedTime(blockedReason, System.nanoTime() - startBlockedNanos),
                                        directExecutor());
                            }
                        }

                        stageExecution.recordSchedulerRunningTime(
                                THREAD_MX_BEAN.getCurrentThreadCpuTime() - startCpuNanos,
                                System.nanoTime() - startWallNanos);
                    }

                    // make sure to update stage linkage at least once per loop to catch async state changes (e.g., partial cancel)
                    boolean stageFinishedExecution = false;
                    for (StageExecutionAndScheduler stageExecutionInfo : scheduledStageExecutions) {
                        SqlStageExecution stageExecution = stageExecutionInfo.getStageExecution();
                        StageId stageId = stageExecution.getStageExecutionId().getStageId();
                        if (!completedStages.contains(stageId) && stageExecution.getState().isDone()) {
                            stageExecutionInfo.getStageLinkage()
                                    .processScheduleResults(stageExecution.getState(), ImmutableSet.of());
                            completedStages.add(stageId);
                            stageFinishedExecution = true;
                        }
                    }

                    // if any stage has just finished execution try to pull more sections for scheduling
                    if (stageFinishedExecution) {
                        break;
                    }

                    // wait for a state change and then schedule again
                    if (!blockedStages.isEmpty()) {
                        try (TimeStat.BlockTimer timer = schedulerStats.getSleepTime().time()) {
                            tryGetFutureValue(whenAnyComplete(blockedStages), 1, SECONDS);
                        }
                        for (ListenableFuture<?> blockedStage : blockedStages) {
                            blockedStage.cancel(true);
                        }
                    }
                }
            }

            for (StageExecutionAndScheduler stageExecutionInfo : scheduledStageExecutions) {
                StageExecutionState state = stageExecutionInfo.getStageExecution().getState();
                if (state != SCHEDULED && state != RUNNING && !state.isDone()) {
                    throw new PrestoException(GENERIC_INTERNAL_ERROR, format("Scheduling is complete, but stage execution %s is in state %s", stageExecutionInfo.getStageExecution().getStageExecutionId(), state));
                }
            }

            scheduling.set(false);

            // Inform the tracker that task scheduling has completed
            partialResultQueryTaskTracker.completeTaskScheduling();

            if (!getSectionsReadyForExecution().isEmpty()) {
                startScheduling();
            }
        }
        catch (Throwable t) {
            scheduling.set(false);
            queryStateMachine.transitionToFailed(t);
            throw t;
        }
        finally {
            RuntimeException closeError = new RuntimeException();
            for (StageExecutionAndScheduler stageExecutionInfo : scheduledStageExecutions) {
                try {
                    stageExecutionInfo.getStageScheduler().close();
                }
                catch (Throwable t) {
                    queryStateMachine.transitionToFailed(t);
                    // Self-suppression not permitted
                    if (closeError != t) {
                        closeError.addSuppressed(t);
                    }
                }
            }
            if (closeError.getSuppressed().length > 0) {
                throw closeError;
            }
        }
    }

    private List<StreamingPlanSection> getSectionsReadyForExecution()
    {
        long runningPlanSections =
                stream(forTree(StreamingPlanSection::getChildren).depthFirstPreOrder(sectionedPlan))
                        .map(section -> getStageExecution(section.getPlan().getFragment().getId()).getState())
                        .filter(state -> !state.isDone() && state != PLANNED)
                        .count();

        return stream(forTree(StreamingPlanSection::getChildren).depthFirstPreOrder(sectionedPlan))
                // get all sections ready for execution
                .filter(this::isReadyForExecution)
                // for enhanced cte blocking we do not need a limit on the sections
                .limit(isEnhancedCTESchedulingEnabled(session) ? Long.MAX_VALUE : maxConcurrentMaterializations - runningPlanSections)
                .map(this::tryCostBasedOptimize)
                .collect(toImmutableList());
    }

    /**
     * A general purpose utility function to invoke runtime cost-based optimizer.
     * (right now there is only one plan optimizer which determines if the probe and build side of a JoinNode should be swapped
     * based on the statistics of the temporary table holding materialized exchange outputs from finished children sections)
     */
    private StreamingPlanSection tryCostBasedOptimize(StreamingPlanSection section)
    {
        // no need to do runtime optimization if no materialized exchange data is utilized by the section.
        if (!isRuntimeOptimizerEnabled(session) || section.getChildren().isEmpty()) {
            return section;
        }

        // Apply runtime optimization on each StreamingSubPlan and generate optimized new fragments
        Map<PlanFragment, PlanFragment> oldToNewFragment = new HashMap<>();
        stream(forTree(StreamingSubPlan::getChildren).depthFirstPreOrder(section.getPlan()))
                .forEach(currentSubPlan -> {
                    Optional<PlanFragment> newPlanFragment = performRuntimeOptimizations(currentSubPlan);
                    if (newPlanFragment.isPresent()) {
                        planChecker.validatePlanFragment(newPlanFragment.get(), session, metadata, warningCollector);
                        oldToNewFragment.put(currentSubPlan.getFragment(), newPlanFragment.get());
                    }
                });

        // Early exit when no stage's fragment is changed
        if (oldToNewFragment.isEmpty()) {
            return section;
        }

        oldToNewFragment.forEach((oldFragment, newFragment) -> runtimeOptimizedStages.add(getStageId(oldFragment.getId())));

        // Update SubPlan so that getStageInfo will reflect the latest optimized plan when query is finished.
        updatePlan(oldToNewFragment);

        // Rebuild and update entries of the stageExecutions map.
        updateStageExecutions(section, oldToNewFragment);
        log.debug("Invoked CBO during runtime, optimized stage IDs: " + oldToNewFragment.keySet().stream()
                .map(PlanFragment::getId)
                .map(PlanFragmentId::toString)
                .collect(Collectors.joining(", ")));
        return section;
    }

    private Optional<PlanFragment> performRuntimeOptimizations(StreamingSubPlan subPlan)
    {
        PlanFragment fragment = subPlan.getFragment();
        PlanNode newRoot = fragment.getRoot();
        for (PlanOptimizer optimizer : runtimePlanOptimizers) {
            newRoot = optimizer.optimize(newRoot, session, TypeProvider.viewOf(variableAllocator.getVariables()), variableAllocator, idAllocator, warningCollector).getPlanNode();
        }
        if (newRoot != fragment.getRoot()) {
            Optional<StatsAndCosts> estimatedStatsAndCosts = fragment.getStatsAndCosts();
            return Optional.of(
                    // The partitioningScheme should stay the same
                    // even if the root's outputVariable layout is changed.
                    new PlanFragment(
                            fragment.getId(),
                            newRoot,
                            fragment.getVariables(),
                            fragment.getPartitioning(),
                            scheduleOrder(newRoot),
                            fragment.getPartitioningScheme(),
                            fragment.getStageExecutionDescriptor(),
                            fragment.isOutputTableWriterFragment(),
                            estimatedStatsAndCosts,
                            Optional.of(jsonFragmentPlan(newRoot, fragment.getVariables(), estimatedStatsAndCosts.orElse(StatsAndCosts.empty()), functionAndTypeManager, session))));
        }
        return Optional.empty();
    }

    /**
     * Utility function that rebuild a StreamingPlanSection, re-create stageExecutionAndScheduler for each of its stage, and finally update the stageExecutions map.
     */
    private void updateStageExecutions(StreamingPlanSection section, Map<PlanFragment, PlanFragment> oldToNewFragment)
    {
        StreamingPlanSection newSection = new StreamingPlanSection(rewriteStreamingSubPlan(section.getPlan(), oldToNewFragment), section.getChildren());
        PlanFragment sectionRootFragment = newSection.getPlan().getFragment();
        Optional<int[]> bucketToPartition;
        OutputBuffers outputBuffers;
        ExchangeLocationsConsumer locationsConsumer;
        if (isRootFragment(sectionRootFragment)) {
            bucketToPartition = Optional.of(new int[1]);
            outputBuffers = createInitialEmptyOutputBuffers(sectionRootFragment.getPartitioningScheme().getPartitioning().getHandle())
                    .withBuffer(new OutputBufferId(0), BROADCAST_PARTITION_ID)
                    .withNoMoreBufferIds();
            OutputBufferId rootBufferId = getOnlyElement(outputBuffers.getBuffers().keySet());
            locationsConsumer = (fragmentId, tasks, noMoreExchangeLocations) ->
                    updateQueryOutputLocations(queryStateMachine, rootBufferId, tasks, noMoreExchangeLocations);
        }
        else {
            bucketToPartition = Optional.empty();
            outputBuffers = createDiscardingOutputBuffers();
            locationsConsumer = (fragmentId, tasks, noMoreExchangeLocations) -> {};
        }
        SectionExecution sectionExecution = sectionExecutionFactory.createSectionExecutions(
                session,
                newSection,
                locationsConsumer,
                bucketToPartition,
                outputBuffers,
                summarizeTaskInfo,
                remoteTaskFactory,
                splitSourceFactory,
                0,
                cteMaterializationTracker);
        addStateChangeListeners(sectionExecution);
        Map<StageId, StageExecutionAndScheduler> updatedStageExecutions = sectionExecution.getSectionStages().stream()
                .collect(toImmutableMap(execution -> execution.getStageExecution().getStageExecutionId().getStageId(), identity()));
        synchronized (this) {
            stageExecutions.putAll(updatedStageExecutions);
        }
    }

    private void updatePlan(Map<PlanFragment, PlanFragment> oldToNewFragments)
    {
        plan.getAndUpdate(value -> rewritePlan(value, oldToNewFragments));
    }

    private SubPlan rewritePlan(SubPlan root, Map<PlanFragment, PlanFragment> oldToNewFragments)
    {
        ImmutableList.Builder<SubPlan> children = ImmutableList.builder();
        for (SubPlan child : root.getChildren()) {
            children.add(rewritePlan(child, oldToNewFragments));
        }
        if (oldToNewFragments.containsKey(root.getFragment())) {
            return new SubPlan(oldToNewFragments.get(root.getFragment()), children.build());
        }
        else {
            return new SubPlan(root.getFragment(), children.build());
        }
    }

    // Only used for adaptive optimization, to register listeners to new stageExecutions generated in runtime.
    private void addStateChangeListeners(SectionExecution sectionExecution)
    {
        for (StageExecutionAndScheduler stageExecutionAndScheduler : sectionExecution.getSectionStages()) {
            SqlStageExecution stageExecution = stageExecutionAndScheduler.getStageExecution();
            if (isRootFragment(stageExecution.getFragment())) {
                stageExecution.addStateChangeListener(state -> {
                    if (state == FINISHED) {
                        queryStateMachine.transitionToFinishing();
                    }
                    else if (state == CANCELED) {
                        // output stage was canceled
                        queryStateMachine.transitionToCanceled();
                    }
                });
            }
            stageExecution.addStateChangeListener(state -> {
                if (queryStateMachine.isDone()) {
                    return;
                }
                if (state == FAILED) {
                    queryStateMachine.transitionToFailed(stageExecution.getStageExecutionInfo().getFailureCause().get().toException());
                }
                else if (state == ABORTED) {
                    // this should never happen, since abort can only be triggered in query clean up after the query is finished
                    queryStateMachine.transitionToFailed(new PrestoException(GENERIC_INTERNAL_ERROR, "Query stage was aborted"));
                }
                else if (state == FINISHED) {
                    // checks if there's any new sections available for execution and starts the scheduling if any
                    startScheduling();
                }
                else if (queryStateMachine.getQueryState() == QueryState.STARTING) {
                    // if the stage has at least one task, we are running
                    if (stageExecution.hasTasks()) {
                        queryStateMachine.transitionToRunning();
                    }
                }
            });
            stageExecution.addFinalStageInfoListener(status -> queryStateMachine.updateQueryInfo(Optional.of(getStageInfo())));
        }
    }

    private StreamingSubPlan rewriteStreamingSubPlan(StreamingSubPlan root, Map<PlanFragment, PlanFragment> oldToNewFragment)
    {
        ImmutableList.Builder<StreamingSubPlan> childrenPlans = ImmutableList.builder();
        for (StreamingSubPlan child : root.getChildren()) {
            childrenPlans.add(rewriteStreamingSubPlan(child, oldToNewFragment));
        }
        if (oldToNewFragment.containsKey(root.getFragment())) {
            return new StreamingSubPlan(oldToNewFragment.get(root.getFragment()), childrenPlans.build());
        }
        else {
            return new StreamingSubPlan(root.getFragment(), childrenPlans.build());
        }
    }

    private static boolean isRootFragment(PlanFragment fragment)
    {
        return fragment.getId().getId() == ROOT_FRAGMENT_ID;
    }

    private boolean isReadyForExecution(StreamingPlanSection section)
    {
        SqlStageExecution stageExecution = getStageExecution(section.getPlan().getFragment().getId());
        if (stageExecution.getState() != PLANNED) {
            // already scheduled
            return false;
        }
        if (!isEnhancedCTESchedulingEnabled(session)) {
            // Enhanced cte blocking is not enabled so block till child sections are complete
            for (StreamingPlanSection child : section.getChildren()) {
                SqlStageExecution rootStageExecution = getStageExecution(child.getPlan().getFragment().getId());
                if (rootStageExecution.getState() != FINISHED) {
                    return false;
                }
            }
        }
        return true;
    }

    private List<List<StageExecutionAndScheduler>> getStageExecutions(List<StreamingPlanSection> sections)
    {
        return sections.stream()
                .map(section -> stream(forTree(StreamingSubPlan::getChildren).depthFirstPreOrder(section.getPlan())).collect(toImmutableList()))
                .map(plans -> plans.stream()
                        .map(StreamingSubPlan::getFragment)
                        .map(PlanFragment::getId)
                        .map(this::getStageExecutionInfo)
                        .collect(toImmutableList()))
                .collect(toImmutableList());
    }

    private SqlStageExecution getStageExecution(PlanFragmentId planFragmentId)
    {
        return stageExecutions.get(getStageId(planFragmentId)).getStageExecution();
    }

    private StageExecutionAndScheduler getStageExecutionInfo(PlanFragmentId planFragmentId)
    {
        return stageExecutions.get(getStageId(planFragmentId));
    }

    private StageId getStageId(PlanFragmentId fragmentId)
    {
        return new StageId(queryStateMachine.getQueryId(), fragmentId.getId());
    }

    public long getUserMemoryReservation()
    {
        return stageExecutions.values().stream()
                .mapToLong(stageExecutionInfo -> stageExecutionInfo.getStageExecution().getUserMemoryReservation())
                .sum();
    }

    public long getTotalMemoryReservation()
    {
        return stageExecutions.values().stream()
                .mapToLong(stageExecutionInfo -> stageExecutionInfo.getStageExecution().getTotalMemoryReservation())
                .sum();
    }

    public Duration getTotalCpuTime()
    {
        long millis = stageExecutions.values().stream()
                .mapToLong(stage -> stage.getStageExecution().getTotalCpuTime().toMillis())
                .sum();
        return new Duration(millis, MILLISECONDS);
    }

    @Override
    public long getRawInputDataSizeInBytes()
    {
        return stageExecutions.values().stream()
                .mapToLong(stage -> stage.getStageExecution().getRawInputDataSize())
                .sum();
    }

    @Override
    public long getWrittenIntermediateDataSizeInBytes()
    {
        return stageExecutions.values().stream()
                .mapToLong(stage -> stage.getStageExecution().getWrittenIntermediateDataSize())
                .sum();
    }

    @Override
    public long getOutputPositions()
    {
        return stageExecutions.get(rootStageId).getStageExecution().getStageExecutionInfo().getStats().getOutputPositions();
    }

    @Override
    public long getOutputDataSizeInBytes()
    {
        return stageExecutions.get(rootStageId).getStageExecution().getStageExecutionInfo().getStats().getOutputDataSizeInBytes();
    }

    public BasicStageExecutionStats getBasicStageStats()
    {
        List<BasicStageExecutionStats> stageStats = stageExecutions.values().stream()
                .map(stageExecutionInfo -> stageExecutionInfo.getStageExecution().getBasicStageStats())
                .collect(toImmutableList());

        return aggregateBasicStageStats(stageStats);
    }

    public StageInfo getStageInfo()
    {
        Map<StageId, StageExecutionInfo> stageInfos = stageExecutions.values().stream()
                .map(StageExecutionAndScheduler::getStageExecution)
                .collect(toImmutableMap(execution -> execution.getStageExecutionId().getStageId(), SqlStageExecution::getStageExecutionInfo));

        return buildStageInfo(plan.get(), stageInfos);
    }

    private StageInfo buildStageInfo(SubPlan subPlan, Map<StageId, StageExecutionInfo> stageExecutionInfos)
    {
        StageId stageId = getStageId(subPlan.getFragment().getId());
        StageExecutionInfo stageExecutionInfo = stageExecutionInfos.get(stageId);
        checkArgument(stageExecutionInfo != null, "No stageExecutionInfo for %s", stageId);
        return new StageInfo(
                stageId,
                locationFactory.createStageLocation(stageId),
                Optional.of(subPlan.getFragment()),
                stageExecutionInfo,
                ImmutableList.of(),
                subPlan.getChildren().stream()
                        .map(plan -> buildStageInfo(plan, stageExecutionInfos))
                        .collect(toImmutableList()),
                runtimeOptimizedStages.contains(stageId));
    }

    public void cancelStage(StageId stageId)
    {
        try (SetThreadName ignored = new SetThreadName("Query-%s", queryStateMachine.getQueryId())) {
            SqlStageExecution execution = stageExecutions.get(stageId).getStageExecution();
            SqlStageExecution stage = requireNonNull(execution, () -> format("Stage %s does not exist", stageId));
            stage.cancel();
        }
    }

    public void abort()
    {
        try (SetThreadName ignored = new SetThreadName("Query-%s", queryStateMachine.getQueryId())) {
            stageExecutions.values().forEach(stageExecutionInfo -> stageExecutionInfo.getStageExecution().abort());
        }
    }
}