1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import annotations
6
7import abc
8import random
9import typing
10from math import gcd
11
12from cryptography.hazmat.bindings._rust import openssl as rust_openssl
13from cryptography.hazmat.primitives import _serialization, hashes
14from cryptography.hazmat.primitives._asymmetric import AsymmetricPadding
15from cryptography.hazmat.primitives.asymmetric import utils as asym_utils
16
17
18class RSAPrivateKey(metaclass=abc.ABCMeta):
19 @abc.abstractmethod
20 def decrypt(self, ciphertext: bytes, padding: AsymmetricPadding) -> bytes:
21 """
22 Decrypts the provided ciphertext.
23 """
24
25 @property
26 @abc.abstractmethod
27 def key_size(self) -> int:
28 """
29 The bit length of the public modulus.
30 """
31
32 @abc.abstractmethod
33 def public_key(self) -> RSAPublicKey:
34 """
35 The RSAPublicKey associated with this private key.
36 """
37
38 @abc.abstractmethod
39 def sign(
40 self,
41 data: bytes,
42 padding: AsymmetricPadding,
43 algorithm: asym_utils.Prehashed | hashes.HashAlgorithm,
44 ) -> bytes:
45 """
46 Signs the data.
47 """
48
49 @abc.abstractmethod
50 def private_numbers(self) -> RSAPrivateNumbers:
51 """
52 Returns an RSAPrivateNumbers.
53 """
54
55 @abc.abstractmethod
56 def private_bytes(
57 self,
58 encoding: _serialization.Encoding,
59 format: _serialization.PrivateFormat,
60 encryption_algorithm: _serialization.KeySerializationEncryption,
61 ) -> bytes:
62 """
63 Returns the key serialized as bytes.
64 """
65
66
67RSAPrivateKeyWithSerialization = RSAPrivateKey
68RSAPrivateKey.register(rust_openssl.rsa.RSAPrivateKey)
69
70
71class RSAPublicKey(metaclass=abc.ABCMeta):
72 @abc.abstractmethod
73 def encrypt(self, plaintext: bytes, padding: AsymmetricPadding) -> bytes:
74 """
75 Encrypts the given plaintext.
76 """
77
78 @property
79 @abc.abstractmethod
80 def key_size(self) -> int:
81 """
82 The bit length of the public modulus.
83 """
84
85 @abc.abstractmethod
86 def public_numbers(self) -> RSAPublicNumbers:
87 """
88 Returns an RSAPublicNumbers
89 """
90
91 @abc.abstractmethod
92 def public_bytes(
93 self,
94 encoding: _serialization.Encoding,
95 format: _serialization.PublicFormat,
96 ) -> bytes:
97 """
98 Returns the key serialized as bytes.
99 """
100
101 @abc.abstractmethod
102 def verify(
103 self,
104 signature: bytes,
105 data: bytes,
106 padding: AsymmetricPadding,
107 algorithm: asym_utils.Prehashed | hashes.HashAlgorithm,
108 ) -> None:
109 """
110 Verifies the signature of the data.
111 """
112
113 @abc.abstractmethod
114 def recover_data_from_signature(
115 self,
116 signature: bytes,
117 padding: AsymmetricPadding,
118 algorithm: hashes.HashAlgorithm | None,
119 ) -> bytes:
120 """
121 Recovers the original data from the signature.
122 """
123
124 @abc.abstractmethod
125 def __eq__(self, other: object) -> bool:
126 """
127 Checks equality.
128 """
129
130
131RSAPublicKeyWithSerialization = RSAPublicKey
132RSAPublicKey.register(rust_openssl.rsa.RSAPublicKey)
133
134RSAPrivateNumbers = rust_openssl.rsa.RSAPrivateNumbers
135RSAPublicNumbers = rust_openssl.rsa.RSAPublicNumbers
136
137
138def generate_private_key(
139 public_exponent: int,
140 key_size: int,
141 backend: typing.Any = None,
142) -> RSAPrivateKey:
143 _verify_rsa_parameters(public_exponent, key_size)
144 return rust_openssl.rsa.generate_private_key(public_exponent, key_size)
145
146
147def _verify_rsa_parameters(public_exponent: int, key_size: int) -> None:
148 if public_exponent not in (3, 65537):
149 raise ValueError(
150 "public_exponent must be either 3 (for legacy compatibility) or "
151 "65537. Almost everyone should choose 65537 here!"
152 )
153
154 if key_size < 1024:
155 raise ValueError("key_size must be at least 1024-bits.")
156
157
158def _modinv(e: int, m: int) -> int:
159 """
160 Modular Multiplicative Inverse. Returns x such that: (x*e) mod m == 1
161 """
162 x1, x2 = 1, 0
163 a, b = e, m
164 while b > 0:
165 q, r = divmod(a, b)
166 xn = x1 - q * x2
167 a, b, x1, x2 = b, r, x2, xn
168 return x1 % m
169
170
171def rsa_crt_iqmp(p: int, q: int) -> int:
172 """
173 Compute the CRT (q ** -1) % p value from RSA primes p and q.
174 """
175 return _modinv(q, p)
176
177
178def rsa_crt_dmp1(private_exponent: int, p: int) -> int:
179 """
180 Compute the CRT private_exponent % (p - 1) value from the RSA
181 private_exponent (d) and p.
182 """
183 return private_exponent % (p - 1)
184
185
186def rsa_crt_dmq1(private_exponent: int, q: int) -> int:
187 """
188 Compute the CRT private_exponent % (q - 1) value from the RSA
189 private_exponent (d) and q.
190 """
191 return private_exponent % (q - 1)
192
193
194def rsa_recover_private_exponent(e: int, p: int, q: int) -> int:
195 """
196 Compute the RSA private_exponent (d) given the public exponent (e)
197 and the RSA primes p and q.
198
199 This uses the Carmichael totient function to generate the
200 smallest possible working value of the private exponent.
201 """
202 # This lambda_n is the Carmichael totient function.
203 # The original RSA paper uses the Euler totient function
204 # here: phi_n = (p - 1) * (q - 1)
205 # Either version of the private exponent will work, but the
206 # one generated by the older formulation may be larger
207 # than necessary. (lambda_n always divides phi_n)
208 #
209 # TODO: Replace with lcm(p - 1, q - 1) once the minimum
210 # supported Python version is >= 3.9.
211 lambda_n = (p - 1) * (q - 1) // gcd(p - 1, q - 1)
212 return _modinv(e, lambda_n)
213
214
215# Controls the number of iterations rsa_recover_prime_factors will perform
216# to obtain the prime factors.
217_MAX_RECOVERY_ATTEMPTS = 500
218
219
220def rsa_recover_prime_factors(n: int, e: int, d: int) -> tuple[int, int]:
221 """
222 Compute factors p and q from the private exponent d. We assume that n has
223 no more than two factors. This function is adapted from code in PyCrypto.
224 """
225 # reject invalid values early
226 if 17 != pow(17, e * d, n):
227 raise ValueError("n, d, e don't match")
228 # See 8.2.2(i) in Handbook of Applied Cryptography.
229 ktot = d * e - 1
230 # The quantity d*e-1 is a multiple of phi(n), even,
231 # and can be represented as t*2^s.
232 t = ktot
233 while t % 2 == 0:
234 t = t // 2
235 # Cycle through all multiplicative inverses in Zn.
236 # The algorithm is non-deterministic, but there is a 50% chance
237 # any candidate a leads to successful factoring.
238 # See "Digitalized Signatures and Public Key Functions as Intractable
239 # as Factorization", M. Rabin, 1979
240 spotted = False
241 tries = 0
242 while not spotted and tries < _MAX_RECOVERY_ATTEMPTS:
243 a = random.randint(2, n - 1)
244 tries += 1
245 k = t
246 # Cycle through all values a^{t*2^i}=a^k
247 while k < ktot:
248 cand = pow(a, k, n)
249 # Check if a^k is a non-trivial root of unity (mod n)
250 if cand != 1 and cand != (n - 1) and pow(cand, 2, n) == 1:
251 # We have found a number such that (cand-1)(cand+1)=0 (mod n).
252 # Either of the terms divides n.
253 p = gcd(cand + 1, n)
254 spotted = True
255 break
256 k *= 2
257 if not spotted:
258 raise ValueError("Unable to compute factors p and q from exponent d.")
259 # Found !
260 q, r = divmod(n, p)
261 assert r == 0
262 p, q = sorted((p, q), reverse=True)
263 return (p, q)