JSSEProviderIntegrationTest.java

/*
 * ====================================================================
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 * ====================================================================
 *
 * This software consists of voluntary contributions made by many
 * individuals on behalf of the Apache Software Foundation.  For more
 * information on the Apache Software Foundation, please see
 * <http://www.apache.org/>.
 *
 */

package org.apache.hc.core5.testing.nio;

import java.net.InetSocketAddress;
import java.net.URL;
import java.security.Provider;
import java.security.SecureRandom;
import java.security.Security;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.Future;

import org.apache.hc.core5.http.HttpHeaders;
import org.apache.hc.core5.http.HttpHost;
import org.apache.hc.core5.http.HttpResponse;
import org.apache.hc.core5.http.Message;
import org.apache.hc.core5.http.message.BasicHttpRequest;
import org.apache.hc.core5.http.nio.entity.StringAsyncEntityConsumer;
import org.apache.hc.core5.http.nio.support.AsyncRequestBuilder;
import org.apache.hc.core5.http.nio.support.BasicRequestProducer;
import org.apache.hc.core5.http.nio.support.BasicResponseConsumer;
import org.apache.hc.core5.http.protocol.DefaultHttpProcessor;
import org.apache.hc.core5.http.protocol.RequestValidateHost;
import org.apache.hc.core5.http.support.BasicRequestBuilder;
import org.apache.hc.core5.reactor.IOReactorConfig;
import org.apache.hc.core5.ssl.SSLContextBuilder;
import org.apache.hc.core5.util.TimeValue;
import org.apache.hc.core5.util.Timeout;
import org.conscrypt.Conscrypt;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Assumptions;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.RegisterExtension;

abstract class JSSEProviderIntegrationTest {

    private final String securityProviderName;
    private final String protocolVersion;

    public JSSEProviderIntegrationTest(final String securityProviderName, final String protocolVersion) {
        super();
        this.securityProviderName = securityProviderName;
        this.protocolVersion = protocolVersion;
    }

    private static final Timeout TIMEOUT = Timeout.ofMinutes(1);
    private static final int REQ_NUM = 25;

    private Provider securityProvider;

    class SecurityProviderResource implements BeforeEachCallback, AfterEachCallback {

        @Override
        public void beforeEach(final ExtensionContext context) throws Exception {
            if ("Conscrypt".equalsIgnoreCase(securityProviderName)) {
                final Set<String> supportedArchitectures = new HashSet<>(Arrays.asList("x86", "x86_64",
                        "x86-64", "amd64", "aarch64", "armeabi-v7a", "arm64-v8a"));
                Assumptions.assumeTrue(supportedArchitectures.contains(System.getProperty("os.arch")));
                try {
                    securityProvider = Conscrypt.newProviderBuilder().provideTrustManager(true).build();
                } catch (final UnsatisfiedLinkError e) {
                    Assertions.fail("Conscrypt provider failed to be loaded: " + e.getMessage());
                }
            } else {
                securityProvider = null;
            }
            if (securityProvider != null) {
                Security.insertProviderAt(securityProvider, 1);
            }
        }

        @Override
        public void afterEach(final ExtensionContext context) throws Exception {
            if (securityProvider != null) {
                Security.removeProvider(securityProvider.getName());
                securityProvider = null;
            }
        }

    }

    @RegisterExtension
    @Order(1)
    private final SecurityProviderResource securityProviderResource = new SecurityProviderResource();

    private Http1TestServer server;

    class ServerResource implements BeforeEachCallback, AfterEachCallback {

        @Override
        public void beforeEach(final ExtensionContext context) throws Exception {
            final URL keyStoreURL = getClass().getResource("/test-server.p12");
            final String storePassword = "nopassword";

            server = new Http1TestServer(
                    IOReactorConfig.custom()
                            .setSoTimeout(TIMEOUT)
                            .build(),
                    SSLContextBuilder.create()
                            .setProvider(securityProvider)
                            .setKeyStoreType("pkcs12")
                            .loadTrustMaterial(keyStoreURL, storePassword.toCharArray())
                            .loadKeyMaterial(keyStoreURL, storePassword.toCharArray(), storePassword.toCharArray())
                            .setSecureRandom(new SecureRandom())
                            .build(),
                    (endpoint, sslEngine) -> {
                        if (protocolVersion != null) {
                            sslEngine.setEnabledProtocols(new String[]{protocolVersion});
                        }
                    },
                    null);
        }

        @Override
        public void afterEach(final ExtensionContext context) throws Exception {
            if (server != null) {
                server.shutdown(TimeValue.ofSeconds(5));
            }
        }

    }

    @RegisterExtension
    @Order(2)
    private final ServerResource serverResource = new ServerResource();

    private Http1TestClient client;

    class ClientResource implements BeforeEachCallback, AfterEachCallback {

        @Override
        public void beforeEach(final ExtensionContext context) throws Exception {
            final URL keyStoreURL = getClass().getResource("/test-client.p12");
            final String storePassword = "nopassword";

            client = new Http1TestClient(
                    IOReactorConfig.custom()
                            .setSoTimeout(TIMEOUT)
                            .build(),
                    SSLContextBuilder.create()
                            .setProvider(securityProvider)
                            .setKeyStoreType("pkcs12")
                            .loadTrustMaterial(keyStoreURL, storePassword.toCharArray())
                            .setSecureRandom(new SecureRandom())
                            .build(),
                    (endpoint, sslEngine) -> {
                        if (protocolVersion != null) {
                            sslEngine.setEnabledProtocols(new String[]{protocolVersion});
                        }
                    },
                    null);
        }

        @Override
        public void afterEach(final ExtensionContext context) throws Exception {
            if (client != null) {
                client.shutdown(TimeValue.ofSeconds(5));
            }
        }

    }

    @RegisterExtension
    @Order(3)
    private final ClientResource clientResource = new ClientResource();

    private HttpHost target(final InetSocketAddress serverEndpoint) {
        return new HttpHost("https", null, "localhost", serverEndpoint.getPort());
    }

    @Test
    void testSimpleGet() throws Exception {
        server.register("/hello", () -> new SingleLineResponseHandler("Hi there"));
        final InetSocketAddress serverEndpoint = server.start();

        final HttpHost target = target(serverEndpoint);

        client.start();
        final Future<ClientSessionEndpoint> connectFuture = client.connect(target, TIMEOUT);
        final ClientSessionEndpoint streamEndpoint = connectFuture.get();

        for (int i = 0; i < REQ_NUM; i++) {
            final BasicHttpRequest request = BasicRequestBuilder.get()
                    .setHttpHost(target)
                    .setPath("/hello")
                    .build();
            final Future<Message<HttpResponse, String>> future = streamEndpoint.execute(
                    new BasicRequestProducer(request, null),
                    new BasicResponseConsumer<>(new StringAsyncEntityConsumer()), null);
            final Message<HttpResponse, String> result = future.get(TIMEOUT.getDuration(), TIMEOUT.getTimeUnit());
            Assertions.assertNotNull(result);
            final HttpResponse response1 = result.getHead();
            final String entity1 = result.getBody();
            Assertions.assertNotNull(response1);
            Assertions.assertEquals(200, response1.getCode());
            Assertions.assertEquals("Hi there", entity1);
        }
    }

    @Test
    void testSimpleGetConnectionClose() throws Exception {
        server.register("/hello", () -> new SingleLineResponseHandler("Hi there"));
        final InetSocketAddress serverEndpoint = server.start();

        final HttpHost target = target(serverEndpoint);

        client.start();
        for (int i = 0; i < REQ_NUM; i++) {
            final Future<ClientSessionEndpoint> connectFuture = client.connect(
                    "localhost", serverEndpoint.getPort(), TIMEOUT);
            try (final ClientSessionEndpoint streamEndpoint = connectFuture.get()) {
                final Future<Message<HttpResponse, String>> future = streamEndpoint.execute(
                        AsyncRequestBuilder.get()
                                .setHttpHost(target)
                                .setPath("/hello")
                                .addHeader(HttpHeaders.CONNECTION, "close")
                                .build(),
                        new BasicResponseConsumer<>(new StringAsyncEntityConsumer()), null);
                final Message<HttpResponse, String> result = future.get(TIMEOUT.getDuration(), TIMEOUT.getTimeUnit());
                Assertions.assertNotNull(result);
                final HttpResponse response1 = result.getHead();
                final String entity1 = result.getBody();
                Assertions.assertNotNull(response1);
                Assertions.assertEquals(200, response1.getCode());
                Assertions.assertEquals("Hi there", entity1);
            }
        }
    }

    @Test
    void testSimpleGetIdentityTransfer() throws Exception {
        server.register("/hello", () -> new SingleLineResponseHandler("Hi there"));
        server.configure(new DefaultHttpProcessor(new RequestValidateHost()));
        final InetSocketAddress serverEndpoint = server.start();

        final HttpHost target = target(serverEndpoint);

        client.start();

        for (int i = 0; i < REQ_NUM; i++) {
            final Future<ClientSessionEndpoint> connectFuture = client.connect(
                    "localhost", serverEndpoint.getPort(), TIMEOUT);
            try (final ClientSessionEndpoint streamEndpoint = connectFuture.get()) {
                final BasicHttpRequest request = BasicRequestBuilder.get()
                        .setHttpHost(target)
                        .setPath("/hello")
                        .build();
                final Future<Message<HttpResponse, String>> future = streamEndpoint.execute(
                        new BasicRequestProducer(request, null),
                        new BasicResponseConsumer<>(new StringAsyncEntityConsumer()), null);
                final Message<HttpResponse, String> result = future.get(TIMEOUT.getDuration(), TIMEOUT.getTimeUnit());
                Assertions.assertNotNull(result);
                final HttpResponse response = result.getHead();
                final String entity = result.getBody();
                Assertions.assertNotNull(response);
                Assertions.assertEquals(200, response.getCode());
                Assertions.assertEquals("Hi there", entity);
            }
        }
    }

}