NettyServerCnxnTest.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.zookeeper.server;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.when;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelId;
import io.netty.channel.ChannelPipeline;
import io.netty.util.Attribute;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.ProtocolException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.zookeeper.AsyncCallback.DataCallback;
import org.apache.zookeeper.ClientCnxnSocketNetty;
import org.apache.zookeeper.CreateMode;
import org.apache.zookeeper.KeeperException;
import org.apache.zookeeper.ZooDefs.Ids;
import org.apache.zookeeper.ZooKeeper;
import org.apache.zookeeper.client.ZKClientConfig;
import org.apache.zookeeper.common.ClientX509Util;
import org.apache.zookeeper.common.NettyUtils;
import org.apache.zookeeper.data.Stat;
import org.apache.zookeeper.server.quorum.BufferStats;
import org.apache.zookeeper.server.quorum.LeaderZooKeeperServer;
import org.apache.zookeeper.test.ClientBase;
import org.apache.zookeeper.test.SSLAuthTest;
import org.apache.zookeeper.test.TestByteBufAllocator;
import org.apache.zookeeper.test.TestUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.mockito.Mockito;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Test verifies the behavior of NettyServerCnxn which represents a connection
 * from a client to the server.
 */
public class NettyServerCnxnTest extends ClientBase {

    private static final Logger LOG = LoggerFactory.getLogger(NettyServerCnxnTest.class);

    @BeforeEach
    @Override
    public void setUp() throws Exception {
        System.setProperty(ServerCnxnFactory.ZOOKEEPER_SERVER_CNXN_FACTORY, "org.apache.zookeeper.server.NettyServerCnxnFactory");
        NettyServerCnxnFactory.setTestAllocator(TestByteBufAllocator.getInstance());
        super.maxCnxns = 1;
        super.exceptionOnFailedConnect = true;
        super.setUp();
    }

    @AfterEach
    @Override
    public void tearDown() throws Exception {
        super.tearDown();
        NettyServerCnxnFactory.clearTestAllocator();
        TestByteBufAllocator.checkForLeaks();
    }

    /**
     * Test verifies the channel closure - while closing the channel
     * servercnxnfactory should remove all channel references to avoid
     * duplicate channel closure. Duplicate closure may result in indefinite
     * hanging due to netty open issue.
     *
     * @see <a href="https://issues.jboss.org/browse/NETTY-412">NETTY-412</a>
     */
    @Test
    @Timeout(value = 40)
    public void testSendCloseSession() throws Exception {
        assertTrue(serverFactory instanceof NettyServerCnxnFactory, "Didn't instantiate ServerCnxnFactory with NettyServerCnxnFactory!");

        final ZooKeeper zk = createClient();
        final ZooKeeperServer zkServer = serverFactory.getZooKeeperServer();
        final String path = "/a";
        try {
            // make sure zkclient works
            zk.create(path, "test".getBytes(StandardCharsets.UTF_8), Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT);
            // set on watch
            assertNotNull(zk.exists(path, true), "Didn't create znode:" + path);
            assertEquals(1, zkServer.getZKDatabase().getDataTree().getWatchCount());
            Iterable<ServerCnxn> connections = serverFactory.getConnections();
            assertEquals(1, serverFactory.getNumAliveConnections(), "Mismatch in number of live connections!");
            for (ServerCnxn serverCnxn : connections) {
                serverCnxn.sendCloseSession();
            }
            LOG.info("Waiting for the channel disconnected event");
            int timeout = 0;
            while (serverFactory.getNumAliveConnections() != 0) {
                Thread.sleep(1000);
                timeout += 1000;
                if (timeout > CONNECTION_TIMEOUT) {
                    fail("The number of live connections should be 0");
                }
            }
            // make sure the watch is removed when the connection closed
            assertEquals(0, zkServer.getZKDatabase().getDataTree().getWatchCount());
        } finally {
            zk.close();
        }
    }

    /**
     * In the {@link #setUp()} routine, the maximum number of connections per IP
     * is set to 1. This tests that if more than one connection is attempted, the
     * connection fails.
     */
    @Test
    @Timeout(value = 40)
    public void testMaxConnectionPerIpSurpased() {
        assertTrue(serverFactory instanceof NettyServerCnxnFactory, "Did not instantiate ServerCnxnFactory with NettyServerCnxnFactory!");
        assertThrows(ProtocolException.class, () -> {
            try (final ZooKeeper zk1 = createClient(); final ZooKeeper zk2 = createClient()) {
            }
        });
    }

    @Test
    public void testClientResponseStatsUpdate() throws IOException, InterruptedException, KeeperException {
        try (ZooKeeper zk = createClient()) {
            BufferStats clientResponseStats = serverFactory.getZooKeeperServer().serverStats().getClientResponseStats();
            assertThat("Last client response size should be initialized with INIT_VALUE", clientResponseStats.getLastBufferSize(), equalTo(BufferStats.INIT_VALUE));

            zk.create("/a", "test".getBytes(StandardCharsets.UTF_8), Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT);

            assertThat("Last client response size should be greater than 0 after client request was performed", clientResponseStats.getLastBufferSize(), greaterThan(0));

            byte[] contents = zk.getData("/a", null, null);
            assertArrayEquals("test".getBytes(StandardCharsets.UTF_8), contents, "unexpected data");
        }
    }

    @Test
    public void testNonMTLSLocalConn() throws IOException, InterruptedException, KeeperException {
        try (ZooKeeper zk = createClient()) {
            ServerStats serverStats = serverFactory.getZooKeeperServer().serverStats();
            //2 for local stat connection and this client
            assertEquals(2, serverStats.getNonMTLSLocalConnCount());
            assertEquals(0, serverStats.getNonMTLSRemoteConnCount());
        }
    }

    @Test
    public void testNonMTLSRemoteConn() throws Exception {
        LeaderZooKeeperServer zks = mock(LeaderZooKeeperServer.class);
        when(zks.isRunning()).thenReturn(true);
        ServerStats.Provider providerMock = mock(ServerStats.Provider.class);
        when(zks.serverStats()).thenReturn(new ServerStats(providerMock));
        testNonMTLSRemoteConn(zks, false, false);
    }

    @Test
    public void testNonMTLSRemoteConnZookKeeperServerNotReady() throws Exception {
        testNonMTLSRemoteConn(null, false, false);
    }

    @Test
    public void testNonMTLSRemoteConnZookKeeperServerNotReadyEarlyDropEnabled() throws Exception {
        testNonMTLSRemoteConn(null, false, true);
    }

    @Test
    public void testMTLSRemoteConnZookKeeperServerNotReadyEarlyDropEnabled() throws Exception {
        testNonMTLSRemoteConn(null, true, true);
    }

    @Test
    public void testMTLSRemoteConnZookKeeperServerNotReadyEarlyDropDisabled() throws Exception {
        testNonMTLSRemoteConn(null, true, true);
    }

    @SuppressWarnings("unchecked")
    private void testNonMTLSRemoteConn(ZooKeeperServer zks, boolean secure, boolean earlyDrop) throws Exception {
        try {
            System.setProperty(NettyServerCnxnFactory.EARLY_DROP_SECURE_CONNECTION_HANDSHAKES, earlyDrop + "");

            Channel channel = mock(Channel.class);
            ChannelId id = mock(ChannelId.class);
            ChannelFuture success = mock(ChannelFuture.class);
            ChannelHandlerContext context = mock(ChannelHandlerContext.class);
            ChannelPipeline channelPipeline = mock(ChannelPipeline.class);

            when(context.channel()).thenReturn(channel);
            when(channel.pipeline()).thenReturn(channelPipeline);
            when(success.channel()).thenReturn(channel);
            when(channel.closeFuture()).thenReturn(success);

            InetSocketAddress address = new InetSocketAddress(0);
            when(channel.remoteAddress()).thenReturn(address);
            when(channel.id()).thenReturn(id);
            NettyServerCnxnFactory factory = new NettyServerCnxnFactory();
            factory.setSecure(secure);
            factory.setZooKeeperServer(zks);
            Attribute atr = mock(Attribute.class);
            Mockito.doReturn(atr).when(channel).attr(
                    Mockito.any()
            );
            doNothing().when(atr).set(Mockito.any());
            factory.channelHandler.channelActive(context);

            if (zks != null)  {
                assertEquals(0, zks.serverStats().getNonMTLSLocalConnCount());
                assertEquals(1, zks.serverStats().getNonMTLSRemoteConnCount());
            } else {
                if (earlyDrop && secure) {
                    // the channel must have been forcibly closed
                    Mockito.verify(channel, times(1)).close();
                } else {
                    Mockito.verify(channel, times(0)).close();
                }
            }
        } finally {
            System.clearProperty(NettyServerCnxnFactory.EARLY_DROP_SECURE_CONNECTION_HANDSHAKES);
        }
    }

    @Test
    public void testServerSideThrottling() throws IOException, InterruptedException, KeeperException {
        try (ZooKeeper zk = createClient()) {
            BufferStats clientResponseStats = serverFactory.getZooKeeperServer().serverStats().getClientResponseStats();
            assertThat("Last client response size should be initialized with INIT_VALUE", clientResponseStats.getLastBufferSize(), equalTo(BufferStats.INIT_VALUE));

            zk.create("/a", "test".getBytes(StandardCharsets.UTF_8), Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT);

            assertThat("Last client response size should be greater than 0 after client request was performed", clientResponseStats.getLastBufferSize(), greaterThan(0));

            for (final ServerCnxn cnxn : serverFactory.cnxns) {
                final NettyServerCnxn nettyCnxn = ((NettyServerCnxn) cnxn);
                // Disable receiving data for all open connections ...
                nettyCnxn.disableRecv();
                // ... then force a throttled read after 1 second (this puts the read into queuedBuffer) ...
                nettyCnxn.getChannel().eventLoop().schedule(new Runnable() {
                    @Override
                    public void run() {
                        nettyCnxn.getChannel().read();
                    }
                }, 1, TimeUnit.SECONDS);

                // ... and finally disable throttling after 2 seconds.
                nettyCnxn.getChannel().eventLoop().schedule(new Runnable() {
                    @Override
                    public void run() {
                        nettyCnxn.enableRecv();
                    }
                }, 2, TimeUnit.SECONDS);
            }

            byte[] contents = zk.getData("/a", null, null);
            assertArrayEquals("test".getBytes(StandardCharsets.UTF_8), contents, "unexpected data");

            // As above, but don't do the throttled read. Make the request bytes wait in the socket
            // input buffer until after throttling is turned off. Need to make sure both modes work.
            for (final ServerCnxn cnxn : serverFactory.cnxns) {
                final NettyServerCnxn nettyCnxn = ((NettyServerCnxn) cnxn);
                // Disable receiving data for all open connections ...
                nettyCnxn.disableRecv();
                // ... then disable throttling after 2 seconds.
                nettyCnxn.getChannel().eventLoop().schedule(new Runnable() {
                    @Override
                    public void run() {
                        nettyCnxn.enableRecv();
                    }
                }, 2, TimeUnit.SECONDS);
            }

            contents = zk.getData("/a", null, null);
            assertArrayEquals("test".getBytes(StandardCharsets.UTF_8), contents, "unexpected data");
        }
    }

    @Test
    public void testEnableDisableThrottling_secure_random() throws Exception {
        runEnableDisableThrottling(true, true);
    }

    @Test
    public void testEnableDisableThrottling_secure_sequentially() throws Exception {
        runEnableDisableThrottling(true, false);
    }

    @Test
    public void testEnableDisableThrottling_nonSecure_random() throws Exception {
        runEnableDisableThrottling(false, true);
    }

    @Test
    public void testEnableDisableThrottling_nonSecure_sequentially() throws Exception {
        runEnableDisableThrottling(false, false);
    }

    @Test
    public void testNettyUsesDaemonThreads() throws Exception {
        assertTrue(serverFactory instanceof NettyServerCnxnFactory,
                "Didn't instantiate ServerCnxnFactory with NettyServerCnxnFactory!");

        // Use Netty in the client to check the threads on both the client and server side
        System.setProperty(ZKClientConfig.ZOOKEEPER_CLIENT_CNXN_SOCKET, ClientCnxnSocketNetty.class.getName());
        try {
            final ZooKeeperServer zkServer = serverFactory.getZooKeeperServer();
            try (ZooKeeper zk = createClient()) {
                final String path = "/a";
                // make sure connection is established
                zk.create(path, "test".getBytes(StandardCharsets.UTF_8), Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT);

                List<Thread> threads = TestUtils.getAllThreads();
                boolean foundThread = false;
                for (Thread t : threads) {
                    if (t.getName().startsWith(NettyUtils.THREAD_POOL_NAME_PREFIX)) {
                        foundThread = true;
                        assertTrue(t.isDaemon(), "All Netty threads started by ZK must daemon threads");
                    }
                }
                assertTrue(foundThread, "Did not find any Netty ZK Threads");
            } finally {
                zkServer.shutdown();
            }
        } finally {
            System.clearProperty(ZKClientConfig.ZOOKEEPER_CLIENT_CNXN_SOCKET);
        }
    }

    private void runEnableDisableThrottling(boolean secure, boolean randomDisableEnable) throws Exception {
        ClientX509Util x509Util = null;
        if (secure) {
            x509Util = SSLAuthTest.setUpSecure();
        }
        try {
            NettyServerCnxnFactory factory = (NettyServerCnxnFactory) serverFactory;
            factory.setAdvancedFlowControlEnabled(true);
            if (secure) {
                factory.setSecure(true);
            }

            final String path = "/testEnableDisableThrottling";
            try (ZooKeeper zk = createClient()) {
                zk.create(path, new byte[1], Ids.OPEN_ACL_UNSAFE, CreateMode.PERSISTENT);

                // meanwhile start another thread to enable and disable recv
                AtomicBoolean stopped = new AtomicBoolean(false);
                Random random = new Random();

                Thread enableDisableThread = null;
                if (randomDisableEnable) {
                    enableDisableThread = new Thread() {
                        @Override
                        public void run() {
                            while (!stopped.get()) {
                                for (final ServerCnxn cnxn : serverFactory.cnxns) {
                                    boolean shouldDisableEnable = random.nextBoolean();
                                    if (shouldDisableEnable) {
                                        cnxn.disableRecv();
                                    } else {
                                        cnxn.enableRecv();
                                    }
                                }
                                try {
                                    Thread.sleep(10);
                                } catch (InterruptedException e) { /* ignore */ }
                            }
                            // always enable the recv at end
                            for (final ServerCnxn cnxn : serverFactory.cnxns) {
                                cnxn.enableRecv();
                            }
                        }
                    };
                } else {
                    enableDisableThread = new Thread() {
                        @Override
                        public void run() {
                            while (!stopped.get()) {
                                for (final ServerCnxn cnxn : serverFactory.cnxns) {
                                    try {
                                        cnxn.disableRecv();
                                        Thread.sleep(10);
                                        cnxn.enableRecv();
                                        Thread.sleep(10);
                                    } catch (InterruptedException e) { /* ignore */ }
                                }
                            }
                        }
                    };
                }
                enableDisableThread.start();
                LOG.info("started thread to enable and disable recv");

                // start a thread to keep sending requests
                int totalRequestsNum = 100000;
                AtomicInteger successResponse = new AtomicInteger();
                CountDownLatch responseReceivedLatch = new CountDownLatch(totalRequestsNum);
                Thread clientThread = new Thread() {
                    @Override
                    public void run() {
                        int requestIssued = 0;
                        while (requestIssued++ < totalRequestsNum) {
                            zk.getData(path, null, new DataCallback() {
                                @Override
                                public void processResult(int rc, String path, Object ctx, byte[] data, Stat stat) {
                                    if (rc == KeeperException.Code.OK.intValue()) {
                                        successResponse.addAndGet(1);
                                    } else {
                                        LOG.info("failed response is {}", rc);
                                    }
                                    responseReceivedLatch.countDown();
                                }
                            }, null);
                        }
                    }
                };
                clientThread.start();
                LOG.info("started thread to issue {} async requests", totalRequestsNum);

                // and verify the response received is same as what we issued
                assertTrue(responseReceivedLatch.await(60, TimeUnit.SECONDS));
                LOG.info("received all {} responses", totalRequestsNum);

                stopped.set(true);
                enableDisableThread.join();
                LOG.info("enable and disable recv thread exited");

                // wait another second for the left requests to finish
                LOG.info("waiting another 1s for the requests to go through");
                Thread.sleep(1000);
                assertEquals(successResponse.get(), totalRequestsNum);
            }
        } finally {
            if (secure) {
                SSLAuthTest.clearSecureSetting(x509Util);
            }
        }
    }
}