AmazonSigningService.java

/*
 * Copyright 2022 Emmanuel Bourg
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.jsign.jca;

import java.io.IOException;
import java.net.HttpURLConnection;
import java.net.URL;
import java.security.GeneralSecurityException;
import java.security.InvalidAlgorithmParameterException;
import java.security.KeyStoreException;
import java.security.MessageDigest;
import java.security.UnrecoverableKeyException;
import java.security.cert.Certificate;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TimeZone;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;

import org.bouncycastle.util.encoders.Hex;

import net.jsign.DigestAlgorithm;

import static java.nio.charset.StandardCharsets.*;

/**
 * Signing service using the AWS API.
 *
 * @since 5.0
 * @see <a href="https://docs.aws.amazon.com/kms/latest/APIReference/">AWS Key Management Service API Reference</a>
 * @see <a href="https://docs.aws.amazon.com/general/latest/gr/signing_aws_api_requests.html">Signing AWS API Requests</a>
 */
public class AmazonSigningService implements SigningService {

    /** Source for the certificates */
    private final Function<String, Certificate[]> certificateStore;

    /** Cache of private keys indexed by id */
    private final Map<String, SigningServicePrivateKey> keys = new HashMap<>();

    private final RESTClient client;

    /** Mapping between Java and AWS signing algorithms */
    private final Map<String, String> algorithmMapping = new HashMap<>();
    {
        algorithmMapping.put("SHA256withRSA", "RSASSA_PKCS1_V1_5_SHA_256");
        algorithmMapping.put("SHA384withRSA", "RSASSA_PKCS1_V1_5_SHA_384");
        algorithmMapping.put("SHA512withRSA", "RSASSA_PKCS1_V1_5_SHA_512");
        algorithmMapping.put("SHA256withECDSA", "ECDSA_SHA_256");
        algorithmMapping.put("SHA384withECDSA", "ECDSA_SHA_384");
        algorithmMapping.put("SHA512withECDSA", "ECDSA_SHA_512");
        algorithmMapping.put("SHA256withRSA/PSS", "RSASSA_PSS_SHA_256");
        algorithmMapping.put("SHA384withRSA/PSS", "RSASSA_PSS_SHA_384");
        algorithmMapping.put("SHA512withRSA/PSS", "RSASSA_PSS_SHA_512");
    }

    /**
     * Creates a new AWS signing service.
     *
     * @param region           the AWS region holding the keys (for example <tt>eu-west-3</tt>)
     * @param credentials      the AWS credentials provider
     * @param certificateStore provides the certificate chain for the keys
     * @since 6.0
     */
    public AmazonSigningService(String region, Supplier<AmazonCredentials> credentials, Function<String, Certificate[]> certificateStore) {
        this(credentials, certificateStore, getEndpointUrl(region));
    }

    /**
     * Creates a new AWS signing service.
     *
     * @param region           the AWS region holding the keys (for example <tt>eu-west-3</tt>)
     * @param credentials      the AWS credentials
     * @param certificateStore provides the certificate chain for the keys
     */
    public AmazonSigningService(String region, AmazonCredentials credentials, Function<String, Certificate[]> certificateStore) {
        this(region, () -> credentials, certificateStore);
    }

    AmazonSigningService(Supplier<AmazonCredentials> credentials, Function<String, Certificate[]> certificateStore, String endpoint) {
        this.certificateStore = certificateStore;
        this.client = new RESTClient(endpoint)
                .authentication((conn, data) -> sign(conn, credentials.get(), data, null))
                .errorHandler(response -> response.get("__type") + ": " + response.get("message"));
    }

    /**
     * Creates a new AWS signing service.
     *
     * @param region           the AWS region holding the keys (for example <tt>eu-west-3</tt>)
     * @param credentials      the AWS credentials: <tt>accessKey|secretKey|sessionToken</tt> (the session token is optional)
     * @param certificateStore provides the certificate chain for the keys
     */
    @Deprecated
    public AmazonSigningService(String region, String credentials, Function<String, Certificate[]> certificateStore) {
        this(region, AmazonCredentials.parse(credentials), certificateStore);
    }

    /**
     * Returns the endpoint URL for the given AWS region.
     *
     * @param region the AWS region
     * @return the endpoint URL
     */
    static String getEndpointUrl(String region) {
        String domain = "true".equalsIgnoreCase(getenv("AWS_USE_FIPS_ENDPOINT")) ? "kms-fips" : "kms";
        return "https://" + domain + "." + region + ".amazonaws.com";
    }

    @Override
    public String getName() {
        return "AWS";
    }

    @Override
    public List<String> aliases() throws KeyStoreException {
        List<String> aliases = new ArrayList<>();

        try {
            // kms:ListKeys (https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeys.html)
            Map<String, ?> response = query("TrentService.ListKeys", "{}");
            Object[] keys = (Object[]) response.get("Keys");
            for (Object key : keys) {
                aliases.add((String) ((Map) key).get("KeyId"));
            }
        } catch (IOException e) {
            throw new KeyStoreException(e);
        }

        return aliases;
    }

    @Override
    public Certificate[] getCertificateChain(String alias) throws KeyStoreException {
        return certificateStore.apply(alias);
    }

    @Override
    public SigningServicePrivateKey getPrivateKey(String alias, char[] password) throws UnrecoverableKeyException {
        if (keys.containsKey(alias)) {
            return keys.get(alias);
        }

        String algorithm;

        try {
            // kms:DescribeKey (https://docs.aws.amazon.com/kms/latest/APIReference/API_DescribeKey.html)
            Map<String, ?> response = query("TrentService.DescribeKey", "{\"KeyId\":\"" + normalizeKeyId(alias) + "\"}");
            Map<String, ?> keyMetadata = (Map<String, ?>) response.get("KeyMetadata");

            String keyUsage = (String) keyMetadata.get("KeyUsage");
            if (!"SIGN_VERIFY".equals(keyUsage)) {
                throw new UnrecoverableKeyException("The key '" + alias + "' is not a signing key");
            }

            String keyState = (String) keyMetadata.get("KeyState");
            if (!"Enabled".equals(keyState)) {
                throw new UnrecoverableKeyException("The key '" + alias + "' is not enabled (" + keyState + ")");
            }

            String keySpec = (String) keyMetadata.get("KeySpec");
            algorithm = keySpec.substring(0, keySpec.indexOf('_'));
            if ("ECC".equals(algorithm)) {
                algorithm = "EC";
            }
        } catch (IOException e) {
            throw (UnrecoverableKeyException) new UnrecoverableKeyException("Unable to fetch AWS key '" + alias + "'").initCause(e);
        }

        SigningServicePrivateKey key = new SigningServicePrivateKey(alias, algorithm, this);
        keys.put(alias, key);
        return key;
    }

    @Override
    public byte[] sign(SigningServicePrivateKey privateKey, String algorithm, byte[] data) throws GeneralSecurityException {
        String alg = algorithmMapping.get(algorithm);
        if (alg == null) {
            throw new InvalidAlgorithmParameterException("Unsupported signing algorithm: " + algorithm);
        }

        DigestAlgorithm digestAlgorithm = DigestAlgorithm.of(algorithm.substring(0, algorithm.toLowerCase().indexOf("with")));
        data = digestAlgorithm.getMessageDigest().digest(data);

        // kms:Sign (https://docs.aws.amazon.com/kms/latest/APIReference/API_Sign.html)
        Map<String, String> request = new HashMap<>();
        request.put("KeyId", normalizeKeyId(privateKey.getId()));
        request.put("MessageType", "DIGEST");
        request.put("Message", Base64.getEncoder().encodeToString(data));
        request.put("SigningAlgorithm", alg);

        try {
            Map<String, ?> response = query("TrentService.Sign", JsonWriter.format(request));
            String signature = (String) response.get("Signature");
            return Base64.getDecoder().decode(signature);
        } catch (IOException e) {
            throw new GeneralSecurityException(e);
        }
    }

    /**
     * Sends a request to the AWS API.
     */
    private Map<String, ?> query(String target, String body) throws IOException {
        Map<String, String> headers = new HashMap<>();
        headers.put("X-Amz-Target", target);
        headers.put("Content-Type", "application/x-amz-json-1.1");
        return client.post("/", body, headers);
    }

    /**
     * Prefixes the key id with <tt>alias/</tt> if necessary.
     */
    private String normalizeKeyId(String keyId) {
        if (keyId.startsWith("arn:") || keyId.startsWith("alias/")) {
            return keyId;
        }

        if (!keyId.matches("^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$")) {
            return "alias/" + keyId;
        } else {
            return keyId;
        }
    }

    /**
     * Signs the request
     *
     * @see <a href="https://docs.aws.amazon.com/general/latest/gr/signature-version-4.html">Signature Version 4 signing process</a>
     */
    void sign(HttpURLConnection conn, AmazonCredentials credentials, byte[] content, Date date) {
        DateFormat dateFormat = new SimpleDateFormat("yyyyMMdd");
        dateFormat.setTimeZone(TimeZone.getTimeZone("UTC"));
        DateFormat dateTimeFormat = new SimpleDateFormat("yyyyMMdd'T'HHmmss'Z'");
        dateTimeFormat.setTimeZone(TimeZone.getTimeZone("UTC"));
        if (date == null) {
            date = new Date();
        }

        // Extract the service name and the region from the endpoint
        URL endpoint = conn.getURL();
        Pattern hostnamePattern = Pattern.compile("^([^.]+)\\.([^.]+)\\.amazonaws\\.com$");
        String host = endpoint.getHost();
        Matcher matcher = hostnamePattern.matcher(host);
        String regionName = matcher.matches() ? matcher.group(2) : "us-east-1";
        String serviceName = matcher.matches() ? matcher.group(1).replace("-fips", "") : "kms";

        String credentialScope = dateFormat.format(date) + "/" + regionName + "/" + serviceName + "/" + "aws4_request";

        conn.addRequestProperty("X-Amz-Date", dateTimeFormat.format(date));

        // Create the canonical request (https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html)
        Map<String, List<String>> headers = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
        headers.putAll(conn.getRequestProperties());
        headers.put("Host", Collections.singletonList(host));

        String canonicalRequest = conn.getRequestMethod() + "\n"
                + endpoint.getPath() + (endpoint.getPath().endsWith("/") ? "" : "/") + "\n"
                + /* canonical query string, not used for kms operations */ "\n"
                + canonicalHeaders(headers) + "\n"
                + signedHeaders(headers) + "\n"
                + sha256(content);

        // Create the string to sign (https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html)
        String stringToSign = "AWS4-HMAC-SHA256" + "\n"
                + dateTimeFormat.format(date) + "\n"
                + credentialScope + "\n"
                + sha256(canonicalRequest.getBytes(UTF_8));

        // Derive the signing key (https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html)
        byte[] key = ("AWS4" + credentials.getSecretKey()).getBytes(UTF_8);
        byte[] signingKey = hmac("aws4_request", hmac(serviceName, hmac(regionName, hmac(dateFormat.format(date), key))));

        // Compute the signature (https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html)
        byte[] signature = hmac(stringToSign, signingKey);

        conn.setRequestProperty("Authorization",
                "AWS4-HMAC-SHA256 Credential=" + credentials.getAccessKey() + "/" + credentialScope
                + ", SignedHeaders=" + signedHeaders(headers)
                + ", Signature=" + Hex.toHexString(signature).toLowerCase());

        if (credentials.getSessionToken() != null) {
            conn.setRequestProperty("X-Amz-Security-Token", credentials.getSessionToken());
        }
    }

    private String canonicalHeaders(Map<String, List<String>> headers) {
        return headers.entrySet().stream()
                .map(entry -> entry.getKey().toLowerCase() + ":" + String.join(",", entry.getValue()).replaceAll("\\s+", " "))
                .collect(Collectors.joining("\n")) + "\n";
    }

    private String signedHeaders(Map<String, List<String>> headers) {
        return headers.keySet().stream()
                .map(String::toLowerCase)
                .collect(Collectors.joining(";"));
    }

    private byte[] hmac(String data, byte[] key) {
        return hmac(data.getBytes(UTF_8), key);
    }

    private byte[] hmac(byte[] data, byte[] key) {
        try {
            Mac mac = Mac.getInstance("HmacSHA256");
            mac.init(new SecretKeySpec(key, mac.getAlgorithm()));
            return mac.doFinal(data);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private String sha256(byte[] data) {
        MessageDigest digest =  DigestAlgorithm.SHA256.getMessageDigest();
        digest.update(data);
        return Hex.toHexString(digest.digest()).toLowerCase();
    }

    static String getenv(String name) {
        return System.getenv(name);
    }
}