ContinuousTaskStatusFetcher.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.server.remotetask;
import com.facebook.airlift.concurrent.SetThreadName;
import com.facebook.airlift.http.client.HttpClient;
import com.facebook.airlift.http.client.Request;
import com.facebook.airlift.http.client.ResponseHandler;
import com.facebook.airlift.http.client.thrift.ThriftRequestUtils;
import com.facebook.airlift.http.client.thrift.ThriftResponseHandler;
import com.facebook.airlift.json.Codec;
import com.facebook.airlift.json.JsonCodec;
import com.facebook.airlift.json.smile.SmileCodec;
import com.facebook.airlift.log.Logger;
import com.facebook.drift.transport.netty.codec.Protocol;
import com.facebook.presto.execution.StateMachine;
import com.facebook.presto.execution.TaskId;
import com.facebook.presto.execution.TaskStatus;
import com.facebook.presto.server.RequestErrorTracker;
import com.facebook.presto.server.SimpleHttpResponseCallback;
import com.facebook.presto.server.SimpleHttpResponseHandler;
import com.facebook.presto.server.smile.BaseResponse;
import com.facebook.presto.server.thrift.ThriftHttpResponseHandler;
import com.facebook.presto.spi.HostAddress;
import com.facebook.presto.spi.PrestoException;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.units.Duration;
import javax.annotation.concurrent.GuardedBy;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import static com.facebook.airlift.http.client.HttpUriBuilder.uriBuilderFrom;
import static com.facebook.airlift.http.client.Request.Builder.prepareGet;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_CURRENT_STATE;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_MAX_WAIT;
import static com.facebook.presto.server.RequestErrorTracker.taskRequestErrorTracker;
import static com.facebook.presto.server.RequestHelpers.getBinaryTransportBuilder;
import static com.facebook.presto.server.RequestHelpers.getJsonTransportBuilder;
import static com.facebook.presto.server.smile.AdaptingJsonResponseHandler.createAdaptingJsonResponseHandler;
import static com.facebook.presto.server.smile.FullSmileResponseHandler.createFullSmileResponseHandler;
import static com.facebook.presto.server.thrift.ThriftCodecWrapper.unwrapThriftCodec;
import static com.facebook.presto.spi.StandardErrorCode.REMOTE_TASK_ERROR;
import static com.facebook.presto.spi.StandardErrorCode.REMOTE_TASK_MISMATCH;
import static com.facebook.presto.util.Failures.REMOTE_TASK_MISMATCH_ERROR;
import static io.airlift.units.Duration.nanosSince;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
class ContinuousTaskStatusFetcher
implements SimpleHttpResponseCallback<TaskStatus>
{
private static final Logger log = Logger.get(ContinuousTaskStatusFetcher.class);
private final TaskId taskId;
private final Consumer<Throwable> onFail;
private final StateMachine<TaskStatus> taskStatus;
private final Codec<TaskStatus> taskStatusCodec;
private final Duration refreshMaxWait;
private final Executor executor;
private final HttpClient httpClient;
private final RequestErrorTracker errorTracker;
private final RemoteTaskStats stats;
private final boolean binaryTransportEnabled;
private final boolean thriftTransportEnabled;
private final Protocol thriftProtocol;
private final AtomicLong currentRequestStartNanos = new AtomicLong();
@GuardedBy("this")
private boolean running;
@GuardedBy("this")
private ListenableFuture<BaseResponse<TaskStatus>> future;
public ContinuousTaskStatusFetcher(
Consumer<Throwable> onFail,
TaskId taskId,
TaskStatus initialTaskStatus,
Duration refreshMaxWait,
Codec<TaskStatus> taskStatusCodec,
Executor executor,
HttpClient httpClient,
Duration maxErrorDuration,
ScheduledExecutorService errorScheduledExecutor,
RemoteTaskStats stats,
boolean binaryTransportEnabled,
boolean thriftTransportEnabled,
Protocol thriftProtocol)
{
requireNonNull(initialTaskStatus, "initialTaskStatus is null");
this.taskId = requireNonNull(taskId, "taskId is null");
this.onFail = requireNonNull(onFail, "onFail is null");
this.taskStatus = new StateMachine<>("task-" + taskId, executor, initialTaskStatus);
this.refreshMaxWait = requireNonNull(refreshMaxWait, "refreshMaxWait is null");
this.taskStatusCodec = requireNonNull(taskStatusCodec, "taskStatusCodec is null");
this.executor = requireNonNull(executor, "executor is null");
this.httpClient = requireNonNull(httpClient, "httpClient is null");
this.errorTracker = taskRequestErrorTracker(taskId, initialTaskStatus.getSelf(), maxErrorDuration, errorScheduledExecutor, "getting task status");
this.stats = requireNonNull(stats, "stats is null");
this.binaryTransportEnabled = binaryTransportEnabled;
this.thriftTransportEnabled = thriftTransportEnabled;
this.thriftProtocol = requireNonNull(thriftProtocol, "thriftProtocol is null");
}
public synchronized void start()
{
if (running) {
// already running
return;
}
running = true;
scheduleNextRequest();
}
public synchronized void stop()
{
running = false;
if (future != null) {
// do not terminate if the request is already running to avoid closing pooled connections
future.cancel(false);
future = null;
}
}
private synchronized void scheduleNextRequest()
{
// stopped or done?
TaskStatus taskStatus = getTaskStatus();
if (!running || taskStatus.getState().isDone()) {
return;
}
// outstanding request?
if (future != null && !future.isDone()) {
// this should never happen
log.error("Can not reschedule update because an update is already running");
return;
}
// if throttled due to error, asynchronously wait for timeout and try again
ListenableFuture<?> errorRateLimit = errorTracker.acquireRequestPermit();
if (!errorRateLimit.isDone()) {
errorRateLimit.addListener(this::scheduleNextRequest, executor);
return;
}
Request.Builder requestBuilder;
ResponseHandler responseHandler;
if (thriftTransportEnabled) {
requestBuilder = ThriftRequestUtils.prepareThriftGet(thriftProtocol);
responseHandler = new ThriftResponseHandler(unwrapThriftCodec(taskStatusCodec));
}
else if (binaryTransportEnabled) {
requestBuilder = getBinaryTransportBuilder(prepareGet());
responseHandler = createFullSmileResponseHandler((SmileCodec<TaskStatus>) taskStatusCodec);
}
else {
requestBuilder = getJsonTransportBuilder(prepareGet());
responseHandler = createAdaptingJsonResponseHandler((JsonCodec<TaskStatus>) taskStatusCodec);
}
Request request = requestBuilder.setUri(uriBuilderFrom(taskStatus.getSelf()).appendPath("status").build())
.setHeader(PRESTO_CURRENT_STATE, taskStatus.getState().toString())
.setHeader(PRESTO_MAX_WAIT, refreshMaxWait.toString())
.build();
errorTracker.startRequest();
future = httpClient.executeAsync(request, responseHandler);
currentRequestStartNanos.set(System.nanoTime());
FutureCallback callback;
if (thriftTransportEnabled) {
callback = new ThriftHttpResponseHandler(this, request.getUri(), stats.getHttpResponseStats(), REMOTE_TASK_ERROR);
}
else {
callback = new SimpleHttpResponseHandler<>(this, request.getUri(), stats.getHttpResponseStats(), REMOTE_TASK_ERROR);
}
Futures.addCallback(
future,
callback,
executor);
}
TaskStatus getTaskStatus()
{
return taskStatus.get();
}
@Override
public void success(TaskStatus value)
{
try (SetThreadName ignored = new SetThreadName("ContinuousTaskStatusFetcher-%s", taskId)) {
updateStats(currentRequestStartNanos.get());
try {
updateTaskStatus(value);
errorTracker.requestSucceeded();
}
finally {
scheduleNextRequest();
}
}
}
@Override
public void failed(Throwable cause)
{
try (SetThreadName ignored = new SetThreadName("ContinuousTaskStatusFetcher-%s", taskId)) {
updateStats(currentRequestStartNanos.get());
try {
// if task not already done, record error
TaskStatus taskStatus = getTaskStatus();
if (!taskStatus.getState().isDone()) {
errorTracker.requestFailed(cause);
}
}
catch (Error e) {
onFail.accept(e);
throw e;
}
catch (RuntimeException e) {
onFail.accept(e);
}
finally {
scheduleNextRequest();
}
}
}
@Override
public void fatal(Throwable cause)
{
try (SetThreadName ignored = new SetThreadName("ContinuousTaskStatusFetcher-%s", taskId)) {
updateStats(currentRequestStartNanos.get());
onFail.accept(cause);
}
}
void updateTaskStatus(TaskStatus newValue)
{
// change to new value if old value is not changed and new value has a newer version
AtomicBoolean taskMismatch = new AtomicBoolean();
taskStatus.setIf(newValue, oldValue -> {
// did the task instance id change
boolean isEmpty = oldValue.getTaskInstanceIdLeastSignificantBits() == 0 && oldValue.getTaskInstanceIdMostSignificantBits() == 0;
if (!isEmpty &&
!(oldValue.getTaskInstanceIdLeastSignificantBits() == newValue.getTaskInstanceIdLeastSignificantBits() &&
oldValue.getTaskInstanceIdMostSignificantBits() == newValue.getTaskInstanceIdMostSignificantBits())) {
taskMismatch.set(true);
return false;
}
if (oldValue.getState().isDone()) {
// never update if the task has reached a terminal state
return false;
}
if (newValue.getVersion() < oldValue.getVersion()) {
// don't update to an older version (same version is ok)
return false;
}
return true;
});
if (taskMismatch.get()) {
// This will also set the task status to FAILED state directly.
// Additionally, this will issue a DELETE for the task to the worker.
// While sending the DELETE is not required, it is preferred because a task was created by the previous request.
onFail.accept(new PrestoException(REMOTE_TASK_MISMATCH, format("%s (%s)", REMOTE_TASK_MISMATCH_ERROR, HostAddress.fromUri(getTaskStatus().getSelf()))));
}
}
public synchronized boolean isRunning()
{
return running;
}
/**
* Listener is always notified asynchronously using a dedicated notification thread pool so, care should
* be taken to avoid leaking {@code this} when adding a listener in a constructor. Additionally, it is
* possible notifications are observed out of order due to the asynchronous execution.
*/
public void addStateChangeListener(StateMachine.StateChangeListener<TaskStatus> stateChangeListener)
{
taskStatus.addStateChangeListener(stateChangeListener);
}
private void updateStats(long currentRequestStartNanos)
{
stats.statusRoundTripMillis(nanosSince(currentRequestStartNanos).toMillis());
}
}