NioUdpTestCase.java

/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2013 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.xnio.nio.test;

import java.io.IOException;
import java.net.Inet4Address;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import junit.framework.TestCase;
import org.jboss.logging.Logger;
import org.junit.Ignore;
import org.xnio.Buffers;
import org.xnio.IoUtils;
import org.xnio.Xnio;
import org.xnio.OptionMap;
import org.xnio.ChannelListener;
import org.xnio.Options;
import org.xnio.XnioWorker;
import org.xnio.channels.MulticastMessageChannel;
import org.xnio.channels.SocketAddressBuffer;

/**
 * 
 * Test for UDP connections.
 * 
 * @author <a href="mailto:david.lloyd@redhat.com">David M. Lloyd</a>
 *
 */
public final class NioUdpTestCase extends TestCase {
    private static final int SERVER_PORT = 12345;
    private static final InetSocketAddress SERVER_SOCKET_ADDRESS;
    private static final InetSocketAddress CLIENT_SOCKET_ADDRESS;

    private static final Logger log = Logger.getLogger("TEST");

    static {
        try {
            SERVER_SOCKET_ADDRESS = new InetSocketAddress(Inet4Address.getByAddress(new byte[] {127, 0, 0, 1}), SERVER_PORT);
            CLIENT_SOCKET_ADDRESS = new InetSocketAddress(Inet4Address.getByAddress(new byte[] {127, 0, 0, 1}), 0);
        } catch (UnknownHostException e) {
            throw new RuntimeException(e);
        }
    }

    private synchronized void doServerSideTest(final boolean multicast, final ChannelListener<MulticastMessageChannel> handler, final Runnable body) throws IOException {
        final Xnio xnio = Xnio.getInstance("nio");
        doServerSidePart(multicast, handler, body, xnio.createWorker(OptionMap.EMPTY));
    }

    private void doServerSidePart(final boolean multicast, final ChannelListener<MulticastMessageChannel> handler, final Runnable body, final XnioWorker worker) throws IOException {
        doPart(multicast, handler, body, SERVER_SOCKET_ADDRESS, worker);
    }

    private void doClientSidePart(final boolean multicast, final ChannelListener<MulticastMessageChannel> handler, final Runnable body, final XnioWorker worker) throws IOException {
        doPart(multicast, handler, body, CLIENT_SOCKET_ADDRESS, worker);
    }

    private synchronized void doPart(final boolean multicast, final ChannelListener<MulticastMessageChannel> handler, final Runnable body, final InetSocketAddress bindAddress, final XnioWorker worker) throws IOException {
        final MulticastMessageChannel server = worker.createUdpServer(bindAddress, handler, OptionMap.create(Options.MULTICAST, Boolean.valueOf(multicast)));
        try {
            body.run();
            server.close();
        } catch (RuntimeException e) {
            log.errorf(e, "Error running part");
            throw e;
        } catch (IOException e) {
            log.errorf(e, "Error running part");
            throw e;
        } catch (Error e) {
            log.errorf(e, "Error running part");
            throw e;
        } finally {
            IoUtils.safeClose(server);
        }
    }

    private synchronized void doClientServerSide(final boolean clientMulticast, final boolean serverMulticast, final ChannelListener<MulticastMessageChannel> serverHandler, final ChannelListener<MulticastMessageChannel> clientHandler, final Runnable body) throws IOException {
        final Xnio xnio = Xnio.getInstance("nio");
        final XnioWorker worker = xnio.createWorker(OptionMap.EMPTY);
        try {
            doServerSidePart(serverMulticast, serverHandler, new Runnable() {
                public void run() {
                    try {
                        doClientSidePart(clientMulticast, clientHandler, body, worker);
                    } catch (IOException e) {
                        throw new RuntimeException(e);
                    }
                }
            }, worker);
        } finally {
            worker.shutdown();
            try {
                worker.awaitTermination(1L, TimeUnit.MINUTES);
            } catch (InterruptedException ignored) {
            }
        }
    }

    private void doServerCreate(boolean multicast) throws Exception {
        final CountDownLatch latch = new CountDownLatch(2);
        final AtomicBoolean openedOk = new AtomicBoolean(false);
        final AtomicBoolean closedOk = new AtomicBoolean(false);
        doServerSideTest(multicast, new ChannelListener<MulticastMessageChannel>() {
            public void handleEvent(final MulticastMessageChannel channel) {
                channel.getCloseSetter().set(new ChannelListener<MulticastMessageChannel>() {
                    public void handleEvent(final MulticastMessageChannel channel) {
                        closedOk.set(true);
                        latch.countDown();
                    }
                });
                log.infof("In handleEvent for %s", channel);
                openedOk.set(true);
                latch.countDown();
            }
        }, new Runnable() {
            public void run() {
            }
        });
        assertTrue(latch.await(500L, TimeUnit.MILLISECONDS));
        assertTrue(openedOk.get());
        assertTrue(closedOk.get());
    }

    public void testServerCreate() throws Exception {
        log.info("Test: testServerCreate");
        doServerCreate(false);
    }

    public void testServerCreateMulticast() throws Exception {
        log.info("Test: testServerCreateMulticast");
        doServerCreate(true);
    }

    @SuppressWarnings("unused")
    @Ignore /* XXX - depends on each server getting a separate thread */
    public void testClientToServerTransmitNioToNio() throws Exception {
        if (true) return;
        log.info("Test: testClientToServerTransmitNioToNio");
        final AtomicBoolean clientOK = new AtomicBoolean(false);
        final AtomicBoolean serverOK = new AtomicBoolean(false);
        final CountDownLatch startLatch = new CountDownLatch(1);
        final CountDownLatch receivedLatch = new CountDownLatch(1);
        final CountDownLatch doneLatch = new CountDownLatch(2);
        final byte[] payload = new byte[] { 10, 5, 15, 10, 100, -128, 30, 0, 0 };
        doClientServerSide(true, true, new ChannelListener<MulticastMessageChannel>() {
            public void handleEvent(final MulticastMessageChannel channel) {
                log.infof("In handleEvent for %s", channel);
                channel.getReadSetter().set(new ChannelListener<MulticastMessageChannel>() {
                    public void handleEvent(final MulticastMessageChannel channel) {
                        log.infof("In handleReadable for %s", channel);
                        try {
                            final ByteBuffer buffer = ByteBuffer.allocate(50);
                            final SocketAddressBuffer addressBuffer = new SocketAddressBuffer();
                            final int result = channel.receiveFrom(addressBuffer, buffer);
                            if (result == 0) {
                                log.infof("Whoops, spurious read notification for %s", channel);
                                channel.resumeReads();
                                return;
                            }
                            try {
                                final byte[] testPayload = new byte[payload.length];
                                Buffers.flip(buffer).get(testPayload);
                                log.infof("We received the packet on %s", channel);
                                assertTrue(Arrays.equals(testPayload, payload));
                                assertFalse(buffer.hasRemaining());
                                assertNotNull(addressBuffer.getSourceAddress());
                                try {
                                    channel.close();
                                    serverOK.set(true);
                                } finally {
                                    IoUtils.safeClose(channel);
                                }
                            } finally {
                                receivedLatch.countDown();
                                doneLatch.countDown();
                            }
                        } catch (IOException e) {
                            IoUtils.safeClose(channel);
                            throw new RuntimeException(e);
                        }
                    }
                });
                channel.resumeReads();
                startLatch.countDown();
            }
        }, new ChannelListener<MulticastMessageChannel>() {
            public void handleEvent(final MulticastMessageChannel channel) {
                log.infof("In handleEvent for %s", channel);
                channel.getWriteSetter().set(new ChannelListener<MulticastMessageChannel>() {
                    public void handleEvent(final MulticastMessageChannel channel) {
                        log.infof("In handleWritable for %s", channel);
                        try {
                            if (clientOK.get()) {
                                log.infof("Extra writable notification on %s (?!)", channel);
                            } else if (! channel.sendTo(SERVER_SOCKET_ADDRESS, ByteBuffer.wrap(payload))) {
                                log.infof("Whoops, spurious write notification for %s", channel);
                                channel.resumeWrites();
                            } else {
                                log.infof("We sent the packet on %s", channel);
                                try {
                                    assertTrue(receivedLatch.await(500000L, TimeUnit.MILLISECONDS));
                                    channel.close();
                                } finally {
                                    IoUtils.safeClose(channel);
                                }
                                clientOK.set(true);
                                doneLatch.countDown();
                            }
                        } catch (IOException e) {
                            IoUtils.safeClose(channel);
                            e.printStackTrace();
                        } catch (InterruptedException e) {
                            throw new RuntimeException(e);
                        }
                    }
                });
                try {
                    // wait until server is ready
                    assertTrue(startLatch.await(500000L, TimeUnit.MILLISECONDS));
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
                channel.resumeWrites();
            }
        }, new Runnable() {
            public void run() {
                try {
                    assertTrue(doneLatch.await(500000L, TimeUnit.MILLISECONDS));
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
            }
        });
        assertTrue(clientOK.get());
        assertTrue(serverOK.get());
    }
}