Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jose/backends/rsa_backend.py: 21%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

160 statements  

1import binascii 

2import warnings 

3 

4import rsa as pyrsa 

5import rsa.pem as pyrsa_pem 

6from pyasn1.error import PyAsn1Error 

7from rsa import DecryptionError 

8 

9from jose.backends._asn1 import ( 

10 rsa_private_key_pkcs1_to_pkcs8, 

11 rsa_private_key_pkcs8_to_pkcs1, 

12 rsa_public_key_pkcs1_to_pkcs8, 

13) 

14from jose.backends.base import Key 

15from jose.constants import ALGORITHMS 

16from jose.exceptions import JWEError, JWKError 

17from jose.utils import base64_to_long, long_to_base64 

18 

19ALGORITHMS.SUPPORTED.remove(ALGORITHMS.RSA_OAEP) # RSA OAEP not supported 

20 

21LEGACY_INVALID_PKCS8_RSA_HEADER = binascii.unhexlify( 

22 "30" # sequence 

23 "8204BD" # DER-encoded sequence contents length of 1213 bytes -- INCORRECT STATIC LENGTH 

24 "020100" # integer: 0 -- Version 

25 "30" # sequence 

26 "0D" # DER-encoded sequence contents length of 13 bytes -- PrivateKeyAlgorithmIdentifier 

27 "06092A864886F70D010101" # OID -- rsaEncryption 

28 "0500" # NULL -- parameters 

29) 

30ASN1_SEQUENCE_ID = binascii.unhexlify("30") 

31RSA_ENCRYPTION_ASN1_OID = "1.2.840.113549.1.1.1" 

32 

33# Functions gcd and rsa_recover_prime_factors were copied from cryptography 1.9 

34# to enable pure python rsa module to be in compliance with section 6.3.1 of RFC7518 

35# which requires only private exponent (d) for private key. 

36 

37 

38def _gcd(a, b): 

39 """Calculate the Greatest Common Divisor of a and b. 

40 

41 Unless b==0, the result will have the same sign as b (so that when 

42 b is divided by it, the result comes out positive). 

43 """ 

44 while b: 

45 a, b = b, (a % b) 

46 return a 

47 

48 

49# Controls the number of iterations rsa_recover_prime_factors will perform 

50# to obtain the prime factors. Each iteration increments by 2 so the actual 

51# maximum attempts is half this number. 

52_MAX_RECOVERY_ATTEMPTS = 1000 

53 

54 

55def _rsa_recover_prime_factors(n, e, d): 

56 """ 

57 Compute factors p and q from the private exponent d. We assume that n has 

58 no more than two factors. This function is adapted from code in PyCrypto. 

59 """ 

60 # See 8.2.2(i) in Handbook of Applied Cryptography. 

61 ktot = d * e - 1 

62 # The quantity d*e-1 is a multiple of phi(n), even, 

63 # and can be represented as t*2^s. 

64 t = ktot 

65 while t % 2 == 0: 

66 t = t // 2 

67 # Cycle through all multiplicative inverses in Zn. 

68 # The algorithm is non-deterministic, but there is a 50% chance 

69 # any candidate a leads to successful factoring. 

70 # See "Digitalized Signatures and Public Key Functions as Intractable 

71 # as Factorization", M. Rabin, 1979 

72 spotted = False 

73 a = 2 

74 while not spotted and a < _MAX_RECOVERY_ATTEMPTS: 

75 k = t 

76 # Cycle through all values a^{t*2^i}=a^k 

77 while k < ktot: 

78 cand = pow(a, k, n) 

79 # Check if a^k is a non-trivial root of unity (mod n) 

80 if cand != 1 and cand != (n - 1) and pow(cand, 2, n) == 1: 

81 # We have found a number such that (cand-1)(cand+1)=0 (mod n). 

82 # Either of the terms divides n. 

83 p = _gcd(cand + 1, n) 

84 spotted = True 

85 break 

86 k *= 2 

87 # This value was not any good... let's try another! 

88 a += 2 

89 if not spotted: 

90 raise ValueError("Unable to compute factors p and q from exponent d.") 

91 # Found ! 

92 q, r = divmod(n, p) 

93 assert r == 0 

94 p, q = sorted((p, q), reverse=True) 

95 return (p, q) 

96 

97 

98def pem_to_spki(pem, fmt="PKCS8"): 

99 key = RSAKey(pem, ALGORITHMS.RS256) 

100 return key.to_pem(fmt) 

101 

102 

103def _legacy_private_key_pkcs8_to_pkcs1(pkcs8_key): 

104 """Legacy RSA private key PKCS8-to-PKCS1 conversion. 

105 

106 .. warning:: 

107 

108 This is incorrect parsing and only works because the legacy PKCS1-to-PKCS8 

109 encoding was also incorrect. 

110 """ 

111 # Only allow this processing if the prefix matches 

112 # AND the following byte indicates an ASN1 sequence, 

113 # as we would expect with the legacy encoding. 

114 if not pkcs8_key.startswith(LEGACY_INVALID_PKCS8_RSA_HEADER + ASN1_SEQUENCE_ID): 

115 raise ValueError("Invalid private key encoding") 

116 

117 return pkcs8_key[len(LEGACY_INVALID_PKCS8_RSA_HEADER) :] 

118 

119 

120class RSAKey(Key): 

121 SHA256 = "SHA-256" 

122 SHA384 = "SHA-384" 

123 SHA512 = "SHA-512" 

124 

125 def __init__(self, key, algorithm): 

126 if algorithm not in ALGORITHMS.RSA: 

127 raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm) 

128 

129 if algorithm in ALGORITHMS.RSA_KW and algorithm != ALGORITHMS.RSA1_5: 

130 raise JWKError("alg: %s is not supported by the RSA backend" % algorithm) 

131 

132 self.hash_alg = { 

133 ALGORITHMS.RS256: self.SHA256, 

134 ALGORITHMS.RS384: self.SHA384, 

135 ALGORITHMS.RS512: self.SHA512, 

136 }.get(algorithm) 

137 self._algorithm = algorithm 

138 

139 if isinstance(key, dict): 

140 self._prepared_key = self._process_jwk(key) 

141 return 

142 

143 if isinstance(key, (pyrsa.PublicKey, pyrsa.PrivateKey)): 

144 self._prepared_key = key 

145 return 

146 

147 if isinstance(key, str): 

148 key = key.encode("utf-8") 

149 

150 if isinstance(key, bytes): 

151 try: 

152 self._prepared_key = pyrsa.PublicKey.load_pkcs1(key) 

153 except ValueError: 

154 try: 

155 self._prepared_key = pyrsa.PublicKey.load_pkcs1_openssl_pem(key) 

156 except ValueError: 

157 try: 

158 self._prepared_key = pyrsa.PrivateKey.load_pkcs1(key) 

159 except ValueError: 

160 try: 

161 der = pyrsa_pem.load_pem(key, b"PRIVATE KEY") 

162 try: 

163 pkcs1_key = rsa_private_key_pkcs8_to_pkcs1(der) 

164 except PyAsn1Error: 

165 # If the key was encoded using the old, invalid, 

166 # encoding then pyasn1 will throw an error attempting 

167 # to parse the key. 

168 pkcs1_key = _legacy_private_key_pkcs8_to_pkcs1(der) 

169 self._prepared_key = pyrsa.PrivateKey.load_pkcs1(pkcs1_key, format="DER") 

170 except ValueError as e: 

171 raise JWKError(e) 

172 return 

173 raise JWKError("Unable to parse an RSA_JWK from key: %s" % key) 

174 

175 def _process_jwk(self, jwk_dict): 

176 if not jwk_dict.get("kty") == "RSA": 

177 raise JWKError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get("kty")) 

178 

179 e = base64_to_long(jwk_dict.get("e")) 

180 n = base64_to_long(jwk_dict.get("n")) 

181 

182 if "d" not in jwk_dict: 

183 return pyrsa.PublicKey(e=e, n=n) 

184 else: 

185 d = base64_to_long(jwk_dict.get("d")) 

186 extra_params = ["p", "q", "dp", "dq", "qi"] 

187 

188 if any(k in jwk_dict for k in extra_params): 

189 # Precomputed private key parameters are available. 

190 if not all(k in jwk_dict for k in extra_params): 

191 # These values must be present when 'p' is according to 

192 # Section 6.3.2 of RFC7518, so if they are not we raise 

193 # an error. 

194 raise JWKError("Precomputed private key parameters are incomplete.") 

195 

196 p = base64_to_long(jwk_dict["p"]) 

197 q = base64_to_long(jwk_dict["q"]) 

198 return pyrsa.PrivateKey(e=e, n=n, d=d, p=p, q=q) 

199 else: 

200 p, q = _rsa_recover_prime_factors(n, e, d) 

201 return pyrsa.PrivateKey(n=n, e=e, d=d, p=p, q=q) 

202 

203 def sign(self, msg): 

204 return pyrsa.sign(msg, self._prepared_key, self.hash_alg) 

205 

206 def verify(self, msg, sig): 

207 if not self.is_public(): 

208 warnings.warn("Attempting to verify a message with a private key. " "This is not recommended.") 

209 try: 

210 pyrsa.verify(msg, sig, self._prepared_key) 

211 return True 

212 except pyrsa.pkcs1.VerificationError: 

213 return False 

214 

215 def is_public(self): 

216 return isinstance(self._prepared_key, pyrsa.PublicKey) 

217 

218 def public_key(self): 

219 if isinstance(self._prepared_key, pyrsa.PublicKey): 

220 return self 

221 return self.__class__(pyrsa.PublicKey(n=self._prepared_key.n, e=self._prepared_key.e), self._algorithm) 

222 

223 def to_pem(self, pem_format="PKCS8"): 

224 if isinstance(self._prepared_key, pyrsa.PrivateKey): 

225 der = self._prepared_key.save_pkcs1(format="DER") 

226 if pem_format == "PKCS8": 

227 pkcs8_der = rsa_private_key_pkcs1_to_pkcs8(der) 

228 pem = pyrsa_pem.save_pem(pkcs8_der, pem_marker="PRIVATE KEY") 

229 elif pem_format == "PKCS1": 

230 pem = pyrsa_pem.save_pem(der, pem_marker="RSA PRIVATE KEY") 

231 else: 

232 raise ValueError(f"Invalid pem format specified: {pem_format!r}") 

233 else: 

234 if pem_format == "PKCS8": 

235 pkcs1_der = self._prepared_key.save_pkcs1(format="DER") 

236 pkcs8_der = rsa_public_key_pkcs1_to_pkcs8(pkcs1_der) 

237 pem = pyrsa_pem.save_pem(pkcs8_der, pem_marker="PUBLIC KEY") 

238 elif pem_format == "PKCS1": 

239 der = self._prepared_key.save_pkcs1(format="DER") 

240 pem = pyrsa_pem.save_pem(der, pem_marker="RSA PUBLIC KEY") 

241 else: 

242 raise ValueError(f"Invalid pem format specified: {pem_format!r}") 

243 return pem 

244 

245 def to_dict(self): 

246 if not self.is_public(): 

247 public_key = self.public_key()._prepared_key 

248 else: 

249 public_key = self._prepared_key 

250 

251 data = { 

252 "alg": self._algorithm, 

253 "kty": "RSA", 

254 "n": long_to_base64(public_key.n).decode("ASCII"), 

255 "e": long_to_base64(public_key.e).decode("ASCII"), 

256 } 

257 

258 if not self.is_public(): 

259 data.update( 

260 { 

261 "d": long_to_base64(self._prepared_key.d).decode("ASCII"), 

262 "p": long_to_base64(self._prepared_key.p).decode("ASCII"), 

263 "q": long_to_base64(self._prepared_key.q).decode("ASCII"), 

264 "dp": long_to_base64(self._prepared_key.exp1).decode("ASCII"), 

265 "dq": long_to_base64(self._prepared_key.exp2).decode("ASCII"), 

266 "qi": long_to_base64(self._prepared_key.coef).decode("ASCII"), 

267 } 

268 ) 

269 

270 return data 

271 

272 def wrap_key(self, key_data): 

273 if not self.is_public(): 

274 warnings.warn("Attempting to encrypt a message with a private key." " This is not recommended.") 

275 wrapped_key = pyrsa.encrypt(key_data, self._prepared_key) 

276 return wrapped_key 

277 

278 def unwrap_key(self, wrapped_key): 

279 try: 

280 unwrapped_key = pyrsa.decrypt(wrapped_key, self._prepared_key) 

281 except DecryptionError as e: 

282 raise JWEError(e) 

283 return unwrapped_key