Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/cryptography/hazmat/primitives/serialization/ssh.py: 21%
758 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 07:26 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 07:26 +0000
1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
5from __future__ import annotations
7import binascii
8import enum
9import os
10import re
11import typing
12import warnings
13from base64 import encodebytes as _base64_encode
14from dataclasses import dataclass
16from cryptography import utils
17from cryptography.exceptions import UnsupportedAlgorithm
18from cryptography.hazmat.primitives import hashes
19from cryptography.hazmat.primitives.asymmetric import (
20 dsa,
21 ec,
22 ed25519,
23 padding,
24 rsa,
25)
26from cryptography.hazmat.primitives.asymmetric import utils as asym_utils
27from cryptography.hazmat.primitives.ciphers import (
28 AEADDecryptionContext,
29 Cipher,
30 algorithms,
31 modes,
32)
33from cryptography.hazmat.primitives.serialization import (
34 Encoding,
35 KeySerializationEncryption,
36 NoEncryption,
37 PrivateFormat,
38 PublicFormat,
39 _KeySerializationEncryption,
40)
42try:
43 from bcrypt import kdf as _bcrypt_kdf
45 _bcrypt_supported = True
46except ImportError:
47 _bcrypt_supported = False
49 def _bcrypt_kdf(
50 password: bytes,
51 salt: bytes,
52 desired_key_bytes: int,
53 rounds: int,
54 ignore_few_rounds: bool = False,
55 ) -> bytes:
56 raise UnsupportedAlgorithm("Need bcrypt module")
59_SSH_ED25519 = b"ssh-ed25519"
60_SSH_RSA = b"ssh-rsa"
61_SSH_DSA = b"ssh-dss"
62_ECDSA_NISTP256 = b"ecdsa-sha2-nistp256"
63_ECDSA_NISTP384 = b"ecdsa-sha2-nistp384"
64_ECDSA_NISTP521 = b"ecdsa-sha2-nistp521"
65_CERT_SUFFIX = b"-cert-v01@openssh.com"
67# These are not key types, only algorithms, so they cannot appear
68# as a public key type
69_SSH_RSA_SHA256 = b"rsa-sha2-256"
70_SSH_RSA_SHA512 = b"rsa-sha2-512"
72_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)")
73_SK_MAGIC = b"openssh-key-v1\0"
74_SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----"
75_SK_END = b"-----END OPENSSH PRIVATE KEY-----"
76_BCRYPT = b"bcrypt"
77_NONE = b"none"
78_DEFAULT_CIPHER = b"aes256-ctr"
79_DEFAULT_ROUNDS = 16
81# re is only way to work on bytes-like data
82_PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL)
84# padding for max blocksize
85_PADDING = memoryview(bytearray(range(1, 1 + 16)))
88@dataclass
89class _SSHCipher:
90 alg: type[algorithms.AES]
91 key_len: int
92 mode: type[modes.CTR] | type[modes.CBC] | type[modes.GCM]
93 block_len: int
94 iv_len: int
95 tag_len: int | None
96 is_aead: bool
99# ciphers that are actually used in key wrapping
100_SSH_CIPHERS: dict[bytes, _SSHCipher] = {
101 b"aes256-ctr": _SSHCipher(
102 alg=algorithms.AES,
103 key_len=32,
104 mode=modes.CTR,
105 block_len=16,
106 iv_len=16,
107 tag_len=None,
108 is_aead=False,
109 ),
110 b"aes256-cbc": _SSHCipher(
111 alg=algorithms.AES,
112 key_len=32,
113 mode=modes.CBC,
114 block_len=16,
115 iv_len=16,
116 tag_len=None,
117 is_aead=False,
118 ),
119 b"aes256-gcm@openssh.com": _SSHCipher(
120 alg=algorithms.AES,
121 key_len=32,
122 mode=modes.GCM,
123 block_len=16,
124 iv_len=12,
125 tag_len=16,
126 is_aead=True,
127 ),
128}
130# map local curve name to key type
131_ECDSA_KEY_TYPE = {
132 "secp256r1": _ECDSA_NISTP256,
133 "secp384r1": _ECDSA_NISTP384,
134 "secp521r1": _ECDSA_NISTP521,
135}
138def _get_ssh_key_type(key: SSHPrivateKeyTypes | SSHPublicKeyTypes) -> bytes:
139 if isinstance(key, ec.EllipticCurvePrivateKey):
140 key_type = _ecdsa_key_type(key.public_key())
141 elif isinstance(key, ec.EllipticCurvePublicKey):
142 key_type = _ecdsa_key_type(key)
143 elif isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)):
144 key_type = _SSH_RSA
145 elif isinstance(key, (dsa.DSAPrivateKey, dsa.DSAPublicKey)):
146 key_type = _SSH_DSA
147 elif isinstance(
148 key, (ed25519.Ed25519PrivateKey, ed25519.Ed25519PublicKey)
149 ):
150 key_type = _SSH_ED25519
151 else:
152 raise ValueError("Unsupported key type")
154 return key_type
157def _ecdsa_key_type(public_key: ec.EllipticCurvePublicKey) -> bytes:
158 """Return SSH key_type and curve_name for private key."""
159 curve = public_key.curve
160 if curve.name not in _ECDSA_KEY_TYPE:
161 raise ValueError(
162 f"Unsupported curve for ssh private key: {curve.name!r}"
163 )
164 return _ECDSA_KEY_TYPE[curve.name]
167def _ssh_pem_encode(
168 data: bytes,
169 prefix: bytes = _SK_START + b"\n",
170 suffix: bytes = _SK_END + b"\n",
171) -> bytes:
172 return b"".join([prefix, _base64_encode(data), suffix])
175def _check_block_size(data: bytes, block_len: int) -> None:
176 """Require data to be full blocks"""
177 if not data or len(data) % block_len != 0:
178 raise ValueError("Corrupt data: missing padding")
181def _check_empty(data: bytes) -> None:
182 """All data should have been parsed."""
183 if data:
184 raise ValueError("Corrupt data: unparsed data")
187def _init_cipher(
188 ciphername: bytes,
189 password: bytes | None,
190 salt: bytes,
191 rounds: int,
192) -> Cipher[modes.CBC | modes.CTR | modes.GCM]:
193 """Generate key + iv and return cipher."""
194 if not password:
195 raise ValueError("Key is password-protected.")
197 ciph = _SSH_CIPHERS[ciphername]
198 seed = _bcrypt_kdf(
199 password, salt, ciph.key_len + ciph.iv_len, rounds, True
200 )
201 return Cipher(
202 ciph.alg(seed[: ciph.key_len]),
203 ciph.mode(seed[ciph.key_len :]),
204 )
207def _get_u32(data: memoryview) -> tuple[int, memoryview]:
208 """Uint32"""
209 if len(data) < 4:
210 raise ValueError("Invalid data")
211 return int.from_bytes(data[:4], byteorder="big"), data[4:]
214def _get_u64(data: memoryview) -> tuple[int, memoryview]:
215 """Uint64"""
216 if len(data) < 8:
217 raise ValueError("Invalid data")
218 return int.from_bytes(data[:8], byteorder="big"), data[8:]
221def _get_sshstr(data: memoryview) -> tuple[memoryview, memoryview]:
222 """Bytes with u32 length prefix"""
223 n, data = _get_u32(data)
224 if n > len(data):
225 raise ValueError("Invalid data")
226 return data[:n], data[n:]
229def _get_mpint(data: memoryview) -> tuple[int, memoryview]:
230 """Big integer."""
231 val, data = _get_sshstr(data)
232 if val and val[0] > 0x7F:
233 raise ValueError("Invalid data")
234 return int.from_bytes(val, "big"), data
237def _to_mpint(val: int) -> bytes:
238 """Storage format for signed bigint."""
239 if val < 0:
240 raise ValueError("negative mpint not allowed")
241 if not val:
242 return b""
243 nbytes = (val.bit_length() + 8) // 8
244 return utils.int_to_bytes(val, nbytes)
247class _FragList:
248 """Build recursive structure without data copy."""
250 flist: list[bytes]
252 def __init__(self, init: list[bytes] | None = None) -> None:
253 self.flist = []
254 if init:
255 self.flist.extend(init)
257 def put_raw(self, val: bytes) -> None:
258 """Add plain bytes"""
259 self.flist.append(val)
261 def put_u32(self, val: int) -> None:
262 """Big-endian uint32"""
263 self.flist.append(val.to_bytes(length=4, byteorder="big"))
265 def put_u64(self, val: int) -> None:
266 """Big-endian uint64"""
267 self.flist.append(val.to_bytes(length=8, byteorder="big"))
269 def put_sshstr(self, val: bytes | _FragList) -> None:
270 """Bytes prefixed with u32 length"""
271 if isinstance(val, (bytes, memoryview, bytearray)):
272 self.put_u32(len(val))
273 self.flist.append(val)
274 else:
275 self.put_u32(val.size())
276 self.flist.extend(val.flist)
278 def put_mpint(self, val: int) -> None:
279 """Big-endian bigint prefixed with u32 length"""
280 self.put_sshstr(_to_mpint(val))
282 def size(self) -> int:
283 """Current number of bytes"""
284 return sum(map(len, self.flist))
286 def render(self, dstbuf: memoryview, pos: int = 0) -> int:
287 """Write into bytearray"""
288 for frag in self.flist:
289 flen = len(frag)
290 start, pos = pos, pos + flen
291 dstbuf[start:pos] = frag
292 return pos
294 def tobytes(self) -> bytes:
295 """Return as bytes"""
296 buf = memoryview(bytearray(self.size()))
297 self.render(buf)
298 return buf.tobytes()
301class _SSHFormatRSA:
302 """Format for RSA keys.
304 Public:
305 mpint e, n
306 Private:
307 mpint n, e, d, iqmp, p, q
308 """
310 def get_public(self, data: memoryview):
311 """RSA public fields"""
312 e, data = _get_mpint(data)
313 n, data = _get_mpint(data)
314 return (e, n), data
316 def load_public(
317 self, data: memoryview
318 ) -> tuple[rsa.RSAPublicKey, memoryview]:
319 """Make RSA public key from data."""
320 (e, n), data = self.get_public(data)
321 public_numbers = rsa.RSAPublicNumbers(e, n)
322 public_key = public_numbers.public_key()
323 return public_key, data
325 def load_private(
326 self, data: memoryview, pubfields
327 ) -> tuple[rsa.RSAPrivateKey, memoryview]:
328 """Make RSA private key from data."""
329 n, data = _get_mpint(data)
330 e, data = _get_mpint(data)
331 d, data = _get_mpint(data)
332 iqmp, data = _get_mpint(data)
333 p, data = _get_mpint(data)
334 q, data = _get_mpint(data)
336 if (e, n) != pubfields:
337 raise ValueError("Corrupt data: rsa field mismatch")
338 dmp1 = rsa.rsa_crt_dmp1(d, p)
339 dmq1 = rsa.rsa_crt_dmq1(d, q)
340 public_numbers = rsa.RSAPublicNumbers(e, n)
341 private_numbers = rsa.RSAPrivateNumbers(
342 p, q, d, dmp1, dmq1, iqmp, public_numbers
343 )
344 private_key = private_numbers.private_key()
345 return private_key, data
347 def encode_public(
348 self, public_key: rsa.RSAPublicKey, f_pub: _FragList
349 ) -> None:
350 """Write RSA public key"""
351 pubn = public_key.public_numbers()
352 f_pub.put_mpint(pubn.e)
353 f_pub.put_mpint(pubn.n)
355 def encode_private(
356 self, private_key: rsa.RSAPrivateKey, f_priv: _FragList
357 ) -> None:
358 """Write RSA private key"""
359 private_numbers = private_key.private_numbers()
360 public_numbers = private_numbers.public_numbers
362 f_priv.put_mpint(public_numbers.n)
363 f_priv.put_mpint(public_numbers.e)
365 f_priv.put_mpint(private_numbers.d)
366 f_priv.put_mpint(private_numbers.iqmp)
367 f_priv.put_mpint(private_numbers.p)
368 f_priv.put_mpint(private_numbers.q)
371class _SSHFormatDSA:
372 """Format for DSA keys.
374 Public:
375 mpint p, q, g, y
376 Private:
377 mpint p, q, g, y, x
378 """
380 def get_public(self, data: memoryview) -> tuple[tuple, memoryview]:
381 """DSA public fields"""
382 p, data = _get_mpint(data)
383 q, data = _get_mpint(data)
384 g, data = _get_mpint(data)
385 y, data = _get_mpint(data)
386 return (p, q, g, y), data
388 def load_public(
389 self, data: memoryview
390 ) -> tuple[dsa.DSAPublicKey, memoryview]:
391 """Make DSA public key from data."""
392 (p, q, g, y), data = self.get_public(data)
393 parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
394 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
395 self._validate(public_numbers)
396 public_key = public_numbers.public_key()
397 return public_key, data
399 def load_private(
400 self, data: memoryview, pubfields
401 ) -> tuple[dsa.DSAPrivateKey, memoryview]:
402 """Make DSA private key from data."""
403 (p, q, g, y), data = self.get_public(data)
404 x, data = _get_mpint(data)
406 if (p, q, g, y) != pubfields:
407 raise ValueError("Corrupt data: dsa field mismatch")
408 parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
409 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
410 self._validate(public_numbers)
411 private_numbers = dsa.DSAPrivateNumbers(x, public_numbers)
412 private_key = private_numbers.private_key()
413 return private_key, data
415 def encode_public(
416 self, public_key: dsa.DSAPublicKey, f_pub: _FragList
417 ) -> None:
418 """Write DSA public key"""
419 public_numbers = public_key.public_numbers()
420 parameter_numbers = public_numbers.parameter_numbers
421 self._validate(public_numbers)
423 f_pub.put_mpint(parameter_numbers.p)
424 f_pub.put_mpint(parameter_numbers.q)
425 f_pub.put_mpint(parameter_numbers.g)
426 f_pub.put_mpint(public_numbers.y)
428 def encode_private(
429 self, private_key: dsa.DSAPrivateKey, f_priv: _FragList
430 ) -> None:
431 """Write DSA private key"""
432 self.encode_public(private_key.public_key(), f_priv)
433 f_priv.put_mpint(private_key.private_numbers().x)
435 def _validate(self, public_numbers: dsa.DSAPublicNumbers) -> None:
436 parameter_numbers = public_numbers.parameter_numbers
437 if parameter_numbers.p.bit_length() != 1024:
438 raise ValueError("SSH supports only 1024 bit DSA keys")
441class _SSHFormatECDSA:
442 """Format for ECDSA keys.
444 Public:
445 str curve
446 bytes point
447 Private:
448 str curve
449 bytes point
450 mpint secret
451 """
453 def __init__(self, ssh_curve_name: bytes, curve: ec.EllipticCurve):
454 self.ssh_curve_name = ssh_curve_name
455 self.curve = curve
457 def get_public(self, data: memoryview) -> tuple[tuple, memoryview]:
458 """ECDSA public fields"""
459 curve, data = _get_sshstr(data)
460 point, data = _get_sshstr(data)
461 if curve != self.ssh_curve_name:
462 raise ValueError("Curve name mismatch")
463 if point[0] != 4:
464 raise NotImplementedError("Need uncompressed point")
465 return (curve, point), data
467 def load_public(
468 self, data: memoryview
469 ) -> tuple[ec.EllipticCurvePublicKey, memoryview]:
470 """Make ECDSA public key from data."""
471 (_, point), data = self.get_public(data)
472 public_key = ec.EllipticCurvePublicKey.from_encoded_point(
473 self.curve, point.tobytes()
474 )
475 return public_key, data
477 def load_private(
478 self, data: memoryview, pubfields
479 ) -> tuple[ec.EllipticCurvePrivateKey, memoryview]:
480 """Make ECDSA private key from data."""
481 (curve_name, point), data = self.get_public(data)
482 secret, data = _get_mpint(data)
484 if (curve_name, point) != pubfields:
485 raise ValueError("Corrupt data: ecdsa field mismatch")
486 private_key = ec.derive_private_key(secret, self.curve)
487 return private_key, data
489 def encode_public(
490 self, public_key: ec.EllipticCurvePublicKey, f_pub: _FragList
491 ) -> None:
492 """Write ECDSA public key"""
493 point = public_key.public_bytes(
494 Encoding.X962, PublicFormat.UncompressedPoint
495 )
496 f_pub.put_sshstr(self.ssh_curve_name)
497 f_pub.put_sshstr(point)
499 def encode_private(
500 self, private_key: ec.EllipticCurvePrivateKey, f_priv: _FragList
501 ) -> None:
502 """Write ECDSA private key"""
503 public_key = private_key.public_key()
504 private_numbers = private_key.private_numbers()
506 self.encode_public(public_key, f_priv)
507 f_priv.put_mpint(private_numbers.private_value)
510class _SSHFormatEd25519:
511 """Format for Ed25519 keys.
513 Public:
514 bytes point
515 Private:
516 bytes point
517 bytes secret_and_point
518 """
520 def get_public(self, data: memoryview) -> tuple[tuple, memoryview]:
521 """Ed25519 public fields"""
522 point, data = _get_sshstr(data)
523 return (point,), data
525 def load_public(
526 self, data: memoryview
527 ) -> tuple[ed25519.Ed25519PublicKey, memoryview]:
528 """Make Ed25519 public key from data."""
529 (point,), data = self.get_public(data)
530 public_key = ed25519.Ed25519PublicKey.from_public_bytes(
531 point.tobytes()
532 )
533 return public_key, data
535 def load_private(
536 self, data: memoryview, pubfields
537 ) -> tuple[ed25519.Ed25519PrivateKey, memoryview]:
538 """Make Ed25519 private key from data."""
539 (point,), data = self.get_public(data)
540 keypair, data = _get_sshstr(data)
542 secret = keypair[:32]
543 point2 = keypair[32:]
544 if point != point2 or (point,) != pubfields:
545 raise ValueError("Corrupt data: ed25519 field mismatch")
546 private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret)
547 return private_key, data
549 def encode_public(
550 self, public_key: ed25519.Ed25519PublicKey, f_pub: _FragList
551 ) -> None:
552 """Write Ed25519 public key"""
553 raw_public_key = public_key.public_bytes(
554 Encoding.Raw, PublicFormat.Raw
555 )
556 f_pub.put_sshstr(raw_public_key)
558 def encode_private(
559 self, private_key: ed25519.Ed25519PrivateKey, f_priv: _FragList
560 ) -> None:
561 """Write Ed25519 private key"""
562 public_key = private_key.public_key()
563 raw_private_key = private_key.private_bytes(
564 Encoding.Raw, PrivateFormat.Raw, NoEncryption()
565 )
566 raw_public_key = public_key.public_bytes(
567 Encoding.Raw, PublicFormat.Raw
568 )
569 f_keypair = _FragList([raw_private_key, raw_public_key])
571 self.encode_public(public_key, f_priv)
572 f_priv.put_sshstr(f_keypair)
575_KEY_FORMATS = {
576 _SSH_RSA: _SSHFormatRSA(),
577 _SSH_DSA: _SSHFormatDSA(),
578 _SSH_ED25519: _SSHFormatEd25519(),
579 _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()),
580 _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()),
581 _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()),
582}
585def _lookup_kformat(key_type: bytes):
586 """Return valid format or throw error"""
587 if not isinstance(key_type, bytes):
588 key_type = memoryview(key_type).tobytes()
589 if key_type in _KEY_FORMATS:
590 return _KEY_FORMATS[key_type]
591 raise UnsupportedAlgorithm(f"Unsupported key type: {key_type!r}")
594SSHPrivateKeyTypes = typing.Union[
595 ec.EllipticCurvePrivateKey,
596 rsa.RSAPrivateKey,
597 dsa.DSAPrivateKey,
598 ed25519.Ed25519PrivateKey,
599]
602def load_ssh_private_key(
603 data: bytes,
604 password: bytes | None,
605 backend: typing.Any = None,
606) -> SSHPrivateKeyTypes:
607 """Load private key from OpenSSH custom encoding."""
608 utils._check_byteslike("data", data)
609 if password is not None:
610 utils._check_bytes("password", password)
612 m = _PEM_RC.search(data)
613 if not m:
614 raise ValueError("Not OpenSSH private key format")
615 p1 = m.start(1)
616 p2 = m.end(1)
617 data = binascii.a2b_base64(memoryview(data)[p1:p2])
618 if not data.startswith(_SK_MAGIC):
619 raise ValueError("Not OpenSSH private key format")
620 data = memoryview(data)[len(_SK_MAGIC) :]
622 # parse header
623 ciphername, data = _get_sshstr(data)
624 kdfname, data = _get_sshstr(data)
625 kdfoptions, data = _get_sshstr(data)
626 nkeys, data = _get_u32(data)
627 if nkeys != 1:
628 raise ValueError("Only one key supported")
630 # load public key data
631 pubdata, data = _get_sshstr(data)
632 pub_key_type, pubdata = _get_sshstr(pubdata)
633 kformat = _lookup_kformat(pub_key_type)
634 pubfields, pubdata = kformat.get_public(pubdata)
635 _check_empty(pubdata)
637 if (ciphername, kdfname) != (_NONE, _NONE):
638 ciphername_bytes = ciphername.tobytes()
639 if ciphername_bytes not in _SSH_CIPHERS:
640 raise UnsupportedAlgorithm(
641 f"Unsupported cipher: {ciphername_bytes!r}"
642 )
643 if kdfname != _BCRYPT:
644 raise UnsupportedAlgorithm(f"Unsupported KDF: {kdfname!r}")
645 blklen = _SSH_CIPHERS[ciphername_bytes].block_len
646 tag_len = _SSH_CIPHERS[ciphername_bytes].tag_len
647 # load secret data
648 edata, data = _get_sshstr(data)
649 # see https://bugzilla.mindrot.org/show_bug.cgi?id=3553 for
650 # information about how OpenSSH handles AEAD tags
651 if _SSH_CIPHERS[ciphername_bytes].is_aead:
652 tag = bytes(data)
653 if len(tag) != tag_len:
654 raise ValueError("Corrupt data: invalid tag length for cipher")
655 else:
656 _check_empty(data)
657 _check_block_size(edata, blklen)
658 salt, kbuf = _get_sshstr(kdfoptions)
659 rounds, kbuf = _get_u32(kbuf)
660 _check_empty(kbuf)
661 ciph = _init_cipher(ciphername_bytes, password, salt.tobytes(), rounds)
662 dec = ciph.decryptor()
663 edata = memoryview(dec.update(edata))
664 if _SSH_CIPHERS[ciphername_bytes].is_aead:
665 assert isinstance(dec, AEADDecryptionContext)
666 _check_empty(dec.finalize_with_tag(tag))
667 else:
668 # _check_block_size requires data to be a full block so there
669 # should be no output from finalize
670 _check_empty(dec.finalize())
671 else:
672 # load secret data
673 edata, data = _get_sshstr(data)
674 _check_empty(data)
675 blklen = 8
676 _check_block_size(edata, blklen)
677 ck1, edata = _get_u32(edata)
678 ck2, edata = _get_u32(edata)
679 if ck1 != ck2:
680 raise ValueError("Corrupt data: broken checksum")
682 # load per-key struct
683 key_type, edata = _get_sshstr(edata)
684 if key_type != pub_key_type:
685 raise ValueError("Corrupt data: key type mismatch")
686 private_key, edata = kformat.load_private(edata, pubfields)
687 # We don't use the comment
688 _, edata = _get_sshstr(edata)
690 # yes, SSH does padding check *after* all other parsing is done.
691 # need to follow as it writes zero-byte padding too.
692 if edata != _PADDING[: len(edata)]:
693 raise ValueError("Corrupt data: invalid padding")
695 if isinstance(private_key, dsa.DSAPrivateKey):
696 warnings.warn(
697 "SSH DSA keys are deprecated and will be removed in a future "
698 "release.",
699 utils.DeprecatedIn40,
700 stacklevel=2,
701 )
703 return private_key
706def _serialize_ssh_private_key(
707 private_key: SSHPrivateKeyTypes,
708 password: bytes,
709 encryption_algorithm: KeySerializationEncryption,
710) -> bytes:
711 """Serialize private key with OpenSSH custom encoding."""
712 utils._check_bytes("password", password)
713 if isinstance(private_key, dsa.DSAPrivateKey):
714 warnings.warn(
715 "SSH DSA key support is deprecated and will be "
716 "removed in a future release",
717 utils.DeprecatedIn40,
718 stacklevel=4,
719 )
721 key_type = _get_ssh_key_type(private_key)
722 kformat = _lookup_kformat(key_type)
724 # setup parameters
725 f_kdfoptions = _FragList()
726 if password:
727 ciphername = _DEFAULT_CIPHER
728 blklen = _SSH_CIPHERS[ciphername].block_len
729 kdfname = _BCRYPT
730 rounds = _DEFAULT_ROUNDS
731 if (
732 isinstance(encryption_algorithm, _KeySerializationEncryption)
733 and encryption_algorithm._kdf_rounds is not None
734 ):
735 rounds = encryption_algorithm._kdf_rounds
736 salt = os.urandom(16)
737 f_kdfoptions.put_sshstr(salt)
738 f_kdfoptions.put_u32(rounds)
739 ciph = _init_cipher(ciphername, password, salt, rounds)
740 else:
741 ciphername = kdfname = _NONE
742 blklen = 8
743 ciph = None
744 nkeys = 1
745 checkval = os.urandom(4)
746 comment = b""
748 # encode public and private parts together
749 f_public_key = _FragList()
750 f_public_key.put_sshstr(key_type)
751 kformat.encode_public(private_key.public_key(), f_public_key)
753 f_secrets = _FragList([checkval, checkval])
754 f_secrets.put_sshstr(key_type)
755 kformat.encode_private(private_key, f_secrets)
756 f_secrets.put_sshstr(comment)
757 f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)])
759 # top-level structure
760 f_main = _FragList()
761 f_main.put_raw(_SK_MAGIC)
762 f_main.put_sshstr(ciphername)
763 f_main.put_sshstr(kdfname)
764 f_main.put_sshstr(f_kdfoptions)
765 f_main.put_u32(nkeys)
766 f_main.put_sshstr(f_public_key)
767 f_main.put_sshstr(f_secrets)
769 # copy result info bytearray
770 slen = f_secrets.size()
771 mlen = f_main.size()
772 buf = memoryview(bytearray(mlen + blklen))
773 f_main.render(buf)
774 ofs = mlen - slen
776 # encrypt in-place
777 if ciph is not None:
778 ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:])
780 return _ssh_pem_encode(buf[:mlen])
783SSHPublicKeyTypes = typing.Union[
784 ec.EllipticCurvePublicKey,
785 rsa.RSAPublicKey,
786 dsa.DSAPublicKey,
787 ed25519.Ed25519PublicKey,
788]
790SSHCertPublicKeyTypes = typing.Union[
791 ec.EllipticCurvePublicKey,
792 rsa.RSAPublicKey,
793 ed25519.Ed25519PublicKey,
794]
797class SSHCertificateType(enum.Enum):
798 USER = 1
799 HOST = 2
802class SSHCertificate:
803 def __init__(
804 self,
805 _nonce: memoryview,
806 _public_key: SSHPublicKeyTypes,
807 _serial: int,
808 _cctype: int,
809 _key_id: memoryview,
810 _valid_principals: list[bytes],
811 _valid_after: int,
812 _valid_before: int,
813 _critical_options: dict[bytes, bytes],
814 _extensions: dict[bytes, bytes],
815 _sig_type: memoryview,
816 _sig_key: memoryview,
817 _inner_sig_type: memoryview,
818 _signature: memoryview,
819 _tbs_cert_body: memoryview,
820 _cert_key_type: bytes,
821 _cert_body: memoryview,
822 ):
823 self._nonce = _nonce
824 self._public_key = _public_key
825 self._serial = _serial
826 try:
827 self._type = SSHCertificateType(_cctype)
828 except ValueError:
829 raise ValueError("Invalid certificate type")
830 self._key_id = _key_id
831 self._valid_principals = _valid_principals
832 self._valid_after = _valid_after
833 self._valid_before = _valid_before
834 self._critical_options = _critical_options
835 self._extensions = _extensions
836 self._sig_type = _sig_type
837 self._sig_key = _sig_key
838 self._inner_sig_type = _inner_sig_type
839 self._signature = _signature
840 self._cert_key_type = _cert_key_type
841 self._cert_body = _cert_body
842 self._tbs_cert_body = _tbs_cert_body
844 @property
845 def nonce(self) -> bytes:
846 return bytes(self._nonce)
848 def public_key(self) -> SSHCertPublicKeyTypes:
849 # make mypy happy until we remove DSA support entirely and
850 # the underlying union won't have a disallowed type
851 return typing.cast(SSHCertPublicKeyTypes, self._public_key)
853 @property
854 def serial(self) -> int:
855 return self._serial
857 @property
858 def type(self) -> SSHCertificateType:
859 return self._type
861 @property
862 def key_id(self) -> bytes:
863 return bytes(self._key_id)
865 @property
866 def valid_principals(self) -> list[bytes]:
867 return self._valid_principals
869 @property
870 def valid_before(self) -> int:
871 return self._valid_before
873 @property
874 def valid_after(self) -> int:
875 return self._valid_after
877 @property
878 def critical_options(self) -> dict[bytes, bytes]:
879 return self._critical_options
881 @property
882 def extensions(self) -> dict[bytes, bytes]:
883 return self._extensions
885 def signature_key(self) -> SSHCertPublicKeyTypes:
886 sigformat = _lookup_kformat(self._sig_type)
887 signature_key, sigkey_rest = sigformat.load_public(self._sig_key)
888 _check_empty(sigkey_rest)
889 return signature_key
891 def public_bytes(self) -> bytes:
892 return (
893 bytes(self._cert_key_type)
894 + b" "
895 + binascii.b2a_base64(bytes(self._cert_body), newline=False)
896 )
898 def verify_cert_signature(self) -> None:
899 signature_key = self.signature_key()
900 if isinstance(signature_key, ed25519.Ed25519PublicKey):
901 signature_key.verify(
902 bytes(self._signature), bytes(self._tbs_cert_body)
903 )
904 elif isinstance(signature_key, ec.EllipticCurvePublicKey):
905 # The signature is encoded as a pair of big-endian integers
906 r, data = _get_mpint(self._signature)
907 s, data = _get_mpint(data)
908 _check_empty(data)
909 computed_sig = asym_utils.encode_dss_signature(r, s)
910 hash_alg = _get_ec_hash_alg(signature_key.curve)
911 signature_key.verify(
912 computed_sig, bytes(self._tbs_cert_body), ec.ECDSA(hash_alg)
913 )
914 else:
915 assert isinstance(signature_key, rsa.RSAPublicKey)
916 if self._inner_sig_type == _SSH_RSA:
917 hash_alg = hashes.SHA1()
918 elif self._inner_sig_type == _SSH_RSA_SHA256:
919 hash_alg = hashes.SHA256()
920 else:
921 assert self._inner_sig_type == _SSH_RSA_SHA512
922 hash_alg = hashes.SHA512()
923 signature_key.verify(
924 bytes(self._signature),
925 bytes(self._tbs_cert_body),
926 padding.PKCS1v15(),
927 hash_alg,
928 )
931def _get_ec_hash_alg(curve: ec.EllipticCurve) -> hashes.HashAlgorithm:
932 if isinstance(curve, ec.SECP256R1):
933 return hashes.SHA256()
934 elif isinstance(curve, ec.SECP384R1):
935 return hashes.SHA384()
936 else:
937 assert isinstance(curve, ec.SECP521R1)
938 return hashes.SHA512()
941def _load_ssh_public_identity(
942 data: bytes,
943 _legacy_dsa_allowed=False,
944) -> SSHCertificate | SSHPublicKeyTypes:
945 utils._check_byteslike("data", data)
947 m = _SSH_PUBKEY_RC.match(data)
948 if not m:
949 raise ValueError("Invalid line format")
950 key_type = orig_key_type = m.group(1)
951 key_body = m.group(2)
952 with_cert = False
953 if key_type.endswith(_CERT_SUFFIX):
954 with_cert = True
955 key_type = key_type[: -len(_CERT_SUFFIX)]
956 if key_type == _SSH_DSA and not _legacy_dsa_allowed:
957 raise UnsupportedAlgorithm(
958 "DSA keys aren't supported in SSH certificates"
959 )
960 kformat = _lookup_kformat(key_type)
962 try:
963 rest = memoryview(binascii.a2b_base64(key_body))
964 except (TypeError, binascii.Error):
965 raise ValueError("Invalid format")
967 if with_cert:
968 cert_body = rest
969 inner_key_type, rest = _get_sshstr(rest)
970 if inner_key_type != orig_key_type:
971 raise ValueError("Invalid key format")
972 if with_cert:
973 nonce, rest = _get_sshstr(rest)
974 public_key, rest = kformat.load_public(rest)
975 if with_cert:
976 serial, rest = _get_u64(rest)
977 cctype, rest = _get_u32(rest)
978 key_id, rest = _get_sshstr(rest)
979 principals, rest = _get_sshstr(rest)
980 valid_principals = []
981 while principals:
982 principal, principals = _get_sshstr(principals)
983 valid_principals.append(bytes(principal))
984 valid_after, rest = _get_u64(rest)
985 valid_before, rest = _get_u64(rest)
986 crit_options, rest = _get_sshstr(rest)
987 critical_options = _parse_exts_opts(crit_options)
988 exts, rest = _get_sshstr(rest)
989 extensions = _parse_exts_opts(exts)
990 # Get the reserved field, which is unused.
991 _, rest = _get_sshstr(rest)
992 sig_key_raw, rest = _get_sshstr(rest)
993 sig_type, sig_key = _get_sshstr(sig_key_raw)
994 if sig_type == _SSH_DSA and not _legacy_dsa_allowed:
995 raise UnsupportedAlgorithm(
996 "DSA signatures aren't supported in SSH certificates"
997 )
998 # Get the entire cert body and subtract the signature
999 tbs_cert_body = cert_body[: -len(rest)]
1000 signature_raw, rest = _get_sshstr(rest)
1001 _check_empty(rest)
1002 inner_sig_type, sig_rest = _get_sshstr(signature_raw)
1003 # RSA certs can have multiple algorithm types
1004 if (
1005 sig_type == _SSH_RSA
1006 and inner_sig_type
1007 not in [_SSH_RSA_SHA256, _SSH_RSA_SHA512, _SSH_RSA]
1008 ) or (sig_type != _SSH_RSA and inner_sig_type != sig_type):
1009 raise ValueError("Signature key type does not match")
1010 signature, sig_rest = _get_sshstr(sig_rest)
1011 _check_empty(sig_rest)
1012 return SSHCertificate(
1013 nonce,
1014 public_key,
1015 serial,
1016 cctype,
1017 key_id,
1018 valid_principals,
1019 valid_after,
1020 valid_before,
1021 critical_options,
1022 extensions,
1023 sig_type,
1024 sig_key,
1025 inner_sig_type,
1026 signature,
1027 tbs_cert_body,
1028 orig_key_type,
1029 cert_body,
1030 )
1031 else:
1032 _check_empty(rest)
1033 return public_key
1036def load_ssh_public_identity(
1037 data: bytes,
1038) -> SSHCertificate | SSHPublicKeyTypes:
1039 return _load_ssh_public_identity(data)
1042def _parse_exts_opts(exts_opts: memoryview) -> dict[bytes, bytes]:
1043 result: dict[bytes, bytes] = {}
1044 last_name = None
1045 while exts_opts:
1046 name, exts_opts = _get_sshstr(exts_opts)
1047 bname: bytes = bytes(name)
1048 if bname in result:
1049 raise ValueError("Duplicate name")
1050 if last_name is not None and bname < last_name:
1051 raise ValueError("Fields not lexically sorted")
1052 value, exts_opts = _get_sshstr(exts_opts)
1053 if len(value) > 0:
1054 value, extra = _get_sshstr(value)
1055 if len(extra) > 0:
1056 raise ValueError("Unexpected extra data after value")
1057 result[bname] = bytes(value)
1058 last_name = bname
1059 return result
1062def load_ssh_public_key(
1063 data: bytes, backend: typing.Any = None
1064) -> SSHPublicKeyTypes:
1065 cert_or_key = _load_ssh_public_identity(data, _legacy_dsa_allowed=True)
1066 public_key: SSHPublicKeyTypes
1067 if isinstance(cert_or_key, SSHCertificate):
1068 public_key = cert_or_key.public_key()
1069 else:
1070 public_key = cert_or_key
1072 if isinstance(public_key, dsa.DSAPublicKey):
1073 warnings.warn(
1074 "SSH DSA keys are deprecated and will be removed in a future "
1075 "release.",
1076 utils.DeprecatedIn40,
1077 stacklevel=2,
1078 )
1079 return public_key
1082def serialize_ssh_public_key(public_key: SSHPublicKeyTypes) -> bytes:
1083 """One-line public key format for OpenSSH"""
1084 if isinstance(public_key, dsa.DSAPublicKey):
1085 warnings.warn(
1086 "SSH DSA key support is deprecated and will be "
1087 "removed in a future release",
1088 utils.DeprecatedIn40,
1089 stacklevel=4,
1090 )
1091 key_type = _get_ssh_key_type(public_key)
1092 kformat = _lookup_kformat(key_type)
1094 f_pub = _FragList()
1095 f_pub.put_sshstr(key_type)
1096 kformat.encode_public(public_key, f_pub)
1098 pub = binascii.b2a_base64(f_pub.tobytes()).strip()
1099 return b"".join([key_type, b" ", pub])
1102SSHCertPrivateKeyTypes = typing.Union[
1103 ec.EllipticCurvePrivateKey,
1104 rsa.RSAPrivateKey,
1105 ed25519.Ed25519PrivateKey,
1106]
1109# This is an undocumented limit enforced in the openssh codebase for sshd and
1110# ssh-keygen, but it is undefined in the ssh certificates spec.
1111_SSHKEY_CERT_MAX_PRINCIPALS = 256
1114class SSHCertificateBuilder:
1115 def __init__(
1116 self,
1117 _public_key: SSHCertPublicKeyTypes | None = None,
1118 _serial: int | None = None,
1119 _type: SSHCertificateType | None = None,
1120 _key_id: bytes | None = None,
1121 _valid_principals: list[bytes] = [],
1122 _valid_for_all_principals: bool = False,
1123 _valid_before: int | None = None,
1124 _valid_after: int | None = None,
1125 _critical_options: list[tuple[bytes, bytes]] = [],
1126 _extensions: list[tuple[bytes, bytes]] = [],
1127 ):
1128 self._public_key = _public_key
1129 self._serial = _serial
1130 self._type = _type
1131 self._key_id = _key_id
1132 self._valid_principals = _valid_principals
1133 self._valid_for_all_principals = _valid_for_all_principals
1134 self._valid_before = _valid_before
1135 self._valid_after = _valid_after
1136 self._critical_options = _critical_options
1137 self._extensions = _extensions
1139 def public_key(
1140 self, public_key: SSHCertPublicKeyTypes
1141 ) -> SSHCertificateBuilder:
1142 if not isinstance(
1143 public_key,
1144 (
1145 ec.EllipticCurvePublicKey,
1146 rsa.RSAPublicKey,
1147 ed25519.Ed25519PublicKey,
1148 ),
1149 ):
1150 raise TypeError("Unsupported key type")
1151 if self._public_key is not None:
1152 raise ValueError("public_key already set")
1154 return SSHCertificateBuilder(
1155 _public_key=public_key,
1156 _serial=self._serial,
1157 _type=self._type,
1158 _key_id=self._key_id,
1159 _valid_principals=self._valid_principals,
1160 _valid_for_all_principals=self._valid_for_all_principals,
1161 _valid_before=self._valid_before,
1162 _valid_after=self._valid_after,
1163 _critical_options=self._critical_options,
1164 _extensions=self._extensions,
1165 )
1167 def serial(self, serial: int) -> SSHCertificateBuilder:
1168 if not isinstance(serial, int):
1169 raise TypeError("serial must be an integer")
1170 if not 0 <= serial < 2**64:
1171 raise ValueError("serial must be between 0 and 2**64")
1172 if self._serial is not None:
1173 raise ValueError("serial already set")
1175 return SSHCertificateBuilder(
1176 _public_key=self._public_key,
1177 _serial=serial,
1178 _type=self._type,
1179 _key_id=self._key_id,
1180 _valid_principals=self._valid_principals,
1181 _valid_for_all_principals=self._valid_for_all_principals,
1182 _valid_before=self._valid_before,
1183 _valid_after=self._valid_after,
1184 _critical_options=self._critical_options,
1185 _extensions=self._extensions,
1186 )
1188 def type(self, type: SSHCertificateType) -> SSHCertificateBuilder:
1189 if not isinstance(type, SSHCertificateType):
1190 raise TypeError("type must be an SSHCertificateType")
1191 if self._type is not None:
1192 raise ValueError("type already set")
1194 return SSHCertificateBuilder(
1195 _public_key=self._public_key,
1196 _serial=self._serial,
1197 _type=type,
1198 _key_id=self._key_id,
1199 _valid_principals=self._valid_principals,
1200 _valid_for_all_principals=self._valid_for_all_principals,
1201 _valid_before=self._valid_before,
1202 _valid_after=self._valid_after,
1203 _critical_options=self._critical_options,
1204 _extensions=self._extensions,
1205 )
1207 def key_id(self, key_id: bytes) -> SSHCertificateBuilder:
1208 if not isinstance(key_id, bytes):
1209 raise TypeError("key_id must be bytes")
1210 if self._key_id is not None:
1211 raise ValueError("key_id already set")
1213 return SSHCertificateBuilder(
1214 _public_key=self._public_key,
1215 _serial=self._serial,
1216 _type=self._type,
1217 _key_id=key_id,
1218 _valid_principals=self._valid_principals,
1219 _valid_for_all_principals=self._valid_for_all_principals,
1220 _valid_before=self._valid_before,
1221 _valid_after=self._valid_after,
1222 _critical_options=self._critical_options,
1223 _extensions=self._extensions,
1224 )
1226 def valid_principals(
1227 self, valid_principals: list[bytes]
1228 ) -> SSHCertificateBuilder:
1229 if self._valid_for_all_principals:
1230 raise ValueError(
1231 "Principals can't be set because the cert is valid "
1232 "for all principals"
1233 )
1234 if (
1235 not all(isinstance(x, bytes) for x in valid_principals)
1236 or not valid_principals
1237 ):
1238 raise TypeError(
1239 "principals must be a list of bytes and can't be empty"
1240 )
1241 if self._valid_principals:
1242 raise ValueError("valid_principals already set")
1244 if len(valid_principals) > _SSHKEY_CERT_MAX_PRINCIPALS:
1245 raise ValueError(
1246 "Reached or exceeded the maximum number of valid_principals"
1247 )
1249 return SSHCertificateBuilder(
1250 _public_key=self._public_key,
1251 _serial=self._serial,
1252 _type=self._type,
1253 _key_id=self._key_id,
1254 _valid_principals=valid_principals,
1255 _valid_for_all_principals=self._valid_for_all_principals,
1256 _valid_before=self._valid_before,
1257 _valid_after=self._valid_after,
1258 _critical_options=self._critical_options,
1259 _extensions=self._extensions,
1260 )
1262 def valid_for_all_principals(self):
1263 if self._valid_principals:
1264 raise ValueError(
1265 "valid_principals already set, can't set "
1266 "valid_for_all_principals"
1267 )
1268 if self._valid_for_all_principals:
1269 raise ValueError("valid_for_all_principals already set")
1271 return SSHCertificateBuilder(
1272 _public_key=self._public_key,
1273 _serial=self._serial,
1274 _type=self._type,
1275 _key_id=self._key_id,
1276 _valid_principals=self._valid_principals,
1277 _valid_for_all_principals=True,
1278 _valid_before=self._valid_before,
1279 _valid_after=self._valid_after,
1280 _critical_options=self._critical_options,
1281 _extensions=self._extensions,
1282 )
1284 def valid_before(self, valid_before: int | float) -> SSHCertificateBuilder:
1285 if not isinstance(valid_before, (int, float)):
1286 raise TypeError("valid_before must be an int or float")
1287 valid_before = int(valid_before)
1288 if valid_before < 0 or valid_before >= 2**64:
1289 raise ValueError("valid_before must [0, 2**64)")
1290 if self._valid_before is not None:
1291 raise ValueError("valid_before already set")
1293 return SSHCertificateBuilder(
1294 _public_key=self._public_key,
1295 _serial=self._serial,
1296 _type=self._type,
1297 _key_id=self._key_id,
1298 _valid_principals=self._valid_principals,
1299 _valid_for_all_principals=self._valid_for_all_principals,
1300 _valid_before=valid_before,
1301 _valid_after=self._valid_after,
1302 _critical_options=self._critical_options,
1303 _extensions=self._extensions,
1304 )
1306 def valid_after(self, valid_after: int | float) -> SSHCertificateBuilder:
1307 if not isinstance(valid_after, (int, float)):
1308 raise TypeError("valid_after must be an int or float")
1309 valid_after = int(valid_after)
1310 if valid_after < 0 or valid_after >= 2**64:
1311 raise ValueError("valid_after must [0, 2**64)")
1312 if self._valid_after is not None:
1313 raise ValueError("valid_after already set")
1315 return SSHCertificateBuilder(
1316 _public_key=self._public_key,
1317 _serial=self._serial,
1318 _type=self._type,
1319 _key_id=self._key_id,
1320 _valid_principals=self._valid_principals,
1321 _valid_for_all_principals=self._valid_for_all_principals,
1322 _valid_before=self._valid_before,
1323 _valid_after=valid_after,
1324 _critical_options=self._critical_options,
1325 _extensions=self._extensions,
1326 )
1328 def add_critical_option(
1329 self, name: bytes, value: bytes
1330 ) -> SSHCertificateBuilder:
1331 if not isinstance(name, bytes) or not isinstance(value, bytes):
1332 raise TypeError("name and value must be bytes")
1333 # This is O(n**2)
1334 if name in [name for name, _ in self._critical_options]:
1335 raise ValueError("Duplicate critical option name")
1337 return SSHCertificateBuilder(
1338 _public_key=self._public_key,
1339 _serial=self._serial,
1340 _type=self._type,
1341 _key_id=self._key_id,
1342 _valid_principals=self._valid_principals,
1343 _valid_for_all_principals=self._valid_for_all_principals,
1344 _valid_before=self._valid_before,
1345 _valid_after=self._valid_after,
1346 _critical_options=[*self._critical_options, (name, value)],
1347 _extensions=self._extensions,
1348 )
1350 def add_extension(
1351 self, name: bytes, value: bytes
1352 ) -> SSHCertificateBuilder:
1353 if not isinstance(name, bytes) or not isinstance(value, bytes):
1354 raise TypeError("name and value must be bytes")
1355 # This is O(n**2)
1356 if name in [name for name, _ in self._extensions]:
1357 raise ValueError("Duplicate extension name")
1359 return SSHCertificateBuilder(
1360 _public_key=self._public_key,
1361 _serial=self._serial,
1362 _type=self._type,
1363 _key_id=self._key_id,
1364 _valid_principals=self._valid_principals,
1365 _valid_for_all_principals=self._valid_for_all_principals,
1366 _valid_before=self._valid_before,
1367 _valid_after=self._valid_after,
1368 _critical_options=self._critical_options,
1369 _extensions=[*self._extensions, (name, value)],
1370 )
1372 def sign(self, private_key: SSHCertPrivateKeyTypes) -> SSHCertificate:
1373 if not isinstance(
1374 private_key,
1375 (
1376 ec.EllipticCurvePrivateKey,
1377 rsa.RSAPrivateKey,
1378 ed25519.Ed25519PrivateKey,
1379 ),
1380 ):
1381 raise TypeError("Unsupported private key type")
1383 if self._public_key is None:
1384 raise ValueError("public_key must be set")
1386 # Not required
1387 serial = 0 if self._serial is None else self._serial
1389 if self._type is None:
1390 raise ValueError("type must be set")
1392 # Not required
1393 key_id = b"" if self._key_id is None else self._key_id
1395 # A zero length list is valid, but means the certificate
1396 # is valid for any principal of the specified type. We require
1397 # the user to explicitly set valid_for_all_principals to get
1398 # that behavior.
1399 if not self._valid_principals and not self._valid_for_all_principals:
1400 raise ValueError(
1401 "valid_principals must be set if valid_for_all_principals "
1402 "is False"
1403 )
1405 if self._valid_before is None:
1406 raise ValueError("valid_before must be set")
1408 if self._valid_after is None:
1409 raise ValueError("valid_after must be set")
1411 if self._valid_after > self._valid_before:
1412 raise ValueError("valid_after must be earlier than valid_before")
1414 # lexically sort our byte strings
1415 self._critical_options.sort(key=lambda x: x[0])
1416 self._extensions.sort(key=lambda x: x[0])
1418 key_type = _get_ssh_key_type(self._public_key)
1419 cert_prefix = key_type + _CERT_SUFFIX
1421 # Marshal the bytes to be signed
1422 nonce = os.urandom(32)
1423 kformat = _lookup_kformat(key_type)
1424 f = _FragList()
1425 f.put_sshstr(cert_prefix)
1426 f.put_sshstr(nonce)
1427 kformat.encode_public(self._public_key, f)
1428 f.put_u64(serial)
1429 f.put_u32(self._type.value)
1430 f.put_sshstr(key_id)
1431 fprincipals = _FragList()
1432 for p in self._valid_principals:
1433 fprincipals.put_sshstr(p)
1434 f.put_sshstr(fprincipals.tobytes())
1435 f.put_u64(self._valid_after)
1436 f.put_u64(self._valid_before)
1437 fcrit = _FragList()
1438 for name, value in self._critical_options:
1439 fcrit.put_sshstr(name)
1440 if len(value) > 0:
1441 foptval = _FragList()
1442 foptval.put_sshstr(value)
1443 fcrit.put_sshstr(foptval.tobytes())
1444 else:
1445 fcrit.put_sshstr(value)
1446 f.put_sshstr(fcrit.tobytes())
1447 fext = _FragList()
1448 for name, value in self._extensions:
1449 fext.put_sshstr(name)
1450 if len(value) > 0:
1451 fextval = _FragList()
1452 fextval.put_sshstr(value)
1453 fext.put_sshstr(fextval.tobytes())
1454 else:
1455 fext.put_sshstr(value)
1456 f.put_sshstr(fext.tobytes())
1457 f.put_sshstr(b"") # RESERVED FIELD
1458 # encode CA public key
1459 ca_type = _get_ssh_key_type(private_key)
1460 caformat = _lookup_kformat(ca_type)
1461 caf = _FragList()
1462 caf.put_sshstr(ca_type)
1463 caformat.encode_public(private_key.public_key(), caf)
1464 f.put_sshstr(caf.tobytes())
1465 # Sigs according to the rules defined for the CA's public key
1466 # (RFC4253 section 6.6 for ssh-rsa, RFC5656 for ECDSA,
1467 # and RFC8032 for Ed25519).
1468 if isinstance(private_key, ed25519.Ed25519PrivateKey):
1469 signature = private_key.sign(f.tobytes())
1470 fsig = _FragList()
1471 fsig.put_sshstr(ca_type)
1472 fsig.put_sshstr(signature)
1473 f.put_sshstr(fsig.tobytes())
1474 elif isinstance(private_key, ec.EllipticCurvePrivateKey):
1475 hash_alg = _get_ec_hash_alg(private_key.curve)
1476 signature = private_key.sign(f.tobytes(), ec.ECDSA(hash_alg))
1477 r, s = asym_utils.decode_dss_signature(signature)
1478 fsig = _FragList()
1479 fsig.put_sshstr(ca_type)
1480 fsigblob = _FragList()
1481 fsigblob.put_mpint(r)
1482 fsigblob.put_mpint(s)
1483 fsig.put_sshstr(fsigblob.tobytes())
1484 f.put_sshstr(fsig.tobytes())
1486 else:
1487 assert isinstance(private_key, rsa.RSAPrivateKey)
1488 # Just like Golang, we're going to use SHA512 for RSA
1489 # https://cs.opensource.google/go/x/crypto/+/refs/tags/
1490 # v0.4.0:ssh/certs.go;l=445
1491 # RFC 8332 defines SHA256 and 512 as options
1492 fsig = _FragList()
1493 fsig.put_sshstr(_SSH_RSA_SHA512)
1494 signature = private_key.sign(
1495 f.tobytes(), padding.PKCS1v15(), hashes.SHA512()
1496 )
1497 fsig.put_sshstr(signature)
1498 f.put_sshstr(fsig.tobytes())
1500 cert_data = binascii.b2a_base64(f.tobytes()).strip()
1501 # load_ssh_public_identity returns a union, but this is
1502 # guaranteed to be an SSHCertificate, so we cast to make
1503 # mypy happy.
1504 return typing.cast(
1505 SSHCertificate,
1506 load_ssh_public_identity(b"".join([cert_prefix, b" ", cert_data])),
1507 )