RedisEntraIDIntegrationTests.java
package redis.clients.jedis.authentication;
import static org.awaitility.Awaitility.await;
import static org.awaitility.Durations.TWO_SECONDS;
import static org.awaitility.Durations.FIVE_SECONDS;
import static org.awaitility.Durations.ONE_HUNDRED_MILLISECONDS;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockConstruction;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.hamcrest.Matchers.in;
import static org.hamcrest.Matchers.is;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import org.awaitility.Awaitility;
import org.awaitility.Durations;
import org.hamcrest.MatcherAssert;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.MethodOrderer;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestMethodOrder;
import org.mockito.MockedConstruction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.azure.identity.DefaultAzureCredential;
import com.azure.identity.DefaultAzureCredentialBuilder;
import redis.clients.authentication.core.IdentityProvider;
import redis.clients.authentication.core.IdentityProviderConfig;
import redis.clients.authentication.core.SimpleToken;
import redis.clients.authentication.core.Token;
import redis.clients.authentication.core.TokenAuthConfig;
import redis.clients.authentication.entraid.AzureTokenAuthConfigBuilder;
import redis.clients.authentication.entraid.EntraIDIdentityProvider;
import redis.clients.authentication.entraid.EntraIDIdentityProviderConfig;
import redis.clients.authentication.entraid.EntraIDTokenAuthConfigBuilder;
import redis.clients.authentication.entraid.ServicePrincipalInfo;
import redis.clients.jedis.Connection;
import redis.clients.jedis.DefaultJedisClientConfig;
import redis.clients.jedis.EndpointConfig;
import redis.clients.jedis.HostAndPort;
import redis.clients.jedis.HostAndPorts;
import redis.clients.jedis.JedisPooled;
import redis.clients.jedis.exceptions.JedisAccessControlException;
import redis.clients.jedis.exceptions.JedisConnectionException;
import redis.clients.jedis.scenario.FaultInjectionClient;
import redis.clients.jedis.scenario.FaultInjectionClient.TriggerActionResponse;
@TestMethodOrder(MethodOrderer.MethodName.class)
public class RedisEntraIDIntegrationTests {
private static final Logger log = LoggerFactory.getLogger(RedisEntraIDIntegrationTests.class);
private static EntraIDTestContext testCtx;
private static EndpointConfig endpointConfig;
private static HostAndPort hnp;
private final FaultInjectionClient faultClient = new FaultInjectionClient();
@BeforeAll
public static void before() {
try {
testCtx = EntraIDTestContext.DEFAULT;
endpointConfig = HostAndPorts.getRedisEndpoint("standalone-entraid-acl");
hnp = endpointConfig.getHostAndPort();
} catch (IllegalArgumentException e) {
log.warn("Skipping test because no Redis endpoint is configured");
assumeTrue(false, "No Redis endpoint 'standalone-entraid-acl' is configured!");
}
}
@Test
public void testJedisConfig() {
AtomicInteger counter = new AtomicInteger(0);
try (MockedConstruction<EntraIDIdentityProvider> mockedConstructor = mockConstruction(
EntraIDIdentityProvider.class, (mock, context) -> {
ServicePrincipalInfo info = (ServicePrincipalInfo) context.arguments().get(0);
assertEquals(testCtx.getClientId(), info.getClientId());
assertEquals(testCtx.getAuthority(), info.getAuthority());
assertEquals(testCtx.getClientSecret(), info.getSecret());
assertEquals(testCtx.getRedisScopes(), context.arguments().get(1));
assertNotNull(mock);
doAnswer(invocation -> {
counter.incrementAndGet();
return new SimpleToken("default", "token1", System.currentTimeMillis() + 5 * 60 * 1000,
System.currentTimeMillis(), null);
}).when(mock).requestToken();
})) {
TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder()
.authority(testCtx.getAuthority()).clientId(testCtx.getClientId())
.secret(testCtx.getClientSecret()).scopes(testCtx.getRedisScopes()).build();
DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder()
.authXManager(new AuthXManager(tokenAuthConfig)).build();
JedisPooled jedis = new JedisPooled(new HostAndPort("localhost", 6379), jedisConfig);
assertNotNull(jedis);
assertEquals(1, counter.get());
}
}
// T.1.1
// Verify authentication using Azure AD with service principals
@Test
public void withSecret_azureServicePrincipalIntegrationTest() {
TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder()
.clientId(testCtx.getClientId()).secret(testCtx.getClientSecret())
.authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()).build();
DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder()
.authXManager(new AuthXManager(tokenAuthConfig)).build();
try (JedisPooled jedis = new JedisPooled(hnp, jedisConfig)) {
String key = UUID.randomUUID().toString();
jedis.set(key, "value");
assertEquals("value", jedis.get(key));
jedis.del(key);
}
}
// T.1.1
// Verify authentication using Azure AD with service principals
@Test
public void withCertificate_azureServicePrincipalIntegrationTest() {
TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder()
.clientId(testCtx.getClientId()).secret(testCtx.getClientSecret())
.authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes()).build();
DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder()
.authXManager(new AuthXManager(tokenAuthConfig)).build();
try (JedisPooled jedis = new JedisPooled(hnp, jedisConfig)) {
String key = UUID.randomUUID().toString();
jedis.set(key, "value");
assertEquals("value", jedis.get(key));
jedis.del(key);
}
}
// T.2.2
// Test that the Redis client is not blocked/interrupted during token renewal.
@Test
public void renewalDuringOperationsTest() throws InterruptedException, ExecutionException {
// set the stage with consecutive get/set operations with unique keys which keeps running with a jedispooled instace,
// configure token manager to renew token approximately approximately every 10ms
// wait till token was renewed at least 10 times after initial token acquisition
// Additional note: Assumptions made on the time taken for token renewal and operations are based on the current implementation and may vary in future
// Assumptions:
// - TTL of token is 2 hour
// - expirationRefreshRatio is 0.000001F
// - renewal delay is 7 ms each time a token is acquired
// - each auth command takes 40 ms in total to complete(considering the cloud test environments)
// - each auth command would need to wait for an ongoing customer operation(GET/SET/DEL) to complete, which would take another 40 ms
// - each renewal happens in 40+40+7 = 87 ms
// - total number of renewals would take 87 * 10 = 870 ms
TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder()
.clientId(testCtx.getClientId()).secret(testCtx.getClientSecret())
.authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes())
.expirationRefreshRatio(0.000001F).build();
AuthXManager authXManager = new AuthXManager(tokenAuthConfig);
Consumer<Token> hook = mock(Consumer.class);
authXManager.addPostAuthenticationHook(hook);
DefaultJedisClientConfig jedisClientConfig = DefaultJedisClientConfig.builder()
.authXManager(authXManager).build();
ExecutorService jedisExecutors = Executors.newFixedThreadPool(5);
AtomicBoolean completed = new AtomicBoolean(false);
ExecutorService runner = Executors.newSingleThreadExecutor();
runner.submit(() -> {
try (JedisPooled jedis = new JedisPooled(hnp, jedisClientConfig)) {
List<Future<?>> futures = new ArrayList<>();
for (int i = 0; i < 5; i++) {
Future<?> future = jedisExecutors.submit(() -> {
while (!completed.get()) {
String key = UUID.randomUUID().toString();
jedis.set(key, "value");
assertEquals("value", jedis.get(key));
jedis.del(key);
}
});
futures.add(future);
}
for (Future<?> task : futures) {
try {
task.get();
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
}
}
}
});
await().pollInterval(ONE_HUNDRED_MILLISECONDS).atMost(FIVE_SECONDS).untilAsserted(() -> {
verify(hook, atLeast(10)).accept(any());
});
completed.set(true);
runner.shutdown();
jedisExecutors.shutdown();
}
// T.3.2
// Verify that all existing connections can be re-authenticated when a new token is received.
@Test
public void allConnectionsReauthTest() throws InterruptedException, ExecutionException {
TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder()
.clientId(testCtx.getClientId()).secret(testCtx.getClientSecret())
.authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes())
.expirationRefreshRatio(0.000001F).build();
AuthXManager authXManager = new AuthXManager(tokenAuthConfig);
authXManager = spy(authXManager);
List<Connection> connections = new ArrayList<>();
doAnswer(invocation -> {
Connection connection = spy((Connection) invocation.getArgument(0));
invocation.getArguments()[0] = connection;
connections.add(connection);
Object result = invocation.callRealMethod();
return result;
}).when(authXManager).addConnection(any(Connection.class));
DefaultJedisClientConfig jedisClientConfig = DefaultJedisClientConfig.builder()
.authXManager(authXManager).build();
long startTime = System.currentTimeMillis();
List<Future<?>> futures = new ArrayList<>();
ExecutorService executor = Executors.newFixedThreadPool(5);
try (JedisPooled jedis = new JedisPooled(hnp, jedisClientConfig)) {
for (int i = 0; i < 5; i++) {
Future<?> future = executor.submit(() -> {
for (; System.currentTimeMillis() - startTime < 2000;) {
String key = UUID.randomUUID().toString();
jedis.set(key, "value");
assertEquals("value", jedis.get(key));
jedis.del(key);
}
});
futures.add(future);
}
for (Future<?> task : futures) {
task.get();
}
connections.forEach(conn -> {
verify(conn, atLeast(1)).reAuthenticate();
});
executor.shutdown();
}
}
// T.3.3
// Verify behavior when attempting to authenticate a single connection with an expired token.
@Test
public void connectionAuthWithExpiredTokenTest() {
IdentityProvider idp = new EntraIDIdentityProviderConfig(
new ServicePrincipalInfo(testCtx.getClientId(), testCtx.getClientSecret(),
testCtx.getAuthority()),
testCtx.getRedisScopes(), 1000).getProvider();
IdentityProvider mockIdentityProvider = mock(IdentityProvider.class);
AtomicReference<Token> token = new AtomicReference<>();
doAnswer(invocation -> {
if (token.get() == null) {
token.set(idp.requestToken());
}
return token.get();
}).when(mockIdentityProvider).requestToken();
IdentityProviderConfig idpConfig = mock(IdentityProviderConfig.class);
when(idpConfig.getProvider()).thenReturn(mockIdentityProvider);
TokenAuthConfig tokenAuthConfig = TokenAuthConfig.builder().tokenRequestExecTimeoutInMs(4000)
.identityProviderConfig(idpConfig).expirationRefreshRatio(0.000001F).build();
AuthXManager authXManager = new AuthXManager(tokenAuthConfig);
DefaultJedisClientConfig jedisClientConfig = DefaultJedisClientConfig.builder()
.authXManager(authXManager).build();
try (JedisPooled jedis = new JedisPooled(hnp, jedisClientConfig)) {
for (int i = 0; i < 50; i++) {
String key = UUID.randomUUID().toString();
jedis.set(key, "value");
assertEquals("value", jedis.get(key));
jedis.del(key);
}
token.set(new SimpleToken(idp.requestToken().getUser(), "token1",
System.currentTimeMillis() - 1, System.currentTimeMillis(), null));
JedisAccessControlException aclException = assertThrows(JedisAccessControlException.class,
() -> {
for (int i = 0; i < 50; i++) {
String key = UUID.randomUUID().toString();
jedis.set(key, "value");
assertEquals("value", jedis.get(key));
jedis.del(key);
}
});
String expectedError = "WRONGPASS invalid username-password pair";
assertTrue(aclException.getMessage().startsWith(expectedError),
"Expected '" + aclException.getMessage() + "' to start with '" + expectedError + "'");
}
}
// T.3.4
// Verify handling of reconnection and re-authentication after a network partition. (use cached token)
@Test
public void networkPartitionEvictionTest() {
TokenAuthConfig tokenAuthConfig = EntraIDTokenAuthConfigBuilder.builder()
.clientId(testCtx.getClientId()).secret(testCtx.getClientSecret())
.authority(testCtx.getAuthority()).scopes(testCtx.getRedisScopes())
.expirationRefreshRatio(0.5F).build();
AuthXManager authXManager = new AuthXManager(tokenAuthConfig);
DefaultJedisClientConfig jedisClientConfig = DefaultJedisClientConfig.builder()
.authXManager(authXManager).build();
try (JedisPooled jedis = new JedisPooled(hnp, jedisClientConfig)) {
for (int i = 0; i < 5; i++) {
String key = UUID.randomUUID().toString();
jedis.set(key, "value");
assertEquals("value", jedis.get(key));
jedis.del(key);
}
TriggerActionResponse actionResponse = triggerNetworkFailure();
JedisConnectionException aclException = assertThrows(JedisConnectionException.class, () -> {
while (!actionResponse.isCompleted(ONE_HUNDRED_MILLISECONDS, TWO_SECONDS, FIVE_SECONDS)) {
for (int i = 0; i < 50; i++) {
String key = UUID.randomUUID().toString();
jedis.set(key, "value");
assertEquals("value", jedis.get(key));
jedis.del(key);
}
}
});
String[] expectedMessages = new String[] { "Unexpected end of stream.",
"java.net.SocketException: Connection reset" };
MatcherAssert.assertThat(aclException.getMessage(), is(in(expectedMessages)));
Awaitility.await().pollDelay(Durations.ONE_HUNDRED_MILLISECONDS).atMost(Durations.TWO_SECONDS)
.until(() -> {
try {
String key = UUID.randomUUID().toString();
jedis.set(key, "value");
assertEquals("value", jedis.get(key));
jedis.del(key);
return true;
} catch (Exception e) {
log.debug("attempt to reconnect after network failure, connection has not been re-established yet:"
+ e.getMessage());
return false;
}
});
}
}
private TriggerActionResponse triggerNetworkFailure() {
HashMap<String, Object> params = new HashMap<>();
params.put("bdb_id", endpointConfig.getBdbId());
TriggerActionResponse actionResponse = null;
String action = "network_failure";
try {
log.info("Triggering {}", action);
actionResponse = faultClient.triggerAction(action, params);
} catch (IOException e) {
fail("Fault Injection Server error:" + e.getMessage());
}
log.info("Action id: {}", actionResponse.getActionId());
return actionResponse;
}
@Test
public void withDefaultCredentials_azureCredentialsIntegrationTest() {
DefaultAzureCredential credential = new DefaultAzureCredentialBuilder().build();
TokenAuthConfig tokenAuthConfig = AzureTokenAuthConfigBuilder.builder()
.defaultAzureCredential(credential).tokenRequestExecTimeoutInMs(2000)
.build();
DefaultJedisClientConfig jedisConfig = DefaultJedisClientConfig.builder()
.authXManager(new AuthXManager(tokenAuthConfig)).build();
try (JedisPooled jedis = new JedisPooled(hnp, jedisConfig)) {
String key = UUID.randomUUID().toString();
jedis.set(key, "value");
assertEquals("value", jedis.get(key));
jedis.del(key);
}
}
}