1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import annotations
6
7import abc
8import random
9import typing
10from math import gcd
11
12from cryptography.hazmat.bindings._rust import openssl as rust_openssl
13from cryptography.hazmat.primitives import _serialization, hashes
14from cryptography.hazmat.primitives._asymmetric import AsymmetricPadding
15from cryptography.hazmat.primitives.asymmetric import utils as asym_utils
16
17
18class RSAPrivateKey(metaclass=abc.ABCMeta):
19 @abc.abstractmethod
20 def decrypt(self, ciphertext: bytes, padding: AsymmetricPadding) -> bytes:
21 """
22 Decrypts the provided ciphertext.
23 """
24
25 @property
26 @abc.abstractmethod
27 def key_size(self) -> int:
28 """
29 The bit length of the public modulus.
30 """
31
32 @abc.abstractmethod
33 def public_key(self) -> RSAPublicKey:
34 """
35 The RSAPublicKey associated with this private key.
36 """
37
38 @abc.abstractmethod
39 def sign(
40 self,
41 data: bytes,
42 padding: AsymmetricPadding,
43 algorithm: asym_utils.Prehashed
44 | hashes.HashAlgorithm
45 | asym_utils.NoDigestInfo,
46 ) -> bytes:
47 """
48 Signs the data.
49 """
50
51 @abc.abstractmethod
52 def private_numbers(self) -> RSAPrivateNumbers:
53 """
54 Returns an RSAPrivateNumbers.
55 """
56
57 @abc.abstractmethod
58 def private_bytes(
59 self,
60 encoding: _serialization.Encoding,
61 format: _serialization.PrivateFormat,
62 encryption_algorithm: _serialization.KeySerializationEncryption,
63 ) -> bytes:
64 """
65 Returns the key serialized as bytes.
66 """
67
68 @abc.abstractmethod
69 def __copy__(self) -> RSAPrivateKey:
70 """
71 Returns a copy.
72 """
73
74 @abc.abstractmethod
75 def __deepcopy__(self, memo: dict) -> RSAPrivateKey:
76 """
77 Returns a deep copy.
78 """
79
80
81RSAPrivateKeyWithSerialization = RSAPrivateKey
82RSAPrivateKey.register(rust_openssl.rsa.RSAPrivateKey)
83
84
85class RSAPublicKey(metaclass=abc.ABCMeta):
86 @abc.abstractmethod
87 def encrypt(self, plaintext: bytes, padding: AsymmetricPadding) -> bytes:
88 """
89 Encrypts the given plaintext.
90 """
91
92 @property
93 @abc.abstractmethod
94 def key_size(self) -> int:
95 """
96 The bit length of the public modulus.
97 """
98
99 @abc.abstractmethod
100 def public_numbers(self) -> RSAPublicNumbers:
101 """
102 Returns an RSAPublicNumbers
103 """
104
105 @abc.abstractmethod
106 def public_bytes(
107 self,
108 encoding: _serialization.Encoding,
109 format: _serialization.PublicFormat,
110 ) -> bytes:
111 """
112 Returns the key serialized as bytes.
113 """
114
115 @abc.abstractmethod
116 def verify(
117 self,
118 signature: bytes,
119 data: bytes,
120 padding: AsymmetricPadding,
121 algorithm: asym_utils.Prehashed | hashes.HashAlgorithm,
122 ) -> None:
123 """
124 Verifies the signature of the data.
125 """
126
127 @abc.abstractmethod
128 def recover_data_from_signature(
129 self,
130 signature: bytes,
131 padding: AsymmetricPadding,
132 algorithm: hashes.HashAlgorithm | asym_utils.NoDigestInfo | None,
133 ) -> bytes:
134 """
135 Recovers the original data from the signature.
136 """
137
138 @abc.abstractmethod
139 def __eq__(self, other: object) -> bool:
140 """
141 Checks equality.
142 """
143
144 @abc.abstractmethod
145 def __copy__(self) -> RSAPublicKey:
146 """
147 Returns a copy.
148 """
149
150 @abc.abstractmethod
151 def __deepcopy__(self, memo: dict) -> RSAPublicKey:
152 """
153 Returns a deep copy.
154 """
155
156
157RSAPublicKeyWithSerialization = RSAPublicKey
158RSAPublicKey.register(rust_openssl.rsa.RSAPublicKey)
159
160RSAPrivateNumbers = rust_openssl.rsa.RSAPrivateNumbers
161RSAPublicNumbers = rust_openssl.rsa.RSAPublicNumbers
162
163
164def generate_private_key(
165 public_exponent: int,
166 key_size: int,
167 backend: typing.Any = None,
168) -> RSAPrivateKey:
169 _verify_rsa_parameters(public_exponent, key_size)
170 return rust_openssl.rsa.generate_private_key(public_exponent, key_size)
171
172
173def _verify_rsa_parameters(public_exponent: int, key_size: int) -> None:
174 if public_exponent not in (3, 65537):
175 raise ValueError(
176 "public_exponent must be either 3 (for legacy compatibility) or "
177 "65537. Almost everyone should choose 65537 here!"
178 )
179
180 if key_size < 1024:
181 raise ValueError("key_size must be at least 1024-bits.")
182
183
184def _modinv(e: int, m: int) -> int:
185 """
186 Modular Multiplicative Inverse. Returns x such that: (x*e) mod m == 1
187 """
188 x1, x2 = 1, 0
189 a, b = e, m
190 while b > 0:
191 q, r = divmod(a, b)
192 xn = x1 - q * x2
193 a, b, x1, x2 = b, r, x2, xn
194 return x1 % m
195
196
197def rsa_crt_iqmp(p: int, q: int) -> int:
198 """
199 Compute the CRT (q ** -1) % p value from RSA primes p and q.
200 """
201 if p <= 1 or q <= 1:
202 raise ValueError("Values can't be <= 1")
203 return _modinv(q, p)
204
205
206def rsa_crt_dmp1(private_exponent: int, p: int) -> int:
207 """
208 Compute the CRT private_exponent % (p - 1) value from the RSA
209 private_exponent (d) and p.
210 """
211 if private_exponent <= 1 or p <= 1:
212 raise ValueError("Values can't be <= 1")
213 return private_exponent % (p - 1)
214
215
216def rsa_crt_dmq1(private_exponent: int, q: int) -> int:
217 """
218 Compute the CRT private_exponent % (q - 1) value from the RSA
219 private_exponent (d) and q.
220 """
221 if private_exponent <= 1 or q <= 1:
222 raise ValueError("Values can't be <= 1")
223 return private_exponent % (q - 1)
224
225
226def rsa_recover_private_exponent(e: int, p: int, q: int) -> int:
227 """
228 Compute the RSA private_exponent (d) given the public exponent (e)
229 and the RSA primes p and q.
230
231 This uses the Carmichael totient function to generate the
232 smallest possible working value of the private exponent.
233 """
234 # This lambda_n is the Carmichael totient function.
235 # The original RSA paper uses the Euler totient function
236 # here: phi_n = (p - 1) * (q - 1)
237 # Either version of the private exponent will work, but the
238 # one generated by the older formulation may be larger
239 # than necessary. (lambda_n always divides phi_n)
240 #
241 # TODO: Replace with lcm(p - 1, q - 1) once the minimum
242 # supported Python version is >= 3.9.
243 if e <= 1 or p <= 1 or q <= 1:
244 raise ValueError("Values can't be <= 1")
245 lambda_n = (p - 1) * (q - 1) // gcd(p - 1, q - 1)
246 return _modinv(e, lambda_n)
247
248
249# Controls the number of iterations rsa_recover_prime_factors will perform
250# to obtain the prime factors.
251_MAX_RECOVERY_ATTEMPTS = 500
252
253
254def rsa_recover_prime_factors(n: int, e: int, d: int) -> tuple[int, int]:
255 """
256 Compute factors p and q from the private exponent d. We assume that n has
257 no more than two factors. This function is adapted from code in PyCrypto.
258 """
259 # reject invalid values early
260 if d <= 1 or e <= 1:
261 raise ValueError("d, e can't be <= 1")
262 if 17 != pow(17, e * d, n):
263 raise ValueError("n, d, e don't match")
264 # See 8.2.2(i) in Handbook of Applied Cryptography.
265 ktot = d * e - 1
266 # The quantity d*e-1 is a multiple of phi(n), even,
267 # and can be represented as t*2^s.
268 t = ktot
269 while t % 2 == 0:
270 t = t // 2
271 # Cycle through all multiplicative inverses in Zn.
272 # The algorithm is non-deterministic, but there is a 50% chance
273 # any candidate a leads to successful factoring.
274 # See "Digitalized Signatures and Public Key Functions as Intractable
275 # as Factorization", M. Rabin, 1979
276 spotted = False
277 tries = 0
278 while not spotted and tries < _MAX_RECOVERY_ATTEMPTS:
279 a = random.randint(2, n - 1)
280 tries += 1
281 k = t
282 # Cycle through all values a^{t*2^i}=a^k
283 while k < ktot:
284 cand = pow(a, k, n)
285 # Check if a^k is a non-trivial root of unity (mod n)
286 if cand != 1 and cand != (n - 1) and pow(cand, 2, n) == 1:
287 # We have found a number such that (cand-1)(cand+1)=0 (mod n).
288 # Either of the terms divides n.
289 p = gcd(cand + 1, n)
290 spotted = True
291 break
292 k *= 2
293 if not spotted:
294 raise ValueError("Unable to compute factors p and q from exponent d.")
295 # Found !
296 q, r = divmod(n, p)
297 assert r == 0
298 p, q = sorted((p, q), reverse=True)
299 return (p, q)