1import binascii
2import warnings
3
4import rsa as pyrsa
5import rsa.pem as pyrsa_pem
6from pyasn1.error import PyAsn1Error
7from rsa import DecryptionError
8
9from jose.backends._asn1 import (
10 rsa_private_key_pkcs1_to_pkcs8,
11 rsa_private_key_pkcs8_to_pkcs1,
12 rsa_public_key_pkcs1_to_pkcs8,
13)
14from jose.backends.base import Key
15from jose.constants import ALGORITHMS
16from jose.exceptions import JWEError, JWKError
17from jose.utils import base64_to_long, long_to_base64
18
19ALGORITHMS.SUPPORTED.remove(ALGORITHMS.RSA_OAEP) # RSA OAEP not supported
20
21LEGACY_INVALID_PKCS8_RSA_HEADER = binascii.unhexlify(
22 "30" # sequence
23 "8204BD" # DER-encoded sequence contents length of 1213 bytes -- INCORRECT STATIC LENGTH
24 "020100" # integer: 0 -- Version
25 "30" # sequence
26 "0D" # DER-encoded sequence contents length of 13 bytes -- PrivateKeyAlgorithmIdentifier
27 "06092A864886F70D010101" # OID -- rsaEncryption
28 "0500" # NULL -- parameters
29)
30ASN1_SEQUENCE_ID = binascii.unhexlify("30")
31RSA_ENCRYPTION_ASN1_OID = "1.2.840.113549.1.1.1"
32
33# Functions gcd and rsa_recover_prime_factors were copied from cryptography 1.9
34# to enable pure python rsa module to be in compliance with section 6.3.1 of RFC7518
35# which requires only private exponent (d) for private key.
36
37
38def _gcd(a, b):
39 """Calculate the Greatest Common Divisor of a and b.
40
41 Unless b==0, the result will have the same sign as b (so that when
42 b is divided by it, the result comes out positive).
43 """
44 while b:
45 a, b = b, (a % b)
46 return a
47
48
49# Controls the number of iterations rsa_recover_prime_factors will perform
50# to obtain the prime factors. Each iteration increments by 2 so the actual
51# maximum attempts is half this number.
52_MAX_RECOVERY_ATTEMPTS = 1000
53
54
55def _rsa_recover_prime_factors(n, e, d):
56 """
57 Compute factors p and q from the private exponent d. We assume that n has
58 no more than two factors. This function is adapted from code in PyCrypto.
59 """
60 # See 8.2.2(i) in Handbook of Applied Cryptography.
61 ktot = d * e - 1
62 # The quantity d*e-1 is a multiple of phi(n), even,
63 # and can be represented as t*2^s.
64 t = ktot
65 while t % 2 == 0:
66 t = t // 2
67 # Cycle through all multiplicative inverses in Zn.
68 # The algorithm is non-deterministic, but there is a 50% chance
69 # any candidate a leads to successful factoring.
70 # See "Digitalized Signatures and Public Key Functions as Intractable
71 # as Factorization", M. Rabin, 1979
72 spotted = False
73 a = 2
74 while not spotted and a < _MAX_RECOVERY_ATTEMPTS:
75 k = t
76 # Cycle through all values a^{t*2^i}=a^k
77 while k < ktot:
78 cand = pow(a, k, n)
79 # Check if a^k is a non-trivial root of unity (mod n)
80 if cand != 1 and cand != (n - 1) and pow(cand, 2, n) == 1:
81 # We have found a number such that (cand-1)(cand+1)=0 (mod n).
82 # Either of the terms divides n.
83 p = _gcd(cand + 1, n)
84 spotted = True
85 break
86 k *= 2
87 # This value was not any good... let's try another!
88 a += 2
89 if not spotted:
90 raise ValueError("Unable to compute factors p and q from exponent d.")
91 # Found !
92 q, r = divmod(n, p)
93 assert r == 0
94 p, q = sorted((p, q), reverse=True)
95 return (p, q)
96
97
98def pem_to_spki(pem, fmt="PKCS8"):
99 key = RSAKey(pem, ALGORITHMS.RS256)
100 return key.to_pem(fmt)
101
102
103def _legacy_private_key_pkcs8_to_pkcs1(pkcs8_key):
104 """Legacy RSA private key PKCS8-to-PKCS1 conversion.
105
106 .. warning::
107
108 This is incorrect parsing and only works because the legacy PKCS1-to-PKCS8
109 encoding was also incorrect.
110 """
111 # Only allow this processing if the prefix matches
112 # AND the following byte indicates an ASN1 sequence,
113 # as we would expect with the legacy encoding.
114 if not pkcs8_key.startswith(LEGACY_INVALID_PKCS8_RSA_HEADER + ASN1_SEQUENCE_ID):
115 raise ValueError("Invalid private key encoding")
116
117 return pkcs8_key[len(LEGACY_INVALID_PKCS8_RSA_HEADER) :]
118
119
120class RSAKey(Key):
121 SHA256 = "SHA-256"
122 SHA384 = "SHA-384"
123 SHA512 = "SHA-512"
124
125 def __init__(self, key, algorithm):
126 if algorithm not in ALGORITHMS.RSA:
127 raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
128
129 if algorithm in ALGORITHMS.RSA_KW and algorithm != ALGORITHMS.RSA1_5:
130 raise JWKError("alg: %s is not supported by the RSA backend" % algorithm)
131
132 self.hash_alg = {
133 ALGORITHMS.RS256: self.SHA256,
134 ALGORITHMS.RS384: self.SHA384,
135 ALGORITHMS.RS512: self.SHA512,
136 }.get(algorithm)
137 self._algorithm = algorithm
138
139 if isinstance(key, dict):
140 self._prepared_key = self._process_jwk(key)
141 return
142
143 if isinstance(key, (pyrsa.PublicKey, pyrsa.PrivateKey)):
144 self._prepared_key = key
145 return
146
147 if isinstance(key, str):
148 key = key.encode("utf-8")
149
150 if isinstance(key, bytes):
151 try:
152 self._prepared_key = pyrsa.PublicKey.load_pkcs1(key)
153 except ValueError:
154 try:
155 self._prepared_key = pyrsa.PublicKey.load_pkcs1_openssl_pem(key)
156 except ValueError:
157 try:
158 self._prepared_key = pyrsa.PrivateKey.load_pkcs1(key)
159 except ValueError:
160 try:
161 der = pyrsa_pem.load_pem(key, b"PRIVATE KEY")
162 try:
163 pkcs1_key = rsa_private_key_pkcs8_to_pkcs1(der)
164 except PyAsn1Error:
165 # If the key was encoded using the old, invalid,
166 # encoding then pyasn1 will throw an error attempting
167 # to parse the key.
168 pkcs1_key = _legacy_private_key_pkcs8_to_pkcs1(der)
169 self._prepared_key = pyrsa.PrivateKey.load_pkcs1(pkcs1_key, format="DER")
170 except ValueError as e:
171 raise JWKError(e)
172 return
173 raise JWKError("Unable to parse an RSA_JWK from key: %s" % key)
174
175 def _process_jwk(self, jwk_dict):
176 if not jwk_dict.get("kty") == "RSA":
177 raise JWKError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get("kty"))
178
179 e = base64_to_long(jwk_dict.get("e"))
180 n = base64_to_long(jwk_dict.get("n"))
181
182 if "d" not in jwk_dict:
183 return pyrsa.PublicKey(e=e, n=n)
184 else:
185 d = base64_to_long(jwk_dict.get("d"))
186 extra_params = ["p", "q", "dp", "dq", "qi"]
187
188 if any(k in jwk_dict for k in extra_params):
189 # Precomputed private key parameters are available.
190 if not all(k in jwk_dict for k in extra_params):
191 # These values must be present when 'p' is according to
192 # Section 6.3.2 of RFC7518, so if they are not we raise
193 # an error.
194 raise JWKError("Precomputed private key parameters are incomplete.")
195
196 p = base64_to_long(jwk_dict["p"])
197 q = base64_to_long(jwk_dict["q"])
198 return pyrsa.PrivateKey(e=e, n=n, d=d, p=p, q=q)
199 else:
200 p, q = _rsa_recover_prime_factors(n, e, d)
201 return pyrsa.PrivateKey(n=n, e=e, d=d, p=p, q=q)
202
203 def sign(self, msg):
204 return pyrsa.sign(msg, self._prepared_key, self.hash_alg)
205
206 def verify(self, msg, sig):
207 if not self.is_public():
208 warnings.warn("Attempting to verify a message with a private key. " "This is not recommended.")
209 try:
210 pyrsa.verify(msg, sig, self._prepared_key)
211 return True
212 except pyrsa.pkcs1.VerificationError:
213 return False
214
215 def is_public(self):
216 return isinstance(self._prepared_key, pyrsa.PublicKey)
217
218 def public_key(self):
219 if isinstance(self._prepared_key, pyrsa.PublicKey):
220 return self
221 return self.__class__(pyrsa.PublicKey(n=self._prepared_key.n, e=self._prepared_key.e), self._algorithm)
222
223 def to_pem(self, pem_format="PKCS8"):
224 if isinstance(self._prepared_key, pyrsa.PrivateKey):
225 der = self._prepared_key.save_pkcs1(format="DER")
226 if pem_format == "PKCS8":
227 pkcs8_der = rsa_private_key_pkcs1_to_pkcs8(der)
228 pem = pyrsa_pem.save_pem(pkcs8_der, pem_marker="PRIVATE KEY")
229 elif pem_format == "PKCS1":
230 pem = pyrsa_pem.save_pem(der, pem_marker="RSA PRIVATE KEY")
231 else:
232 raise ValueError(f"Invalid pem format specified: {pem_format!r}")
233 else:
234 if pem_format == "PKCS8":
235 pkcs1_der = self._prepared_key.save_pkcs1(format="DER")
236 pkcs8_der = rsa_public_key_pkcs1_to_pkcs8(pkcs1_der)
237 pem = pyrsa_pem.save_pem(pkcs8_der, pem_marker="PUBLIC KEY")
238 elif pem_format == "PKCS1":
239 der = self._prepared_key.save_pkcs1(format="DER")
240 pem = pyrsa_pem.save_pem(der, pem_marker="RSA PUBLIC KEY")
241 else:
242 raise ValueError(f"Invalid pem format specified: {pem_format!r}")
243 return pem
244
245 def to_dict(self):
246 if not self.is_public():
247 public_key = self.public_key()._prepared_key
248 else:
249 public_key = self._prepared_key
250
251 data = {
252 "alg": self._algorithm,
253 "kty": "RSA",
254 "n": long_to_base64(public_key.n).decode("ASCII"),
255 "e": long_to_base64(public_key.e).decode("ASCII"),
256 }
257
258 if not self.is_public():
259 data.update(
260 {
261 "d": long_to_base64(self._prepared_key.d).decode("ASCII"),
262 "p": long_to_base64(self._prepared_key.p).decode("ASCII"),
263 "q": long_to_base64(self._prepared_key.q).decode("ASCII"),
264 "dp": long_to_base64(self._prepared_key.exp1).decode("ASCII"),
265 "dq": long_to_base64(self._prepared_key.exp2).decode("ASCII"),
266 "qi": long_to_base64(self._prepared_key.coef).decode("ASCII"),
267 }
268 )
269
270 return data
271
272 def wrap_key(self, key_data):
273 if not self.is_public():
274 warnings.warn("Attempting to encrypt a message with a private key." " This is not recommended.")
275 wrapped_key = pyrsa.encrypt(key_data, self._prepared_key)
276 return wrapped_key
277
278 def unwrap_key(self, wrapped_key):
279 try:
280 unwrapped_key = pyrsa.decrypt(wrapped_key, self._prepared_key)
281 except DecryptionError as e:
282 raise JWEError(e)
283 return unwrapped_key