PrestoSparkHttpTaskClient.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.spark.execution.http;
import com.facebook.airlift.http.client.HeaderName;
import com.facebook.airlift.http.client.HttpClient;
import com.facebook.airlift.http.client.Request;
import com.facebook.airlift.http.client.Response;
import com.facebook.airlift.http.client.ResponseHandler;
import com.facebook.airlift.json.JsonCodec;
import com.facebook.presto.Session;
import com.facebook.presto.execution.TaskId;
import com.facebook.presto.execution.TaskInfo;
import com.facebook.presto.execution.TaskSource;
import com.facebook.presto.execution.buffer.OutputBuffers;
import com.facebook.presto.execution.scheduler.TableWriteInfo;
import com.facebook.presto.operator.HttpRpcShuffleClient.PageResponseHandler;
import com.facebook.presto.operator.PageBufferClient.PagesResponse;
import com.facebook.presto.server.RequestErrorTracker;
import com.facebook.presto.server.SimpleHttpResponseCallback;
import com.facebook.presto.server.SimpleHttpResponseHandler;
import com.facebook.presto.server.SimpleHttpResponseHandlerStats;
import com.facebook.presto.server.TaskUpdateRequest;
import com.facebook.presto.server.smile.BaseResponse;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.sql.planner.PlanFragment;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ListMultimap;
import com.google.common.io.ByteStreams;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import javax.annotation.concurrent.ThreadSafe;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import static com.facebook.airlift.concurrent.MoreFutures.getFutureValue;
import static com.facebook.airlift.http.client.HttpUriBuilder.uriBuilderFrom;
import static com.facebook.airlift.http.client.Request.Builder.prepareDelete;
import static com.facebook.airlift.http.client.Request.Builder.prepareGet;
import static com.facebook.airlift.http.client.Request.Builder.preparePost;
import static com.facebook.airlift.http.client.ResponseHandlerUtils.propagate;
import static com.facebook.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_MAX_SIZE;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_MAX_WAIT;
import static com.facebook.presto.server.RequestHelpers.setContentTypeHeaders;
import static com.facebook.presto.server.smile.AdaptingJsonResponseHandler.createAdaptingJsonResponseHandler;
import static com.facebook.presto.spi.StandardErrorCode.NATIVE_EXECUTION_TASK_ERROR;
import static com.facebook.presto.spi.StandardErrorCode.REMOTE_TASK_ERROR;
import static com.google.common.util.concurrent.Futures.addCallback;
import static com.google.common.util.concurrent.Futures.transformAsync;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static java.util.Objects.requireNonNull;
/**
* An abstraction of HTTP client that communicates with the locally running Presto worker process. It exposes worker endpoints to simple method calls.
*/
@ThreadSafe
public class PrestoSparkHttpTaskClient
{
private static final String TASK_URI = "/v1/task/";
private final HttpClient httpClient;
private final URI location;
private final URI taskUri;
private final JsonCodec<TaskInfo> taskInfoCodec;
private final JsonCodec<PlanFragment> planFragmentCodec;
private final JsonCodec<BatchTaskUpdateRequest> taskUpdateRequestCodec;
private final Duration infoRefreshMaxWait;
private final Executor executor;
private final ScheduledExecutorService scheduledExecutorService;
private final Duration remoteTaskMaxErrorDuration;
public PrestoSparkHttpTaskClient(
HttpClient httpClient,
TaskId taskId,
URI location,
JsonCodec<TaskInfo> taskInfoCodec,
JsonCodec<PlanFragment> planFragmentCodec,
JsonCodec<BatchTaskUpdateRequest> taskUpdateRequestCodec,
Duration infoRefreshMaxWait,
Executor executor,
ScheduledExecutorService scheduledExecutorService,
Duration remoteTaskMaxErrorDuration)
{
this.httpClient = requireNonNull(httpClient, "httpClient is null");
this.location = requireNonNull(location, "location is null");
this.taskInfoCodec = requireNonNull(taskInfoCodec, "taskInfoCodec is null");
this.planFragmentCodec = requireNonNull(planFragmentCodec, "planFragmentCodec is null");
this.taskUpdateRequestCodec = requireNonNull(taskUpdateRequestCodec, "taskUpdateRequestCodec is null");
this.taskUri = createTaskUri(location, taskId);
this.infoRefreshMaxWait = requireNonNull(infoRefreshMaxWait, "infoRefreshMaxWait is null");
this.executor = requireNonNull(executor, "executor is null");
this.scheduledExecutorService = requireNonNull(scheduledExecutorService, "scheduledExecutorService is null");
this.remoteTaskMaxErrorDuration = requireNonNull(remoteTaskMaxErrorDuration, "remoteTaskMaxErrorDuration is null");
}
/**
* Get results from a native engine task that ends with none shuffle operator. It always fetches from a single buffer.
*/
public ListenableFuture<PagesResponse> getResults(long token, DataSize maxResponseSize)
{
RequestErrorTracker errorTracker = new RequestErrorTracker(
"NativeExecution",
location,
NATIVE_EXECUTION_TASK_ERROR,
"getResults encountered too many errors talking to native process",
remoteTaskMaxErrorDuration,
scheduledExecutorService,
"sending update request to native process");
SettableFuture<PagesResponse> result = SettableFuture.create();
scheduleGetResultsRequest(prepareGetResultsRequest(token, maxResponseSize), errorTracker, result);
return result;
}
private void scheduleGetResultsRequest(
Request request,
RequestErrorTracker errorTracker,
SettableFuture<PagesResponse> result)
{
ListenableFuture<PagesResponse> responseFuture = transformAsync(
errorTracker.acquireRequestPermit(),
ignored -> {
errorTracker.startRequest();
return httpClient.executeAsync(request, new PageResponseHandler());
},
executor);
addCallback(responseFuture, new FutureCallback<PagesResponse>()
{
@Override
public void onSuccess(PagesResponse response)
{
errorTracker.requestSucceeded();
result.set(response);
}
@Override
public void onFailure(Throwable failure)
{
if (failure instanceof PrestoException) {
// do not retry on PrestoException
result.setException(failure);
return;
}
try {
errorTracker.requestFailed(failure);
scheduleGetResultsRequest(request, errorTracker, result);
}
catch (Throwable t) {
result.setException(t);
}
}
}, executor);
}
private Request prepareGetResultsRequest(long token, DataSize maxResponseSize)
{
return prepareGet()
.setHeader(PRESTO_MAX_SIZE, maxResponseSize.toString())
.setUri(uriBuilderFrom(taskUri)
.appendPath("/results/0")
.appendPath(String.valueOf(token))
.build())
.build();
}
public void acknowledgeResultsAsync(long nextToken)
{
URI uri = uriBuilderFrom(taskUri)
.appendPath("/results/0")
.appendPath(String.valueOf(nextToken))
.appendPath("acknowledge")
.build();
Request request = prepareGet().setUri(uri).build();
executeWithRetries("acknowledgeResults", "acknowledge task results are received", request, new BytesResponseHandler());
}
public ListenableFuture<Void> abortResultsAsync()
{
Request request = prepareDelete().setUri(
uriBuilderFrom(taskUri)
.appendPath("/results/0")
.build())
.build();
return asVoidFuture(executeWithRetries("abortResults", "abort task results", request, new BytesResponseHandler()));
}
private static ListenableFuture<Void> asVoidFuture(ListenableFuture<?> future)
{
return Futures.transform(future, (ignored) -> null, directExecutor());
}
public TaskInfo getTaskInfo()
{
Request request = setContentTypeHeaders(false, prepareGet())
.setHeader(PRESTO_MAX_WAIT, infoRefreshMaxWait.toString())
.setUri(taskUri)
.build();
ListenableFuture<TaskInfo> future = executeWithRetries(
"getTaskInfo",
"get remote task info",
request,
createAdaptingJsonResponseHandler(taskInfoCodec));
return getFutureValue(future);
}
public TaskInfo updateTask(
List<TaskSource> sources,
PlanFragment planFragment,
TableWriteInfo tableWriteInfo,
Optional<String> shuffleWriteInfo,
Optional<String> broadcastBasePath,
Session session,
OutputBuffers outputBuffers)
{
Optional<byte[]> fragment = Optional.of(planFragment.bytesForTaskSerialization(planFragmentCodec));
Optional<TableWriteInfo> writeInfo = Optional.of(tableWriteInfo);
TaskUpdateRequest updateRequest = new TaskUpdateRequest(
session.toSessionRepresentation(),
session.getIdentity().getExtraCredentials(),
fragment,
sources,
outputBuffers,
writeInfo);
BatchTaskUpdateRequest batchTaskUpdateRequest = new BatchTaskUpdateRequest(updateRequest, shuffleWriteInfo, broadcastBasePath);
Request request = setContentTypeHeaders(false, preparePost())
.setUri(uriBuilderFrom(taskUri)
.appendPath("batch")
.build())
.setBodyGenerator(createStaticBodyGenerator(taskUpdateRequestCodec.toBytes(batchTaskUpdateRequest)))
.build();
ListenableFuture<TaskInfo> future = executeWithRetries(
"updateTask",
"create or update remote task",
request,
createAdaptingJsonResponseHandler(taskInfoCodec));
return getFutureValue(future);
}
public URI getLocation()
{
return location;
}
public URI getTaskUri()
{
return taskUri;
}
private URI createTaskUri(URI baseUri, TaskId taskId)
{
return uriBuilderFrom(baseUri)
.appendPath(TASK_URI)
.appendPath(taskId.toString())
.build();
}
private <T> ListenableFuture<T> executeWithRetries(
String name,
String description,
Request request,
ResponseHandler<BaseResponse<T>, RuntimeException> responseHandler)
{
RequestErrorTracker errorTracker = new RequestErrorTracker(
"NativeExecution",
location,
NATIVE_EXECUTION_TASK_ERROR,
name + " encountered too many errors talking to native process",
remoteTaskMaxErrorDuration,
scheduledExecutorService,
description);
SettableFuture<T> result = SettableFuture.create();
scheduleRequest(request, responseHandler, errorTracker, result);
return result;
}
private <T> void scheduleRequest(
Request request,
ResponseHandler<BaseResponse<T>, RuntimeException> responseHandler,
RequestErrorTracker errorTracker,
SettableFuture<T> result)
{
ListenableFuture<BaseResponse<T>> responseFuture = transformAsync(
errorTracker.acquireRequestPermit(),
ignored -> {
errorTracker.startRequest();
return httpClient.executeAsync(request, responseHandler);
},
executor);
SimpleHttpResponseCallback<T> callback = new SimpleHttpResponseCallback<T>()
{
@Override
public void success(T value)
{
result.set(value);
}
@Override
public void failed(Throwable failure)
{
if (failure instanceof PrestoException) {
// do not retry on PrestoException
result.setException(failure);
return;
}
try {
errorTracker.requestFailed(failure);
scheduleRequest(request, responseHandler, errorTracker, result);
}
catch (Throwable t) {
result.setException(t);
}
}
@Override
public void fatal(Throwable cause)
{
result.setException(cause);
}
};
addCallback(
responseFuture,
new SimpleHttpResponseHandler<>(
callback,
location,
new SimpleHttpResponseHandlerStats(),
REMOTE_TASK_ERROR),
executor);
}
private static class BytesResponseHandler
implements ResponseHandler<BaseResponse<byte[]>, RuntimeException>
{
@Override
public BaseResponse<byte[]> handleException(Request request, Exception exception)
{
throw propagate(request, exception);
}
@Override
public BaseResponse<byte[]> handle(Request request, Response response)
{
return new BytesResponse(
response.getStatusCode(),
response.getHeaders(),
readResponseBytes(response));
}
private static byte[] readResponseBytes(Response response)
{
try {
InputStream inputStream = response.getInputStream();
if (inputStream == null) {
return new byte[] {};
}
return ByteStreams.toByteArray(inputStream);
}
catch (IOException e) {
throw new RuntimeException("Error reading response from server", e);
}
}
}
private static class BytesResponse
implements BaseResponse<byte[]>
{
private final int statusCode;
private final ListMultimap<HeaderName, String> headers;
private final byte[] bytes;
public BytesResponse(int statusCode, ListMultimap<HeaderName, String> headers, byte[] bytes)
{
this.statusCode = statusCode;
this.headers = ImmutableListMultimap.copyOf(requireNonNull(headers, "headers is null"));
this.bytes = bytes;
}
@Override
public int getStatusCode()
{
return statusCode;
}
@Override
public String getHeader(String name)
{
List<String> values = getHeaders().get(HeaderName.of(name));
return values.isEmpty() ? null : values.get(0);
}
@Override
public List<String> getHeaders(String name)
{
return headers.get(HeaderName.of(name));
}
@Override
public ListMultimap<HeaderName, String> getHeaders()
{
return headers;
}
@Override
public boolean hasValue()
{
return true;
}
@Override
public byte[] getValue()
{
return bytes;
}
@Override
public int getResponseSize()
{
return bytes.length;
}
@Override
public byte[] getResponseBytes()
{
return bytes;
}
@Override
public Exception getException()
{
return null;
}
}
}