ConcurrentPhaseExecutor.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.benchmark.framework;
import com.facebook.airlift.event.client.EventClient;
import com.facebook.airlift.log.Logger;
import com.facebook.presto.benchmark.event.BenchmarkPhaseEvent;
import com.facebook.presto.benchmark.event.BenchmarkQueryEvent;
import com.facebook.presto.benchmark.prestoaction.PrestoActionFactory;
import com.facebook.presto.sql.parser.ParsingOptions;
import com.facebook.presto.sql.parser.SqlParser;
import com.google.inject.Inject;
import java.util.EnumMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import static com.facebook.presto.benchmark.event.BenchmarkQueryEvent.Status;
import static com.facebook.presto.benchmark.event.BenchmarkQueryEvent.Status.FAILED;
import static com.facebook.presto.benchmark.event.BenchmarkQueryEvent.Status.SUCCEEDED;
import static java.lang.String.format;
import static java.lang.Thread.currentThread;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.Executors.newFixedThreadPool;
public class ConcurrentPhaseExecutor
extends AbstractPhaseExecutor<ConcurrentExecutionPhase>
{
private static final int DEFAULT_MAX_CONCURRENCY = 70;
private static final Logger log = Logger.get(ConcurrentPhaseExecutor.class);
private final boolean continueOnFailure;
private final Optional<Integer> maxConcurrency;
@Inject
public ConcurrentPhaseExecutor(
SqlParser sqlParser,
ParsingOptions parsingOptions,
PrestoActionFactory prestoActionFactory,
Set<EventClient> eventClients,
BenchmarkRunnerConfig config)
{
super(sqlParser, parsingOptions, prestoActionFactory, eventClients, config.getTestId());
this.continueOnFailure = config.isContinueOnFailure();
this.maxConcurrency = requireNonNull(config.getMaxConcurrency(), "maxConcurrency is null");
}
@Override
public BenchmarkPhaseEvent runPhase(ConcurrentExecutionPhase phase, BenchmarkSuite suite)
{
int maxConcurrency = this.maxConcurrency.orElseGet(() -> phase.getMaxConcurrency().orElse(DEFAULT_MAX_CONCURRENCY));
log.info("Starting concurrent phase '%s' with max concurrency %s", phase.getName(), maxConcurrency);
ExecutorService executor = newFixedThreadPool(maxConcurrency);
try {
CompletionService<BenchmarkQueryEvent> completionService = new ExecutorCompletionService<>(executor);
for (String queryName : phase.getQueries()) {
BenchmarkQuery benchmarkQuery = overrideSessionProperties(suite.getQueries().get(queryName), suite.getSessionProperties());
completionService.submit(() -> runQuery(benchmarkQuery));
}
return reportProgressUntilFinished(phase, completionService);
}
finally {
executor.shutdownNow();
}
}
private BenchmarkPhaseEvent reportProgressUntilFinished(
ConcurrentExecutionPhase phase,
CompletionService<BenchmarkQueryEvent> completionService)
{
String phaseName = phase.getName();
int completed = 0;
double lastProgress = 0;
int queriesSubmitted = phase.getQueries().size();
Map<Status, Integer> statusCount = new EnumMap<>(Status.class);
while (completed < queriesSubmitted) {
try {
BenchmarkQueryEvent event = completionService.take().get();
postEvent(event);
completed++;
statusCount.compute(event.getEventStatus(), (status, count) -> count == null ? 1 : count + 1);
if (event.getEventStatus() == FAILED && !continueOnFailure) {
return BenchmarkPhaseEvent.failed(phaseName, event.getErrorMessage());
}
double progress = ((double) completed) / queriesSubmitted * 100;
if (progress - lastProgress > 0.5 || completed == queriesSubmitted) {
log.info("Progress: %s succeeded, %s failed, %s submitted, %.2f%% done",
statusCount.getOrDefault(SUCCEEDED, 0),
statusCount.getOrDefault(FAILED, 0),
queriesSubmitted,
progress);
lastProgress = progress;
}
}
catch (InterruptedException e) {
currentThread().interrupt();
if (!continueOnFailure) {
return BenchmarkPhaseEvent.failed(phaseName, e.toString());
}
}
catch (ExecutionException e) {
if (!continueOnFailure) {
return BenchmarkPhaseEvent.failed(phaseName, e.toString());
}
}
}
if (statusCount.getOrDefault(FAILED, 0) > 0) {
return BenchmarkPhaseEvent.completedWithFailures(phaseName, format("%s out of %s submitted queries failed", statusCount.get(FAILED), queriesSubmitted));
}
return BenchmarkPhaseEvent.succeeded(phaseName);
}
}