ShedConnectionsCommandTest.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.zookeeper.server.admin;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.util.Map;
import javax.servlet.http.HttpServletResponse;
import org.apache.zookeeper.server.ServerCnxnFactory;
import org.apache.zookeeper.server.ZooKeeperServer;
import org.junit.jupiter.api.Test;

public class ShedConnectionsCommandTest {

    private static final String VALID_JSON_25_PERCENT = "{\"percentage\": 25}";
    private static final String VALID_JSON_100_PERCENT = "{\"percentage\": 100}";
    private static final String VALID_JSON_1_PERCENT = "{\"percentage\": 1}";
    private static final String VALID_JSON_0_PERCENT = "{\"percentage\": 0}";

    private static final String INVALID_JSON_OVER_100_PERCENT = "{\"percentage\": 101}";

    private static final String INVALID_JSON_MISSING_FIELD = "{\"other\": 25}";
    private static final String INVALID_JSON_MALFORMED = "{\"percentage\": }";
    private static final String INVALID_JSON_EMPTY = "{}";

    @Test
    public void testValidPercentage25() {
        validateSuccessfulShedCommand(25, 50, 30, VALID_JSON_25_PERCENT, true, true);
    }

    @Test
    public void testValidPercentage100() {
        validateSuccessfulShedCommand(100, 20, 10, VALID_JSON_100_PERCENT, true, true);
    }

    @Test
    public void testValidPercentage1() {
        validateSuccessfulShedCommand(1, 100, 0, VALID_JSON_1_PERCENT, true, false);
    }

    @Test
    public void testValidPercentage0() {
        validateSuccessfulShedCommand(0, 100, 50, VALID_JSON_0_PERCENT, false, false);
    }

    @Test
    public void testInvalidPercentage101() {
        validateFailedShedCommand(INVALID_JSON_OVER_100_PERCENT, "Percentage must be between 0 and 100", true);
    }

    @Test
    public void testInvalidNullInputStream() {
        validateFailedShedCommand(null, "Request body is required", true);
    }

    @Test
    public void testEmptyJson() {
        validateFailedShedCommand(INVALID_JSON_EMPTY, "Missing required field: percentage", true);
    }

    @Test
    public void testMissingPercentageParameter() {
        validateFailedShedCommand(INVALID_JSON_MISSING_FIELD, "Missing required field: percentage", true);
    }

    @Test
    public void testMalformedJson() {
        validateFailedShedCommand(INVALID_JSON_MALFORMED, "Invalid JSON or failed to read request body", false);
    }

    @Test
    public void testOnlyInsecureConnections() {
        validateSuccessfulShedCommand(25, 40, 0, VALID_JSON_25_PERCENT, true, false);
    }

    @Test
    public void testOnlySecureConnections() {
        validateSuccessfulShedCommand(25, 0, 60, VALID_JSON_25_PERCENT, false, true);
    }

    @Test
    public void testNoConnections() {
        validateSuccessfulShedCommand(25, 0, 0, VALID_JSON_25_PERCENT, false, false);
    }

    @Test
    public void testMixedConnections() {
        validateSuccessfulShedCommand(25, 30, 20, VALID_JSON_25_PERCENT, true, true);
    }

    @Test
    public void testCommandNames() {
        final Commands.ShedConnectionsCommand command = new Commands.ShedConnectionsCommand();
        assertEquals(2, command.getNames().size());
        assertTrue(command.getNames().contains("shed"));
        assertTrue(command.getNames().contains("shed_connections"));
    }

    @Test
    public void testAuthorizationRequired() {
        final Commands.ShedConnectionsCommand command = new Commands.ShedConnectionsCommand();
        final AuthRequest authRequest = command.getAuthRequest();

        assertNotNull(authRequest);
        assertEquals(org.apache.zookeeper.ZooDefs.Perms.ALL, authRequest.getPermission());
        assertEquals(Commands.ROOT_PATH, authRequest.getPath());
    }

    private void validateSuccessfulShedCommand(
            final int expectedPercentage,
            final int insecureConnections,
            final int secureConnections,
            final String jsonInput,
            final boolean shouldCallInsecureFactory,
            final boolean shouldCallSecureFactory) {

        final Commands.ShedConnectionsCommand command = new Commands.ShedConnectionsCommand();
        final ZooKeeperServer zkServer = createMockZooKeeperServer(insecureConnections, secureConnections);
        final InputStream inputStream = new ByteArrayInputStream(jsonInput.getBytes());
        final int totalConnections = insecureConnections + secureConnections;

        final CommandResponse response = command.runPost(zkServer, inputStream);
        assertSuccessfulResponse(response, expectedPercentage, totalConnections);
        assertFactoryCalls(zkServer, expectedPercentage, shouldCallInsecureFactory, shouldCallSecureFactory);
    }

    private void validateFailedShedCommand(
            final String jsonInput,
            final String expectedError,
            final boolean exactMatch) {

        final Commands.ShedConnectionsCommand command = new Commands.ShedConnectionsCommand();
        final ZooKeeperServer zkServer = createMockZooKeeperServer(10, 10);
        final InputStream inputStream = jsonInput != null ? new ByteArrayInputStream(jsonInput.getBytes()) : null;

        final CommandResponse response = command.runPost(zkServer, inputStream);

        assertNotNull(response);
        assertEquals(HttpServletResponse.SC_BAD_REQUEST, response.getStatusCode());

        final Map<String, Object> result = response.toMap();
        final String actualError = (String) result.get("error");

        if (exactMatch) {
            assertEquals(expectedError, actualError);
        } else {
            assertTrue(actualError.contains(expectedError),
                    String.format("Expected error message to contain '%s', but was '%s'", expectedError, actualError));
        }
    }

    private void assertSuccessfulResponse(
            final CommandResponse response,
            final int expectedPercentage,
            final int totalConnections) {

        assertNotNull(response);
        assertEquals(HttpServletResponse.SC_OK, response.getStatusCode());

        final Map<String, Object> result = response.toMap();
        assertEquals(expectedPercentage, result.get("percentage_requested"));

        assertTrue(result.containsKey("connections_shed"));
        final int actualShed = (Integer) result.get("connections_shed");
        assertTrue(actualShed >= 0, "Shed count should be non-negative");
        assertTrue(actualShed <= totalConnections, "Cannot shed more than total connections");

        // For 0% and 100%, we can make exact assertions
        if (expectedPercentage == 0) {
            assertEquals(0, actualShed, "0% should shed exactly 0 connections");
        } else if (expectedPercentage == 100) {
            assertEquals(totalConnections, actualShed, "100% should shed all connections");
        }
    }

    private void assertFactoryCalls(
            final ZooKeeperServer zkServer,
            final int percentage,
            final boolean shouldCallInsecureFactory,
            final boolean shouldCallSecureFactory) {

        final ServerCnxnFactory factory = zkServer.getServerCnxnFactory();
        final ServerCnxnFactory secureFactory = zkServer.getSecureServerCnxnFactory();

        if (factory != null) {
            if (shouldCallInsecureFactory) {
                verify(factory, times(1)).shedConnections(percentage);
            } else {
                verify(factory, never()).shedConnections(anyInt());
            }
        }

        if (secureFactory != null) {
            if (shouldCallSecureFactory) {
                verify(secureFactory, times(1)).shedConnections(percentage);
            } else {
                verify(secureFactory, never()).shedConnections(anyInt());
            }
        }
    }

    private ZooKeeperServer createMockZooKeeperServer(int insecureConnections, int secureConnections) {
        final ZooKeeperServer zkServer = mock(ZooKeeperServer.class);
        final int totalConnections = insecureConnections + secureConnections;

        when(zkServer.getNumAliveConnections()).thenReturn(totalConnections);

        // Mock insecure factory
        ServerCnxnFactory factory = null;
        if (insecureConnections > 0) {
            factory = createMockServerCnxnFactory(insecureConnections);
        }
        when(zkServer.getServerCnxnFactory()).thenReturn(factory);

        // Mock secure factory
        ServerCnxnFactory secureFactory = null;
        if (secureConnections > 0) {
            secureFactory = createMockServerCnxnFactory(secureConnections);
        }
        when(zkServer.getSecureServerCnxnFactory()).thenReturn(secureFactory);

        return zkServer;
    }

    private ServerCnxnFactory createMockServerCnxnFactory(int connections) {
        final ServerCnxnFactory factory = mock(ServerCnxnFactory.class);
        when(factory.getNumAliveConnections()).thenReturn(connections);
        when(factory.shedConnections(anyInt())).thenAnswer(invocation -> {
            int percentage = invocation.getArgument(0);
            if (percentage == 0) {
                return 0;
            }
            if (percentage == 100) {
                return connections;
            }
            return (int) Math.ceil(connections * percentage / 100.0);
        });
        return factory;
    }
}