ScramScheme.java

/*
 * ====================================================================
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 * ====================================================================
 *
 * This software consists of voluntary contributions made by many
 * individuals on behalf of the Apache Software Foundation.  For more
 * information on the Apache Software Foundation, please see
 * <http://www.apache.org/>.
 *
 */
package org.apache.hc.client5.http.impl.auth;

import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.security.MessageDigest;
import java.security.Principal;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Base64;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;

import javax.crypto.Mac;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.PBEKeySpec;
import javax.crypto.spec.SecretKeySpec;

import org.apache.hc.client5.http.auth.AuthChallenge;
import org.apache.hc.client5.http.auth.AuthScheme;
import org.apache.hc.client5.http.auth.AuthScope;
import org.apache.hc.client5.http.auth.AuthenticationException;
import org.apache.hc.client5.http.auth.Credentials;
import org.apache.hc.client5.http.auth.CredentialsProvider;
import org.apache.hc.client5.http.auth.MalformedChallengeException;
import org.apache.hc.client5.http.auth.StandardAuthScheme;
import org.apache.hc.client5.http.auth.UsernamePasswordCredentials;
import org.apache.hc.client5.http.protocol.HttpClientContext;
import org.apache.hc.core5.annotation.Contract;
import org.apache.hc.core5.annotation.Experimental;
import org.apache.hc.core5.annotation.ThreadingBehavior;
import org.apache.hc.core5.http.HttpHost;
import org.apache.hc.core5.http.HttpRequest;
import org.apache.hc.core5.http.NameValuePair;
import org.apache.hc.core5.http.protocol.HttpContext;
import org.apache.hc.core5.util.Args;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


/**
 * Strict HTTP SCRAM client implementing {@code SCRAM-SHA-256} per RFC&nbsp;7804
 * with SCRAM core per RFC&nbsp;5802/7677.
 * <p>HTTP SCRAM uses <em>no channel binding</em> (GS2 header {@code "n,,"}; {@code c=biws}).</p>
 * <p><strong>Experimental:</strong> This API is work in progress and may change without notice in a future release.</p>
 *
 * @since 5.6
 */
@Contract(threading = ThreadingBehavior.UNSAFE)
@Experimental
public final class ScramScheme implements AuthScheme {

    private static final Logger LOG = LoggerFactory.getLogger(ScramScheme.class);

    // RFC 7804 / RFC 5802 fixed no-CB GS2 header and its base64 value for 'c='
    private static final String GS2_HEADER = "n,,";
    private static final String C_BIND_B64 = "biws"; // base64("n,,")

    private static final Base64.Encoder B64 = Base64.getEncoder().withoutPadding();
    private static final Base64.Decoder B64D = Base64.getDecoder();

    private enum State {
        INIT,
        ANNOUNCED,            // after 401 challenge without data
        CLIENT_FIRST_SENT,    // after Authorization with client-first
        SERVER_FIRST_RCVD,    // after 401 with data (r,s,i)
        CLIENT_FINAL_SENT,    // after Authorization with client-final (p=...)
        COMPLETE,             // after 2xx with matching v
        FAILED
    }

    private final SecureRandom secureRandom;
    private final int warnMinIterations;
    private final int minIterationsRequired;

    private State state = State.INIT;
    private boolean complete;

    private String realm;
    private String sid;

    private String username;     // SASLprep (query)
    private char[] password;     // SASLprep (stored), zeroed after use
    private Principal principal;

    private String clientNonce;
    private String clientFirstBare;
    private String serverFirstRaw;
    private String serverNonce;
    private byte[] salt;
    private int iterations;

    // Expected server signature (raw bytes) for constant-time check on Authentication-Info
    // (may appear on any final response status code)
    private byte[] expectedV;

    /**
     * Default policy: warn if {@code i < 4096}, no hard enforcement; SHA-256 only.
     *
     * @since 5.6
     */
    public ScramScheme() {
        this(4096, 0, null);
    }

    /**
     * Constructor with custom iteration policy.
     *
     * @param warnMinIterations     warn if iteration count is lower than this (0 disables warnings)
     * @param minIterationsRequired fail if iteration count is lower than this (0 disables enforcement)
     * @param rnd                   optional secure random source (null uses system default)
     * @since 5.6
     */
    public ScramScheme(final int warnMinIterations, final int minIterationsRequired, final SecureRandom rnd) {
        this.warnMinIterations = Math.max(0, warnMinIterations);
        this.minIterationsRequired = Math.max(0, minIterationsRequired);
        this.secureRandom = rnd != null ? rnd : new SecureRandom();
    }

    /**
     * Returns textual designation of the scheme.
     *
     * @since 5.6
     */
    @Override
    public String getName() {
        return StandardAuthScheme.SCRAM_SHA_256;
    }

    /**
     * SCRAM is per-request (no connection binding).
     *
     * @since 5.6
     */
    @Override
    public boolean isConnectionBased() {
        return false;
    }

    /**
     * SCRAM must inspect final responses to verify {@code v=} in {@code Authentication-Info}.
     *
     * @since 5.6
     */
    @Override
    public boolean isChallengeExpected() {
        return true;
    }

    /**
     * Legacy entry point: wraps {@link AuthenticationException} as {@link MalformedChallengeException}.
     *
     * @since 5.6
     */
    @Override
    public void processChallenge(final AuthChallenge authChallenge, final HttpContext context)
            throws MalformedChallengeException {
        try {
            processChallenge(null, true, authChallenge, context);
        } catch (final AuthenticationException ex) {
            throw new MalformedChallengeException(ex.getMessage(), ex);
        }
    }

    /**
     * Handles 401 challenges (with/without {@code data}) and final responses carrying
     * {@code Authentication-Info} (any status code).
     *
     * @since 5.6
     */
    @Override
    public void processChallenge(
            final HttpHost host,
            final boolean challenged,
            final AuthChallenge authChallenge,
            final HttpContext context) throws MalformedChallengeException, AuthenticationException {

        Args.notNull(context, "HTTP context");

        if (authChallenge == null) {
            if (!challenged) {
                // Final response with no Authentication-Info: nothing to do
                return;
            }
            throw new MalformedChallengeException("Null SCRAM challenge");
        }

        final Map<String, String> params = toParamMap(authChallenge.getParams());

        if (challenged) {
            // --- 401 path (WWW-Authenticate) ---
            final String scheme = authChallenge.getSchemeName();
            if (scheme == null || !StandardAuthScheme.SCRAM_SHA_256.equalsIgnoreCase(scheme)) {
                throw new MalformedChallengeException("Unexpected scheme: " + scheme);
            }

            final String data = params.get("data");
            if (data == null) {
                // initial announce (no data)
                this.realm = params.get("realm");
                this.state = State.ANNOUNCED;
                this.complete = false;
                zeroAndClearExpectedV();
                return;
            }

            // server-first (data present)
            final String decoded = b64ToString(data);
            this.serverFirstRaw = decoded;
            final Map<String, String> attrs = parseAttrs(decoded);

            final String r = attrs.get("r");
            final String s = attrs.get("s");
            final String i = attrs.get("i");
            if (r == null || r.isEmpty() || s == null || s.isEmpty() || i == null || i.isEmpty()) {
                this.state = State.FAILED;
                throw new MalformedChallengeException("SCRAM server-first missing r/s/i");
            }
            if (this.clientNonce == null || !r.startsWith(this.clientNonce)) {
                this.state = State.FAILED;
                throw new AuthenticationException("SCRAM server nonce does not start with client nonce");
            }

            this.sid = params.get("sid");
            try {
                this.salt = B64D.decode(s);
                if (this.salt.length == 0) {
                    throw new IllegalArgumentException("empty salt");
                }
            } catch (final IllegalArgumentException e) {
                this.state = State.FAILED;
                throw new MalformedChallengeException("Invalid base64 salt", e);
            }
            try {
                this.iterations = Integer.parseInt(i);
                if (this.iterations <= 0) {
                    throw new NumberFormatException("i<=0");
                }
            } catch (final NumberFormatException e) {
                this.state = State.FAILED;
                throw new MalformedChallengeException("Invalid iteration count", e);
            }

            if (this.minIterationsRequired > 0 && this.iterations < this.minIterationsRequired) {
                this.state = State.FAILED;
                throw new AuthenticationException(
                        "SCRAM iteration count below required minimum: " + this.iterations + " < " + this.minIterationsRequired);
            }
            if (this.warnMinIterations > 0 && this.iterations < this.warnMinIterations && LOG.isWarnEnabled()) {
                LOG.warn("SCRAM iteration count ({}) lower than recommended ({})", this.iterations, warnMinIterations);
            }

            this.serverNonce = r;
            this.state = State.SERVER_FIRST_RCVD;
            this.complete = false;
            zeroAndClearExpectedV();
            return;
        }

        // --- final-response path (Authentication-Info on any status) ---
        // For Authentication-Info, RFC 7804 does NOT mandate a scheme token; do NOT enforce scheme name here.
        final String data = params.get("data");
        if (data == null) {
            return;
        }
        final String decoded = b64ToString(data);
        final Map<String, String> attrs = parseAttrs(decoded);
        final String err = attrs.get("e");
        if (err != null) {
            this.state = State.FAILED;
            if (err.isEmpty()) {
                throw new MalformedChallengeException("SCRAM server error attribute 'e' is empty");
            }
            throw new AuthenticationException("SCRAM server error: " + err);
        }
        final String vB64 = attrs.get("v");
        if (vB64 == null) {
            return;
        }

        // compare 'v' in constant time; treat bad base64 for v as a signature mismatch (tests expect "signature")
        final byte[] expected = this.expectedV;
        this.expectedV = null; // clear reference early

        byte[] vBytes = null;
        boolean match;
        try {
            vBytes = B64D.decode(vB64);
            match = expected != null && MessageDigest.isEqual(expected, vBytes);
        } catch (final IllegalArgumentException e) {
            match = false; // invalid base64 for v -> treat as mismatch
        } finally {
            zero(vBytes);
            zero(expected);
        }

        if (!match) {
            this.state = State.FAILED;
            throw new MalformedChallengeException("SCRAM server signature mismatch");
        }
        this.complete = true;
        this.state = State.COMPLETE;
    }

    /**
     * @since 5.6
     */
    @Override
    public boolean isChallengeComplete() {
        return this.complete || this.state == State.COMPLETE || this.state == State.FAILED;
    }

    /**
     * @since 5.6
     */
    @Override
    public String getRealm() {
        return this.realm;
    }

    /**
     * Allow response when:
     * - INIT  (preemptive client-first) ��� only if creds have been prepared
     * - ANNOUNCED (401 without data)
     * - SERVER_FIRST_RCVD (ready to send client-final)
     *
     * @since 5.6
     */
    @Override
    public boolean isResponseReady(
            final HttpHost host,
            final CredentialsProvider credentialsProvider,
            final HttpContext context) throws AuthenticationException {

        Args.notNull(credentialsProvider, "Credentials provider");

        final HttpClientContext clientContext = HttpClientContext.cast(context);
        final AuthScope scope = new AuthScope(host, this.realm, getName());
        final Credentials creds = credentialsProvider.getCredentials(scope, clientContext);
        if (!(creds instanceof UsernamePasswordCredentials)) {
            return false;
        }
        final UsernamePasswordCredentials up = (UsernamePasswordCredentials) creds;

        // SASLprep: username as query, password as stored
        final String preppedUser;
        final char[] passChars = up.getUserPassword() != null ? up.getUserPassword().clone() : null;
        try {
            preppedUser = SaslPrep.INSTANCE.prepAsQueryString(up.getUserName());
            final String preppedPassStr = SaslPrep.INSTANCE.prepAsStoredString(
                    passChars != null ? new String(passChars) : null);
            this.username = preppedUser;
            this.password = preppedPassStr != null ? preppedPassStr.toCharArray() : null;
        } catch (final Exception e) {
            throw new AuthenticationException("SASLprep failed", e);
        } finally {
            if (passChars != null) {
                Arrays.fill(passChars, '\0');
            }
        }

        this.principal = new SimplePrincipal(this.username);

        switch (this.state) {
            case INIT: // allow preemptive
            case ANNOUNCED:
            case SERVER_FIRST_RCVD:
                return true;
            default:
                return false;
        }
    }

    /**
     * @since 5.6
     */
    @Override
    public String generateAuthResponse(
            final HttpHost host,
            final HttpRequest request,
            final HttpContext context) throws AuthenticationException {

        switch (this.state) {
            case INIT:
                // Allow preemptive only if credentials were prepared already
                if (this.username == null) {
                    this.state = State.FAILED;
                    throw new AuthenticationException("SCRAM state out of sequence: INIT without credentials");
                }
                return buildClientFirst();
            case ANNOUNCED:
                return buildClientFirst();
            case SERVER_FIRST_RCVD:
                return buildClientFinalAndExpectV();
            default:
                this.state = State.FAILED;
                throw new AuthenticationException("SCRAM state out of sequence: " + this.state);
        }
    }

    /**
     * Returns {@link Principal} whose credentials are used.
     *
     * @since 5.6
     */
    @Override
    public Principal getPrincipal() {
        return this.principal;
    }

    // ---------------- internals ----------------

    private String buildClientFirst() {
        this.clientNonce = genNonce();
        final String escUser = escapeUser(this.username);
        this.clientFirstBare = "n=" + escUser + ",r=" + this.clientNonce;

        // RFC 7804: data = base64(GS2 header + client-first-bare)
        final String data = stringToB64(GS2_HEADER + this.clientFirstBare);

        final StringBuilder sb = new StringBuilder(64);
        sb.append(StandardAuthScheme.SCRAM_SHA_256).append(' ');
        if (this.realm != null) {
            sb.append("realm=").append(quoteParam(this.realm)).append(", ");
        }
        sb.append("data=").append(quoteParam(data)); // quoted per RFC 7804
        this.state = State.CLIENT_FIRST_SENT;
        this.complete = false;
        zeroAndClearExpectedV();
        return sb.toString();
    }

    private String buildClientFinalAndExpectV() throws AuthenticationException {
        byte[] salted = null, clientKey = null, storedKey = null, clientSignature = null,
                clientProof = null, serverKey = null, serverSignature = null;

        try {
            // HTTP SCRAM: no CB -> c=biws
            final String clientFinalNoProof = "c=" + C_BIND_B64 + ",r=" + this.serverNonce;
            final String authMessage = this.clientFirstBare + "," + this.serverFirstRaw + "," + clientFinalNoProof;

            salted = hiPBKDF2(this.password, this.salt, this.iterations, 32);
            clientKey = hmac(salted, "Client Key");
            storedKey = sha256(clientKey);
            clientSignature = hmac(storedKey, authMessage);
            clientProof = xor(clientKey, clientSignature);
            final String pB64 = B64.encodeToString(clientProof);

            serverKey = hmac(salted, "Server Key");
            serverSignature = hmac(serverKey, authMessage);

            // Stash expected v (raw) for constant-time check on 2xx
            zeroAndClearExpectedV();
            this.expectedV = serverSignature.clone();

            final String clientFinal = clientFinalNoProof + ",p=" + pB64;
            final String data = stringToB64(clientFinal);

            final StringBuilder sb = new StringBuilder(64);
            sb.append(StandardAuthScheme.SCRAM_SHA_256).append(' ');
            if (this.sid != null) {
                sb.append("sid=").append(quoteParam(this.sid)).append(", ");
            }
            sb.append("data=").append(quoteParam(data)); // quoted

            this.state = State.CLIENT_FINAL_SENT;
            this.complete = false;
            return sb.toString();
        } catch (final GeneralSecurityException e) {
            this.state = State.FAILED;
            throw new AuthenticationException("SCRAM crypto failure", e);
        } finally {
            if (this.password != null) {
                Arrays.fill(this.password, '\0');
                this.password = null;
            }
            zero(salted);
            zero(clientKey);
            zero(storedKey);
            zero(clientSignature);
            zero(clientProof);
            zero(serverKey);
            // clone held in expectedV
            zero(serverSignature);
        }
    }

    private static void zero(final byte[] a) {
        if (a != null) {
            Arrays.fill(a, (byte) 0);
        }
    }

    private void zeroAndClearExpectedV() {
        if (this.expectedV != null) {
            zero(this.expectedV);
            this.expectedV = null;
        }
    }

    private static Map<String, String> toParamMap(final List<NameValuePair> pairs) {
        final Map<String, String> m = new HashMap<>();
        if (pairs != null) {
            for (final NameValuePair p : pairs) {
                if (p != null && p.getName() != null) {
                    m.put(p.getName().toLowerCase(Locale.ROOT), p.getValue());
                }
            }
        }
        return m;
    }

    // Parse "k=v(,k=v)*" with RFC 5802 escapes: "=2C"/"=2c" -> ",", "=3D"/"=3d" -> "="
    private static Map<String, String> parseAttrs(final String s) throws MalformedChallengeException {
        final Map<String, String> out = new LinkedHashMap<>();
        int i = 0;
        while (i < s.length()) {
            if (i + 2 > s.length() || s.charAt(i + 1) != '=') {
                throw new MalformedChallengeException("Bad SCRAM attr at index " + i);
            }
            final String k = String.valueOf(s.charAt(i));
            i += 2;
            final StringBuilder v = new StringBuilder();
            while (i < s.length()) {
                final char c = s.charAt(i);
                if (c == ',') {
                    i++;
                    break;
                }
                if (c == '=' && i + 2 < s.length()) {
                    final String esc = s.substring(i, i + 3);
                    if ("=2C".equalsIgnoreCase(esc)) {
                        v.append(',');
                        i += 3;
                        continue;
                    }
                    if ("=3D".equalsIgnoreCase(esc)) {
                        v.append('=');
                        i += 3;
                        continue;
                    }
                }
                v.append(c);
                i++;
            }
            out.put(k, v.toString());
        }
        return out;
    }

    private String genNonce() {
        final byte[] buf = new byte[16];
        this.secureRandom.nextBytes(buf);
        final StringBuilder sb = new StringBuilder(buf.length * 2);
        for (final byte b : buf) {
            final int v = b & 0xff;
            if (v < 0x10) {
                sb.append('0');
            }
            sb.append(Integer.toHexString(v));
        }
        return sb.toString();
    }

    private static String escapeUser(final String user) {
        if (user == null) {
            // Defensive for tests that skip isResponseReady() before client-first
            return "";
        }
        final StringBuilder sb = new StringBuilder(user.length() + 8);
        for (int i = 0; i < user.length(); i++) {
            final char c = user.charAt(i);
            if (c == ',') {
                sb.append("=2C");
            } else if (c == '=') {
                sb.append("=3D");
            } else {
                sb.append(c);
            }
        }
        return sb.toString();
    }

    private static String quoteParam(final String v) {
        if (v == null) {
            return "\"\"";
        }
        final StringBuilder sb = new StringBuilder(v.length() + 2);
        sb.append('"');
        for (int i = 0; i < v.length(); i++) {
            final char c = v.charAt(i);
            if (c == '\\' || c == '"') {
                sb.append('\\');
            }
            sb.append(c);
        }
        sb.append('"');
        return sb.toString();
    }

    private static byte[] hiPBKDF2(final char[] password, final byte[] salt, final int iterations, final int dkLen)
            throws GeneralSecurityException {
        final PBEKeySpec spec = new PBEKeySpec(password, salt, iterations, dkLen * 8);
        return SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256").generateSecret(spec).getEncoded();
    }

    private static byte[] hmac(final byte[] key, final String msg) throws GeneralSecurityException {
        final Mac mac = Mac.getInstance("HmacSHA256");
        mac.init(new SecretKeySpec(key, "HmacSHA256"));
        return mac.doFinal(msg.getBytes(StandardCharsets.UTF_8));
    }

    private static byte[] sha256(final byte[] in) throws GeneralSecurityException {
        return MessageDigest.getInstance("SHA-256").digest(in);
    }

    private static byte[] xor(final byte[] a, final byte[] b) {
        final int len = Math.min(a.length, b.length);
        final byte[] out = new byte[len];
        for (int i = 0; i < len; i++) {
            out[i] = (byte) (a[i] ^ b[i]);
        }
        return out;
    }

    private static String stringToB64(final String s) {
        return B64.encodeToString(s.getBytes(StandardCharsets.UTF_8));
    }

    private static String b64ToString(final String b64) throws MalformedChallengeException {
        try {
            return new String(B64D.decode(b64), StandardCharsets.UTF_8);
        } catch (final IllegalArgumentException e) {
            throw new MalformedChallengeException("Bad base64 'data' value", e);
        }
    }

    private static final class SimplePrincipal implements Principal {
        private final String name;

        private SimplePrincipal(final String name) {
            this.name = name;
        }

        @Override
        public String getName() {
            return this.name;
        }

        @Override
        public String toString() {
            return this.name;
        }
    }

}