1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import annotations
6
7import abc
8import random
9import typing
10from math import gcd
11
12from cryptography.hazmat.bindings._rust import openssl as rust_openssl
13from cryptography.hazmat.primitives import _serialization, hashes
14from cryptography.hazmat.primitives._asymmetric import AsymmetricPadding
15from cryptography.hazmat.primitives.asymmetric import utils as asym_utils
16
17
18class RSAPrivateKey(metaclass=abc.ABCMeta):
19 @abc.abstractmethod
20 def decrypt(self, ciphertext: bytes, padding: AsymmetricPadding) -> bytes:
21 """
22 Decrypts the provided ciphertext.
23 """
24
25 @property
26 @abc.abstractmethod
27 def key_size(self) -> int:
28 """
29 The bit length of the public modulus.
30 """
31
32 @abc.abstractmethod
33 def public_key(self) -> RSAPublicKey:
34 """
35 The RSAPublicKey associated with this private key.
36 """
37
38 @abc.abstractmethod
39 def sign(
40 self,
41 data: bytes,
42 padding: AsymmetricPadding,
43 algorithm: asym_utils.Prehashed | hashes.HashAlgorithm,
44 ) -> bytes:
45 """
46 Signs the data.
47 """
48
49 @abc.abstractmethod
50 def private_numbers(self) -> RSAPrivateNumbers:
51 """
52 Returns an RSAPrivateNumbers.
53 """
54
55 @abc.abstractmethod
56 def private_bytes(
57 self,
58 encoding: _serialization.Encoding,
59 format: _serialization.PrivateFormat,
60 encryption_algorithm: _serialization.KeySerializationEncryption,
61 ) -> bytes:
62 """
63 Returns the key serialized as bytes.
64 """
65
66 @abc.abstractmethod
67 def __copy__(self) -> RSAPrivateKey:
68 """
69 Returns a copy.
70 """
71
72
73RSAPrivateKeyWithSerialization = RSAPrivateKey
74RSAPrivateKey.register(rust_openssl.rsa.RSAPrivateKey)
75
76
77class RSAPublicKey(metaclass=abc.ABCMeta):
78 @abc.abstractmethod
79 def encrypt(self, plaintext: bytes, padding: AsymmetricPadding) -> bytes:
80 """
81 Encrypts the given plaintext.
82 """
83
84 @property
85 @abc.abstractmethod
86 def key_size(self) -> int:
87 """
88 The bit length of the public modulus.
89 """
90
91 @abc.abstractmethod
92 def public_numbers(self) -> RSAPublicNumbers:
93 """
94 Returns an RSAPublicNumbers
95 """
96
97 @abc.abstractmethod
98 def public_bytes(
99 self,
100 encoding: _serialization.Encoding,
101 format: _serialization.PublicFormat,
102 ) -> bytes:
103 """
104 Returns the key serialized as bytes.
105 """
106
107 @abc.abstractmethod
108 def verify(
109 self,
110 signature: bytes,
111 data: bytes,
112 padding: AsymmetricPadding,
113 algorithm: asym_utils.Prehashed | hashes.HashAlgorithm,
114 ) -> None:
115 """
116 Verifies the signature of the data.
117 """
118
119 @abc.abstractmethod
120 def recover_data_from_signature(
121 self,
122 signature: bytes,
123 padding: AsymmetricPadding,
124 algorithm: hashes.HashAlgorithm | None,
125 ) -> bytes:
126 """
127 Recovers the original data from the signature.
128 """
129
130 @abc.abstractmethod
131 def __eq__(self, other: object) -> bool:
132 """
133 Checks equality.
134 """
135
136 @abc.abstractmethod
137 def __copy__(self) -> RSAPublicKey:
138 """
139 Returns a copy.
140 """
141
142
143RSAPublicKeyWithSerialization = RSAPublicKey
144RSAPublicKey.register(rust_openssl.rsa.RSAPublicKey)
145
146RSAPrivateNumbers = rust_openssl.rsa.RSAPrivateNumbers
147RSAPublicNumbers = rust_openssl.rsa.RSAPublicNumbers
148
149
150def generate_private_key(
151 public_exponent: int,
152 key_size: int,
153 backend: typing.Any = None,
154) -> RSAPrivateKey:
155 _verify_rsa_parameters(public_exponent, key_size)
156 return rust_openssl.rsa.generate_private_key(public_exponent, key_size)
157
158
159def _verify_rsa_parameters(public_exponent: int, key_size: int) -> None:
160 if public_exponent not in (3, 65537):
161 raise ValueError(
162 "public_exponent must be either 3 (for legacy compatibility) or "
163 "65537. Almost everyone should choose 65537 here!"
164 )
165
166 if key_size < 1024:
167 raise ValueError("key_size must be at least 1024-bits.")
168
169
170def _modinv(e: int, m: int) -> int:
171 """
172 Modular Multiplicative Inverse. Returns x such that: (x*e) mod m == 1
173 """
174 x1, x2 = 1, 0
175 a, b = e, m
176 while b > 0:
177 q, r = divmod(a, b)
178 xn = x1 - q * x2
179 a, b, x1, x2 = b, r, x2, xn
180 return x1 % m
181
182
183def rsa_crt_iqmp(p: int, q: int) -> int:
184 """
185 Compute the CRT (q ** -1) % p value from RSA primes p and q.
186 """
187 return _modinv(q, p)
188
189
190def rsa_crt_dmp1(private_exponent: int, p: int) -> int:
191 """
192 Compute the CRT private_exponent % (p - 1) value from the RSA
193 private_exponent (d) and p.
194 """
195 return private_exponent % (p - 1)
196
197
198def rsa_crt_dmq1(private_exponent: int, q: int) -> int:
199 """
200 Compute the CRT private_exponent % (q - 1) value from the RSA
201 private_exponent (d) and q.
202 """
203 return private_exponent % (q - 1)
204
205
206def rsa_recover_private_exponent(e: int, p: int, q: int) -> int:
207 """
208 Compute the RSA private_exponent (d) given the public exponent (e)
209 and the RSA primes p and q.
210
211 This uses the Carmichael totient function to generate the
212 smallest possible working value of the private exponent.
213 """
214 # This lambda_n is the Carmichael totient function.
215 # The original RSA paper uses the Euler totient function
216 # here: phi_n = (p - 1) * (q - 1)
217 # Either version of the private exponent will work, but the
218 # one generated by the older formulation may be larger
219 # than necessary. (lambda_n always divides phi_n)
220 #
221 # TODO: Replace with lcm(p - 1, q - 1) once the minimum
222 # supported Python version is >= 3.9.
223 lambda_n = (p - 1) * (q - 1) // gcd(p - 1, q - 1)
224 return _modinv(e, lambda_n)
225
226
227# Controls the number of iterations rsa_recover_prime_factors will perform
228# to obtain the prime factors.
229_MAX_RECOVERY_ATTEMPTS = 500
230
231
232def rsa_recover_prime_factors(n: int, e: int, d: int) -> tuple[int, int]:
233 """
234 Compute factors p and q from the private exponent d. We assume that n has
235 no more than two factors. This function is adapted from code in PyCrypto.
236 """
237 # reject invalid values early
238 if d <= 1 or e <= 1:
239 raise ValueError("d, e can't be <= 1")
240 if 17 != pow(17, e * d, n):
241 raise ValueError("n, d, e don't match")
242 # See 8.2.2(i) in Handbook of Applied Cryptography.
243 ktot = d * e - 1
244 # The quantity d*e-1 is a multiple of phi(n), even,
245 # and can be represented as t*2^s.
246 t = ktot
247 while t % 2 == 0:
248 t = t // 2
249 # Cycle through all multiplicative inverses in Zn.
250 # The algorithm is non-deterministic, but there is a 50% chance
251 # any candidate a leads to successful factoring.
252 # See "Digitalized Signatures and Public Key Functions as Intractable
253 # as Factorization", M. Rabin, 1979
254 spotted = False
255 tries = 0
256 while not spotted and tries < _MAX_RECOVERY_ATTEMPTS:
257 a = random.randint(2, n - 1)
258 tries += 1
259 k = t
260 # Cycle through all values a^{t*2^i}=a^k
261 while k < ktot:
262 cand = pow(a, k, n)
263 # Check if a^k is a non-trivial root of unity (mod n)
264 if cand != 1 and cand != (n - 1) and pow(cand, 2, n) == 1:
265 # We have found a number such that (cand-1)(cand+1)=0 (mod n).
266 # Either of the terms divides n.
267 p = gcd(cand + 1, n)
268 spotted = True
269 break
270 k *= 2
271 if not spotted:
272 raise ValueError("Unable to compute factors p and q from exponent d.")
273 # Found !
274 q, r = divmod(n, p)
275 assert r == 0
276 p, q = sorted((p, q), reverse=True)
277 return (p, q)