ResourceManagerProxy.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.resourcemanager;

import com.facebook.airlift.http.client.BodyGenerator;
import com.facebook.airlift.http.client.HeaderName;
import com.facebook.airlift.http.client.HttpClient;
import com.facebook.airlift.http.client.Request;
import com.facebook.presto.metadata.InternalNodeManager;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ListMultimap;
import com.google.common.io.ByteStreams;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import io.airlift.units.Duration;

import javax.inject.Inject;
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.container.AsyncResponse;
import javax.ws.rs.core.Response;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.net.URI;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;

import static com.facebook.airlift.http.client.HttpStatus.INTERNAL_SERVER_ERROR;
import static com.facebook.airlift.http.server.AsyncResponseHandler.bindAsyncResponse;
import static com.google.common.base.Verify.verify;
import static com.google.common.net.HttpHeaders.CONTENT_TYPE;
import static com.google.common.net.HttpHeaders.COOKIE;
import static com.google.common.net.HttpHeaders.USER_AGENT;
import static com.google.common.net.HttpHeaders.X_FORWARDED_FOR;
import static com.google.common.util.concurrent.Futures.transform;
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Collections.list;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;
import static javax.ws.rs.core.MediaType.TEXT_PLAIN;
import static javax.ws.rs.core.Response.Status.GATEWAY_TIMEOUT;
import static javax.ws.rs.core.Response.status;

@SuppressWarnings("UnstableApiUsage")
public class ResourceManagerProxy
{
    private final InternalNodeManager internalNodeManager;
    private final HttpClient httpClient;
    private final Duration asyncTimeout;
    private final Executor executor;

    @Inject
    private ResourceManagerProxy(
            InternalNodeManager internalNodeManager,
            @ForResourceManager HttpClient httpClient,
            ResourceManagerConfig resourceManagerConfig,
            @ForResourceManager ListeningExecutorService executor)
    {
        this.internalNodeManager = requireNonNull(internalNodeManager, "internalNodeManager is null");
        this.httpClient = requireNonNull(httpClient, "httpClient is null");
        this.asyncTimeout = requireNonNull(resourceManagerConfig, "resourceManagerConfig is null").getProxyAsyncTimeout();
        this.executor = requireNonNull(executor, "executor is null");
    }

    public void performRequest(
            HttpServletRequest servletRequest,
            AsyncResponse asyncResponse,
            URI remoteUri)
    {
        try {
            BodyGenerator bodyGenerator = new InputStreamBodyGenerator(servletRequest.getInputStream());
            Request request = createRequest(servletRequest, servletRequest.getMethod(), remoteUri, bodyGenerator);
            ListenableFuture<ProxyResponse> proxyResponse = httpClient.executeAsync(request, new ResponseHandler());
            ListenableFuture<Response> future = transform(proxyResponse, this::toResponse, executor);
            setupAsyncResponse(servletRequest, asyncResponse, future);
        }
        catch (IOException e) {
            asyncResponse.resume(e);
        }
    }

    private Request createRequest(HttpServletRequest servletRequest, String httpMethod, URI remoteUri, BodyGenerator bodyGenerator)
    {
        Request.Builder requestBuilder = new Request.Builder()
                .setMethod(httpMethod)
                .setUri(remoteUri)
                .setPreserveAuthorizationOnRedirect(true)
                .setBodyGenerator(bodyGenerator);

        for (String name : list(servletRequest.getHeaderNames())) {
            if (isPrestoHeader(name) || name.equalsIgnoreCase(COOKIE)) {
                for (String value : list(servletRequest.getHeaders(name))) {
                    requestBuilder.addHeader(name, value);
                }
            }
            else if (name.equalsIgnoreCase(USER_AGENT)) {
                for (String value : list(servletRequest.getHeaders(name))) {
                    requestBuilder.addHeader(name, "[Resource Manager] " + value);
                }
            }
        }

        StringBuilder xForwardedFor = new StringBuilder();
        if (servletRequest.getHeader(X_FORWARDED_FOR) != null) {
            xForwardedFor.append(servletRequest.getHeader(X_FORWARDED_FOR) + ",");
        }
        xForwardedFor.append(servletRequest.getRemoteAddr());
        requestBuilder.addHeader(X_FORWARDED_FOR, xForwardedFor.toString());

        return requestBuilder.build();
    }

    private static boolean isPrestoHeader(String name)
    {
        return name.toLowerCase(ENGLISH).startsWith("x-presto-");
    }

    private Response toResponse(ProxyResponse input)
    {
        Response.ResponseBuilder entity = status(input.getStatusCode()).entity(input.getBody());
        input.getHeaders().forEach(((headerName, value) -> entity.header(headerName.toString(), value)));
        return entity.build();
    }

    private void setupAsyncResponse(HttpServletRequest servletRequest, AsyncResponse asyncResponse, ListenableFuture<Response> future)
    {
        bindAsyncResponse(asyncResponse, future, executor)
                .withTimeout(asyncTimeout, () -> status(GATEWAY_TIMEOUT)
                        .type(TEXT_PLAIN)
                        .entity(format("Request to remote Presto server (%s), current node (%s), timed out after %s",
                                servletRequest.getRemoteAddr(),
                                internalNodeManager.getCurrentNode().getNodeIdentifier(),
                                asyncTimeout.toString()))
                        .build());
    }

    private static class InputStreamBodyGenerator
            implements BodyGenerator
    {
        private final InputStream inputStream;
        private final AtomicBoolean called = new AtomicBoolean();

        public InputStreamBodyGenerator(InputStream inputStream)
        {
            this.inputStream = requireNonNull(inputStream, "inputStream is null");
        }

        @Override
        public void write(OutputStream outputStream)
                throws Exception
        {
            verify(called.compareAndSet(false, true), "Already read servlet request body");
            try {
                ByteStreams.copy(inputStream, outputStream);
            }
            finally {
                inputStream.close();
            }
        }
    }

    private static class ResponseHandler
            implements com.facebook.airlift.http.client.ResponseHandler
    {
        @Override
        public ProxyResponse handleException(Request request, Exception exception)
        {
            StringWriter sw = new StringWriter();
            exception.printStackTrace(new PrintWriter(sw));
            String message = format("Exception receiving response from %s: %s", request.getUri(), sw.toString());
            InputStream inputStream = new ByteArrayInputStream(message.getBytes(UTF_8));
            return new ProxyResponse(INTERNAL_SERVER_ERROR.code(), ImmutableListMultimap.of(HeaderName.of(CONTENT_TYPE), TEXT_PLAIN), inputStream);
        }

        @Override
        public ProxyResponse handle(Request request, com.facebook.airlift.http.client.Response response)
        {
            try {
                return new ProxyResponse(response.getStatusCode(), response.getHeaders(), response.getInputStream());
            }
            catch (IOException e) {
                return handleException(request, e);
            }
        }
    }

    private static class ProxyResponse
    {
        private final int statusCode;
        private final ListMultimap<HeaderName, String> headers;
        private final InputStream body;

        ProxyResponse(int statusCode, ListMultimap<HeaderName, String> headers, InputStream body)
        {
            this.statusCode = statusCode;
            this.headers = requireNonNull(headers, "headers is null");
            this.body = requireNonNull(body, "body is null");
        }

        public int getStatusCode()
        {
            return statusCode;
        }

        public ListMultimap<HeaderName, String> getHeaders()
        {
            return headers;
        }

        public InputStream getBody()
        {
            return body;
        }
    }
}