1# This file is dual licensed under the terms of the Apache License, Version 
    2# 2.0, and the BSD License. See the LICENSE file in the root of this repository 
    3# for complete details. 
    4 
    5from __future__ import annotations 
    6 
    7import abc 
    8import random 
    9import typing 
    10from math import gcd 
    11 
    12from cryptography.hazmat.bindings._rust import openssl as rust_openssl 
    13from cryptography.hazmat.primitives import _serialization, hashes 
    14from cryptography.hazmat.primitives._asymmetric import AsymmetricPadding 
    15from cryptography.hazmat.primitives.asymmetric import utils as asym_utils 
    16 
    17 
    18class RSAPrivateKey(metaclass=abc.ABCMeta): 
    19    @abc.abstractmethod 
    20    def decrypt(self, ciphertext: bytes, padding: AsymmetricPadding) -> bytes: 
    21        """ 
    22        Decrypts the provided ciphertext. 
    23        """ 
    24 
    25    @property 
    26    @abc.abstractmethod 
    27    def key_size(self) -> int: 
    28        """ 
    29        The bit length of the public modulus. 
    30        """ 
    31 
    32    @abc.abstractmethod 
    33    def public_key(self) -> RSAPublicKey: 
    34        """ 
    35        The RSAPublicKey associated with this private key. 
    36        """ 
    37 
    38    @abc.abstractmethod 
    39    def sign( 
    40        self, 
    41        data: bytes, 
    42        padding: AsymmetricPadding, 
    43        algorithm: asym_utils.Prehashed | hashes.HashAlgorithm, 
    44    ) -> bytes: 
    45        """ 
    46        Signs the data. 
    47        """ 
    48 
    49    @abc.abstractmethod 
    50    def private_numbers(self) -> RSAPrivateNumbers: 
    51        """ 
    52        Returns an RSAPrivateNumbers. 
    53        """ 
    54 
    55    @abc.abstractmethod 
    56    def private_bytes( 
    57        self, 
    58        encoding: _serialization.Encoding, 
    59        format: _serialization.PrivateFormat, 
    60        encryption_algorithm: _serialization.KeySerializationEncryption, 
    61    ) -> bytes: 
    62        """ 
    63        Returns the key serialized as bytes. 
    64        """ 
    65 
    66 
    67RSAPrivateKeyWithSerialization = RSAPrivateKey 
    68RSAPrivateKey.register(rust_openssl.rsa.RSAPrivateKey) 
    69 
    70 
    71class RSAPublicKey(metaclass=abc.ABCMeta): 
    72    @abc.abstractmethod 
    73    def encrypt(self, plaintext: bytes, padding: AsymmetricPadding) -> bytes: 
    74        """ 
    75        Encrypts the given plaintext. 
    76        """ 
    77 
    78    @property 
    79    @abc.abstractmethod 
    80    def key_size(self) -> int: 
    81        """ 
    82        The bit length of the public modulus. 
    83        """ 
    84 
    85    @abc.abstractmethod 
    86    def public_numbers(self) -> RSAPublicNumbers: 
    87        """ 
    88        Returns an RSAPublicNumbers 
    89        """ 
    90 
    91    @abc.abstractmethod 
    92    def public_bytes( 
    93        self, 
    94        encoding: _serialization.Encoding, 
    95        format: _serialization.PublicFormat, 
    96    ) -> bytes: 
    97        """ 
    98        Returns the key serialized as bytes. 
    99        """ 
    100 
    101    @abc.abstractmethod 
    102    def verify( 
    103        self, 
    104        signature: bytes, 
    105        data: bytes, 
    106        padding: AsymmetricPadding, 
    107        algorithm: asym_utils.Prehashed | hashes.HashAlgorithm, 
    108    ) -> None: 
    109        """ 
    110        Verifies the signature of the data. 
    111        """ 
    112 
    113    @abc.abstractmethod 
    114    def recover_data_from_signature( 
    115        self, 
    116        signature: bytes, 
    117        padding: AsymmetricPadding, 
    118        algorithm: hashes.HashAlgorithm | None, 
    119    ) -> bytes: 
    120        """ 
    121        Recovers the original data from the signature. 
    122        """ 
    123 
    124    @abc.abstractmethod 
    125    def __eq__(self, other: object) -> bool: 
    126        """ 
    127        Checks equality. 
    128        """ 
    129 
    130 
    131RSAPublicKeyWithSerialization = RSAPublicKey 
    132RSAPublicKey.register(rust_openssl.rsa.RSAPublicKey) 
    133 
    134RSAPrivateNumbers = rust_openssl.rsa.RSAPrivateNumbers 
    135RSAPublicNumbers = rust_openssl.rsa.RSAPublicNumbers 
    136 
    137 
    138def generate_private_key( 
    139    public_exponent: int, 
    140    key_size: int, 
    141    backend: typing.Any = None, 
    142) -> RSAPrivateKey: 
    143    _verify_rsa_parameters(public_exponent, key_size) 
    144    return rust_openssl.rsa.generate_private_key(public_exponent, key_size) 
    145 
    146 
    147def _verify_rsa_parameters(public_exponent: int, key_size: int) -> None: 
    148    if public_exponent not in (3, 65537): 
    149        raise ValueError( 
    150            "public_exponent must be either 3 (for legacy compatibility) or " 
    151            "65537. Almost everyone should choose 65537 here!" 
    152        ) 
    153 
    154    if key_size < 1024: 
    155        raise ValueError("key_size must be at least 1024-bits.") 
    156 
    157 
    158def _modinv(e: int, m: int) -> int: 
    159    """ 
    160    Modular Multiplicative Inverse. Returns x such that: (x*e) mod m == 1 
    161    """ 
    162    x1, x2 = 1, 0 
    163    a, b = e, m 
    164    while b > 0: 
    165        q, r = divmod(a, b) 
    166        xn = x1 - q * x2 
    167        a, b, x1, x2 = b, r, x2, xn 
    168    return x1 % m 
    169 
    170 
    171def rsa_crt_iqmp(p: int, q: int) -> int: 
    172    """ 
    173    Compute the CRT (q ** -1) % p value from RSA primes p and q. 
    174    """ 
    175    return _modinv(q, p) 
    176 
    177 
    178def rsa_crt_dmp1(private_exponent: int, p: int) -> int: 
    179    """ 
    180    Compute the CRT private_exponent % (p - 1) value from the RSA 
    181    private_exponent (d) and p. 
    182    """ 
    183    return private_exponent % (p - 1) 
    184 
    185 
    186def rsa_crt_dmq1(private_exponent: int, q: int) -> int: 
    187    """ 
    188    Compute the CRT private_exponent % (q - 1) value from the RSA 
    189    private_exponent (d) and q. 
    190    """ 
    191    return private_exponent % (q - 1) 
    192 
    193 
    194def rsa_recover_private_exponent(e: int, p: int, q: int) -> int: 
    195    """ 
    196    Compute the RSA private_exponent (d) given the public exponent (e) 
    197    and the RSA primes p and q. 
    198 
    199    This uses the Carmichael totient function to generate the 
    200    smallest possible working value of the private exponent. 
    201    """ 
    202    # This lambda_n is the Carmichael totient function. 
    203    # The original RSA paper uses the Euler totient function 
    204    # here: phi_n = (p - 1) * (q - 1) 
    205    # Either version of the private exponent will work, but the 
    206    # one generated by the older formulation may be larger 
    207    # than necessary. (lambda_n always divides phi_n) 
    208    # 
    209    # TODO: Replace with lcm(p - 1, q - 1) once the minimum 
    210    # supported Python version is >= 3.9. 
    211    lambda_n = (p - 1) * (q - 1) // gcd(p - 1, q - 1) 
    212    return _modinv(e, lambda_n) 
    213 
    214 
    215# Controls the number of iterations rsa_recover_prime_factors will perform 
    216# to obtain the prime factors. 
    217_MAX_RECOVERY_ATTEMPTS = 500 
    218 
    219 
    220def rsa_recover_prime_factors(n: int, e: int, d: int) -> tuple[int, int]: 
    221    """ 
    222    Compute factors p and q from the private exponent d. We assume that n has 
    223    no more than two factors. This function is adapted from code in PyCrypto. 
    224    """ 
    225    # reject invalid values early 
    226    if 17 != pow(17, e * d, n): 
    227        raise ValueError("n, d, e don't match") 
    228    # See 8.2.2(i) in Handbook of Applied Cryptography. 
    229    ktot = d * e - 1 
    230    # The quantity d*e-1 is a multiple of phi(n), even, 
    231    # and can be represented as t*2^s. 
    232    t = ktot 
    233    while t % 2 == 0: 
    234        t = t // 2 
    235    # Cycle through all multiplicative inverses in Zn. 
    236    # The algorithm is non-deterministic, but there is a 50% chance 
    237    # any candidate a leads to successful factoring. 
    238    # See "Digitalized Signatures and Public Key Functions as Intractable 
    239    # as Factorization", M. Rabin, 1979 
    240    spotted = False 
    241    tries = 0 
    242    while not spotted and tries < _MAX_RECOVERY_ATTEMPTS: 
    243        a = random.randint(2, n - 1) 
    244        tries += 1 
    245        k = t 
    246        # Cycle through all values a^{t*2^i}=a^k 
    247        while k < ktot: 
    248            cand = pow(a, k, n) 
    249            # Check if a^k is a non-trivial root of unity (mod n) 
    250            if cand != 1 and cand != (n - 1) and pow(cand, 2, n) == 1: 
    251                # We have found a number such that (cand-1)(cand+1)=0 (mod n). 
    252                # Either of the terms divides n. 
    253                p = gcd(cand + 1, n) 
    254                spotted = True 
    255                break 
    256            k *= 2 
    257    if not spotted: 
    258        raise ValueError("Unable to compute factors p and q from exponent d.") 
    259    # Found ! 
    260    q, r = divmod(n, p) 
    261    assert r == 0 
    262    p, q = sorted((p, q), reverse=True) 
    263    return (p, q)