FixedSourcePartitionedScheduler.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.execution.scheduler;

import com.facebook.airlift.log.Logger;
import com.facebook.presto.execution.Lifespan;
import com.facebook.presto.execution.RemoteTask;
import com.facebook.presto.execution.SqlStageExecution;
import com.facebook.presto.execution.TaskId;
import com.facebook.presto.execution.scheduler.ScheduleResult.BlockedReason;
import com.facebook.presto.execution.scheduler.group.DynamicLifespanScheduler;
import com.facebook.presto.execution.scheduler.group.FixedLifespanScheduler;
import com.facebook.presto.execution.scheduler.group.LifespanScheduler;
import com.facebook.presto.execution.scheduler.nodeSelection.NodeSelector;
import com.facebook.presto.metadata.InternalNode;
import com.facebook.presto.metadata.Split;
import com.facebook.presto.spi.connector.ConnectorPartitionHandle;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.plan.StageExecutionDescriptor;
import com.facebook.presto.split.SplitSource;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Streams;
import com.google.common.util.concurrent.ListenableFuture;

import javax.annotation.concurrent.GuardedBy;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Supplier;

import static com.facebook.airlift.concurrent.MoreFutures.whenAnyComplete;
import static com.facebook.presto.execution.scheduler.SourcePartitionedScheduler.newSourcePartitionedSchedulerAsSourceScheduler;
import static com.facebook.presto.spi.connector.NotPartitionedPartitionHandle.NOT_PARTITIONED;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

public class FixedSourcePartitionedScheduler
        implements StageScheduler
{
    private static final Logger log = Logger.get(FixedSourcePartitionedScheduler.class);

    private final SqlStageExecution stage;
    private final List<InternalNode> nodes;

    private final List<SourceScheduler> sourceSchedulers;
    private final List<ConnectorPartitionHandle> partitionHandles;
    private boolean scheduledTasks;
    private boolean anySourceSchedulingFinished;
    private final Optional<LifespanScheduler> groupedLifespanScheduler;

    private final Queue<Integer> tasksToRecover = new ConcurrentLinkedQueue<>();

    private final CTEMaterializationTracker cteMaterializationTracker;

    @GuardedBy("this")
    private boolean closed;

    public FixedSourcePartitionedScheduler(
            SqlStageExecution stage,
            Map<PlanNodeId, SplitSource> splitSources,
            StageExecutionDescriptor stageExecutionDescriptor,
            List<PlanNodeId> schedulingOrder,
            List<InternalNode> nodes,
            BucketNodeMap bucketNodeMap,
            int splitBatchSize,
            OptionalInt concurrentLifespansPerTask,
            NodeSelector nodeSelector,
            List<ConnectorPartitionHandle> partitionHandles,
            CTEMaterializationTracker cteMaterializationTracker)
    {
        requireNonNull(stage, "stage is null");
        requireNonNull(splitSources, "splitSources is null");
        requireNonNull(bucketNodeMap, "bucketNodeMap is null");
        checkArgument(!requireNonNull(nodes, "nodes is null").isEmpty(), "nodes is empty");
        requireNonNull(partitionHandles, "partitionHandles is null");
        this.cteMaterializationTracker = cteMaterializationTracker;

        this.stage = stage;
        this.nodes = ImmutableList.copyOf(nodes);
        this.partitionHandles = ImmutableList.copyOf(partitionHandles);

        checkArgument(splitSources.keySet().equals(ImmutableSet.copyOf(schedulingOrder)));

        BucketedSplitPlacementPolicy splitPlacementPolicy = new BucketedSplitPlacementPolicy(nodeSelector, nodes, bucketNodeMap, stage::getAllTasks);

        ArrayList<SourceScheduler> sourceSchedulers = new ArrayList<>();
        checkArgument(
                partitionHandles.equals(ImmutableList.of(NOT_PARTITIONED)) != stageExecutionDescriptor.isStageGroupedExecution(),
                "PartitionHandles should be [NOT_PARTITIONED] if and only if all scan nodes use ungrouped execution strategy");
        int nodeCount = nodes.size();
        int concurrentLifespans;
        if (concurrentLifespansPerTask.isPresent() && concurrentLifespansPerTask.getAsInt() * nodeCount <= partitionHandles.size()) {
            concurrentLifespans = concurrentLifespansPerTask.getAsInt() * nodeCount;
        }
        else {
            concurrentLifespans = partitionHandles.size();
        }

        boolean firstPlanNode = true;
        Optional<LifespanScheduler> groupedLifespanScheduler = Optional.empty();
        for (PlanNodeId planNodeId : schedulingOrder) {
            SplitSource splitSource = splitSources.get(planNodeId);
            boolean groupedExecutionForScanNode = stageExecutionDescriptor.isScanGroupedExecution(planNodeId);
            SourceScheduler sourceScheduler = newSourcePartitionedSchedulerAsSourceScheduler(
                    stage,
                    planNodeId,
                    splitSource,
                    splitPlacementPolicy,
                    Math.max(splitBatchSize / concurrentLifespans, 1),
                    groupedExecutionForScanNode);

            if (stageExecutionDescriptor.isStageGroupedExecution() && !groupedExecutionForScanNode) {
                sourceScheduler = new AsGroupedSourceScheduler(sourceScheduler);
            }
            sourceSchedulers.add(sourceScheduler);

            if (firstPlanNode) {
                firstPlanNode = false;
                if (!stageExecutionDescriptor.isStageGroupedExecution()) {
                    sourceScheduler.startLifespan(Lifespan.taskWide(), NOT_PARTITIONED);
                }
                else {
                    LifespanScheduler lifespanScheduler;
                    if (bucketNodeMap.isDynamic()) {
                        // Caller of the constructor guarantees dynamic bucket node map will only be
                        // used when the stage has no non-replicated remote sources and all scans use grouped
                        // execution.
                        lifespanScheduler = new DynamicLifespanScheduler(bucketNodeMap, nodes, partitionHandles, concurrentLifespansPerTask);
                    }
                    else {
                        lifespanScheduler = new FixedLifespanScheduler(bucketNodeMap, partitionHandles, concurrentLifespansPerTask);
                    }

                    // Schedule the first few lifespans
                    lifespanScheduler.scheduleInitial(sourceScheduler);
                    // Schedule new lifespans for finished ones
                    stage.addCompletedDriverGroupsChangedListener(lifespanScheduler::onLifespanExecutionFinished);
                    groupedLifespanScheduler = Optional.of(lifespanScheduler);
                }
            }
        }
        this.groupedLifespanScheduler = groupedLifespanScheduler;

        // use a CopyOnWriteArrayList to prevent ConcurrentModificationExceptions
        // if close() is called while the main thread is in the scheduling loop
        this.sourceSchedulers = new CopyOnWriteArrayList<>(sourceSchedulers);
    }

    private ConnectorPartitionHandle partitionHandleFor(Lifespan lifespan)
    {
        if (lifespan.isTaskWide()) {
            return NOT_PARTITIONED;
        }
        return partitionHandles.get(lifespan.getId());
    }

    @Override
    public ScheduleResult schedule()
    {
        // schedule a task on every node in the distribution
        List<RemoteTask> newTasks = ImmutableList.of();

        // CTE Materialization Check
        if (stage.requiresMaterializedCTE()) {
            List<ListenableFuture<?>> blocked = new ArrayList<>();
            List<String> requiredCTEIds = stage.getRequiredCTEList();
            for (String cteId : requiredCTEIds) {
                ListenableFuture<Void> cteFuture = cteMaterializationTracker.getFutureForCTE(cteId);
                if (!cteFuture.isDone()) {
                    // Add CTE materialization future to the blocked list
                    blocked.add(cteFuture);
                }
            }
            // If any CTE is not materialized, return a blocked ScheduleResult
            if (!blocked.isEmpty()) {
                return ScheduleResult.blocked(
                        false,
                        newTasks,
                        whenAnyComplete(blocked),
                        BlockedReason.WAITING_FOR_CTE_MATERIALIZATION,
                        0);
            }
        }
        // schedule a task on every node in the distribution
        if (!scheduledTasks) {
            newTasks = Streams.mapWithIndex(
                    nodes.stream(),
                    (node, id) -> stage.scheduleTask(node, toIntExact(id)))
                    .filter(Optional::isPresent)
                    .map(Optional::get)
                    .collect(toImmutableList());
            scheduledTasks = true;

            // notify listeners that we have scheduled all tasks so they can set no more buffers or exchange splits
            stage.transitionToFinishedTaskScheduling();
        }
        List<ListenableFuture<?>> blocked = new ArrayList<>();
        boolean allBlocked = true;
        BlockedReason blockedReason = BlockedReason.NO_ACTIVE_DRIVER_GROUP;

        if (groupedLifespanScheduler.isPresent()) {
            while (!tasksToRecover.isEmpty()) {
                if (anySourceSchedulingFinished) {
                    throw new IllegalStateException("Recover after any source scheduling finished is not supported");
                }
                groupedLifespanScheduler.get().onTaskFailed(tasksToRecover.poll(), sourceSchedulers);
            }

            if (groupedLifespanScheduler.get().allLifespanExecutionFinished()) {
                for (SourceScheduler sourceScheduler : sourceSchedulers) {
                    sourceScheduler.notifyAllLifespansFinishedExecution();
                }
            }
            else {
                // Start new driver groups on the first scheduler if necessary,
                // i.e. when previous ones have finished execution (not finished scheduling).
                //
                // Invoke schedule method to get a new SettableFuture every time.
                // Reusing previously returned SettableFuture could lead to the ListenableFuture retaining too many listeners.
                blocked.add(groupedLifespanScheduler.get().schedule(sourceSchedulers.get(0)));
            }
        }

        int splitsScheduled = 0;
        Iterator<SourceScheduler> schedulerIterator = sourceSchedulers.iterator();
        List<Lifespan> driverGroupsToStart = ImmutableList.of();
        while (schedulerIterator.hasNext()) {
            synchronized (this) {
                // if a source scheduler is closed while it is scheduling, we can get an error
                // prevent that by checking if scheduling has been cancelled first.
                if (closed) {
                    break;
                }
                SourceScheduler sourceScheduler = schedulerIterator.next();

                for (Lifespan lifespan : driverGroupsToStart) {
                    sourceScheduler.startLifespan(lifespan, partitionHandleFor(lifespan));
                }

                ScheduleResult schedule = sourceScheduler.schedule();
                if (schedule.getSplitsScheduled() > 0) {
                    stage.transitionToSchedulingSplits();
                }
                splitsScheduled += schedule.getSplitsScheduled();
                if (schedule.getBlockedReason().isPresent()) {
                    blocked.add(schedule.getBlocked());
                    blockedReason = blockedReason.combineWith(schedule.getBlockedReason().get());
                }
                else {
                    verify(schedule.getBlocked().isDone(), "blockedReason not provided when scheduler is blocked");
                    allBlocked = false;
                }

                driverGroupsToStart = sourceScheduler.drainCompletelyScheduledLifespans();

                if (schedule.isFinished()) {
                    stage.schedulingComplete(sourceScheduler.getPlanNodeId());
                    sourceSchedulers.remove(sourceScheduler);
                    sourceScheduler.close();
                    anySourceSchedulingFinished = true;
                }
            }
        }

        if (allBlocked) {
            return ScheduleResult.blocked(sourceSchedulers.isEmpty(), newTasks, whenAnyComplete(blocked), blockedReason, splitsScheduled);
        }
        else {
            return ScheduleResult.nonBlocked(sourceSchedulers.isEmpty(), newTasks, splitsScheduled);
        }
    }

    public void recover(TaskId taskId)
    {
        tasksToRecover.add(taskId.getId());
    }

    @Override
    public synchronized void close()
    {
        closed = true;
        for (SourceScheduler sourceScheduler : sourceSchedulers) {
            try {
                sourceScheduler.close();
            }
            catch (Throwable t) {
                log.warn(t, "Error closing split source");
            }
        }
        sourceSchedulers.clear();
    }

    public static class BucketedSplitPlacementPolicy
            implements SplitPlacementPolicy
    {
        private final NodeSelector nodeSelector;
        private final List<InternalNode> activeNodes;
        private final BucketNodeMap bucketNodeMap;
        private final Supplier<? extends List<RemoteTask>> remoteTasks;

        public BucketedSplitPlacementPolicy(
                NodeSelector nodeSelector,
                List<InternalNode> activeNodes,
                BucketNodeMap bucketNodeMap,
                Supplier<? extends List<RemoteTask>> remoteTasks)
        {
            this.nodeSelector = requireNonNull(nodeSelector, "nodeSelector is null");
            this.activeNodes = ImmutableList.copyOf(requireNonNull(activeNodes, "activeNodes is null"));
            this.bucketNodeMap = requireNonNull(bucketNodeMap, "bucketNodeMap is null");
            this.remoteTasks = requireNonNull(remoteTasks, "remoteTasks is null");
        }

        @Override
        public SplitPlacementResult computeAssignments(Set<Split> splits)
        {
            return nodeSelector.computeAssignments(splits, remoteTasks.get(), bucketNodeMap);
        }

        @Override
        public void lockDownNodes()
        {
        }

        @Override
        public List<InternalNode> getActiveNodes()
        {
            return activeNodes;
        }

        public InternalNode getNodeForBucket(int bucketId)
        {
            return bucketNodeMap.getAssignedNode(bucketId).get();
        }
    }

    private static class AsGroupedSourceScheduler
            implements SourceScheduler
    {
        private final SourceScheduler sourceScheduler;
        private boolean started;
        private boolean scheduleCompleted;
        private final List<Lifespan> pendingCompleted;

        public AsGroupedSourceScheduler(SourceScheduler sourceScheduler)
        {
            this.sourceScheduler = requireNonNull(sourceScheduler, "sourceScheduler is null");
            pendingCompleted = new ArrayList<>();
        }

        @Override
        public ScheduleResult schedule()
        {
            return sourceScheduler.schedule();
        }

        @Override
        public void close()
        {
            sourceScheduler.close();
        }

        @Override
        public PlanNodeId getPlanNodeId()
        {
            return sourceScheduler.getPlanNodeId();
        }

        @Override
        public void startLifespan(Lifespan lifespan, ConnectorPartitionHandle partitionHandle)
        {
            pendingCompleted.add(lifespan);
            if (started) {
                return;
            }
            started = true;
            sourceScheduler.startLifespan(Lifespan.taskWide(), NOT_PARTITIONED);
        }

        @Override
        public void rewindLifespan(Lifespan lifespan, ConnectorPartitionHandle partitionHandle)
        {
            throw new UnsupportedOperationException("rewindLifespan is not supported in AsGroupedSourceScheduler");
        }

        @Override
        public List<Lifespan> drainCompletelyScheduledLifespans()
        {
            if (!scheduleCompleted) {
                List<Lifespan> lifespans = sourceScheduler.drainCompletelyScheduledLifespans();
                if (lifespans.isEmpty()) {
                    return ImmutableList.of();
                }
                checkState(ImmutableList.of(Lifespan.taskWide()).equals(lifespans));
                scheduleCompleted = true;
            }
            List<Lifespan> result = ImmutableList.copyOf(pendingCompleted);
            pendingCompleted.clear();
            return result;
        }

        @Override
        public void notifyAllLifespansFinishedExecution()
        {
            checkState(scheduleCompleted);
            // no-op
        }
    }
}