1# This file is dual licensed under the terms of the Apache License, Version 
    2# 2.0, and the BSD License. See the LICENSE file in the root of this repository 
    3# for complete details. 
    4 
    5from __future__ import annotations 
    6 
    7import abc 
    8import random 
    9import typing 
    10from math import gcd 
    11 
    12from cryptography.hazmat.bindings._rust import openssl as rust_openssl 
    13from cryptography.hazmat.primitives import _serialization, hashes 
    14from cryptography.hazmat.primitives._asymmetric import AsymmetricPadding 
    15from cryptography.hazmat.primitives.asymmetric import utils as asym_utils 
    16 
    17 
    18class RSAPrivateKey(metaclass=abc.ABCMeta): 
    19    @abc.abstractmethod 
    20    def decrypt(self, ciphertext: bytes, padding: AsymmetricPadding) -> bytes: 
    21        """ 
    22        Decrypts the provided ciphertext. 
    23        """ 
    24 
    25    @property 
    26    @abc.abstractmethod 
    27    def key_size(self) -> int: 
    28        """ 
    29        The bit length of the public modulus. 
    30        """ 
    31 
    32    @abc.abstractmethod 
    33    def public_key(self) -> RSAPublicKey: 
    34        """ 
    35        The RSAPublicKey associated with this private key. 
    36        """ 
    37 
    38    @abc.abstractmethod 
    39    def sign( 
    40        self, 
    41        data: bytes, 
    42        padding: AsymmetricPadding, 
    43        algorithm: asym_utils.Prehashed | hashes.HashAlgorithm, 
    44    ) -> bytes: 
    45        """ 
    46        Signs the data. 
    47        """ 
    48 
    49    @abc.abstractmethod 
    50    def private_numbers(self) -> RSAPrivateNumbers: 
    51        """ 
    52        Returns an RSAPrivateNumbers. 
    53        """ 
    54 
    55    @abc.abstractmethod 
    56    def private_bytes( 
    57        self, 
    58        encoding: _serialization.Encoding, 
    59        format: _serialization.PrivateFormat, 
    60        encryption_algorithm: _serialization.KeySerializationEncryption, 
    61    ) -> bytes: 
    62        """ 
    63        Returns the key serialized as bytes. 
    64        """ 
    65 
    66    @abc.abstractmethod 
    67    def __copy__(self) -> RSAPrivateKey: 
    68        """ 
    69        Returns a copy. 
    70        """ 
    71 
    72 
    73RSAPrivateKeyWithSerialization = RSAPrivateKey 
    74RSAPrivateKey.register(rust_openssl.rsa.RSAPrivateKey) 
    75 
    76 
    77class RSAPublicKey(metaclass=abc.ABCMeta): 
    78    @abc.abstractmethod 
    79    def encrypt(self, plaintext: bytes, padding: AsymmetricPadding) -> bytes: 
    80        """ 
    81        Encrypts the given plaintext. 
    82        """ 
    83 
    84    @property 
    85    @abc.abstractmethod 
    86    def key_size(self) -> int: 
    87        """ 
    88        The bit length of the public modulus. 
    89        """ 
    90 
    91    @abc.abstractmethod 
    92    def public_numbers(self) -> RSAPublicNumbers: 
    93        """ 
    94        Returns an RSAPublicNumbers 
    95        """ 
    96 
    97    @abc.abstractmethod 
    98    def public_bytes( 
    99        self, 
    100        encoding: _serialization.Encoding, 
    101        format: _serialization.PublicFormat, 
    102    ) -> bytes: 
    103        """ 
    104        Returns the key serialized as bytes. 
    105        """ 
    106 
    107    @abc.abstractmethod 
    108    def verify( 
    109        self, 
    110        signature: bytes, 
    111        data: bytes, 
    112        padding: AsymmetricPadding, 
    113        algorithm: asym_utils.Prehashed | hashes.HashAlgorithm, 
    114    ) -> None: 
    115        """ 
    116        Verifies the signature of the data. 
    117        """ 
    118 
    119    @abc.abstractmethod 
    120    def recover_data_from_signature( 
    121        self, 
    122        signature: bytes, 
    123        padding: AsymmetricPadding, 
    124        algorithm: hashes.HashAlgorithm | None, 
    125    ) -> bytes: 
    126        """ 
    127        Recovers the original data from the signature. 
    128        """ 
    129 
    130    @abc.abstractmethod 
    131    def __eq__(self, other: object) -> bool: 
    132        """ 
    133        Checks equality. 
    134        """ 
    135 
    136    @abc.abstractmethod 
    137    def __copy__(self) -> RSAPublicKey: 
    138        """ 
    139        Returns a copy. 
    140        """ 
    141 
    142 
    143RSAPublicKeyWithSerialization = RSAPublicKey 
    144RSAPublicKey.register(rust_openssl.rsa.RSAPublicKey) 
    145 
    146RSAPrivateNumbers = rust_openssl.rsa.RSAPrivateNumbers 
    147RSAPublicNumbers = rust_openssl.rsa.RSAPublicNumbers 
    148 
    149 
    150def generate_private_key( 
    151    public_exponent: int, 
    152    key_size: int, 
    153    backend: typing.Any = None, 
    154) -> RSAPrivateKey: 
    155    _verify_rsa_parameters(public_exponent, key_size) 
    156    return rust_openssl.rsa.generate_private_key(public_exponent, key_size) 
    157 
    158 
    159def _verify_rsa_parameters(public_exponent: int, key_size: int) -> None: 
    160    if public_exponent not in (3, 65537): 
    161        raise ValueError( 
    162            "public_exponent must be either 3 (for legacy compatibility) or " 
    163            "65537. Almost everyone should choose 65537 here!" 
    164        ) 
    165 
    166    if key_size < 1024: 
    167        raise ValueError("key_size must be at least 1024-bits.") 
    168 
    169 
    170def _modinv(e: int, m: int) -> int: 
    171    """ 
    172    Modular Multiplicative Inverse. Returns x such that: (x*e) mod m == 1 
    173    """ 
    174    x1, x2 = 1, 0 
    175    a, b = e, m 
    176    while b > 0: 
    177        q, r = divmod(a, b) 
    178        xn = x1 - q * x2 
    179        a, b, x1, x2 = b, r, x2, xn 
    180    return x1 % m 
    181 
    182 
    183def rsa_crt_iqmp(p: int, q: int) -> int: 
    184    """ 
    185    Compute the CRT (q ** -1) % p value from RSA primes p and q. 
    186    """ 
    187    if p <= 1 or q <= 1: 
    188        raise ValueError("Values can't be <= 1") 
    189    return _modinv(q, p) 
    190 
    191 
    192def rsa_crt_dmp1(private_exponent: int, p: int) -> int: 
    193    """ 
    194    Compute the CRT private_exponent % (p - 1) value from the RSA 
    195    private_exponent (d) and p. 
    196    """ 
    197    if private_exponent <= 1 or p <= 1: 
    198        raise ValueError("Values can't be <= 1") 
    199    return private_exponent % (p - 1) 
    200 
    201 
    202def rsa_crt_dmq1(private_exponent: int, q: int) -> int: 
    203    """ 
    204    Compute the CRT private_exponent % (q - 1) value from the RSA 
    205    private_exponent (d) and q. 
    206    """ 
    207    if private_exponent <= 1 or q <= 1: 
    208        raise ValueError("Values can't be <= 1") 
    209    return private_exponent % (q - 1) 
    210 
    211 
    212def rsa_recover_private_exponent(e: int, p: int, q: int) -> int: 
    213    """ 
    214    Compute the RSA private_exponent (d) given the public exponent (e) 
    215    and the RSA primes p and q. 
    216 
    217    This uses the Carmichael totient function to generate the 
    218    smallest possible working value of the private exponent. 
    219    """ 
    220    # This lambda_n is the Carmichael totient function. 
    221    # The original RSA paper uses the Euler totient function 
    222    # here: phi_n = (p - 1) * (q - 1) 
    223    # Either version of the private exponent will work, but the 
    224    # one generated by the older formulation may be larger 
    225    # than necessary. (lambda_n always divides phi_n) 
    226    # 
    227    # TODO: Replace with lcm(p - 1, q - 1) once the minimum 
    228    # supported Python version is >= 3.9. 
    229    if e <= 1 or p <= 1 or q <= 1: 
    230        raise ValueError("Values can't be <= 1") 
    231    lambda_n = (p - 1) * (q - 1) // gcd(p - 1, q - 1) 
    232    return _modinv(e, lambda_n) 
    233 
    234 
    235# Controls the number of iterations rsa_recover_prime_factors will perform 
    236# to obtain the prime factors. 
    237_MAX_RECOVERY_ATTEMPTS = 500 
    238 
    239 
    240def rsa_recover_prime_factors(n: int, e: int, d: int) -> tuple[int, int]: 
    241    """ 
    242    Compute factors p and q from the private exponent d. We assume that n has 
    243    no more than two factors. This function is adapted from code in PyCrypto. 
    244    """ 
    245    # reject invalid values early 
    246    if d <= 1 or e <= 1: 
    247        raise ValueError("d, e can't be <= 1") 
    248    if 17 != pow(17, e * d, n): 
    249        raise ValueError("n, d, e don't match") 
    250    # See 8.2.2(i) in Handbook of Applied Cryptography. 
    251    ktot = d * e - 1 
    252    # The quantity d*e-1 is a multiple of phi(n), even, 
    253    # and can be represented as t*2^s. 
    254    t = ktot 
    255    while t % 2 == 0: 
    256        t = t // 2 
    257    # Cycle through all multiplicative inverses in Zn. 
    258    # The algorithm is non-deterministic, but there is a 50% chance 
    259    # any candidate a leads to successful factoring. 
    260    # See "Digitalized Signatures and Public Key Functions as Intractable 
    261    # as Factorization", M. Rabin, 1979 
    262    spotted = False 
    263    tries = 0 
    264    while not spotted and tries < _MAX_RECOVERY_ATTEMPTS: 
    265        a = random.randint(2, n - 1) 
    266        tries += 1 
    267        k = t 
    268        # Cycle through all values a^{t*2^i}=a^k 
    269        while k < ktot: 
    270            cand = pow(a, k, n) 
    271            # Check if a^k is a non-trivial root of unity (mod n) 
    272            if cand != 1 and cand != (n - 1) and pow(cand, 2, n) == 1: 
    273                # We have found a number such that (cand-1)(cand+1)=0 (mod n). 
    274                # Either of the terms divides n. 
    275                p = gcd(cand + 1, n) 
    276                spotted = True 
    277                break 
    278            k *= 2 
    279    if not spotted: 
    280        raise ValueError("Unable to compute factors p and q from exponent d.") 
    281    # Found ! 
    282    q, r = divmod(n, p) 
    283    assert r == 0 
    284    p, q = sorted((p, q), reverse=True) 
    285    return (p, q)