1# This file is dual licensed under the terms of the Apache License, Version 
    2# 2.0, and the BSD License. See the LICENSE file in the root of this repository 
    3# for complete details. 
    4 
    5from __future__ import annotations 
    6 
    7import abc 
    8 
    9from cryptography.hazmat.primitives import hashes 
    10from cryptography.hazmat.primitives._asymmetric import ( 
    11    AsymmetricPadding as AsymmetricPadding, 
    12) 
    13from cryptography.hazmat.primitives.asymmetric import rsa 
    14 
    15 
    16class PKCS1v15(AsymmetricPadding): 
    17    name = "EMSA-PKCS1-v1_5" 
    18 
    19 
    20class _MaxLength: 
    21    "Sentinel value for `MAX_LENGTH`." 
    22 
    23 
    24class _Auto: 
    25    "Sentinel value for `AUTO`." 
    26 
    27 
    28class _DigestLength: 
    29    "Sentinel value for `DIGEST_LENGTH`." 
    30 
    31 
    32class PSS(AsymmetricPadding): 
    33    MAX_LENGTH = _MaxLength() 
    34    AUTO = _Auto() 
    35    DIGEST_LENGTH = _DigestLength() 
    36    name = "EMSA-PSS" 
    37    _salt_length: int | _MaxLength | _Auto | _DigestLength 
    38 
    39    def __init__( 
    40        self, 
    41        mgf: MGF, 
    42        salt_length: int | _MaxLength | _Auto | _DigestLength, 
    43    ) -> None: 
    44        self._mgf = mgf 
    45 
    46        if not isinstance( 
    47            salt_length, (int, _MaxLength, _Auto, _DigestLength) 
    48        ): 
    49            raise TypeError( 
    50                "salt_length must be an integer, MAX_LENGTH, " 
    51                "DIGEST_LENGTH, or AUTO" 
    52            ) 
    53 
    54        if isinstance(salt_length, int) and salt_length < 0: 
    55            raise ValueError("salt_length must be zero or greater.") 
    56 
    57        self._salt_length = salt_length 
    58 
    59    @property 
    60    def mgf(self) -> MGF: 
    61        return self._mgf 
    62 
    63 
    64class OAEP(AsymmetricPadding): 
    65    name = "EME-OAEP" 
    66 
    67    def __init__( 
    68        self, 
    69        mgf: MGF, 
    70        algorithm: hashes.HashAlgorithm, 
    71        label: bytes | None, 
    72    ): 
    73        if not isinstance(algorithm, hashes.HashAlgorithm): 
    74            raise TypeError("Expected instance of hashes.HashAlgorithm.") 
    75 
    76        self._mgf = mgf 
    77        self._algorithm = algorithm 
    78        self._label = label 
    79 
    80    @property 
    81    def algorithm(self) -> hashes.HashAlgorithm: 
    82        return self._algorithm 
    83 
    84    @property 
    85    def mgf(self) -> MGF: 
    86        return self._mgf 
    87 
    88 
    89class MGF(metaclass=abc.ABCMeta): 
    90    _algorithm: hashes.HashAlgorithm 
    91 
    92 
    93class MGF1(MGF): 
    94    def __init__(self, algorithm: hashes.HashAlgorithm): 
    95        if not isinstance(algorithm, hashes.HashAlgorithm): 
    96            raise TypeError("Expected instance of hashes.HashAlgorithm.") 
    97 
    98        self._algorithm = algorithm 
    99 
    100 
    101def calculate_max_pss_salt_length( 
    102    key: rsa.RSAPrivateKey | rsa.RSAPublicKey, 
    103    hash_algorithm: hashes.HashAlgorithm, 
    104) -> int: 
    105    if not isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)): 
    106        raise TypeError("key must be an RSA public or private key") 
    107    # bit length - 1 per RFC 3447 
    108    emlen = (key.key_size + 6) // 8 
    109    salt_length = emlen - hash_algorithm.digest_size - 2 
    110    assert salt_length >= 0 
    111    return salt_length