ReactorNettyHttpClient.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.server.remotetask;

import com.facebook.airlift.http.client.HeaderName;
import com.facebook.airlift.http.client.Request;
import com.facebook.airlift.http.client.RequestStats;
import com.facebook.airlift.http.client.Response;
import com.facebook.airlift.http.client.ResponseHandler;
import com.facebook.airlift.http.client.StaticBodyGenerator;
import com.facebook.airlift.log.Logger;
import com.google.common.base.Splitter;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ListMultimap;
import com.google.common.util.concurrent.SettableFuture;
import com.google.inject.Inject;
import io.airlift.units.Duration;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.jmx.JmxMeterRegistry;
import io.netty.channel.ChannelOption;
import io.netty.channel.epoll.Epoll;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.ssl.ApplicationProtocolConfig;
import io.netty.handler.ssl.OpenSsl;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslProvider;
import reactor.core.Disposable;
import reactor.core.publisher.Mono;
import reactor.netty.ByteBufFlux;
import reactor.netty.channel.MicrometerChannelMetricsRecorder;
import reactor.netty.http.HttpProtocol;
import reactor.netty.http.client.Http2AllocationStrategy;
import reactor.netty.http.client.HttpClient;
import reactor.netty.http.client.HttpClientResponse;
import reactor.netty.resources.ConnectionProvider;
import reactor.netty.resources.LoopResources;

import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.nio.file.Files;
import java.security.GeneralSecurityException;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import static com.facebook.airlift.security.pem.PemReader.loadPrivateKey;
import static com.facebook.airlift.security.pem.PemReader.readCertificateChain;
import static io.micrometer.core.instrument.Clock.SYSTEM;
import static io.micrometer.jmx.JmxConfig.DEFAULT;
import static io.netty.handler.ssl.ApplicationProtocolConfig.Protocol.ALPN;
import static io.netty.handler.ssl.ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT;
import static io.netty.handler.ssl.ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE;
import static io.netty.handler.ssl.ApplicationProtocolNames.HTTP_1_1;
import static io.netty.handler.ssl.ApplicationProtocolNames.HTTP_2;
import static io.netty.handler.ssl.SslProtocols.TLS_v1_2;
import static io.netty.handler.ssl.SslProtocols.TLS_v1_3;
import static io.netty.handler.ssl.SslProvider.JDK;
import static io.netty.handler.ssl.SslProvider.OPENSSL;
import static io.netty.handler.ssl.SslProvider.isAlpnSupported;
import static java.lang.String.format;
import static java.time.temporal.ChronoUnit.MILLIS;

public class ReactorNettyHttpClient
        implements com.facebook.airlift.http.client.HttpClient, Closeable
{
    private static final Logger log = Logger.get(ReactorNettyHttpClient.class);
    private static final HeaderName CONTENT_TYPE_HEADER_NAME = HeaderName.of("Content-Type");
    private static final HeaderName CONTENT_LENGTH_HEADER_NAME = HeaderName.of("Content-Length");

    private final Duration requestTimeout;
    private HttpClient httpClient;

    @Inject
    public ReactorNettyHttpClient(ReactorNettyHttpClientConfig config)
    {
        SslContext sslContext = null;
        if (config.isHttpsEnabled()) {
            try {
                File keyFile = new File(config.getKeyStorePath());
                File trustCertificateFile = new File(config.getTrustStorePath());
                if (!Files.exists(keyFile.toPath()) || !Files.isReadable(keyFile.toPath())) {
                    throw new IllegalArgumentException("KeyStore file path is unreadable or doesn't exist");
                }
                if (!Files.exists(trustCertificateFile.toPath()) || !Files.isReadable(trustCertificateFile.toPath())) {
                    throw new IllegalArgumentException("TrustStore file path is unreadable or doesn't exist");
                }
                PrivateKey privateKey = loadPrivateKey(keyFile, Optional.of(config.getKeyStorePassword()));
                X509Certificate[] certificateChain = readCertificateChain(keyFile).toArray(new X509Certificate[0]);
                X509Certificate[] trustChain = readCertificateChain(trustCertificateFile).toArray(new X509Certificate[0]);

                String os = System.getProperty("os.name");
                if (os.toLowerCase(Locale.ENGLISH).contains("linux")) {
                    // Make sure Open ssl is available for linux deployments
                    if (!OpenSsl.isAvailable()) {
                        throw new UnsupportedOperationException(format("OpenSsl is not unavailable. Stacktrace: %s", Arrays.toString(OpenSsl.unavailabilityCause().getStackTrace()).replace(',', '\n')));
                    }
                    // Make sure epoll threads are used for linux deployments
                    if (!Epoll.isAvailable()) {
                        throw new UnsupportedOperationException(format("Epoll is not unavailable. Stacktrace: %s", Arrays.toString(Epoll.unavailabilityCause().getStackTrace()).replace(',', '\n')));
                    }
                }

                SslProvider provider = isAlpnSupported(OPENSSL) ? OPENSSL : JDK;
                SslContextBuilder sslContextBuilder = SslContextBuilder.forClient()
                        .sslProvider(provider)
                        .protocols(TLS_v1_3, TLS_v1_2)
                        .keyManager(privateKey, certificateChain)
                        .trustManager(trustChain)
                        .applicationProtocolConfig(new ApplicationProtocolConfig(ALPN, NO_ADVERTISE, ACCEPT, HTTP_2, HTTP_1_1));
                if (config.getCipherSuites().isPresent()) {
                    sslContextBuilder.ciphers(Splitter
                            .on(',')
                            .trimResults()
                            .omitEmptyStrings()
                            .splitToList(config.getCipherSuites().get()));
                }

                sslContext = sslContextBuilder.build();
            }
            catch (IOException | GeneralSecurityException e) {
                throw new RuntimeException("Failed to configure SSL context", e);
            }
        }

        /*
         * This is like wrapper and underlying there is a separate pool of connections for http1 and http2 protocols. Basically different pools for different protocols.
         * Reactor Netty's HttpConnectionProvider will wrap this connection provider and handle protocol routing in the acquire() call. It examines
         * the configured protocols and routes requests appropriately. So the http2 allocation strategy defined here will only be used for http2 connections.
         */
        ConnectionProvider pool = ConnectionProvider.builder("shared-pool")
                .maxConnections(config.getMaxConnections())
                .allocationStrategy((Http2AllocationStrategy.builder()
                        .maxConnections(config.getMaxConnections())
                        .maxConcurrentStreams(config.getMaxStreamPerChannel())
                        .minConnections(config.getMinConnections()).build()))
                .build();

        LoopResources loopResources = LoopResources.create("event-loop", config.getSelectorThreadCount(), config.getEventLoopThreadCount(), true, false);

        // Add the JMX MeterRegistry to the global Metrics registry
        JmxMeterRegistry jmxMeterRegistry = new JmxMeterRegistry(DEFAULT, SYSTEM);
        Metrics.addRegistry(jmxMeterRegistry);

        // Create HTTP/2 client
        SslContext finalSslContext = sslContext;
        this.httpClient = HttpClient
                // The custom pool is wrapped with a HttpConnectionProvider over here
                .create(pool)
                .protocol(HttpProtocol.H2, HttpProtocol.HTTP11)
                .runOn(loopResources, true)
                .http2Settings(settings -> settings.maxConcurrentStreams(config.getMaxStreamPerChannel()))
                .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) config.getConnectTimeout().getValue())
                // Track the metrics for all the tcp connections
                .metrics(true, () -> new MicrometerChannelMetricsRecorder("reactor.netty.http.client", "tcp", false));

        if (config.isHttpsEnabled()) {
            if (finalSslContext == null) {
                throw new IllegalStateException("SSL context must be configured for HTTPS");
            }
            httpClient = httpClient.secure(spec -> spec.sslContext(finalSslContext));
        }

        this.requestTimeout = config.getRequestTimeout();
    }

    @Override
    public <T, E extends Exception> T execute(Request request, ResponseHandler<T, E> responseHandler)
            throws E
    {
        throw new UnsupportedOperationException();
    }

    public <T, E extends Exception> HttpResponseFuture<T> executeAsync(Request airliftRequest, ResponseHandler<T, E> responseHandler)
    {
        SettableFuture<Object> listenableFuture = SettableFuture.create();

        // Set the request headers
        HttpClient client = this.httpClient.headers(hdr -> {
            for (Map.Entry<String, String> entry : airliftRequest.getHeaders().entries()) {
                hdr.set(entry.getKey(), entry.getValue());
            }
        });

        URI uri = airliftRequest.getUri();
        Disposable disposable;
        switch (airliftRequest.getMethod()) {
            case "GET":
                disposable = client.get()
                        .uri(uri)
                        .responseSingle((response, bytes) -> bytes.asInputStream().zipWith(Mono.just(response)))
                        // Request timeout
                        .timeout(java.time.Duration.of(requestTimeout.toMillis(), MILLIS))
                        .subscribe(t -> onSuccess(responseHandler, t.getT1(), t.getT2(), listenableFuture), e -> onError(listenableFuture, e), () -> onComplete(listenableFuture));
                break;
            case "POST":
                byte[] postBytes = ((StaticBodyGenerator) airliftRequest.getBodyGenerator()).getBody();
                disposable = client.post()
                        .uri(uri)
                        .send(ByteBufFlux.fromInbound(Mono.just(postBytes)))
                        .responseSingle((response, bytes) -> bytes.asInputStream().zipWith(Mono.just(response)))
                        // Request timeout
                        .timeout(java.time.Duration.of(requestTimeout.toMillis(), MILLIS))
                        .subscribe(t -> onSuccess(responseHandler, t.getT1(), t.getT2(), listenableFuture), e -> onError(listenableFuture, e), () -> onComplete(listenableFuture));
                break;
            case "DELETE":
                disposable = client.delete()
                        .uri(uri)
                        .responseSingle((response, bytes) -> bytes.asInputStream().zipWith(Mono.just(response)))
                        // Request timeout
                        .timeout(java.time.Duration.of(requestTimeout.toMillis(), MILLIS))
                        .subscribe(t -> onSuccess(responseHandler, t.getT1(), t.getT2(), listenableFuture), e -> onError(listenableFuture, e), () -> onComplete(listenableFuture));
                break;
            default:
                throw new UnsupportedOperationException("Unexpected request: " + airliftRequest);
        }

        return new HttpResponseFuture()
        {
            @Override
            public boolean cancel(boolean mayInterruptIfRunning)
            {
                disposable.dispose();
                return listenableFuture.cancel(mayInterruptIfRunning);
            }

            @Override
            public boolean isCancelled()
            {
                return listenableFuture.isCancelled();
            }

            @Override
            public boolean isDone()
            {
                return listenableFuture.isDone();
            }

            @Override
            public Object get()
                    throws InterruptedException, ExecutionException
            {
                return listenableFuture.get();
            }

            @Override
            public Object get(long timeout, TimeUnit unit)
                    throws InterruptedException, ExecutionException, TimeoutException
            {
                return listenableFuture.get(timeout, unit);
            }

            @Override
            public void addListener(Runnable listener, Executor executor)
            {
                listenableFuture.addListener(listener, executor);
            }

            @Override
            public String getState()
            {
                return "";
            }
        };
    }

    public void onSuccess(ResponseHandler responseHandler, InputStream inputStream, HttpClientResponse response, SettableFuture<Object> listenableFuture)
    {
        ListMultimap<HeaderName, String> responseHeaders = ArrayListMultimap.create();
        HttpHeaders headers = response.responseHeaders();
        int status = response.status().code();
        if (status != 200 && status != 204) {
            listenableFuture.setException(new RuntimeException("Invalid response status: " + status));
            return;
        }

        long contentLength = 0;
        // Iterate over the headers
        for (String name : headers.names()) {
            if (name.equalsIgnoreCase(CONTENT_LENGTH_HEADER_NAME.toString())) {
                String val = headers.get(name);
                contentLength = Integer.parseInt(val);
                responseHeaders.put(CONTENT_LENGTH_HEADER_NAME, val);
            }
            else if (name.equalsIgnoreCase(CONTENT_TYPE_HEADER_NAME.toString())) {
                responseHeaders.put(CONTENT_TYPE_HEADER_NAME, headers.get(name));
            }
            else {
                responseHeaders.put(HeaderName.of(name), headers.get(name));
            }
        }

        if (!responseHeaders.containsKey(CONTENT_TYPE_HEADER_NAME) || responseHeaders.get(CONTENT_TYPE_HEADER_NAME).size() != 1) {
            listenableFuture.setException(new RuntimeException("Expected ContentType header: " + responseHeaders));
            return;
        }

        try {
            long finalContentLength = contentLength;
            Object a = responseHandler.handle(null, new Response()
            {
                @Override
                public int getStatusCode()
                {
                    return status;
                }

                @Override
                public ListMultimap<HeaderName, String> getHeaders()
                {
                    return responseHeaders;
                }

                @Override
                public long getBytesRead()
                {
                    return finalContentLength;
                }

                @Override
                public InputStream getInputStream()
                        throws IOException
                {
                    return inputStream;
                }
            });
            // closing it here to prevent memory leak of bytebuf
            inputStream.close();
            listenableFuture.set(a);
        }
        catch (Exception e) {
            listenableFuture.setException(e);
        }
        finally {
            try {
                inputStream.close();
            }
            catch (IOException e) {
                log.warn(e, "Failed to close input stream");
            }
        }
    }

    public void onError(SettableFuture<Object> listenableFuture, Throwable t)
    {
        listenableFuture.setException(t);
    }

    public void onComplete(SettableFuture<Object> listenableFuture)
    {
        if (!listenableFuture.isDone()) {
            listenableFuture.setException(new RuntimeException("completed without success or failure"));
        }
    }

    @Override
    public RequestStats getStats()
    {
        return null;
    }

    @Override
    public long getMaxContentLength()
    {
        return 0;
    }

    @Override
    public void close()
    {
        // void
    }

    @Override
    public boolean isClosed()
    {
        return false;
    }
}