PrestoSparkTaskExecutorFactory.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.spark.execution.task;

import com.facebook.airlift.json.Codec;
import com.facebook.airlift.json.JsonCodec;
import com.facebook.airlift.log.Logger;
import com.facebook.airlift.stats.TestingGcMonitor;
import com.facebook.presto.Session;
import com.facebook.presto.SessionRepresentation;
import com.facebook.presto.common.RuntimeStats;
import com.facebook.presto.common.block.BlockEncodingManager;
import com.facebook.presto.common.io.DataOutput;
import com.facebook.presto.event.SplitMonitor;
import com.facebook.presto.execution.ExecutionFailureInfo;
import com.facebook.presto.execution.MemoryRevokingSchedulerUtils;
import com.facebook.presto.execution.ScheduledSplit;
import com.facebook.presto.execution.StageExecutionId;
import com.facebook.presto.execution.StageId;
import com.facebook.presto.execution.TaskId;
import com.facebook.presto.execution.TaskInfo;
import com.facebook.presto.execution.TaskManagerConfig;
import com.facebook.presto.execution.TaskSource;
import com.facebook.presto.execution.TaskState;
import com.facebook.presto.execution.TaskStateMachine;
import com.facebook.presto.execution.TaskStatus;
import com.facebook.presto.execution.buffer.OutputBufferInfo;
import com.facebook.presto.execution.buffer.OutputBufferMemoryManager;
import com.facebook.presto.execution.executor.TaskExecutor;
import com.facebook.presto.memory.MemoryPool;
import com.facebook.presto.memory.NodeMemoryConfig;
import com.facebook.presto.memory.QueryContext;
import com.facebook.presto.memory.TraversingQueryContextVisitor;
import com.facebook.presto.memory.VoidTraversingQueryContextVisitor;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.SessionPropertyManager;
import com.facebook.presto.operator.OperatorContext;
import com.facebook.presto.operator.OutputFactory;
import com.facebook.presto.operator.TaskContext;
import com.facebook.presto.operator.TaskMemoryReservationSummary;
import com.facebook.presto.operator.TaskStats;
import com.facebook.presto.spark.BasicPrincipal;
import com.facebook.presto.spark.PrestoSparkConfig;
import com.facebook.presto.spark.PrestoSparkTaskDescriptor;
import com.facebook.presto.spark.accesscontrol.PrestoSparkAuthenticatorProvider;
import com.facebook.presto.spark.classloader_interface.IPrestoSparkTaskExecutor;
import com.facebook.presto.spark.classloader_interface.IPrestoSparkTaskExecutorFactory;
import com.facebook.presto.spark.classloader_interface.MutablePartitionId;
import com.facebook.presto.spark.classloader_interface.PrestoSparkJavaExecutionTaskInputs;
import com.facebook.presto.spark.classloader_interface.PrestoSparkMutableRow;
import com.facebook.presto.spark.classloader_interface.PrestoSparkSerializedPage;
import com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleStats;
import com.facebook.presto.spark.classloader_interface.PrestoSparkStorageHandle;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskInputs;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskOutput;
import com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskDescriptor;
import com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskSource;
import com.facebook.presto.spark.classloader_interface.SerializedTaskInfo;
import com.facebook.presto.spark.execution.PrestoSparkBroadcastTableCacheManager;
import com.facebook.presto.spark.execution.PrestoSparkBufferedSerializedPage;
import com.facebook.presto.spark.execution.PrestoSparkExecutionExceptionFactory;
import com.facebook.presto.spark.execution.PrestoSparkOutputBuffer;
import com.facebook.presto.spark.execution.PrestoSparkPageOutputOperator.PrestoSparkPageOutputFactory;
import com.facebook.presto.spark.execution.PrestoSparkRemoteSourceFactory;
import com.facebook.presto.spark.execution.PrestoSparkRowBatch;
import com.facebook.presto.spark.execution.PrestoSparkRowBatch.RowTupleSupplier;
import com.facebook.presto.spark.execution.PrestoSparkRowOutputOperator.PreDeterminedPartitionFunction;
import com.facebook.presto.spark.execution.PrestoSparkRowOutputOperator.PrestoSparkRowOutputFactory;
import com.facebook.presto.spark.execution.shuffle.PrestoSparkShuffleInput;
import com.facebook.presto.spark.execution.shuffle.PrestoSparkShuffleReadInfo;
import com.facebook.presto.spark.execution.shuffle.PrestoSparkShuffleWriteInfo;
import com.facebook.presto.spark.util.PrestoSparkStatsCollectionUtils;
import com.facebook.presto.spi.ConnectorSplit;
import com.facebook.presto.spi.QueryId;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.memory.MemoryPoolId;
import com.facebook.presto.spi.page.PageDataOutput;
import com.facebook.presto.spi.plan.PlanFragmentId;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.security.Identity;
import com.facebook.presto.spi.security.TokenAuthenticator;
import com.facebook.presto.spi.storage.TempDataOperationContext;
import com.facebook.presto.spi.storage.TempDataSink;
import com.facebook.presto.spi.storage.TempStorage;
import com.facebook.presto.spi.storage.TempStorageHandle;
import com.facebook.presto.spiller.NodeSpillConfig;
import com.facebook.presto.spiller.SpillSpaceTracker;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.planner.LocalExecutionPlanner;
import com.facebook.presto.sql.planner.LocalExecutionPlanner.LocalExecutionPlan;
import com.facebook.presto.sql.planner.OutputPartitioning;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
import com.facebook.presto.sql.planner.planPrinter.PlanPrinter;
import com.facebook.presto.storage.TempStorageManager;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.airlift.slice.Slice;
import io.airlift.units.DataSize;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.util.CollectionAccumulator;
import scala.Tuple2;
import scala.collection.AbstractIterator;
import scala.collection.Iterator;

import javax.inject.Inject;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.URI;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.OptionalLong;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.zip.CRC32;

import static com.facebook.presto.ExceededMemoryLimitException.exceededLocalTotalMemoryLimit;
import static com.facebook.presto.SystemSessionProperties.QUERY_MAX_MEMORY_PER_NODE;
import static com.facebook.presto.SystemSessionProperties.QUERY_MAX_REVOCABLE_MEMORY_PER_NODE;
import static com.facebook.presto.SystemSessionProperties.QUERY_MAX_TOTAL_MEMORY_PER_NODE;
import static com.facebook.presto.SystemSessionProperties.getHashPartitionCount;
import static com.facebook.presto.SystemSessionProperties.getHeapDumpFileDirectory;
import static com.facebook.presto.SystemSessionProperties.getQueryMaxBroadcastMemory;
import static com.facebook.presto.SystemSessionProperties.getQueryMaxMemoryPerNode;
import static com.facebook.presto.SystemSessionProperties.getQueryMaxRevocableMemoryPerNode;
import static com.facebook.presto.SystemSessionProperties.getQueryMaxTotalMemoryPerNode;
import static com.facebook.presto.SystemSessionProperties.isHeapDumpOnExceededMemoryLimitEnabled;
import static com.facebook.presto.SystemSessionProperties.isSpillEnabled;
import static com.facebook.presto.SystemSessionProperties.isVerboseExceededMemoryLimitErrorsEnabled;
import static com.facebook.presto.execution.TaskState.FAILED;
import static com.facebook.presto.execution.TaskStatus.STARTING_VERSION;
import static com.facebook.presto.execution.buffer.BufferState.FINISHED;
import static com.facebook.presto.metadata.MetadataUpdates.DEFAULT_METADATA_UPDATES;
import static com.facebook.presto.spark.PrestoSparkSessionProperties.getAttemptNumberToApplyDynamicMemoryPoolTuning;
import static com.facebook.presto.spark.PrestoSparkSessionProperties.getDynamicPrestoMemoryPoolTuningFraction;
import static com.facebook.presto.spark.PrestoSparkSessionProperties.getMemoryRevokingTarget;
import static com.facebook.presto.spark.PrestoSparkSessionProperties.getMemoryRevokingThreshold;
import static com.facebook.presto.spark.PrestoSparkSessionProperties.getShuffleOutputTargetAverageRowSize;
import static com.facebook.presto.spark.PrestoSparkSessionProperties.getSparkBroadcastJoinMaxMemoryOverride;
import static com.facebook.presto.spark.PrestoSparkSessionProperties.getStorageBasedBroadcastJoinWriteBufferSize;
import static com.facebook.presto.spark.PrestoSparkSessionProperties.isDynamicPrestoMemoryPoolTuningEnabled;
import static com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleStats.Operation.WRITE;
import static com.facebook.presto.spark.util.PrestoSparkUtils.deserializeZstdCompressed;
import static com.facebook.presto.spark.util.PrestoSparkUtils.getNullifyingIterator;
import static com.facebook.presto.spark.util.PrestoSparkUtils.serializeZstdCompressed;
import static com.facebook.presto.spark.util.PrestoSparkUtils.toPrestoSparkSerializedPage;
import static com.facebook.presto.spi.ErrorCause.UNKNOWN;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION;
import static com.facebook.presto.util.Failures.toFailures;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Throwables.propagateIfPossible;
import static com.google.common.collect.Iterables.getFirst;
import static io.airlift.units.DataSize.Unit.BYTE;
import static io.airlift.units.DataSize.succinctBytes;
import static java.lang.Math.min;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.UUID.randomUUID;

public class PrestoSparkTaskExecutorFactory
        implements IPrestoSparkTaskExecutorFactory
{
    private static final Logger log = Logger.get(PrestoSparkTaskExecutorFactory.class);

    private final SessionPropertyManager sessionPropertyManager;
    private final BlockEncodingManager blockEncodingManager;
    private final FunctionAndTypeManager functionAndTypeManager;

    private final JsonCodec<PrestoSparkTaskDescriptor> taskDescriptorJsonCodec;
    private final Codec<TaskSource> taskSourceCodec;
    private final Codec<TaskInfo> taskInfoCodec;
    private final JsonCodec<List<TaskMemoryReservationSummary>> memoryReservationSummaryJsonCodec;

    private final Executor notificationExecutor;
    private final ScheduledExecutorService yieldExecutor;
    private final ScheduledExecutorService memoryUpdateExecutor;
    private final ExecutorService memoryRevocationExecutor;

    private final LocalExecutionPlanner localExecutionPlanner;
    private final PrestoSparkExecutionExceptionFactory executionExceptionFactory;
    private final TaskExecutor taskExecutor;
    private final SplitMonitor splitMonitor;

    private final Set<PrestoSparkAuthenticatorProvider> authenticatorProviders;

    private final NodeMemoryConfig nodeMemoryConfig;

    private final boolean nativeExecution;

    private final DataSize maxQuerySpillPerNode;
    private final DataSize sinkMaxBufferSize;

    private final boolean perOperatorCpuTimerEnabled;
    private final boolean cpuTimerEnabled;

    private final boolean perOperatorAllocationTrackingEnabled;
    private final boolean allocationTrackingEnabled;
    private final TempStorageManager tempStorageManager;
    private final PrestoSparkBroadcastTableCacheManager prestoSparkBroadcastTableCacheManager;
    private final String storageBasedBroadcastJoinStorage;

    private final AtomicBoolean memoryRevokePending = new AtomicBoolean();
    private final AtomicBoolean memoryRevokeRequestInProgress = new AtomicBoolean();

    @Inject
    public PrestoSparkTaskExecutorFactory(
            SessionPropertyManager sessionPropertyManager,
            BlockEncodingManager blockEncodingManager,
            FunctionAndTypeManager functionAndTypeManager,
            JsonCodec<PrestoSparkTaskDescriptor> taskDescriptorJsonCodec,
            Codec<TaskSource> taskSourceCodec,
            Codec<TaskInfo> taskInfoCodec,
            JsonCodec<List<TaskMemoryReservationSummary>> memoryReservationSummaryJsonCodec,
            Executor notificationExecutor,
            ScheduledExecutorService yieldExecutor,
            ScheduledExecutorService memoryUpdateExecutor,
            ExecutorService memoryRevocationExecutor,
            LocalExecutionPlanner localExecutionPlanner,
            PrestoSparkExecutionExceptionFactory executionExceptionFactory,
            TaskExecutor taskExecutor,
            SplitMonitor splitMonitor,
            Set<PrestoSparkAuthenticatorProvider> authenticatorProviders,
            FeaturesConfig featuresConfig,
            TaskManagerConfig taskManagerConfig,
            NodeMemoryConfig nodeMemoryConfig,
            NodeSpillConfig nodeSpillConfig,
            TempStorageManager tempStorageManager,
            PrestoSparkBroadcastTableCacheManager prestoSparkBroadcastTableCacheManager,
            PrestoSparkConfig prestoSparkConfig)
    {
        this(
                sessionPropertyManager,
                blockEncodingManager,
                functionAndTypeManager,
                taskDescriptorJsonCodec,
                taskSourceCodec,
                taskInfoCodec,
                memoryReservationSummaryJsonCodec,
                notificationExecutor,
                yieldExecutor,
                memoryUpdateExecutor,
                memoryRevocationExecutor,
                localExecutionPlanner,
                executionExceptionFactory,
                taskExecutor,
                splitMonitor,
                authenticatorProviders,
                nodeMemoryConfig,
                featuresConfig.isNativeExecutionEnabled(),
                requireNonNull(nodeSpillConfig, "nodeSpillConfig is null").getQueryMaxSpillPerNode(),
                requireNonNull(taskManagerConfig, "taskManagerConfig is null").getSinkMaxBufferSize(),
                requireNonNull(taskManagerConfig, "taskManagerConfig is null").isPerOperatorCpuTimerEnabled(),
                requireNonNull(taskManagerConfig, "taskManagerConfig is null").isTaskCpuTimerEnabled(),
                requireNonNull(taskManagerConfig, "taskManagerConfig is null").isPerOperatorAllocationTrackingEnabled(),
                requireNonNull(taskManagerConfig, "taskManagerConfig is null").isTaskAllocationTrackingEnabled(),
                tempStorageManager,
                requireNonNull(prestoSparkConfig, "prestoSparkConfig is null").getStorageBasedBroadcastJoinStorage(),
                prestoSparkBroadcastTableCacheManager);
    }

    public PrestoSparkTaskExecutorFactory(
            SessionPropertyManager sessionPropertyManager,
            BlockEncodingManager blockEncodingManager,
            FunctionAndTypeManager functionAndTypeManager,
            JsonCodec<PrestoSparkTaskDescriptor> taskDescriptorJsonCodec,
            Codec<TaskSource> taskSourceCodec,
            Codec<TaskInfo> taskInfoCodec,
            JsonCodec<List<TaskMemoryReservationSummary>> memoryReservationSummaryJsonCodec,
            Executor notificationExecutor,
            ScheduledExecutorService yieldExecutor,
            ScheduledExecutorService memoryUpdateExecutor,
            ExecutorService memoryRevocationExecutor,
            LocalExecutionPlanner localExecutionPlanner,
            PrestoSparkExecutionExceptionFactory executionExceptionFactory,
            TaskExecutor taskExecutor,
            SplitMonitor splitMonitor,
            Set<PrestoSparkAuthenticatorProvider> authenticatorProviders,
            NodeMemoryConfig nodeMemoryConfig,
            boolean nativeExecution,
            DataSize maxQuerySpillPerNode,
            DataSize sinkMaxBufferSize,
            boolean perOperatorCpuTimerEnabled,
            boolean cpuTimerEnabled,
            boolean perOperatorAllocationTrackingEnabled,
            boolean allocationTrackingEnabled,
            TempStorageManager tempStorageManager,
            String storageBasedBroadcastJoinStorage,
            PrestoSparkBroadcastTableCacheManager prestoSparkBroadcastTableCacheManager)
    {
        this.sessionPropertyManager = requireNonNull(sessionPropertyManager, "sessionPropertyManager is null");
        this.blockEncodingManager = requireNonNull(blockEncodingManager, "blockEncodingManager is null");
        this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionManager is null");
        this.taskDescriptorJsonCodec = requireNonNull(taskDescriptorJsonCodec, "sparkTaskDescriptorJsonCodec is null");
        this.taskSourceCodec = requireNonNull(taskSourceCodec, "taskSourceCodec is null");
        this.taskInfoCodec = requireNonNull(taskInfoCodec, "taskInfoCodec is null");
        this.memoryReservationSummaryJsonCodec = requireNonNull(memoryReservationSummaryJsonCodec, "memoryReservationSummaryJsonCodec is null");
        this.notificationExecutor = requireNonNull(notificationExecutor, "notificationExecutor is null");
        this.yieldExecutor = requireNonNull(yieldExecutor, "yieldExecutor is null");
        this.memoryUpdateExecutor = requireNonNull(memoryUpdateExecutor, "memoryUpdateExecutor is null");
        this.memoryRevocationExecutor = requireNonNull(memoryRevocationExecutor, "memoryRevocationExecutor is null");
        this.localExecutionPlanner = requireNonNull(localExecutionPlanner, "localExecutionPlanner is null");
        this.executionExceptionFactory = requireNonNull(executionExceptionFactory, "executionExceptionFactory is null");
        this.taskExecutor = requireNonNull(taskExecutor, "taskExecutor is null");
        this.splitMonitor = requireNonNull(splitMonitor, "splitMonitor is null");
        this.authenticatorProviders = ImmutableSet.copyOf(requireNonNull(authenticatorProviders, "authenticatorProviders is null"));
        // Ordering is needed to make sure serialized plans are consistent for the same map
        this.nodeMemoryConfig = requireNonNull(nodeMemoryConfig, "nodeMemoryConfig is null");
        this.nativeExecution = nativeExecution;
        this.maxQuerySpillPerNode = requireNonNull(maxQuerySpillPerNode, "maxQuerySpillPerNode is null");
        this.sinkMaxBufferSize = requireNonNull(sinkMaxBufferSize, "sinkMaxBufferSize is null");
        this.perOperatorCpuTimerEnabled = perOperatorCpuTimerEnabled;
        this.cpuTimerEnabled = cpuTimerEnabled;
        this.perOperatorAllocationTrackingEnabled = perOperatorAllocationTrackingEnabled;
        this.allocationTrackingEnabled = allocationTrackingEnabled;
        this.tempStorageManager = requireNonNull(tempStorageManager, "tempStorageManager is null");
        this.storageBasedBroadcastJoinStorage = requireNonNull(storageBasedBroadcastJoinStorage, "storageBasedBroadcastJoinStorage is null");
        this.prestoSparkBroadcastTableCacheManager = requireNonNull(prestoSparkBroadcastTableCacheManager, "prestoSparkBroadcastTableCacheManager is null");
    }

    @Override
    public <T extends PrestoSparkTaskOutput> IPrestoSparkTaskExecutor<T> create(
            int partitionId,
            int attemptNumber,
            SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor,
            Iterator<SerializedPrestoSparkTaskSource> serializedTaskSources,
            PrestoSparkTaskInputs inputs,
            CollectionAccumulator<SerializedTaskInfo> taskInfoCollector,
            CollectionAccumulator<PrestoSparkShuffleStats> shuffleStatsCollector,
            Class<T> outputType)
    {
        try {
            return doCreate(
                    partitionId,
                    attemptNumber,
                    serializedTaskDescriptor,
                    serializedTaskSources,
                    inputs,
                    taskInfoCollector,
                    shuffleStatsCollector,
                    outputType);
        }
        catch (RuntimeException e) {
            throw executionExceptionFactory.toPrestoSparkExecutionException(e);
        }
    }

    @Override
    public void close() {}

    public <T extends PrestoSparkTaskOutput> IPrestoSparkTaskExecutor<T> doCreate(
            int partitionId,
            int attemptNumber,
            SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor,
            Iterator<SerializedPrestoSparkTaskSource> serializedTaskSources,
            PrestoSparkTaskInputs inputs,
            CollectionAccumulator<SerializedTaskInfo> taskInfoCollector,
            CollectionAccumulator<PrestoSparkShuffleStats> shuffleStatsCollector,
            Class<T> outputType)
    {
        PrestoSparkTaskDescriptor taskDescriptor = taskDescriptorJsonCodec.fromJson(serializedTaskDescriptor.getBytes());
        ImmutableMap.Builder<String, TokenAuthenticator> extraAuthenticators = ImmutableMap.builder();
        authenticatorProviders.forEach(provider -> extraAuthenticators.putAll(provider.getTokenAuthenticators()));

        SessionRepresentation sessionRepresentation = taskDescriptor.getSession();
        Session session = sessionRepresentation.toSession(
                sessionPropertyManager,
                taskDescriptor.getExtraCredentials(),
                extraAuthenticators.build());
        DataSize maxUserMemory = getDynamicallyComputedMemory(session, getQueryMaxMemoryPerNode(session), attemptNumber);
        DataSize maxTotalMemory = getDynamicallyComputedMemory(session, getQueryMaxTotalMemoryPerNode(session), attemptNumber);
        DataSize maxRevocableMemory = getDynamicallyComputedMemory(session, getQueryMaxRevocableMemoryPerNode(session), attemptNumber);

        ImmutableMap.Builder<String, String> extraSessionProperties = ImmutableMap.<String, String>builder()
                .put(QUERY_MAX_MEMORY_PER_NODE, maxUserMemory.toString())
                .put(QUERY_MAX_TOTAL_MEMORY_PER_NODE, maxTotalMemory.toString())
                .put(QUERY_MAX_REVOCABLE_MEMORY_PER_NODE, maxRevocableMemory.toString());

        session = createSessionWithExtraSessionProperties(
                 sessionRepresentation,
                 taskDescriptor.getExtraCredentials(),
                 extraAuthenticators.build(),
                 extraSessionProperties.build());

        PlanFragment fragment = taskDescriptor.getFragment();
        StageId stageId = new StageId(session.getQueryId(), fragment.getId().getId());

        // Clear the cache if the cache does not have broadcast table for current stageId.
        // We will only cache 1 HT at any time. If the stageId changes, we will drop the old cached HT
        prestoSparkBroadcastTableCacheManager.removeCachedTablesForStagesOtherThan(stageId);

        TaskId taskId = new TaskId(new StageExecutionId(stageId, 0), partitionId, attemptNumber);

        log.info(PlanPrinter.textPlanFragment(fragment, functionAndTypeManager, session, true));

        DataSize maxBroadcastMemory = getSparkBroadcastJoinMaxMemoryOverride(session);
        if (maxBroadcastMemory == null) {
            maxBroadcastMemory = new DataSize(min(nodeMemoryConfig.getMaxQueryBroadcastMemory().toBytes(), getQueryMaxBroadcastMemory(session).toBytes()), BYTE);
        }

        MemoryPool memoryPool = new MemoryPool(new MemoryPoolId("spark-executor-memory-pool"), maxTotalMemory);

        SpillSpaceTracker spillSpaceTracker = new SpillSpaceTracker(maxQuerySpillPerNode);

        QueryContext queryContext = new QueryContext(
                session.getQueryId(),
                maxUserMemory,
                maxTotalMemory,
                maxBroadcastMemory,
                maxRevocableMemory,
                memoryPool,
                new TestingGcMonitor(),
                notificationExecutor,
                yieldExecutor,
                maxQuerySpillPerNode,
                spillSpaceTracker,
                memoryReservationSummaryJsonCodec);
        queryContext.setVerboseExceededMemoryLimitErrorsEnabled(isVerboseExceededMemoryLimitErrorsEnabled(session));
        queryContext.setHeapDumpOnExceededMemoryLimitEnabled(isHeapDumpOnExceededMemoryLimitEnabled(session));
        String heapDumpFilePath = Paths.get(
                getHeapDumpFileDirectory(session),
                format("%s_%s.hprof", session.getQueryId().getId(), stageId.getId())).toString();
        queryContext.setHeapDumpFilePath(heapDumpFilePath);

        TaskStateMachine taskStateMachine = new TaskStateMachine(taskId, notificationExecutor);
        TaskContext taskContext = queryContext.addTaskContext(
                taskStateMachine,
                session,
                // Plan has to be retained only if verbose memory exceeded errors are requested
                isVerboseExceededMemoryLimitErrorsEnabled(session) ? Optional.of(fragment.getRoot()) : Optional.empty(),
                perOperatorCpuTimerEnabled,
                cpuTimerEnabled,
                perOperatorAllocationTrackingEnabled,
                allocationTrackingEnabled,
                false);

        final double memoryRevokingThreshold = getMemoryRevokingThreshold(session);
        final double memoryRevokingTarget = getMemoryRevokingTarget(session);
        checkArgument(
                memoryRevokingTarget <= memoryRevokingThreshold,
                "memoryRevokingTarget should be less than or equal memoryRevokingThreshold, but got %s and %s respectively",
                memoryRevokingTarget, memoryRevokingThreshold);
        boolean heapDumpOnExceededMemoryLimitEnabled = isHeapDumpOnExceededMemoryLimitEnabled(session);
        if (isSpillEnabled(session)) {
            memoryPool.addListener((pool, queryId, totalMemoryReservationBytes) -> {
                if (totalMemoryReservationBytes > queryContext.getPeakNodeTotalMemory()) {
                    queryContext.setPeakNodeTotalMemory(totalMemoryReservationBytes);
                }

                if (totalMemoryReservationBytes > pool.getMaxBytes() * memoryRevokingThreshold && memoryRevokeRequestInProgress.compareAndSet(false, true)) {
                    memoryRevocationExecutor.execute(() -> {
                        try {
                            AtomicLong remainingBytesToRevoke = new AtomicLong(totalMemoryReservationBytes - (long) (memoryRevokingTarget * pool.getMaxBytes()));
                            remainingBytesToRevoke.addAndGet(-MemoryRevokingSchedulerUtils.getMemoryAlreadyBeingRevoked(ImmutableList.of(taskContext), remainingBytesToRevoke.get()));
                            taskContext.accept(new VoidTraversingQueryContextVisitor<AtomicLong>()
                            {
                                @Override
                                public Void visitOperatorContext(OperatorContext operatorContext, AtomicLong remainingBytesToRevoke)
                                {
                                    if (remainingBytesToRevoke.get() > 0) {
                                        long revokedBytes = operatorContext.requestMemoryRevoking();
                                        if (revokedBytes > 0) {
                                            memoryRevokePending.set(true);
                                            remainingBytesToRevoke.addAndGet(-revokedBytes);
                                        }
                                    }
                                    return null;
                                }
                            }, remainingBytesToRevoke);
                            memoryRevokeRequestInProgress.set(false);
                        }
                        catch (Exception e) {
                            log.error(e, "Error requesting memory revoking");
                        }
                    });
                }

                // Get the latest memory reservation info since it might have changed due to revoke
                long totalReservedMemory = pool.getQueryMemoryReservation(queryId) + pool.getQueryRevocableMemoryReservation(queryId);

                // If total memory usage is over maxTotalMemory and memory revoke request is not pending, fail the query with EXCEEDED_MEMORY_LIMIT error
                if (totalReservedMemory > maxTotalMemory.toBytes() && !memoryRevokeRequestInProgress.get() && !isMemoryRevokePending(taskContext)) {
                    throw exceededLocalTotalMemoryLimit(
                            maxTotalMemory,
                            queryContext.getAdditionalFailureInfo(totalReservedMemory, 0, "test-operator") +
                                    format("Total reserved memory: %s, Total revocable memory: %s",
                                            succinctBytes(pool.getQueryMemoryReservation(queryId)),
                                            succinctBytes(pool.getQueryRevocableMemoryReservation(queryId))),
                            heapDumpOnExceededMemoryLimitEnabled,
                            Optional.ofNullable(heapDumpFilePath),
                            UNKNOWN);
                }
            });
        }

        ImmutableMap.Builder<PlanNodeId, List<PrestoSparkShuffleInput>> shuffleInputs = ImmutableMap.builder();
        ImmutableMap.Builder<PlanNodeId, List<java.util.Iterator<PrestoSparkSerializedPage>>> pageInputs = ImmutableMap.builder();
        ImmutableMap.Builder<PlanNodeId, List<?>> broadcastInputs = ImmutableMap.builder();
        ImmutableMap.Builder<PlanNodeId, PrestoSparkShuffleReadInfo> shuffleReadInfos = ImmutableMap.builder();

        List<TaskSource> taskSources;
        Optional<PrestoSparkShuffleWriteInfo> shuffleWriteInfo = Optional.empty();

        checkArgument(
                inputs instanceof PrestoSparkJavaExecutionTaskInputs,
                format("PrestoSparkJavaExecutionTaskInputs is required for java execution, but %s is provided", inputs.getClass().getName()));
        PrestoSparkJavaExecutionTaskInputs taskInputs = (PrestoSparkJavaExecutionTaskInputs) inputs;
        fillJavaExecutionTaskInputs(fragment, taskInputs, shuffleInputs, pageInputs, broadcastInputs);
        taskSources = getTaskSources(serializedTaskSources);

        OutputBufferMemoryManager memoryManager = new OutputBufferMemoryManager(
                sinkMaxBufferSize.toBytes(),
                () -> queryContext.getTaskContextByTaskId(taskId).localSystemMemoryContext(),
                notificationExecutor);

        Optional<OutputPartitioning> preDeterminedPartition = Optional.empty();
        if (fragment.getPartitioningScheme().getPartitioning().getHandle().equals(FIXED_ARBITRARY_DISTRIBUTION)) {
            int partitionCount = getHashPartitionCount(session);
            preDeterminedPartition = Optional.of(new OutputPartitioning(
                    new PreDeterminedPartitionFunction(partitionId % partitionCount, partitionCount),
                    ImmutableList.of(),
                    ImmutableList.of(),
                    false,
                    OptionalInt.empty()));
        }

        TempDataOperationContext tempDataOperationContext = new TempDataOperationContext(
                session.getSource(),
                session.getQueryId().getId(),
                session.getClientInfo(),
                Optional.of(session.getClientTags()),
                session.getIdentity());
        TempStorage tempStorage = tempStorageManager.getTempStorage(storageBasedBroadcastJoinStorage);

        Output<T> output = configureOutput(
                outputType,
                blockEncodingManager,
                memoryManager,
                getShuffleOutputTargetAverageRowSize(session),
                preDeterminedPartition,
                tempStorage,
                tempDataOperationContext,
                getStorageBasedBroadcastJoinWriteBufferSize(session));
        PrestoSparkOutputBuffer<?> outputBuffer = output.getOutputBuffer();

        LocalExecutionPlan localExecutionPlan = localExecutionPlanner.plan(
                taskContext,
                fragment,
                output.getOutputFactory(),
                new PrestoSparkRemoteSourceFactory(
                        blockEncodingManager,
                        shuffleInputs.build(),
                        pageInputs.build(),
                        broadcastInputs.build(),
                        partitionId,
                        shuffleStatsCollector,
                        tempStorage,
                        tempDataOperationContext,
                        prestoSparkBroadcastTableCacheManager,
                        stageId),
                taskDescriptor.getTableWriteInfo(),
                true,
                ImmutableList.of());

        taskStateMachine.addStateChangeListener(state -> {
            if (state.isDone()) {
                outputBuffer.setNoMoreRows();
            }
        });

        PrestoSparkTaskExecution taskExecution = new PrestoSparkTaskExecution(
                taskStateMachine,
                taskContext,
                localExecutionPlan,
                taskExecutor,
                splitMonitor,
                notificationExecutor,
                memoryUpdateExecutor,
                nativeExecution);

        log.info("Task [%s] received %d splits.",
                taskId,
                taskSources.stream()
                        .mapToInt(taskSource -> taskSource.getSplits().size())
                        .sum());

        OptionalLong totalSplitSize = computeAllSplitsSize(taskSources);
        if (totalSplitSize.isPresent()) {
            log.info("Total split size: %s bytes.", totalSplitSize.getAsLong());
        }
        taskExecution.start(taskSources);

        return new PrestoSparkTaskExecutor<>(
                taskContext,
                taskStateMachine,
                output.getOutputSupplier(),
                taskInfoCodec,
                taskInfoCollector,
                shuffleStatsCollector,
                executionExceptionFactory,
                output.getOutputBufferType(),
                outputBuffer,
                tempStorage,
                tempDataOperationContext);
    }

    private DataSize getDynamicallyComputedMemory(Session session, DataSize memory, int attemptNumber)
    {
        // For the first attempt, use the static memory values present in SessionProperties.
        // Second attempt onwards use configured JVM memory for computing presto memory limits.
        if (isDynamicPrestoMemoryPoolTuningEnabled(session) && attemptNumber >= getAttemptNumberToApplyDynamicMemoryPoolTuning(session)) {
            double jvmMaxMemoryInBytes = Runtime.getRuntime().maxMemory();
            double prestoMemoryPoolTuningFraction = getDynamicPrestoMemoryPoolTuningFraction(session);
            log.info("Dynamically Tuning Presto Memory Configs. Configured JVM Memory: %f; Dynamic Memory Tuning fraction: %f", jvmMaxMemoryInBytes, prestoMemoryPoolTuningFraction);
            double finalMemoryValue = Math.max(prestoMemoryPoolTuningFraction * jvmMaxMemoryInBytes, memory.toBytes());
            return DataSize.succinctDataSize(finalMemoryValue, DataSize.Unit.BYTE);
        }

        return memory;
    }

    private Session createSessionWithExtraSessionProperties(
            SessionRepresentation sessionRepresentation,
            Map<String, String> extraCredentials,
            Map<String, TokenAuthenticator> extraAuthenticators,
            Map<String, String> extraSystemProperties)
    {
        Map<String, String> updatedSessionProperties = new HashMap<>(sessionRepresentation.getSystemProperties());
        updatedSessionProperties.putAll(extraSystemProperties);

        return new Session(
                new QueryId(sessionRepresentation.getQueryId()),
                sessionRepresentation.getTransactionId(),
                sessionRepresentation.isClientTransactionSupport(),
                new Identity(
                        sessionRepresentation.getUser(),
                        sessionRepresentation.getPrincipal().map(BasicPrincipal::new),
                        sessionRepresentation.getRoles(),
                        extraCredentials,
                        extraAuthenticators,
                        Optional.empty(),
                        Optional.empty()),
                sessionRepresentation.getSource(),
                sessionRepresentation.getCatalog(),
                sessionRepresentation.getSchema(),
                sessionRepresentation.getTraceToken(),
                sessionRepresentation.getTimeZoneKey(),
                sessionRepresentation.getLocale(),
                sessionRepresentation.getRemoteUserAddress(),
                sessionRepresentation.getUserAgent(),
                sessionRepresentation.getClientInfo(),
                sessionRepresentation.getClientTags(),
                sessionRepresentation.getResourceEstimates(),
                sessionRepresentation.getStartTime(),
                ImmutableMap.copyOf(updatedSessionProperties),
                sessionRepresentation.getCatalogProperties(),
                sessionRepresentation.getUnprocessedCatalogProperties(),
                sessionPropertyManager,
                sessionRepresentation.getPreparedStatements(),
                sessionRepresentation.getSessionFunctions(),
                Optional.empty(),
                // we use NOOP to create a session from the representation as worker does not require warning collectors
                WarningCollector.NOOP,
                new RuntimeStats(),
                Optional.empty());
    }

    public boolean isMemoryRevokePending(TaskContext taskContext)
    {
        TraversingQueryContextVisitor<Void, Boolean> visitor = new TraversingQueryContextVisitor<Void, Boolean>()
        {
            @Override
            public Boolean visitOperatorContext(OperatorContext operatorContext, Void context)
            {
                return operatorContext.isMemoryRevokingRequested();
            }

            @Override
            public Boolean mergeResults(List<Boolean> childrenResults)
            {
                return childrenResults.contains(true);
            }
        };

        memoryRevocationExecutor.execute(() -> memoryRevokePending.set(taskContext.accept(visitor, null)));
        return memoryRevokePending.get();
    }

    private static OptionalLong computeAllSplitsSize(List<TaskSource> taskSources)
    {
        long sum = 0;
        for (TaskSource taskSource : taskSources) {
            for (ScheduledSplit scheduledSplit : taskSource.getSplits()) {
                ConnectorSplit connectorSplit = scheduledSplit.getSplit().getConnectorSplit();
                if (!connectorSplit.getSplitSizeInBytes().isPresent()) {
                    return OptionalLong.empty();
                }

                sum += connectorSplit.getSplitSizeInBytes().getAsLong();
            }
        }

        return OptionalLong.of(sum);
    }

    private void fillJavaExecutionTaskInputs(
            PlanFragment fragment,
            PrestoSparkJavaExecutionTaskInputs inputs,
            ImmutableMap.Builder<PlanNodeId, List<PrestoSparkShuffleInput>> shuffleInputs,
            ImmutableMap.Builder<PlanNodeId, List<java.util.Iterator<PrestoSparkSerializedPage>>> pageInputs,
            ImmutableMap.Builder<PlanNodeId, List<?>> broadcastInputs)
    {
        for (RemoteSourceNode remoteSource : fragment.getRemoteSourceNodes()) {
            ImmutableList.Builder<PrestoSparkShuffleInput> remoteSourceRowInputsBuilder = ImmutableList.builder();
            ImmutableList.Builder<java.util.Iterator<PrestoSparkSerializedPage>> remoteSourcePageInputsBuilder = ImmutableList.builder();
            ImmutableList.Builder<List<?>> broadcastInputsListBuilder = ImmutableList.builder();
            for (PlanFragmentId sourceFragmentId : remoteSource.getSourceFragmentIds()) {
                Iterator<Tuple2<MutablePartitionId, PrestoSparkMutableRow>> shuffleInput = inputs.getShuffleInputs().get(sourceFragmentId.toString());
                Broadcast<?> broadcastInput = inputs.getBroadcastInputs().get(sourceFragmentId.toString());
                List<PrestoSparkSerializedPage> inMemoryInput = inputs.getInMemoryInputs().get(sourceFragmentId.toString());

                if (shuffleInput != null) {
                    checkArgument(broadcastInput == null, "single remote source is not expected to accept different kind of inputs");
                    checkArgument(inMemoryInput == null, "single remote source is not expected to accept different kind of inputs");
                    remoteSourceRowInputsBuilder.add(new PrestoSparkShuffleInput(sourceFragmentId.getId(), shuffleInput));
                    continue;
                }

                if (broadcastInput != null) {
                    checkArgument(inMemoryInput == null, "single remote source is not expected to accept different kind of inputs");
                    // TODO: Enable NullifyingIterator once migrated to one task per JVM model
                    // NullifyingIterator removes element from the list upon return
                    // This allows GC to gradually reclaim memory
                    // remoteSourcePageInputs.add(getNullifyingIterator(broadcastInput.value()));
                    broadcastInputsListBuilder.add((List<?>) broadcastInput.value());
                    continue;
                }

                if (inMemoryInput != null) {
                    // for in-memory inputs pages can be released incrementally to save memory
                    remoteSourcePageInputsBuilder.add(getNullifyingIterator(inMemoryInput));
                    continue;
                }

                throw new IllegalStateException("Input not found for sourceFragmentId: " + sourceFragmentId);
            }
            List<PrestoSparkShuffleInput> remoteSourceRowInputs = remoteSourceRowInputsBuilder.build();
            List<java.util.Iterator<PrestoSparkSerializedPage>> remoteSourcePageInputs = remoteSourcePageInputsBuilder.build();
            List<List<?>> broadcastInputsList = broadcastInputsListBuilder.build();
            if (!remoteSourceRowInputs.isEmpty()) {
                shuffleInputs.put(remoteSource.getId(), remoteSourceRowInputs);
            }
            if (!remoteSourcePageInputs.isEmpty()) {
                pageInputs.put(remoteSource.getId(), remoteSourcePageInputs);
            }
            if (!broadcastInputsList.isEmpty()) {
                broadcastInputs.put(remoteSource.getId(), broadcastInputsList);
            }
        }
    }

    private List<TaskSource> getTaskSources(Iterator<SerializedPrestoSparkTaskSource> serializedTaskSources)
    {
        long totalSerializedSizeInBytes = 0;
        ImmutableList.Builder<TaskSource> result = ImmutableList.builder();
        while (serializedTaskSources.hasNext()) {
            SerializedPrestoSparkTaskSource serializedTaskSource = serializedTaskSources.next();
            totalSerializedSizeInBytes += serializedTaskSource.getBytes().length;
            result.add(deserializeZstdCompressed(taskSourceCodec, serializedTaskSource.getBytes()));
        }
        log.info("Total serialized size of all task sources: %s", succinctBytes(totalSerializedSizeInBytes));
        return result.build();
    }

    @SuppressWarnings("unchecked")
    private static <T extends PrestoSparkTaskOutput> Output<T> configureOutput(
            Class<T> outputType,
            BlockEncodingManager blockEncodingManager,
            OutputBufferMemoryManager memoryManager,
            DataSize targetAverageRowSize,
            Optional<OutputPartitioning> preDeterminedPartition,
            TempStorage tempStorage,
            TempDataOperationContext tempDataOperationContext,
            DataSize writeBufferSize)
    {
        if (outputType.equals(PrestoSparkMutableRow.class)) {
            PrestoSparkOutputBuffer<PrestoSparkRowBatch> outputBuffer = new PrestoSparkOutputBuffer<>(memoryManager);
            OutputFactory outputFactory = new PrestoSparkRowOutputFactory(outputBuffer, targetAverageRowSize, preDeterminedPartition);
            OutputSupplier<T> outputSupplier = (OutputSupplier<T>) new RowOutputSupplier(outputBuffer);
            return new Output<>(OutputBufferType.SPARK_ROW_OUTPUT_BUFFER, outputBuffer, outputFactory, outputSupplier);
        }
        else if (outputType.equals(PrestoSparkSerializedPage.class)) {
            PrestoSparkOutputBuffer<PrestoSparkBufferedSerializedPage> outputBuffer = new PrestoSparkOutputBuffer<>(memoryManager);
            OutputFactory outputFactory = new PrestoSparkPageOutputFactory(outputBuffer, blockEncodingManager);
            OutputSupplier<T> outputSupplier = (OutputSupplier<T>) new PageOutputSupplier(outputBuffer);
            return new Output<>(OutputBufferType.SPARK_PAGE_OUTPUT_BUFFER, outputBuffer, outputFactory, outputSupplier);
        }
        else if (outputType.equals(PrestoSparkStorageHandle.class)) {
            PrestoSparkOutputBuffer<PrestoSparkBufferedSerializedPage> outputBuffer = new PrestoSparkOutputBuffer<>(memoryManager);
            OutputFactory outputFactory = new PrestoSparkPageOutputFactory(outputBuffer, blockEncodingManager);
            OutputSupplier<T> outputSupplier = (OutputSupplier<T>) new DiskPageOutputSupplier(outputBuffer, tempStorage, tempDataOperationContext, writeBufferSize);
            return new Output<>(OutputBufferType.SPARK_DISK_PAGE_OUTPUT_BUFFER, outputBuffer, outputFactory, outputSupplier);
        }
        else {
            throw new IllegalArgumentException("Unexpected output type: " + outputType.getName());
        }
    }

    private static class PrestoSparkTaskExecutor<T extends PrestoSparkTaskOutput>
            extends AbstractIterator<Tuple2<MutablePartitionId, T>>
            implements IPrestoSparkTaskExecutor<T>
    {
        private final TaskContext taskContext;
        private final TaskStateMachine taskStateMachine;
        private final OutputSupplier<T> outputSupplier;
        private final Codec<TaskInfo> taskInfoCodec;
        private final CollectionAccumulator<SerializedTaskInfo> taskInfoCollector;
        private final CollectionAccumulator<PrestoSparkShuffleStats> shuffleStatsCollector;
        private final PrestoSparkExecutionExceptionFactory executionExceptionFactory;
        private final OutputBufferType outputBufferType;
        private final PrestoSparkOutputBuffer<?> outputBuffer;
        private final TempStorage tempStorage;
        private final TempDataOperationContext tempDataOperationContext;

        private final UUID taskInstanceId = randomUUID();

        private Tuple2<MutablePartitionId, T> next;

        private Long start;
        private long processedRows;
        private long processedRowBatches;
        private long processedBytes;

        private PrestoSparkTaskExecutor(
                TaskContext taskContext,
                TaskStateMachine taskStateMachine,
                OutputSupplier<T> outputSupplier,
                Codec<TaskInfo> taskInfoCodec,
                CollectionAccumulator<SerializedTaskInfo> taskInfoCollector,
                CollectionAccumulator<PrestoSparkShuffleStats> shuffleStatsCollector,
                PrestoSparkExecutionExceptionFactory executionExceptionFactory,
                OutputBufferType outputBufferType,
                PrestoSparkOutputBuffer<?> outputBuffer,
                TempStorage tempStorage,
                TempDataOperationContext tempDataOperationContext)
        {
            this.taskContext = requireNonNull(taskContext, "taskContext is null");
            this.taskStateMachine = requireNonNull(taskStateMachine, "taskStateMachine is null");
            this.outputSupplier = requireNonNull(outputSupplier, "outputSupplier is null");
            this.taskInfoCodec = requireNonNull(taskInfoCodec, "taskInfoCodec is null");
            this.taskInfoCollector = requireNonNull(taskInfoCollector, "taskInfoCollector is null");
            this.shuffleStatsCollector = requireNonNull(shuffleStatsCollector, "shuffleStatsCollector is null");
            this.executionExceptionFactory = requireNonNull(executionExceptionFactory, "executionExceptionFactory is null");
            this.outputBufferType = requireNonNull(outputBufferType, "outputBufferType is null");
            this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null");
            this.tempStorage = requireNonNull(tempStorage, "tempStorage is null");
            this.tempDataOperationContext = requireNonNull(tempDataOperationContext, "tempDataOperationContext is null");
        }

        @Override
        public boolean hasNext()
        {
            if (next == null) {
                next = computeNext();
            }
            return next != null;
        }

        @Override
        public Tuple2<MutablePartitionId, T> next()
        {
            if (next == null) {
                next = computeNext();
            }
            if (next == null) {
                throw new NoSuchElementException();
            }
            Tuple2<MutablePartitionId, T> result = next;
            next = null;
            return result;
        }

        protected Tuple2<MutablePartitionId, T> computeNext()
        {
            try {
                return doComputeNext();
            }
            catch (RuntimeException e) {
                throw executionExceptionFactory.toPrestoSparkExecutionException(e);
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                taskStateMachine.abort();
                throw new RuntimeException(e);
            }
        }

        private Tuple2<MutablePartitionId, T> doComputeNext()
                throws InterruptedException
        {
            if (start == null) {
                start = System.currentTimeMillis();
            }
            Tuple2<MutablePartitionId, T> output = outputSupplier.getNext();

            if (output != null) {
                processedRows += output._2.getPositionCount();
                processedRowBatches++;
                processedBytes += output._2.getSize();
                return output;
            }

            // task finished
            TaskState taskState = taskStateMachine.getState();
            checkState(taskState.isDone(), "task is expected to be done");

            collectTaskStatsOnCompletion();

            LinkedBlockingQueue<Throwable> failures = taskStateMachine.getFailureCauses();
            if (failures.isEmpty()) {
                return null;
            }

            Throwable failure = getFirst(failures, null);
            // Delete the storage file, if task is not successful
            if (outputSupplier instanceof DiskPageOutputSupplier && output != null) {
                PrestoSparkStorageHandle sparkStorageHandle = (PrestoSparkStorageHandle) output._2;
                TempStorageHandle tempStorageHandle = tempStorage.deserialize(sparkStorageHandle.getSerializedStorageHandle());
                try {
                    tempStorage.remove(tempDataOperationContext, tempStorageHandle);
                    log.info("Removed broadcast spill file: " + tempStorageHandle.toString());
                }
                catch (IOException e) {
                    // self suppression is not allowed
                    if (e != failure) {
                        failure.addSuppressed(e);
                    }
                }
            }

            propagateIfPossible(failure, Error.class);
            propagateIfPossible(failure, RuntimeException.class);
            propagateIfPossible(failure, InterruptedException.class);
            throw new RuntimeException(failure);
        }

        private void collectTaskStatsOnCompletion()
        {
            TaskInfo taskInfo = createTaskInfo(taskContext, taskStateMachine, taskInstanceId, outputBufferType, outputBuffer);
            SerializedTaskInfo serializedTaskInfo = new SerializedTaskInfo(serializeZstdCompressed(taskInfoCodec, taskInfo));
            taskInfoCollector.add(serializedTaskInfo);

            PrestoSparkStatsCollectionUtils.collectMetrics(taskInfo);

            long end = System.currentTimeMillis();
            PrestoSparkShuffleStats shuffleStats = new PrestoSparkShuffleStats(
                    taskContext.getTaskId().getStageExecutionId().getStageId().getId(),
                    taskContext.getTaskId().getId(),
                    WRITE,
                    processedRows,
                    processedRowBatches,
                    processedBytes,
                    end - start - outputSupplier.getTimeSpentWaitingForOutputInMillis());
            shuffleStatsCollector.add(shuffleStats);
        }

        private static TaskInfo createTaskInfo(
                TaskContext taskContext,
                TaskStateMachine taskStateMachine,
                UUID taskInstanceId,
                OutputBufferType outputBufferType,
                PrestoSparkOutputBuffer<?> outputBuffer)
        {
            TaskId taskId = taskContext.getTaskId();
            TaskState taskState = taskContext.getState();
            TaskStats taskStats = taskContext.getTaskStats().summarizeFinal();

            List<ExecutionFailureInfo> failures = ImmutableList.of();
            if (taskState == FAILED) {
                failures = toFailures(taskStateMachine.getFailureCauses());
            }

            TaskStatus taskStatus = new TaskStatus(
                    taskInstanceId.getLeastSignificantBits(),
                    taskInstanceId.getMostSignificantBits(),
                    STARTING_VERSION,
                    taskState,
                    URI.create("http://fake.invalid/task/" + taskId),
                    taskContext.getCompletedDriverGroups(),
                    failures,
                    taskStats.getQueuedPartitionedDrivers(),
                    taskStats.getRunningPartitionedDrivers(),
                    0,
                    false,
                    taskStats.getPhysicalWrittenDataSizeInBytes(),
                    taskStats.getUserMemoryReservationInBytes(),
                    taskStats.getSystemMemoryReservationInBytes(),
                    taskStats.getPeakNodeTotalMemoryInBytes(),
                    taskStats.getFullGcCount(),
                    taskStats.getFullGcTimeInMillis(),
                    taskStats.getTotalCpuTimeInNanos(),
                    System.currentTimeMillis() - taskStats.getCreateTimeInMillis(),
                    taskStats.getQueuedPartitionedSplitsWeight(),
                    taskStats.getRunningPartitionedSplitsWeight());

            OutputBufferInfo outputBufferInfo = new OutputBufferInfo(
                    outputBufferType.name(),
                    FINISHED,
                    false,
                    false,
                    0,
                    0,
                    outputBuffer.getTotalRowsProcessed(),
                    outputBuffer.getTotalPagesProcessed(),
                    ImmutableList.of());

            return new TaskInfo(
                    taskId,
                    taskStatus,
                    System.currentTimeMillis(),
                    outputBufferInfo,
                    ImmutableSet.of(),
                    taskStats,
                    false,
                    DEFAULT_METADATA_UPDATES,
                    "");
        }
    }

    private static class Output<T extends PrestoSparkTaskOutput>
    {
        private final OutputBufferType outputBufferType;
        private final PrestoSparkOutputBuffer<?> outputBuffer;
        private final OutputFactory outputFactory;
        private final OutputSupplier<T> outputSupplier;

        private Output(
                OutputBufferType outputBufferType,
                PrestoSparkOutputBuffer<?> outputBuffer,
                OutputFactory outputFactory,
                OutputSupplier<T> outputSupplier)
        {
            this.outputBufferType = requireNonNull(outputBufferType, "outputBufferType is null");
            this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null");
            this.outputFactory = requireNonNull(outputFactory, "outputFactory is null");
            this.outputSupplier = requireNonNull(outputSupplier, "outputSupplier is null");
        }

        public OutputBufferType getOutputBufferType()
        {
            return outputBufferType;
        }

        public PrestoSparkOutputBuffer<?> getOutputBuffer()
        {
            return outputBuffer;
        }

        public OutputFactory getOutputFactory()
        {
            return outputFactory;
        }

        public OutputSupplier<T> getOutputSupplier()
        {
            return outputSupplier;
        }
    }

    private interface OutputSupplier<T extends PrestoSparkTaskOutput>
    {
        Tuple2<MutablePartitionId, T> getNext()
                throws InterruptedException;

        long getTimeSpentWaitingForOutputInMillis();
    }

    private static class RowOutputSupplier
            implements OutputSupplier<PrestoSparkMutableRow>
    {
        private final PrestoSparkOutputBuffer<PrestoSparkRowBatch> outputBuffer;
        private RowTupleSupplier currentRowTupleSupplier;

        private long timeSpentWaitingForOutputInMillis;

        private RowOutputSupplier(PrestoSparkOutputBuffer<PrestoSparkRowBatch> outputBuffer)
        {
            this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null");
        }

        @Override
        public Tuple2<MutablePartitionId, PrestoSparkMutableRow> getNext()
                throws InterruptedException
        {
            Tuple2<MutablePartitionId, PrestoSparkMutableRow> next = null;
            while (next == null) {
                if (currentRowTupleSupplier == null) {
                    long start = System.currentTimeMillis();
                    PrestoSparkRowBatch rowBatch = outputBuffer.get();
                    long end = System.currentTimeMillis();
                    timeSpentWaitingForOutputInMillis += (end - start);
                    if (rowBatch == null) {
                        return null;
                    }
                    currentRowTupleSupplier = rowBatch.createRowTupleSupplier();
                }
                next = currentRowTupleSupplier.getNext();
                if (next == null) {
                    currentRowTupleSupplier = null;
                }
            }
            return next;
        }

        @Override
        public long getTimeSpentWaitingForOutputInMillis()
        {
            return timeSpentWaitingForOutputInMillis;
        }
    }

    private static class PageOutputSupplier
            implements OutputSupplier<PrestoSparkSerializedPage>
    {
        private static final MutablePartitionId DEFAULT_PARTITION = new MutablePartitionId();

        private final PrestoSparkOutputBuffer<PrestoSparkBufferedSerializedPage> outputBuffer;

        private long timeSpentWaitingForOutputInMillis;

        private PageOutputSupplier(PrestoSparkOutputBuffer<PrestoSparkBufferedSerializedPage> outputBuffer)
        {
            this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null");
        }

        @Override
        public Tuple2<MutablePartitionId, PrestoSparkSerializedPage> getNext()
                throws InterruptedException
        {
            long start = System.currentTimeMillis();
            PrestoSparkBufferedSerializedPage page = outputBuffer.get();
            long end = System.currentTimeMillis();
            timeSpentWaitingForOutputInMillis += (end - start);
            if (page == null) {
                return null;
            }
            return new Tuple2<>(DEFAULT_PARTITION, toPrestoSparkSerializedPage(page.getSerializedPage()));
        }

        @Override
        public long getTimeSpentWaitingForOutputInMillis()
        {
            return timeSpentWaitingForOutputInMillis;
        }
    }

    private static class DiskPageOutputSupplier
            implements OutputSupplier<PrestoSparkStorageHandle>
    {
        private static final MutablePartitionId DEFAULT_PARTITION = new MutablePartitionId();

        private final PrestoSparkOutputBuffer<PrestoSparkBufferedSerializedPage> outputBuffer;
        private final TempStorage tempStorage;
        private final TempDataOperationContext tempDataOperationContext;
        private final long writeBufferSizeInBytes;

        private TempDataSink tempDataSink;
        private long timeSpentWaitingForOutputInMillis;

        private DiskPageOutputSupplier(PrestoSparkOutputBuffer<PrestoSparkBufferedSerializedPage> outputBuffer,
                TempStorage tempStorage,
                TempDataOperationContext tempDataOperationContext,
                DataSize writeBufferSize)
        {
            this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null");
            this.tempStorage = requireNonNull(tempStorage, "tempStorage is null");
            this.tempDataOperationContext = requireNonNull(tempDataOperationContext, "tempDataOperationContext is null");
            this.writeBufferSizeInBytes = requireNonNull(writeBufferSize, "writeBufferSize is null").toBytes();
        }

        @Override
        public Tuple2<MutablePartitionId, PrestoSparkStorageHandle> getNext()
                throws InterruptedException
        {
            long start = System.currentTimeMillis();
            PrestoSparkBufferedSerializedPage page = outputBuffer.get();
            if (page == null) {
                return null;
            }

            long compressedBroadcastSizeInBytes = 0;
            long uncompressedBroadcastSizeInBytes = 0;
            long deserializedBroadcastRetainedSizeInBytes = 0;
            int positionCount = 0;
            CRC32 checksum = new CRC32();
            TempStorageHandle tempStorageHandle;
            IOException ioException = null;
            try {
                this.tempDataSink = tempStorage.create(tempDataOperationContext);
                List<DataOutput> bufferedPages = new ArrayList<>();
                long bufferedBytes = 0;

                while (page != null) {
                    PageDataOutput pageDataOutput = new PageDataOutput(page.getSerializedPage());
                    long writtenSize = pageDataOutput.size();

                    if ((writeBufferSizeInBytes - bufferedBytes) < writtenSize && !bufferedPages.isEmpty()) {
                        tempDataSink.write(bufferedPages);
                        bufferedPages.clear();
                        bufferedBytes = 0;
                    }

                    bufferedPages.add(pageDataOutput);
                    bufferedBytes += writtenSize;
                    compressedBroadcastSizeInBytes += page.getSerializedPage().getSizeInBytes();
                    uncompressedBroadcastSizeInBytes += page.getSerializedPage().getUncompressedSizeInBytes();
                    deserializedBroadcastRetainedSizeInBytes += page.getDeserializedRetainedSizeInBytes();
                    positionCount += page.getPositionCount();
                    Slice slice = page.getSerializedPage().getSlice();
                    checksum.update(slice.byteArray(), slice.byteArrayOffset(), slice.length());
                    page = outputBuffer.get();
                }

                if (!bufferedPages.isEmpty()) {
                    tempDataSink.write(bufferedPages);
                    bufferedPages.clear();
                }

                tempStorageHandle = tempDataSink.commit();
                log.info("Created broadcast spill file: " + tempStorageHandle.toString() + " deserialized size: " + deserializedBroadcastRetainedSizeInBytes);
                PrestoSparkStorageHandle prestoSparkStorageHandle =
                        new PrestoSparkStorageHandle(
                                tempStorage.serializeHandle(tempStorageHandle),
                                uncompressedBroadcastSizeInBytes,
                                compressedBroadcastSizeInBytes,
                                deserializedBroadcastRetainedSizeInBytes,
                                checksum.getValue(),
                                positionCount);
                long end = System.currentTimeMillis();
                timeSpentWaitingForOutputInMillis += (end - start);
                return new Tuple2<>(DEFAULT_PARTITION, prestoSparkStorageHandle);
            }
            catch (IOException e) {
                if (ioException == null) {
                    ioException = e;
                }
                try {
                    if (tempDataSink != null) {
                        tempDataSink.rollback();
                    }
                }
                catch (IOException exception) {
                    if (ioException != exception) {
                        ioException.addSuppressed(exception);
                    }
                }
            }
            finally {
                try {
                    if (tempDataSink != null) {
                        tempDataSink.close();
                    }
                }
                catch (IOException e) {
                    if (ioException == null) {
                        ioException = e;
                    }
                    else if (ioException != e) {
                        ioException.addSuppressed(e);
                    }
                    throw new UncheckedIOException("Unable to dump data to disk: ", ioException);
                }
            }

            throw new UncheckedIOException("Unable to dump data to disk: ", ioException);
        }

        @Override
        public long getTimeSpentWaitingForOutputInMillis()
        {
            return timeSpentWaitingForOutputInMillis;
        }
    }

    private enum OutputBufferType
    {
        SPARK_ROW_OUTPUT_BUFFER,
        SPARK_PAGE_OUTPUT_BUFFER,
        SPARK_DISK_PAGE_OUTPUT_BUFFER,
    }
}