AsyncPageTransportServlet.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.server;

import com.facebook.airlift.concurrent.BoundedExecutor;
import com.facebook.airlift.log.Logger;
import com.facebook.airlift.stats.TimeStat;
import com.facebook.presto.execution.TaskId;
import com.facebook.presto.execution.TaskManager;
import com.facebook.presto.execution.buffer.BufferInfo;
import com.facebook.presto.execution.buffer.BufferResult;
import com.facebook.presto.execution.buffer.OutputBuffers.OutputBufferId;
import com.facebook.presto.execution.buffer.PageBufferInfo;
import com.facebook.presto.operator.ExchangeClientConfig;
import com.facebook.presto.spi.page.SerializedPage;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import org.weakref.jmx.Managed;
import org.weakref.jmx.Nested;

import javax.annotation.security.RolesAllowed;
import javax.inject.Inject;
import javax.servlet.AsyncContext;
import javax.servlet.AsyncEvent;
import javax.servlet.AsyncListener;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import java.io.IOException;
import java.util.Enumeration;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;

import static com.facebook.airlift.concurrent.MoreFutures.addTimeout;
import static com.facebook.presto.PrestoMediaTypes.PRESTO_PAGES;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_BUFFER_COMPLETE;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_BUFFER_REMAINING_BYTES;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_MAX_SIZE;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_PAGE_NEXT_TOKEN;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_PAGE_TOKEN;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_TASK_INSTANCE_ID;
import static com.facebook.presto.server.security.RoleType.INTERNAL;
import static com.facebook.presto.spi.page.PagesSerdeUtil.PAGE_METADATA_SIZE;
import static com.facebook.presto.util.TaskUtils.DEFAULT_MAX_WAIT_TIME;
import static com.facebook.presto.util.TaskUtils.randomizeWaitTime;
import static com.google.common.net.HttpHeaders.CONTENT_LENGTH;
import static com.google.common.net.HttpHeaders.CONTENT_TYPE;
import static com.google.common.util.concurrent.Futures.addCallback;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static java.lang.Long.parseLong;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static javax.servlet.http.HttpServletResponse.SC_BAD_REQUEST;
import static javax.servlet.http.HttpServletResponse.SC_INTERNAL_SERVER_ERROR;
import static javax.servlet.http.HttpServletResponse.SC_NO_CONTENT;

@RolesAllowed(INTERNAL)
public class AsyncPageTransportServlet
        extends HttpServlet
{
    private static final Logger log = Logger.get(AsyncPageTransportServlet.class);

    private final Duration pageTransportTimeout;
    private final TaskManager taskManager;
    private final Executor responseExecutor;
    private final ScheduledExecutorService timeoutExecutor;

    private final TimeStat readFromOutputBufferTime = new TimeStat();
    private final TimeStat resultsRequestTime = new TimeStat();

    @Inject
    public AsyncPageTransportServlet(
            TaskManager taskManager,
            ExchangeClientConfig exchangeClientConfig,
            @ForAsyncRpc BoundedExecutor responseExecutor,
            @ForAsyncRpc ScheduledExecutorService timeoutExecutor)
    {
        this.taskManager = requireNonNull(taskManager, "taskManager is null");
        this.pageTransportTimeout = requireNonNull(exchangeClientConfig.getAsyncPageTransportTimeout(), "asyncPageTransportTimeout is null");
        this.responseExecutor = requireNonNull(responseExecutor, "responseExecutor is null");
        this.timeoutExecutor = requireNonNull(timeoutExecutor, "timeoutExecutor is null");
    }

    @VisibleForTesting
    protected AsyncPageTransportServlet()
    {
        this.taskManager = null;
        this.pageTransportTimeout = null;
        this.responseExecutor = null;
        this.timeoutExecutor = null;
    }

    @Override
    protected void doGet(HttpServletRequest request, HttpServletResponse response)
            throws IOException
    {
        parseURI(request.getRequestURI(), request, response);
    }

    protected void reportFailure(HttpServletResponse response, String message)
            throws IOException
    {
        response.sendError(SC_BAD_REQUEST, message);
    }

    protected void parseURI(String requestURI, HttpServletRequest request, HttpServletResponse response)
            throws IOException
    {
        // Split a task URI without allocating a list and unnecessary strings
        // Example:  /v1/task/async/{taskId}/results/{bufferId}/{token}
        TaskId taskId = null;
        OutputBufferId bufferId = null;
        long token = 0;

        if (request != null) {
            Enumeration<String> headerNames = request.getHeaderNames();
            while (headerNames.hasMoreElements()) {
                String headerName = headerNames.nextElement();
                String headerValue = request.getHeader(headerName);
                if (headerName.contains("\r") || headerName.contains("\n")) {
                    throw new IllegalArgumentException(format("Invalid header name: %s", headerName));
                }
                if (headerValue.contains("\r") || headerValue.contains("\n")) {
                    throw new IllegalArgumentException(format("Invalid header value: %s", headerValue));
                }
            }
        }

        int previousIndex = -1;
        for (int part = 0; part < 8; part++) {
            int nextIndex = requestURI.indexOf('/', previousIndex + 1);

            if (nextIndex == -1 && part != 7 || nextIndex != -1 && part == 7) {
                reportFailure(response, format("Unexpected URI for task result request in async mode: %s", requestURI));
                return;
            }

            switch (part) {
                case 4:
                    taskId = TaskId.valueOf(requestURI.substring(previousIndex + 1, nextIndex));
                    break;
                case 6:
                    bufferId = OutputBufferId.fromString(requestURI.substring(previousIndex + 1, nextIndex));
                    break;
                case 7:
                    token = parseLong(requestURI.substring(previousIndex + 1));
                    break;
            }

            previousIndex = nextIndex;
        }
        // This is sent forward instead of returned to avoid allocations
        processRequest(requestURI, taskId, bufferId, token, request, response);
    }

    protected void processRequest(
            String requestURI, TaskId taskId, OutputBufferId bufferId, long token,
            HttpServletRequest request, HttpServletResponse response)
            throws IOException
    {
        long start = System.nanoTime();
        DataSize maxSize = DataSize.valueOf(request.getHeader(PRESTO_MAX_SIZE));

        AsyncContext asyncContext = request.startAsync(request, response);

        // wait time to get results
        Duration waitTime = randomizeWaitTime(DEFAULT_MAX_WAIT_TIME);
        asyncContext.setTimeout(waitTime.toMillis() + pageTransportTimeout.toMillis());

        asyncContext.addListener(new AsyncListener()
        {
            public void onComplete(AsyncEvent event)
            {
                resultsRequestTime.add(Duration.nanosSince(start));
            }

            public void onError(AsyncEvent event)
                    throws IOException
            {
                String errorMessage = format("Server error to process task result request %s : %s", requestURI, event.getThrowable().getMessage());
                log.error(event.getThrowable(), errorMessage);
                response.sendError(SC_INTERNAL_SERVER_ERROR, errorMessage);
            }

            public void onStartAsync(AsyncEvent event)
            {
            }

            public void onTimeout(AsyncEvent event)
                    throws IOException
            {
                String errorMessage = format("Server timeout to process task result request: %s", requestURI);
                log.error(event.getThrowable(), errorMessage);
                response.sendError(SC_INTERNAL_SERVER_ERROR, errorMessage);
            }
        });

        ListenableFuture<BufferResult> bufferResultFuture = taskManager.getTaskResults(taskId, bufferId, token, maxSize.toBytes());
        bufferResultFuture = addTimeout(
                bufferResultFuture,
                () -> BufferResult.emptyResults(
                        taskManager.getTaskInstanceId(taskId),
                        token,
                        taskManager.getOutputBufferInfo(taskId).getBuffers().stream()
                                .filter(bufferInfo -> bufferInfo.getBufferId().equals(bufferId))
                                .map(BufferInfo::getPageBufferInfo)
                                .map(PageBufferInfo::getBufferedBytes)
                                .findFirst()
                                .orElse(0L),
                        false),
                waitTime,
                timeoutExecutor);

        bufferResultFuture.addListener(() -> readFromOutputBufferTime.add(Duration.nanosSince(start)), directExecutor());

        ServletOutputStream out = response.getOutputStream();
        addCallback(bufferResultFuture, new FutureCallback<BufferResult>()
                {
                    @Override
                    public void onSuccess(BufferResult bufferResult)
                    {
                        response.setHeader(CONTENT_TYPE, PRESTO_PAGES);
                        response.setHeader(PRESTO_TASK_INSTANCE_ID, bufferResult.getTaskInstanceId());
                        response.setHeader(PRESTO_PAGE_TOKEN, String.valueOf(bufferResult.getToken()));
                        response.setHeader(PRESTO_PAGE_NEXT_TOKEN, String.valueOf(bufferResult.getNextToken()));
                        response.setHeader(PRESTO_BUFFER_COMPLETE, String.valueOf(bufferResult.isBufferComplete()));
                        response.setHeader(PRESTO_BUFFER_REMAINING_BYTES, String.valueOf(bufferResult.getBufferedBytes()));

                        List<SerializedPage> serializedPages = bufferResult.getSerializedPages();
                        if (serializedPages.isEmpty()) {
                            response.setStatus(SC_NO_CONTENT);
                            asyncContext.complete();
                        }
                        else {
                            int contentLength = (serializedPages.size() * PAGE_METADATA_SIZE) + serializedPages.stream()
                                    .mapToInt(SerializedPage::getSizeInBytes)
                                    .sum();
                            response.setHeader(CONTENT_LENGTH, String.valueOf(contentLength));
                            out.setWriteListener(new SerializedPageWriteListener(serializedPages, asyncContext, out));
                        }
                    }

                    @Override
                    public void onFailure(Throwable thrown)
                    {
                        String errorMessage = format("Error getting task result from TaskManager for request %s : %s", requestURI, thrown.getMessage());
                        log.error(thrown, errorMessage);
                        try {
                            response.sendError(SC_INTERNAL_SERVER_ERROR, errorMessage);
                        }
                        catch (IOException e) {
                            log.error(e, "Failed to send response with error code: %s", e.getMessage());
                        }
                        asyncContext.complete();
                    }
                },
                responseExecutor);
    }

    @Managed
    @Nested
    public TimeStat getReadFromOutputBufferTime()
    {
        return readFromOutputBufferTime;
    }

    @Managed
    @Nested
    public TimeStat getResultsRequestTime()
    {
        return resultsRequestTime;
    }
}