1import binascii 
    2import warnings 
    3 
    4import rsa as pyrsa 
    5import rsa.pem as pyrsa_pem 
    6from pyasn1.error import PyAsn1Error 
    7from rsa import DecryptionError 
    8 
    9from jose.backends._asn1 import ( 
    10    rsa_private_key_pkcs1_to_pkcs8, 
    11    rsa_private_key_pkcs8_to_pkcs1, 
    12    rsa_public_key_pkcs1_to_pkcs8, 
    13) 
    14from jose.backends.base import Key 
    15from jose.constants import ALGORITHMS 
    16from jose.exceptions import JWEError, JWKError 
    17from jose.utils import base64_to_long, long_to_base64 
    18 
    19ALGORITHMS.SUPPORTED.remove(ALGORITHMS.RSA_OAEP)  # RSA OAEP not supported 
    20 
    21LEGACY_INVALID_PKCS8_RSA_HEADER = binascii.unhexlify( 
    22    "30"  # sequence 
    23    "8204BD"  # DER-encoded sequence contents length of 1213 bytes -- INCORRECT STATIC LENGTH 
    24    "020100"  # integer: 0 -- Version 
    25    "30"  # sequence 
    26    "0D"  # DER-encoded sequence contents length of 13 bytes -- PrivateKeyAlgorithmIdentifier 
    27    "06092A864886F70D010101"  # OID -- rsaEncryption 
    28    "0500"  # NULL -- parameters 
    29) 
    30ASN1_SEQUENCE_ID = binascii.unhexlify("30") 
    31RSA_ENCRYPTION_ASN1_OID = "1.2.840.113549.1.1.1" 
    32 
    33# Functions gcd and rsa_recover_prime_factors were copied from cryptography 1.9 
    34# to enable pure python rsa module to be in compliance with section 6.3.1 of RFC7518 
    35# which requires only private exponent (d) for private key. 
    36 
    37 
    38def _gcd(a, b): 
    39    """Calculate the Greatest Common Divisor of a and b. 
    40 
    41    Unless b==0, the result will have the same sign as b (so that when 
    42    b is divided by it, the result comes out positive). 
    43    """ 
    44    while b: 
    45        a, b = b, (a % b) 
    46    return a 
    47 
    48 
    49# Controls the number of iterations rsa_recover_prime_factors will perform 
    50# to obtain the prime factors. Each iteration increments by 2 so the actual 
    51# maximum attempts is half this number. 
    52_MAX_RECOVERY_ATTEMPTS = 1000 
    53 
    54 
    55def _rsa_recover_prime_factors(n, e, d): 
    56    """ 
    57    Compute factors p and q from the private exponent d. We assume that n has 
    58    no more than two factors. This function is adapted from code in PyCrypto. 
    59    """ 
    60    # See 8.2.2(i) in Handbook of Applied Cryptography. 
    61    ktot = d * e - 1 
    62    # The quantity d*e-1 is a multiple of phi(n), even, 
    63    # and can be represented as t*2^s. 
    64    t = ktot 
    65    while t % 2 == 0: 
    66        t = t // 2 
    67    # Cycle through all multiplicative inverses in Zn. 
    68    # The algorithm is non-deterministic, but there is a 50% chance 
    69    # any candidate a leads to successful factoring. 
    70    # See "Digitalized Signatures and Public Key Functions as Intractable 
    71    # as Factorization", M. Rabin, 1979 
    72    spotted = False 
    73    a = 2 
    74    while not spotted and a < _MAX_RECOVERY_ATTEMPTS: 
    75        k = t 
    76        # Cycle through all values a^{t*2^i}=a^k 
    77        while k < ktot: 
    78            cand = pow(a, k, n) 
    79            # Check if a^k is a non-trivial root of unity (mod n) 
    80            if cand != 1 and cand != (n - 1) and pow(cand, 2, n) == 1: 
    81                # We have found a number such that (cand-1)(cand+1)=0 (mod n). 
    82                # Either of the terms divides n. 
    83                p = _gcd(cand + 1, n) 
    84                spotted = True 
    85                break 
    86            k *= 2 
    87        # This value was not any good... let's try another! 
    88        a += 2 
    89    if not spotted: 
    90        raise ValueError("Unable to compute factors p and q from exponent d.") 
    91    # Found ! 
    92    q, r = divmod(n, p) 
    93    assert r == 0 
    94    p, q = sorted((p, q), reverse=True) 
    95    return (p, q) 
    96 
    97 
    98def pem_to_spki(pem, fmt="PKCS8"): 
    99    key = RSAKey(pem, ALGORITHMS.RS256) 
    100    return key.to_pem(fmt) 
    101 
    102 
    103def _legacy_private_key_pkcs8_to_pkcs1(pkcs8_key): 
    104    """Legacy RSA private key PKCS8-to-PKCS1 conversion. 
    105 
    106    .. warning:: 
    107 
    108        This is incorrect parsing and only works because the legacy PKCS1-to-PKCS8 
    109        encoding was also incorrect. 
    110    """ 
    111    # Only allow this processing if the prefix matches 
    112    # AND the following byte indicates an ASN1 sequence, 
    113    # as we would expect with the legacy encoding. 
    114    if not pkcs8_key.startswith(LEGACY_INVALID_PKCS8_RSA_HEADER + ASN1_SEQUENCE_ID): 
    115        raise ValueError("Invalid private key encoding") 
    116 
    117    return pkcs8_key[len(LEGACY_INVALID_PKCS8_RSA_HEADER) :] 
    118 
    119 
    120class RSAKey(Key): 
    121    SHA256 = "SHA-256" 
    122    SHA384 = "SHA-384" 
    123    SHA512 = "SHA-512" 
    124 
    125    def __init__(self, key, algorithm): 
    126        if algorithm not in ALGORITHMS.RSA: 
    127            raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm) 
    128 
    129        if algorithm in ALGORITHMS.RSA_KW and algorithm != ALGORITHMS.RSA1_5: 
    130            raise JWKError("alg: %s is not supported by the RSA backend" % algorithm) 
    131 
    132        self.hash_alg = { 
    133            ALGORITHMS.RS256: self.SHA256, 
    134            ALGORITHMS.RS384: self.SHA384, 
    135            ALGORITHMS.RS512: self.SHA512, 
    136        }.get(algorithm) 
    137        self._algorithm = algorithm 
    138 
    139        if isinstance(key, dict): 
    140            self._prepared_key = self._process_jwk(key) 
    141            return 
    142 
    143        if isinstance(key, (pyrsa.PublicKey, pyrsa.PrivateKey)): 
    144            self._prepared_key = key 
    145            return 
    146 
    147        if isinstance(key, str): 
    148            key = key.encode("utf-8") 
    149 
    150        if isinstance(key, bytes): 
    151            try: 
    152                self._prepared_key = pyrsa.PublicKey.load_pkcs1(key) 
    153            except ValueError: 
    154                try: 
    155                    self._prepared_key = pyrsa.PublicKey.load_pkcs1_openssl_pem(key) 
    156                except ValueError: 
    157                    try: 
    158                        self._prepared_key = pyrsa.PrivateKey.load_pkcs1(key) 
    159                    except ValueError: 
    160                        try: 
    161                            der = pyrsa_pem.load_pem(key, b"PRIVATE KEY") 
    162                            try: 
    163                                pkcs1_key = rsa_private_key_pkcs8_to_pkcs1(der) 
    164                            except PyAsn1Error: 
    165                                # If the key was encoded using the old, invalid, 
    166                                # encoding then pyasn1 will throw an error attempting 
    167                                # to parse the key. 
    168                                pkcs1_key = _legacy_private_key_pkcs8_to_pkcs1(der) 
    169                            self._prepared_key = pyrsa.PrivateKey.load_pkcs1(pkcs1_key, format="DER") 
    170                        except ValueError as e: 
    171                            raise JWKError(e) 
    172            return 
    173        raise JWKError("Unable to parse an RSA_JWK from key: %s" % key) 
    174 
    175    def _process_jwk(self, jwk_dict): 
    176        if not jwk_dict.get("kty") == "RSA": 
    177            raise JWKError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get("kty")) 
    178 
    179        e = base64_to_long(jwk_dict.get("e")) 
    180        n = base64_to_long(jwk_dict.get("n")) 
    181 
    182        if "d" not in jwk_dict: 
    183            return pyrsa.PublicKey(e=e, n=n) 
    184        else: 
    185            d = base64_to_long(jwk_dict.get("d")) 
    186            extra_params = ["p", "q", "dp", "dq", "qi"] 
    187 
    188            if any(k in jwk_dict for k in extra_params): 
    189                # Precomputed private key parameters are available. 
    190                if not all(k in jwk_dict for k in extra_params): 
    191                    # These values must be present when 'p' is according to 
    192                    # Section 6.3.2 of RFC7518, so if they are not we raise 
    193                    # an error. 
    194                    raise JWKError("Precomputed private key parameters are incomplete.") 
    195 
    196                p = base64_to_long(jwk_dict["p"]) 
    197                q = base64_to_long(jwk_dict["q"]) 
    198                return pyrsa.PrivateKey(e=e, n=n, d=d, p=p, q=q) 
    199            else: 
    200                p, q = _rsa_recover_prime_factors(n, e, d) 
    201                return pyrsa.PrivateKey(n=n, e=e, d=d, p=p, q=q) 
    202 
    203    def sign(self, msg): 
    204        return pyrsa.sign(msg, self._prepared_key, self.hash_alg) 
    205 
    206    def verify(self, msg, sig): 
    207        if not self.is_public(): 
    208            warnings.warn("Attempting to verify a message with a private key. " "This is not recommended.") 
    209        try: 
    210            pyrsa.verify(msg, sig, self._prepared_key) 
    211            return True 
    212        except pyrsa.pkcs1.VerificationError: 
    213            return False 
    214 
    215    def is_public(self): 
    216        return isinstance(self._prepared_key, pyrsa.PublicKey) 
    217 
    218    def public_key(self): 
    219        if isinstance(self._prepared_key, pyrsa.PublicKey): 
    220            return self 
    221        return self.__class__(pyrsa.PublicKey(n=self._prepared_key.n, e=self._prepared_key.e), self._algorithm) 
    222 
    223    def to_pem(self, pem_format="PKCS8"): 
    224        if isinstance(self._prepared_key, pyrsa.PrivateKey): 
    225            der = self._prepared_key.save_pkcs1(format="DER") 
    226            if pem_format == "PKCS8": 
    227                pkcs8_der = rsa_private_key_pkcs1_to_pkcs8(der) 
    228                pem = pyrsa_pem.save_pem(pkcs8_der, pem_marker="PRIVATE KEY") 
    229            elif pem_format == "PKCS1": 
    230                pem = pyrsa_pem.save_pem(der, pem_marker="RSA PRIVATE KEY") 
    231            else: 
    232                raise ValueError(f"Invalid pem format specified: {pem_format!r}") 
    233        else: 
    234            if pem_format == "PKCS8": 
    235                pkcs1_der = self._prepared_key.save_pkcs1(format="DER") 
    236                pkcs8_der = rsa_public_key_pkcs1_to_pkcs8(pkcs1_der) 
    237                pem = pyrsa_pem.save_pem(pkcs8_der, pem_marker="PUBLIC KEY") 
    238            elif pem_format == "PKCS1": 
    239                der = self._prepared_key.save_pkcs1(format="DER") 
    240                pem = pyrsa_pem.save_pem(der, pem_marker="RSA PUBLIC KEY") 
    241            else: 
    242                raise ValueError(f"Invalid pem format specified: {pem_format!r}") 
    243        return pem 
    244 
    245    def to_dict(self): 
    246        if not self.is_public(): 
    247            public_key = self.public_key()._prepared_key 
    248        else: 
    249            public_key = self._prepared_key 
    250 
    251        data = { 
    252            "alg": self._algorithm, 
    253            "kty": "RSA", 
    254            "n": long_to_base64(public_key.n).decode("ASCII"), 
    255            "e": long_to_base64(public_key.e).decode("ASCII"), 
    256        } 
    257 
    258        if not self.is_public(): 
    259            data.update( 
    260                { 
    261                    "d": long_to_base64(self._prepared_key.d).decode("ASCII"), 
    262                    "p": long_to_base64(self._prepared_key.p).decode("ASCII"), 
    263                    "q": long_to_base64(self._prepared_key.q).decode("ASCII"), 
    264                    "dp": long_to_base64(self._prepared_key.exp1).decode("ASCII"), 
    265                    "dq": long_to_base64(self._prepared_key.exp2).decode("ASCII"), 
    266                    "qi": long_to_base64(self._prepared_key.coef).decode("ASCII"), 
    267                } 
    268            ) 
    269 
    270        return data 
    271 
    272    def wrap_key(self, key_data): 
    273        if not self.is_public(): 
    274            warnings.warn("Attempting to encrypt a message with a private key." " This is not recommended.") 
    275        wrapped_key = pyrsa.encrypt(key_data, self._prepared_key) 
    276        return wrapped_key 
    277 
    278    def unwrap_key(self, wrapped_key): 
    279        try: 
    280            unwrapped_key = pyrsa.decrypt(wrapped_key, self._prepared_key) 
    281        except DecryptionError as e: 
    282            raise JWEError(e) 
    283        return unwrapped_key