ExtendedExtensionTest.java

/*
 * Copyright (c) 2014, 2017 Oracle and/or its affiliates. All rights reserved.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License v. 2.0, which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * This Source Code may also be made available under the following Secondary
 * Licenses when the conditions for such availability set forth in the
 * Eclipse Public License v. 2.0 are satisfied: GNU General Public License,
 * version 2 with the GNU Classpath Exception, which is available at
 * https://www.gnu.org/software/classpath/license.html.
 *
 * SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
 */

package org.glassfish.tyrus.test.e2e.appconfig;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import javax.websocket.ClientEndpointConfig;
import javax.websocket.CloseReason;
import javax.websocket.DeploymentException;
import javax.websocket.EncodeException;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.Extension;
import javax.websocket.HandshakeResponse;
import javax.websocket.MessageHandler;
import javax.websocket.Session;
import javax.websocket.server.ServerApplicationConfig;
import javax.websocket.server.ServerEndpointConfig;

import org.glassfish.tyrus.client.ClientManager;
import org.glassfish.tyrus.core.extension.ExtendedExtension;
import org.glassfish.tyrus.core.frame.Frame;
import org.glassfish.tyrus.server.Server;
import org.glassfish.tyrus.test.tools.TestContainer;

import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

/**
 * @author Pavel Bucek (pavel.bucek at oracle.com)
 */
public class ExtendedExtensionTest extends TestContainer {

    public ExtendedExtensionTest() {
        this.setContextPath("/e2e-test-appconfig");
    }

    public static class ExtendedExtensionApplicationConfig implements ServerApplicationConfig {

        @Override
        public Set<ServerEndpointConfig> getEndpointConfigs(Set<Class<? extends Endpoint>> endpointClasses) {
            Set<ServerEndpointConfig> endpointConfigs = new HashSet<ServerEndpointConfig>();
            endpointConfigs.add(
                    ServerEndpointConfig.Builder
                            .create(ExtendedExtensionEndpoint.class, "/extendedExtensionEndpoint")
                            .extensions(Collections.<Extension>singletonList(new TestExtendedExtension(1)))
                            .build());
            endpointConfigs.add(
                    ServerEndpointConfig.Builder
                            .create(ExtendedExtensionOrderedEndpoint.class, "/extendedExtensionOrderedEndpoint")
                            .extensions(Arrays.<Extension>asList(
                                    new TestExtendedServerExtension(2, "ext1"),
                                    new TestExtendedServerExtension(3, "ext2")))
                            .build());
            return endpointConfigs;
        }

        @Override
        public Set<Class<?>> getAnnotatedEndpointClasses(Set<Class<?>> scanned) {
            return Collections.<Class<?>>emptySet();
        }
    }

    /**
     * {@link org.glassfish.tyrus.test.e2e.appconfig.ExtendedExtensionTest.Constants#MESSAGE} cannot be directly in
     * {@link org.glassfish.tyrus.test.e2e.appconfig.ExtendedExtensionTest}, because {@link
     * org.glassfish.tyrus.test.tools.TestContainer} is not be available at runtime.
     */
    private static class Constants {
        static final byte[] MESSAGE = {'h', 'e', 'l', 'l', 'o'};
    }

    @Test
    public void extendedExtensionTest() throws DeploymentException {

        Server server = startServer(ExtendedExtensionApplicationConfig.class);
        final CountDownLatch messageLatch = new CountDownLatch(1);

        try {
            ArrayList<Extension> extensions = new ArrayList<Extension>();
            final TestExtendedExtension clientExtension = new TestExtendedExtension(0);
            extensions.add(clientExtension);

            final ClientEndpointConfig clientConfiguration =
                    ClientEndpointConfig.Builder.create().extensions(extensions)
                                                .configurator(new LoggingClientEndpointConfigurator()).build();

            ClientManager client = createClient();
            final Session session = client.connectToServer(new Endpoint() {
                @Override
                public void onOpen(Session session, EndpointConfig config) {
                    session.addMessageHandler(new MessageHandler.Whole<byte[]>() {
                        @Override
                        public void onMessage(byte[] message) {
                            System.out.println("client onMessage.");
                            if (Arrays.equals(Constants.MESSAGE, message)) {
                                messageLatch.countDown();
                            }
                        }
                    });

                    try {
                        session.getBasicRemote().sendObject(Constants.MESSAGE);
                    } catch (IOException e) {
                        e.printStackTrace();
                    } catch (EncodeException e) {
                        e.printStackTrace();
                    }
                }
            }, clientConfiguration, getURI("/extendedExtensionEndpoint"));

            assertEquals(1, session.getNegotiatedExtensions().size());
            final Extension extension = session.getNegotiatedExtensions().get(0);
            assertEquals(clientExtension, extension);

            assertTrue(messageLatch.await(3, TimeUnit.SECONDS));
        } catch (Exception e) {
            e.printStackTrace();
            throw new RuntimeException(e.getMessage(), e);
        } finally {
            stopServer(server);
        }
    }

    public static class ExtendedExtensionEndpoint extends Endpoint {

        @Override
        public void onOpen(final Session session, EndpointConfig config) {
            print("onOpen " + session);
            session.addMessageHandler(new MessageHandler.Whole<byte[]>() {
                @Override
                public void onMessage(byte[] message) {
                    try {
                        print("server onMessage.");

                        if ((message[0] ^ TestExtendedExtension.MASK) == Constants.MESSAGE[0]
                                && (message[1] ^ TestExtendedExtension.MASK) == Constants.MESSAGE[1]
                                && Arrays.equals(Arrays.copyOfRange(message, 2, 4),
                                                 Arrays.copyOfRange(Constants.MESSAGE, 2, 4))) {
                            session.getBasicRemote().sendObject(message);
                        }
                    } catch (IOException e) {
                        e.printStackTrace();
                    } catch (EncodeException e) {
                        e.printStackTrace();
                    }
                }
            });
        }

        @Override
        public void onClose(Session session, CloseReason closeReason) {
            print("onClose " + session);
        }

        @Override
        public void onError(Session session, Throwable thr) {
            print("onError " + session);
            thr.printStackTrace();
        }

        private void print(String s) {
            System.out.println(this.getClass().getName() + " " + s);
        }
    }

    public static class TestExtendedExtension implements ExtendedExtension {

        public static final byte MASK = 0x55;
        public static final String NAME = "TestExtendedExtension";

        protected final int index;
        private final String name;


        public TestExtendedExtension(int index) {
            this.index = index;
            this.name = NAME;
        }

        public TestExtendedExtension(int index, String name) {
            this.index = index;
            this.name = name;
        }

        @Override
        public Frame processIncoming(ExtendedExtension.ExtensionContext context, Frame frame) {
            if (!frame.isControlFrame()) {
                final byte[] payloadData = frame.getPayloadData();
                payloadData[index] = (byte) (payloadData[index] ^ MASK);
                return Frame.builder(frame).payloadData(payloadData).build();
            } else {
                return frame;
            }
        }

        @Override
        public Frame processOutgoing(ExtendedExtension.ExtensionContext context, Frame frame) {
            if (!frame.isControlFrame()) {
                final byte[] payloadData = frame.getPayloadData();
                payloadData[index] = (byte) (payloadData[index] ^ MASK);
                return Frame.builder(frame).payloadData(payloadData).build();
            } else {
                return frame;
            }
        }

        @Override
        public List<Extension.Parameter> onExtensionNegotiation(ExtendedExtension.ExtensionContext context,
                                                                List<Extension.Parameter> requestedParameters) {
            print("onExtensionNegotiation :: " + context + " :: " + requestedParameters);
            return requestedParameters;
        }

        @Override
        public void onHandshakeResponse(ExtensionContext context, List<Parameter> responseParameters) {
            print("onHandshakeResponse :: " + context + " :: " + responseParameters);
        }

        @Override
        public String getName() {
            print("getName: " + name);
            return name;
        }

        @Override
        public List<Parameter> getParameters() {
            return null;
        }

        @Override
        public void destroy(ExtensionContext context) {
            print("destroy :: " + context);
        }

        private void print(String s) {
            System.out.println("##### " + NAME + " " + s);
        }
    }

    public static class LoggingClientEndpointConfigurator extends ClientEndpointConfig.Configurator {
        @Override
        public void beforeRequest(Map<String, List<String>> headers) {
            System.out.println("##### beforeRequest");
            System.out.println(headers);
            System.out.println();
        }

        @Override
        public void afterResponse(HandshakeResponse hr) {
            System.out.println("##### afterResponse");
            System.out.println(hr.getHeaders());
            System.out.println();
        }
    }

    @Test
    public void extendedExtensionOrderingTest() throws DeploymentException {

        Server server = startServer(ExtendedExtensionApplicationConfig.class);
        final CountDownLatch messageLatch = new CountDownLatch(1);

        try {
            ArrayList<Extension> extensions = new ArrayList<Extension>();
            final TestExtendedExtension clientExtension1 = new TestExtendedClientExtension(0, "ext1");
            final TestExtendedExtension clientExtension2 = new TestExtendedClientExtension(1, "ext2");
            extensions.add(clientExtension1);
            extensions.add(clientExtension2);


            final ClientEndpointConfig clientConfiguration =
                    ClientEndpointConfig.Builder
                            .create().extensions(extensions)
                            .configurator(new LoggingClientEndpointConfigurator()).build();

            ClientManager client = createClient();
            final Session session = client.connectToServer(new Endpoint() {
                @Override
                public void onOpen(Session session, EndpointConfig config) {
                    session.addMessageHandler(new MessageHandler.Whole<byte[]>() {
                        @Override
                        public void onMessage(byte[] message) {
                            System.out.println("client onMessage.");

                            if (Arrays.equals(Constants.MESSAGE, message)) {
                                messageLatch.countDown();
                            }
                        }
                    });

                    try {
                        session.getBasicRemote().sendObject(Constants.MESSAGE);
                    } catch (IOException e) {
                        e.printStackTrace();
                    } catch (EncodeException e) {
                        e.printStackTrace();
                    }
                }
            }, clientConfiguration, getURI("/extendedExtensionOrderedEndpoint"));

            assertEquals(2, session.getNegotiatedExtensions().size());
            assertEquals(clientExtension1, session.getNegotiatedExtensions().get(0));
            assertEquals(clientExtension2, session.getNegotiatedExtensions().get(1));

            assertTrue(messageLatch.await(3, TimeUnit.SECONDS));
        } catch (Exception e) {
            e.printStackTrace();
            throw new RuntimeException(e.getMessage(), e);
        } finally {
            stopServer(server);
        }
    }

    public static class TestExtendedClientExtension extends TestExtendedExtension {
        public TestExtendedClientExtension(int index, String name) {
            super(index, name);
        }

        @Override
        public Frame processIncoming(ExtensionContext context, Frame frame) {
            if (frame.isControlFrame()) {
                return frame;
            }

            frame = super.processIncoming(context, frame);
            if (index == 0) {
                assertEquals(frame.getPayloadData()[1], (Constants.MESSAGE[1] ^ MASK));
            } else if (index == 1) {
                assertEquals(frame.getPayloadData()[0], Constants.MESSAGE[0]);
            } else {
                throw new IllegalArgumentException();
            }
            return frame;
        }

        @Override
        public Frame processOutgoing(ExtensionContext context, Frame frame) {
            if (frame.isControlFrame()) {
                return frame;
            }

            check(frame);
            return super.processOutgoing(context, frame);
        }

        private void check(Frame frame) {
            if (!Arrays.equals(Arrays.copyOfRange(frame.getPayloadData(), index, Constants.MESSAGE.length),
                               Arrays.copyOfRange(Constants.MESSAGE, index, Constants.MESSAGE.length))) {
                throw new IllegalArgumentException();
            } else {
                for (int i = 0; i < index; i++) {
                    if (frame.getPayloadData()[i] != (Constants.MESSAGE[i] ^ MASK)) {
                        throw new IllegalArgumentException();
                    }
                }
            }
        }
    }

    public static class TestExtendedServerExtension extends TestExtendedExtension {
        public TestExtendedServerExtension(int index, String name) {
            super(index, name);
        }

        @Override
        public Frame processIncoming(ExtensionContext context, Frame frame) {
            if (frame.isControlFrame()) {
                return frame;
            }

            check(frame);
            return super.processIncoming(context, frame);
        }

        @Override
        public Frame processOutgoing(ExtensionContext context, Frame frame) {
            if (frame.isControlFrame()) {
                return frame;
            }

            // no junit on server side.
            if (index == 2) {
                if (frame.getPayloadData()[3] != (Constants.MESSAGE[3] ^ MASK)) {
                    throw new IllegalArgumentException();
                }
            } else if (index == 3) {
                if (frame.getPayloadData()[2] != Constants.MESSAGE[2]) {
                    throw new IllegalArgumentException();
                }
            } else {
                throw new IllegalArgumentException();
            }

            frame = super.processOutgoing(context, frame);
            return frame;
        }

        private void check(Frame frame) {
            if (!Arrays.equals(Arrays.copyOfRange(frame.getPayloadData(), index, Constants.MESSAGE.length), Arrays
                    .copyOfRange(Constants.MESSAGE, index, Constants.MESSAGE.length))) {
                throw new IllegalArgumentException();
            } else {
                for (int i = 0; i < index; i++) {
                    if (frame.getPayloadData()[i] != (Constants.MESSAGE[i] ^ MASK)) {
                        throw new IllegalArgumentException();
                    }
                }
            }
        }
    }

    public static class ExtendedExtensionOrderedEndpoint extends Endpoint {

        @Override
        public void onOpen(final Session session, EndpointConfig config) {
            print("onOpen " + session);
            session.addMessageHandler(new MessageHandler.Whole<byte[]>() {
                @Override
                public void onMessage(byte[] message) {
                    try {
                        print("server onMessage.");

                        if ((message[0] ^ TestExtendedExtension.MASK) == Constants.MESSAGE[0]
                                && (message[1] ^ TestExtendedExtension.MASK) == Constants.MESSAGE[1]
                                && (message[2] ^ TestExtendedExtension.MASK) == Constants.MESSAGE[2]
                                && (message[3] ^ TestExtendedExtension.MASK) == Constants.MESSAGE[3]) {
                            session.getBasicRemote().sendObject(message);
                        }
                    } catch (IOException e) {
                        e.printStackTrace();
                    } catch (EncodeException e) {
                        e.printStackTrace();
                    }
                }
            });
        }

        @Override
        public void onClose(Session session, CloseReason closeReason) {
            print("onClose " + session);
        }

        @Override
        public void onError(Session session, Throwable thr) {
            print("onError " + session);
            thr.printStackTrace();
        }

        private void print(String s) {
            System.out.println(this.getClass().getName() + " " + s);
        }
    }
}