ChunkedInputWriteErrorSimulationTest.java

/*
 * Copyright (c) 2024, 2025 Oracle and/or its affiliates. All rights reserved.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License v. 2.0, which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * This Source Code may also be made available under the following Secondary
 * Licenses when the conditions for such availability set forth in the
 * Eclipse Public License v. 2.0 are satisfied: GNU General Public License,
 * version 2 with the GNU Classpath Exception, which is available at
 * https://www.gnu.org/software/classpath/license.html.
 *
 * SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
 */

package org.glassfish.jersey.netty.connector;

import io.netty.channel.Channel;
import org.glassfish.jersey.client.ClientConfig;
import org.glassfish.jersey.client.ClientProperties;
import org.glassfish.jersey.client.ClientRequest;
import org.glassfish.jersey.client.spi.Connector;
import org.glassfish.jersey.client.spi.ConnectorProvider;
import org.glassfish.jersey.netty.connector.internal.JerseyChunkedInput;
import org.glassfish.jersey.netty.connector.internal.NettyEntityWriter;
import org.glassfish.jersey.server.ResourceConfig;
import org.glassfish.jersey.test.JerseyTest;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import javax.ws.rs.POST;
import javax.ws.rs.Path;
import javax.ws.rs.client.Client;
import javax.ws.rs.client.ClientBuilder;
import javax.ws.rs.client.Entity;
import javax.ws.rs.client.Invocation;
import javax.ws.rs.client.WebTarget;
import javax.ws.rs.core.Application;
import javax.ws.rs.core.Configuration;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.MultivaluedHashMap;
import javax.ws.rs.core.Response;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Proxy;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.logging.Level;
import java.util.logging.Logger;

public class ChunkedInputWriteErrorSimulationTest extends JerseyTest {
    private static final String EXCEPTION_MSG = "BOGUS BUFFER OVERFLOW";
    private static final AtomicReference<Throwable> caught = new AtomicReference<>(null);

    public static class ClientThread extends Thread {

        public static AtomicInteger count = new AtomicInteger();
        public static String url;
        public static int nLoops;

        private static Client client;

        public static void main(DequeOffer offer, String[] args) throws InterruptedException {
            url = args[0];
            int nThreads = Integer.parseInt(args[1]);
            nLoops = Integer.parseInt(args[2]);
            initClient(offer);
            Thread[] threads = new Thread[nThreads];
            for (int i = 0; i < nThreads; i++) {
                threads[i] = new ClientThread();
                threads[i].start();
            }

            for (int i = 0; i < nThreads; i++) {
                threads[i].join();
            }
            // System.out.println("Processed calls: " + count);
        }

        private static void initClient(DequeOffer offer) {
            ClientConfig defaultConfig = new ClientConfig();
            defaultConfig.property(ClientProperties.CONNECT_TIMEOUT, 10 * 1000);
            defaultConfig.property(ClientProperties.READ_TIMEOUT, 10 * 1000);
            defaultConfig.connectorProvider(getJerseyChunkedInputModifiedNettyConnector(offer));
            client = ClientBuilder.newBuilder()
                    .withConfig(defaultConfig)
                    .build();
        }

        public void doCall() {
            CompletableFuture<Response> cf = invokeResponse().toCompletableFuture()
                    .whenComplete((rsp, t) -> {
                        if (t != null) {
//                            System.out.println(Thread.currentThread() + " async complete. Caught exception " + t);
//                            t.printStackTrace();
                            while (t.getCause() != null) {
                                t = t.getCause();
                            }
                            caught.set(t);
                        }
                    })
                    .handle((rsp, t) -> {
                        if (rsp != null) {
                            rsp.readEntity(String.class);
                        } else {
                            System.out.println(Thread.currentThread().getName() + " response is null");
                        }
                        return rsp;
                    }).exceptionally(t -> {
                        System.out.println("async complete. completed exceptionally " + t);
                        throw new RuntimeException(t);
                    });

            try {
                cf.get();
                System.out.println("Done call " + count.incrementAndGet());
            } catch (InterruptedException | ExecutionException ex) {
                Logger.getLogger(ClientThread.class.getName()).log(Level.SEVERE, null, ex);
            }
        }

        private static CompletionStage<Response> invokeResponse() {
            WebTarget target = client.target(url);
            MultivaluedHashMap hdrs = new MultivaluedHashMap<>();
            StringBuilder sb = new StringBuilder("{");
            for (int i = 0; i < 10000; i++) {
                sb.append("\"fname\":\"foo\", \"lname\":\"bar\"");
            }
            sb.append("}");
            String jsonPayload = sb.toString();
            Invocation.Builder builder = ((WebTarget) target).request().headers(hdrs);
            return builder.rx().method("POST", Entity.entity(jsonPayload, MediaType.APPLICATION_JSON_TYPE));
        }

        @Override
        public void run() {
            for (int i = 0; i < nLoops; i++) {
                try {
                    doCall();
                } catch (Throwable t) {
                    throw new RuntimeException(t);
                }
            }
        }
    }

    @Path("/console")
    public static class HangingEndpoint {
        @Path("/login")
        @POST
        public String post(String entity) {
            return "Welcome";
        }
    }

    @Override
    protected Application configure() {
        return new ResourceConfig(HangingEndpoint.class);
    }

    @Test
    public void testNoHangOnOfferInterrupt() throws InterruptedException {
        String path = getBaseUri() + "console/login";
        ClientThread.main(new InterruptedExceptionOffer(), new String[] {path, "5", "10"});
        Assertions.assertTrue(caught.get().getMessage().contains(EXCEPTION_MSG));
    }

    @Test
    public void testNoHangOnPollInterrupt() throws InterruptedException {
        String path = getBaseUri() + "console/login";
        ClientThread.main(new DequePoll(), new String[] {path, "5", "10"});
        Assertions.assertNotNull(caught.get());
    }

    @Test
    public void testNoHangOnOfferNoData() throws InterruptedException {
        String path = getBaseUri() + "console/login";
        ClientThread.main(new ReturnFalseOffer(), new String[] {path, "5", "10"});
        Assertions.assertTrue(caught.get().getMessage().contains("Buffer overflow")); //JerseyChunkedInput
        Thread.sleep(1_000L); // Sleep for the server to finish
    }

    private interface DequeOffer {
        public boolean offer(ByteBuffer e, long timeout, TimeUnit unit) throws InterruptedException;
    }

    private static class InterruptedExceptionOffer implements DequeOffer {
        private AtomicInteger ai = new AtomicInteger(0);

        @Override
        public boolean offer(ByteBuffer e, long timeout, TimeUnit unit) throws InterruptedException {
            if ((ai.getAndIncrement() % 10) == 0) {
                throw new InterruptedException(EXCEPTION_MSG);
            }
            return true;
        }
    }

    private static class ReturnFalseOffer implements DequeOffer {
        private AtomicInteger ai = new AtomicInteger(0);
        @Override
        public boolean offer(ByteBuffer e, long timeout, TimeUnit unit) throws InterruptedException {
            return !((ai.getAndIncrement() % 10) == 1);
        }
    }

    private static class DequePoll extends InterruptedExceptionOffer {
    }


    private static ConnectorProvider getJerseyChunkedInputModifiedNettyConnector(DequeOffer offer) {
        return new ConnectorProvider() {
            @Override
            public Connector getConnector(Client client, Configuration runtimeConfig) {
                return new NettyConnector(client, NettyConnectorProvider.config().rw()) {
                    @Override
                    NettyEntityWriter nettyEntityWriter(
                            ClientRequest clientRequest, Channel channel, NettyConnectorProvider.Config.RW config) {
                        NettyEntityWriter wrapped = NettyEntityWriter.getInstance(
                                clientRequest, channel, () -> config.requestEntityProcessing(clientRequest));

                        JerseyChunkedInput chunkedInput = (JerseyChunkedInput) wrapped.getChunkedInput();
                        try {
                            Field field = JerseyChunkedInput.class.getDeclaredField("queue");
                            field.setAccessible(true);

                            removeFinal(field);

                            field.set(chunkedInput, new LinkedBlockingDeque<ByteBuffer>() {
                                @Override
                                public boolean offer(ByteBuffer e, long timeout, TimeUnit unit) throws InterruptedException {
                                    if (!DequePoll.class.isInstance(offer) && !offer.offer(e, timeout, unit)) {
                                        return false;
                                    }
                                    return super.offer(e, timeout, unit);
                                }

                                @Override
                                public ByteBuffer poll(long timeout, TimeUnit unit) throws InterruptedException {
                                    if (DequePoll.class.isInstance(offer)) {
                                        offer.offer(null, timeout, unit);
                                    }
                                    return super.poll(timeout, unit);
                                }
                            });

                        } catch (Exception e) {
                            throw new RuntimeException(e);
                        }

                        NettyEntityWriter proxy = (NettyEntityWriter) Proxy.newProxyInstance(
                                ConnectorProvider.class.getClassLoader(), new Class[]{NettyEntityWriter.class},
                                (proxy1, method, args) -> {
                                    if (method.getName().equals("readChunk")) {
                                        try {
                                            return method.invoke(wrapped, args);
                                        } catch (RuntimeException e) {
                                            // consume
                                        }
                                    }
                                    return method.invoke(wrapped, args);
                                });
                        return proxy;
                    }
                };
            }
        };
    }

    public static void removeFinal(Field field) throws RuntimeException {
        try {
            Method[] classMethods = Class.class.getDeclaredMethods();
            Method declaredFieldMethod = Arrays
                    .stream(classMethods).filter(x -> Objects.equals(x.getName(), "getDeclaredFields0"))
                    .findAny().orElseThrow(() -> new NoSuchElementException("No value present"));
            declaredFieldMethod.setAccessible(true);
            Field[] declaredFieldsOfField = (Field[]) declaredFieldMethod.invoke(Field.class, false);
            Field modifiersField = Arrays
                    .stream(declaredFieldsOfField).filter(x -> Objects.equals(x.getName(), "modifiers"))
                    .findAny().orElseThrow(() -> new NoSuchElementException("No value present"));
            modifiersField.setAccessible(true);
            modifiersField.setInt(field, field.getModifiers() & ~Modifier.FINAL);
        } catch (RuntimeException re) {
            throw re;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

}