NioStartTLSTcpConnectionTestCase.java

/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2013 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.xnio.nio.test;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

import org.junit.Before;
import org.junit.Test;
import org.xnio.ChannelListener;
import org.xnio.IoUtils;
import org.xnio.OptionMap;
import org.xnio.Options;
import org.xnio.channels.ConnectedChannel;
import org.xnio.conduits.ConduitStreamSinkChannel;
import org.xnio.conduits.ConduitStreamSourceChannel;
import org.xnio.ssl.SslConnection;


/**
 * Test for {@code XnioSsl} connections with the start TLS option enabled.
 * 
 * @author <a href="mailto:frainone@redhat.com">Flavia Rainone</a>
 *
 */
public class NioStartTLSTcpConnectionTestCase extends NioSslTcpConnectionTestCase {

    @Before
    public void setStartTLSOption() {
        final OptionMap optionMap = OptionMap.create(Options.SSL_STARTTLS, true);
        super.setServerOptionMap(optionMap);
        super.setClientOptionMap(optionMap);
    }

    @Test
    public void oneWayTransfer3() throws Exception {
        log.info("Test: oneWayTransfer3");
        final CountDownLatch latch = new CountDownLatch(2);
        final AtomicInteger clientSent = new AtomicInteger(0);
        final AtomicInteger serverReceived = new AtomicInteger(0);
        final AtomicBoolean clientHandshakeStarted = new AtomicBoolean(false);
        final AtomicBoolean serverHandshakeStarted = new AtomicBoolean(false);
        doConnectionTest(new Runnable() {
            public void run() {
                try {
                    assertTrue(latch.await(500L, TimeUnit.MILLISECONDS));
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
            }
        }, new ChannelListener<SslConnection>() {
            public void handleEvent(final SslConnection connection) {
                connection.getCloseSetter().set(new ChannelListener<SslConnection>() {
                    public void handleEvent(final SslConnection channel) {
                        latch.countDown();
                    }
                });
                connection.getSinkChannel().setWriteListener(new ChannelListener<ConduitStreamSinkChannel>() {
                    private boolean continueWriting() throws IOException {
                        if (clientSent.get() > 100) {
                            if (!clientHandshakeStarted.get()) {
                                if (serverReceived.get() == clientSent.get()) {
                                    connection.startHandshake();
                                    log.info("client starting handshake");
                                    clientHandshakeStarted.set(true);
                                    return true;
                                }
                                return false;
                            }
                            if (serverHandshakeStarted.get()) {
                                return true;
                            }
                            return false;
                        }
                        return true;
                    }

                    public void handleEvent(final ConduitStreamSinkChannel channel) {
                        try {
                            final ByteBuffer buffer = ByteBuffer.allocate(100);
                            buffer.put("This Is A Test\r\n".getBytes("UTF-8")).flip();
                            int c;
                            try {
                                while (continueWriting() && (c = channel.write(buffer)) > 0) {
                                    log.info("client wrote " + (c + clientSent.get()));
                                    if (clientSent.addAndGet(c) > 1000) {
                                        final ChannelListener<ConduitStreamSinkChannel> listener = new ChannelListener<ConduitStreamSinkChannel>() {
                                            public void handleEvent(final ConduitStreamSinkChannel channel) {
                                                try {
                                                    if (channel.flush()) {
                                                        final ChannelListener<ConduitStreamSinkChannel> listener = new ChannelListener<ConduitStreamSinkChannel>() {
                                                            public void handleEvent(final ConduitStreamSinkChannel channel) {
                                                                // really lame, but due to the way SSL shuts down...
                                                                if (serverReceived.get() == clientSent.get()) {
                                                                    try {
                                                                        log.info("client shutting down writes");
                                                                        channel.shutdownWrites();
                                                                        if (connection.isWriteShutdown()) {
                                                                            log.info("client write handler closing connection");
                                                                            connection.close();
                                                                        }
                                                                    } catch (Throwable t) {
                                                                        t.printStackTrace();
                                                                        throw new RuntimeException(t);
                                                                    }
                                                                }
                                                            }
                                                        };
                                                        channel.getWriteSetter().set(listener);
                                                        listener.handleEvent(channel);
                                                        return;
                                                    }
                                                } catch (Throwable t) {
                                                    t.printStackTrace();
                                                    throw new RuntimeException(t);
                                                }
                                            }
                                        };
                                        channel.setWriteListener(listener);
                                        listener.handleEvent(channel);
                                        return;
                                    }
                                }
                                buffer.rewind();
                            } catch (ClosedChannelException e) {
                                try {
                                    channel.shutdownWrites();
                                } catch (Exception exception) {}
                                throw e;
                            }
                        } catch (Throwable t) {
                            t.printStackTrace();
                            throw new RuntimeException(t);
                        }
                    }
                });
                connection.getSinkChannel().resumeWrites();
            }
        }, new ChannelListener<SslConnection>() {
            public void handleEvent(final SslConnection connection) {
                connection.getCloseSetter().set(new ChannelListener<SslConnection>() {
                    public void handleEvent(final SslConnection channel) {
                        latch.countDown();
                    }
                });
                connection.getSourceChannel().setReadListener(new ChannelListener<ConduitStreamSourceChannel>() {
                    public void handleEvent(final ConduitStreamSourceChannel channel) {
                        try {
                            int c;
                            while ((c = channel.read(ByteBuffer.allocate(100))) > 0) {
                                log.info("server received " +  (c + serverReceived.get()));
                                if (serverReceived.addAndGet(c) > 100 && !serverHandshakeStarted.get() ) {
                                    connection.startHandshake();
                                    serverHandshakeStarted.set(true);
                                }
                            }
                            if (c == -1) {
                                log.info("server shutting down reads");
                                channel.shutdownReads();
                                if(connection.isReadShutdown()) {
                                    log.info("server read handler closing connection");
                                    connection.close();
                                }
                            }
                        } catch (Throwable t) {
                            t.printStackTrace();
                            throw new RuntimeException(t);
                        }
                    }
                });
                connection.getSourceChannel().resumeReads();
            }
        });
        assertEquals(clientSent.get(), serverReceived.get());
    }

    public void oneWayTransfer4() throws Exception {
        log.info("Test: oneWayTransfer4");
        final CountDownLatch latch = new CountDownLatch(2);
        final AtomicInteger clientReceived = new AtomicInteger(0);
        final AtomicInteger serverSent = new AtomicInteger(0);
        final AtomicBoolean clientHandshakeStarted = new AtomicBoolean(false);
        final AtomicBoolean serverHandshakeStarted = new AtomicBoolean(false);
        doConnectionTest(new Runnable() {
            public void run() {
                try {
                    assertTrue(latch.await(500L, TimeUnit.MILLISECONDS));
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
            }
        }, new ChannelListener<SslConnection>() {
            public void handleEvent(final SslConnection connection) {
                connection.getCloseSetter().set(new ChannelListener<SslConnection>() {
                    public void handleEvent(final SslConnection connection) {
                        latch.countDown();
                    }
                });
                connection.getSourceChannel().setReadListener(new ChannelListener<ConduitStreamSourceChannel>() {
                    public void handleEvent(final ConduitStreamSourceChannel channel) {
                        try {
                            int c;
                            while ((c = channel.read(ByteBuffer.allocate(100))) > 0) {
                                log.info("client received " +  (c + clientReceived.get()));
                                if (clientReceived.addAndGet(c) > 100 && !clientHandshakeStarted.get()) {
                                    connection.startHandshake();
                                    clientHandshakeStarted.set(true);
                                }
                            }
                            if (c == -1) {
                                channel.shutdownReads();
                                if (connection.isReadShutdown())
                                    connection.close();
                            }
                        } catch (Throwable t) {
                            t.printStackTrace();
                            throw new RuntimeException(t);
                        }
                    }
                });

                connection.getSourceChannel().resumeReads();
            }
        }, new ChannelListener<SslConnection>() {
            public void handleEvent(final SslConnection connection) {
                connection.getCloseSetter().set(new ChannelListener<SslConnection>() {
                    public void handleEvent(final SslConnection connection) {
                        latch.countDown();
                    }
                });
                connection.getSinkChannel().setWriteListener(new ChannelListener<ConduitStreamSinkChannel>() {
                    private boolean continueWriting() throws IOException {
                        if (serverSent.get() > 100) {
                            if (!serverHandshakeStarted.get()) {
                                if (clientReceived.get() == serverSent.get()) {
                                    connection.startHandshake();
                                    log.info("server starting handshake");
                                    serverHandshakeStarted.set(true);
                                    return true;
                                }
                                return false;
                            }
                            if (clientHandshakeStarted.get()) {
                                return true;
                            }
                            return false;
                        }
                        return true;
                    }

                    public void handleEvent(final ConduitStreamSinkChannel channel) {
                        try {
                            final ByteBuffer buffer = ByteBuffer.allocate(100);
                            buffer.put("This Is A Test\r\n".getBytes("UTF-8")).flip();
                            int c;
                            try {
                                while (continueWriting() && (c = channel.write(buffer)) > 0) {
                                    log.info("server wrote " + (c + serverSent.get()));
                                    if (serverSent.addAndGet(c) > 100) {
                                        connection.startHandshake();
                                        if (serverSent.get() > 1000) {
                                            final ChannelListener<ConduitStreamSinkChannel> listener = new ChannelListener<ConduitStreamSinkChannel>() {
                                                public void handleEvent(final ConduitStreamSinkChannel channel) {
                                                    try {
                                                        if (channel.flush()) {
                                                            final ChannelListener<ConduitStreamSinkChannel> listener = new ChannelListener<ConduitStreamSinkChannel>() {
                                                                public void handleEvent(final ConduitStreamSinkChannel channel) {
                                                                    // really lame, but due to the way SSL shuts down...
                                                                    if (clientReceived.get() == serverSent.get()) {
                                                                        try {
                                                                            channel.shutdownWrites();
                                                                            if (connection.isWriteShutdown())
                                                                                connection.close();
                                                                        } catch (Throwable t) {
                                                                            t.printStackTrace();
                                                                            throw new RuntimeException(t);
                                                                        }
                                                                    }
                                                                }
                                                            };
                                                            channel.setWriteListener(listener);
                                                            listener.handleEvent(channel);
                                                            return;
                                                        }
                                                    } catch (Throwable t) {
                                                        t.printStackTrace();
                                                        throw new RuntimeException(t);
                                                    }
                                                }
                                            };
                                            channel.getWriteSetter().set(listener);
                                            listener.handleEvent(channel);
                                            return;
                                        }
                                    }
                                    buffer.rewind();
                                }
                            } catch (ClosedChannelException e) {
                                channel.shutdownWrites();
                                throw e;
                            }
                        } catch (Throwable t) {
                            t.printStackTrace();
                            throw new RuntimeException(t);
                        }
                    }
                });
                connection.getSinkChannel().resumeWrites();
            }
        });
        assertEquals(serverSent.get(), clientReceived.get());
    }

    @Test
    public void twoWayTransferWithHandshake() throws Exception {
        log.info("Test: twoWayTransferWithHandshake");
        final CountDownLatch latch = new CountDownLatch(2);
        final AtomicInteger clientSent = new AtomicInteger(0);
        final AtomicInteger clientReceived = new AtomicInteger(0);
        final AtomicInteger serverSent = new AtomicInteger(0);
        final AtomicInteger serverReceived = new AtomicInteger(0);
        final AtomicBoolean clientHandshakeStarted = new AtomicBoolean(false);
        final AtomicBoolean serverHandshakeStarted = new AtomicBoolean(false);
        doConnectionTest(new Runnable() {
            public void run() {
                try {
                    assertTrue(latch.await(500L, TimeUnit.MILLISECONDS));
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
            }
        }, new ChannelListener<SslConnection>() {
            public void handleEvent(final SslConnection connection) {
                connection.getCloseSetter().set(new ChannelListener<SslConnection>() {
                    public void handleEvent(final SslConnection connection) {
                        latch.countDown();
                    }
                });
                final ConduitStreamSourceChannel sourceChannel = connection.getSourceChannel();
                sourceChannel.setReadListener(new ChannelListener<ConduitStreamSourceChannel>() {

                    private boolean continueReading() throws IOException {
                        return clientHandshakeStarted.get() || clientReceived.get() < 101;
                    }

                    public void handleEvent(final ConduitStreamSourceChannel sourceChannel) {
                        log.info("client handle read events");
                        try {
                            int c = 0;
                            while (continueReading() && (c = sourceChannel.read(ByteBuffer.allocate(100))) > 0) {
                                log.info("client received: "+ (clientReceived.get() + c));
                                clientReceived.addAndGet(c);
                            }
                            if (c == -1) {
                                log.info("client shutdown reads");
                                //sourceChannel.shutdownReads();
                                connection.close();
                            }
                        } catch (Throwable t) {
                            t.printStackTrace();
                        }
                    }
                });
                final ConduitStreamSinkChannel sinkChannel = connection.getSinkChannel();
                sinkChannel.setWriteListener(new ChannelListener<ConduitStreamSinkChannel>() {
                    private boolean continueWriting(ConduitStreamSinkChannel sinkChannel) throws IOException {
                        if (clientSent.get() > 100) {
                            if (!clientHandshakeStarted.get()) {
                                if (serverReceived.get() == clientSent.get() && serverSent.get() > 100 && clientReceived.get() == serverSent.get() ) {
                                    connection.startHandshake();
                                    log.info("client starting handshake");
                                    clientHandshakeStarted.set(true);
                                    return true;
                                }
                                return false;
                            }
                            if (clientHandshakeStarted.get()) {
                                return true;
                            }
                            return false;
                        }
                        return true;
                    }

                    public void handleEvent(final ConduitStreamSinkChannel sinkChannel) {
                                                try {
                            final ByteBuffer buffer = ByteBuffer.allocate(100);
                            buffer.put("This Is A Test\r\n".getBytes("UTF-8")).flip();
                            int c = 0;
                            try {
                                while (continueWriting(sinkChannel) && (clientSent.get() > 1000 || (c = sinkChannel.write(buffer)) > 0)) {
                                    log.info("clientSent: " + (clientSent.get() + c));
                                    if (clientSent.addAndGet(c) > 1000) {
                                        sinkChannel.setWriteListener(new ChannelListener<ConduitStreamSinkChannel>() {
                                            public void handleEvent(final ConduitStreamSinkChannel sinkChannel) {
                                                try {
                                                    if (sinkChannel.flush()) {
                                                        try {
                                                            log.info("client shutdown writes on " + sinkChannel);
                                                            sinkChannel.shutdownWrites();
                                                        } catch (Throwable t) {
                                                            t.printStackTrace();
                                                            throw new RuntimeException(t);
                                                        }
                                                        return;
                                                    }
                                                } catch (Throwable t) {
                                                    t.printStackTrace();
                                                    throw new RuntimeException(t);
                                                }
                                            }
                                        });
                                        return;
                                    }
                                    buffer.rewind();
                                }
                            } catch (ClosedChannelException e) {
                                try {
                                    sinkChannel.shutdownWrites();
                                } catch (Exception cce) {/* do nothing */}
                                throw e;
                            }
                        } catch (Throwable t) {
                            t.printStackTrace();
                            throw new RuntimeException(t);
                        }
                    }
                });

                sourceChannel.resumeReads();
                sinkChannel.resumeWrites();
            }
        }, new ChannelListener<SslConnection>() {
            public void handleEvent(final SslConnection connection) {
                connection.getCloseSetter().set(new ChannelListener<SslConnection>() {
                    public void handleEvent(final SslConnection connection) {
                        latch.countDown();
                    }
                });
                final ConduitStreamSourceChannel sourceChannel = connection.getSourceChannel();
                sourceChannel.setReadListener(new ChannelListener<ConduitStreamSourceChannel>() {
                    private boolean continueReading() throws IOException {
                        return serverHandshakeStarted.get() || serverReceived.get() < 101;
                    }

                    public void handleEvent(final ConduitStreamSourceChannel sourceChannel) {
                        try {
                            int c = 0;
                            while (continueReading() && (c = sourceChannel.read(ByteBuffer.allocate(100))) > 0) {
                                log.info("server received: "+ (serverReceived.get() + c));
                                serverReceived.addAndGet(c);
                            }
                            if (c == -1) {
                                log.info("server shutdown reads");
                                //sourceChannel.shutdownReads();
                                connection.close();
                            }
                        } catch (Throwable t) {
                            t.printStackTrace();
                            throw new RuntimeException(t);
                        }
                    }
                });
                final ConduitStreamSinkChannel sinkChannel = connection.getSinkChannel();
                sinkChannel.setWriteListener(new ChannelListener<ConduitStreamSinkChannel>() {

                    private boolean continueWriting(ConduitStreamSinkChannel sinkChannel) throws IOException {
                        if (serverSent.get() > 100) {
                            if (!serverHandshakeStarted.get()) {
                                if (clientReceived.get() == serverSent.get() && clientSent.get() > 100 && serverReceived.get() == clientSent.get() ) {
                                    connection.startHandshake();
                                    log.info("server starting handshake");
                                    serverHandshakeStarted.set(true);
                                    return true;
                                }
                                return false;
                            }
                            if (serverHandshakeStarted.get()) {
                                return true;
                            }
                            return false;
                        }
                        return true;
                    }

                    public void handleEvent(final ConduitStreamSinkChannel sinkChannel) {
                        try {
                            final ByteBuffer buffer = ByteBuffer.allocate(100);
                            buffer.put("This Is A Test\r\n".getBytes("UTF-8")).flip();
                            int c;
                            try {
                                while (continueWriting(sinkChannel) && (c = sinkChannel.write(buffer)) > 0) {
                                    log.info("server sent: "+ (serverSent.get() + c));
                                    if (serverSent.addAndGet(c) > 1000) {
                                        sinkChannel.setWriteListener(new ChannelListener<ConduitStreamSinkChannel>() {
                                            public void handleEvent(final ConduitStreamSinkChannel sinkChannel) {
                                                try {
                                                    if (sinkChannel.flush()) {
                                                        try {
                                                            log.info("server shutdown writes");
                                                            sinkChannel.shutdownWrites();
                                                        } catch (Throwable t) {
                                                            t.printStackTrace();
                                                            throw new RuntimeException(t);
                                                        }
                                                        return;
                                                    }
                                                } catch (Throwable t) {
                                                    t.printStackTrace();
                                                    throw new RuntimeException(t);
                                                }
                                            }
                                        });
                                        return;
                                    }
                                }
                                buffer.rewind();
                            } catch (ClosedChannelException e) {
                                sinkChannel.shutdownWrites();
                                throw e;
                            }
                        } catch (Throwable t) {
                            t.printStackTrace();
                            throw new RuntimeException(t);
                        }
                    }
                });
                sourceChannel.resumeReads();
                sinkChannel.resumeWrites();
            }
        });
        assertEquals(serverSent.get(), clientReceived.get());
        assertEquals(clientSent.get(), serverReceived.get());
    }
}