RsaKeyHelper.java
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.crypto.encrypt;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.StringWriter;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.CharacterCodingException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.KeySpec;
import java.security.spec.RSAPrivateCrtKeySpec;
import java.security.spec.RSAPublicKeySpec;
import java.security.spec.X509EncodedKeySpec;
import java.util.Arrays;
import java.util.Base64;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.bouncycastle.asn1.ASN1Sequence;
import org.jspecify.annotations.Nullable;
/**
* Reads RSA key pairs using BC provider classes but without the need to specify a crypto
* provider or have BC added as one.
*
* @author Luke Taylor
* @author Dave Syer
*/
final class RsaKeyHelper {
private static final Charset UTF8 = StandardCharsets.UTF_8;
private static final String BEGIN = "-----BEGIN";
private static final Pattern PEM_DATA = Pattern.compile(".*-----BEGIN (.*)-----(.*)-----END (.*)-----",
Pattern.DOTALL);
private static final byte[] PREFIX = new byte[] { 0, 0, 0, 7, 's', 's', 'h', '-', 'r', 's', 'a' };
private RsaKeyHelper() {
}
static KeyPair parseKeyPair(String pemData) {
Matcher m = PEM_DATA.matcher(pemData.replaceAll("\n *", "").trim());
if (!m.matches()) {
try {
RSAPublicKey publicValue = extractPublicKey(pemData);
if (publicValue != null) {
return new KeyPair(publicValue, null);
}
}
catch (Exception ex) {
// Ignore
}
throw new IllegalArgumentException("String is not PEM encoded data, nor a public key encoded for ssh");
}
String type = m.group(1);
final byte[] content = base64Decode(m.group(2));
PublicKey publicKey;
PrivateKey privateKey = null;
try {
KeyFactory fact = KeyFactory.getInstance("RSA");
switch (type) {
case "RSA PRIVATE KEY" -> {
ASN1Sequence seq = ASN1Sequence.getInstance(content);
if (seq.size() != 9) {
throw new IllegalArgumentException("Invalid RSA Private Key ASN1 sequence.");
}
org.bouncycastle.asn1.pkcs.RSAPrivateKey key = org.bouncycastle.asn1.pkcs.RSAPrivateKey
.getInstance(seq);
RSAPublicKeySpec pubSpec = new RSAPublicKeySpec(key.getModulus(), key.getPublicExponent());
RSAPrivateCrtKeySpec privSpec = new RSAPrivateCrtKeySpec(key.getModulus(), key.getPublicExponent(),
key.getPrivateExponent(), key.getPrime1(), key.getPrime2(), key.getExponent1(),
key.getExponent2(), key.getCoefficient());
publicKey = fact.generatePublic(pubSpec);
privateKey = fact.generatePrivate(privSpec);
}
case "PUBLIC KEY" -> {
KeySpec keySpec = new X509EncodedKeySpec(content);
publicKey = fact.generatePublic(keySpec);
}
case "RSA PUBLIC KEY" -> {
ASN1Sequence seq = ASN1Sequence.getInstance(content);
org.bouncycastle.asn1.pkcs.RSAPublicKey key = org.bouncycastle.asn1.pkcs.RSAPublicKey
.getInstance(seq);
RSAPublicKeySpec pubSpec = new RSAPublicKeySpec(key.getModulus(), key.getPublicExponent());
publicKey = fact.generatePublic(pubSpec);
}
default -> throw new IllegalArgumentException(type + " is not a supported format");
}
return new KeyPair(publicKey, privateKey);
}
catch (InvalidKeySpecException ex) {
throw new RuntimeException(ex);
}
catch (NoSuchAlgorithmException ex) {
throw new IllegalStateException(ex);
}
}
private static byte[] base64Decode(String string) {
try {
ByteBuffer bytes = UTF8.newEncoder().encode(CharBuffer.wrap(string));
byte[] bytesCopy = new byte[bytes.limit()];
System.arraycopy(bytes.array(), 0, bytesCopy, 0, bytes.limit());
return Base64.getDecoder().decode(bytesCopy);
}
catch (CharacterCodingException ex) {
throw new RuntimeException(ex);
}
}
static String base64Encode(byte[] bytes) {
try {
return UTF8.newDecoder().decode(ByteBuffer.wrap(Base64.getEncoder().encode(bytes))).toString();
}
catch (CharacterCodingException ex) {
throw new RuntimeException(ex);
}
}
static KeyPair generateKeyPair() {
try {
final KeyPairGenerator keyGen = KeyPairGenerator.getInstance("RSA");
keyGen.initialize(1024);
return keyGen.generateKeyPair();
}
catch (NoSuchAlgorithmException ex) {
throw new IllegalStateException(ex);
}
}
private static final Pattern SSH_PUB_KEY = Pattern.compile("ssh-(rsa|dsa) ([A-Za-z0-9/+]+=*) (.*)");
private static @Nullable RSAPublicKey extractPublicKey(String key) {
Matcher m = SSH_PUB_KEY.matcher(key);
if (m.matches()) {
String alg = m.group(1);
String encKey = m.group(2);
// String id = m.group(3);
if (!"rsa".equalsIgnoreCase(alg)) {
throw new IllegalArgumentException("Only RSA is currently supported, but algorithm was " + alg);
}
return parseSSHPublicKey(encKey);
}
else if (!key.startsWith(BEGIN)) {
// Assume it's the plain Base64 encoded ssh key without the
// "ssh-rsa" at the start
return parseSSHPublicKey(key);
}
return null;
}
static RSAPublicKey parsePublicKey(String key) {
RSAPublicKey publicKey = extractPublicKey(key);
if (publicKey != null) {
return publicKey;
}
KeyPair kp = parseKeyPair(key);
if (kp.getPublic() == null) {
throw new IllegalArgumentException("Key data does not contain a public key");
}
return (RSAPublicKey) kp.getPublic();
}
static String encodePublicKey(RSAPublicKey key, String id) {
StringWriter output = new StringWriter();
output.append("ssh-rsa ");
ByteArrayOutputStream stream = new ByteArrayOutputStream();
try {
stream.write(PREFIX);
writeBigInteger(stream, key.getPublicExponent());
writeBigInteger(stream, key.getModulus());
}
catch (IOException ex) {
throw new IllegalStateException("Cannot encode key", ex);
}
output.append(base64Encode(stream.toByteArray()));
output.append(" " + id);
return output.toString();
}
private static RSAPublicKey parseSSHPublicKey(String encKey) {
ByteArrayInputStream in = new ByteArrayInputStream(base64Decode(encKey));
byte[] prefix = new byte[11];
try {
if (in.read(prefix) != 11 || !Arrays.equals(PREFIX, prefix)) {
throw new IllegalArgumentException("SSH key prefix not found");
}
BigInteger e = new BigInteger(readBigInteger(in));
BigInteger n = new BigInteger(readBigInteger(in));
return createPublicKey(n, e);
}
catch (IOException ex) {
throw new RuntimeException(ex);
}
}
static RSAPublicKey createPublicKey(BigInteger n, BigInteger e) {
try {
return (RSAPublicKey) KeyFactory.getInstance("RSA").generatePublic(new RSAPublicKeySpec(n, e));
}
catch (Exception ex) {
throw new RuntimeException(ex);
}
}
private static void writeBigInteger(ByteArrayOutputStream stream, BigInteger num) throws IOException {
int length = num.toByteArray().length;
byte[] data = new byte[4];
data[0] = (byte) ((length >> 24) & 0xFF);
data[1] = (byte) ((length >> 16) & 0xFF);
data[2] = (byte) ((length >> 8) & 0xFF);
data[3] = (byte) (length & 0xFF);
stream.write(data);
stream.write(num.toByteArray());
}
private static byte[] readBigInteger(ByteArrayInputStream in) throws IOException {
byte[] b = new byte[4];
if (in.read(b) != 4) {
throw new IOException("Expected length data as 4 bytes");
}
int l = ((b[0] & 0xFF) << 24) | ((b[1] & 0xFF) << 16) | ((b[2] & 0xFF) << 8) | (b[3] & 0xFF);
b = new byte[l];
if (in.read(b) != l) {
throw new IOException("Expected " + l + " key bytes");
}
return b;
}
}