PrestoSparkNativeTaskExecutorFactory.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.spark.execution.task;
import com.facebook.airlift.json.Codec;
import com.facebook.airlift.json.JsonCodec;
import com.facebook.airlift.log.Logger;
import com.facebook.presto.Session;
import com.facebook.presto.common.RuntimeUnit;
import com.facebook.presto.common.block.BlockEncodingManager;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.execution.ExecutionFailureInfo;
import com.facebook.presto.execution.Lifespan;
import com.facebook.presto.execution.Location;
import com.facebook.presto.execution.ScheduledSplit;
import com.facebook.presto.execution.StageExecutionId;
import com.facebook.presto.execution.StageId;
import com.facebook.presto.execution.TaskId;
import com.facebook.presto.execution.TaskInfo;
import com.facebook.presto.execution.TaskSource;
import com.facebook.presto.execution.TaskState;
import com.facebook.presto.metadata.RemoteTransactionHandle;
import com.facebook.presto.metadata.SessionPropertyManager;
import com.facebook.presto.metadata.Split;
import com.facebook.presto.spark.PrestoSparkTaskDescriptor;
import com.facebook.presto.spark.accesscontrol.PrestoSparkAuthenticatorProvider;
import com.facebook.presto.spark.classloader_interface.IPrestoSparkTaskExecutor;
import com.facebook.presto.spark.classloader_interface.IPrestoSparkTaskExecutorFactory;
import com.facebook.presto.spark.classloader_interface.MutablePartitionId;
import com.facebook.presto.spark.classloader_interface.PrestoSparkNativeTaskInputs;
import com.facebook.presto.spark.classloader_interface.PrestoSparkSerializedPage;
import com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleReadDescriptor;
import com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleStats;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskInputs;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskOutput;
import com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskDescriptor;
import com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskSource;
import com.facebook.presto.spark.classloader_interface.SerializedTaskInfo;
import com.facebook.presto.spark.execution.BroadcastFileInfo;
import com.facebook.presto.spark.execution.PrestoSparkBroadcastTableCacheManager;
import com.facebook.presto.spark.execution.PrestoSparkExecutionExceptionFactory;
import com.facebook.presto.spark.execution.nativeprocess.NativeExecutionProcess;
import com.facebook.presto.spark.execution.nativeprocess.NativeExecutionProcessFactory;
import com.facebook.presto.spark.execution.shuffle.PrestoSparkShuffleInfoTranslator;
import com.facebook.presto.spark.execution.shuffle.PrestoSparkShuffleWriteInfo;
import com.facebook.presto.spark.util.PrestoSparkStatsCollectionUtils;
import com.facebook.presto.spark.util.PrestoSparkUtils;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.PrestoTransportException;
import com.facebook.presto.spi.page.PagesSerde;
import com.facebook.presto.spi.page.SerializedPage;
import com.facebook.presto.spi.plan.PlanFragmentId;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.plan.TableWriterNode;
import com.facebook.presto.spi.security.TokenAuthenticator;
import com.facebook.presto.split.RemoteSplit;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import com.sun.management.OperatingSystemMXBean;
import io.airlift.units.Duration;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.util.CollectionAccumulator;
import scala.Tuple2;
import scala.collection.AbstractIterator;
import scala.collection.Iterator;
import javax.inject.Inject;
import java.io.IOException;
import java.lang.management.ManagementFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static com.facebook.presto.operator.ExchangeOperator.REMOTE_CONNECTOR_ID;
import static com.facebook.presto.spark.PrestoSparkSessionProperties.getNativeExecutionBroadcastBasePath;
import static com.facebook.presto.spark.PrestoSparkSessionProperties.getNativeTerminateWithCoreTimeout;
import static com.facebook.presto.spark.PrestoSparkSessionProperties.isNativeTerminateWithCoreWhenUnresponsiveEnabled;
import static com.facebook.presto.spark.util.PrestoSparkUtils.deserializeZstdCompressed;
import static com.facebook.presto.spark.util.PrestoSparkUtils.serializeZstdCompressed;
import static com.facebook.presto.spark.util.PrestoSparkUtils.toPrestoSparkSerializedPage;
import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static com.facebook.presto.spi.StandardErrorCode.TOO_MANY_REQUESTS_FAILED;
import static com.facebook.presto.sql.planner.SchedulingOrderVisitor.scheduleOrder;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION;
import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.airlift.units.DataSize.succinctBytes;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
/**
* PrestoSparkNativeTaskExecutorFactory is responsible for launching the external native process and managing the communication
* between Java process and native process (by using the {@Link NativeExecutionTask}).
* It will send necessary metadata (e.g, plan fragment, session properties etc.) as a part of
* BatchTaskUpdateRequest. It will poll the remote CPP task for status and results (pages/data if applicable)
* and send these back to the Spark's RDD api
* <p>
* PrestoSparkNativeTaskExecutorFactory is singleton instantiated once per executor.
* <p>
* For every task it receives, it does the following
* 1. Create the Native execution Process (NativeTaskExecutionFactory) ensure that is it created only once.
* 2. Serialize and pass the planFragment, source-metadata (taskSources), sink-metadata (tableWriteInfo or shuffleWriteInfo)
* and submit a nativeExecutionTask.
* 3. Return Iterator to sparkRDD layer. RDD execution will call the .next() methods, which will
* 3.a Call {@link NativeExecutionTask}'s pollResult() to retrieve {@link SerializedPage} back from external process.
* 3.b If no more output is available, then check if task has finished successfully or with exception
* If task finished with exception - fail the spark task (throw exception)
* IF task finished successfully - collect statistics through taskInfo object and add to accumulator
*/
public class PrestoSparkNativeTaskExecutorFactory
implements IPrestoSparkTaskExecutorFactory
{
private static final Logger log = Logger.get(PrestoSparkNativeTaskExecutorFactory.class);
// For Presto-on-Spark, we do not have remoteSourceTasks as the shuffle data is
// in persistent shuffle.
// Current protocol for Split mandates having a remoteSourceTaskId as the
// part of the split info. So for shuffleRead split we set it to a dummy
// value that is ignored by the shuffle-reader
private static final TaskId DUMMY_TASK_ID = TaskId.valueOf("remotesourcetaskid.0.0.0.0");
private final SessionPropertyManager sessionPropertyManager;
private final JsonCodec<PrestoSparkTaskDescriptor> taskDescriptorJsonCodec;
private final JsonCodec<BroadcastFileInfo> broadcastFileInfoJsonCodec;
private final Codec<TaskSource> taskSourceCodec;
private final Codec<TaskInfo> taskInfoCodec;
private final PrestoSparkExecutionExceptionFactory executionExceptionFactory;
private final Set<PrestoSparkAuthenticatorProvider> authenticatorProviders;
private final NativeExecutionProcessFactory nativeExecutionProcessFactory;
private final NativeExecutionTaskFactory nativeExecutionTaskFactory;
private final PrestoSparkShuffleInfoTranslator shuffleInfoTranslator;
private final PagesSerde pagesSerde;
private NativeExecutionProcess nativeExecutionProcess;
private static class CpuTracker
{
private OperatingSystemMXBean operatingSystemMXBean;
private OptionalLong startCpuTime;
public CpuTracker()
{
if (ManagementFactory.getOperatingSystemMXBean() instanceof OperatingSystemMXBean) {
// we want the com.sun.management sub-interface of java.lang.management.OperatingSystemMXBean
operatingSystemMXBean = (OperatingSystemMXBean) ManagementFactory.getOperatingSystemMXBean();
startCpuTime = OptionalLong.of(operatingSystemMXBean.getProcessCpuTime());
}
else {
startCpuTime = OptionalLong.empty();
}
}
OptionalLong get()
{
if (operatingSystemMXBean != null) {
long endCpuTime = operatingSystemMXBean.getProcessCpuTime();
return OptionalLong.of(endCpuTime - startCpuTime.getAsLong());
}
else {
return OptionalLong.empty();
}
}
}
@Inject
public PrestoSparkNativeTaskExecutorFactory(
SessionPropertyManager sessionPropertyManager,
BlockEncodingManager blockEncodingManager,
JsonCodec<PrestoSparkTaskDescriptor> taskDescriptorJsonCodec,
JsonCodec<BroadcastFileInfo> broadcastFileInfoJsonCodec,
Codec<TaskSource> taskSourceCodec,
Codec<TaskInfo> taskInfoCodec,
PrestoSparkExecutionExceptionFactory executionExceptionFactory,
Set<PrestoSparkAuthenticatorProvider> authenticatorProviders,
PrestoSparkBroadcastTableCacheManager prestoSparkBroadcastTableCacheManager,
NativeExecutionProcessFactory nativeExecutionProcessFactory,
NativeExecutionTaskFactory nativeExecutionTaskFactory,
PrestoSparkShuffleInfoTranslator shuffleInfoTranslator)
{
this.sessionPropertyManager = requireNonNull(sessionPropertyManager, "sessionPropertyManager is null");
this.taskDescriptorJsonCodec = requireNonNull(taskDescriptorJsonCodec, "sparkTaskDescriptorJsonCodec is null");
this.taskSourceCodec = requireNonNull(taskSourceCodec, "taskSourceCodec is null");
this.taskInfoCodec = requireNonNull(taskInfoCodec, "taskInfoCodec is null");
this.broadcastFileInfoJsonCodec = requireNonNull(broadcastFileInfoJsonCodec, "broadcastFileInfoJsonCodec is null");
this.executionExceptionFactory = requireNonNull(executionExceptionFactory, "executionExceptionFactory is null");
this.authenticatorProviders = ImmutableSet.copyOf(requireNonNull(authenticatorProviders, "authenticatorProviders is null"));
this.nativeExecutionProcessFactory = requireNonNull(nativeExecutionProcessFactory, "processFactory is null");
this.nativeExecutionTaskFactory = requireNonNull(nativeExecutionTaskFactory, "taskFactory is null");
this.shuffleInfoTranslator = requireNonNull(shuffleInfoTranslator, "shuffleInfoFactory is null");
this.pagesSerde = PrestoSparkUtils.createPagesSerde(requireNonNull(blockEncodingManager, "blockEncodingManager is null"));
}
@Override
public <T extends PrestoSparkTaskOutput> IPrestoSparkTaskExecutor<T> create(
int partitionId,
int attemptNumber,
SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor,
Iterator<SerializedPrestoSparkTaskSource> serializedTaskSources,
PrestoSparkTaskInputs inputs,
CollectionAccumulator<SerializedTaskInfo> taskInfoCollector,
CollectionAccumulator<PrestoSparkShuffleStats> shuffleStatsCollector,
Class<T> outputType)
{
try {
return doCreate(
partitionId,
attemptNumber,
serializedTaskDescriptor,
serializedTaskSources,
inputs,
taskInfoCollector,
shuffleStatsCollector,
outputType);
}
catch (RuntimeException e) {
throw executionExceptionFactory.toPrestoSparkExecutionException(e);
}
}
public <T extends PrestoSparkTaskOutput> IPrestoSparkTaskExecutor<T> doCreate(
int partitionId,
int attemptNumber,
SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor,
Iterator<SerializedPrestoSparkTaskSource> serializedTaskSources,
PrestoSparkTaskInputs inputs,
CollectionAccumulator<SerializedTaskInfo> taskInfoCollector,
CollectionAccumulator<PrestoSparkShuffleStats> shuffleStatsCollector,
Class<T> outputType)
{
CpuTracker cpuTracker = new CpuTracker();
PrestoSparkTaskDescriptor taskDescriptor = taskDescriptorJsonCodec.fromJson(serializedTaskDescriptor.getBytes());
ImmutableMap.Builder<String, TokenAuthenticator> extraAuthenticators = ImmutableMap.builder();
authenticatorProviders.forEach(provider -> extraAuthenticators.putAll(provider.getTokenAuthenticators()));
Session session = taskDescriptor.getSession().toSession(
sessionPropertyManager,
taskDescriptor.getExtraCredentials(),
extraAuthenticators.build());
PlanFragment fragment = taskDescriptor.getFragment();
StageId stageId = new StageId(session.getQueryId(), fragment.getId().getId());
TaskId taskId = new TaskId(new StageExecutionId(stageId, 0), partitionId, attemptNumber);
// TODO: Remove this once we can display the plan on Spark UI.
// Currently, `textPlanFragment` throws an exception if json-based UDFs are used in the query, which can only
// happen in native execution mode. To resolve this error, `JsonFileBasedFunctionNamespaceManager` must be
// loaded on the executors as well (which is actually not required for native execution). To do so, we need a
// mechanism to ship the JSON file containing the UDF metadata to workers, which does not exist as of today.
// TODO: Address this issue; more details in https://github.com/prestodb/presto/issues/19600
log.info("Logging plan fragment is not supported for presto-on-spark native execution, yet");
if (fragment.getPartitioning().isCoordinatorOnly()) {
throw new UnsupportedOperationException("Coordinator only fragment execution is not supported by native task executor");
}
checkArgument(
inputs instanceof PrestoSparkNativeTaskInputs,
format("PrestoSparkNativeTaskInputs is required for native execution, but %s is provided", inputs.getClass().getName()));
// 1. Start the native process if it hasn't already been started or dead
createAndStartNativeExecutionProcess(session);
// 2. compute the task info to send to cpp process
PrestoSparkNativeTaskInputs nativeInputs = (PrestoSparkNativeTaskInputs) inputs;
// 2.a Populate Read info
List<TaskSource> taskSources = getTaskSources(serializedTaskSources, fragment, session, nativeInputs);
boolean isFixedBroadcastDistribution = fragment.getPartitioningScheme().getPartitioning().getHandle().equals(FIXED_BROADCAST_DISTRIBUTION);
// 2.b Populate Shuffle Write info
Optional<PrestoSparkShuffleWriteInfo> shuffleWriteInfo = nativeInputs.getShuffleWriteDescriptor()
.map(descriptor -> shuffleInfoTranslator.createShuffleWriteInfo(session, descriptor));
Optional<String> serializedShuffleWriteInfo = shuffleWriteInfo.map(shuffleInfoTranslator::createSerializedWriteInfo);
// 2.c populate broadcast path
Optional<String> broadcastDirectory =
isFixedBroadcastDistribution ? Optional.of(getBroadcastDirectoryPath(session)) : Optional.empty();
boolean terminateWithCoreWhenUnresponsive = isNativeTerminateWithCoreWhenUnresponsiveEnabled(session);
Duration terminateWithCoreTimeout = getNativeTerminateWithCoreTimeout(session);
try {
// 3. Submit the task to cpp process for execution
log.info("Submitting native execution task ");
NativeExecutionTask task = nativeExecutionTaskFactory.createNativeExecutionTask(
session,
nativeExecutionProcess.getLocation(),
taskId,
fragment,
ImmutableList.copyOf(taskSources),
taskDescriptor.getTableWriteInfo(),
serializedShuffleWriteInfo,
broadcastDirectory);
log.info("Creating task and will wait for remote task completion");
TaskInfo taskInfo = task.start();
// task creation might have failed
processTaskInfoForErrorsOrCompletion(taskInfo);
// 4. return output to spark RDD layer
return new PrestoSparkNativeTaskOutputIterator<>(
partitionId,
task,
outputType,
taskInfoCollector,
taskInfoCodec,
executionExceptionFactory,
cpuTracker,
nativeExecutionProcess,
terminateWithCoreWhenUnresponsive,
terminateWithCoreTimeout);
}
catch (RuntimeException e) {
throw processFailure(e, nativeExecutionProcess, terminateWithCoreWhenUnresponsive, terminateWithCoreTimeout);
}
}
private String getBroadcastDirectoryPath(Session session)
{
return format("%s/%s", getNativeExecutionBroadcastBasePath(session), session.getQueryId().getId());
}
@Override
public void close()
{
if (nativeExecutionProcess != null) {
nativeExecutionProcess.close();
}
}
private static void completeTask(boolean success, CollectionAccumulator<SerializedTaskInfo> taskInfoCollector, NativeExecutionTask task, Codec<TaskInfo> taskInfoCodec, CpuTracker cpuTracker)
{
// stop the task
task.stop(success);
OptionalLong processCpuTime = cpuTracker.get();
// collect statistics (if available)
Optional<TaskInfo> taskInfoOptional = tryGetTaskInfo(task);
if (!taskInfoOptional.isPresent()) {
log.error("Missing taskInfo. Statistics might be inaccurate");
return;
}
// Record process-wide CPU time spent while executing this task. Since we run one task at a time,
// process-wide CPU time matches task's CPU time.
processCpuTime.ifPresent(cpuTime -> taskInfoOptional.get().getStats().getRuntimeStats()
.addMetricValue("javaProcessCpuTime", RuntimeUnit.NANO, cpuTime));
SerializedTaskInfo serializedTaskInfo = new SerializedTaskInfo(serializeZstdCompressed(taskInfoCodec, taskInfoOptional.get()));
taskInfoCollector.add(serializedTaskInfo);
// Update Spark Accumulators for spark internal metrics
PrestoSparkStatsCollectionUtils.collectMetrics(taskInfoOptional.get());
}
private static Optional<TaskInfo> tryGetTaskInfo(NativeExecutionTask task)
{
try {
return task.getTaskInfo();
}
catch (RuntimeException e) {
log.debug(e, "TaskInfo is not available");
return Optional.empty();
}
}
private static void processTaskInfoForErrorsOrCompletion(TaskInfo taskInfo)
{
if (!taskInfo.getTaskStatus().getState().isDone()) {
log.info("processTaskInfoForErrors: task is not done yet.. %s", taskInfo);
return;
}
if (!taskInfo.getTaskStatus().getState().equals(TaskState.FINISHED)) {
// task failed with errors
RuntimeException failure = taskInfo.getTaskStatus().getFailures().stream()
.findFirst()
.map(ExecutionFailureInfo::toException)
.orElseGet(() -> new PrestoException(GENERIC_INTERNAL_ERROR, "Native task failed for an unknown reason"));
throw failure;
}
log.info("processTaskInfoForErrors: task completed successfully = %s", taskInfo);
}
private void createAndStartNativeExecutionProcess(Session session)
{
requireNonNull(nativeExecutionProcessFactory, "Trying to instantiate native process but factory is null");
try {
// create the CPP sidecar process if it doesn't exist.
// We create this when the first task is scheduled
nativeExecutionProcess = nativeExecutionProcessFactory.getNativeExecutionProcess(session);
nativeExecutionProcess.start();
}
catch (ExecutionException | InterruptedException | IOException e) {
throw new RuntimeException(e);
}
}
private List<TaskSource> getTaskSources(
Iterator<SerializedPrestoSparkTaskSource> serializedTaskSources,
PlanFragment fragment,
Session session,
PrestoSparkNativeTaskInputs nativeTaskInputs)
{
List<TaskSource> taskSources = new ArrayList<>();
// Populate TableScan sources
long totalSerializedSizeInBytes = 0;
while (serializedTaskSources.hasNext()) {
SerializedPrestoSparkTaskSource serializedTaskSource = serializedTaskSources.next();
taskSources.add(deserializeZstdCompressed(taskSourceCodec, serializedTaskSource.getBytes()));
totalSerializedSizeInBytes += serializedTaskSource.getBytes().length;
}
// When joining bucketed table with a non-bucketed table with a filter on "$bucket",
// some tasks may not have splits for the bucketed table. In this case we still need
// to send no-more-splits message to Velox.
Set<PlanNodeId> planNodeIdsWithSources = taskSources.stream().map(TaskSource::getPlanNodeId).collect(Collectors.toSet());
Set<PlanNodeId> tableScanIds = Sets.newHashSet(scheduleOrder(fragment.getRoot()));
tableScanIds.stream()
.filter(id -> !planNodeIdsWithSources.contains(id))
.forEach(id -> taskSources.add(new TaskSource(id, ImmutableSet.of(), true)));
log.info("Total serialized size of all table scan task sources: %s", succinctBytes(totalSerializedSizeInBytes));
// Populate remote sources - ShuffleRead & Broadcast.
ImmutableList.Builder<TaskSource> shuffleTaskSources = ImmutableList.builder();
ImmutableList.Builder<TaskSource> broadcastTaskSources = ImmutableList.builder();
AtomicLong nextSplitId = new AtomicLong();
taskSources.stream()
.flatMap(source -> source.getSplits().stream())
.mapToLong(ScheduledSplit::getSequenceId)
.max()
.ifPresent(id -> nextSplitId.set(id + 1));
for (RemoteSourceNode remoteSource : fragment.getRemoteSourceNodes()) {
for (PlanFragmentId sourceFragmentId : remoteSource.getSourceFragmentIds()) {
PrestoSparkShuffleReadDescriptor shuffleReadDescriptor =
nativeTaskInputs.getShuffleReadDescriptors().get(sourceFragmentId.toString());
if (shuffleReadDescriptor != null) {
ScheduledSplit split = new ScheduledSplit(nextSplitId.getAndIncrement(), remoteSource.getId(), new Split(REMOTE_CONNECTOR_ID, new RemoteTransactionHandle(), new RemoteSplit(
new Location(format("batch://%s?shuffleInfo=%s", DUMMY_TASK_ID,
shuffleInfoTranslator.createSerializedReadInfo(
shuffleInfoTranslator.createShuffleReadInfo(session, shuffleReadDescriptor)))),
DUMMY_TASK_ID)));
TaskSource source = new TaskSource(remoteSource.getId(), ImmutableSet.of(split), ImmutableSet.of(Lifespan.taskWide()), true);
shuffleTaskSources.add(source);
}
Broadcast<?> broadcast = nativeTaskInputs.getBroadcastInputs().get(sourceFragmentId.toString());
if (broadcast != null) {
Set<ScheduledSplit> splits =
((List<?>) broadcast.value()).stream()
.map(PrestoSparkSerializedPage.class::cast)
.map(prestoSparkSerializedPage -> PrestoSparkUtils.toSerializedPage(prestoSparkSerializedPage))
.map(serializedPage -> pagesSerde.deserialize(serializedPage))
// Extract filePath.
.flatMap(page -> IntStream.range(0, page.getPositionCount())
.mapToObj(position -> VarcharType.VARCHAR.getObjectValue(null, page.getBlock(0), position)))
.map(String.class::cast)
.map(filePath -> new BroadcastFileInfo(filePath))
.map(broadcastFileInfo -> new ScheduledSplit(
nextSplitId.getAndIncrement(),
remoteSource.getId(),
new Split(
REMOTE_CONNECTOR_ID,
new RemoteTransactionHandle(),
new RemoteSplit(
new Location(
format("batch://%s?broadcastInfo=%s", DUMMY_TASK_ID, broadcastFileInfoJsonCodec.toJson(broadcastFileInfo))),
DUMMY_TASK_ID))))
.collect(toImmutableSet());
TaskSource source = new TaskSource(remoteSource.getId(), splits, ImmutableSet.of(Lifespan.taskWide()), true);
broadcastTaskSources.add(source);
}
}
}
taskSources.addAll(shuffleTaskSources.build());
taskSources.addAll(broadcastTaskSources.build());
return taskSources;
}
private Optional<TableWriterNode> findTableWriteNode(PlanNode node)
{
return searchFrom(node)
.where(TableWriterNode.class::isInstance)
.findFirst();
}
private static class PrestoSparkNativeTaskOutputIterator<T extends PrestoSparkTaskOutput>
extends AbstractIterator<Tuple2<MutablePartitionId, T>>
implements IPrestoSparkTaskExecutor<T>
{
private final int partitionId;
private final NativeExecutionTask nativeExecutionTask;
private Optional<SerializedPage> next = Optional.empty();
private final CollectionAccumulator<SerializedTaskInfo> taskInfoCollectionAccumulator;
private final Codec<TaskInfo> taskInfoCodec;
private final Class<T> outputType;
private final PrestoSparkExecutionExceptionFactory executionExceptionFactory;
private final CpuTracker cpuTracker;
private final NativeExecutionProcess nativeExecutionProcess;
private final boolean terminateWithCoreWhenUnresponsive;
private final Duration terminateWithCoreTimeout;
public PrestoSparkNativeTaskOutputIterator(
int partitionId,
NativeExecutionTask nativeExecutionTask,
Class<T> outputType,
CollectionAccumulator<SerializedTaskInfo> taskInfoCollectionAccumulator,
Codec<TaskInfo> taskInfoCodec,
PrestoSparkExecutionExceptionFactory executionExceptionFactory,
CpuTracker cpuTracker,
NativeExecutionProcess nativeExecutionProcess,
boolean terminateWithCoreWhenUnresponsive,
Duration terminateWithCoreTimeout)
{
this.partitionId = partitionId;
this.nativeExecutionTask = nativeExecutionTask;
this.taskInfoCollectionAccumulator = taskInfoCollectionAccumulator;
this.taskInfoCodec = taskInfoCodec;
this.outputType = outputType;
this.executionExceptionFactory = executionExceptionFactory;
this.cpuTracker = cpuTracker;
this.nativeExecutionProcess = requireNonNull(nativeExecutionProcess, "nativeExecutionProcess is null");
this.terminateWithCoreWhenUnresponsive = terminateWithCoreWhenUnresponsive;
this.terminateWithCoreTimeout = requireNonNull(terminateWithCoreTimeout, "terminateWithCoreTimeout is null");
}
/**
* This function is called by Spark's RDD layer to check if there are output pages
* There are 2 scenarios
* 1. ShuffleMap Task - Always returns false. But the internal function calls do all the work needed
* 2. Result Task - True until pages are available. False once all pages have been extracted
*
* @return if output is available
*/
@Override
public boolean hasNext()
{
next = computeNext();
return next.isPresent();
}
/**
* This function returns the next available page fetched from CPP process
* <p>
* Has 3 main responsibilities
* 1) wait-for-pages-or-completion
* <p>
* The thread running this method will wait until either of the 3 conditions happen
* * 1. We get a page
* * 2. Task has finished successfully
* * 3. Task has finished with error
* <p>
* For ShuffleMap Task, as of now, the CPP process returns no pages.
* So the thread will be in WAITING state till the CPP task is done and returns an Optional.empty()
* once the task has terminated
* <p>
* For a Result Task, this function will return pages retrieved from CPP side once we got them.
* Once all the pages have been read and the task has been terminates
* <p>
* 2) Exception handling
* The function also checks if the task has finished
* with exceptions and throws the appropriate exception back to spark's RDD processing
* layer
* <p>
* 3) Statistics collection
* For both, when the task finished successfully or with exception, it tries to collect
* statistics if TaskInfo object is available
*
* @return Optional<SerializedPage> outputPage
*/
private Optional<SerializedPage> computeNext()
{
try {
Object taskFinishedOrHasResult = nativeExecutionTask.getTaskFinishedOrHasResult();
// Blocking wait if task is still running or hasn't produced any output page
synchronized (taskFinishedOrHasResult) {
while (!nativeExecutionTask.isTaskDone() && !nativeExecutionTask.hasResult()) {
taskFinishedOrHasResult.wait();
}
}
// For ShuffleMap Task, this will always return Optional.empty()
Optional<SerializedPage> pageOptional = nativeExecutionTask.pollResult();
if (pageOptional.isPresent()) {
return pageOptional;
}
// Double check if current task's already done (since thread could be awoken by either having output or task is done above)
synchronized (taskFinishedOrHasResult) {
while (!nativeExecutionTask.isTaskDone()) {
taskFinishedOrHasResult.wait();
}
}
Optional<TaskInfo> taskInfo = nativeExecutionTask.getTaskInfo();
processTaskInfoForErrorsOrCompletion(taskInfo.get());
}
catch (RuntimeException ex) {
// For a failed task, if taskInfo is present we still want to log the metrics
completeTask(false, taskInfoCollectionAccumulator, nativeExecutionTask, taskInfoCodec, cpuTracker);
throw executionExceptionFactory.toPrestoSparkExecutionException(processFailure(
ex,
nativeExecutionProcess,
terminateWithCoreWhenUnresponsive,
terminateWithCoreTimeout));
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
// Reaching here marks the end of task processing
completeTask(true, taskInfoCollectionAccumulator, nativeExecutionTask, taskInfoCodec, cpuTracker);
return Optional.empty();
}
@Override
public Tuple2<MutablePartitionId, T> next()
{
// Result Tasks only have outputType of PrestoSparkSerializedPage.
checkArgument(outputType == PrestoSparkSerializedPage.class,
format("PrestoSparkNativeTaskExecutorFactory only outputType=PrestoSparkSerializedPage " +
"But tried to extract outputType=%s", outputType));
// Set partition ID to help match the results to the task on the driver for debugging.
MutablePartitionId mutablePartitionId = new MutablePartitionId();
mutablePartitionId.setPartition(partitionId);
return new Tuple2<>(mutablePartitionId, (T) toPrestoSparkSerializedPage(next.get()));
}
}
private static RuntimeException processFailure(
RuntimeException failure,
NativeExecutionProcess process,
boolean terminateWithCoreWhenUnresponsive,
Duration terminateWithCoreTimeout)
{
if (isCommunicationLoss(failure)) {
PrestoTransportException transportException = (PrestoTransportException) failure;
String message;
// lost communication with the native execution process
if (process.isAlive()) {
// process is unresponsive
if (terminateWithCoreWhenUnresponsive) {
process.terminateWithCore(terminateWithCoreTimeout);
}
message = "Native execution process is alive but unresponsive";
}
else {
message = "Native execution process is dead";
String crashReport = process.getCrashReport();
if (!crashReport.isEmpty()) {
message += ":\n" + crashReport;
}
}
return new PrestoTransportException(
transportException::getErrorCode,
transportException.getRemoteHost(),
message,
failure);
}
return failure;
}
private static boolean isCommunicationLoss(RuntimeException failure)
{
if (!(failure instanceof PrestoTransportException)) {
return false;
}
PrestoTransportException transportException = (PrestoTransportException) failure;
return TOO_MANY_REQUESTS_FAILED.toErrorCode().equals(transportException.getErrorCode());
}
}