EntraIDTestContext.java

package redis.clients.jedis.authentication;

import java.io.ByteArrayInputStream;
import java.security.KeyFactory;
import java.security.PrivateKey;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.spec.PKCS8EncodedKeySpec;
import java.util.Arrays;
import java.util.Base64;
import java.util.HashSet;
import java.util.Set;

public class EntraIDTestContext {
  private static final String AZURE_CLIENT_ID = "AZURE_CLIENT_ID";
  private static final String AZURE_AUTHORITY = "AZURE_AUTHORITY";
  private static final String AZURE_CLIENT_SECRET = "AZURE_CLIENT_SECRET";
  private static final String AZURE_PRIVATE_KEY = "AZURE_PRIVATE_KEY";
  private static final String AZURE_CERT = "AZURE_CERT";
  private static final String AZURE_REDIS_SCOPES = "AZURE_REDIS_SCOPES";
  private static final String AZURE_USER_ASSIGNED_MANAGED_ID = "AZURE_USER_ASSIGNED_MANAGED_ID";

  private String clientId;
  private String authority;
  private String clientSecret;
  private PrivateKey privateKey;
  private X509Certificate cert;
  private Set<String> redisScopes;
  private String userAssignedManagedIdentity;

  public static final EntraIDTestContext DEFAULT = new EntraIDTestContext();

  private EntraIDTestContext() {
    clientId = System.getenv(AZURE_CLIENT_ID);
    authority = System.getenv(AZURE_AUTHORITY);
    clientSecret = System.getenv(AZURE_CLIENT_SECRET);
    userAssignedManagedIdentity = System.getenv(AZURE_USER_ASSIGNED_MANAGED_ID);
  }

  public EntraIDTestContext(String clientId, String authority, String clientSecret,
      PrivateKey privateKey, X509Certificate cert, Set<String> redisScopes,
      String userAssignedManagedIdentity) {
    this.clientId = clientId;
    this.authority = authority;
    this.clientSecret = clientSecret;
    this.privateKey = privateKey;
    this.cert = cert;
    this.redisScopes = redisScopes;
    this.userAssignedManagedIdentity = userAssignedManagedIdentity;
  }

  public String getClientId() {
    return clientId;
  }

  public String getAuthority() {
    return authority;
  }

  public String getClientSecret() {
    return clientSecret;
  }

  public PrivateKey getPrivateKey() {
    if (privateKey == null) {
      this.privateKey = getPrivateKey(System.getenv(AZURE_PRIVATE_KEY));
    }
    return privateKey;
  }

  public X509Certificate getCert() {
    if (cert == null) {
      this.cert = getCert(System.getenv(AZURE_CERT));
    }
    return cert;
  }

  public Set<String> getRedisScopes() {
    if (redisScopes == null) {
      String redisScopesEnv = System.getenv(AZURE_REDIS_SCOPES);
      this.redisScopes = new HashSet<>(Arrays.asList(redisScopesEnv.split(";")));
    }
    return redisScopes;
  }

  public String getUserAssignedManagedIdentity() {
    return userAssignedManagedIdentity;
  }

  private PrivateKey getPrivateKey(String privateKey) {
    try {
      // Decode the base64 encoded key into a byte array
      byte[] decodedKey = Base64.getDecoder().decode(privateKey);

      // Generate the private key from the decoded byte array using PKCS8EncodedKeySpec
      PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(decodedKey);
      KeyFactory keyFactory = KeyFactory.getInstance("RSA"); // Use the correct algorithm (e.g., "RSA", "EC", "DSA")
      PrivateKey key = keyFactory.generatePrivate(keySpec);
      return key;
    } catch (Exception e) {
      e.printStackTrace();
      throw new RuntimeException(e);
    }
  }

  private X509Certificate getCert(String cert) {
    try {
      // Convert the Base64 encoded string into a byte array
      byte[] encoded = java.util.Base64.getDecoder().decode(cert);

      // Create a CertificateFactory for X.509 certificates
      CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509");

      // Generate the certificate from the byte array
      X509Certificate certificate = (X509Certificate) certificateFactory
          .generateCertificate(new ByteArrayInputStream(encoded));
      return certificate;
    } catch (Exception e) {
      e.printStackTrace();
      throw new RuntimeException(e);
    }
  }
}