RedisEntraIDClusterIntegrationTests.java

package redis.clients.jedis.authentication;

import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.awaitility.Awaitility.await;
import static org.awaitility.Durations.*;

import java.util.Collections;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import redis.clients.authentication.core.TokenAuthConfig;
import redis.clients.authentication.entraid.EntraIDTokenAuthConfigBuilder;
import redis.clients.jedis.Connection;
import redis.clients.jedis.ConnectionPoolConfig;
import redis.clients.jedis.DefaultJedisClientConfig;
import redis.clients.jedis.EndpointConfig;
import redis.clients.jedis.HostAndPort;
import redis.clients.jedis.HostAndPorts;
import redis.clients.jedis.JedisClientConfig;
import redis.clients.jedis.JedisCluster;

public class RedisEntraIDClusterIntegrationTests {
    private static final Logger log = LoggerFactory
            .getLogger(RedisEntraIDClusterIntegrationTests.class);

    private static EntraIDTestContext testCtx;
    private static EndpointConfig endpointConfig;
    private static HostAndPort hnp;

    @BeforeAll
    public static void before() {
        try {
            testCtx = EntraIDTestContext.DEFAULT;
            endpointConfig = HostAndPorts.getRedisEndpoint("cluster-entraid-acl");
            hnp = endpointConfig.getHostAndPort();
        } catch (IllegalArgumentException e) {
            log.warn("Skipping test because no Redis endpoint is configured");
            assumeTrue(false, "No Redis endpoint 'standalone-entraid-acl' is configured!");
        }
    }

    @Test
    public void testClusterInitWithAuthXManager() {
        TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder()
                .lowerRefreshBoundMillis(1000).clientId(testCtx.getClientId())
                .secret(testCtx.getClientSecret()).authority(testCtx.getAuthority())
                .scopes(testCtx.getRedisScopes()).build();

        int defaultDirections = 5;
        JedisClientConfig config = DefaultJedisClientConfig.builder()
                .authXManager(new AuthXManager(tokenAuthConfig)).build();

        ConnectionPoolConfig DEFAULT_POOL_CONFIG = new ConnectionPoolConfig();

        try (JedisCluster jc = new JedisCluster(hnp, config, defaultDirections,
                DEFAULT_POOL_CONFIG)) {

            assertEquals("OK", jc.set("foo", "bar"));
            assertEquals("bar", jc.get("foo"));
            assertEquals(1, jc.del("foo"));
        }
    }

    @Test
    public void testClusterWithReAuth() throws InterruptedException, ExecutionException {
        TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder()
                // 0.00002F is to make it fit into 2 seconds, we need at least 2 attempt in 2 seconds
                // to trigger re-authentication.
                // For expiration time between 30 minutes to 12 hours 
                // token renew will happen in from 36ms up to 864ms
                // If the received token has more than 12 hours to expire, this test will probably fail, and need to be adjusted.
                .expirationRefreshRatio(0.00002F).clientId(testCtx.getClientId())
                .secret(testCtx.getClientSecret()).authority(testCtx.getAuthority())
                .scopes(testCtx.getRedisScopes()).build();

        AuthXManager authXManager = new AuthXManager(tokenAuthConfig);

        authXManager = spy(authXManager);

        List<Connection> connections = new CopyOnWriteArrayList<>();
        doAnswer(invocation -> {
            Connection connection = spy((Connection) invocation.getArgument(0));
            invocation.getArguments()[0] = connection;
            connections.add(connection);
            return invocation.callRealMethod();
        }).when(authXManager).addConnection(any(Connection.class));

        JedisClientConfig config = DefaultJedisClientConfig.builder().authXManager(authXManager)
                .build();

        ExecutorService executorService = Executors.newFixedThreadPool(2);
        CountDownLatch latch = new CountDownLatch(1);
        try (JedisCluster jc = new JedisCluster(Collections.singleton(hnp), config)) {
            Runnable task = () -> {
                while (latch.getCount() > 0) {
                    assertEquals("OK", jc.set("foo", "bar"));
                }
            };
            Future<?> task1 = executorService.submit(task);
            Future<?> task2 = executorService.submit(task);

            await().pollInterval(ONE_HUNDRED_MILLISECONDS).atMost(TWO_SECONDS)
                    .until(connections::size, greaterThanOrEqualTo(2));

            connections.forEach(conn -> {
                await().pollInterval(ONE_HUNDRED_MILLISECONDS).atMost(TWO_SECONDS)
                        .untilAsserted(() -> verify(conn, atLeast(2)).reAuthenticate());
            });
            latch.countDown();
            task1.get();
            task2.get();
        } finally {
            latch.countDown();
            executorService.shutdown();
        }
    }
}