TokenBasedAuthenticationIntegrationTests.java

package redis.clients.jedis.authentication;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
import static org.mockito.Mockito.when;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.awaitility.Awaitility.await;
import static org.awaitility.Durations.ONE_HUNDRED_MILLISECONDS;
import static org.awaitility.Durations.ONE_SECOND;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.MatcherAssert.assertThat;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;

import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import redis.clients.authentication.core.IdentityProvider;
import redis.clients.authentication.core.IdentityProviderConfig;
import redis.clients.authentication.core.SimpleToken;
import redis.clients.authentication.core.TokenAuthConfig;
import redis.clients.jedis.CommandArguments;
import redis.clients.jedis.Connection;
/*  */
import redis.clients.jedis.DefaultJedisClientConfig;
import redis.clients.jedis.EndpointConfig;
import redis.clients.jedis.HostAndPorts;
import redis.clients.jedis.JedisClientConfig;
import redis.clients.jedis.JedisPooled;
import redis.clients.jedis.JedisPubSub;
import redis.clients.jedis.RedisProtocol;
import redis.clients.jedis.Protocol.Command;
import redis.clients.jedis.exceptions.JedisException;

public class TokenBasedAuthenticationIntegrationTests {
  private static final Logger log = LoggerFactory
      .getLogger(TokenBasedAuthenticationIntegrationTests.class);

  private static EndpointConfig endpointConfig;

  @BeforeAll
  public static void before() {
    try {
      endpointConfig = HostAndPorts.getRedisEndpoint("standalone0");
    } catch (IllegalArgumentException e) {
      log.warn("Skipping test because no Redis endpoint is configured");
      assumeTrue(false);
    }
  }

  @Test
  public void testJedisPooledForInitialAuth() {
    String user = "default";
    String password = endpointConfig.getPassword();

    IdentityProvider idProvider = mock(IdentityProvider.class);
    when(idProvider.requestToken()).thenReturn(new SimpleToken(user, password,
        System.currentTimeMillis() + 100000, System.currentTimeMillis(), null));

    IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class);
    when(idProviderConfig.getProvider()).thenReturn(idProvider);

    TokenAuthConfig tokenAuthConfig = TokenAuthConfig.builder()
        .identityProviderConfig(idProviderConfig).expirationRefreshRatio(0.8F)
        .lowerRefreshBoundMillis(10000).tokenRequestExecTimeoutInMs(1000).build();

    JedisClientConfig clientConfig = DefaultJedisClientConfig.builder()
        .authXManager(new AuthXManager(tokenAuthConfig)).build();

    try (JedisPooled jedis = new JedisPooled(endpointConfig.getHostAndPort(), clientConfig)) {
      jedis.get("key1");
    }
  }

  @Test
  public void testJedisPooledReauth() {
    String user = "default";
    String password = endpointConfig.getPassword();

    IdentityProvider idProvider = mock(IdentityProvider.class);
    when(idProvider.requestToken()).thenAnswer(invocation -> new SimpleToken(user, password,
        System.currentTimeMillis() + 5000, System.currentTimeMillis(), null));

    IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class);
    when(idProviderConfig.getProvider()).thenReturn(idProvider);

    TokenAuthConfig tokenAuthConfig = TokenAuthConfig.builder()
        .identityProviderConfig(idProviderConfig).expirationRefreshRatio(0.8F)
        .lowerRefreshBoundMillis(4800).tokenRequestExecTimeoutInMs(1000).build();

    AuthXManager authXManager = new AuthXManager(tokenAuthConfig);
    authXManager = spy(authXManager);
    List<Connection> connections = new ArrayList<>();
    doAnswer(invocation -> {
      Connection connection = spy((Connection) invocation.getArgument(0));
      invocation.getArguments()[0] = connection;
      connections.add(connection);
      Object result = invocation.callRealMethod();
      return result;
    }).when(authXManager).addConnection(any(Connection.class));

    JedisClientConfig clientConfig = DefaultJedisClientConfig.builder().authXManager(authXManager)
        .build();

    try (JedisPooled jedis = new JedisPooled(endpointConfig.getHostAndPort(), clientConfig)) {
      AtomicBoolean stop = new AtomicBoolean(false);
      ExecutorService executor = Executors.newSingleThreadExecutor();
      executor.submit(() -> {
        while (!stop.get()) {
          jedis.get("key1");
        }
      });

      for (Connection connection : connections) {
        await().pollDelay(ONE_HUNDRED_MILLISECONDS).atMost(ONE_SECOND).untilAsserted(() -> {
          verify(connection, atLeast(3)).reAuthenticate();
        });
      }
      stop.set(true);
      executor.shutdown();
    }
  }

  @Test
  public void testPubSubForInitialAuth() throws InterruptedException {
    String user = "default";
    String password = endpointConfig.getPassword();

    IdentityProvider idProvider = mock(IdentityProvider.class);
    when(idProvider.requestToken()).thenReturn(new SimpleToken(user, password,
        System.currentTimeMillis() + 100000, System.currentTimeMillis(), null));

    IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class);
    when(idProviderConfig.getProvider()).thenReturn(idProvider);

    TokenAuthConfig tokenAuthConfig = TokenAuthConfig.builder()
        .identityProviderConfig(idProviderConfig).expirationRefreshRatio(0.8F)
        .lowerRefreshBoundMillis(10000).tokenRequestExecTimeoutInMs(1000).build();

    JedisClientConfig clientConfig = DefaultJedisClientConfig.builder()
        .authXManager(new AuthXManager(tokenAuthConfig)).protocol(RedisProtocol.RESP3).build();

    JedisPubSub pubSub = new JedisPubSub() {
      public void onSubscribe(String channel, int subscribedChannels) {
        this.unsubscribe();
      }
    };

    try (JedisPooled jedis = new JedisPooled(endpointConfig.getHostAndPort(), clientConfig)) {
      jedis.subscribe(pubSub, "channel1");
    }
  }

  @Test
  public void testJedisPubSubReauth() {
    String user = "default";
    String password = endpointConfig.getPassword();

    IdentityProvider idProvider = mock(IdentityProvider.class);
    when(idProvider.requestToken()).thenAnswer(invocation -> new SimpleToken(user, password,
        System.currentTimeMillis() + 5000, System.currentTimeMillis(), null));

    IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class);
    when(idProviderConfig.getProvider()).thenReturn(idProvider);

    TokenAuthConfig tokenAuthConfig = TokenAuthConfig.builder()
        .identityProviderConfig(idProviderConfig).expirationRefreshRatio(0.8F)
        .lowerRefreshBoundMillis(4800).tokenRequestExecTimeoutInMs(1000).build();

    AuthXManager authXManager = new AuthXManager(tokenAuthConfig);
    authXManager = spy(authXManager);
    List<Connection> connections = new ArrayList<>();
    doAnswer(invocation -> {
      Connection connection = spy((Connection) invocation.getArgument(0));
      invocation.getArguments()[0] = connection;
      connections.add(connection);
      Object result = invocation.callRealMethod();
      return result;
    }).when(authXManager).addConnection(any(Connection.class));

    JedisClientConfig clientConfig = DefaultJedisClientConfig.builder().authXManager(authXManager)
        .protocol(RedisProtocol.RESP3).build();

    JedisPubSub pubSub = new JedisPubSub() {
    };
    try (JedisPooled jedis = new JedisPooled(endpointConfig.getHostAndPort(), clientConfig)) {
      ExecutorService executor = Executors.newSingleThreadExecutor();
      executor.submit(() -> {
        jedis.subscribe(pubSub, "channel1");
      });

      await().pollDelay(ONE_HUNDRED_MILLISECONDS).atMost(ONE_SECOND)
          .until(pubSub::getSubscribedChannels, greaterThan(0));

      assertEquals(1, connections.size());
      for (Connection connection : connections) {
        await().pollDelay(ONE_HUNDRED_MILLISECONDS).atMost(ONE_SECOND).untilAsserted(() -> {
          ArgumentCaptor<CommandArguments> captor = ArgumentCaptor.forClass(CommandArguments.class);

          verify(connection, atLeast(3)).sendCommand(captor.capture());
          assertThat(captor.getAllValues().stream()
              .filter((item) -> item.getCommand() == Command.AUTH).count(),
            greaterThan(3L));

        });
      }
      pubSub.unsubscribe();
      executor.shutdown();
    }
  }

  @Test
  public void testJedisPubSubWithResp2() {
    String user = "default";
    String password = endpointConfig.getPassword();

    IdentityProvider idProvider = mock(IdentityProvider.class);
    when(idProvider.requestToken()).thenReturn(new SimpleToken(user, password,
        System.currentTimeMillis() + 100000, System.currentTimeMillis(), null));

    IdentityProviderConfig idProviderConfig = mock(IdentityProviderConfig.class);
    when(idProviderConfig.getProvider()).thenReturn(idProvider);

    TokenAuthConfig tokenAuthConfig = TokenAuthConfig.builder()
        .identityProviderConfig(idProviderConfig).expirationRefreshRatio(0.8F)
        .lowerRefreshBoundMillis(10000).tokenRequestExecTimeoutInMs(1000).build();

    JedisClientConfig clientConfig = DefaultJedisClientConfig.builder()
        .authXManager(new AuthXManager(tokenAuthConfig)).build();

    try (JedisPooled jedis = new JedisPooled(endpointConfig.getHostAndPort(), clientConfig)) {
      JedisPubSub pubSub = new JedisPubSub() {};
      JedisException e = assertThrows(JedisException.class,
          () -> jedis.subscribe(pubSub, "channel1"));
      assertEquals(
          "Blocking pub/sub operations are not supported on token-based authentication enabled connections with RESP2 protocol!",
          e.getMessage());
    }
  }
}