AuthXManager.java
package redis.clients.jedis.authentication;
import java.lang.ref.WeakReference;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import redis.clients.authentication.core.Token;
import redis.clients.authentication.core.TokenAuthConfig;
import redis.clients.authentication.core.TokenListener;
import redis.clients.authentication.core.TokenManager;
import redis.clients.jedis.Connection;
import redis.clients.jedis.RedisCredentials;
public final class AuthXManager implements Supplier<RedisCredentials> {
private static final Logger log = LoggerFactory.getLogger(AuthXManager.class);
private TokenManager tokenManager;
private List<WeakReference<Connection>> connections = Collections
.synchronizedList(new ArrayList<>());
private Token currentToken;
private AuthXEventListener listener = AuthXEventListener.NOOP_LISTENER;
private List<Consumer<Token>> postAuthenticateHooks = new ArrayList<>();
private AtomicReference<CompletableFuture<Void>> uniqueStarterTask = new AtomicReference<>();
protected AuthXManager(TokenManager tokenManager) {
this.tokenManager = tokenManager;
}
public AuthXManager(TokenAuthConfig tokenAuthConfig) {
this(new TokenManager(tokenAuthConfig.getIdentityProviderConfig().getProvider(),
tokenAuthConfig.getTokenManagerConfig()));
}
public void start() {
Future<Void> safeStarter = safeStart(this::tokenManagerStart);
try {
safeStarter.get();
} catch (InterruptedException | ExecutionException e) {
log.error("AuthXManager failed to start!", e);
throw new JedisAuthenticationException("AuthXManager failed to start!",
(e instanceof ExecutionException) ? e.getCause() : e);
}
}
private Future<Void> safeStart(Runnable starter) {
if (uniqueStarterTask.compareAndSet(null, new CompletableFuture<Void>())) {
try {
starter.run();
uniqueStarterTask.get().complete(null);
} catch (Exception e) {
uniqueStarterTask.get().completeExceptionally(e);
}
}
return uniqueStarterTask.get();
}
private void tokenManagerStart() {
tokenManager.start(new TokenListener() {
@Override
public void onTokenRenewed(Token token) {
currentToken = token;
authenticateConnections(token);
}
@Override
public void onError(Exception reason) {
listener.onIdentityProviderError(reason);
}
}, true);
}
public void authenticateConnections(Token token) {
RedisCredentials credentialsFromToken = new TokenCredentials(token);
for (WeakReference<Connection> connectionRef : connections) {
Connection connection = connectionRef.get();
if (connection != null) {
connection.setCredentials(credentialsFromToken);
} else {
connections.remove(connectionRef);
}
}
postAuthenticateHooks.forEach(hook -> hook.accept(token));
}
public Connection addConnection(Connection connection) {
connections.add(new WeakReference<>(connection));
return connection;
}
public void stop() {
tokenManager.stop();
}
public void setListener(AuthXEventListener listener) {
if (listener != null) {
this.listener = listener;
}
}
public void addPostAuthenticationHook(Consumer<Token> postAuthenticateHook) {
postAuthenticateHooks.add(postAuthenticateHook);
}
public void removePostAuthenticationHook(Consumer<Token> postAuthenticateHook) {
postAuthenticateHooks.remove(postAuthenticateHook);
}
public AuthXEventListener getListener() {
return listener;
}
@Override
public RedisCredentials get() {
return new TokenCredentials(this.currentToken);
}
}