ExtendedDigestAuthenticator.java

package org.asynchttpclient.test;

import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;

/**
 * Pure Java DigestAuthenticator for testing MD5, SHA-256, and SHA-512-256.
 */
public class ExtendedDigestAuthenticator {
    private final String advertisedAlgorithm;

    public ExtendedDigestAuthenticator() {
        this(null);
    }

    public ExtendedDigestAuthenticator(String advertisedAlgorithm) {
        this.advertisedAlgorithm = advertisedAlgorithm;
    }

    public String getAdvertisedAlgorithm() {
        return findAlgorithm(advertisedAlgorithm);
    }

    public static String findAlgorithm(String algorithm) {
        if (algorithm == null || "MD5".equalsIgnoreCase(algorithm) || "MD5-sess".equalsIgnoreCase(algorithm)) {
            return "MD5";
        } else if ("SHA-256".equalsIgnoreCase(algorithm) || "SHA-256-sess".equalsIgnoreCase(algorithm)) {
            return "SHA-256";
        } else if ("SHA-512-256".equalsIgnoreCase(algorithm) || "SHA-512-256-sess".equalsIgnoreCase(algorithm)) {
            return "SHA-512-256";
        } else {
            return null;
        }
    }

    public static MessageDigest getMessageDigest(String algorithm) throws NoSuchAlgorithmException {
        if (algorithm == null || "MD5".equalsIgnoreCase(algorithm) || "MD5-sess".equalsIgnoreCase(algorithm)) {
            return MessageDigest.getInstance("MD5");
        } else if ("SHA-256".equalsIgnoreCase(algorithm) || "SHA-256-sess".equalsIgnoreCase(algorithm)) {
            return MessageDigest.getInstance("SHA-256");
        } else if ("SHA-512-256".equalsIgnoreCase(algorithm) || "SHA-512-256-sess".equalsIgnoreCase(algorithm)) {
            return MessageDigest.getInstance("SHA-512/256");
        } else {
            throw new NoSuchAlgorithmException("Unsupported digest algorithm: " + algorithm);
        }
    }

    public static String newNonce() {
        byte[] nonceBytes = new byte[16];
        new Random().nextBytes(nonceBytes);
        return Base64.getEncoder().encodeToString(nonceBytes);
    }

    public String createAuthenticateHeader(String realm, String nonce, boolean stale) {
        StringBuilder header = new StringBuilder(128);
        header.append("Digest realm=\"").append(realm).append('"');
        header.append(", nonce=\"").append(nonce).append('"');
        String algorithm = getAdvertisedAlgorithm();
        if (algorithm != null) {
            header.append(", algorithm=").append(algorithm);
        }
        header.append(", qop=\"auth\"");
        if (stale) {
            header.append(", stale=true");
        }
        return header.toString();
    }

    /**
     * Validate a Digest response from the client.
     * @param method HTTP method
     * @param credentials The Authorization header value (without "Digest ")
     * @param password The user's password
     * @return true if valid, false otherwise
     */
    public static boolean validateDigest(String method, String credentials, String password) {
        Map<String, String> params = parseCredentials(credentials);
        String username = params.get("username");
        String realm = params.get("realm");
        String nonce = params.get("nonce");
        String uri = params.get("uri");
        String response = params.get("response");
        String qop = params.get("qop");
        String nc = params.get("nc");
        String cnonce = params.get("cnonce");
        String algorithm = findAlgorithm(params.get("algorithm"));

        if (algorithm == null) {
            algorithm = "MD5";
        }

        try {
            MessageDigest md = getMessageDigest(algorithm);
            String a1 = username + ':' + realm + ':' + password;
            byte[] ha1 = md.digest(a1.getBytes(StandardCharsets.ISO_8859_1));

            String ha1Hex = toHexString(ha1);
            String a2 = method + ':' + uri;
            byte[] ha2 = md.digest(a2.getBytes(StandardCharsets.ISO_8859_1));

            String ha2Hex = toHexString(ha2);
            String kd;
            if (qop != null && !qop.isEmpty()) {
                kd = ha1Hex + ':' + nonce + ':' + nc + ':' + cnonce + ':' + qop + ':' + ha2Hex;
            } else {
                kd = ha1Hex + ':' + nonce + ':' + ha2Hex;
            }

            String expectedResponse = toHexString(md.digest(kd.getBytes(StandardCharsets.ISO_8859_1)));
            return expectedResponse.equalsIgnoreCase(response);
        } catch (Exception e) {
            return false;
        }
    }

    public static Map<String, String> parseCredentials(String credentials) {
        Map<String, String> map = new HashMap<>();
        String[] parts = credentials.split(",");
        for (String part : parts) {
            int idx = part.indexOf('=');
            if (idx > 0) {
                String key = part.substring(0, idx).trim();
                String value = part.substring(idx + 1).trim();
                if (value.startsWith("\"") && value.endsWith("\"")) {
                    value = value.substring(1, value.length() - 1);
                }
                map.put(key, value);
            }
        }
        return map;
    }

    private static String toHexString(byte[] bytes) {
        StringBuilder sb = new StringBuilder(bytes.length * 2);
        for (byte b : bytes) {
            sb.append(String.format("%02x", b & 0xff));
        }
        return sb.toString();
    }
}