AbstractPrestoSparkQueryExecution.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.spark.execution;
import com.facebook.airlift.json.Codec;
import com.facebook.airlift.json.JsonCodec;
import com.facebook.airlift.log.Logger;
import com.facebook.presto.Session;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.cost.HistoryBasedPlanStatisticsTracker;
import com.facebook.presto.event.QueryMonitor;
import com.facebook.presto.execution.ExecutionFailureInfo;
import com.facebook.presto.execution.QueryInfo;
import com.facebook.presto.execution.QueryManagerConfig;
import com.facebook.presto.execution.QueryState;
import com.facebook.presto.execution.QueryStateTimer;
import com.facebook.presto.execution.StageInfo;
import com.facebook.presto.execution.TaskInfo;
import com.facebook.presto.execution.scheduler.ExecutionWriterTarget;
import com.facebook.presto.execution.scheduler.StreamingPlanSection;
import com.facebook.presto.execution.scheduler.StreamingSubPlan;
import com.facebook.presto.execution.scheduler.TableWriteInfo;
import com.facebook.presto.memory.NodeMemoryConfig;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spark.ErrorClassifier;
import com.facebook.presto.spark.PrestoSparkBroadcastDependency;
import com.facebook.presto.spark.PrestoSparkMemoryBasedBroadcastDependency;
import com.facebook.presto.spark.PrestoSparkMetadataStorage;
import com.facebook.presto.spark.PrestoSparkNativeStorageBasedDependency;
import com.facebook.presto.spark.PrestoSparkQueryData;
import com.facebook.presto.spark.PrestoSparkQueryExecutionFactory;
import com.facebook.presto.spark.PrestoSparkQueryStatusInfo;
import com.facebook.presto.spark.PrestoSparkServiceWaitTimeMetrics;
import com.facebook.presto.spark.PrestoSparkStorageBasedBroadcastDependency;
import com.facebook.presto.spark.PrestoSparkTaskDescriptor;
import com.facebook.presto.spark.RddAndMore;
import com.facebook.presto.spark.classloader_interface.IPrestoSparkQueryExecution;
import com.facebook.presto.spark.classloader_interface.IPrestoSparkTaskExecutor;
import com.facebook.presto.spark.classloader_interface.MutablePartitionId;
import com.facebook.presto.spark.classloader_interface.PrestoSparkExecutionException;
import com.facebook.presto.spark.classloader_interface.PrestoSparkJavaExecutionTaskInputs;
import com.facebook.presto.spark.classloader_interface.PrestoSparkMutableRow;
import com.facebook.presto.spark.classloader_interface.PrestoSparkPartitioner;
import com.facebook.presto.spark.classloader_interface.PrestoSparkSerializedPage;
import com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleSerializer;
import com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleStats;
import com.facebook.presto.spark.classloader_interface.PrestoSparkStorageHandle;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskExecutorFactoryProvider;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskOutput;
import com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskDescriptor;
import com.facebook.presto.spark.classloader_interface.SerializedTaskInfo;
import com.facebook.presto.spark.execution.task.PrestoSparkTaskExecutorFactory;
import com.facebook.presto.spark.planner.PrestoSparkPlanFragmenter;
import com.facebook.presto.spark.planner.PrestoSparkQueryPlanner.PlanAndMore;
import com.facebook.presto.spark.planner.PrestoSparkRddFactory;
import com.facebook.presto.spark.util.PrestoSparkTransactionUtils;
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.analyzer.UpdateInfo;
import com.facebook.presto.spi.connector.ConnectorCapabilities;
import com.facebook.presto.spi.connector.ConnectorNodePartitioningProvider;
import com.facebook.presto.spi.page.PagesSerde;
import com.facebook.presto.spi.plan.PartitioningHandle;
import com.facebook.presto.spi.plan.PartitioningScheme;
import com.facebook.presto.spi.plan.PlanFragmentId;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.storage.StorageCapabilities;
import com.facebook.presto.spi.storage.TempDataOperationContext;
import com.facebook.presto.spi.storage.TempStorage;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.planner.PartitioningProviderManager;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.SubPlan;
import com.facebook.presto.transaction.TransactionManager;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ComparisonChain;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import org.apache.spark.MapOutputStatistics;
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SimpleFutureAction;
import org.apache.spark.SparkException;
import org.apache.spark.api.java.JavaFutureAction;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.ShuffledRDD;
import org.apache.spark.util.CollectionAccumulator;
import scala.Tuple2;
import javax.annotation.concurrent.GuardedBy;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.IntStream;
import static com.facebook.presto.SystemSessionProperties.getQueryMaxBroadcastMemory;
import static com.facebook.presto.SystemSessionProperties.getQueryMaxTotalMemoryPerNode;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.execution.QueryState.FAILED;
import static com.facebook.presto.execution.QueryState.FINISHED;
import static com.facebook.presto.execution.scheduler.StreamingPlanSection.extractStreamingSections;
import static com.facebook.presto.execution.scheduler.TableWriteInfo.createTableWriteInfo;
import static com.facebook.presto.spark.PrestoSparkSessionProperties.getSparkBroadcastJoinMaxMemoryOverride;
import static com.facebook.presto.spark.PrestoSparkSessionProperties.isStorageBasedBroadcastJoinEnabled;
import static com.facebook.presto.spark.PrestoSparkSettingsRequirements.SPARK_DYNAMIC_ALLOCATION_MAX_EXECUTORS_CONFIG;
import static com.facebook.presto.spark.SparkErrorCode.EXCEEDED_SPARK_DRIVER_MAX_RESULT_SIZE;
import static com.facebook.presto.spark.SparkErrorCode.GENERIC_SPARK_ERROR;
import static com.facebook.presto.spark.SparkErrorCode.SPARK_EXECUTOR_LOST;
import static com.facebook.presto.spark.SparkErrorCode.SPARK_EXECUTOR_OOM;
import static com.facebook.presto.spark.SparkErrorCode.UNSUPPORTED_STORAGE_TYPE;
import static com.facebook.presto.spark.classloader_interface.ScalaUtils.collectScalaIterator;
import static com.facebook.presto.spark.classloader_interface.ScalaUtils.emptyScalaIterator;
import static com.facebook.presto.spark.planner.PrestoSparkRddFactory.getRDDName;
import static com.facebook.presto.spark.util.PrestoSparkFailureUtils.toPrestoSparkFailure;
import static com.facebook.presto.spark.util.PrestoSparkUtils.classTag;
import static com.facebook.presto.spark.util.PrestoSparkUtils.computeNextTimeout;
import static com.facebook.presto.spark.util.PrestoSparkUtils.deserializeZstdCompressed;
import static com.facebook.presto.spark.util.PrestoSparkUtils.toSerializedPage;
import static com.facebook.presto.spi.StandardErrorCode.EXCEEDED_TIME_LIMIT;
import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED;
import static com.facebook.presto.spi.connector.ConnectorCapabilities.SUPPORTS_PAGE_SINK_COMMIT;
import static com.facebook.presto.spi.storage.StorageCapabilities.REMOTELY_ACCESSIBLE;
import static com.facebook.presto.sql.planner.PlanFragmenterUtils.isRootFragment;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
import static com.facebook.presto.util.Failures.toFailure;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.util.concurrent.Futures.getUnchecked;
import static io.airlift.units.DataSize.Unit.BYTE;
import static java.lang.Math.min;
import static java.util.Collections.unmodifiableList;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.apache.spark.util.Utils.isLocalMaster;
public abstract class AbstractPrestoSparkQueryExecution
implements IPrestoSparkQueryExecution
{
private static final Logger log = Logger.get(AbstractPrestoSparkQueryExecution.class);
protected final Session session;
protected final QueryMonitor queryMonitor;
protected final CollectionAccumulator<SerializedTaskInfo> taskInfoCollector;
protected final CollectionAccumulator<PrestoSparkShuffleStats> shuffleStatsCollector;
// used to create tasks on the Driver
protected final PrestoSparkTaskExecutorFactory taskExecutorFactory;
// used to create tasks on executor, serializable
protected final PrestoSparkTaskExecutorFactoryProvider taskExecutorFactoryProvider;
protected final QueryStateTimer queryStateTimer;
protected final WarningCollector warningCollector;
protected final String query;
protected final PlanAndMore planAndMore;
protected final Optional<String> sparkQueueName;
protected final Codec<TaskInfo> taskInfoCodec;
protected final JsonCodec<PrestoSparkTaskDescriptor> sparkTaskDescriptorJsonCodec;
protected final JsonCodec<PrestoSparkQueryStatusInfo> queryStatusInfoJsonCodec;
protected final JsonCodec<PrestoSparkQueryData> queryDataJsonCodec;
protected final PrestoSparkRddFactory rddFactory;
protected final TransactionManager transactionManager;
protected final PagesSerde pagesSerde;
protected final PrestoSparkExecutionExceptionFactory executionExceptionFactory;
protected final Duration queryTimeout;
protected final Metadata metadata;
protected final PrestoSparkMetadataStorage metadataStorage;
protected final Optional<String> queryStatusInfoOutputLocation;
protected final Optional<String> queryDataOutputLocation;
protected final long queryCompletionDeadline;
protected final TempStorage tempStorage;
protected final NodeMemoryConfig nodeMemoryConfig;
protected final FeaturesConfig featuresConfig;
protected final QueryManagerConfig queryManagerConfig;
protected final Set<PrestoSparkServiceWaitTimeMetrics> waitTimeMetrics;
protected final Optional<ErrorClassifier> errorClassifier;
protected final JavaSparkContext sparkContext;
protected final PrestoSparkPlanFragmenter planFragmenter;
protected final PartitioningProviderManager partitioningProviderManager;
protected final HistoryBasedPlanStatisticsTracker historyBasedPlanStatisticsTracker;
private AtomicReference<SubPlan> finalFragmentedPlan = new AtomicReference<>();
@GuardedBy("this")
private final Map<PlanFragmentId, RddAndMore> fragmentIdToRdd = new HashMap<>();
private final Optional<CollectionAccumulator<Map<String, Long>>> bootstrapMetricsCollector;
public AbstractPrestoSparkQueryExecution(
JavaSparkContext sparkContext,
Session session,
QueryMonitor queryMonitor,
CollectionAccumulator<SerializedTaskInfo> taskInfoCollector,
CollectionAccumulator<PrestoSparkShuffleStats> shuffleStatsCollector,
PrestoSparkTaskExecutorFactory taskExecutorFactory,
PrestoSparkTaskExecutorFactoryProvider taskExecutorFactoryProvider,
QueryStateTimer queryStateTimer,
WarningCollector warningCollector,
String query,
PlanAndMore planAndMore,
Optional<String> sparkQueueName,
Codec<TaskInfo> taskInfoCodec,
JsonCodec<PrestoSparkTaskDescriptor> sparkTaskDescriptorJsonCodec,
JsonCodec<PrestoSparkQueryStatusInfo> queryStatusInfoJsonCodec,
JsonCodec<PrestoSparkQueryData> queryDataJsonCodec,
PrestoSparkRddFactory rddFactory,
TransactionManager transactionManager,
PagesSerde pagesSerde,
PrestoSparkExecutionExceptionFactory executionExceptionFactory,
Duration queryTimeout,
long queryCompletionDeadline,
PrestoSparkMetadataStorage metadataStorage,
Optional<String> queryStatusInfoOutputLocation,
Optional<String> queryDataOutputLocation,
TempStorage tempStorage,
NodeMemoryConfig nodeMemoryConfig,
FeaturesConfig featuresConfig,
QueryManagerConfig queryManagerConfig,
Set<PrestoSparkServiceWaitTimeMetrics> waitTimeMetrics,
Optional<ErrorClassifier> errorClassifier,
PrestoSparkPlanFragmenter planFragmenter,
Metadata metadata,
PartitioningProviderManager partitioningProviderManager,
HistoryBasedPlanStatisticsTracker historyBasedPlanStatisticsTracker,
Optional<CollectionAccumulator<Map<String, Long>>> bootstrapMetricsCollector)
{
this.sparkContext = requireNonNull(sparkContext, "sparkContext is null");
this.session = requireNonNull(session, "session is null");
this.queryMonitor = requireNonNull(queryMonitor, "queryMonitor is null");
this.taskInfoCollector = requireNonNull(taskInfoCollector, "taskInfoCollector is null");
this.shuffleStatsCollector = requireNonNull(shuffleStatsCollector, "shuffleStatsCollector is null");
this.taskExecutorFactory = requireNonNull(taskExecutorFactory, "taskExecutorFactory is null");
this.taskExecutorFactoryProvider = requireNonNull(taskExecutorFactoryProvider, "taskExecutorFactoryProvider is null");
this.queryStateTimer = requireNonNull(queryStateTimer, "queryStateTimer is null");
this.warningCollector = requireNonNull(warningCollector, "warningCollector is null");
this.query = requireNonNull(query, "query is null");
this.planAndMore = requireNonNull(planAndMore, "planAndMore is null");
this.sparkQueueName = requireNonNull(sparkQueueName, "sparkQueueName is null");
this.taskInfoCodec = requireNonNull(taskInfoCodec, "taskInfoCodec is null");
this.sparkTaskDescriptorJsonCodec = requireNonNull(sparkTaskDescriptorJsonCodec, "sparkTaskDescriptorJsonCodec is null");
this.queryStatusInfoJsonCodec = requireNonNull(queryStatusInfoJsonCodec, "queryStatusInfoJsonCodec is null");
this.queryDataJsonCodec = requireNonNull(queryDataJsonCodec, "queryDataJsonCodec is null");
this.rddFactory = requireNonNull(rddFactory, "rddFactory is null");
this.transactionManager = requireNonNull(transactionManager, "transactionManager is null");
this.pagesSerde = requireNonNull(pagesSerde, "pagesSerde is null");
this.executionExceptionFactory = requireNonNull(executionExceptionFactory, "executionExceptionFactory is null");
this.queryTimeout = requireNonNull(queryTimeout, "queryTimeout is null");
this.queryCompletionDeadline = queryCompletionDeadline;
this.metadataStorage = requireNonNull(metadataStorage, "metadataStorage is null");
this.queryStatusInfoOutputLocation = requireNonNull(queryStatusInfoOutputLocation, "queryStatusInfoOutputLocation is null");
this.queryDataOutputLocation = requireNonNull(queryDataOutputLocation, "queryDataOutputLocation is null");
this.tempStorage = requireNonNull(tempStorage, "tempStorage is null");
this.nodeMemoryConfig = requireNonNull(nodeMemoryConfig, "nodeMemoryConfig is null");
this.featuresConfig = requireNonNull(featuresConfig, "featuresConfig is null");
this.queryManagerConfig = requireNonNull(queryManagerConfig, "queryManagerConfig is null");
this.waitTimeMetrics = requireNonNull(waitTimeMetrics, "waitTimeMetrics is null");
this.errorClassifier = requireNonNull(errorClassifier, "errorClassifier is null");
this.planFragmenter = requireNonNull(planFragmenter, "planFragmenter is null");
this.metadata = requireNonNull(metadata, "metadata is null");
this.partitioningProviderManager = requireNonNull(partitioningProviderManager, "partitioningProviderManager is null");
this.historyBasedPlanStatisticsTracker = requireNonNull(historyBasedPlanStatisticsTracker, "historyBasedPlanStatisticsTracker is null");
this.bootstrapMetricsCollector = requireNonNull(bootstrapMetricsCollector, "bootstrapTimeCollector is null");
}
protected static JavaPairRDD<MutablePartitionId, PrestoSparkMutableRow> partitionBy(
int planFragmentId,
JavaPairRDD<MutablePartitionId, PrestoSparkMutableRow> rdd,
PartitioningScheme partitioningScheme)
{
Partitioner partitioner = createPartitioner(partitioningScheme);
JavaPairRDD<MutablePartitionId, PrestoSparkMutableRow> javaPairRdd = rdd.partitionBy(partitioner);
ShuffledRDD<MutablePartitionId, PrestoSparkMutableRow, PrestoSparkMutableRow> shuffledRdd = (ShuffledRDD<MutablePartitionId, PrestoSparkMutableRow, PrestoSparkMutableRow>) javaPairRdd.rdd();
shuffledRdd.setSerializer(new PrestoSparkShuffleSerializer());
shuffledRdd.setName(getRDDName(planFragmentId));
return JavaPairRDD.fromRDD(
shuffledRdd,
classTag(MutablePartitionId.class),
classTag(PrestoSparkMutableRow.class));
}
protected static Partitioner createPartitioner(PartitioningScheme partitioningScheme)
{
PartitioningHandle partitioning = partitioningScheme.getPartitioning().getHandle();
if (partitioning.equals(SINGLE_DISTRIBUTION)) {
return new PrestoSparkPartitioner(1);
}
if (partitioning.equals(FIXED_HASH_DISTRIBUTION)
|| partitioning.equals(FIXED_ARBITRARY_DISTRIBUTION)
|| partitioning.getConnectorId().isPresent()) {
int[] bucketToPartition = partitioningScheme.getBucketToPartition().orElseThrow(
() -> new IllegalArgumentException("bucketToPartition is expected to be assigned at this point"));
checkArgument(bucketToPartition.length > 0, "bucketToPartition is expected to be non empty");
int numberOfPartitions = IntStream.of(bucketToPartition)
.max()
.getAsInt() + 1;
return new PrestoSparkPartitioner(numberOfPartitions);
}
throw new IllegalArgumentException("Unexpected partitioning: " + partitioning);
}
@Override
public List<List<Object>> execute()
{
List<Tuple2<MutablePartitionId, PrestoSparkSerializedPage>> rddResults;
try {
tuneMaxExecutorsCount();
rddResults = doExecute();
queryStateTimer.beginFinishing();
PrestoSparkTransactionUtils.commit(session, transactionManager);
queryStateTimer.endQuery();
}
catch (Throwable executionException) {
queryStateTimer.beginFinishing();
try {
PrestoSparkTransactionUtils.rollback(session, transactionManager);
}
catch (RuntimeException rollbackFailure) {
log.error(rollbackFailure, "Encountered error when performing rollback");
}
Optional<ExecutionFailureInfo> failureInfo = Optional.empty();
if (executionException instanceof SparkException) {
SparkException sparkException = (SparkException) executionException;
failureInfo = executionExceptionFactory.extractExecutionFailureInfo(sparkException);
if (!failureInfo.isPresent()) {
// not a SparkException with Presto failure info encoded
PrestoException wrappedPrestoException;
if (sparkException.getMessage().contains("most recent failure: JVM_OOM")) {
wrappedPrestoException = new PrestoException(SPARK_EXECUTOR_OOM, executionException);
}
else if (sparkException.getMessage().matches(".*Total size of serialized results .* is bigger than allowed maxResultSize.*")) {
wrappedPrestoException = new PrestoException(EXCEEDED_SPARK_DRIVER_MAX_RESULT_SIZE, executionException);
}
else if (sparkException.getMessage().contains("Executor heartbeat timed out") ||
sparkException.getMessage().contains("Unable to talk to the executor")) {
wrappedPrestoException = new PrestoException(SPARK_EXECUTOR_LOST, executionException);
}
else if (errorClassifier.isPresent()) {
wrappedPrestoException = errorClassifier.get().classify(executionException);
}
else {
wrappedPrestoException = new PrestoException(GENERIC_SPARK_ERROR, executionException);
}
failureInfo = Optional.of(toFailure(wrappedPrestoException));
}
}
else if (executionException instanceof PrestoSparkExecutionException) {
failureInfo = executionExceptionFactory.extractExecutionFailureInfo((PrestoSparkExecutionException) executionException);
}
else if (executionException instanceof TimeoutException) {
failureInfo = Optional.of(toFailure(new PrestoException(EXCEEDED_TIME_LIMIT, "Query exceeded maximum time limit of " + queryTimeout, executionException)));
}
if (!failureInfo.isPresent()) {
failureInfo = Optional.of(toFailure(executionException));
}
queryStateTimer.endQuery();
try {
queryCompletedEvent(failureInfo, OptionalLong.empty());
}
catch (RuntimeException eventFailure) {
log.error(eventFailure, "Error publishing query completed event");
}
throw toPrestoSparkFailure(session, failureInfo.get());
}
processShuffleStats();
ConnectorSession connectorSession = session.toConnectorSession();
List<Type> types = getOutputTypes();
ImmutableList.Builder<List<Object>> result = ImmutableList.builder();
for (Tuple2<MutablePartitionId, PrestoSparkSerializedPage> tuple : rddResults) {
Page page = pagesSerde.deserialize(toSerializedPage(tuple._2));
checkArgument(page.getChannelCount() == types.size(), "expected %s channels, got %s", types.size(), page.getChannelCount());
for (int position = 0; position < page.getPositionCount(); position++) {
List<Object> columns = new ArrayList<>();
for (int channel = 0; channel < page.getChannelCount(); channel++) {
columns.add(types.get(channel).getObjectValue(connectorSession.getSqlFunctionProperties(), page.getBlock(channel), position));
}
result.add(unmodifiableList(columns));
}
}
List<List<Object>> results = result.build();
// Based on com.facebook.presto.server.protocol.Query#getNextResult
OptionalLong updateCount = OptionalLong.empty();
if (planAndMore.getUpdateInfo().isPresent() &&
types.size() == 1 &&
types.get(0).equals(BIGINT) &&
results.size() == 1 &&
results.get(0).size() == 1 &&
results.get(0).get(0) != null) {
updateCount = OptionalLong.of(((Number) results.get(0).get(0)).longValue());
}
// successfully finished
try {
queryCompletedEvent(Optional.empty(), updateCount);
}
catch (RuntimeException eventFailure) {
log.error(eventFailure, "Error publishing query completed event");
}
if (queryDataOutputLocation.isPresent()) {
metadataStorage.write(
queryDataOutputLocation.get(),
queryDataJsonCodec.toJsonBytes(new PrestoSparkQueryData(PrestoSparkQueryExecutionFactory.getOutputColumns(planAndMore), results)));
}
return results;
}
public List<Type> getOutputTypes()
{
Optional<SubPlan> subPlanOptional = getFinalFragmentedPlan();
verify(subPlanOptional.isPresent(), "finalFragmentedPlan is null");
return subPlanOptional.get().getFragment().getTypes();
}
public Optional<UpdateInfo> getUpdateType()
{
return planAndMore.getUpdateInfo();
}
protected abstract List<Tuple2<MutablePartitionId, PrestoSparkSerializedPage>> doExecute()
throws SparkException, TimeoutException;
protected List<Tuple2<MutablePartitionId, PrestoSparkSerializedPage>> collectPages(TableWriteInfo tableWriteInfo, PlanFragment rootFragment, Map<PlanFragmentId, RddAndMore<PrestoSparkSerializedPage>> inputRdds)
throws SparkException, TimeoutException
{
PrestoSparkTaskDescriptor taskDescriptor = new PrestoSparkTaskDescriptor(
session.toSessionRepresentation(),
session.getIdentity().getExtraCredentials(),
rootFragment,
tableWriteInfo);
SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor = new SerializedPrestoSparkTaskDescriptor(sparkTaskDescriptorJsonCodec.toJsonBytes(taskDescriptor));
Map<String, JavaFutureAction<List<Tuple2<MutablePartitionId, PrestoSparkSerializedPage>>>> inputFutures = inputRdds.entrySet().stream()
.collect(toImmutableMap(entry -> entry.getKey().toString(), entry -> entry.getValue().getRdd().collectAsync()));
PrestoSparkQueryExecutionFactory.waitForActionsCompletionWithTimeout(inputFutures.values(), computeNextTimeout(queryCompletionDeadline), MILLISECONDS, waitTimeMetrics);
// release memory retained by the RDDs (splits and dependencies)
inputRdds = null;
ImmutableMap.Builder<String, List<PrestoSparkSerializedPage>> inputs = ImmutableMap.builder();
long totalNumberOfPagesReceived = 0;
long totalCompressedSizeInBytes = 0;
long totalUncompressedSizeInBytes = 0;
for (Map.Entry<String, JavaFutureAction<List<Tuple2<MutablePartitionId, PrestoSparkSerializedPage>>>> inputFuture : inputFutures.entrySet()) {
// Use a mutable list to allow memory release on per page basis
List<PrestoSparkSerializedPage> pages = new ArrayList<>();
List<Tuple2<MutablePartitionId, PrestoSparkSerializedPage>> tuples = getUnchecked(inputFuture.getValue());
long currentFragmentOutputCompressedSizeInBytes = 0;
long currentFragmentOutputUncompressedSizeInBytes = 0;
for (Tuple2<MutablePartitionId, PrestoSparkSerializedPage> tuple : tuples) {
PrestoSparkSerializedPage page = tuple._2;
currentFragmentOutputCompressedSizeInBytes += page.getSize();
currentFragmentOutputUncompressedSizeInBytes += page.getUncompressedSizeInBytes();
log.info("Received %s rows from partition %s in fragment %s", page.getPositionCount(), tuple._1.getPartition(), inputFuture.getKey());
pages.add(page);
}
log.info(
"Received %s pages from fragment %s. Compressed size: %s. Uncompressed size: %s.",
pages.size(),
inputFuture.getKey(),
DataSize.succinctBytes(currentFragmentOutputCompressedSizeInBytes),
DataSize.succinctBytes(currentFragmentOutputUncompressedSizeInBytes));
totalNumberOfPagesReceived += pages.size();
totalCompressedSizeInBytes += currentFragmentOutputCompressedSizeInBytes;
totalUncompressedSizeInBytes += currentFragmentOutputUncompressedSizeInBytes;
inputs.put(inputFuture.getKey(), pages);
}
log.info(
"Received %s pages in total. Compressed size: %s. Uncompressed size: %s.",
totalNumberOfPagesReceived,
DataSize.succinctBytes(totalCompressedSizeInBytes),
DataSize.succinctBytes(totalUncompressedSizeInBytes));
IPrestoSparkTaskExecutor<PrestoSparkSerializedPage> prestoSparkTaskExecutor = taskExecutorFactory.create(
0,
0,
serializedTaskDescriptor,
emptyScalaIterator(),
new PrestoSparkJavaExecutionTaskInputs(ImmutableMap.of(), ImmutableMap.of(), inputs.build()),
taskInfoCollector,
shuffleStatsCollector,
PrestoSparkSerializedPage.class);
return collectScalaIterator(prestoSparkTaskExecutor);
}
@VisibleForTesting
public <T extends PrestoSparkTaskOutput> RddAndMore<T> createRdd(SubPlan subPlan, Class<T> outputType, TableWriteInfo tableWriteInfo)
throws SparkException, TimeoutException
{
ImmutableMap.Builder<PlanFragmentId, JavaPairRDD<MutablePartitionId, PrestoSparkMutableRow>> rddInputs = ImmutableMap.builder();
ImmutableMap.Builder<PlanFragmentId, Broadcast<?>> broadcastInputs = ImmutableMap.builder();
ImmutableList.Builder<PrestoSparkBroadcastDependency<?>> broadcastDependencies = ImmutableList.builder();
for (SubPlan child : subPlan.getChildren()) {
PlanFragment childFragment = child.getFragment();
if (childFragment.getPartitioningScheme().getPartitioning().getHandle().equals(FIXED_BROADCAST_DISTRIBUTION)) {
RddAndMore<?> childRdd;
PrestoSparkBroadcastDependency<?> broadcastDependency;
if (isStorageBasedBroadcastJoinEnabled(session)) {
childRdd = createRdd(child, PrestoSparkStorageHandle.class, tableWriteInfo);
}
else {
childRdd = createRdd(child, PrestoSparkSerializedPage.class, tableWriteInfo);
}
broadcastDependency = createBroadcastDependency(childRdd);
broadcastInputs.put(childFragment.getId(), broadcastDependency.executeBroadcast(sparkContext));
broadcastDependencies.add(broadcastDependency);
}
else {
RddAndMore<PrestoSparkMutableRow> childRdd = createRdd(child, PrestoSparkMutableRow.class, tableWriteInfo);
rddInputs.put(childFragment.getId(), partitionBy(childFragment.getId().getId(), childRdd.getRdd(), child.getFragment().getPartitioningScheme()));
broadcastDependencies.addAll(childRdd.getBroadcastDependencies());
}
}
JavaPairRDD<MutablePartitionId, T> rdd = rddFactory.createSparkRdd(
sparkContext,
session,
subPlan.getFragment(),
rddInputs.build(),
broadcastInputs.build(),
taskExecutorFactoryProvider,
taskInfoCollector,
shuffleStatsCollector,
tableWriteInfo,
outputType);
return new RddAndMore<>(rdd, broadcastDependencies.build());
}
protected void validateStorageCapabilities(TempStorage tempStorage)
{
boolean isLocalMode = isLocalMaster(sparkContext.getConf());
List<StorageCapabilities> storageCapabilities = tempStorage.getStorageCapabilities();
if (!isLocalMode && !storageCapabilities.contains(REMOTELY_ACCESSIBLE)) {
throw new PrestoException(UNSUPPORTED_STORAGE_TYPE, "Configured TempStorage does not support remote access required for distributing broadcast tables.");
}
}
protected void queryCompletedEvent(Optional<ExecutionFailureInfo> failureInfo, OptionalLong updateCount)
{
List<SerializedTaskInfo> serializedTaskInfos = taskInfoCollector.value();
ImmutableList.Builder<TaskInfo> taskInfos = ImmutableList.builder();
long totalSerializedTaskInfoSizeInBytes = 0;
for (SerializedTaskInfo serializedTaskInfo : serializedTaskInfos) {
byte[] bytes = serializedTaskInfo.getBytesAndClear();
totalSerializedTaskInfoSizeInBytes += bytes.length;
TaskInfo taskInfo = deserializeZstdCompressed(taskInfoCodec, bytes);
taskInfos.add(taskInfo);
}
taskInfoCollector.reset();
log.info("Total serialized task info size: %s", DataSize.succinctBytes(totalSerializedTaskInfoSizeInBytes));
Optional<StageInfo> stageInfoOptional = getFinalFragmentedPlan().map(finalFragmentedPlan ->
PrestoSparkQueryExecutionFactory.createStageInfo(session.getQueryId(), finalFragmentedPlan, taskInfos.build()));
QueryState queryState = failureInfo.isPresent() ? FAILED : FINISHED;
QueryInfo queryInfo = PrestoSparkQueryExecutionFactory.createQueryInfo(
session,
query,
queryState,
Optional.of(planAndMore),
sparkQueueName,
failureInfo,
queryStateTimer,
stageInfoOptional,
warningCollector);
queryMonitor.queryCompletedEvent(queryInfo);
historyBasedPlanStatisticsTracker.updateStatistics(queryInfo);
if (queryStatusInfoOutputLocation.isPresent()) {
PrestoSparkQueryStatusInfo prestoSparkQueryStatusInfo = PrestoSparkQueryExecutionFactory.createPrestoSparkQueryInfo(
queryInfo,
Optional.of(planAndMore),
warningCollector,
updateCount);
metadataStorage.write(
queryStatusInfoOutputLocation.get(),
queryStatusInfoJsonCodec.toJsonBytes(prestoSparkQueryStatusInfo));
}
processBootstrapStats();
}
protected final void setFinalFragmentedPlan(SubPlan subPlan)
{
verify(subPlan != null, "subPlan is null");
boolean updated = finalFragmentedPlan.compareAndSet(null, subPlan);
verify(updated, "finalFragmentedPlan is already non-null");
}
public final Optional<SubPlan> getFinalFragmentedPlan()
{
return Optional.ofNullable(finalFragmentedPlan.get());
}
protected void processShuffleStats()
{
List<PrestoSparkShuffleStats> statsList = shuffleStatsCollector.value();
Map<ShuffleStatsKey, List<PrestoSparkShuffleStats>> statsMap = new TreeMap<>();
for (PrestoSparkShuffleStats stats : statsList) {
ShuffleStatsKey key = new ShuffleStatsKey(stats.getFragmentId(), stats.getOperation());
statsMap.computeIfAbsent(key, (ignored) -> new ArrayList<>()).add(stats);
}
log.info("Shuffle statistics summary:");
for (Map.Entry<ShuffleStatsKey, List<PrestoSparkShuffleStats>> fragment : statsMap.entrySet()) {
logShuffleStatsSummary(fragment.getKey(), fragment.getValue());
}
shuffleStatsCollector.reset();
}
protected void logShuffleStatsSummary(ShuffleStatsKey key, List<PrestoSparkShuffleStats> statsList)
{
long totalProcessedRows = 0;
long totalProcessedRowBatches = 0;
long totalProcessedBytes = 0;
long totalElapsedWallTimeMills = 0;
for (PrestoSparkShuffleStats stats : statsList) {
totalProcessedRows += stats.getProcessedRows();
totalProcessedRowBatches += stats.getProcessedRowBatches();
totalProcessedBytes += stats.getProcessedBytes();
totalElapsedWallTimeMills += stats.getElapsedWallTimeMills();
}
long totalElapsedWallTimeSeconds = totalElapsedWallTimeMills / 1000;
long rowsPerSecond = totalProcessedRows;
long rowBatchesPerSecond = totalProcessedRowBatches;
long bytesPerSecond = totalProcessedBytes;
if (totalElapsedWallTimeSeconds > 0) {
rowsPerSecond = totalProcessedRows / totalElapsedWallTimeSeconds;
rowBatchesPerSecond = totalProcessedRowBatches / totalElapsedWallTimeSeconds;
bytesPerSecond = totalProcessedBytes / totalElapsedWallTimeSeconds;
}
long averageRowSize = 0;
if (totalProcessedRows > 0) {
averageRowSize = totalProcessedBytes / totalProcessedRows;
}
long averageRowBatchSize = 0;
if (totalProcessedRowBatches > 0) {
averageRowBatchSize = totalProcessedBytes / totalProcessedRowBatches;
}
log.info(
"Fragment: %s, Operation: %s, Rows: %s, Row Batches: %s, Size: %s, Avg Row Size: %s, Avg Row Batch Size: %s, Time: %s, %s rows/s, %s batches/s, %s/s",
key.getFragmentId(),
key.getOperation(),
totalProcessedRows,
totalProcessedRowBatches,
DataSize.succinctBytes(totalProcessedBytes),
DataSize.succinctBytes(averageRowSize),
DataSize.succinctBytes(averageRowBatchSize),
Duration.succinctDuration(totalElapsedWallTimeMills, MILLISECONDS),
rowsPerSecond,
rowBatchesPerSecond,
DataSize.succinctBytes(bytesPerSecond));
}
protected Optional<int[]> getBucketToPartition(Session session, PartitioningHandle partitioningHandle, int hashPartitionCount)
{
if (partitioningHandle.equals(FIXED_HASH_DISTRIBUTION)) {
return Optional.of(IntStream.range(0, hashPartitionCount).toArray());
}
// FIXED_ARBITRARY_DISTRIBUTION is used for UNION ALL
// UNION ALL inputs could be source inputs or shuffle inputs
if (partitioningHandle.equals(FIXED_ARBITRARY_DISTRIBUTION)) {
// given modular hash function, partition count could be arbitrary size
// simply reuse hash_partition_count for convenience
// it can also be set by a separate session property if needed
return Optional.of(IntStream.range(0, hashPartitionCount).toArray());
}
if (partitioningHandle.getConnectorId().isPresent()) {
int connectorPartitionCount = getPartitionCount(session, partitioningHandle);
return Optional.of(IntStream.range(0, connectorPartitionCount).toArray());
}
return Optional.empty();
}
protected int getPartitionCount(Session session, PartitioningHandle partitioning)
{
ConnectorNodePartitioningProvider partitioningProvider = getPartitioningProvider(partitioning);
return partitioningProvider.getBucketCount(
partitioning.getTransactionHandle().orElse(null),
session.toConnectorSession(),
partitioning.getConnectorHandle());
}
protected ConnectorNodePartitioningProvider getPartitioningProvider(PartitioningHandle partitioning)
{
ConnectorId connectorId = partitioning.getConnectorId()
.orElseThrow(() -> new IllegalArgumentException("Unexpected partitioning: " + partitioning));
return partitioningProviderManager.getPartitioningProvider(connectorId);
}
protected SubPlan configureOutputPartitioning(Session session, SubPlan subPlan, int hashPartitionCount)
{
PlanFragment fragment = subPlan.getFragment();
if (!fragment.getPartitioningScheme().getBucketToPartition().isPresent()) {
PartitioningHandle partitioningHandle = fragment.getPartitioningScheme().getPartitioning().getHandle();
Optional<int[]> bucketToPartition = getBucketToPartition(session, partitioningHandle, hashPartitionCount);
if (bucketToPartition.isPresent()) {
fragment = fragment.withBucketToPartition(bucketToPartition);
}
}
return new SubPlan(
fragment,
subPlan.getChildren().stream()
.map(child -> configureOutputPartitioning(session, child, hashPartitionCount))
.collect(toImmutableList()));
}
@VisibleForTesting
public TableWriteInfo getTableWriteInfo(Session session, SubPlan plan)
{
StreamingPlanSection streamingPlanSection = extractStreamingSections(plan);
StreamingSubPlan streamingSubPlan = streamingPlanSection.getPlan();
TableWriteInfo tableWriteInfo = createTableWriteInfo(streamingSubPlan, metadata, session);
if (tableWriteInfo.getWriterTarget().isPresent()) {
checkPageSinkCommitIsSupported(session, tableWriteInfo.getWriterTarget().get());
}
return tableWriteInfo;
}
@VisibleForTesting
public TableWriteInfo getTableWriteInfo(Session session, PlanNode planNode)
{
TableWriteInfo tableWriteInfo = createTableWriteInfo(planNode, metadata, session);
if (tableWriteInfo.getWriterTarget().isPresent()) {
checkPageSinkCommitIsSupported(session, tableWriteInfo.getWriterTarget().get());
}
return tableWriteInfo;
}
private void checkPageSinkCommitIsSupported(Session session, ExecutionWriterTarget writerTarget)
{
ConnectorId connectorId;
if (writerTarget instanceof ExecutionWriterTarget.DeleteHandle) {
throw new PrestoException(NOT_SUPPORTED, "delete queries are not supported by presto on spark");
}
else if (writerTarget instanceof ExecutionWriterTarget.CreateHandle) {
connectorId = ((ExecutionWriterTarget.CreateHandle) writerTarget).getHandle().getConnectorId();
}
else if (writerTarget instanceof ExecutionWriterTarget.InsertHandle) {
connectorId = ((ExecutionWriterTarget.InsertHandle) writerTarget).getHandle().getConnectorId();
}
else if (writerTarget instanceof ExecutionWriterTarget.RefreshMaterializedViewHandle) {
connectorId = ((ExecutionWriterTarget.RefreshMaterializedViewHandle) writerTarget).getHandle().getConnectorId();
}
else {
throw new IllegalArgumentException("unexpected writer target type: " + writerTarget.getClass());
}
verify(connectorId != null, "connectorId is null");
Set<ConnectorCapabilities> connectorCapabilities = metadata.getConnectorCapabilities(session, connectorId);
if (!connectorCapabilities.contains(SUPPORTS_PAGE_SINK_COMMIT)) {
throw new PrestoException(NOT_SUPPORTED, "catalog does not support page sink commit: " + connectorId);
}
}
// Returns RDD for specified fragmented SubPlan
// This method ensures that RDD is created only once for a sub-plan, where identity is determined by fragment id
// For broadcast RDDs, it returns RDD to be broadcasted.
protected synchronized <T extends PrestoSparkTaskOutput> RddAndMore<T> createRddForSubPlan(SubPlan subPlan,
TableWriteInfo tableWriteInfo,
Optional<Class<?>> outputTypeOptional)
throws SparkException, TimeoutException
{
if (fragmentIdToRdd.containsKey(subPlan.getFragment().getId())) {
return fragmentIdToRdd.get(subPlan.getFragment().getId());
}
ImmutableMap.Builder<PlanFragmentId, JavaPairRDD<MutablePartitionId, PrestoSparkMutableRow>> rddInputs = ImmutableMap.builder();
ImmutableMap.Builder<PlanFragmentId, Broadcast<?>> broadcastInputs = ImmutableMap.builder();
ImmutableList.Builder<PrestoSparkBroadcastDependency<?>> broadcastDependencies = ImmutableList.builder();
for (SubPlan child : subPlan.getChildren()) {
RddAndMore<?> childRdd = createRddForSubPlan(child, tableWriteInfo, Optional.empty());
if (childRdd.isBroadcastDistribution()) {
PrestoSparkBroadcastDependency<?> broadcastDependency = createBroadcastDependency(childRdd);
broadcastInputs.put(child.getFragment().getId(), broadcastDependency.executeBroadcast(sparkContext));
broadcastDependencies.add(broadcastDependency);
}
else {
rddInputs.put(child.getFragment().getId(), (JavaPairRDD<MutablePartitionId, PrestoSparkMutableRow>) childRdd.getRdd());
broadcastDependencies.addAll(childRdd.getBroadcastDependencies());
}
}
Class outputType = outputTypeOptional.orElseGet(() -> getOutputType(subPlan));
JavaPairRDD<MutablePartitionId, T> rdd = rddFactory.createSparkRdd(
sparkContext,
session,
subPlan.getFragment(),
rddInputs.build(),
broadcastInputs.build(),
taskExecutorFactoryProvider,
taskInfoCollector,
shuffleStatsCollector,
tableWriteInfo,
outputType);
// For intermediate, non-broadcast stages - we use partitioned RDD
// These stages produce PrestoSparkMutableRow
if (outputType == PrestoSparkMutableRow.class) {
rdd = (JavaPairRDD<MutablePartitionId, T>) partitionBy(subPlan.getFragment().getId().getId(), (JavaPairRDD<MutablePartitionId, PrestoSparkMutableRow>) rdd, subPlan.getFragment().getPartitioningScheme());
}
RddAndMore rddAndMore = new RddAndMore<T>(rdd, broadcastDependencies.build(), Optional.ofNullable(subPlan.getFragment().getPartitioningScheme().getPartitioning().getHandle()));
fragmentIdToRdd.put(subPlan.getFragment().getId(), rddAndMore);
return rddAndMore;
}
protected Optional<RddAndMore> getRdd(PlanFragmentId planFragmentId)
{
return Optional.ofNullable(fragmentIdToRdd.get(planFragmentId));
}
// Returns output type of RDD for a subPlan
private Class getOutputType(SubPlan subPlan)
{
// Root node has SerializedPage as output
if (isRootFragment(subPlan.getFragment())) {
return PrestoSparkSerializedPage.class;
}
// Broadcast node can have SerializedPage vs Storage handle depending on how broadcast is done
if (isBroadcastDistribution(subPlan)) {
return getOutputTypeForBroadcastNode();
}
// Everything else is Mutable row
return PrestoSparkMutableRow.class;
}
private Class getOutputTypeForBroadcastNode()
{
if (isStorageBasedBroadcastJoinEnabled(session)) {
return PrestoSparkStorageHandle.class; // Handle to file
}
else {
return PrestoSparkSerializedPage.class; // In Memory broadcast
}
}
private boolean isBroadcastDistribution(SubPlan subPlan)
{
return subPlan.getFragment().getPartitioningScheme().getPartitioning().getHandle().equals(FIXED_BROADCAST_DISTRIBUTION);
}
private PrestoSparkBroadcastDependency<?> createBroadcastDependency(RddAndMore<?> childRdd)
{
PrestoSparkBroadcastDependency<?> broadcastDependency;
DataSize maxBroadcastMemory = getSparkBroadcastJoinMaxMemoryOverride(session);
if (maxBroadcastMemory == null) {
maxBroadcastMemory = new DataSize(min(nodeMemoryConfig.getMaxQueryBroadcastMemory().toBytes(), getQueryMaxBroadcastMemory(session).toBytes()), BYTE);
}
if (featuresConfig.isNativeExecutionEnabled()) {
return new PrestoSparkNativeStorageBasedDependency(
(RddAndMore<PrestoSparkSerializedPage>) childRdd,
maxBroadcastMemory,
queryCompletionDeadline,
waitTimeMetrics,
pagesSerde);
}
if (isStorageBasedBroadcastJoinEnabled(session)) {
validateStorageCapabilities(tempStorage);
TempDataOperationContext tempDataOperationContext = new TempDataOperationContext(
session.getSource(),
session.getQueryId().getId(),
session.getClientInfo(),
Optional.of(session.getClientTags()),
session.getIdentity());
broadcastDependency = new PrestoSparkStorageBasedBroadcastDependency(
(RddAndMore<PrestoSparkStorageHandle>) childRdd,
maxBroadcastMemory,
getQueryMaxTotalMemoryPerNode(session),
queryCompletionDeadline,
tempStorage,
tempDataOperationContext,
waitTimeMetrics);
}
else {
broadcastDependency = new PrestoSparkMemoryBasedBroadcastDependency(
(RddAndMore<PrestoSparkSerializedPage>) childRdd,
maxBroadcastMemory,
queryCompletionDeadline,
waitTimeMetrics);
}
return broadcastDependency;
}
@VisibleForTesting
public FragmentExecutionResult executeFragment(SubPlan plan,
TableWriteInfo tableWriteInfo,
Optional<Class<?>> outputType)
throws SparkException, TimeoutException
{
RddAndMore rddAndMore = createRddForSubPlan(plan, tableWriteInfo, outputType);
List<ShuffleDependency> shuffleDependencies = rddAndMore.getShuffleDependencies();
SimpleFutureAction<MapOutputStatistics> mapOutputStatisticsFutureAction = null;
// For PoS, we don't expect more than 1 shuffle dependency.
verify(shuffleDependencies.size() <= 1, "More than 1 shuffle dependency found");
if (!shuffleDependencies.isEmpty()
// We can only execute map stage on RDD with more than 0 partition(Non-Empty tables)
&& shuffleDependencies.get(0).rdd().partitions().length > 0) {
ShuffleDependency shuffleDependency = shuffleDependencies.get(0);
mapOutputStatisticsFutureAction = sparkContext.sc().submitMapStage(shuffleDependency);
}
return new FragmentExecutionResult(rddAndMore, Optional.ofNullable(mapOutputStatisticsFutureAction));
}
private static class ShuffleStatsKey
implements Comparable<ShuffleStatsKey>
{
private final int fragmentId;
private final PrestoSparkShuffleStats.Operation operation;
public ShuffleStatsKey(int fragmentId, PrestoSparkShuffleStats.Operation operation)
{
this.fragmentId = fragmentId;
this.operation = requireNonNull(operation, "operation is null");
}
public int getFragmentId()
{
return fragmentId;
}
public PrestoSparkShuffleStats.Operation getOperation()
{
return operation;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
ShuffleStatsKey that = (ShuffleStatsKey) o;
return fragmentId == that.fragmentId &&
operation == that.operation;
}
@Override
public int hashCode()
{
return Objects.hash(fragmentId, operation);
}
@Override
public int compareTo(ShuffleStatsKey that)
{
return ComparisonChain.start()
.compare(this.fragmentId, that.fragmentId)
.compare(this.operation, that.operation)
.result();
}
}
private void processBootstrapStats()
{
if (!this.bootstrapMetricsCollector.isPresent()) {
return;
}
List<Map<String, Long>> bootstrapStats = this.bootstrapMetricsCollector.get().value();
int loggedBootstrapCount = bootstrapStats.size();
if (loggedBootstrapCount > 0) {
Set<String> statsKeySet = bootstrapStats.get(0).keySet();
StringBuilder metricsLog = new StringBuilder();
metricsLog.append("Average executor bootstrap durations in milliseconds: \n");
for (String statsKey : statsKeySet) {
double avgDuration = 0.0;
for (int i = 0; i < loggedBootstrapCount; i++) {
avgDuration = (avgDuration * i + bootstrapStats.get(i).get(statsKey)) / (i + 1);
}
metricsLog.append(String.format("%s: %.2f \n", statsKey, avgDuration));
}
log.info(metricsLog.toString());
}
else {
log.info("No entry found in bootstrapMetricsCollector");
}
}
private void tuneMaxExecutorsCount()
{
// Executor allocation is currently only supported at root level of the plan
// In future this could be extended to fragment level configuration
if (planAndMore.getPhysicalResourceSettings().isMaxExecutorCountAutoTuned()) {
sparkContext.sc().conf().set(SPARK_DYNAMIC_ALLOCATION_MAX_EXECUTORS_CONFIG,
Integer.toString(planAndMore.getPhysicalResourceSettings().getMaxExecutorCount()));
}
}
}