PrestoSparkAdaptiveQueryExecution.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.spark.execution;
import com.facebook.airlift.json.Codec;
import com.facebook.airlift.json.JsonCodec;
import com.facebook.airlift.log.Logger;
import com.facebook.presto.Session;
import com.facebook.presto.cost.FragmentStatsProvider;
import com.facebook.presto.cost.HistoryBasedPlanStatisticsTracker;
import com.facebook.presto.event.QueryMonitor;
import com.facebook.presto.execution.QueryManagerConfig;
import com.facebook.presto.execution.QueryStateTimer;
import com.facebook.presto.execution.TaskInfo;
import com.facebook.presto.execution.scheduler.TableWriteInfo;
import com.facebook.presto.memory.NodeMemoryConfig;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spark.ErrorClassifier;
import com.facebook.presto.spark.PrestoSparkMetadataStorage;
import com.facebook.presto.spark.PrestoSparkQueryData;
import com.facebook.presto.spark.PrestoSparkQueryStatusInfo;
import com.facebook.presto.spark.PrestoSparkServiceWaitTimeMetrics;
import com.facebook.presto.spark.PrestoSparkTaskDescriptor;
import com.facebook.presto.spark.RddAndMore;
import com.facebook.presto.spark.classloader_interface.MutablePartitionId;
import com.facebook.presto.spark.classloader_interface.PrestoSparkSerializedPage;
import com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleStats;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskExecutorFactoryProvider;
import com.facebook.presto.spark.classloader_interface.SerializedTaskInfo;
import com.facebook.presto.spark.execution.task.PrestoSparkTaskExecutorFactory;
import com.facebook.presto.spark.node.PrestoSparkNodePartitioningManager;
import com.facebook.presto.spark.planner.IterativePlanFragmenter;
import com.facebook.presto.spark.planner.PrestoSparkPlanFragmenter;
import com.facebook.presto.spark.planner.PrestoSparkQueryPlanner.PlanAndMore;
import com.facebook.presto.spark.planner.PrestoSparkRddFactory;
import com.facebook.presto.spark.planner.optimizers.AdaptivePlanOptimizers;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.page.PagesSerde;
import com.facebook.presto.spi.plan.OutputNode;
import com.facebook.presto.spi.plan.PlanFragmentId;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.storage.TempStorage;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.planner.PartitioningProviderManager;
import com.facebook.presto.sql.planner.Plan;
import com.facebook.presto.sql.planner.SubPlan;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
import com.facebook.presto.sql.planner.sanity.PlanChecker;
import com.facebook.presto.sql.planner.sanity.PlanCheckerProviderManager;
import com.facebook.presto.transaction.TransactionManager;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.UncheckedExecutionException;
import io.airlift.units.Duration;
import org.apache.spark.MapOutputStatistics;
import org.apache.spark.SimpleFutureAction;
import org.apache.spark.SparkException;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.util.CollectionAccumulator;
import org.apache.spark.util.ThreadUtils;
import scala.Tuple2;
import scala.concurrent.ExecutionContextExecutorService;
import scala.concurrent.impl.ExecutionContextImpl;
import scala.runtime.AbstractFunction1;
import scala.util.Try;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeoutException;
import java.util.function.Function;
import java.util.stream.Collectors;
import static com.facebook.presto.execution.QueryState.PLANNING;
import static com.facebook.presto.spark.PrestoSparkQueryExecutionFactory.createQueryInfo;
import static com.facebook.presto.spark.PrestoSparkQueryExecutionFactory.createStageInfo;
import static com.facebook.presto.spark.execution.RuntimeStatistics.createRuntimeStats;
import static com.facebook.presto.spark.util.PrestoSparkUtils.computeNextTimeout;
import static com.facebook.presto.sql.planner.PlanFragmenterUtils.isCoordinatorOnlyDistribution;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION;
import static com.facebook.presto.sql.planner.planPrinter.PlanPrinter.textLogicalPlan;
import static com.facebook.presto.sql.planner.planPrinter.PlanPrinter.textPlanFragment;
import static com.google.common.base.Throwables.propagateIfPossible;
import static com.google.common.base.Verify.verify;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
/**
* This class drives the adaptive query execution of a Presto-on-Spark query.
* It iteratively generates fragments for the query, executes the fragments as they become ready (a fragment is ready when all
* its dependencies have been executed), extracts statistics out of the executed fragments, attempts to re-optimize the part of
* the query plan that has not yet been executed based on the new statistics before continuing the execution.
*/
public class PrestoSparkAdaptiveQueryExecution
extends AbstractPrestoSparkQueryExecution
{
private static final Logger log = Logger.get(PrestoSparkAdaptiveQueryExecution.class);
private final IterativePlanFragmenter iterativePlanFragmenter;
private final List<PlanOptimizer> adaptivePlanOptimizers;
private final VariableAllocator variableAllocator;
private final PlanNodeIdAllocator idAllocator;
private final FragmentStatsProvider fragmentStatsProvider;
/**
* Set with the IDs of the fragments that have finished execution.
*/
private final Set<PlanFragmentId> executedFragments = ConcurrentHashMap.newKeySet();
/**
* Queue with events related to the execution of plan fragments.
*/
private final BlockingQueue<FragmentCompletionEvent> fragmentEventQueue = new LinkedBlockingQueue<>();
// TODO: Bring over from the AbstractPrestoSparkQueryExecution the methods that are specific to adaptive execution.
public PrestoSparkAdaptiveQueryExecution(
JavaSparkContext sparkContext,
Session session,
QueryMonitor queryMonitor,
CollectionAccumulator<SerializedTaskInfo> taskInfoCollector,
CollectionAccumulator<PrestoSparkShuffleStats> shuffleStatsCollector,
PrestoSparkTaskExecutorFactory taskExecutorFactory,
PrestoSparkTaskExecutorFactoryProvider taskExecutorFactoryProvider,
QueryStateTimer queryStateTimer,
WarningCollector warningCollector,
String query,
PlanAndMore planAndMore,
Optional<String> sparkQueueName,
Codec<TaskInfo> taskInfoCodec,
JsonCodec<PrestoSparkTaskDescriptor> sparkTaskDescriptorJsonCodec,
JsonCodec<PrestoSparkQueryStatusInfo> queryStatusInfoJsonCodec,
JsonCodec<PrestoSparkQueryData> queryDataJsonCodec,
PrestoSparkRddFactory rddFactory,
TransactionManager transactionManager,
PagesSerde pagesSerde,
PrestoSparkExecutionExceptionFactory executionExceptionFactory,
Duration queryTimeout,
long queryCompletionDeadline,
PrestoSparkMetadataStorage metadataStorage,
Optional<String> queryStatusInfoOutputLocation,
Optional<String> queryDataOutputLocation,
TempStorage tempStorage,
NodeMemoryConfig nodeMemoryConfig,
FeaturesConfig featuresConfig,
QueryManagerConfig queryManagerConfig,
Set<PrestoSparkServiceWaitTimeMetrics> waitTimeMetrics,
Optional<ErrorClassifier> errorClassifier,
PrestoSparkPlanFragmenter planFragmenter,
Metadata metadata,
PartitioningProviderManager partitioningProviderManager,
HistoryBasedPlanStatisticsTracker historyBasedPlanStatisticsTracker,
AdaptivePlanOptimizers adaptivePlanOptimizers,
VariableAllocator variableAllocator,
PlanNodeIdAllocator idAllocator,
FragmentStatsProvider fragmentStatsProvider,
Optional<CollectionAccumulator<Map<String, Long>>> bootstrapMetricsCollector,
PlanCheckerProviderManager planCheckerProviderManager)
{
super(
sparkContext,
session,
queryMonitor,
taskInfoCollector,
shuffleStatsCollector,
taskExecutorFactory,
taskExecutorFactoryProvider,
queryStateTimer,
warningCollector,
query,
planAndMore,
sparkQueueName,
taskInfoCodec,
sparkTaskDescriptorJsonCodec,
queryStatusInfoJsonCodec,
queryDataJsonCodec,
rddFactory,
transactionManager,
pagesSerde,
executionExceptionFactory,
queryTimeout,
queryCompletionDeadline,
metadataStorage,
queryStatusInfoOutputLocation,
queryDataOutputLocation,
tempStorage,
nodeMemoryConfig,
featuresConfig,
queryManagerConfig,
waitTimeMetrics,
errorClassifier,
planFragmenter,
metadata,
partitioningProviderManager,
historyBasedPlanStatisticsTracker,
bootstrapMetricsCollector);
this.fragmentStatsProvider = requireNonNull(fragmentStatsProvider, "fragmentStatsProvider is null");
this.adaptivePlanOptimizers = requireNonNull(adaptivePlanOptimizers, "adaptivePlanOptimizers is null").getAdaptiveOptimizers();
this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator is null");
this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
this.iterativePlanFragmenter = createIterativePlanFragmenter(requireNonNull(planCheckerProviderManager, "planCheckerProviderManager is null"));
}
private IterativePlanFragmenter createIterativePlanFragmenter(PlanCheckerProviderManager planCheckerProviderManager)
{
boolean noExchange = false;
Function<PlanFragmentId, Boolean> isFragmentFinished = this.executedFragments::contains;
// TODO Create the IterativePlanFragmenter by injection (it has to become stateless first--check PR 18811).
return new IterativePlanFragmenter(
this.planAndMore.getPlan(),
isFragmentFinished,
this.metadata,
new PlanChecker(this.featuresConfig, noExchange, planCheckerProviderManager),
this.idAllocator,
new PrestoSparkNodePartitioningManager(this.partitioningProviderManager),
this.queryManagerConfig,
this.session,
this.warningCollector,
noExchange);
}
@Override
protected List<Tuple2<MutablePartitionId, PrestoSparkSerializedPage>> doExecute()
throws SparkException, TimeoutException
{
queryStateTimer.beginRunning();
log.info("Using AdaptiveQueryExecutor");
log.info(format("Logical plan : %s",
textLogicalPlan(this.planAndMore.getPlan().getRoot(), this.planAndMore.getPlan().getTypes(), this.planAndMore.getPlan().getStatsAndCosts(), metadata.getFunctionAndTypeManager(), session, 0)));
queryMonitor.queryUpdatedEvent(
createQueryInfo(
session,
query,
PLANNING,
Optional.of(planAndMore),
sparkQueueName,
Optional.empty(),
queryStateTimer,
Optional.of(createStageInfo(session.getQueryId(), planFragmenter.fragmentQueryPlan(session, planAndMore.getPlan(), warningCollector), ImmutableList.of())),
warningCollector));
IterativePlanFragmenter.PlanAndFragments planAndFragments = iterativePlanFragmenter.createReadySubPlans(this.planAndMore.getPlan().getRoot());
ExecutionContextExecutorService executorService = !planAndFragments.hasRemainingPlan() ? null :
(ExecutionContextExecutorService) ExecutionContextImpl.fromExecutorService(ThreadUtils.newDaemonCachedThreadPool("AdaptiveExecution", 16, 60), null);
TableWriteInfo tableWriteInfo = getTableWriteInfo(session, this.planAndMore.getPlan().getRoot());
while (planAndFragments.hasRemainingPlan()) {
List<SubPlan> readyFragments = planAndFragments.getReadyFragments();
Set<PlanFragmentId> rootChildren = getRootChildNodeFragmentIDs(planAndFragments.getRemainingPlan().get());
for (SubPlan fragment : readyFragments) {
log.info(format("Executing fragment : %s",
textPlanFragment(fragment.getFragment(), metadata.getFunctionAndTypeManager(), session, true)));
Optional<Class<?>> outputType = Optional.empty();
if (isCoordinatorOnly(this.planAndMore.getPlan()) && rootChildren.contains(fragment.getFragment().getId())) {
outputType = Optional.of(PrestoSparkSerializedPage.class);
}
SubPlan currentFragment = configureOutputPartitioning(session, fragment, planAndMore.getPhysicalResourceSettings().getHashPartitionCount());
FragmentExecutionResult fragmentExecutionResult = executeFragment(currentFragment, tableWriteInfo, outputType);
// Create the corresponding event when the fragment finishes execution (successfully or not) and place it in the event queue.
// Note that these are Scala futures that we manipulate here in Java.
Optional<SimpleFutureAction<MapOutputStatistics>> fragmentFuture = fragmentExecutionResult.getMapOutputStatisticsFutureAction();
if (fragmentFuture.isPresent()) {
SimpleFutureAction<MapOutputStatistics> mapOutputStatsFuture = fragmentFuture.get();
mapOutputStatsFuture.onComplete(new AbstractFunction1<Try<MapOutputStatistics>, Void>()
{
@Override
public Void apply(Try<MapOutputStatistics> result)
{
if (result.isSuccess()) {
Optional<MapOutputStatistics> mapOutputStats = Optional.ofNullable(result.get());
publishFragmentCompletionEvent(new FragmentCompletionSuccessEvent(currentFragment.getFragment().getId(), mapOutputStats));
}
else {
Throwable throwable = result.failed().get();
publishFragmentCompletionEvent(new FragmentCompletionFailureEvent(currentFragment.getFragment().getId(), throwable));
}
return null;
}
}, executorService);
}
else {
log.info("Fragment %s will not get executed now either because there was no exchange involved (a broadcast is present) or because of an unknown issue.",
fragment.getFragment().getId());
// Mark this fragment/non-shuffle stage as completed to continue next stage plan generation
publishFragmentCompletionEvent(new FragmentCompletionSuccessEvent(currentFragment.getFragment().getId(), Optional.empty()));
}
}
// Consume the next fragment execution completion event (block if no new fragment execution has finished) and re-optimize if possible.
FragmentCompletionEvent fragmentEvent;
try {
fragmentEvent = fragmentEventQueue.poll(computeNextTimeout(queryCompletionDeadline), MILLISECONDS);
}
catch (InterruptedException e) {
throw new RuntimeException(e);
}
// In case poll() timed out without getting an event in the queue.
if (fragmentEvent == null) {
throw executionExceptionFactory.toPrestoSparkExecutionException(new RuntimeException("Adaptive query execution failed due to timeout."));
}
if (fragmentEvent instanceof FragmentCompletionFailureEvent) {
FragmentCompletionFailureEvent failureEvent = (FragmentCompletionFailureEvent) fragmentEvent;
propagateIfPossible(failureEvent.getExecutionError(), SparkException.class);
propagateIfPossible(failureEvent.getExecutionError(), RuntimeException.class);
throw new UncheckedExecutionException(failureEvent.getExecutionError());
}
verify(fragmentEvent instanceof FragmentCompletionSuccessEvent, String.format("Unexpected FragmentCompletionEvent type: %s", fragmentEvent.getClass().getSimpleName()));
FragmentCompletionSuccessEvent successEvent = (FragmentCompletionSuccessEvent) fragmentEvent;
executedFragments.add(successEvent.getFragmentId());
// add runtime stats to the fragmentStatsProvider
createRuntimeStats(successEvent.getMapOutputStats()).ifPresent(
stats -> fragmentStatsProvider.putStats(session.getQueryId(), successEvent.getFragmentId(), stats));
// Re-optimize plan.
PlanNode optimizedPlan = planAndFragments.getRemainingPlan().get();
for (PlanOptimizer optimizer : adaptivePlanOptimizers) {
optimizedPlan = optimizer.optimize(optimizedPlan, session, TypeProvider.viewOf(variableAllocator.getVariables()), variableAllocator, idAllocator, warningCollector).getPlanNode();
}
if (!optimizedPlan.equals(planAndFragments.getRemainingPlan().get())) {
log.info("adaptive plan optimizations triggered");
}
// Call the iterative fragmenter on the remaining plan that has not yet been submitted for execution.
planAndFragments = iterativePlanFragmenter.createReadySubPlans(optimizedPlan);
}
verify(planAndFragments.getReadyFragments().size() == 1, "The last step of the adaptive execution is expected to have a single fragment remaining.");
SubPlan finalFragment = planAndFragments.getReadyFragments().get(0);
setFinalFragmentedPlan(finalFragment);
return executeFinalFragment(finalFragment, tableWriteInfo);
}
private static Set<PlanFragmentId> getRootChildNodeFragmentIDs(PlanNode rootPlanNode)
{
return PlanNodeSearcher.searchFrom(rootPlanNode)
.recurseOnlyWhen(node -> !(node instanceof ExchangeNode && ((ExchangeNode) node).getScope() == ExchangeNode.Scope.REMOTE_STREAMING))
.where(node1 -> node1 instanceof RemoteSourceNode)
.findAll()
.stream()
.map(n -> ((RemoteSourceNode) n).getSourceFragmentIds())
.flatMap(l -> l.stream())
.collect(Collectors.toSet());
}
private boolean isCoordinatorOnly(Plan plan)
{
if (!(plan.getRoot() instanceof OutputNode)) {
return false;
}
PlanNode outputSourceNode = ((OutputNode) plan.getRoot()).getSource();
return isCoordinatorOnlyDistribution(outputSourceNode);
}
private void publishFragmentCompletionEvent(FragmentCompletionEvent fragmentCompletionEvent)
{
try {
fragmentEventQueue.put(fragmentCompletionEvent);
}
catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
/**
* Execute the final fragment of the plan and collect the result.
*/
private List<Tuple2<MutablePartitionId, PrestoSparkSerializedPage>> executeFinalFragment(SubPlan finalFragment, TableWriteInfo tableWriteInfo)
throws SparkException, TimeoutException
{
if (finalFragment.getFragment().getPartitioning().equals(COORDINATOR_DISTRIBUTION)) {
Map<PlanFragmentId, RddAndMore<PrestoSparkSerializedPage>> inputRdds = new HashMap<>();
for (SubPlan child : finalFragment.getChildren()) {
inputRdds.put(child.getFragment().getId(), getRdd(child.getFragment().getId()).get());
}
return collectPages(tableWriteInfo, finalFragment.getFragment(), inputRdds);
}
RddAndMore rddAndMore = createRddForSubPlan(finalFragment, tableWriteInfo, Optional.of(PrestoSparkSerializedPage.class));
return rddAndMore.collectAndDestroyDependenciesWithTimeout(computeNextTimeout(queryCompletionDeadline), MILLISECONDS, waitTimeMetrics);
}
/**
* Event for the completion of a fragment's execution.
*/
private class FragmentCompletionEvent
{
protected final PlanFragmentId fragmentId;
private FragmentCompletionEvent(PlanFragmentId fragmentId)
{
this.fragmentId = fragmentId;
}
public PlanFragmentId getFragmentId()
{
return fragmentId;
}
}
private class FragmentCompletionSuccessEvent
extends FragmentCompletionEvent
{
private Optional<MapOutputStatistics> mapOutputStats;
private FragmentCompletionSuccessEvent(PlanFragmentId fragmentId, Optional<MapOutputStatistics> mapOutputStats)
{
super(fragmentId);
this.mapOutputStats = mapOutputStats;
}
public Optional<MapOutputStatistics> getMapOutputStats()
{
return mapOutputStats;
}
}
private class FragmentCompletionFailureEvent
extends FragmentCompletionEvent
{
private Throwable executionError;
private FragmentCompletionFailureEvent(PlanFragmentId fragmentId, Throwable executionError)
{
super(fragmentId);
this.executionError = executionError;
}
public Throwable getExecutionError()
{
return executionError;
}
}
}