BCEcdhEsAlgorithmProvider.java
/*
* Copyright 2023 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.crypto.def;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.math.BigInteger;
import java.nio.charset.Charset;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.PublicKey;
import java.security.SecureRandom;
import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.ECPublicKey;
import java.security.spec.ECParameterSpec;
import java.security.spec.ECPoint;
import java.security.spec.ECPublicKeySpec;
import java.security.spec.InvalidKeySpecException;
import javax.crypto.KeyAgreement;
import org.bouncycastle.crypto.Wrapper;
import org.bouncycastle.crypto.agreement.kdf.ConcatenationKDFGenerator;
import org.bouncycastle.crypto.engines.AESWrapEngine;
import org.bouncycastle.crypto.params.KDFParameters;
import org.bouncycastle.crypto.params.KeyParameter;
import org.bouncycastle.crypto.util.DigestFactory;
import org.bouncycastle.jce.ECNamedCurveTable;
import org.bouncycastle.jce.spec.ECNamedCurveParameterSpec;
import org.bouncycastle.jce.spec.ECNamedCurveSpec;
import org.keycloak.common.util.Base64Url;
import org.keycloak.crypto.Algorithm;
import org.keycloak.crypto.KeyType;
import org.keycloak.jose.jwe.JWEHeader;
import org.keycloak.jose.jwe.JWEHeader.JWEHeaderBuilder;
import org.keycloak.jose.jwe.JWEKeyStorage;
import org.keycloak.jose.jwe.alg.JWEAlgorithmProvider;
import org.keycloak.jose.jwe.enc.JWEEncryptionProvider;
import org.keycloak.jose.jwk.ECPublicJWK;
import org.keycloak.jose.jwk.JWKUtil;
/**
* ECDH Ephemeral Static Algorithm Provider.
*
* @author Justin Tay
* @see <a href=
* "https://datatracker.ietf.org/doc/html/rfc7518#section-4.6.2">Key
* Derivation for ECDH Key Agreement</a>
*/
public class BCEcdhEsAlgorithmProvider implements JWEAlgorithmProvider {
@Override
public byte[] decodeCek(byte[] encodedCek, Key encryptionKey, JWEHeader header,
JWEEncryptionProvider encryptionProvider) throws Exception {
int keyDataLength = getKeyDataLength(header.getAlgorithm(), encryptionProvider);
PublicKey sharedPublicKey = toPublicKey(header.getEphemeralPublicKey());
String algorithmID = getAlgorithmID(header.getAlgorithm(), header.getEncryptionAlgorithm());
byte[] derivedKey = deriveKey(sharedPublicKey, encryptionKey, keyDataLength, algorithmID,
base64UrlDecode(header.getAgreementPartyUInfo()), base64UrlDecode(header.getAgreementPartyVInfo()));
if (Algorithm.ECDH_ES.equals(header.getAlgorithm())) {
return derivedKey;
} else {
Wrapper encrypter = new AESWrapEngine();
encrypter.init(false, new KeyParameter(derivedKey));
return encrypter.unwrap(encodedCek, 0, encodedCek.length);
}
}
@Override
public byte[] encodeCek(JWEEncryptionProvider encryptionProvider, JWEKeyStorage keyStorage, Key encryptionKey,
JWEHeaderBuilder headerBuilder) throws Exception {
JWEHeader header = headerBuilder.build();
int keyDataLength = getKeyDataLength(header.getAlgorithm(), encryptionProvider);
ECParameterSpec params = ((ECPublicKey) encryptionKey).getParams();
KeyPair ephemeralKeyPair = generateEcKeyPair(params);
ECPublicKey ephemeralPublicKey = (ECPublicKey) ephemeralKeyPair.getPublic();
ECPrivateKey ephemeralPrivateKey = (ECPrivateKey) ephemeralKeyPair.getPrivate();
byte[] agreementPartyUInfo = header.getAgreementPartyUInfo() != null
? base64UrlDecode(header.getAgreementPartyUInfo())
: new byte[0];
byte[] agreementPartyVInfo = header.getAgreementPartyVInfo() != null
? base64UrlDecode(header.getAgreementPartyVInfo())
: new byte[0];
headerBuilder.ephemeralPublicKey(toECPublicJWK(ephemeralPublicKey));
String algorithmID = getAlgorithmID(header.getAlgorithm(), header.getEncryptionAlgorithm());
byte[] derivedKey = deriveKey(encryptionKey, ephemeralPrivateKey, keyDataLength, algorithmID,
agreementPartyUInfo, agreementPartyVInfo);
if (Algorithm.ECDH_ES.equals(header.getAlgorithm())) {
keyStorage.setCEKBytes(derivedKey);
encryptionProvider.deserializeCEK(keyStorage);
return new byte[0];
} else {
Wrapper encrypter = new AESWrapEngine();
encrypter.init(true, new KeyParameter(derivedKey));
byte[] cekBytes = keyStorage.getCekBytes();
return encrypter.wrap(cekBytes, 0, cekBytes.length);
}
}
private byte[] base64UrlDecode(String encoded) {
return Base64Url.decode(encoded == null ? "" : encoded);
}
private static KeyPair generateEcKeyPair(ECParameterSpec params) {
try {
KeyPairGenerator keyGen = KeyPairGenerator.getInstance("EC");
SecureRandom randomGen = SecureRandom.getInstance("SHA1PRNG");
keyGen.initialize(params, randomGen);
return keyGen.generateKeyPair();
} catch (InvalidAlgorithmParameterException | NoSuchAlgorithmException e) {
throw new IllegalArgumentException(e);
}
}
private static byte[] deriveOtherInfo(int keyDataLength, String algorithmID, byte[] agreementPartyUInfo,
byte[] agreementPartyVInfo) {
byte[] algorithmId = encodeDataLengthData(algorithmID.getBytes(Charset.forName("ASCII")));
byte[] partyUInfo = encodeDataLengthData(agreementPartyUInfo);
byte[] partyVInfo = encodeDataLengthData(agreementPartyVInfo);
byte[] suppPubInfo = toByteArray(keyDataLength);
byte[] suppPrivInfo = emptyBytes();
return concat(algorithmId, partyUInfo, partyVInfo, suppPubInfo, suppPrivInfo);
}
public static byte[] deriveKey(Key publicKey, Key privateKey, int keyDataLength, String algorithmID,
byte[] agreementPartyUInfo, byte[] agreementPartyVInfo)
throws InvalidKeyException, NoSuchAlgorithmException, IllegalStateException {
byte[] z = deriveSharedSecret(publicKey, privateKey);
byte[] otherInfo = deriveOtherInfo(keyDataLength, algorithmID, agreementPartyUInfo, agreementPartyVInfo);
KDFParameters param = new KDFParameters(z, otherInfo);
ConcatenationKDFGenerator concatKdf = new ConcatenationKDFGenerator(DigestFactory.createSHA256());
concatKdf.init(param);
int derivedKeyLength = keyDataLength / 8;
byte[] derivedKeyBytes = new byte[derivedKeyLength];
concatKdf.generateBytes(derivedKeyBytes, 0, derivedKeyLength);
return derivedKeyBytes;
}
private static ECPublicJWK toECPublicJWK(ECPublicKey ecKey) {
ECPublicJWK k = new ECPublicJWK();
int fieldSize = ecKey.getParams().getCurve().getField().getFieldSize();
k.setCrv("P-" + fieldSize);
k.setKeyType(KeyType.EC);
k.setX(Base64Url.encode(JWKUtil.toIntegerBytes(ecKey.getW().getAffineX(), fieldSize)));
k.setY(Base64Url.encode(JWKUtil.toIntegerBytes(ecKey.getW().getAffineY(), fieldSize)));
return k;
}
private static PublicKey toPublicKey(ECPublicJWK jwk) {
String crv = jwk.getCrv();
String xStr = jwk.getX();
String yStr = jwk.getY();
if (crv == null) {
throw new IllegalArgumentException("JWK crv must be set");
}
if (xStr == null) {
throw new IllegalArgumentException("JWK x must be set");
}
if (yStr == null) {
throw new IllegalArgumentException("JWK y must be set");
}
BigInteger x = new BigInteger(1, Base64Url.decode(xStr));
BigInteger y = new BigInteger(1, Base64Url.decode(yStr));
String name = nistToSecCurveName(crv);
try {
ECPoint point = new ECPoint(x, y);
ECNamedCurveParameterSpec spec = ECNamedCurveTable.getParameterSpec(name);
ECParameterSpec params = new ECNamedCurveSpec(name, spec.getCurve(), spec.getG(), spec.getN());
ECPublicKeySpec pubKeySpec = new ECPublicKeySpec(point, params);
KeyFactory keyFactory = KeyFactory.getInstance("EC");
return keyFactory.generatePublic(pubKeySpec);
} catch (InvalidKeySpecException | NoSuchAlgorithmException e) {
throw new IllegalArgumentException(e);
}
}
private static byte[] deriveSharedSecret(Key publicKey, Key privateKey)
throws NoSuchAlgorithmException, InvalidKeyException, IllegalStateException {
KeyAgreement keyAgreement = KeyAgreement.getInstance("ECDH");
keyAgreement.init(privateKey);
keyAgreement.doPhase(publicKey, true);
return keyAgreement.generateSecret();
}
private static String getAlgorithmID(String alg, String enc) {
if (Algorithm.ECDH_ES_A128KW.equals(alg) || Algorithm.ECDH_ES_A192KW.equals(alg)
|| Algorithm.ECDH_ES_A256KW.equals(alg)) {
return alg;
} else if (Algorithm.ECDH_ES.equals(alg)) {
return enc;
} else {
throw new IllegalArgumentException("Unsupported algorithm");
}
}
private static String nistToSecCurveName(String nistCurveName) {
switch (nistCurveName) {
case "P-256":
return "secp256r1";
case "P-384":
return "secp384r1";
case "P-521":
return "secp521r1";
default:
throw new IllegalArgumentException("Unsupported curve");
}
}
private static int getKeyDataLength(String alg, JWEEncryptionProvider encryptionProvider) {
if (Algorithm.ECDH_ES_A128KW.equals(alg)) {
return 128;
} else if (Algorithm.ECDH_ES_A192KW.equals(alg)) {
return 192;
} else if (Algorithm.ECDH_ES_A256KW.equals(alg)) {
return 256;
} else if (Algorithm.ECDH_ES.equals(alg)) {
return encryptionProvider.getExpectedCEKLength() * 8;
} else {
throw new IllegalArgumentException("Unsupported algorithm");
}
}
private static byte[] encodeDataLengthData(final byte[] data) {
byte[] databytes = data != null ? data : new byte[0];
byte[] datalen = toByteArray(databytes.length);
return concat(datalen, databytes);
}
private static byte[] emptyBytes() {
return new byte[0];
}
private static byte[] toByteArray(int intValue) {
return new byte[] { (byte) (intValue >> 24), (byte) (intValue >> 16), (byte) (intValue >> 8), (byte) intValue };
}
private static byte[] concat(byte[]... byteArrays) {
try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
for (byte[] bytes : byteArrays) {
if (bytes != null) {
baos.write(bytes);
}
}
return baos.toByteArray();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
}