1# This file is dual licensed under the terms of the Apache License, Version 
    2# 2.0, and the BSD License. See the LICENSE file in the root of this repository 
    3# for complete details. 
    4 
    5from __future__ import annotations 
    6 
    7import binascii 
    8import enum 
    9import os 
    10import re 
    11import typing 
    12import warnings 
    13from base64 import encodebytes as _base64_encode 
    14from dataclasses import dataclass 
    15 
    16from cryptography import utils 
    17from cryptography.exceptions import UnsupportedAlgorithm 
    18from cryptography.hazmat.primitives import hashes 
    19from cryptography.hazmat.primitives.asymmetric import ( 
    20    dsa, 
    21    ec, 
    22    ed25519, 
    23    padding, 
    24    rsa, 
    25) 
    26from cryptography.hazmat.primitives.asymmetric import utils as asym_utils 
    27from cryptography.hazmat.primitives.ciphers import ( 
    28    AEADDecryptionContext, 
    29    Cipher, 
    30    algorithms, 
    31    modes, 
    32) 
    33from cryptography.hazmat.primitives.serialization import ( 
    34    Encoding, 
    35    KeySerializationEncryption, 
    36    NoEncryption, 
    37    PrivateFormat, 
    38    PublicFormat, 
    39    _KeySerializationEncryption, 
    40) 
    41 
    42try: 
    43    from bcrypt import kdf as _bcrypt_kdf 
    44 
    45    _bcrypt_supported = True 
    46except ImportError: 
    47    _bcrypt_supported = False 
    48 
    49    def _bcrypt_kdf( 
    50        password: bytes, 
    51        salt: bytes, 
    52        desired_key_bytes: int, 
    53        rounds: int, 
    54        ignore_few_rounds: bool = False, 
    55    ) -> bytes: 
    56        raise UnsupportedAlgorithm("Need bcrypt module") 
    57 
    58 
    59_SSH_ED25519 = b"ssh-ed25519" 
    60_SSH_RSA = b"ssh-rsa" 
    61_SSH_DSA = b"ssh-dss" 
    62_ECDSA_NISTP256 = b"ecdsa-sha2-nistp256" 
    63_ECDSA_NISTP384 = b"ecdsa-sha2-nistp384" 
    64_ECDSA_NISTP521 = b"ecdsa-sha2-nistp521" 
    65_CERT_SUFFIX = b"-cert-v01@openssh.com" 
    66 
    67# U2F application string suffixed pubkey 
    68_SK_SSH_ED25519 = b"sk-ssh-ed25519@openssh.com" 
    69_SK_SSH_ECDSA_NISTP256 = b"sk-ecdsa-sha2-nistp256@openssh.com" 
    70 
    71# These are not key types, only algorithms, so they cannot appear 
    72# as a public key type 
    73_SSH_RSA_SHA256 = b"rsa-sha2-256" 
    74_SSH_RSA_SHA512 = b"rsa-sha2-512" 
    75 
    76_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)") 
    77_SK_MAGIC = b"openssh-key-v1\0" 
    78_SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----" 
    79_SK_END = b"-----END OPENSSH PRIVATE KEY-----" 
    80_BCRYPT = b"bcrypt" 
    81_NONE = b"none" 
    82_DEFAULT_CIPHER = b"aes256-ctr" 
    83_DEFAULT_ROUNDS = 16 
    84 
    85# re is only way to work on bytes-like data 
    86_PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL) 
    87 
    88# padding for max blocksize 
    89_PADDING = memoryview(bytearray(range(1, 1 + 16))) 
    90 
    91 
    92@dataclass 
    93class _SSHCipher: 
    94    alg: type[algorithms.AES] 
    95    key_len: int 
    96    mode: type[modes.CTR] | type[modes.CBC] | type[modes.GCM] 
    97    block_len: int 
    98    iv_len: int 
    99    tag_len: int | None 
    100    is_aead: bool 
    101 
    102 
    103# ciphers that are actually used in key wrapping 
    104_SSH_CIPHERS: dict[bytes, _SSHCipher] = { 
    105    b"aes256-ctr": _SSHCipher( 
    106        alg=algorithms.AES, 
    107        key_len=32, 
    108        mode=modes.CTR, 
    109        block_len=16, 
    110        iv_len=16, 
    111        tag_len=None, 
    112        is_aead=False, 
    113    ), 
    114    b"aes256-cbc": _SSHCipher( 
    115        alg=algorithms.AES, 
    116        key_len=32, 
    117        mode=modes.CBC, 
    118        block_len=16, 
    119        iv_len=16, 
    120        tag_len=None, 
    121        is_aead=False, 
    122    ), 
    123    b"aes256-gcm@openssh.com": _SSHCipher( 
    124        alg=algorithms.AES, 
    125        key_len=32, 
    126        mode=modes.GCM, 
    127        block_len=16, 
    128        iv_len=12, 
    129        tag_len=16, 
    130        is_aead=True, 
    131    ), 
    132} 
    133 
    134# map local curve name to key type 
    135_ECDSA_KEY_TYPE = { 
    136    "secp256r1": _ECDSA_NISTP256, 
    137    "secp384r1": _ECDSA_NISTP384, 
    138    "secp521r1": _ECDSA_NISTP521, 
    139} 
    140 
    141 
    142def _get_ssh_key_type(key: SSHPrivateKeyTypes | SSHPublicKeyTypes) -> bytes: 
    143    if isinstance(key, ec.EllipticCurvePrivateKey): 
    144        key_type = _ecdsa_key_type(key.public_key()) 
    145    elif isinstance(key, ec.EllipticCurvePublicKey): 
    146        key_type = _ecdsa_key_type(key) 
    147    elif isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)): 
    148        key_type = _SSH_RSA 
    149    elif isinstance(key, (dsa.DSAPrivateKey, dsa.DSAPublicKey)): 
    150        key_type = _SSH_DSA 
    151    elif isinstance( 
    152        key, (ed25519.Ed25519PrivateKey, ed25519.Ed25519PublicKey) 
    153    ): 
    154        key_type = _SSH_ED25519 
    155    else: 
    156        raise ValueError("Unsupported key type") 
    157 
    158    return key_type 
    159 
    160 
    161def _ecdsa_key_type(public_key: ec.EllipticCurvePublicKey) -> bytes: 
    162    """Return SSH key_type and curve_name for private key.""" 
    163    curve = public_key.curve 
    164    if curve.name not in _ECDSA_KEY_TYPE: 
    165        raise ValueError( 
    166            f"Unsupported curve for ssh private key: {curve.name!r}" 
    167        ) 
    168    return _ECDSA_KEY_TYPE[curve.name] 
    169 
    170 
    171def _ssh_pem_encode( 
    172    data: utils.Buffer, 
    173    prefix: bytes = _SK_START + b"\n", 
    174    suffix: bytes = _SK_END + b"\n", 
    175) -> bytes: 
    176    return b"".join([prefix, _base64_encode(data), suffix]) 
    177 
    178 
    179def _check_block_size(data: utils.Buffer, block_len: int) -> None: 
    180    """Require data to be full blocks""" 
    181    if not data or len(data) % block_len != 0: 
    182        raise ValueError("Corrupt data: missing padding") 
    183 
    184 
    185def _check_empty(data: utils.Buffer) -> None: 
    186    """All data should have been parsed.""" 
    187    if data: 
    188        raise ValueError("Corrupt data: unparsed data") 
    189 
    190 
    191def _init_cipher( 
    192    ciphername: bytes, 
    193    password: bytes | None, 
    194    salt: bytes, 
    195    rounds: int, 
    196) -> Cipher[modes.CBC | modes.CTR | modes.GCM]: 
    197    """Generate key + iv and return cipher.""" 
    198    if not password: 
    199        raise TypeError( 
    200            "Key is password-protected, but password was not provided." 
    201        ) 
    202 
    203    ciph = _SSH_CIPHERS[ciphername] 
    204    seed = _bcrypt_kdf( 
    205        password, salt, ciph.key_len + ciph.iv_len, rounds, True 
    206    ) 
    207    return Cipher( 
    208        ciph.alg(seed[: ciph.key_len]), 
    209        ciph.mode(seed[ciph.key_len :]), 
    210    ) 
    211 
    212 
    213def _get_u32(data: memoryview) -> tuple[int, memoryview]: 
    214    """Uint32""" 
    215    if len(data) < 4: 
    216        raise ValueError("Invalid data") 
    217    return int.from_bytes(data[:4], byteorder="big"), data[4:] 
    218 
    219 
    220def _get_u64(data: memoryview) -> tuple[int, memoryview]: 
    221    """Uint64""" 
    222    if len(data) < 8: 
    223        raise ValueError("Invalid data") 
    224    return int.from_bytes(data[:8], byteorder="big"), data[8:] 
    225 
    226 
    227def _get_sshstr(data: memoryview) -> tuple[memoryview, memoryview]: 
    228    """Bytes with u32 length prefix""" 
    229    n, data = _get_u32(data) 
    230    if n > len(data): 
    231        raise ValueError("Invalid data") 
    232    return data[:n], data[n:] 
    233 
    234 
    235def _get_mpint(data: memoryview) -> tuple[int, memoryview]: 
    236    """Big integer.""" 
    237    val, data = _get_sshstr(data) 
    238    if val and val[0] > 0x7F: 
    239        raise ValueError("Invalid data") 
    240    return int.from_bytes(val, "big"), data 
    241 
    242 
    243def _to_mpint(val: int) -> bytes: 
    244    """Storage format for signed bigint.""" 
    245    if val < 0: 
    246        raise ValueError("negative mpint not allowed") 
    247    if not val: 
    248        return b"" 
    249    nbytes = (val.bit_length() + 8) // 8 
    250    return utils.int_to_bytes(val, nbytes) 
    251 
    252 
    253class _FragList: 
    254    """Build recursive structure without data copy.""" 
    255 
    256    flist: list[utils.Buffer] 
    257 
    258    def __init__(self, init: list[utils.Buffer] | None = None) -> None: 
    259        self.flist = [] 
    260        if init: 
    261            self.flist.extend(init) 
    262 
    263    def put_raw(self, val: utils.Buffer) -> None: 
    264        """Add plain bytes""" 
    265        self.flist.append(val) 
    266 
    267    def put_u32(self, val: int) -> None: 
    268        """Big-endian uint32""" 
    269        self.flist.append(val.to_bytes(length=4, byteorder="big")) 
    270 
    271    def put_u64(self, val: int) -> None: 
    272        """Big-endian uint64""" 
    273        self.flist.append(val.to_bytes(length=8, byteorder="big")) 
    274 
    275    def put_sshstr(self, val: bytes | _FragList) -> None: 
    276        """Bytes prefixed with u32 length""" 
    277        if isinstance(val, (bytes, memoryview, bytearray)): 
    278            self.put_u32(len(val)) 
    279            self.flist.append(val) 
    280        else: 
    281            self.put_u32(val.size()) 
    282            self.flist.extend(val.flist) 
    283 
    284    def put_mpint(self, val: int) -> None: 
    285        """Big-endian bigint prefixed with u32 length""" 
    286        self.put_sshstr(_to_mpint(val)) 
    287 
    288    def size(self) -> int: 
    289        """Current number of bytes""" 
    290        return sum(map(len, self.flist)) 
    291 
    292    def render(self, dstbuf: memoryview, pos: int = 0) -> int: 
    293        """Write into bytearray""" 
    294        for frag in self.flist: 
    295            flen = len(frag) 
    296            start, pos = pos, pos + flen 
    297            dstbuf[start:pos] = frag 
    298        return pos 
    299 
    300    def tobytes(self) -> bytes: 
    301        """Return as bytes""" 
    302        buf = memoryview(bytearray(self.size())) 
    303        self.render(buf) 
    304        return buf.tobytes() 
    305 
    306 
    307class _SSHFormatRSA: 
    308    """Format for RSA keys. 
    309 
    310    Public: 
    311        mpint e, n 
    312    Private: 
    313        mpint n, e, d, iqmp, p, q 
    314    """ 
    315 
    316    def get_public( 
    317        self, data: memoryview 
    318    ) -> tuple[tuple[int, int], memoryview]: 
    319        """RSA public fields""" 
    320        e, data = _get_mpint(data) 
    321        n, data = _get_mpint(data) 
    322        return (e, n), data 
    323 
    324    def load_public( 
    325        self, data: memoryview 
    326    ) -> tuple[rsa.RSAPublicKey, memoryview]: 
    327        """Make RSA public key from data.""" 
    328        (e, n), data = self.get_public(data) 
    329        public_numbers = rsa.RSAPublicNumbers(e, n) 
    330        public_key = public_numbers.public_key() 
    331        return public_key, data 
    332 
    333    def load_private( 
    334        self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool 
    335    ) -> tuple[rsa.RSAPrivateKey, memoryview]: 
    336        """Make RSA private key from data.""" 
    337        n, data = _get_mpint(data) 
    338        e, data = _get_mpint(data) 
    339        d, data = _get_mpint(data) 
    340        iqmp, data = _get_mpint(data) 
    341        p, data = _get_mpint(data) 
    342        q, data = _get_mpint(data) 
    343 
    344        if (e, n) != pubfields: 
    345            raise ValueError("Corrupt data: rsa field mismatch") 
    346        dmp1 = rsa.rsa_crt_dmp1(d, p) 
    347        dmq1 = rsa.rsa_crt_dmq1(d, q) 
    348        public_numbers = rsa.RSAPublicNumbers(e, n) 
    349        private_numbers = rsa.RSAPrivateNumbers( 
    350            p, q, d, dmp1, dmq1, iqmp, public_numbers 
    351        ) 
    352        private_key = private_numbers.private_key( 
    353            unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation 
    354        ) 
    355        return private_key, data 
    356 
    357    def encode_public( 
    358        self, public_key: rsa.RSAPublicKey, f_pub: _FragList 
    359    ) -> None: 
    360        """Write RSA public key""" 
    361        pubn = public_key.public_numbers() 
    362        f_pub.put_mpint(pubn.e) 
    363        f_pub.put_mpint(pubn.n) 
    364 
    365    def encode_private( 
    366        self, private_key: rsa.RSAPrivateKey, f_priv: _FragList 
    367    ) -> None: 
    368        """Write RSA private key""" 
    369        private_numbers = private_key.private_numbers() 
    370        public_numbers = private_numbers.public_numbers 
    371 
    372        f_priv.put_mpint(public_numbers.n) 
    373        f_priv.put_mpint(public_numbers.e) 
    374 
    375        f_priv.put_mpint(private_numbers.d) 
    376        f_priv.put_mpint(private_numbers.iqmp) 
    377        f_priv.put_mpint(private_numbers.p) 
    378        f_priv.put_mpint(private_numbers.q) 
    379 
    380 
    381class _SSHFormatDSA: 
    382    """Format for DSA keys. 
    383 
    384    Public: 
    385        mpint p, q, g, y 
    386    Private: 
    387        mpint p, q, g, y, x 
    388    """ 
    389 
    390    def get_public(self, data: memoryview) -> tuple[tuple, memoryview]: 
    391        """DSA public fields""" 
    392        p, data = _get_mpint(data) 
    393        q, data = _get_mpint(data) 
    394        g, data = _get_mpint(data) 
    395        y, data = _get_mpint(data) 
    396        return (p, q, g, y), data 
    397 
    398    def load_public( 
    399        self, data: memoryview 
    400    ) -> tuple[dsa.DSAPublicKey, memoryview]: 
    401        """Make DSA public key from data.""" 
    402        (p, q, g, y), data = self.get_public(data) 
    403        parameter_numbers = dsa.DSAParameterNumbers(p, q, g) 
    404        public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) 
    405        self._validate(public_numbers) 
    406        public_key = public_numbers.public_key() 
    407        return public_key, data 
    408 
    409    def load_private( 
    410        self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool 
    411    ) -> tuple[dsa.DSAPrivateKey, memoryview]: 
    412        """Make DSA private key from data.""" 
    413        (p, q, g, y), data = self.get_public(data) 
    414        x, data = _get_mpint(data) 
    415 
    416        if (p, q, g, y) != pubfields: 
    417            raise ValueError("Corrupt data: dsa field mismatch") 
    418        parameter_numbers = dsa.DSAParameterNumbers(p, q, g) 
    419        public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) 
    420        self._validate(public_numbers) 
    421        private_numbers = dsa.DSAPrivateNumbers(x, public_numbers) 
    422        private_key = private_numbers.private_key() 
    423        return private_key, data 
    424 
    425    def encode_public( 
    426        self, public_key: dsa.DSAPublicKey, f_pub: _FragList 
    427    ) -> None: 
    428        """Write DSA public key""" 
    429        public_numbers = public_key.public_numbers() 
    430        parameter_numbers = public_numbers.parameter_numbers 
    431        self._validate(public_numbers) 
    432 
    433        f_pub.put_mpint(parameter_numbers.p) 
    434        f_pub.put_mpint(parameter_numbers.q) 
    435        f_pub.put_mpint(parameter_numbers.g) 
    436        f_pub.put_mpint(public_numbers.y) 
    437 
    438    def encode_private( 
    439        self, private_key: dsa.DSAPrivateKey, f_priv: _FragList 
    440    ) -> None: 
    441        """Write DSA private key""" 
    442        self.encode_public(private_key.public_key(), f_priv) 
    443        f_priv.put_mpint(private_key.private_numbers().x) 
    444 
    445    def _validate(self, public_numbers: dsa.DSAPublicNumbers) -> None: 
    446        parameter_numbers = public_numbers.parameter_numbers 
    447        if parameter_numbers.p.bit_length() != 1024: 
    448            raise ValueError("SSH supports only 1024 bit DSA keys") 
    449 
    450 
    451class _SSHFormatECDSA: 
    452    """Format for ECDSA keys. 
    453 
    454    Public: 
    455        str curve 
    456        bytes point 
    457    Private: 
    458        str curve 
    459        bytes point 
    460        mpint secret 
    461    """ 
    462 
    463    def __init__(self, ssh_curve_name: bytes, curve: ec.EllipticCurve): 
    464        self.ssh_curve_name = ssh_curve_name 
    465        self.curve = curve 
    466 
    467    def get_public( 
    468        self, data: memoryview 
    469    ) -> tuple[tuple[memoryview, memoryview], memoryview]: 
    470        """ECDSA public fields""" 
    471        curve, data = _get_sshstr(data) 
    472        point, data = _get_sshstr(data) 
    473        if curve != self.ssh_curve_name: 
    474            raise ValueError("Curve name mismatch") 
    475        if point[0] != 4: 
    476            raise NotImplementedError("Need uncompressed point") 
    477        return (curve, point), data 
    478 
    479    def load_public( 
    480        self, data: memoryview 
    481    ) -> tuple[ec.EllipticCurvePublicKey, memoryview]: 
    482        """Make ECDSA public key from data.""" 
    483        (_, point), data = self.get_public(data) 
    484        public_key = ec.EllipticCurvePublicKey.from_encoded_point( 
    485            self.curve, point.tobytes() 
    486        ) 
    487        return public_key, data 
    488 
    489    def load_private( 
    490        self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool 
    491    ) -> tuple[ec.EllipticCurvePrivateKey, memoryview]: 
    492        """Make ECDSA private key from data.""" 
    493        (curve_name, point), data = self.get_public(data) 
    494        secret, data = _get_mpint(data) 
    495 
    496        if (curve_name, point) != pubfields: 
    497            raise ValueError("Corrupt data: ecdsa field mismatch") 
    498        private_key = ec.derive_private_key(secret, self.curve) 
    499        return private_key, data 
    500 
    501    def encode_public( 
    502        self, public_key: ec.EllipticCurvePublicKey, f_pub: _FragList 
    503    ) -> None: 
    504        """Write ECDSA public key""" 
    505        point = public_key.public_bytes( 
    506            Encoding.X962, PublicFormat.UncompressedPoint 
    507        ) 
    508        f_pub.put_sshstr(self.ssh_curve_name) 
    509        f_pub.put_sshstr(point) 
    510 
    511    def encode_private( 
    512        self, private_key: ec.EllipticCurvePrivateKey, f_priv: _FragList 
    513    ) -> None: 
    514        """Write ECDSA private key""" 
    515        public_key = private_key.public_key() 
    516        private_numbers = private_key.private_numbers() 
    517 
    518        self.encode_public(public_key, f_priv) 
    519        f_priv.put_mpint(private_numbers.private_value) 
    520 
    521 
    522class _SSHFormatEd25519: 
    523    """Format for Ed25519 keys. 
    524 
    525    Public: 
    526        bytes point 
    527    Private: 
    528        bytes point 
    529        bytes secret_and_point 
    530    """ 
    531 
    532    def get_public( 
    533        self, data: memoryview 
    534    ) -> tuple[tuple[memoryview], memoryview]: 
    535        """Ed25519 public fields""" 
    536        point, data = _get_sshstr(data) 
    537        return (point,), data 
    538 
    539    def load_public( 
    540        self, data: memoryview 
    541    ) -> tuple[ed25519.Ed25519PublicKey, memoryview]: 
    542        """Make Ed25519 public key from data.""" 
    543        (point,), data = self.get_public(data) 
    544        public_key = ed25519.Ed25519PublicKey.from_public_bytes( 
    545            point.tobytes() 
    546        ) 
    547        return public_key, data 
    548 
    549    def load_private( 
    550        self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool 
    551    ) -> tuple[ed25519.Ed25519PrivateKey, memoryview]: 
    552        """Make Ed25519 private key from data.""" 
    553        (point,), data = self.get_public(data) 
    554        keypair, data = _get_sshstr(data) 
    555 
    556        secret = keypair[:32] 
    557        point2 = keypair[32:] 
    558        if point != point2 or (point,) != pubfields: 
    559            raise ValueError("Corrupt data: ed25519 field mismatch") 
    560        private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret) 
    561        return private_key, data 
    562 
    563    def encode_public( 
    564        self, public_key: ed25519.Ed25519PublicKey, f_pub: _FragList 
    565    ) -> None: 
    566        """Write Ed25519 public key""" 
    567        raw_public_key = public_key.public_bytes( 
    568            Encoding.Raw, PublicFormat.Raw 
    569        ) 
    570        f_pub.put_sshstr(raw_public_key) 
    571 
    572    def encode_private( 
    573        self, private_key: ed25519.Ed25519PrivateKey, f_priv: _FragList 
    574    ) -> None: 
    575        """Write Ed25519 private key""" 
    576        public_key = private_key.public_key() 
    577        raw_private_key = private_key.private_bytes( 
    578            Encoding.Raw, PrivateFormat.Raw, NoEncryption() 
    579        ) 
    580        raw_public_key = public_key.public_bytes( 
    581            Encoding.Raw, PublicFormat.Raw 
    582        ) 
    583        f_keypair = _FragList([raw_private_key, raw_public_key]) 
    584 
    585        self.encode_public(public_key, f_priv) 
    586        f_priv.put_sshstr(f_keypair) 
    587 
    588 
    589def load_application(data) -> tuple[memoryview, memoryview]: 
    590    """ 
    591    U2F application strings 
    592    """ 
    593    application, data = _get_sshstr(data) 
    594    if not application.tobytes().startswith(b"ssh:"): 
    595        raise ValueError( 
    596            "U2F application string does not start with b'ssh:' " 
    597            f"({application})" 
    598        ) 
    599    return application, data 
    600 
    601 
    602class _SSHFormatSKEd25519: 
    603    """ 
    604    The format of a sk-ssh-ed25519@openssh.com public key is: 
    605 
    606        string          "sk-ssh-ed25519@openssh.com" 
    607        string          public key 
    608        string          application (user-specified, but typically "ssh:") 
    609    """ 
    610 
    611    def load_public( 
    612        self, data: memoryview 
    613    ) -> tuple[ed25519.Ed25519PublicKey, memoryview]: 
    614        """Make Ed25519 public key from data.""" 
    615        public_key, data = _lookup_kformat(_SSH_ED25519).load_public(data) 
    616        _, data = load_application(data) 
    617        return public_key, data 
    618 
    619    def get_public(self, data: memoryview) -> typing.NoReturn: 
    620        # Confusingly `get_public` is an entry point used by private key 
    621        # loading. 
    622        raise UnsupportedAlgorithm( 
    623            "sk-ssh-ed25519 private keys cannot be loaded" 
    624        ) 
    625 
    626 
    627class _SSHFormatSKECDSA: 
    628    """ 
    629    The format of a sk-ecdsa-sha2-nistp256@openssh.com public key is: 
    630 
    631        string          "sk-ecdsa-sha2-nistp256@openssh.com" 
    632        string          curve name 
    633        ec_point        Q 
    634        string          application (user-specified, but typically "ssh:") 
    635    """ 
    636 
    637    def load_public( 
    638        self, data: memoryview 
    639    ) -> tuple[ec.EllipticCurvePublicKey, memoryview]: 
    640        """Make ECDSA public key from data.""" 
    641        public_key, data = _lookup_kformat(_ECDSA_NISTP256).load_public(data) 
    642        _, data = load_application(data) 
    643        return public_key, data 
    644 
    645    def get_public(self, data: memoryview) -> typing.NoReturn: 
    646        # Confusingly `get_public` is an entry point used by private key 
    647        # loading. 
    648        raise UnsupportedAlgorithm( 
    649            "sk-ecdsa-sha2-nistp256 private keys cannot be loaded" 
    650        ) 
    651 
    652 
    653_KEY_FORMATS = { 
    654    _SSH_RSA: _SSHFormatRSA(), 
    655    _SSH_DSA: _SSHFormatDSA(), 
    656    _SSH_ED25519: _SSHFormatEd25519(), 
    657    _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()), 
    658    _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()), 
    659    _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()), 
    660    _SK_SSH_ED25519: _SSHFormatSKEd25519(), 
    661    _SK_SSH_ECDSA_NISTP256: _SSHFormatSKECDSA(), 
    662} 
    663 
    664 
    665def _lookup_kformat(key_type: utils.Buffer): 
    666    """Return valid format or throw error""" 
    667    if not isinstance(key_type, bytes): 
    668        key_type = memoryview(key_type).tobytes() 
    669    if key_type in _KEY_FORMATS: 
    670        return _KEY_FORMATS[key_type] 
    671    raise UnsupportedAlgorithm(f"Unsupported key type: {key_type!r}") 
    672 
    673 
    674SSHPrivateKeyTypes = typing.Union[ 
    675    ec.EllipticCurvePrivateKey, 
    676    rsa.RSAPrivateKey, 
    677    dsa.DSAPrivateKey, 
    678    ed25519.Ed25519PrivateKey, 
    679] 
    680 
    681 
    682def load_ssh_private_key( 
    683    data: utils.Buffer, 
    684    password: bytes | None, 
    685    backend: typing.Any = None, 
    686    *, 
    687    unsafe_skip_rsa_key_validation: bool = False, 
    688) -> SSHPrivateKeyTypes: 
    689    """Load private key from OpenSSH custom encoding.""" 
    690    utils._check_byteslike("data", data) 
    691    if password is not None: 
    692        utils._check_bytes("password", password) 
    693 
    694    m = _PEM_RC.search(data) 
    695    if not m: 
    696        raise ValueError("Not OpenSSH private key format") 
    697    p1 = m.start(1) 
    698    p2 = m.end(1) 
    699    data = binascii.a2b_base64(memoryview(data)[p1:p2]) 
    700    if not data.startswith(_SK_MAGIC): 
    701        raise ValueError("Not OpenSSH private key format") 
    702    data = memoryview(data)[len(_SK_MAGIC) :] 
    703 
    704    # parse header 
    705    ciphername, data = _get_sshstr(data) 
    706    kdfname, data = _get_sshstr(data) 
    707    kdfoptions, data = _get_sshstr(data) 
    708    nkeys, data = _get_u32(data) 
    709    if nkeys != 1: 
    710        raise ValueError("Only one key supported") 
    711 
    712    # load public key data 
    713    pubdata, data = _get_sshstr(data) 
    714    pub_key_type, pubdata = _get_sshstr(pubdata) 
    715    kformat = _lookup_kformat(pub_key_type) 
    716    pubfields, pubdata = kformat.get_public(pubdata) 
    717    _check_empty(pubdata) 
    718 
    719    if ciphername != _NONE or kdfname != _NONE: 
    720        ciphername_bytes = ciphername.tobytes() 
    721        if ciphername_bytes not in _SSH_CIPHERS: 
    722            raise UnsupportedAlgorithm( 
    723                f"Unsupported cipher: {ciphername_bytes!r}" 
    724            ) 
    725        if kdfname != _BCRYPT: 
    726            raise UnsupportedAlgorithm(f"Unsupported KDF: {kdfname!r}") 
    727        blklen = _SSH_CIPHERS[ciphername_bytes].block_len 
    728        tag_len = _SSH_CIPHERS[ciphername_bytes].tag_len 
    729        # load secret data 
    730        edata, data = _get_sshstr(data) 
    731        # see https://bugzilla.mindrot.org/show_bug.cgi?id=3553 for 
    732        # information about how OpenSSH handles AEAD tags 
    733        if _SSH_CIPHERS[ciphername_bytes].is_aead: 
    734            tag = bytes(data) 
    735            if len(tag) != tag_len: 
    736                raise ValueError("Corrupt data: invalid tag length for cipher") 
    737        else: 
    738            _check_empty(data) 
    739        _check_block_size(edata, blklen) 
    740        salt, kbuf = _get_sshstr(kdfoptions) 
    741        rounds, kbuf = _get_u32(kbuf) 
    742        _check_empty(kbuf) 
    743        ciph = _init_cipher(ciphername_bytes, password, salt.tobytes(), rounds) 
    744        dec = ciph.decryptor() 
    745        edata = memoryview(dec.update(edata)) 
    746        if _SSH_CIPHERS[ciphername_bytes].is_aead: 
    747            assert isinstance(dec, AEADDecryptionContext) 
    748            _check_empty(dec.finalize_with_tag(tag)) 
    749        else: 
    750            # _check_block_size requires data to be a full block so there 
    751            # should be no output from finalize 
    752            _check_empty(dec.finalize()) 
    753    else: 
    754        if password: 
    755            raise TypeError( 
    756                "Password was given but private key is not encrypted." 
    757            ) 
    758        # load secret data 
    759        edata, data = _get_sshstr(data) 
    760        _check_empty(data) 
    761        blklen = 8 
    762        _check_block_size(edata, blklen) 
    763    ck1, edata = _get_u32(edata) 
    764    ck2, edata = _get_u32(edata) 
    765    if ck1 != ck2: 
    766        raise ValueError("Corrupt data: broken checksum") 
    767 
    768    # load per-key struct 
    769    key_type, edata = _get_sshstr(edata) 
    770    if key_type != pub_key_type: 
    771        raise ValueError("Corrupt data: key type mismatch") 
    772    private_key, edata = kformat.load_private( 
    773        edata, 
    774        pubfields, 
    775        unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation, 
    776    ) 
    777    # We don't use the comment 
    778    _, edata = _get_sshstr(edata) 
    779 
    780    # yes, SSH does padding check *after* all other parsing is done. 
    781    # need to follow as it writes zero-byte padding too. 
    782    if edata != _PADDING[: len(edata)]: 
    783        raise ValueError("Corrupt data: invalid padding") 
    784 
    785    if isinstance(private_key, dsa.DSAPrivateKey): 
    786        warnings.warn( 
    787            "SSH DSA keys are deprecated and will be removed in a future " 
    788            "release.", 
    789            utils.DeprecatedIn40, 
    790            stacklevel=2, 
    791        ) 
    792 
    793    return private_key 
    794 
    795 
    796def _serialize_ssh_private_key( 
    797    private_key: SSHPrivateKeyTypes, 
    798    password: bytes, 
    799    encryption_algorithm: KeySerializationEncryption, 
    800) -> bytes: 
    801    """Serialize private key with OpenSSH custom encoding.""" 
    802    utils._check_bytes("password", password) 
    803    if isinstance(private_key, dsa.DSAPrivateKey): 
    804        warnings.warn( 
    805            "SSH DSA key support is deprecated and will be " 
    806            "removed in a future release", 
    807            utils.DeprecatedIn40, 
    808            stacklevel=4, 
    809        ) 
    810 
    811    key_type = _get_ssh_key_type(private_key) 
    812    kformat = _lookup_kformat(key_type) 
    813 
    814    # setup parameters 
    815    f_kdfoptions = _FragList() 
    816    if password: 
    817        ciphername = _DEFAULT_CIPHER 
    818        blklen = _SSH_CIPHERS[ciphername].block_len 
    819        kdfname = _BCRYPT 
    820        rounds = _DEFAULT_ROUNDS 
    821        if ( 
    822            isinstance(encryption_algorithm, _KeySerializationEncryption) 
    823            and encryption_algorithm._kdf_rounds is not None 
    824        ): 
    825            rounds = encryption_algorithm._kdf_rounds 
    826        salt = os.urandom(16) 
    827        f_kdfoptions.put_sshstr(salt) 
    828        f_kdfoptions.put_u32(rounds) 
    829        ciph = _init_cipher(ciphername, password, salt, rounds) 
    830    else: 
    831        ciphername = kdfname = _NONE 
    832        blklen = 8 
    833        ciph = None 
    834    nkeys = 1 
    835    checkval = os.urandom(4) 
    836    comment = b"" 
    837 
    838    # encode public and private parts together 
    839    f_public_key = _FragList() 
    840    f_public_key.put_sshstr(key_type) 
    841    kformat.encode_public(private_key.public_key(), f_public_key) 
    842 
    843    f_secrets = _FragList([checkval, checkval]) 
    844    f_secrets.put_sshstr(key_type) 
    845    kformat.encode_private(private_key, f_secrets) 
    846    f_secrets.put_sshstr(comment) 
    847    f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)]) 
    848 
    849    # top-level structure 
    850    f_main = _FragList() 
    851    f_main.put_raw(_SK_MAGIC) 
    852    f_main.put_sshstr(ciphername) 
    853    f_main.put_sshstr(kdfname) 
    854    f_main.put_sshstr(f_kdfoptions) 
    855    f_main.put_u32(nkeys) 
    856    f_main.put_sshstr(f_public_key) 
    857    f_main.put_sshstr(f_secrets) 
    858 
    859    # copy result info bytearray 
    860    slen = f_secrets.size() 
    861    mlen = f_main.size() 
    862    buf = memoryview(bytearray(mlen + blklen)) 
    863    f_main.render(buf) 
    864    ofs = mlen - slen 
    865 
    866    # encrypt in-place 
    867    if ciph is not None: 
    868        ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:]) 
    869 
    870    return _ssh_pem_encode(buf[:mlen]) 
    871 
    872 
    873SSHPublicKeyTypes = typing.Union[ 
    874    ec.EllipticCurvePublicKey, 
    875    rsa.RSAPublicKey, 
    876    dsa.DSAPublicKey, 
    877    ed25519.Ed25519PublicKey, 
    878] 
    879 
    880SSHCertPublicKeyTypes = typing.Union[ 
    881    ec.EllipticCurvePublicKey, 
    882    rsa.RSAPublicKey, 
    883    ed25519.Ed25519PublicKey, 
    884] 
    885 
    886 
    887class SSHCertificateType(enum.Enum): 
    888    USER = 1 
    889    HOST = 2 
    890 
    891 
    892class SSHCertificate: 
    893    def __init__( 
    894        self, 
    895        _nonce: memoryview, 
    896        _public_key: SSHPublicKeyTypes, 
    897        _serial: int, 
    898        _cctype: int, 
    899        _key_id: memoryview, 
    900        _valid_principals: list[bytes], 
    901        _valid_after: int, 
    902        _valid_before: int, 
    903        _critical_options: dict[bytes, bytes], 
    904        _extensions: dict[bytes, bytes], 
    905        _sig_type: memoryview, 
    906        _sig_key: memoryview, 
    907        _inner_sig_type: memoryview, 
    908        _signature: memoryview, 
    909        _tbs_cert_body: memoryview, 
    910        _cert_key_type: bytes, 
    911        _cert_body: memoryview, 
    912    ): 
    913        self._nonce = _nonce 
    914        self._public_key = _public_key 
    915        self._serial = _serial 
    916        try: 
    917            self._type = SSHCertificateType(_cctype) 
    918        except ValueError: 
    919            raise ValueError("Invalid certificate type") 
    920        self._key_id = _key_id 
    921        self._valid_principals = _valid_principals 
    922        self._valid_after = _valid_after 
    923        self._valid_before = _valid_before 
    924        self._critical_options = _critical_options 
    925        self._extensions = _extensions 
    926        self._sig_type = _sig_type 
    927        self._sig_key = _sig_key 
    928        self._inner_sig_type = _inner_sig_type 
    929        self._signature = _signature 
    930        self._cert_key_type = _cert_key_type 
    931        self._cert_body = _cert_body 
    932        self._tbs_cert_body = _tbs_cert_body 
    933 
    934    @property 
    935    def nonce(self) -> bytes: 
    936        return bytes(self._nonce) 
    937 
    938    def public_key(self) -> SSHCertPublicKeyTypes: 
    939        # make mypy happy until we remove DSA support entirely and 
    940        # the underlying union won't have a disallowed type 
    941        return typing.cast(SSHCertPublicKeyTypes, self._public_key) 
    942 
    943    @property 
    944    def serial(self) -> int: 
    945        return self._serial 
    946 
    947    @property 
    948    def type(self) -> SSHCertificateType: 
    949        return self._type 
    950 
    951    @property 
    952    def key_id(self) -> bytes: 
    953        return bytes(self._key_id) 
    954 
    955    @property 
    956    def valid_principals(self) -> list[bytes]: 
    957        return self._valid_principals 
    958 
    959    @property 
    960    def valid_before(self) -> int: 
    961        return self._valid_before 
    962 
    963    @property 
    964    def valid_after(self) -> int: 
    965        return self._valid_after 
    966 
    967    @property 
    968    def critical_options(self) -> dict[bytes, bytes]: 
    969        return self._critical_options 
    970 
    971    @property 
    972    def extensions(self) -> dict[bytes, bytes]: 
    973        return self._extensions 
    974 
    975    def signature_key(self) -> SSHCertPublicKeyTypes: 
    976        sigformat = _lookup_kformat(self._sig_type) 
    977        signature_key, sigkey_rest = sigformat.load_public(self._sig_key) 
    978        _check_empty(sigkey_rest) 
    979        return signature_key 
    980 
    981    def public_bytes(self) -> bytes: 
    982        return ( 
    983            bytes(self._cert_key_type) 
    984            + b" " 
    985            + binascii.b2a_base64(bytes(self._cert_body), newline=False) 
    986        ) 
    987 
    988    def verify_cert_signature(self) -> None: 
    989        signature_key = self.signature_key() 
    990        if isinstance(signature_key, ed25519.Ed25519PublicKey): 
    991            signature_key.verify( 
    992                bytes(self._signature), bytes(self._tbs_cert_body) 
    993            ) 
    994        elif isinstance(signature_key, ec.EllipticCurvePublicKey): 
    995            # The signature is encoded as a pair of big-endian integers 
    996            r, data = _get_mpint(self._signature) 
    997            s, data = _get_mpint(data) 
    998            _check_empty(data) 
    999            computed_sig = asym_utils.encode_dss_signature(r, s) 
    1000            hash_alg = _get_ec_hash_alg(signature_key.curve) 
    1001            signature_key.verify( 
    1002                computed_sig, bytes(self._tbs_cert_body), ec.ECDSA(hash_alg) 
    1003            ) 
    1004        else: 
    1005            assert isinstance(signature_key, rsa.RSAPublicKey) 
    1006            if self._inner_sig_type == _SSH_RSA: 
    1007                hash_alg = hashes.SHA1() 
    1008            elif self._inner_sig_type == _SSH_RSA_SHA256: 
    1009                hash_alg = hashes.SHA256() 
    1010            else: 
    1011                assert self._inner_sig_type == _SSH_RSA_SHA512 
    1012                hash_alg = hashes.SHA512() 
    1013            signature_key.verify( 
    1014                bytes(self._signature), 
    1015                bytes(self._tbs_cert_body), 
    1016                padding.PKCS1v15(), 
    1017                hash_alg, 
    1018            ) 
    1019 
    1020 
    1021def _get_ec_hash_alg(curve: ec.EllipticCurve) -> hashes.HashAlgorithm: 
    1022    if isinstance(curve, ec.SECP256R1): 
    1023        return hashes.SHA256() 
    1024    elif isinstance(curve, ec.SECP384R1): 
    1025        return hashes.SHA384() 
    1026    else: 
    1027        assert isinstance(curve, ec.SECP521R1) 
    1028        return hashes.SHA512() 
    1029 
    1030 
    1031def _load_ssh_public_identity( 
    1032    data: utils.Buffer, 
    1033    _legacy_dsa_allowed=False, 
    1034) -> SSHCertificate | SSHPublicKeyTypes: 
    1035    utils._check_byteslike("data", data) 
    1036 
    1037    m = _SSH_PUBKEY_RC.match(data) 
    1038    if not m: 
    1039        raise ValueError("Invalid line format") 
    1040    key_type = orig_key_type = m.group(1) 
    1041    key_body = m.group(2) 
    1042    with_cert = False 
    1043    if key_type.endswith(_CERT_SUFFIX): 
    1044        with_cert = True 
    1045        key_type = key_type[: -len(_CERT_SUFFIX)] 
    1046    if key_type == _SSH_DSA and not _legacy_dsa_allowed: 
    1047        raise UnsupportedAlgorithm( 
    1048            "DSA keys aren't supported in SSH certificates" 
    1049        ) 
    1050    kformat = _lookup_kformat(key_type) 
    1051 
    1052    try: 
    1053        rest = memoryview(binascii.a2b_base64(key_body)) 
    1054    except (TypeError, binascii.Error): 
    1055        raise ValueError("Invalid format") 
    1056 
    1057    if with_cert: 
    1058        cert_body = rest 
    1059    inner_key_type, rest = _get_sshstr(rest) 
    1060    if inner_key_type != orig_key_type: 
    1061        raise ValueError("Invalid key format") 
    1062    if with_cert: 
    1063        nonce, rest = _get_sshstr(rest) 
    1064    public_key, rest = kformat.load_public(rest) 
    1065    if with_cert: 
    1066        serial, rest = _get_u64(rest) 
    1067        cctype, rest = _get_u32(rest) 
    1068        key_id, rest = _get_sshstr(rest) 
    1069        principals, rest = _get_sshstr(rest) 
    1070        valid_principals = [] 
    1071        while principals: 
    1072            principal, principals = _get_sshstr(principals) 
    1073            valid_principals.append(bytes(principal)) 
    1074        valid_after, rest = _get_u64(rest) 
    1075        valid_before, rest = _get_u64(rest) 
    1076        crit_options, rest = _get_sshstr(rest) 
    1077        critical_options = _parse_exts_opts(crit_options) 
    1078        exts, rest = _get_sshstr(rest) 
    1079        extensions = _parse_exts_opts(exts) 
    1080        # Get the reserved field, which is unused. 
    1081        _, rest = _get_sshstr(rest) 
    1082        sig_key_raw, rest = _get_sshstr(rest) 
    1083        sig_type, sig_key = _get_sshstr(sig_key_raw) 
    1084        if sig_type == _SSH_DSA and not _legacy_dsa_allowed: 
    1085            raise UnsupportedAlgorithm( 
    1086                "DSA signatures aren't supported in SSH certificates" 
    1087            ) 
    1088        # Get the entire cert body and subtract the signature 
    1089        tbs_cert_body = cert_body[: -len(rest)] 
    1090        signature_raw, rest = _get_sshstr(rest) 
    1091        _check_empty(rest) 
    1092        inner_sig_type, sig_rest = _get_sshstr(signature_raw) 
    1093        # RSA certs can have multiple algorithm types 
    1094        if ( 
    1095            sig_type == _SSH_RSA 
    1096            and inner_sig_type 
    1097            not in [_SSH_RSA_SHA256, _SSH_RSA_SHA512, _SSH_RSA] 
    1098        ) or (sig_type != _SSH_RSA and inner_sig_type != sig_type): 
    1099            raise ValueError("Signature key type does not match") 
    1100        signature, sig_rest = _get_sshstr(sig_rest) 
    1101        _check_empty(sig_rest) 
    1102        return SSHCertificate( 
    1103            nonce, 
    1104            public_key, 
    1105            serial, 
    1106            cctype, 
    1107            key_id, 
    1108            valid_principals, 
    1109            valid_after, 
    1110            valid_before, 
    1111            critical_options, 
    1112            extensions, 
    1113            sig_type, 
    1114            sig_key, 
    1115            inner_sig_type, 
    1116            signature, 
    1117            tbs_cert_body, 
    1118            orig_key_type, 
    1119            cert_body, 
    1120        ) 
    1121    else: 
    1122        _check_empty(rest) 
    1123        return public_key 
    1124 
    1125 
    1126def load_ssh_public_identity( 
    1127    data: utils.Buffer, 
    1128) -> SSHCertificate | SSHPublicKeyTypes: 
    1129    return _load_ssh_public_identity(data) 
    1130 
    1131 
    1132def _parse_exts_opts(exts_opts: memoryview) -> dict[bytes, bytes]: 
    1133    result: dict[bytes, bytes] = {} 
    1134    last_name = None 
    1135    while exts_opts: 
    1136        name, exts_opts = _get_sshstr(exts_opts) 
    1137        bname: bytes = bytes(name) 
    1138        if bname in result: 
    1139            raise ValueError("Duplicate name") 
    1140        if last_name is not None and bname < last_name: 
    1141            raise ValueError("Fields not lexically sorted") 
    1142        value, exts_opts = _get_sshstr(exts_opts) 
    1143        if len(value) > 0: 
    1144            value, extra = _get_sshstr(value) 
    1145            if len(extra) > 0: 
    1146                raise ValueError("Unexpected extra data after value") 
    1147        result[bname] = bytes(value) 
    1148        last_name = bname 
    1149    return result 
    1150 
    1151 
    1152def ssh_key_fingerprint( 
    1153    key: SSHPublicKeyTypes, 
    1154    hash_algorithm: hashes.MD5 | hashes.SHA256, 
    1155) -> bytes: 
    1156    if not isinstance(hash_algorithm, (hashes.MD5, hashes.SHA256)): 
    1157        raise TypeError("hash_algorithm must be either MD5 or SHA256") 
    1158 
    1159    key_type = _get_ssh_key_type(key) 
    1160    kformat = _lookup_kformat(key_type) 
    1161 
    1162    f_pub = _FragList() 
    1163    f_pub.put_sshstr(key_type) 
    1164    kformat.encode_public(key, f_pub) 
    1165 
    1166    ssh_binary_data = f_pub.tobytes() 
    1167 
    1168    # Hash the binary data 
    1169    hash_obj = hashes.Hash(hash_algorithm) 
    1170    hash_obj.update(ssh_binary_data) 
    1171    return hash_obj.finalize() 
    1172 
    1173 
    1174def load_ssh_public_key( 
    1175    data: utils.Buffer, backend: typing.Any = None 
    1176) -> SSHPublicKeyTypes: 
    1177    cert_or_key = _load_ssh_public_identity(data, _legacy_dsa_allowed=True) 
    1178    public_key: SSHPublicKeyTypes 
    1179    if isinstance(cert_or_key, SSHCertificate): 
    1180        public_key = cert_or_key.public_key() 
    1181    else: 
    1182        public_key = cert_or_key 
    1183 
    1184    if isinstance(public_key, dsa.DSAPublicKey): 
    1185        warnings.warn( 
    1186            "SSH DSA keys are deprecated and will be removed in a future " 
    1187            "release.", 
    1188            utils.DeprecatedIn40, 
    1189            stacklevel=2, 
    1190        ) 
    1191    return public_key 
    1192 
    1193 
    1194def serialize_ssh_public_key(public_key: SSHPublicKeyTypes) -> bytes: 
    1195    """One-line public key format for OpenSSH""" 
    1196    if isinstance(public_key, dsa.DSAPublicKey): 
    1197        warnings.warn( 
    1198            "SSH DSA key support is deprecated and will be " 
    1199            "removed in a future release", 
    1200            utils.DeprecatedIn40, 
    1201            stacklevel=4, 
    1202        ) 
    1203    key_type = _get_ssh_key_type(public_key) 
    1204    kformat = _lookup_kformat(key_type) 
    1205 
    1206    f_pub = _FragList() 
    1207    f_pub.put_sshstr(key_type) 
    1208    kformat.encode_public(public_key, f_pub) 
    1209 
    1210    pub = binascii.b2a_base64(f_pub.tobytes()).strip() 
    1211    return b"".join([key_type, b" ", pub]) 
    1212 
    1213 
    1214SSHCertPrivateKeyTypes = typing.Union[ 
    1215    ec.EllipticCurvePrivateKey, 
    1216    rsa.RSAPrivateKey, 
    1217    ed25519.Ed25519PrivateKey, 
    1218] 
    1219 
    1220 
    1221# This is an undocumented limit enforced in the openssh codebase for sshd and 
    1222# ssh-keygen, but it is undefined in the ssh certificates spec. 
    1223_SSHKEY_CERT_MAX_PRINCIPALS = 256 
    1224 
    1225 
    1226class SSHCertificateBuilder: 
    1227    def __init__( 
    1228        self, 
    1229        _public_key: SSHCertPublicKeyTypes | None = None, 
    1230        _serial: int | None = None, 
    1231        _type: SSHCertificateType | None = None, 
    1232        _key_id: bytes | None = None, 
    1233        _valid_principals: list[bytes] = [], 
    1234        _valid_for_all_principals: bool = False, 
    1235        _valid_before: int | None = None, 
    1236        _valid_after: int | None = None, 
    1237        _critical_options: list[tuple[bytes, bytes]] = [], 
    1238        _extensions: list[tuple[bytes, bytes]] = [], 
    1239    ): 
    1240        self._public_key = _public_key 
    1241        self._serial = _serial 
    1242        self._type = _type 
    1243        self._key_id = _key_id 
    1244        self._valid_principals = _valid_principals 
    1245        self._valid_for_all_principals = _valid_for_all_principals 
    1246        self._valid_before = _valid_before 
    1247        self._valid_after = _valid_after 
    1248        self._critical_options = _critical_options 
    1249        self._extensions = _extensions 
    1250 
    1251    def public_key( 
    1252        self, public_key: SSHCertPublicKeyTypes 
    1253    ) -> SSHCertificateBuilder: 
    1254        if not isinstance( 
    1255            public_key, 
    1256            ( 
    1257                ec.EllipticCurvePublicKey, 
    1258                rsa.RSAPublicKey, 
    1259                ed25519.Ed25519PublicKey, 
    1260            ), 
    1261        ): 
    1262            raise TypeError("Unsupported key type") 
    1263        if self._public_key is not None: 
    1264            raise ValueError("public_key already set") 
    1265 
    1266        return SSHCertificateBuilder( 
    1267            _public_key=public_key, 
    1268            _serial=self._serial, 
    1269            _type=self._type, 
    1270            _key_id=self._key_id, 
    1271            _valid_principals=self._valid_principals, 
    1272            _valid_for_all_principals=self._valid_for_all_principals, 
    1273            _valid_before=self._valid_before, 
    1274            _valid_after=self._valid_after, 
    1275            _critical_options=self._critical_options, 
    1276            _extensions=self._extensions, 
    1277        ) 
    1278 
    1279    def serial(self, serial: int) -> SSHCertificateBuilder: 
    1280        if not isinstance(serial, int): 
    1281            raise TypeError("serial must be an integer") 
    1282        if not 0 <= serial < 2**64: 
    1283            raise ValueError("serial must be between 0 and 2**64") 
    1284        if self._serial is not None: 
    1285            raise ValueError("serial already set") 
    1286 
    1287        return SSHCertificateBuilder( 
    1288            _public_key=self._public_key, 
    1289            _serial=serial, 
    1290            _type=self._type, 
    1291            _key_id=self._key_id, 
    1292            _valid_principals=self._valid_principals, 
    1293            _valid_for_all_principals=self._valid_for_all_principals, 
    1294            _valid_before=self._valid_before, 
    1295            _valid_after=self._valid_after, 
    1296            _critical_options=self._critical_options, 
    1297            _extensions=self._extensions, 
    1298        ) 
    1299 
    1300    def type(self, type: SSHCertificateType) -> SSHCertificateBuilder: 
    1301        if not isinstance(type, SSHCertificateType): 
    1302            raise TypeError("type must be an SSHCertificateType") 
    1303        if self._type is not None: 
    1304            raise ValueError("type already set") 
    1305 
    1306        return SSHCertificateBuilder( 
    1307            _public_key=self._public_key, 
    1308            _serial=self._serial, 
    1309            _type=type, 
    1310            _key_id=self._key_id, 
    1311            _valid_principals=self._valid_principals, 
    1312            _valid_for_all_principals=self._valid_for_all_principals, 
    1313            _valid_before=self._valid_before, 
    1314            _valid_after=self._valid_after, 
    1315            _critical_options=self._critical_options, 
    1316            _extensions=self._extensions, 
    1317        ) 
    1318 
    1319    def key_id(self, key_id: bytes) -> SSHCertificateBuilder: 
    1320        if not isinstance(key_id, bytes): 
    1321            raise TypeError("key_id must be bytes") 
    1322        if self._key_id is not None: 
    1323            raise ValueError("key_id already set") 
    1324 
    1325        return SSHCertificateBuilder( 
    1326            _public_key=self._public_key, 
    1327            _serial=self._serial, 
    1328            _type=self._type, 
    1329            _key_id=key_id, 
    1330            _valid_principals=self._valid_principals, 
    1331            _valid_for_all_principals=self._valid_for_all_principals, 
    1332            _valid_before=self._valid_before, 
    1333            _valid_after=self._valid_after, 
    1334            _critical_options=self._critical_options, 
    1335            _extensions=self._extensions, 
    1336        ) 
    1337 
    1338    def valid_principals( 
    1339        self, valid_principals: list[bytes] 
    1340    ) -> SSHCertificateBuilder: 
    1341        if self._valid_for_all_principals: 
    1342            raise ValueError( 
    1343                "Principals can't be set because the cert is valid " 
    1344                "for all principals" 
    1345            ) 
    1346        if ( 
    1347            not all(isinstance(x, bytes) for x in valid_principals) 
    1348            or not valid_principals 
    1349        ): 
    1350            raise TypeError( 
    1351                "principals must be a list of bytes and can't be empty" 
    1352            ) 
    1353        if self._valid_principals: 
    1354            raise ValueError("valid_principals already set") 
    1355 
    1356        if len(valid_principals) > _SSHKEY_CERT_MAX_PRINCIPALS: 
    1357            raise ValueError( 
    1358                "Reached or exceeded the maximum number of valid_principals" 
    1359            ) 
    1360 
    1361        return SSHCertificateBuilder( 
    1362            _public_key=self._public_key, 
    1363            _serial=self._serial, 
    1364            _type=self._type, 
    1365            _key_id=self._key_id, 
    1366            _valid_principals=valid_principals, 
    1367            _valid_for_all_principals=self._valid_for_all_principals, 
    1368            _valid_before=self._valid_before, 
    1369            _valid_after=self._valid_after, 
    1370            _critical_options=self._critical_options, 
    1371            _extensions=self._extensions, 
    1372        ) 
    1373 
    1374    def valid_for_all_principals(self): 
    1375        if self._valid_principals: 
    1376            raise ValueError( 
    1377                "valid_principals already set, can't set " 
    1378                "valid_for_all_principals" 
    1379            ) 
    1380        if self._valid_for_all_principals: 
    1381            raise ValueError("valid_for_all_principals already set") 
    1382 
    1383        return SSHCertificateBuilder( 
    1384            _public_key=self._public_key, 
    1385            _serial=self._serial, 
    1386            _type=self._type, 
    1387            _key_id=self._key_id, 
    1388            _valid_principals=self._valid_principals, 
    1389            _valid_for_all_principals=True, 
    1390            _valid_before=self._valid_before, 
    1391            _valid_after=self._valid_after, 
    1392            _critical_options=self._critical_options, 
    1393            _extensions=self._extensions, 
    1394        ) 
    1395 
    1396    def valid_before(self, valid_before: int | float) -> SSHCertificateBuilder: 
    1397        if not isinstance(valid_before, (int, float)): 
    1398            raise TypeError("valid_before must be an int or float") 
    1399        valid_before = int(valid_before) 
    1400        if valid_before < 0 or valid_before >= 2**64: 
    1401            raise ValueError("valid_before must [0, 2**64)") 
    1402        if self._valid_before is not None: 
    1403            raise ValueError("valid_before already set") 
    1404 
    1405        return SSHCertificateBuilder( 
    1406            _public_key=self._public_key, 
    1407            _serial=self._serial, 
    1408            _type=self._type, 
    1409            _key_id=self._key_id, 
    1410            _valid_principals=self._valid_principals, 
    1411            _valid_for_all_principals=self._valid_for_all_principals, 
    1412            _valid_before=valid_before, 
    1413            _valid_after=self._valid_after, 
    1414            _critical_options=self._critical_options, 
    1415            _extensions=self._extensions, 
    1416        ) 
    1417 
    1418    def valid_after(self, valid_after: int | float) -> SSHCertificateBuilder: 
    1419        if not isinstance(valid_after, (int, float)): 
    1420            raise TypeError("valid_after must be an int or float") 
    1421        valid_after = int(valid_after) 
    1422        if valid_after < 0 or valid_after >= 2**64: 
    1423            raise ValueError("valid_after must [0, 2**64)") 
    1424        if self._valid_after is not None: 
    1425            raise ValueError("valid_after already set") 
    1426 
    1427        return SSHCertificateBuilder( 
    1428            _public_key=self._public_key, 
    1429            _serial=self._serial, 
    1430            _type=self._type, 
    1431            _key_id=self._key_id, 
    1432            _valid_principals=self._valid_principals, 
    1433            _valid_for_all_principals=self._valid_for_all_principals, 
    1434            _valid_before=self._valid_before, 
    1435            _valid_after=valid_after, 
    1436            _critical_options=self._critical_options, 
    1437            _extensions=self._extensions, 
    1438        ) 
    1439 
    1440    def add_critical_option( 
    1441        self, name: bytes, value: bytes 
    1442    ) -> SSHCertificateBuilder: 
    1443        if not isinstance(name, bytes) or not isinstance(value, bytes): 
    1444            raise TypeError("name and value must be bytes") 
    1445        # This is O(n**2) 
    1446        if name in [name for name, _ in self._critical_options]: 
    1447            raise ValueError("Duplicate critical option name") 
    1448 
    1449        return SSHCertificateBuilder( 
    1450            _public_key=self._public_key, 
    1451            _serial=self._serial, 
    1452            _type=self._type, 
    1453            _key_id=self._key_id, 
    1454            _valid_principals=self._valid_principals, 
    1455            _valid_for_all_principals=self._valid_for_all_principals, 
    1456            _valid_before=self._valid_before, 
    1457            _valid_after=self._valid_after, 
    1458            _critical_options=[*self._critical_options, (name, value)], 
    1459            _extensions=self._extensions, 
    1460        ) 
    1461 
    1462    def add_extension( 
    1463        self, name: bytes, value: bytes 
    1464    ) -> SSHCertificateBuilder: 
    1465        if not isinstance(name, bytes) or not isinstance(value, bytes): 
    1466            raise TypeError("name and value must be bytes") 
    1467        # This is O(n**2) 
    1468        if name in [name for name, _ in self._extensions]: 
    1469            raise ValueError("Duplicate extension name") 
    1470 
    1471        return SSHCertificateBuilder( 
    1472            _public_key=self._public_key, 
    1473            _serial=self._serial, 
    1474            _type=self._type, 
    1475            _key_id=self._key_id, 
    1476            _valid_principals=self._valid_principals, 
    1477            _valid_for_all_principals=self._valid_for_all_principals, 
    1478            _valid_before=self._valid_before, 
    1479            _valid_after=self._valid_after, 
    1480            _critical_options=self._critical_options, 
    1481            _extensions=[*self._extensions, (name, value)], 
    1482        ) 
    1483 
    1484    def sign(self, private_key: SSHCertPrivateKeyTypes) -> SSHCertificate: 
    1485        if not isinstance( 
    1486            private_key, 
    1487            ( 
    1488                ec.EllipticCurvePrivateKey, 
    1489                rsa.RSAPrivateKey, 
    1490                ed25519.Ed25519PrivateKey, 
    1491            ), 
    1492        ): 
    1493            raise TypeError("Unsupported private key type") 
    1494 
    1495        if self._public_key is None: 
    1496            raise ValueError("public_key must be set") 
    1497 
    1498        # Not required 
    1499        serial = 0 if self._serial is None else self._serial 
    1500 
    1501        if self._type is None: 
    1502            raise ValueError("type must be set") 
    1503 
    1504        # Not required 
    1505        key_id = b"" if self._key_id is None else self._key_id 
    1506 
    1507        # A zero length list is valid, but means the certificate 
    1508        # is valid for any principal of the specified type. We require 
    1509        # the user to explicitly set valid_for_all_principals to get 
    1510        # that behavior. 
    1511        if not self._valid_principals and not self._valid_for_all_principals: 
    1512            raise ValueError( 
    1513                "valid_principals must be set if valid_for_all_principals " 
    1514                "is False" 
    1515            ) 
    1516 
    1517        if self._valid_before is None: 
    1518            raise ValueError("valid_before must be set") 
    1519 
    1520        if self._valid_after is None: 
    1521            raise ValueError("valid_after must be set") 
    1522 
    1523        if self._valid_after > self._valid_before: 
    1524            raise ValueError("valid_after must be earlier than valid_before") 
    1525 
    1526        # lexically sort our byte strings 
    1527        self._critical_options.sort(key=lambda x: x[0]) 
    1528        self._extensions.sort(key=lambda x: x[0]) 
    1529 
    1530        key_type = _get_ssh_key_type(self._public_key) 
    1531        cert_prefix = key_type + _CERT_SUFFIX 
    1532 
    1533        # Marshal the bytes to be signed 
    1534        nonce = os.urandom(32) 
    1535        kformat = _lookup_kformat(key_type) 
    1536        f = _FragList() 
    1537        f.put_sshstr(cert_prefix) 
    1538        f.put_sshstr(nonce) 
    1539        kformat.encode_public(self._public_key, f) 
    1540        f.put_u64(serial) 
    1541        f.put_u32(self._type.value) 
    1542        f.put_sshstr(key_id) 
    1543        fprincipals = _FragList() 
    1544        for p in self._valid_principals: 
    1545            fprincipals.put_sshstr(p) 
    1546        f.put_sshstr(fprincipals.tobytes()) 
    1547        f.put_u64(self._valid_after) 
    1548        f.put_u64(self._valid_before) 
    1549        fcrit = _FragList() 
    1550        for name, value in self._critical_options: 
    1551            fcrit.put_sshstr(name) 
    1552            if len(value) > 0: 
    1553                foptval = _FragList() 
    1554                foptval.put_sshstr(value) 
    1555                fcrit.put_sshstr(foptval.tobytes()) 
    1556            else: 
    1557                fcrit.put_sshstr(value) 
    1558        f.put_sshstr(fcrit.tobytes()) 
    1559        fext = _FragList() 
    1560        for name, value in self._extensions: 
    1561            fext.put_sshstr(name) 
    1562            if len(value) > 0: 
    1563                fextval = _FragList() 
    1564                fextval.put_sshstr(value) 
    1565                fext.put_sshstr(fextval.tobytes()) 
    1566            else: 
    1567                fext.put_sshstr(value) 
    1568        f.put_sshstr(fext.tobytes()) 
    1569        f.put_sshstr(b"")  # RESERVED FIELD 
    1570        # encode CA public key 
    1571        ca_type = _get_ssh_key_type(private_key) 
    1572        caformat = _lookup_kformat(ca_type) 
    1573        caf = _FragList() 
    1574        caf.put_sshstr(ca_type) 
    1575        caformat.encode_public(private_key.public_key(), caf) 
    1576        f.put_sshstr(caf.tobytes()) 
    1577        # Sigs according to the rules defined for the CA's public key 
    1578        # (RFC4253 section 6.6 for ssh-rsa, RFC5656 for ECDSA, 
    1579        # and RFC8032 for Ed25519). 
    1580        if isinstance(private_key, ed25519.Ed25519PrivateKey): 
    1581            signature = private_key.sign(f.tobytes()) 
    1582            fsig = _FragList() 
    1583            fsig.put_sshstr(ca_type) 
    1584            fsig.put_sshstr(signature) 
    1585            f.put_sshstr(fsig.tobytes()) 
    1586        elif isinstance(private_key, ec.EllipticCurvePrivateKey): 
    1587            hash_alg = _get_ec_hash_alg(private_key.curve) 
    1588            signature = private_key.sign(f.tobytes(), ec.ECDSA(hash_alg)) 
    1589            r, s = asym_utils.decode_dss_signature(signature) 
    1590            fsig = _FragList() 
    1591            fsig.put_sshstr(ca_type) 
    1592            fsigblob = _FragList() 
    1593            fsigblob.put_mpint(r) 
    1594            fsigblob.put_mpint(s) 
    1595            fsig.put_sshstr(fsigblob.tobytes()) 
    1596            f.put_sshstr(fsig.tobytes()) 
    1597 
    1598        else: 
    1599            assert isinstance(private_key, rsa.RSAPrivateKey) 
    1600            # Just like Golang, we're going to use SHA512 for RSA 
    1601            # https://cs.opensource.google/go/x/crypto/+/refs/tags/ 
    1602            # v0.4.0:ssh/certs.go;l=445 
    1603            # RFC 8332 defines SHA256 and 512 as options 
    1604            fsig = _FragList() 
    1605            fsig.put_sshstr(_SSH_RSA_SHA512) 
    1606            signature = private_key.sign( 
    1607                f.tobytes(), padding.PKCS1v15(), hashes.SHA512() 
    1608            ) 
    1609            fsig.put_sshstr(signature) 
    1610            f.put_sshstr(fsig.tobytes()) 
    1611 
    1612        cert_data = binascii.b2a_base64(f.tobytes()).strip() 
    1613        # load_ssh_public_identity returns a union, but this is 
    1614        # guaranteed to be an SSHCertificate, so we cast to make 
    1615        # mypy happy. 
    1616        return typing.cast( 
    1617            SSHCertificate, 
    1618            load_ssh_public_identity(b"".join([cert_prefix, b" ", cert_data])), 
    1619        )