WebSocketTestClient.java

/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2014 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */
package io.undertow.websockets.utils;

import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketClientHandshakerFactory;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
import io.netty.util.CharsetUtil;
import io.netty.util.ReferenceCountUtil;

import java.net.InetSocketAddress;
import java.net.URI;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * Client which can be used to Test a websocket server
 *
 * @author <a href="mailto:nmaurer@redhat.com">Norman Maurer</a>
 */
public final class WebSocketTestClient {
    private final Bootstrap bootstrap = new Bootstrap();
    private Channel ch;
    private final URI uri;
    private final WebSocketVersion version;
    private volatile boolean closed;

    private static final AtomicInteger count = new AtomicInteger();

    public WebSocketTestClient(WebSocketVersion version, URI uri) {
        this.uri = uri;
        this.version = version;
    }

    /**
     * Connect the WebSocket client
     *
     * @throws Exception
     */
    public WebSocketTestClient connect() throws Exception {
        String protocol = uri.getScheme();
        if (!"ws".equals(protocol)) {
            throw new IllegalArgumentException("Unsupported protocol: " + protocol);
        }
        final WebSocketClientHandshaker handshaker =
                WebSocketClientHandshakerFactory.newHandshaker(
                        uri, version, null, false, new DefaultHttpHeaders());
        WSClientHandler handler = new WSClientHandler(handshaker);
        EventLoopGroup group = new NioEventLoopGroup();
        bootstrap.group(group)
                .channel(NioSocketChannel.class)
                .handler(new ChannelInitializer() {
                    @Override
                    protected void initChannel(Channel channel) throws Exception {
                        ChannelPipeline p = channel.pipeline();
                        p.addLast(
                                new HttpClientCodec(),
                                new HttpObjectAggregator(8192), handler);
                    }
                });

        // Connect
        ChannelFuture future =
                bootstrap.connect(
                        new InetSocketAddress(uri.getHost(), uri.getPort()));
        future.syncUninterruptibly();
        handler.handshakeFuture.syncUninterruptibly();
        ch = future.channel();
        return this;
    }

    /**
     * Send the WebSocketFrame and call the FrameListener once a frame was received as response or
     * when an Exception was caught.
     */
    public WebSocketTestClient send(WebSocketFrame frame, final FrameListener listener) {
        ch.pipeline().addLast("responseHandler" + count.incrementAndGet(), new SimpleChannelInboundHandler<Object>() {

            @Override
            protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
                if (msg instanceof CloseWebSocketFrame) {
                    closed = true;
                }
                listener.onFrame((WebSocketFrame) msg);
                ctx.pipeline().remove(this);
            }

            @Override
            public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
                cause.printStackTrace();
                listener.onError(cause);
                ctx.pipeline().remove(this);
            }
        });
        ChannelFuture cf = ch.writeAndFlush(frame).syncUninterruptibly();
        if (!cf.isSuccess()) {
            listener.onError(cf.cause());
        }
        return this;
    }

    /**
     * Destroy the client and also close open connections if any exist
     */
    public void destroy() {
         this.destroy(false);
     }

    public void destroy(boolean dirty) {
        if (!closed && !dirty) {
            final CountDownLatch latch = new CountDownLatch(1);
            send(new CloseWebSocketFrame(), new FrameListener() {
                @Override
                public void onFrame(WebSocketFrame frame) {
                    latch.countDown();
                }

                @Override
                public void onError(Throwable t) {
                    latch.countDown();
                }
            });
            try {
                latch.await(10, TimeUnit.SECONDS);
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        }
        //bootstrap.releaseExternalResources();
        if (ch != null) {
            ch.close().syncUninterruptibly();
        }
        try {
            bootstrap.group().shutdownGracefully(0, 1, TimeUnit.SECONDS).get();
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        } catch (ExecutionException e) {
            throw new RuntimeException(e);
        }
    }

    public boolean isActive() {
        if(this.ch != null) {
            return this.ch.isActive();
        }
        return false;
    }

    public boolean isOpen() {
        if(this.ch != null) {
            return this.ch.isOpen();
        }
        return false;
    }

    public boolean isWritable() {
        if(this.ch != null) {
            return this.ch.isWritable();
        }
        return false;
    }

    public interface FrameListener {
        /**
         * Is called if an WebSocketFrame was received
         */
        void onFrame(WebSocketFrame frame);

        /**
         * Is called if an error occurred
         */
        void onError(Throwable t);
    }
    private static final class WSClientHandler extends SimpleChannelInboundHandler<Object> {

        private final WebSocketClientHandshaker handshaker;
        private ChannelPromise handshakeFuture;

        WSClientHandler(WebSocketClientHandshaker handshaker) {
            super(false);
            this.handshaker = handshaker;
        }

        public ChannelFuture handshakeFuture() {
            return handshakeFuture;
        }

        @Override
        public void handlerAdded(ChannelHandlerContext ctx) {
            handshakeFuture = ctx.newPromise();
        }

        @Override
        public void channelActive(ChannelHandlerContext ctx) {
            handshaker.handshake(ctx.channel());
        }

        @Override
        protected void channelRead0(ChannelHandlerContext ctx, Object o) throws Exception {

            Channel ch = ctx.channel();

            if (!handshaker.isHandshakeComplete()) {
                handshaker.finishHandshake(ch, (FullHttpResponse) o);
                // the handshake response was processed upgrade is complete
                handshakeFuture.setSuccess();
                ReferenceCountUtil.release(o);
                return;
            }

            if (o instanceof FullHttpResponse) {
                FullHttpResponse response = (FullHttpResponse) o;
                ReferenceCountUtil.release(o);
                throw new Exception("Unexpected HttpResponse (status=" + response.getStatus() + ", content="
                        + response.content().toString(CharsetUtil.UTF_8) + ')');
            }
            ctx.fireChannelRead(o);
        }
        @Override
        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
            cause.printStackTrace();
            if (!handshakeFuture.isDone()) {
                handshakeFuture.setFailure(cause);
            }
            ctx.close();
        }
    }
}