TokenBasedAuthenticationClusterIntegrationTests.java
package redis.clients.jedis.authentication;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.awaitility.Awaitility.await;
import static org.awaitility.Durations.*;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import redis.clients.authentication.core.IdentityProvider;
import redis.clients.authentication.core.IdentityProviderConfig;
import redis.clients.authentication.core.SimpleToken;
import redis.clients.authentication.core.Token;
import redis.clients.authentication.core.TokenAuthConfig;
import redis.clients.jedis.Connection;
import redis.clients.jedis.ConnectionPoolConfig;
import redis.clients.jedis.DefaultJedisClientConfig;
import redis.clients.jedis.EndpointConfig;
import redis.clients.jedis.HostAndPort;
import redis.clients.jedis.HostAndPorts;
import redis.clients.jedis.JedisClientConfig;
import redis.clients.jedis.JedisCluster;
public class TokenBasedAuthenticationClusterIntegrationTests {
private static final Logger log = LoggerFactory
.getLogger(TokenBasedAuthenticationClusterIntegrationTests.class);
private static EndpointConfig endpointConfig;
private static HostAndPort hnp;
@BeforeAll
public static void before() {
try {
endpointConfig = HostAndPorts.getRedisEndpoint("cluster");
hnp = endpointConfig.getHostAndPort();
} catch (IllegalArgumentException e) {
log.warn("Skipping test because no Redis endpoint is configured");
assumeTrue(false, "No Redis endpoint 'cluster' is configured!");
}
}
@Test
public void testClusterInitWithAuthXManager() {
IdentityProviderConfig idpConfig = new IdentityProviderConfig() {
@Override
public IdentityProvider getProvider() {
return new IdentityProvider() {
@Override
public Token requestToken() {
return new SimpleToken(endpointConfig.getUsername(),
endpointConfig.getPassword() == null ? ""
: endpointConfig.getPassword(),
System.currentTimeMillis() + 5 * 1000, System.currentTimeMillis(),
null);
}
};
}
};
AuthXManager manager = new AuthXManager(TokenAuthConfig.builder()
.lowerRefreshBoundMillis(1000).tokenRequestExecTimeoutInMs(3000)
.identityProviderConfig(idpConfig).build());
int defaultDirections = 5;
JedisClientConfig config = DefaultJedisClientConfig.builder().authXManager(manager).build();
ConnectionPoolConfig DEFAULT_POOL_CONFIG = new ConnectionPoolConfig();
try (JedisCluster jc = new JedisCluster(hnp, config, defaultDirections,
DEFAULT_POOL_CONFIG)) {
assertEquals("OK", jc.set("foo", "bar"));
assertEquals("bar", jc.get("foo"));
assertEquals(1, jc.del("foo"));
}
}
@Test
public void testClusterWithReAuth() throws InterruptedException, ExecutionException {
IdentityProviderConfig idpConfig = new IdentityProviderConfig() {
@Override
public IdentityProvider getProvider() {
return new IdentityProvider() {
@Override
public Token requestToken() {
return new SimpleToken(endpointConfig.getUsername(),
endpointConfig.getPassword() == null ? ""
: endpointConfig.getPassword(),
System.currentTimeMillis() + 5 * 1000, System.currentTimeMillis(),
null);
}
};
}
};
AuthXManager authXManager = new AuthXManager(TokenAuthConfig.builder()
.lowerRefreshBoundMillis(4600).tokenRequestExecTimeoutInMs(3000)
.identityProviderConfig(idpConfig).build());
authXManager = spy(authXManager);
List<Connection> connections = new CopyOnWriteArrayList<>();
doAnswer(invocation -> {
Connection connection = spy((Connection) invocation.getArgument(0));
invocation.getArguments()[0] = connection;
connections.add(connection);
return invocation.callRealMethod();
}).when(authXManager).addConnection(any(Connection.class));
JedisClientConfig config = DefaultJedisClientConfig.builder().authXManager(authXManager)
.build();
ExecutorService executorService = Executors.newFixedThreadPool(2);
CountDownLatch latch = new CountDownLatch(1);
try (JedisCluster jc = new JedisCluster(Collections.singleton(hnp), config)) {
Runnable task = () -> {
while (latch.getCount() > 0) {
assertEquals("OK", jc.set("foo", "bar"));
}
};
Future<?> task1 = executorService.submit(task);
Future<?> task2 = executorService.submit(task);
await().pollInterval(ONE_HUNDRED_MILLISECONDS).atMost(TWO_SECONDS)
.until(connections::size, greaterThanOrEqualTo(2));
connections.forEach(conn -> {
await().pollInterval(ONE_HUNDRED_MILLISECONDS).atMost(TWO_SECONDS)
.untilAsserted(() -> verify(conn, atLeast(2)).reAuthenticate());
});
latch.countDown();
task1.get();
task2.get();
} finally {
latch.countDown();
executorService.shutdown();
}
}
}