SqlStageExecution.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.execution;
import com.facebook.presto.Session;
import com.facebook.presto.common.ErrorCode;
import com.facebook.presto.execution.StateMachine.StateChangeListener;
import com.facebook.presto.execution.buffer.OutputBuffers;
import com.facebook.presto.execution.scheduler.ScheduleResult;
import com.facebook.presto.execution.scheduler.SplitSchedulerStats;
import com.facebook.presto.execution.scheduler.TableWriteInfo;
import com.facebook.presto.failureDetector.FailureDetector;
import com.facebook.presto.metadata.InternalNode;
import com.facebook.presto.metadata.RemoteTransactionHandle;
import com.facebook.presto.metadata.Split;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.plan.CteMaterializationInfo;
import com.facebook.presto.spi.plan.PlanFragmentId;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.plan.TableFinishNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.split.RemoteSplit;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import io.airlift.units.Duration;
import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import static com.facebook.presto.SystemSessionProperties.getMaxFailedTaskPercentage;
import static com.facebook.presto.SystemSessionProperties.isEnhancedCTESchedulingEnabled;
import static com.facebook.presto.failureDetector.FailureDetector.State.GONE;
import static com.facebook.presto.operator.ExchangeOperator.REMOTE_CONNECTOR_ID;
import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static com.facebook.presto.spi.StandardErrorCode.GENERIC_RECOVERY_ERROR;
import static com.facebook.presto.spi.StandardErrorCode.PAGE_TRANSPORT_ERROR;
import static com.facebook.presto.spi.StandardErrorCode.PAGE_TRANSPORT_TIMEOUT;
import static com.facebook.presto.spi.StandardErrorCode.REMOTE_HOST_GONE;
import static com.facebook.presto.spi.StandardErrorCode.REMOTE_TASK_ERROR;
import static com.facebook.presto.spi.StandardErrorCode.REMOTE_TASK_MISMATCH;
import static com.facebook.presto.spi.StandardErrorCode.TOO_MANY_REQUESTS_FAILED;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.collect.Sets.newConcurrentHashSet;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.NANOSECONDS;
@ThreadSafe
public final class SqlStageExecution
{
public static final Set<ErrorCode> RECOVERABLE_ERROR_CODES = ImmutableSet.of(
TOO_MANY_REQUESTS_FAILED.toErrorCode(),
PAGE_TRANSPORT_ERROR.toErrorCode(),
PAGE_TRANSPORT_TIMEOUT.toErrorCode(),
REMOTE_TASK_MISMATCH.toErrorCode(),
REMOTE_TASK_ERROR.toErrorCode());
public static final int DEFAULT_TASK_ATTEMPT_NUMBER = 0;
private final Session session;
private final StageExecutionStateMachine stateMachine;
private final PlanFragment planFragment;
private final RemoteTaskFactory remoteTaskFactory;
private final NodeTaskMap nodeTaskMap;
private final boolean summarizeTaskInfo;
private final Executor executor;
private final FailureDetector failureDetector;
private final double maxFailedTaskPercentage;
private final Map<PlanFragmentId, RemoteSourceNode> exchangeSources;
private final TableWriteInfo tableWriteInfo;
private final Map<InternalNode, Set<RemoteTask>> tasks = new ConcurrentHashMap<>();
@GuardedBy("this")
private final AtomicInteger nextTaskId = new AtomicInteger();
@GuardedBy("this")
private final Set<TaskId> allTasks = newConcurrentHashSet();
@GuardedBy("this")
private final Set<TaskId> finishedTasks = newConcurrentHashSet();
@GuardedBy("this")
private final Set<TaskId> failedTasks = newConcurrentHashSet();
@GuardedBy("this")
private final Set<TaskId> runningTasks = newConcurrentHashSet();
private final Set<Lifespan> finishedLifespans = ConcurrentHashMap.newKeySet();
private final int totalLifespans;
@GuardedBy("this")
private final AtomicBoolean splitsScheduled = new AtomicBoolean();
@GuardedBy("this")
private final Multimap<PlanNodeId, RemoteTask> sourceTasks = HashMultimap.create();
@GuardedBy("this")
private final Set<PlanNodeId> completeSources = newConcurrentHashSet();
@GuardedBy("this")
private final Set<PlanFragmentId> completeSourceFragments = newConcurrentHashSet();
private final AtomicReference<OutputBuffers> outputBuffers = new AtomicReference<>();
private final ListenerManager<Set<Lifespan>> completedLifespansChangeListeners = new ListenerManager<>();
@GuardedBy("this")
private Optional<StageTaskRecoveryCallback> stageTaskRecoveryCallback = Optional.empty();
public static SqlStageExecution createSqlStageExecution(
StageExecutionId stageExecutionId,
PlanFragment fragment,
RemoteTaskFactory remoteTaskFactory,
Session session,
boolean summarizeTaskInfo,
NodeTaskMap nodeTaskMap,
ExecutorService executor,
FailureDetector failureDetector,
SplitSchedulerStats schedulerStats,
TableWriteInfo tableWriteInfo)
{
requireNonNull(stageExecutionId, "stageId is null");
requireNonNull(fragment, "fragment is null");
requireNonNull(remoteTaskFactory, "remoteTaskFactory is null");
requireNonNull(session, "session is null");
requireNonNull(nodeTaskMap, "nodeTaskMap is null");
requireNonNull(executor, "executor is null");
requireNonNull(failureDetector, "failureDetector is null");
requireNonNull(schedulerStats, "schedulerStats is null");
requireNonNull(tableWriteInfo, "tableWriteInfo is null");
SqlStageExecution sqlStageExecution = new SqlStageExecution(
session,
new StageExecutionStateMachine(stageExecutionId, executor, schedulerStats, !fragment.getTableScanSchedulingOrder().isEmpty()),
fragment,
remoteTaskFactory,
nodeTaskMap,
summarizeTaskInfo,
executor,
failureDetector,
getMaxFailedTaskPercentage(session),
tableWriteInfo);
sqlStageExecution.initialize();
return sqlStageExecution;
}
private SqlStageExecution(
Session session,
StageExecutionStateMachine stateMachine,
PlanFragment planFragment,
RemoteTaskFactory remoteTaskFactory,
NodeTaskMap nodeTaskMap,
boolean summarizeTaskInfo,
Executor executor,
FailureDetector failureDetector,
double maxFailedTaskPercentage,
TableWriteInfo tableWriteInfo)
{
this.session = requireNonNull(session, "session is null");
this.stateMachine = stateMachine;
this.planFragment = requireNonNull(planFragment, "planFragment is null");
this.remoteTaskFactory = requireNonNull(remoteTaskFactory, "remoteTaskFactory is null");
this.nodeTaskMap = requireNonNull(nodeTaskMap, "nodeTaskMap is null");
this.summarizeTaskInfo = summarizeTaskInfo;
this.executor = requireNonNull(executor, "executor is null");
this.failureDetector = requireNonNull(failureDetector, "failureDetector is null");
this.tableWriteInfo = requireNonNull(tableWriteInfo);
this.maxFailedTaskPercentage = maxFailedTaskPercentage;
ImmutableMap.Builder<PlanFragmentId, RemoteSourceNode> fragmentToExchangeSource = ImmutableMap.builder();
for (RemoteSourceNode remoteSourceNode : planFragment.getRemoteSourceNodes()) {
for (PlanFragmentId planFragmentId : remoteSourceNode.getSourceFragmentIds()) {
fragmentToExchangeSource.put(planFragmentId, remoteSourceNode);
}
}
this.exchangeSources = fragmentToExchangeSource.build();
this.totalLifespans = planFragment.getStageExecutionDescriptor().getTotalLifespans();
}
// this is a separate method to ensure that the `this` reference is not leaked during construction
private void initialize()
{
stateMachine.addStateChangeListener(newState -> {
if (newState.isDone()) {
checkAllTaskFinal();
}
});
completedLifespansChangeListeners.addListener(lifespans -> finishedLifespans.addAll(lifespans));
}
public StageExecutionId getStageExecutionId()
{
return stateMachine.getStageExecutionId();
}
public StageExecutionState getState()
{
return stateMachine.getState();
}
/**
* Listener is always notified asynchronously using a dedicated notification thread pool so, care should
* be taken to avoid leaking {@code this} when adding a listener in a constructor.
*/
public void addStateChangeListener(StateChangeListener<StageExecutionState> stateChangeListener)
{
stateMachine.addStateChangeListener(stateChangeListener);
}
/**
* Add a listener for the final stage info. This notification is guaranteed to be fired only once.
* Listener is always notified asynchronously using a dedicated notification thread pool so, care should
* be taken to avoid leaking {@code this} when adding a listener in a constructor. Additionally, it is
* possible notifications are observed out of order due to the asynchronous execution.
*/
public void addFinalStageInfoListener(StateChangeListener<StageExecutionInfo> stateChangeListener)
{
stateMachine.addFinalStageInfoListener(stateChangeListener);
}
public void addCompletedDriverGroupsChangedListener(Consumer<Set<Lifespan>> newlyCompletedDriverGroupConsumer)
{
completedLifespansChangeListeners.addListener(newlyCompletedDriverGroupConsumer);
}
public synchronized void registerStageTaskRecoveryCallback(StageTaskRecoveryCallback stageTaskRecoveryCallback)
{
checkState(!this.stageTaskRecoveryCallback.isPresent(), "stageTaskRecoveryCallback should be registered only once");
this.stageTaskRecoveryCallback = Optional.of(requireNonNull(stageTaskRecoveryCallback, "stageTaskRecoveryCallback is null"));
}
public PlanFragment getFragment()
{
return planFragment;
}
public OutputBuffers getOutputBuffers()
{
return outputBuffers.get();
}
public void beginScheduling()
{
stateMachine.transitionToScheduling();
}
public synchronized void transitionToFinishedTaskScheduling()
{
stateMachine.transitionToFinishedTaskScheduling();
}
public synchronized void transitionToSchedulingSplits()
{
stateMachine.transitionToSchedulingSplits();
}
public synchronized void schedulingComplete()
{
if (!stateMachine.transitionToScheduled()) {
return;
}
if (finishedTasks.size() == allTasks.size()) {
stateMachine.transitionToFinished();
}
for (PlanNodeId tableScanPlanNodeId : planFragment.getTableScanSchedulingOrder()) {
schedulingComplete(tableScanPlanNodeId);
}
}
public synchronized void schedulingComplete(PlanNodeId partitionedSource)
{
for (RemoteTask task : getAllTasks()) {
task.noMoreSplits(partitionedSource);
}
completeSources.add(partitionedSource);
}
public synchronized void cancel()
{
stateMachine.transitionToCanceled();
getAllTasks().forEach(RemoteTask::cancel);
}
public synchronized void abort()
{
stateMachine.transitionToAborted();
getAllTasks().forEach(RemoteTask::abort);
}
public long getUserMemoryReservation()
{
return stateMachine.getUserMemoryReservation();
}
public long getTotalMemoryReservation()
{
return stateMachine.getTotalMemoryReservation();
}
public Duration getTotalCpuTime()
{
long millis = getAllTasks().stream()
.mapToLong(task -> NANOSECONDS.toMillis(task.getTaskInfo().getStats().getTotalCpuTimeInNanos()))
.sum();
return new Duration(millis, TimeUnit.MILLISECONDS);
}
public synchronized long getRawInputDataSize()
{
if (planFragment.getTableScanSchedulingOrder().isEmpty()) {
return 0L;
}
return getAllTasks().stream()
.mapToLong(task -> task.getTaskInfo().getStats().getRawInputDataSizeInBytes())
.sum();
}
public synchronized long getWrittenIntermediateDataSize()
{
return getAllTasks().stream()
.filter(remoteTask -> !remoteTask.getPlanFragment().isOutputTableWriterFragment())
.mapToLong(task -> task.getTaskInfo().getStats().getPhysicalWrittenDataSizeInBytes())
.sum();
}
public BasicStageExecutionStats getBasicStageStats()
{
return stateMachine.getBasicStageStats(this::getAllTaskInfo);
}
public StageExecutionInfo getStageExecutionInfo()
{
return stateMachine.getStageExecutionInfo(this::getAllTaskInfo, finishedLifespans.size(), totalLifespans);
}
private Iterable<TaskInfo> getAllTaskInfo()
{
return getAllTasks().stream()
.map(RemoteTask::getTaskInfo)
.collect(toImmutableList());
}
public synchronized void addExchangeLocations(PlanFragmentId fragmentId, Set<RemoteTask> sourceTasks, boolean noMoreExchangeLocations)
{
requireNonNull(fragmentId, "fragmentId is null");
requireNonNull(sourceTasks, "sourceTasks is null");
RemoteSourceNode remoteSource = exchangeSources.get(fragmentId);
checkArgument(remoteSource != null, "Unknown remote source %s. Known sources are %s", fragmentId, exchangeSources.keySet());
this.sourceTasks.putAll(remoteSource.getId(), sourceTasks);
for (RemoteTask task : getAllTasks()) {
ImmutableMultimap.Builder<PlanNodeId, Split> newSplits = ImmutableMultimap.builder();
for (RemoteTask sourceTask : sourceTasks) {
TaskStatus sourceTaskStatus = sourceTask.getTaskStatus();
newSplits.put(remoteSource.getId(), createRemoteSplitFor(task.getTaskId(), sourceTask.getRemoteTaskLocation(), sourceTask.getTaskId()));
}
task.addSplits(newSplits.build());
}
if (noMoreExchangeLocations) {
completeSourceFragments.add(fragmentId);
// is the source now complete?
if (completeSourceFragments.containsAll(remoteSource.getSourceFragmentIds())) {
completeSources.add(remoteSource.getId());
for (RemoteTask task : getAllTasks()) {
task.noMoreSplits(remoteSource.getId());
}
}
}
}
public synchronized void setOutputBuffers(OutputBuffers outputBuffers)
{
requireNonNull(outputBuffers, "outputBuffers is null");
while (true) {
OutputBuffers currentOutputBuffers = this.outputBuffers.get();
if (currentOutputBuffers != null) {
if (outputBuffers.getVersion() <= currentOutputBuffers.getVersion()) {
return;
}
currentOutputBuffers.checkValidTransition(outputBuffers);
}
if (this.outputBuffers.compareAndSet(currentOutputBuffers, outputBuffers)) {
for (RemoteTask task : getAllTasks()) {
task.setOutputBuffers(outputBuffers);
}
return;
}
}
}
// do not synchronize
// this is used for query info building which should be independent of scheduling work
public boolean hasTasks()
{
return !tasks.isEmpty();
}
// do not synchronize
// this is used for query info building which should be independent of scheduling work
public List<RemoteTask> getAllTasks()
{
return tasks.values().stream()
.flatMap(Set::stream)
.collect(toImmutableList());
}
// We only support removeRemoteSource for single task stage because stages with many tasks introduce coordinator to worker HTTP requests in bursty manner.
// See https://github.com/prestodb/presto/pull/11065 for a similar issue.
public void removeRemoteSourceIfSingleTaskStage(TaskId remoteSourceTaskId)
{
List<RemoteTask> allTasks = getAllTasks();
if (allTasks.size() > 1) {
return;
}
getOnlyElement(allTasks).removeRemoteSource(remoteSourceTaskId);
}
public synchronized Optional<RemoteTask> scheduleTask(InternalNode node, int partition)
{
requireNonNull(node, "node is null");
if (stateMachine.getState().isDone()) {
return Optional.empty();
}
checkState(!splitsScheduled.get(), "scheduleTask can not be called once splits have been scheduled");
return Optional.of(scheduleTask(node, new TaskId(stateMachine.getStageExecutionId(), partition, DEFAULT_TASK_ATTEMPT_NUMBER), ImmutableMultimap.of()));
}
public synchronized Set<RemoteTask> scheduleSplits(InternalNode node, Multimap<PlanNodeId, Split> splits, Multimap<PlanNodeId, Lifespan> noMoreSplitsNotification)
{
requireNonNull(node, "node is null");
requireNonNull(splits, "splits is null");
if (stateMachine.getState().isDone()) {
return ImmutableSet.of();
}
splitsScheduled.set(true);
checkArgument(planFragment.getTableScanSchedulingOrder().containsAll(splits.keySet()), "Invalid splits");
ImmutableSet.Builder<RemoteTask> newTasks = ImmutableSet.builder();
Collection<RemoteTask> tasks = this.tasks.get(node);
RemoteTask task;
if (tasks == null) {
// The output buffer depends on the task id starting from 0 and being sequential, since each
// task is assigned a private buffer based on task id.
TaskId taskId = new TaskId(stateMachine.getStageExecutionId(), nextTaskId.getAndIncrement(), DEFAULT_TASK_ATTEMPT_NUMBER);
task = scheduleTask(node, taskId, splits);
newTasks.add(task);
}
else {
task = tasks.iterator().next();
task.addSplits(splits);
}
if (noMoreSplitsNotification.size() > 1) {
// The assumption that `noMoreSplitsNotification.size() <= 1` currently holds.
// If this assumption no longer holds, we should consider calling task.noMoreSplits with multiple entries in one shot.
// These kind of methods can be expensive since they are grabbing locks and/or sending HTTP requests on change.
throw new UnsupportedOperationException("This assumption no longer holds: noMoreSplitsNotification.size() < 1");
}
for (Entry<PlanNodeId, Lifespan> entry : noMoreSplitsNotification.entries()) {
task.noMoreSplits(entry.getKey(), entry.getValue());
}
return newTasks.build();
}
private synchronized RemoteTask scheduleTask(InternalNode node, TaskId taskId, Multimap<PlanNodeId, Split> sourceSplits)
{
checkArgument(!allTasks.contains(taskId), "A task with id %s already exists", taskId);
ImmutableMultimap.Builder<PlanNodeId, Split> initialSplits = ImmutableMultimap.builder();
initialSplits.putAll(sourceSplits);
sourceTasks.forEach((planNodeId, task) -> {
TaskStatus status = task.getTaskStatus();
if (status.getState() != TaskState.FINISHED) {
initialSplits.put(planNodeId, createRemoteSplitFor(taskId, task.getRemoteTaskLocation(), task.getTaskId()));
}
});
OutputBuffers outputBuffers = this.outputBuffers.get();
checkState(outputBuffers != null, "Initial output buffers must be set before a task can be scheduled");
RemoteTask task = remoteTaskFactory.createRemoteTask(
session,
taskId,
node,
planFragment,
initialSplits.build(),
outputBuffers,
nodeTaskMap.createTaskStatsTracker(node, taskId),
summarizeTaskInfo,
tableWriteInfo,
stateMachine);
completeSources.forEach(task::noMoreSplits);
allTasks.add(taskId);
runningTasks.add(taskId);
tasks.computeIfAbsent(node, key -> newConcurrentHashSet()).add(task);
nodeTaskMap.addTask(node, task);
task.addStateChangeListener(new StageTaskListener(taskId));
task.addFinalTaskInfoListener(this::updateFinalTaskInfo);
if (!stateMachine.getState().isDone()) {
task.start();
}
else {
// stage finished while we were scheduling this task
task.abort();
}
return task;
}
public Set<InternalNode> getScheduledNodes()
{
return ImmutableSet.copyOf(tasks.keySet());
}
public void recordGetSplitTime(long start)
{
stateMachine.recordGetSplitTime(start);
}
public void recordSchedulerRunningTime(long cpuTimeNanos, long wallTimeNanos)
{
if (planFragment.isLeaf()) {
stateMachine.recordLeafStageSchedulerRunningTime(cpuTimeNanos, wallTimeNanos);
}
stateMachine.recordSchedulerRunningTime(cpuTimeNanos, wallTimeNanos);
}
public void recordSchedulerBlockedTime(ScheduleResult.BlockedReason reason, long nanos)
{
if (planFragment.isLeaf()) {
stateMachine.recordLeafStageSchedulerBlockedTime(reason, nanos);
}
stateMachine.recordSchedulerBlockedTime(reason, nanos);
}
private static Split createRemoteSplitFor(TaskId taskId, URI remoteSourceTaskLocation, TaskId remoteSourceTaskId)
{
// Fetch the results from the buffer assigned to the task based on id
String splitLocation = remoteSourceTaskLocation.toASCIIString() + "/results/" + taskId.getId();
return new Split(REMOTE_CONNECTOR_ID, new RemoteTransactionHandle(), new RemoteSplit(new Location(splitLocation), remoteSourceTaskId));
}
private static String getCteIdFromSource(PlanNode source)
{
// Traverse the plan node tree to find a TableWriterNode with TemporaryTableInfo
return PlanNodeSearcher.searchFrom(source)
.where(planNode -> planNode instanceof TableFinishNode)
.findFirst()
.flatMap(planNode -> ((TableFinishNode) planNode).getCteMaterializationInfo())
.map(CteMaterializationInfo::getCteId)
.orElseThrow(() -> new IllegalStateException("TemporaryTableInfo has no CTE ID"));
}
public boolean isCTETableFinishStage()
{
return PlanNodeSearcher.searchFrom(planFragment.getRoot())
.where(planNode -> planNode instanceof TableFinishNode &&
((TableFinishNode) planNode).getCteMaterializationInfo().isPresent())
.findSingle()
.isPresent();
}
public String getCTEWriterId()
{
// Validate that this is a CTE TableFinish stage and return the associated CTE ID
if (!isCTETableFinishStage()) {
throw new IllegalStateException("This stage is not a CTE writer stage");
}
return getCteIdFromSource(planFragment.getRoot());
}
public boolean requiresMaterializedCTE()
{
if (!isEnhancedCTESchedulingEnabled(session)) {
return false;
}
// Search for TableScanNodes and check if they reference TemporaryTableInfo
return PlanNodeSearcher.searchFrom(planFragment.getRoot())
.where(planNode -> planNode instanceof TableScanNode)
.findAll().stream()
.anyMatch(planNode -> ((TableScanNode) planNode).getCteMaterializationInfo().isPresent());
}
public List<String> getRequiredCTEList()
{
// Collect all CTE IDs referenced by TableScanNodes with TemporaryTableInfo
return PlanNodeSearcher.searchFrom(planFragment.getRoot())
.where(planNode -> planNode instanceof TableScanNode)
.findAll().stream()
.map(planNode -> ((TableScanNode) planNode).getCteMaterializationInfo()
.orElseThrow(() -> new IllegalStateException("TableScanNode has no TemporaryTableInfo")))
.map(CteMaterializationInfo::getCteId)
.collect(Collectors.toList());
}
private void updateTaskStatus(TaskId taskId, TaskStatus taskStatus)
{
StageExecutionState stageExecutionState = getState();
if (stageExecutionState.isDone()) {
return;
}
TaskState taskState = taskStatus.getState();
if (taskState == TaskState.FAILED) {
// no matter if it is possible to recover - the task is failed
failedTasks.add(taskId);
RuntimeException failure = taskStatus.getFailures().stream()
.findFirst()
.map(this::rewriteTransportFailure)
.map(ExecutionFailureInfo::toException)
.orElseGet(() -> new PrestoException(GENERIC_INTERNAL_ERROR, "A task failed for an unknown reason"));
if (isRecoverable(taskStatus.getFailures())) {
try {
stageTaskRecoveryCallback.get().recover(taskId);
finishedTasks.add(taskId);
}
catch (Throwable t) {
// In an ideal world, this exception is not supposed to happen.
// However, it could happen, for example, if connector throws exception.
// We need to handle the exception in order to fail the query properly, otherwise the failed task will hang in RUNNING/SCHEDULING state.
failure.addSuppressed(new PrestoException(GENERIC_RECOVERY_ERROR, format("Encountered error when trying to recover task %s", taskId), t));
stateMachine.transitionToFailed(failure);
}
}
else {
stateMachine.transitionToFailed(failure);
}
}
else if (taskState == TaskState.ABORTED) {
// A task should only be in the aborted state if the STAGE is done (ABORTED or FAILED)
stateMachine.transitionToFailed(new PrestoException(GENERIC_INTERNAL_ERROR, "A task is in the ABORTED state but stage is " + stageExecutionState));
}
else if (taskState == TaskState.FINISHED) {
finishedTasks.add(taskId);
}
// The finishedTasks.add(taskStatus.getTaskId()) must happen before the getState() (see schedulingComplete)
stageExecutionState = getState();
if (stageExecutionState == StageExecutionState.SCHEDULED || stageExecutionState == StageExecutionState.RUNNING) {
if (taskState == TaskState.RUNNING) {
stateMachine.transitionToRunning();
}
if (finishedTasks.size() == allTasks.size()) {
stateMachine.transitionToFinished();
}
}
}
private boolean isRecoverable(List<ExecutionFailureInfo> failures)
{
for (ExecutionFailureInfo failure : failures) {
if (!RECOVERABLE_ERROR_CODES.contains(failure.getErrorCode())) {
return false;
}
}
return stageTaskRecoveryCallback.isPresent() &&
failedTasks.size() < allTasks.size() * maxFailedTaskPercentage;
}
private void updateFinalTaskInfo(TaskInfo finalTaskInfo)
{
runningTasks.remove(finalTaskInfo.getTaskId());
checkAllTaskFinal();
}
private void checkAllTaskFinal()
{
if (stateMachine.getState().isDone() && runningTasks.isEmpty()) {
if (getFragment().getStageExecutionDescriptor().isStageGroupedExecution()) {
// in case stage is CANCELLED/ABORTED/FAILED, number of finished lifespans can be less than total lifespans
checkState(finishedLifespans.size() <= totalLifespans, format("Number of finished lifespans (%s) exceeds number of total lifespans (%s)", finishedLifespans.size(), totalLifespans));
}
else {
// ungrouped execution will not update finished lifespans
checkState(finishedLifespans.isEmpty());
}
List<TaskInfo> finalTaskInfos = getAllTasks().stream()
.map(RemoteTask::getTaskInfo)
.collect(toImmutableList());
stateMachine.setAllTasksFinal(finalTaskInfos, totalLifespans);
}
}
private ExecutionFailureInfo rewriteTransportFailure(ExecutionFailureInfo executionFailureInfo)
{
if (executionFailureInfo.getRemoteHost() == null || failureDetector.getState(executionFailureInfo.getRemoteHost()) != GONE) {
return executionFailureInfo;
}
return new ExecutionFailureInfo(
executionFailureInfo.getType(),
executionFailureInfo.getMessage(),
executionFailureInfo.getCause(),
executionFailureInfo.getSuppressed(),
executionFailureInfo.getStack(),
executionFailureInfo.getErrorLocation(),
REMOTE_HOST_GONE.toErrorCode(),
executionFailureInfo.getRemoteHost(),
executionFailureInfo.getErrorCause());
}
@Override
public String toString()
{
return stateMachine.toString();
}
private class StageTaskListener
implements StateChangeListener<TaskStatus>
{
private long previousUserMemory;
private long previousSystemMemory;
private final Set<Lifespan> completedDriverGroups = new HashSet<>();
private final TaskId taskId;
public StageTaskListener(TaskId taskId)
{
this.taskId = requireNonNull(taskId, "taskId is null");
}
@Override
public void stateChanged(TaskStatus taskStatus)
{
try {
updateMemoryUsage(taskStatus);
updateCompletedDriverGroups(taskStatus);
}
finally {
updateTaskStatus(taskId, taskStatus);
}
}
private synchronized void updateMemoryUsage(TaskStatus taskStatus)
{
long currentUserMemory = taskStatus.getMemoryReservationInBytes();
long currentSystemMemory = taskStatus.getSystemMemoryReservationInBytes();
long deltaUserMemoryInBytes = currentUserMemory - previousUserMemory;
long deltaTotalMemoryInBytes = (currentUserMemory + currentSystemMemory) - (previousUserMemory + previousSystemMemory);
previousUserMemory = currentUserMemory;
previousSystemMemory = currentSystemMemory;
stateMachine.updateMemoryUsage(deltaUserMemoryInBytes, deltaTotalMemoryInBytes, taskStatus.getPeakNodeTotalMemoryReservationInBytes());
}
private synchronized void updateCompletedDriverGroups(TaskStatus taskStatus)
{
// Sets.difference returns a view.
// Once we add the difference into `completedDriverGroups`, the view will be empty.
// `completedLifespansChangeListeners.invoke` happens asynchronously.
// As a result, calling the listeners before updating `completedDriverGroups` doesn't make a difference.
// That's why a copy must be made here.
Set<Lifespan> newlyCompletedDriverGroups = ImmutableSet.copyOf(Sets.difference(taskStatus.getCompletedDriverGroups(), this.completedDriverGroups));
if (newlyCompletedDriverGroups.isEmpty()) {
return;
}
completedLifespansChangeListeners.invoke(newlyCompletedDriverGroups, executor);
// newlyCompletedDriverGroups is a view.
// Making changes to completedDriverGroups will change newlyCompletedDriverGroups.
completedDriverGroups.addAll(newlyCompletedDriverGroups);
}
}
@FunctionalInterface
public interface StageTaskRecoveryCallback
{
void recover(TaskId taskId);
}
private static class ListenerManager<T>
{
private final List<Consumer<T>> listeners = new ArrayList<>();
private boolean frozen;
public synchronized void addListener(Consumer<T> listener)
{
checkState(!frozen, "Listeners have been invoked");
listeners.add(listener);
}
public synchronized void invoke(T payload, Executor executor)
{
frozen = true;
for (Consumer<T> listener : listeners) {
executor.execute(() -> listener.accept(payload));
}
}
}
}