Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/cryptography/hazmat/primitives/serialization/ssh.py: 20%
723 statements
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 06:36 +0000
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 06:36 +0000
1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
6import binascii
7import enum
8import os
9import re
10import typing
11import warnings
12from base64 import encodebytes as _base64_encode
14from cryptography import utils
15from cryptography.exceptions import UnsupportedAlgorithm
16from cryptography.hazmat.primitives import hashes
17from cryptography.hazmat.primitives.asymmetric import (
18 dsa,
19 ec,
20 ed25519,
21 padding,
22 rsa,
23)
24from cryptography.hazmat.primitives.asymmetric import utils as asym_utils
25from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
26from cryptography.hazmat.primitives.serialization import (
27 Encoding,
28 KeySerializationEncryption,
29 NoEncryption,
30 PrivateFormat,
31 PublicFormat,
32 _KeySerializationEncryption,
33)
35try:
36 from bcrypt import kdf as _bcrypt_kdf
38 _bcrypt_supported = True
39except ImportError:
40 _bcrypt_supported = False
42 def _bcrypt_kdf(
43 password: bytes,
44 salt: bytes,
45 desired_key_bytes: int,
46 rounds: int,
47 ignore_few_rounds: bool = False,
48 ) -> bytes:
49 raise UnsupportedAlgorithm("Need bcrypt module")
52_SSH_ED25519 = b"ssh-ed25519"
53_SSH_RSA = b"ssh-rsa"
54_SSH_DSA = b"ssh-dss"
55_ECDSA_NISTP256 = b"ecdsa-sha2-nistp256"
56_ECDSA_NISTP384 = b"ecdsa-sha2-nistp384"
57_ECDSA_NISTP521 = b"ecdsa-sha2-nistp521"
58_CERT_SUFFIX = b"-cert-v01@openssh.com"
60# These are not key types, only algorithms, so they cannot appear
61# as a public key type
62_SSH_RSA_SHA256 = b"rsa-sha2-256"
63_SSH_RSA_SHA512 = b"rsa-sha2-512"
65_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)")
66_SK_MAGIC = b"openssh-key-v1\0"
67_SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----"
68_SK_END = b"-----END OPENSSH PRIVATE KEY-----"
69_BCRYPT = b"bcrypt"
70_NONE = b"none"
71_DEFAULT_CIPHER = b"aes256-ctr"
72_DEFAULT_ROUNDS = 16
74# re is only way to work on bytes-like data
75_PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL)
77# padding for max blocksize
78_PADDING = memoryview(bytearray(range(1, 1 + 16)))
80# ciphers that are actually used in key wrapping
81_SSH_CIPHERS: typing.Dict[
82 bytes,
83 typing.Tuple[
84 typing.Type[algorithms.AES],
85 int,
86 typing.Union[typing.Type[modes.CTR], typing.Type[modes.CBC]],
87 int,
88 ],
89] = {
90 b"aes256-ctr": (algorithms.AES, 32, modes.CTR, 16),
91 b"aes256-cbc": (algorithms.AES, 32, modes.CBC, 16),
92}
94# map local curve name to key type
95_ECDSA_KEY_TYPE = {
96 "secp256r1": _ECDSA_NISTP256,
97 "secp384r1": _ECDSA_NISTP384,
98 "secp521r1": _ECDSA_NISTP521,
99}
102def _get_ssh_key_type(
103 key: typing.Union["SSHPrivateKeyTypes", "SSHPublicKeyTypes"]
104) -> bytes:
105 if isinstance(key, ec.EllipticCurvePrivateKey):
106 key_type = _ecdsa_key_type(key.public_key())
107 elif isinstance(key, ec.EllipticCurvePublicKey):
108 key_type = _ecdsa_key_type(key)
109 elif isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)):
110 key_type = _SSH_RSA
111 elif isinstance(key, (dsa.DSAPrivateKey, dsa.DSAPublicKey)):
112 key_type = _SSH_DSA
113 elif isinstance(
114 key, (ed25519.Ed25519PrivateKey, ed25519.Ed25519PublicKey)
115 ):
116 key_type = _SSH_ED25519
117 else:
118 raise ValueError("Unsupported key type")
120 return key_type
123def _ecdsa_key_type(public_key: ec.EllipticCurvePublicKey) -> bytes:
124 """Return SSH key_type and curve_name for private key."""
125 curve = public_key.curve
126 if curve.name not in _ECDSA_KEY_TYPE:
127 raise ValueError(
128 f"Unsupported curve for ssh private key: {curve.name!r}"
129 )
130 return _ECDSA_KEY_TYPE[curve.name]
133def _ssh_pem_encode(
134 data: bytes,
135 prefix: bytes = _SK_START + b"\n",
136 suffix: bytes = _SK_END + b"\n",
137) -> bytes:
138 return b"".join([prefix, _base64_encode(data), suffix])
141def _check_block_size(data: bytes, block_len: int) -> None:
142 """Require data to be full blocks"""
143 if not data or len(data) % block_len != 0:
144 raise ValueError("Corrupt data: missing padding")
147def _check_empty(data: bytes) -> None:
148 """All data should have been parsed."""
149 if data:
150 raise ValueError("Corrupt data: unparsed data")
153def _init_cipher(
154 ciphername: bytes,
155 password: typing.Optional[bytes],
156 salt: bytes,
157 rounds: int,
158) -> Cipher[typing.Union[modes.CBC, modes.CTR]]:
159 """Generate key + iv and return cipher."""
160 if not password:
161 raise ValueError("Key is password-protected.")
163 algo, key_len, mode, iv_len = _SSH_CIPHERS[ciphername]
164 seed = _bcrypt_kdf(password, salt, key_len + iv_len, rounds, True)
165 return Cipher(algo(seed[:key_len]), mode(seed[key_len:]))
168def _get_u32(data: memoryview) -> typing.Tuple[int, memoryview]:
169 """Uint32"""
170 if len(data) < 4:
171 raise ValueError("Invalid data")
172 return int.from_bytes(data[:4], byteorder="big"), data[4:]
175def _get_u64(data: memoryview) -> typing.Tuple[int, memoryview]:
176 """Uint64"""
177 if len(data) < 8:
178 raise ValueError("Invalid data")
179 return int.from_bytes(data[:8], byteorder="big"), data[8:]
182def _get_sshstr(data: memoryview) -> typing.Tuple[memoryview, memoryview]:
183 """Bytes with u32 length prefix"""
184 n, data = _get_u32(data)
185 if n > len(data):
186 raise ValueError("Invalid data")
187 return data[:n], data[n:]
190def _get_mpint(data: memoryview) -> typing.Tuple[int, memoryview]:
191 """Big integer."""
192 val, data = _get_sshstr(data)
193 if val and val[0] > 0x7F:
194 raise ValueError("Invalid data")
195 return int.from_bytes(val, "big"), data
198def _to_mpint(val: int) -> bytes:
199 """Storage format for signed bigint."""
200 if val < 0:
201 raise ValueError("negative mpint not allowed")
202 if not val:
203 return b""
204 nbytes = (val.bit_length() + 8) // 8
205 return utils.int_to_bytes(val, nbytes)
208class _FragList:
209 """Build recursive structure without data copy."""
211 flist: typing.List[bytes]
213 def __init__(
214 self, init: typing.Optional[typing.List[bytes]] = None
215 ) -> None:
216 self.flist = []
217 if init:
218 self.flist.extend(init)
220 def put_raw(self, val: bytes) -> None:
221 """Add plain bytes"""
222 self.flist.append(val)
224 def put_u32(self, val: int) -> None:
225 """Big-endian uint32"""
226 self.flist.append(val.to_bytes(length=4, byteorder="big"))
228 def put_u64(self, val: int) -> None:
229 """Big-endian uint64"""
230 self.flist.append(val.to_bytes(length=8, byteorder="big"))
232 def put_sshstr(self, val: typing.Union[bytes, "_FragList"]) -> None:
233 """Bytes prefixed with u32 length"""
234 if isinstance(val, (bytes, memoryview, bytearray)):
235 self.put_u32(len(val))
236 self.flist.append(val)
237 else:
238 self.put_u32(val.size())
239 self.flist.extend(val.flist)
241 def put_mpint(self, val: int) -> None:
242 """Big-endian bigint prefixed with u32 length"""
243 self.put_sshstr(_to_mpint(val))
245 def size(self) -> int:
246 """Current number of bytes"""
247 return sum(map(len, self.flist))
249 def render(self, dstbuf: memoryview, pos: int = 0) -> int:
250 """Write into bytearray"""
251 for frag in self.flist:
252 flen = len(frag)
253 start, pos = pos, pos + flen
254 dstbuf[start:pos] = frag
255 return pos
257 def tobytes(self) -> bytes:
258 """Return as bytes"""
259 buf = memoryview(bytearray(self.size()))
260 self.render(buf)
261 return buf.tobytes()
264class _SSHFormatRSA:
265 """Format for RSA keys.
267 Public:
268 mpint e, n
269 Private:
270 mpint n, e, d, iqmp, p, q
271 """
273 def get_public(self, data: memoryview):
274 """RSA public fields"""
275 e, data = _get_mpint(data)
276 n, data = _get_mpint(data)
277 return (e, n), data
279 def load_public(
280 self, data: memoryview
281 ) -> typing.Tuple[rsa.RSAPublicKey, memoryview]:
282 """Make RSA public key from data."""
283 (e, n), data = self.get_public(data)
284 public_numbers = rsa.RSAPublicNumbers(e, n)
285 public_key = public_numbers.public_key()
286 return public_key, data
288 def load_private(
289 self, data: memoryview, pubfields
290 ) -> typing.Tuple[rsa.RSAPrivateKey, memoryview]:
291 """Make RSA private key from data."""
292 n, data = _get_mpint(data)
293 e, data = _get_mpint(data)
294 d, data = _get_mpint(data)
295 iqmp, data = _get_mpint(data)
296 p, data = _get_mpint(data)
297 q, data = _get_mpint(data)
299 if (e, n) != pubfields:
300 raise ValueError("Corrupt data: rsa field mismatch")
301 dmp1 = rsa.rsa_crt_dmp1(d, p)
302 dmq1 = rsa.rsa_crt_dmq1(d, q)
303 public_numbers = rsa.RSAPublicNumbers(e, n)
304 private_numbers = rsa.RSAPrivateNumbers(
305 p, q, d, dmp1, dmq1, iqmp, public_numbers
306 )
307 private_key = private_numbers.private_key()
308 return private_key, data
310 def encode_public(
311 self, public_key: rsa.RSAPublicKey, f_pub: _FragList
312 ) -> None:
313 """Write RSA public key"""
314 pubn = public_key.public_numbers()
315 f_pub.put_mpint(pubn.e)
316 f_pub.put_mpint(pubn.n)
318 def encode_private(
319 self, private_key: rsa.RSAPrivateKey, f_priv: _FragList
320 ) -> None:
321 """Write RSA private key"""
322 private_numbers = private_key.private_numbers()
323 public_numbers = private_numbers.public_numbers
325 f_priv.put_mpint(public_numbers.n)
326 f_priv.put_mpint(public_numbers.e)
328 f_priv.put_mpint(private_numbers.d)
329 f_priv.put_mpint(private_numbers.iqmp)
330 f_priv.put_mpint(private_numbers.p)
331 f_priv.put_mpint(private_numbers.q)
334class _SSHFormatDSA:
335 """Format for DSA keys.
337 Public:
338 mpint p, q, g, y
339 Private:
340 mpint p, q, g, y, x
341 """
343 def get_public(
344 self, data: memoryview
345 ) -> typing.Tuple[typing.Tuple, memoryview]:
346 """DSA public fields"""
347 p, data = _get_mpint(data)
348 q, data = _get_mpint(data)
349 g, data = _get_mpint(data)
350 y, data = _get_mpint(data)
351 return (p, q, g, y), data
353 def load_public(
354 self, data: memoryview
355 ) -> typing.Tuple[dsa.DSAPublicKey, memoryview]:
356 """Make DSA public key from data."""
357 (p, q, g, y), data = self.get_public(data)
358 parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
359 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
360 self._validate(public_numbers)
361 public_key = public_numbers.public_key()
362 return public_key, data
364 def load_private(
365 self, data: memoryview, pubfields
366 ) -> typing.Tuple[dsa.DSAPrivateKey, memoryview]:
367 """Make DSA private key from data."""
368 (p, q, g, y), data = self.get_public(data)
369 x, data = _get_mpint(data)
371 if (p, q, g, y) != pubfields:
372 raise ValueError("Corrupt data: dsa field mismatch")
373 parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
374 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
375 self._validate(public_numbers)
376 private_numbers = dsa.DSAPrivateNumbers(x, public_numbers)
377 private_key = private_numbers.private_key()
378 return private_key, data
380 def encode_public(
381 self, public_key: dsa.DSAPublicKey, f_pub: _FragList
382 ) -> None:
383 """Write DSA public key"""
384 public_numbers = public_key.public_numbers()
385 parameter_numbers = public_numbers.parameter_numbers
386 self._validate(public_numbers)
388 f_pub.put_mpint(parameter_numbers.p)
389 f_pub.put_mpint(parameter_numbers.q)
390 f_pub.put_mpint(parameter_numbers.g)
391 f_pub.put_mpint(public_numbers.y)
393 def encode_private(
394 self, private_key: dsa.DSAPrivateKey, f_priv: _FragList
395 ) -> None:
396 """Write DSA private key"""
397 self.encode_public(private_key.public_key(), f_priv)
398 f_priv.put_mpint(private_key.private_numbers().x)
400 def _validate(self, public_numbers: dsa.DSAPublicNumbers) -> None:
401 parameter_numbers = public_numbers.parameter_numbers
402 if parameter_numbers.p.bit_length() != 1024:
403 raise ValueError("SSH supports only 1024 bit DSA keys")
406class _SSHFormatECDSA:
407 """Format for ECDSA keys.
409 Public:
410 str curve
411 bytes point
412 Private:
413 str curve
414 bytes point
415 mpint secret
416 """
418 def __init__(self, ssh_curve_name: bytes, curve: ec.EllipticCurve):
419 self.ssh_curve_name = ssh_curve_name
420 self.curve = curve
422 def get_public(
423 self, data: memoryview
424 ) -> typing.Tuple[typing.Tuple, memoryview]:
425 """ECDSA public fields"""
426 curve, data = _get_sshstr(data)
427 point, data = _get_sshstr(data)
428 if curve != self.ssh_curve_name:
429 raise ValueError("Curve name mismatch")
430 if point[0] != 4:
431 raise NotImplementedError("Need uncompressed point")
432 return (curve, point), data
434 def load_public(
435 self, data: memoryview
436 ) -> typing.Tuple[ec.EllipticCurvePublicKey, memoryview]:
437 """Make ECDSA public key from data."""
438 (curve_name, point), data = self.get_public(data)
439 public_key = ec.EllipticCurvePublicKey.from_encoded_point(
440 self.curve, point.tobytes()
441 )
442 return public_key, data
444 def load_private(
445 self, data: memoryview, pubfields
446 ) -> typing.Tuple[ec.EllipticCurvePrivateKey, memoryview]:
447 """Make ECDSA private key from data."""
448 (curve_name, point), data = self.get_public(data)
449 secret, data = _get_mpint(data)
451 if (curve_name, point) != pubfields:
452 raise ValueError("Corrupt data: ecdsa field mismatch")
453 private_key = ec.derive_private_key(secret, self.curve)
454 return private_key, data
456 def encode_public(
457 self, public_key: ec.EllipticCurvePublicKey, f_pub: _FragList
458 ) -> None:
459 """Write ECDSA public key"""
460 point = public_key.public_bytes(
461 Encoding.X962, PublicFormat.UncompressedPoint
462 )
463 f_pub.put_sshstr(self.ssh_curve_name)
464 f_pub.put_sshstr(point)
466 def encode_private(
467 self, private_key: ec.EllipticCurvePrivateKey, f_priv: _FragList
468 ) -> None:
469 """Write ECDSA private key"""
470 public_key = private_key.public_key()
471 private_numbers = private_key.private_numbers()
473 self.encode_public(public_key, f_priv)
474 f_priv.put_mpint(private_numbers.private_value)
477class _SSHFormatEd25519:
478 """Format for Ed25519 keys.
480 Public:
481 bytes point
482 Private:
483 bytes point
484 bytes secret_and_point
485 """
487 def get_public(
488 self, data: memoryview
489 ) -> typing.Tuple[typing.Tuple, memoryview]:
490 """Ed25519 public fields"""
491 point, data = _get_sshstr(data)
492 return (point,), data
494 def load_public(
495 self, data: memoryview
496 ) -> typing.Tuple[ed25519.Ed25519PublicKey, memoryview]:
497 """Make Ed25519 public key from data."""
498 (point,), data = self.get_public(data)
499 public_key = ed25519.Ed25519PublicKey.from_public_bytes(
500 point.tobytes()
501 )
502 return public_key, data
504 def load_private(
505 self, data: memoryview, pubfields
506 ) -> typing.Tuple[ed25519.Ed25519PrivateKey, memoryview]:
507 """Make Ed25519 private key from data."""
508 (point,), data = self.get_public(data)
509 keypair, data = _get_sshstr(data)
511 secret = keypair[:32]
512 point2 = keypair[32:]
513 if point != point2 or (point,) != pubfields:
514 raise ValueError("Corrupt data: ed25519 field mismatch")
515 private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret)
516 return private_key, data
518 def encode_public(
519 self, public_key: ed25519.Ed25519PublicKey, f_pub: _FragList
520 ) -> None:
521 """Write Ed25519 public key"""
522 raw_public_key = public_key.public_bytes(
523 Encoding.Raw, PublicFormat.Raw
524 )
525 f_pub.put_sshstr(raw_public_key)
527 def encode_private(
528 self, private_key: ed25519.Ed25519PrivateKey, f_priv: _FragList
529 ) -> None:
530 """Write Ed25519 private key"""
531 public_key = private_key.public_key()
532 raw_private_key = private_key.private_bytes(
533 Encoding.Raw, PrivateFormat.Raw, NoEncryption()
534 )
535 raw_public_key = public_key.public_bytes(
536 Encoding.Raw, PublicFormat.Raw
537 )
538 f_keypair = _FragList([raw_private_key, raw_public_key])
540 self.encode_public(public_key, f_priv)
541 f_priv.put_sshstr(f_keypair)
544_KEY_FORMATS = {
545 _SSH_RSA: _SSHFormatRSA(),
546 _SSH_DSA: _SSHFormatDSA(),
547 _SSH_ED25519: _SSHFormatEd25519(),
548 _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()),
549 _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()),
550 _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()),
551}
554def _lookup_kformat(key_type: bytes):
555 """Return valid format or throw error"""
556 if not isinstance(key_type, bytes):
557 key_type = memoryview(key_type).tobytes()
558 if key_type in _KEY_FORMATS:
559 return _KEY_FORMATS[key_type]
560 raise UnsupportedAlgorithm(f"Unsupported key type: {key_type!r}")
563SSHPrivateKeyTypes = typing.Union[
564 ec.EllipticCurvePrivateKey,
565 rsa.RSAPrivateKey,
566 dsa.DSAPrivateKey,
567 ed25519.Ed25519PrivateKey,
568]
571def load_ssh_private_key(
572 data: bytes,
573 password: typing.Optional[bytes],
574 backend: typing.Any = None,
575) -> SSHPrivateKeyTypes:
576 """Load private key from OpenSSH custom encoding."""
577 utils._check_byteslike("data", data)
578 if password is not None:
579 utils._check_bytes("password", password)
581 m = _PEM_RC.search(data)
582 if not m:
583 raise ValueError("Not OpenSSH private key format")
584 p1 = m.start(1)
585 p2 = m.end(1)
586 data = binascii.a2b_base64(memoryview(data)[p1:p2])
587 if not data.startswith(_SK_MAGIC):
588 raise ValueError("Not OpenSSH private key format")
589 data = memoryview(data)[len(_SK_MAGIC) :]
591 # parse header
592 ciphername, data = _get_sshstr(data)
593 kdfname, data = _get_sshstr(data)
594 kdfoptions, data = _get_sshstr(data)
595 nkeys, data = _get_u32(data)
596 if nkeys != 1:
597 raise ValueError("Only one key supported")
599 # load public key data
600 pubdata, data = _get_sshstr(data)
601 pub_key_type, pubdata = _get_sshstr(pubdata)
602 kformat = _lookup_kformat(pub_key_type)
603 pubfields, pubdata = kformat.get_public(pubdata)
604 _check_empty(pubdata)
606 # load secret data
607 edata, data = _get_sshstr(data)
608 _check_empty(data)
610 if (ciphername, kdfname) != (_NONE, _NONE):
611 ciphername_bytes = ciphername.tobytes()
612 if ciphername_bytes not in _SSH_CIPHERS:
613 raise UnsupportedAlgorithm(
614 f"Unsupported cipher: {ciphername_bytes!r}"
615 )
616 if kdfname != _BCRYPT:
617 raise UnsupportedAlgorithm(f"Unsupported KDF: {kdfname!r}")
618 blklen = _SSH_CIPHERS[ciphername_bytes][3]
619 _check_block_size(edata, blklen)
620 salt, kbuf = _get_sshstr(kdfoptions)
621 rounds, kbuf = _get_u32(kbuf)
622 _check_empty(kbuf)
623 ciph = _init_cipher(ciphername_bytes, password, salt.tobytes(), rounds)
624 edata = memoryview(ciph.decryptor().update(edata))
625 else:
626 blklen = 8
627 _check_block_size(edata, blklen)
628 ck1, edata = _get_u32(edata)
629 ck2, edata = _get_u32(edata)
630 if ck1 != ck2:
631 raise ValueError("Corrupt data: broken checksum")
633 # load per-key struct
634 key_type, edata = _get_sshstr(edata)
635 if key_type != pub_key_type:
636 raise ValueError("Corrupt data: key type mismatch")
637 private_key, edata = kformat.load_private(edata, pubfields)
638 comment, edata = _get_sshstr(edata)
640 # yes, SSH does padding check *after* all other parsing is done.
641 # need to follow as it writes zero-byte padding too.
642 if edata != _PADDING[: len(edata)]:
643 raise ValueError("Corrupt data: invalid padding")
645 if isinstance(private_key, dsa.DSAPrivateKey):
646 warnings.warn(
647 "SSH DSA keys are deprecated and will be removed in a future "
648 "release.",
649 utils.DeprecatedIn40,
650 stacklevel=2,
651 )
653 return private_key
656def _serialize_ssh_private_key(
657 private_key: SSHPrivateKeyTypes,
658 password: bytes,
659 encryption_algorithm: KeySerializationEncryption,
660) -> bytes:
661 """Serialize private key with OpenSSH custom encoding."""
662 utils._check_bytes("password", password)
663 if isinstance(private_key, dsa.DSAPrivateKey):
664 warnings.warn(
665 "SSH DSA key support is deprecated and will be "
666 "removed in a future release",
667 utils.DeprecatedIn40,
668 stacklevel=4,
669 )
671 key_type = _get_ssh_key_type(private_key)
672 kformat = _lookup_kformat(key_type)
674 # setup parameters
675 f_kdfoptions = _FragList()
676 if password:
677 ciphername = _DEFAULT_CIPHER
678 blklen = _SSH_CIPHERS[ciphername][3]
679 kdfname = _BCRYPT
680 rounds = _DEFAULT_ROUNDS
681 if (
682 isinstance(encryption_algorithm, _KeySerializationEncryption)
683 and encryption_algorithm._kdf_rounds is not None
684 ):
685 rounds = encryption_algorithm._kdf_rounds
686 salt = os.urandom(16)
687 f_kdfoptions.put_sshstr(salt)
688 f_kdfoptions.put_u32(rounds)
689 ciph = _init_cipher(ciphername, password, salt, rounds)
690 else:
691 ciphername = kdfname = _NONE
692 blklen = 8
693 ciph = None
694 nkeys = 1
695 checkval = os.urandom(4)
696 comment = b""
698 # encode public and private parts together
699 f_public_key = _FragList()
700 f_public_key.put_sshstr(key_type)
701 kformat.encode_public(private_key.public_key(), f_public_key)
703 f_secrets = _FragList([checkval, checkval])
704 f_secrets.put_sshstr(key_type)
705 kformat.encode_private(private_key, f_secrets)
706 f_secrets.put_sshstr(comment)
707 f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)])
709 # top-level structure
710 f_main = _FragList()
711 f_main.put_raw(_SK_MAGIC)
712 f_main.put_sshstr(ciphername)
713 f_main.put_sshstr(kdfname)
714 f_main.put_sshstr(f_kdfoptions)
715 f_main.put_u32(nkeys)
716 f_main.put_sshstr(f_public_key)
717 f_main.put_sshstr(f_secrets)
719 # copy result info bytearray
720 slen = f_secrets.size()
721 mlen = f_main.size()
722 buf = memoryview(bytearray(mlen + blklen))
723 f_main.render(buf)
724 ofs = mlen - slen
726 # encrypt in-place
727 if ciph is not None:
728 ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:])
730 return _ssh_pem_encode(buf[:mlen])
733SSHPublicKeyTypes = typing.Union[
734 ec.EllipticCurvePublicKey,
735 rsa.RSAPublicKey,
736 dsa.DSAPublicKey,
737 ed25519.Ed25519PublicKey,
738]
740SSHCertPublicKeyTypes = typing.Union[
741 ec.EllipticCurvePublicKey,
742 rsa.RSAPublicKey,
743 ed25519.Ed25519PublicKey,
744]
747class SSHCertificateType(enum.Enum):
748 USER = 1
749 HOST = 2
752class SSHCertificate:
753 def __init__(
754 self,
755 _nonce: memoryview,
756 _public_key: SSHPublicKeyTypes,
757 _serial: int,
758 _cctype: int,
759 _key_id: memoryview,
760 _valid_principals: typing.List[bytes],
761 _valid_after: int,
762 _valid_before: int,
763 _critical_options: typing.Dict[bytes, bytes],
764 _extensions: typing.Dict[bytes, bytes],
765 _sig_type: memoryview,
766 _sig_key: memoryview,
767 _inner_sig_type: memoryview,
768 _signature: memoryview,
769 _tbs_cert_body: memoryview,
770 _cert_key_type: bytes,
771 _cert_body: memoryview,
772 ):
773 self._nonce = _nonce
774 self._public_key = _public_key
775 self._serial = _serial
776 try:
777 self._type = SSHCertificateType(_cctype)
778 except ValueError:
779 raise ValueError("Invalid certificate type")
780 self._key_id = _key_id
781 self._valid_principals = _valid_principals
782 self._valid_after = _valid_after
783 self._valid_before = _valid_before
784 self._critical_options = _critical_options
785 self._extensions = _extensions
786 self._sig_type = _sig_type
787 self._sig_key = _sig_key
788 self._inner_sig_type = _inner_sig_type
789 self._signature = _signature
790 self._cert_key_type = _cert_key_type
791 self._cert_body = _cert_body
792 self._tbs_cert_body = _tbs_cert_body
794 @property
795 def nonce(self) -> bytes:
796 return bytes(self._nonce)
798 def public_key(self) -> SSHCertPublicKeyTypes:
799 # make mypy happy until we remove DSA support entirely and
800 # the underlying union won't have a disallowed type
801 return typing.cast(SSHCertPublicKeyTypes, self._public_key)
803 @property
804 def serial(self) -> int:
805 return self._serial
807 @property
808 def type(self) -> SSHCertificateType:
809 return self._type
811 @property
812 def key_id(self) -> bytes:
813 return bytes(self._key_id)
815 @property
816 def valid_principals(self) -> typing.List[bytes]:
817 return self._valid_principals
819 @property
820 def valid_before(self) -> int:
821 return self._valid_before
823 @property
824 def valid_after(self) -> int:
825 return self._valid_after
827 @property
828 def critical_options(self) -> typing.Dict[bytes, bytes]:
829 return self._critical_options
831 @property
832 def extensions(self) -> typing.Dict[bytes, bytes]:
833 return self._extensions
835 def signature_key(self) -> SSHCertPublicKeyTypes:
836 sigformat = _lookup_kformat(self._sig_type)
837 signature_key, sigkey_rest = sigformat.load_public(self._sig_key)
838 _check_empty(sigkey_rest)
839 return signature_key
841 def public_bytes(self) -> bytes:
842 return (
843 bytes(self._cert_key_type)
844 + b" "
845 + binascii.b2a_base64(bytes(self._cert_body), newline=False)
846 )
848 def verify_cert_signature(self) -> None:
849 signature_key = self.signature_key()
850 if isinstance(signature_key, ed25519.Ed25519PublicKey):
851 signature_key.verify(
852 bytes(self._signature), bytes(self._tbs_cert_body)
853 )
854 elif isinstance(signature_key, ec.EllipticCurvePublicKey):
855 # The signature is encoded as a pair of big-endian integers
856 r, data = _get_mpint(self._signature)
857 s, data = _get_mpint(data)
858 _check_empty(data)
859 computed_sig = asym_utils.encode_dss_signature(r, s)
860 hash_alg = _get_ec_hash_alg(signature_key.curve)
861 signature_key.verify(
862 computed_sig, bytes(self._tbs_cert_body), ec.ECDSA(hash_alg)
863 )
864 else:
865 assert isinstance(signature_key, rsa.RSAPublicKey)
866 if self._inner_sig_type == _SSH_RSA:
867 hash_alg = hashes.SHA1()
868 elif self._inner_sig_type == _SSH_RSA_SHA256:
869 hash_alg = hashes.SHA256()
870 else:
871 assert self._inner_sig_type == _SSH_RSA_SHA512
872 hash_alg = hashes.SHA512()
873 signature_key.verify(
874 bytes(self._signature),
875 bytes(self._tbs_cert_body),
876 padding.PKCS1v15(),
877 hash_alg,
878 )
881def _get_ec_hash_alg(curve: ec.EllipticCurve) -> hashes.HashAlgorithm:
882 if isinstance(curve, ec.SECP256R1):
883 return hashes.SHA256()
884 elif isinstance(curve, ec.SECP384R1):
885 return hashes.SHA384()
886 else:
887 assert isinstance(curve, ec.SECP521R1)
888 return hashes.SHA512()
891def _load_ssh_public_identity(
892 data: bytes,
893 _legacy_dsa_allowed=False,
894) -> typing.Union[SSHCertificate, SSHPublicKeyTypes]:
895 utils._check_byteslike("data", data)
897 m = _SSH_PUBKEY_RC.match(data)
898 if not m:
899 raise ValueError("Invalid line format")
900 key_type = orig_key_type = m.group(1)
901 key_body = m.group(2)
902 with_cert = False
903 if key_type.endswith(_CERT_SUFFIX):
904 with_cert = True
905 key_type = key_type[: -len(_CERT_SUFFIX)]
906 if key_type == _SSH_DSA and not _legacy_dsa_allowed:
907 raise UnsupportedAlgorithm(
908 "DSA keys aren't supported in SSH certificates"
909 )
910 kformat = _lookup_kformat(key_type)
912 try:
913 rest = memoryview(binascii.a2b_base64(key_body))
914 except (TypeError, binascii.Error):
915 raise ValueError("Invalid format")
917 if with_cert:
918 cert_body = rest
919 inner_key_type, rest = _get_sshstr(rest)
920 if inner_key_type != orig_key_type:
921 raise ValueError("Invalid key format")
922 if with_cert:
923 nonce, rest = _get_sshstr(rest)
924 public_key, rest = kformat.load_public(rest)
925 if with_cert:
926 serial, rest = _get_u64(rest)
927 cctype, rest = _get_u32(rest)
928 key_id, rest = _get_sshstr(rest)
929 principals, rest = _get_sshstr(rest)
930 valid_principals = []
931 while principals:
932 principal, principals = _get_sshstr(principals)
933 valid_principals.append(bytes(principal))
934 valid_after, rest = _get_u64(rest)
935 valid_before, rest = _get_u64(rest)
936 crit_options, rest = _get_sshstr(rest)
937 critical_options = _parse_exts_opts(crit_options)
938 exts, rest = _get_sshstr(rest)
939 extensions = _parse_exts_opts(exts)
940 # Get the reserved field, which is unused.
941 _, rest = _get_sshstr(rest)
942 sig_key_raw, rest = _get_sshstr(rest)
943 sig_type, sig_key = _get_sshstr(sig_key_raw)
944 if sig_type == _SSH_DSA and not _legacy_dsa_allowed:
945 raise UnsupportedAlgorithm(
946 "DSA signatures aren't supported in SSH certificates"
947 )
948 # Get the entire cert body and subtract the signature
949 tbs_cert_body = cert_body[: -len(rest)]
950 signature_raw, rest = _get_sshstr(rest)
951 _check_empty(rest)
952 inner_sig_type, sig_rest = _get_sshstr(signature_raw)
953 # RSA certs can have multiple algorithm types
954 if (
955 sig_type == _SSH_RSA
956 and inner_sig_type
957 not in [_SSH_RSA_SHA256, _SSH_RSA_SHA512, _SSH_RSA]
958 ) or (sig_type != _SSH_RSA and inner_sig_type != sig_type):
959 raise ValueError("Signature key type does not match")
960 signature, sig_rest = _get_sshstr(sig_rest)
961 _check_empty(sig_rest)
962 return SSHCertificate(
963 nonce,
964 public_key,
965 serial,
966 cctype,
967 key_id,
968 valid_principals,
969 valid_after,
970 valid_before,
971 critical_options,
972 extensions,
973 sig_type,
974 sig_key,
975 inner_sig_type,
976 signature,
977 tbs_cert_body,
978 orig_key_type,
979 cert_body,
980 )
981 else:
982 _check_empty(rest)
983 return public_key
986def load_ssh_public_identity(
987 data: bytes,
988) -> typing.Union[SSHCertificate, SSHPublicKeyTypes]:
989 return _load_ssh_public_identity(data)
992def _parse_exts_opts(exts_opts: memoryview) -> typing.Dict[bytes, bytes]:
993 result: typing.Dict[bytes, bytes] = {}
994 last_name = None
995 while exts_opts:
996 name, exts_opts = _get_sshstr(exts_opts)
997 bname: bytes = bytes(name)
998 if bname in result:
999 raise ValueError("Duplicate name")
1000 if last_name is not None and bname < last_name:
1001 raise ValueError("Fields not lexically sorted")
1002 value, exts_opts = _get_sshstr(exts_opts)
1003 result[bname] = bytes(value)
1004 last_name = bname
1005 return result
1008def load_ssh_public_key(
1009 data: bytes, backend: typing.Any = None
1010) -> SSHPublicKeyTypes:
1011 cert_or_key = _load_ssh_public_identity(data, _legacy_dsa_allowed=True)
1012 public_key: SSHPublicKeyTypes
1013 if isinstance(cert_or_key, SSHCertificate):
1014 public_key = cert_or_key.public_key()
1015 else:
1016 public_key = cert_or_key
1018 if isinstance(public_key, dsa.DSAPublicKey):
1019 warnings.warn(
1020 "SSH DSA keys are deprecated and will be removed in a future "
1021 "release.",
1022 utils.DeprecatedIn40,
1023 stacklevel=2,
1024 )
1025 return public_key
1028def serialize_ssh_public_key(public_key: SSHPublicKeyTypes) -> bytes:
1029 """One-line public key format for OpenSSH"""
1030 if isinstance(public_key, dsa.DSAPublicKey):
1031 warnings.warn(
1032 "SSH DSA key support is deprecated and will be "
1033 "removed in a future release",
1034 utils.DeprecatedIn40,
1035 stacklevel=4,
1036 )
1037 key_type = _get_ssh_key_type(public_key)
1038 kformat = _lookup_kformat(key_type)
1040 f_pub = _FragList()
1041 f_pub.put_sshstr(key_type)
1042 kformat.encode_public(public_key, f_pub)
1044 pub = binascii.b2a_base64(f_pub.tobytes()).strip()
1045 return b"".join([key_type, b" ", pub])
1048SSHCertPrivateKeyTypes = typing.Union[
1049 ec.EllipticCurvePrivateKey,
1050 rsa.RSAPrivateKey,
1051 ed25519.Ed25519PrivateKey,
1052]
1055# This is an undocumented limit enforced in the openssh codebase for sshd and
1056# ssh-keygen, but it is undefined in the ssh certificates spec.
1057_SSHKEY_CERT_MAX_PRINCIPALS = 256
1060class SSHCertificateBuilder:
1061 def __init__(
1062 self,
1063 _public_key: typing.Optional[SSHCertPublicKeyTypes] = None,
1064 _serial: typing.Optional[int] = None,
1065 _type: typing.Optional[SSHCertificateType] = None,
1066 _key_id: typing.Optional[bytes] = None,
1067 _valid_principals: typing.List[bytes] = [],
1068 _valid_for_all_principals: bool = False,
1069 _valid_before: typing.Optional[int] = None,
1070 _valid_after: typing.Optional[int] = None,
1071 _critical_options: typing.List[typing.Tuple[bytes, bytes]] = [],
1072 _extensions: typing.List[typing.Tuple[bytes, bytes]] = [],
1073 ):
1074 self._public_key = _public_key
1075 self._serial = _serial
1076 self._type = _type
1077 self._key_id = _key_id
1078 self._valid_principals = _valid_principals
1079 self._valid_for_all_principals = _valid_for_all_principals
1080 self._valid_before = _valid_before
1081 self._valid_after = _valid_after
1082 self._critical_options = _critical_options
1083 self._extensions = _extensions
1085 def public_key(
1086 self, public_key: SSHCertPublicKeyTypes
1087 ) -> "SSHCertificateBuilder":
1088 if not isinstance(
1089 public_key,
1090 (
1091 ec.EllipticCurvePublicKey,
1092 rsa.RSAPublicKey,
1093 ed25519.Ed25519PublicKey,
1094 ),
1095 ):
1096 raise TypeError("Unsupported key type")
1097 if self._public_key is not None:
1098 raise ValueError("public_key already set")
1100 return SSHCertificateBuilder(
1101 _public_key=public_key,
1102 _serial=self._serial,
1103 _type=self._type,
1104 _key_id=self._key_id,
1105 _valid_principals=self._valid_principals,
1106 _valid_for_all_principals=self._valid_for_all_principals,
1107 _valid_before=self._valid_before,
1108 _valid_after=self._valid_after,
1109 _critical_options=self._critical_options,
1110 _extensions=self._extensions,
1111 )
1113 def serial(self, serial: int) -> "SSHCertificateBuilder":
1114 if not isinstance(serial, int):
1115 raise TypeError("serial must be an integer")
1116 if not 0 <= serial < 2**64:
1117 raise ValueError("serial must be between 0 and 2**64")
1118 if self._serial is not None:
1119 raise ValueError("serial already set")
1121 return SSHCertificateBuilder(
1122 _public_key=self._public_key,
1123 _serial=serial,
1124 _type=self._type,
1125 _key_id=self._key_id,
1126 _valid_principals=self._valid_principals,
1127 _valid_for_all_principals=self._valid_for_all_principals,
1128 _valid_before=self._valid_before,
1129 _valid_after=self._valid_after,
1130 _critical_options=self._critical_options,
1131 _extensions=self._extensions,
1132 )
1134 def type(self, type: SSHCertificateType) -> "SSHCertificateBuilder":
1135 if not isinstance(type, SSHCertificateType):
1136 raise TypeError("type must be an SSHCertificateType")
1137 if self._type is not None:
1138 raise ValueError("type already set")
1140 return SSHCertificateBuilder(
1141 _public_key=self._public_key,
1142 _serial=self._serial,
1143 _type=type,
1144 _key_id=self._key_id,
1145 _valid_principals=self._valid_principals,
1146 _valid_for_all_principals=self._valid_for_all_principals,
1147 _valid_before=self._valid_before,
1148 _valid_after=self._valid_after,
1149 _critical_options=self._critical_options,
1150 _extensions=self._extensions,
1151 )
1153 def key_id(self, key_id: bytes) -> "SSHCertificateBuilder":
1154 if not isinstance(key_id, bytes):
1155 raise TypeError("key_id must be bytes")
1156 if self._key_id is not None:
1157 raise ValueError("key_id already set")
1159 return SSHCertificateBuilder(
1160 _public_key=self._public_key,
1161 _serial=self._serial,
1162 _type=self._type,
1163 _key_id=key_id,
1164 _valid_principals=self._valid_principals,
1165 _valid_for_all_principals=self._valid_for_all_principals,
1166 _valid_before=self._valid_before,
1167 _valid_after=self._valid_after,
1168 _critical_options=self._critical_options,
1169 _extensions=self._extensions,
1170 )
1172 def valid_principals(
1173 self, valid_principals: typing.List[bytes]
1174 ) -> "SSHCertificateBuilder":
1175 if self._valid_for_all_principals:
1176 raise ValueError(
1177 "Principals can't be set because the cert is valid "
1178 "for all principals"
1179 )
1180 if (
1181 not all(isinstance(x, bytes) for x in valid_principals)
1182 or not valid_principals
1183 ):
1184 raise TypeError(
1185 "principals must be a list of bytes and can't be empty"
1186 )
1187 if self._valid_principals:
1188 raise ValueError("valid_principals already set")
1190 if len(valid_principals) > _SSHKEY_CERT_MAX_PRINCIPALS:
1191 raise ValueError(
1192 "Reached or exceeded the maximum number of valid_principals"
1193 )
1195 return SSHCertificateBuilder(
1196 _public_key=self._public_key,
1197 _serial=self._serial,
1198 _type=self._type,
1199 _key_id=self._key_id,
1200 _valid_principals=valid_principals,
1201 _valid_for_all_principals=self._valid_for_all_principals,
1202 _valid_before=self._valid_before,
1203 _valid_after=self._valid_after,
1204 _critical_options=self._critical_options,
1205 _extensions=self._extensions,
1206 )
1208 def valid_for_all_principals(self):
1209 if self._valid_principals:
1210 raise ValueError(
1211 "valid_principals already set, can't set "
1212 "valid_for_all_principals"
1213 )
1214 if self._valid_for_all_principals:
1215 raise ValueError("valid_for_all_principals already set")
1217 return SSHCertificateBuilder(
1218 _public_key=self._public_key,
1219 _serial=self._serial,
1220 _type=self._type,
1221 _key_id=self._key_id,
1222 _valid_principals=self._valid_principals,
1223 _valid_for_all_principals=True,
1224 _valid_before=self._valid_before,
1225 _valid_after=self._valid_after,
1226 _critical_options=self._critical_options,
1227 _extensions=self._extensions,
1228 )
1230 def valid_before(
1231 self, valid_before: typing.Union[int, float]
1232 ) -> "SSHCertificateBuilder":
1233 if not isinstance(valid_before, (int, float)):
1234 raise TypeError("valid_before must be an int or float")
1235 valid_before = int(valid_before)
1236 if valid_before < 0 or valid_before >= 2**64:
1237 raise ValueError("valid_before must [0, 2**64)")
1238 if self._valid_before is not None:
1239 raise ValueError("valid_before already set")
1241 return SSHCertificateBuilder(
1242 _public_key=self._public_key,
1243 _serial=self._serial,
1244 _type=self._type,
1245 _key_id=self._key_id,
1246 _valid_principals=self._valid_principals,
1247 _valid_for_all_principals=self._valid_for_all_principals,
1248 _valid_before=valid_before,
1249 _valid_after=self._valid_after,
1250 _critical_options=self._critical_options,
1251 _extensions=self._extensions,
1252 )
1254 def valid_after(
1255 self, valid_after: typing.Union[int, float]
1256 ) -> "SSHCertificateBuilder":
1257 if not isinstance(valid_after, (int, float)):
1258 raise TypeError("valid_after must be an int or float")
1259 valid_after = int(valid_after)
1260 if valid_after < 0 or valid_after >= 2**64:
1261 raise ValueError("valid_after must [0, 2**64)")
1262 if self._valid_after is not None:
1263 raise ValueError("valid_after already set")
1265 return SSHCertificateBuilder(
1266 _public_key=self._public_key,
1267 _serial=self._serial,
1268 _type=self._type,
1269 _key_id=self._key_id,
1270 _valid_principals=self._valid_principals,
1271 _valid_for_all_principals=self._valid_for_all_principals,
1272 _valid_before=self._valid_before,
1273 _valid_after=valid_after,
1274 _critical_options=self._critical_options,
1275 _extensions=self._extensions,
1276 )
1278 def add_critical_option(
1279 self, name: bytes, value: bytes
1280 ) -> "SSHCertificateBuilder":
1281 if not isinstance(name, bytes) or not isinstance(value, bytes):
1282 raise TypeError("name and value must be bytes")
1283 # This is O(n**2)
1284 if name in [name for name, _ in self._critical_options]:
1285 raise ValueError("Duplicate critical option name")
1287 return SSHCertificateBuilder(
1288 _public_key=self._public_key,
1289 _serial=self._serial,
1290 _type=self._type,
1291 _key_id=self._key_id,
1292 _valid_principals=self._valid_principals,
1293 _valid_for_all_principals=self._valid_for_all_principals,
1294 _valid_before=self._valid_before,
1295 _valid_after=self._valid_after,
1296 _critical_options=self._critical_options + [(name, value)],
1297 _extensions=self._extensions,
1298 )
1300 def add_extension(
1301 self, name: bytes, value: bytes
1302 ) -> "SSHCertificateBuilder":
1303 if not isinstance(name, bytes) or not isinstance(value, bytes):
1304 raise TypeError("name and value must be bytes")
1305 # This is O(n**2)
1306 if name in [name for name, _ in self._extensions]:
1307 raise ValueError("Duplicate extension name")
1309 return SSHCertificateBuilder(
1310 _public_key=self._public_key,
1311 _serial=self._serial,
1312 _type=self._type,
1313 _key_id=self._key_id,
1314 _valid_principals=self._valid_principals,
1315 _valid_for_all_principals=self._valid_for_all_principals,
1316 _valid_before=self._valid_before,
1317 _valid_after=self._valid_after,
1318 _critical_options=self._critical_options,
1319 _extensions=self._extensions + [(name, value)],
1320 )
1322 def sign(self, private_key: SSHCertPrivateKeyTypes) -> SSHCertificate:
1323 if not isinstance(
1324 private_key,
1325 (
1326 ec.EllipticCurvePrivateKey,
1327 rsa.RSAPrivateKey,
1328 ed25519.Ed25519PrivateKey,
1329 ),
1330 ):
1331 raise TypeError("Unsupported private key type")
1333 if self._public_key is None:
1334 raise ValueError("public_key must be set")
1336 # Not required
1337 serial = 0 if self._serial is None else self._serial
1339 if self._type is None:
1340 raise ValueError("type must be set")
1342 # Not required
1343 key_id = b"" if self._key_id is None else self._key_id
1345 # A zero length list is valid, but means the certificate
1346 # is valid for any principal of the specified type. We require
1347 # the user to explicitly set valid_for_all_principals to get
1348 # that behavior.
1349 if not self._valid_principals and not self._valid_for_all_principals:
1350 raise ValueError(
1351 "valid_principals must be set if valid_for_all_principals "
1352 "is False"
1353 )
1355 if self._valid_before is None:
1356 raise ValueError("valid_before must be set")
1358 if self._valid_after is None:
1359 raise ValueError("valid_after must be set")
1361 if self._valid_after > self._valid_before:
1362 raise ValueError("valid_after must be earlier than valid_before")
1364 # lexically sort our byte strings
1365 self._critical_options.sort(key=lambda x: x[0])
1366 self._extensions.sort(key=lambda x: x[0])
1368 key_type = _get_ssh_key_type(self._public_key)
1369 cert_prefix = key_type + _CERT_SUFFIX
1371 # Marshal the bytes to be signed
1372 nonce = os.urandom(32)
1373 kformat = _lookup_kformat(key_type)
1374 f = _FragList()
1375 f.put_sshstr(cert_prefix)
1376 f.put_sshstr(nonce)
1377 kformat.encode_public(self._public_key, f)
1378 f.put_u64(serial)
1379 f.put_u32(self._type.value)
1380 f.put_sshstr(key_id)
1381 fprincipals = _FragList()
1382 for p in self._valid_principals:
1383 fprincipals.put_sshstr(p)
1384 f.put_sshstr(fprincipals.tobytes())
1385 f.put_u64(self._valid_after)
1386 f.put_u64(self._valid_before)
1387 fcrit = _FragList()
1388 for name, value in self._critical_options:
1389 fcrit.put_sshstr(name)
1390 fcrit.put_sshstr(value)
1391 f.put_sshstr(fcrit.tobytes())
1392 fext = _FragList()
1393 for name, value in self._extensions:
1394 fext.put_sshstr(name)
1395 fext.put_sshstr(value)
1396 f.put_sshstr(fext.tobytes())
1397 f.put_sshstr(b"") # RESERVED FIELD
1398 # encode CA public key
1399 ca_type = _get_ssh_key_type(private_key)
1400 caformat = _lookup_kformat(ca_type)
1401 caf = _FragList()
1402 caf.put_sshstr(ca_type)
1403 caformat.encode_public(private_key.public_key(), caf)
1404 f.put_sshstr(caf.tobytes())
1405 # Sigs according to the rules defined for the CA's public key
1406 # (RFC4253 section 6.6 for ssh-rsa, RFC5656 for ECDSA,
1407 # and RFC8032 for Ed25519).
1408 if isinstance(private_key, ed25519.Ed25519PrivateKey):
1409 signature = private_key.sign(f.tobytes())
1410 fsig = _FragList()
1411 fsig.put_sshstr(ca_type)
1412 fsig.put_sshstr(signature)
1413 f.put_sshstr(fsig.tobytes())
1414 elif isinstance(private_key, ec.EllipticCurvePrivateKey):
1415 hash_alg = _get_ec_hash_alg(private_key.curve)
1416 signature = private_key.sign(f.tobytes(), ec.ECDSA(hash_alg))
1417 r, s = asym_utils.decode_dss_signature(signature)
1418 fsig = _FragList()
1419 fsig.put_sshstr(ca_type)
1420 fsigblob = _FragList()
1421 fsigblob.put_mpint(r)
1422 fsigblob.put_mpint(s)
1423 fsig.put_sshstr(fsigblob.tobytes())
1424 f.put_sshstr(fsig.tobytes())
1426 else:
1427 assert isinstance(private_key, rsa.RSAPrivateKey)
1428 # Just like Golang, we're going to use SHA512 for RSA
1429 # https://cs.opensource.google/go/x/crypto/+/refs/tags/
1430 # v0.4.0:ssh/certs.go;l=445
1431 # RFC 8332 defines SHA256 and 512 as options
1432 fsig = _FragList()
1433 fsig.put_sshstr(_SSH_RSA_SHA512)
1434 signature = private_key.sign(
1435 f.tobytes(), padding.PKCS1v15(), hashes.SHA512()
1436 )
1437 fsig.put_sshstr(signature)
1438 f.put_sshstr(fsig.tobytes())
1440 cert_data = binascii.b2a_base64(f.tobytes()).strip()
1441 # load_ssh_public_identity returns a union, but this is
1442 # guaranteed to be an SSHCertificate, so we cast to make
1443 # mypy happy.
1444 return typing.cast(
1445 SSHCertificate,
1446 load_ssh_public_identity(b"".join([cert_prefix, b" ", cert_data])),
1447 )