Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/cryptography/hazmat/primitives/serialization/ssh.py: 21%
746 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:50 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:50 +0000
1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
5from __future__ import annotations
7import binascii
8import enum
9import os
10import re
11import typing
12import warnings
13from base64 import encodebytes as _base64_encode
14from dataclasses import dataclass
16from cryptography import utils
17from cryptography.exceptions import UnsupportedAlgorithm
18from cryptography.hazmat.primitives import hashes
19from cryptography.hazmat.primitives.asymmetric import (
20 dsa,
21 ec,
22 ed25519,
23 padding,
24 rsa,
25)
26from cryptography.hazmat.primitives.asymmetric import utils as asym_utils
27from cryptography.hazmat.primitives.ciphers import (
28 AEADDecryptionContext,
29 Cipher,
30 algorithms,
31 modes,
32)
33from cryptography.hazmat.primitives.serialization import (
34 Encoding,
35 KeySerializationEncryption,
36 NoEncryption,
37 PrivateFormat,
38 PublicFormat,
39 _KeySerializationEncryption,
40)
42try:
43 from bcrypt import kdf as _bcrypt_kdf
45 _bcrypt_supported = True
46except ImportError:
47 _bcrypt_supported = False
49 def _bcrypt_kdf(
50 password: bytes,
51 salt: bytes,
52 desired_key_bytes: int,
53 rounds: int,
54 ignore_few_rounds: bool = False,
55 ) -> bytes:
56 raise UnsupportedAlgorithm("Need bcrypt module")
59_SSH_ED25519 = b"ssh-ed25519"
60_SSH_RSA = b"ssh-rsa"
61_SSH_DSA = b"ssh-dss"
62_ECDSA_NISTP256 = b"ecdsa-sha2-nistp256"
63_ECDSA_NISTP384 = b"ecdsa-sha2-nistp384"
64_ECDSA_NISTP521 = b"ecdsa-sha2-nistp521"
65_CERT_SUFFIX = b"-cert-v01@openssh.com"
67# These are not key types, only algorithms, so they cannot appear
68# as a public key type
69_SSH_RSA_SHA256 = b"rsa-sha2-256"
70_SSH_RSA_SHA512 = b"rsa-sha2-512"
72_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)")
73_SK_MAGIC = b"openssh-key-v1\0"
74_SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----"
75_SK_END = b"-----END OPENSSH PRIVATE KEY-----"
76_BCRYPT = b"bcrypt"
77_NONE = b"none"
78_DEFAULT_CIPHER = b"aes256-ctr"
79_DEFAULT_ROUNDS = 16
81# re is only way to work on bytes-like data
82_PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL)
84# padding for max blocksize
85_PADDING = memoryview(bytearray(range(1, 1 + 16)))
88@dataclass
89class _SSHCipher:
90 alg: typing.Type[algorithms.AES]
91 key_len: int
92 mode: typing.Union[
93 typing.Type[modes.CTR],
94 typing.Type[modes.CBC],
95 typing.Type[modes.GCM],
96 ]
97 block_len: int
98 iv_len: int
99 tag_len: typing.Optional[int]
100 is_aead: bool
103# ciphers that are actually used in key wrapping
104_SSH_CIPHERS: typing.Dict[bytes, _SSHCipher] = {
105 b"aes256-ctr": _SSHCipher(
106 alg=algorithms.AES,
107 key_len=32,
108 mode=modes.CTR,
109 block_len=16,
110 iv_len=16,
111 tag_len=None,
112 is_aead=False,
113 ),
114 b"aes256-cbc": _SSHCipher(
115 alg=algorithms.AES,
116 key_len=32,
117 mode=modes.CBC,
118 block_len=16,
119 iv_len=16,
120 tag_len=None,
121 is_aead=False,
122 ),
123 b"aes256-gcm@openssh.com": _SSHCipher(
124 alg=algorithms.AES,
125 key_len=32,
126 mode=modes.GCM,
127 block_len=16,
128 iv_len=12,
129 tag_len=16,
130 is_aead=True,
131 ),
132}
134# map local curve name to key type
135_ECDSA_KEY_TYPE = {
136 "secp256r1": _ECDSA_NISTP256,
137 "secp384r1": _ECDSA_NISTP384,
138 "secp521r1": _ECDSA_NISTP521,
139}
142def _get_ssh_key_type(
143 key: typing.Union[SSHPrivateKeyTypes, SSHPublicKeyTypes]
144) -> bytes:
145 if isinstance(key, ec.EllipticCurvePrivateKey):
146 key_type = _ecdsa_key_type(key.public_key())
147 elif isinstance(key, ec.EllipticCurvePublicKey):
148 key_type = _ecdsa_key_type(key)
149 elif isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)):
150 key_type = _SSH_RSA
151 elif isinstance(key, (dsa.DSAPrivateKey, dsa.DSAPublicKey)):
152 key_type = _SSH_DSA
153 elif isinstance(
154 key, (ed25519.Ed25519PrivateKey, ed25519.Ed25519PublicKey)
155 ):
156 key_type = _SSH_ED25519
157 else:
158 raise ValueError("Unsupported key type")
160 return key_type
163def _ecdsa_key_type(public_key: ec.EllipticCurvePublicKey) -> bytes:
164 """Return SSH key_type and curve_name for private key."""
165 curve = public_key.curve
166 if curve.name not in _ECDSA_KEY_TYPE:
167 raise ValueError(
168 f"Unsupported curve for ssh private key: {curve.name!r}"
169 )
170 return _ECDSA_KEY_TYPE[curve.name]
173def _ssh_pem_encode(
174 data: bytes,
175 prefix: bytes = _SK_START + b"\n",
176 suffix: bytes = _SK_END + b"\n",
177) -> bytes:
178 return b"".join([prefix, _base64_encode(data), suffix])
181def _check_block_size(data: bytes, block_len: int) -> None:
182 """Require data to be full blocks"""
183 if not data or len(data) % block_len != 0:
184 raise ValueError("Corrupt data: missing padding")
187def _check_empty(data: bytes) -> None:
188 """All data should have been parsed."""
189 if data:
190 raise ValueError("Corrupt data: unparsed data")
193def _init_cipher(
194 ciphername: bytes,
195 password: typing.Optional[bytes],
196 salt: bytes,
197 rounds: int,
198) -> Cipher[typing.Union[modes.CBC, modes.CTR, modes.GCM]]:
199 """Generate key + iv and return cipher."""
200 if not password:
201 raise ValueError("Key is password-protected.")
203 ciph = _SSH_CIPHERS[ciphername]
204 seed = _bcrypt_kdf(
205 password, salt, ciph.key_len + ciph.iv_len, rounds, True
206 )
207 return Cipher(
208 ciph.alg(seed[: ciph.key_len]),
209 ciph.mode(seed[ciph.key_len :]),
210 )
213def _get_u32(data: memoryview) -> typing.Tuple[int, memoryview]:
214 """Uint32"""
215 if len(data) < 4:
216 raise ValueError("Invalid data")
217 return int.from_bytes(data[:4], byteorder="big"), data[4:]
220def _get_u64(data: memoryview) -> typing.Tuple[int, memoryview]:
221 """Uint64"""
222 if len(data) < 8:
223 raise ValueError("Invalid data")
224 return int.from_bytes(data[:8], byteorder="big"), data[8:]
227def _get_sshstr(data: memoryview) -> typing.Tuple[memoryview, memoryview]:
228 """Bytes with u32 length prefix"""
229 n, data = _get_u32(data)
230 if n > len(data):
231 raise ValueError("Invalid data")
232 return data[:n], data[n:]
235def _get_mpint(data: memoryview) -> typing.Tuple[int, memoryview]:
236 """Big integer."""
237 val, data = _get_sshstr(data)
238 if val and val[0] > 0x7F:
239 raise ValueError("Invalid data")
240 return int.from_bytes(val, "big"), data
243def _to_mpint(val: int) -> bytes:
244 """Storage format for signed bigint."""
245 if val < 0:
246 raise ValueError("negative mpint not allowed")
247 if not val:
248 return b""
249 nbytes = (val.bit_length() + 8) // 8
250 return utils.int_to_bytes(val, nbytes)
253class _FragList:
254 """Build recursive structure without data copy."""
256 flist: typing.List[bytes]
258 def __init__(
259 self, init: typing.Optional[typing.List[bytes]] = None
260 ) -> None:
261 self.flist = []
262 if init:
263 self.flist.extend(init)
265 def put_raw(self, val: bytes) -> None:
266 """Add plain bytes"""
267 self.flist.append(val)
269 def put_u32(self, val: int) -> None:
270 """Big-endian uint32"""
271 self.flist.append(val.to_bytes(length=4, byteorder="big"))
273 def put_u64(self, val: int) -> None:
274 """Big-endian uint64"""
275 self.flist.append(val.to_bytes(length=8, byteorder="big"))
277 def put_sshstr(self, val: typing.Union[bytes, _FragList]) -> None:
278 """Bytes prefixed with u32 length"""
279 if isinstance(val, (bytes, memoryview, bytearray)):
280 self.put_u32(len(val))
281 self.flist.append(val)
282 else:
283 self.put_u32(val.size())
284 self.flist.extend(val.flist)
286 def put_mpint(self, val: int) -> None:
287 """Big-endian bigint prefixed with u32 length"""
288 self.put_sshstr(_to_mpint(val))
290 def size(self) -> int:
291 """Current number of bytes"""
292 return sum(map(len, self.flist))
294 def render(self, dstbuf: memoryview, pos: int = 0) -> int:
295 """Write into bytearray"""
296 for frag in self.flist:
297 flen = len(frag)
298 start, pos = pos, pos + flen
299 dstbuf[start:pos] = frag
300 return pos
302 def tobytes(self) -> bytes:
303 """Return as bytes"""
304 buf = memoryview(bytearray(self.size()))
305 self.render(buf)
306 return buf.tobytes()
309class _SSHFormatRSA:
310 """Format for RSA keys.
312 Public:
313 mpint e, n
314 Private:
315 mpint n, e, d, iqmp, p, q
316 """
318 def get_public(self, data: memoryview):
319 """RSA public fields"""
320 e, data = _get_mpint(data)
321 n, data = _get_mpint(data)
322 return (e, n), data
324 def load_public(
325 self, data: memoryview
326 ) -> typing.Tuple[rsa.RSAPublicKey, memoryview]:
327 """Make RSA public key from data."""
328 (e, n), data = self.get_public(data)
329 public_numbers = rsa.RSAPublicNumbers(e, n)
330 public_key = public_numbers.public_key()
331 return public_key, data
333 def load_private(
334 self, data: memoryview, pubfields
335 ) -> typing.Tuple[rsa.RSAPrivateKey, memoryview]:
336 """Make RSA private key from data."""
337 n, data = _get_mpint(data)
338 e, data = _get_mpint(data)
339 d, data = _get_mpint(data)
340 iqmp, data = _get_mpint(data)
341 p, data = _get_mpint(data)
342 q, data = _get_mpint(data)
344 if (e, n) != pubfields:
345 raise ValueError("Corrupt data: rsa field mismatch")
346 dmp1 = rsa.rsa_crt_dmp1(d, p)
347 dmq1 = rsa.rsa_crt_dmq1(d, q)
348 public_numbers = rsa.RSAPublicNumbers(e, n)
349 private_numbers = rsa.RSAPrivateNumbers(
350 p, q, d, dmp1, dmq1, iqmp, public_numbers
351 )
352 private_key = private_numbers.private_key()
353 return private_key, data
355 def encode_public(
356 self, public_key: rsa.RSAPublicKey, f_pub: _FragList
357 ) -> None:
358 """Write RSA public key"""
359 pubn = public_key.public_numbers()
360 f_pub.put_mpint(pubn.e)
361 f_pub.put_mpint(pubn.n)
363 def encode_private(
364 self, private_key: rsa.RSAPrivateKey, f_priv: _FragList
365 ) -> None:
366 """Write RSA private key"""
367 private_numbers = private_key.private_numbers()
368 public_numbers = private_numbers.public_numbers
370 f_priv.put_mpint(public_numbers.n)
371 f_priv.put_mpint(public_numbers.e)
373 f_priv.put_mpint(private_numbers.d)
374 f_priv.put_mpint(private_numbers.iqmp)
375 f_priv.put_mpint(private_numbers.p)
376 f_priv.put_mpint(private_numbers.q)
379class _SSHFormatDSA:
380 """Format for DSA keys.
382 Public:
383 mpint p, q, g, y
384 Private:
385 mpint p, q, g, y, x
386 """
388 def get_public(
389 self, data: memoryview
390 ) -> typing.Tuple[typing.Tuple, memoryview]:
391 """DSA public fields"""
392 p, data = _get_mpint(data)
393 q, data = _get_mpint(data)
394 g, data = _get_mpint(data)
395 y, data = _get_mpint(data)
396 return (p, q, g, y), data
398 def load_public(
399 self, data: memoryview
400 ) -> typing.Tuple[dsa.DSAPublicKey, memoryview]:
401 """Make DSA public key from data."""
402 (p, q, g, y), data = self.get_public(data)
403 parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
404 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
405 self._validate(public_numbers)
406 public_key = public_numbers.public_key()
407 return public_key, data
409 def load_private(
410 self, data: memoryview, pubfields
411 ) -> typing.Tuple[dsa.DSAPrivateKey, memoryview]:
412 """Make DSA private key from data."""
413 (p, q, g, y), data = self.get_public(data)
414 x, data = _get_mpint(data)
416 if (p, q, g, y) != pubfields:
417 raise ValueError("Corrupt data: dsa field mismatch")
418 parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
419 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
420 self._validate(public_numbers)
421 private_numbers = dsa.DSAPrivateNumbers(x, public_numbers)
422 private_key = private_numbers.private_key()
423 return private_key, data
425 def encode_public(
426 self, public_key: dsa.DSAPublicKey, f_pub: _FragList
427 ) -> None:
428 """Write DSA public key"""
429 public_numbers = public_key.public_numbers()
430 parameter_numbers = public_numbers.parameter_numbers
431 self._validate(public_numbers)
433 f_pub.put_mpint(parameter_numbers.p)
434 f_pub.put_mpint(parameter_numbers.q)
435 f_pub.put_mpint(parameter_numbers.g)
436 f_pub.put_mpint(public_numbers.y)
438 def encode_private(
439 self, private_key: dsa.DSAPrivateKey, f_priv: _FragList
440 ) -> None:
441 """Write DSA private key"""
442 self.encode_public(private_key.public_key(), f_priv)
443 f_priv.put_mpint(private_key.private_numbers().x)
445 def _validate(self, public_numbers: dsa.DSAPublicNumbers) -> None:
446 parameter_numbers = public_numbers.parameter_numbers
447 if parameter_numbers.p.bit_length() != 1024:
448 raise ValueError("SSH supports only 1024 bit DSA keys")
451class _SSHFormatECDSA:
452 """Format for ECDSA keys.
454 Public:
455 str curve
456 bytes point
457 Private:
458 str curve
459 bytes point
460 mpint secret
461 """
463 def __init__(self, ssh_curve_name: bytes, curve: ec.EllipticCurve):
464 self.ssh_curve_name = ssh_curve_name
465 self.curve = curve
467 def get_public(
468 self, data: memoryview
469 ) -> typing.Tuple[typing.Tuple, memoryview]:
470 """ECDSA public fields"""
471 curve, data = _get_sshstr(data)
472 point, data = _get_sshstr(data)
473 if curve != self.ssh_curve_name:
474 raise ValueError("Curve name mismatch")
475 if point[0] != 4:
476 raise NotImplementedError("Need uncompressed point")
477 return (curve, point), data
479 def load_public(
480 self, data: memoryview
481 ) -> typing.Tuple[ec.EllipticCurvePublicKey, memoryview]:
482 """Make ECDSA public key from data."""
483 (curve_name, point), data = self.get_public(data)
484 public_key = ec.EllipticCurvePublicKey.from_encoded_point(
485 self.curve, point.tobytes()
486 )
487 return public_key, data
489 def load_private(
490 self, data: memoryview, pubfields
491 ) -> typing.Tuple[ec.EllipticCurvePrivateKey, memoryview]:
492 """Make ECDSA private key from data."""
493 (curve_name, point), data = self.get_public(data)
494 secret, data = _get_mpint(data)
496 if (curve_name, point) != pubfields:
497 raise ValueError("Corrupt data: ecdsa field mismatch")
498 private_key = ec.derive_private_key(secret, self.curve)
499 return private_key, data
501 def encode_public(
502 self, public_key: ec.EllipticCurvePublicKey, f_pub: _FragList
503 ) -> None:
504 """Write ECDSA public key"""
505 point = public_key.public_bytes(
506 Encoding.X962, PublicFormat.UncompressedPoint
507 )
508 f_pub.put_sshstr(self.ssh_curve_name)
509 f_pub.put_sshstr(point)
511 def encode_private(
512 self, private_key: ec.EllipticCurvePrivateKey, f_priv: _FragList
513 ) -> None:
514 """Write ECDSA private key"""
515 public_key = private_key.public_key()
516 private_numbers = private_key.private_numbers()
518 self.encode_public(public_key, f_priv)
519 f_priv.put_mpint(private_numbers.private_value)
522class _SSHFormatEd25519:
523 """Format for Ed25519 keys.
525 Public:
526 bytes point
527 Private:
528 bytes point
529 bytes secret_and_point
530 """
532 def get_public(
533 self, data: memoryview
534 ) -> typing.Tuple[typing.Tuple, memoryview]:
535 """Ed25519 public fields"""
536 point, data = _get_sshstr(data)
537 return (point,), data
539 def load_public(
540 self, data: memoryview
541 ) -> typing.Tuple[ed25519.Ed25519PublicKey, memoryview]:
542 """Make Ed25519 public key from data."""
543 (point,), data = self.get_public(data)
544 public_key = ed25519.Ed25519PublicKey.from_public_bytes(
545 point.tobytes()
546 )
547 return public_key, data
549 def load_private(
550 self, data: memoryview, pubfields
551 ) -> typing.Tuple[ed25519.Ed25519PrivateKey, memoryview]:
552 """Make Ed25519 private key from data."""
553 (point,), data = self.get_public(data)
554 keypair, data = _get_sshstr(data)
556 secret = keypair[:32]
557 point2 = keypair[32:]
558 if point != point2 or (point,) != pubfields:
559 raise ValueError("Corrupt data: ed25519 field mismatch")
560 private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret)
561 return private_key, data
563 def encode_public(
564 self, public_key: ed25519.Ed25519PublicKey, f_pub: _FragList
565 ) -> None:
566 """Write Ed25519 public key"""
567 raw_public_key = public_key.public_bytes(
568 Encoding.Raw, PublicFormat.Raw
569 )
570 f_pub.put_sshstr(raw_public_key)
572 def encode_private(
573 self, private_key: ed25519.Ed25519PrivateKey, f_priv: _FragList
574 ) -> None:
575 """Write Ed25519 private key"""
576 public_key = private_key.public_key()
577 raw_private_key = private_key.private_bytes(
578 Encoding.Raw, PrivateFormat.Raw, NoEncryption()
579 )
580 raw_public_key = public_key.public_bytes(
581 Encoding.Raw, PublicFormat.Raw
582 )
583 f_keypair = _FragList([raw_private_key, raw_public_key])
585 self.encode_public(public_key, f_priv)
586 f_priv.put_sshstr(f_keypair)
589_KEY_FORMATS = {
590 _SSH_RSA: _SSHFormatRSA(),
591 _SSH_DSA: _SSHFormatDSA(),
592 _SSH_ED25519: _SSHFormatEd25519(),
593 _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()),
594 _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()),
595 _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()),
596}
599def _lookup_kformat(key_type: bytes):
600 """Return valid format or throw error"""
601 if not isinstance(key_type, bytes):
602 key_type = memoryview(key_type).tobytes()
603 if key_type in _KEY_FORMATS:
604 return _KEY_FORMATS[key_type]
605 raise UnsupportedAlgorithm(f"Unsupported key type: {key_type!r}")
608SSHPrivateKeyTypes = typing.Union[
609 ec.EllipticCurvePrivateKey,
610 rsa.RSAPrivateKey,
611 dsa.DSAPrivateKey,
612 ed25519.Ed25519PrivateKey,
613]
616def load_ssh_private_key(
617 data: bytes,
618 password: typing.Optional[bytes],
619 backend: typing.Any = None,
620) -> SSHPrivateKeyTypes:
621 """Load private key from OpenSSH custom encoding."""
622 utils._check_byteslike("data", data)
623 if password is not None:
624 utils._check_bytes("password", password)
626 m = _PEM_RC.search(data)
627 if not m:
628 raise ValueError("Not OpenSSH private key format")
629 p1 = m.start(1)
630 p2 = m.end(1)
631 data = binascii.a2b_base64(memoryview(data)[p1:p2])
632 if not data.startswith(_SK_MAGIC):
633 raise ValueError("Not OpenSSH private key format")
634 data = memoryview(data)[len(_SK_MAGIC) :]
636 # parse header
637 ciphername, data = _get_sshstr(data)
638 kdfname, data = _get_sshstr(data)
639 kdfoptions, data = _get_sshstr(data)
640 nkeys, data = _get_u32(data)
641 if nkeys != 1:
642 raise ValueError("Only one key supported")
644 # load public key data
645 pubdata, data = _get_sshstr(data)
646 pub_key_type, pubdata = _get_sshstr(pubdata)
647 kformat = _lookup_kformat(pub_key_type)
648 pubfields, pubdata = kformat.get_public(pubdata)
649 _check_empty(pubdata)
651 if (ciphername, kdfname) != (_NONE, _NONE):
652 ciphername_bytes = ciphername.tobytes()
653 if ciphername_bytes not in _SSH_CIPHERS:
654 raise UnsupportedAlgorithm(
655 f"Unsupported cipher: {ciphername_bytes!r}"
656 )
657 if kdfname != _BCRYPT:
658 raise UnsupportedAlgorithm(f"Unsupported KDF: {kdfname!r}")
659 blklen = _SSH_CIPHERS[ciphername_bytes].block_len
660 tag_len = _SSH_CIPHERS[ciphername_bytes].tag_len
661 # load secret data
662 edata, data = _get_sshstr(data)
663 # see https://bugzilla.mindrot.org/show_bug.cgi?id=3553 for
664 # information about how OpenSSH handles AEAD tags
665 if _SSH_CIPHERS[ciphername_bytes].is_aead:
666 tag = bytes(data)
667 if len(tag) != tag_len:
668 raise ValueError("Corrupt data: invalid tag length for cipher")
669 else:
670 _check_empty(data)
671 _check_block_size(edata, blklen)
672 salt, kbuf = _get_sshstr(kdfoptions)
673 rounds, kbuf = _get_u32(kbuf)
674 _check_empty(kbuf)
675 ciph = _init_cipher(ciphername_bytes, password, salt.tobytes(), rounds)
676 dec = ciph.decryptor()
677 edata = memoryview(dec.update(edata))
678 if _SSH_CIPHERS[ciphername_bytes].is_aead:
679 assert isinstance(dec, AEADDecryptionContext)
680 _check_empty(dec.finalize_with_tag(tag))
681 else:
682 # _check_block_size requires data to be a full block so there
683 # should be no output from finalize
684 _check_empty(dec.finalize())
685 else:
686 # load secret data
687 edata, data = _get_sshstr(data)
688 _check_empty(data)
689 blklen = 8
690 _check_block_size(edata, blklen)
691 ck1, edata = _get_u32(edata)
692 ck2, edata = _get_u32(edata)
693 if ck1 != ck2:
694 raise ValueError("Corrupt data: broken checksum")
696 # load per-key struct
697 key_type, edata = _get_sshstr(edata)
698 if key_type != pub_key_type:
699 raise ValueError("Corrupt data: key type mismatch")
700 private_key, edata = kformat.load_private(edata, pubfields)
701 comment, edata = _get_sshstr(edata)
703 # yes, SSH does padding check *after* all other parsing is done.
704 # need to follow as it writes zero-byte padding too.
705 if edata != _PADDING[: len(edata)]:
706 raise ValueError("Corrupt data: invalid padding")
708 if isinstance(private_key, dsa.DSAPrivateKey):
709 warnings.warn(
710 "SSH DSA keys are deprecated and will be removed in a future "
711 "release.",
712 utils.DeprecatedIn40,
713 stacklevel=2,
714 )
716 return private_key
719def _serialize_ssh_private_key(
720 private_key: SSHPrivateKeyTypes,
721 password: bytes,
722 encryption_algorithm: KeySerializationEncryption,
723) -> bytes:
724 """Serialize private key with OpenSSH custom encoding."""
725 utils._check_bytes("password", password)
726 if isinstance(private_key, dsa.DSAPrivateKey):
727 warnings.warn(
728 "SSH DSA key support is deprecated and will be "
729 "removed in a future release",
730 utils.DeprecatedIn40,
731 stacklevel=4,
732 )
734 key_type = _get_ssh_key_type(private_key)
735 kformat = _lookup_kformat(key_type)
737 # setup parameters
738 f_kdfoptions = _FragList()
739 if password:
740 ciphername = _DEFAULT_CIPHER
741 blklen = _SSH_CIPHERS[ciphername].block_len
742 kdfname = _BCRYPT
743 rounds = _DEFAULT_ROUNDS
744 if (
745 isinstance(encryption_algorithm, _KeySerializationEncryption)
746 and encryption_algorithm._kdf_rounds is not None
747 ):
748 rounds = encryption_algorithm._kdf_rounds
749 salt = os.urandom(16)
750 f_kdfoptions.put_sshstr(salt)
751 f_kdfoptions.put_u32(rounds)
752 ciph = _init_cipher(ciphername, password, salt, rounds)
753 else:
754 ciphername = kdfname = _NONE
755 blklen = 8
756 ciph = None
757 nkeys = 1
758 checkval = os.urandom(4)
759 comment = b""
761 # encode public and private parts together
762 f_public_key = _FragList()
763 f_public_key.put_sshstr(key_type)
764 kformat.encode_public(private_key.public_key(), f_public_key)
766 f_secrets = _FragList([checkval, checkval])
767 f_secrets.put_sshstr(key_type)
768 kformat.encode_private(private_key, f_secrets)
769 f_secrets.put_sshstr(comment)
770 f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)])
772 # top-level structure
773 f_main = _FragList()
774 f_main.put_raw(_SK_MAGIC)
775 f_main.put_sshstr(ciphername)
776 f_main.put_sshstr(kdfname)
777 f_main.put_sshstr(f_kdfoptions)
778 f_main.put_u32(nkeys)
779 f_main.put_sshstr(f_public_key)
780 f_main.put_sshstr(f_secrets)
782 # copy result info bytearray
783 slen = f_secrets.size()
784 mlen = f_main.size()
785 buf = memoryview(bytearray(mlen + blklen))
786 f_main.render(buf)
787 ofs = mlen - slen
789 # encrypt in-place
790 if ciph is not None:
791 ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:])
793 return _ssh_pem_encode(buf[:mlen])
796SSHPublicKeyTypes = typing.Union[
797 ec.EllipticCurvePublicKey,
798 rsa.RSAPublicKey,
799 dsa.DSAPublicKey,
800 ed25519.Ed25519PublicKey,
801]
803SSHCertPublicKeyTypes = typing.Union[
804 ec.EllipticCurvePublicKey,
805 rsa.RSAPublicKey,
806 ed25519.Ed25519PublicKey,
807]
810class SSHCertificateType(enum.Enum):
811 USER = 1
812 HOST = 2
815class SSHCertificate:
816 def __init__(
817 self,
818 _nonce: memoryview,
819 _public_key: SSHPublicKeyTypes,
820 _serial: int,
821 _cctype: int,
822 _key_id: memoryview,
823 _valid_principals: typing.List[bytes],
824 _valid_after: int,
825 _valid_before: int,
826 _critical_options: typing.Dict[bytes, bytes],
827 _extensions: typing.Dict[bytes, bytes],
828 _sig_type: memoryview,
829 _sig_key: memoryview,
830 _inner_sig_type: memoryview,
831 _signature: memoryview,
832 _tbs_cert_body: memoryview,
833 _cert_key_type: bytes,
834 _cert_body: memoryview,
835 ):
836 self._nonce = _nonce
837 self._public_key = _public_key
838 self._serial = _serial
839 try:
840 self._type = SSHCertificateType(_cctype)
841 except ValueError:
842 raise ValueError("Invalid certificate type")
843 self._key_id = _key_id
844 self._valid_principals = _valid_principals
845 self._valid_after = _valid_after
846 self._valid_before = _valid_before
847 self._critical_options = _critical_options
848 self._extensions = _extensions
849 self._sig_type = _sig_type
850 self._sig_key = _sig_key
851 self._inner_sig_type = _inner_sig_type
852 self._signature = _signature
853 self._cert_key_type = _cert_key_type
854 self._cert_body = _cert_body
855 self._tbs_cert_body = _tbs_cert_body
857 @property
858 def nonce(self) -> bytes:
859 return bytes(self._nonce)
861 def public_key(self) -> SSHCertPublicKeyTypes:
862 # make mypy happy until we remove DSA support entirely and
863 # the underlying union won't have a disallowed type
864 return typing.cast(SSHCertPublicKeyTypes, self._public_key)
866 @property
867 def serial(self) -> int:
868 return self._serial
870 @property
871 def type(self) -> SSHCertificateType:
872 return self._type
874 @property
875 def key_id(self) -> bytes:
876 return bytes(self._key_id)
878 @property
879 def valid_principals(self) -> typing.List[bytes]:
880 return self._valid_principals
882 @property
883 def valid_before(self) -> int:
884 return self._valid_before
886 @property
887 def valid_after(self) -> int:
888 return self._valid_after
890 @property
891 def critical_options(self) -> typing.Dict[bytes, bytes]:
892 return self._critical_options
894 @property
895 def extensions(self) -> typing.Dict[bytes, bytes]:
896 return self._extensions
898 def signature_key(self) -> SSHCertPublicKeyTypes:
899 sigformat = _lookup_kformat(self._sig_type)
900 signature_key, sigkey_rest = sigformat.load_public(self._sig_key)
901 _check_empty(sigkey_rest)
902 return signature_key
904 def public_bytes(self) -> bytes:
905 return (
906 bytes(self._cert_key_type)
907 + b" "
908 + binascii.b2a_base64(bytes(self._cert_body), newline=False)
909 )
911 def verify_cert_signature(self) -> None:
912 signature_key = self.signature_key()
913 if isinstance(signature_key, ed25519.Ed25519PublicKey):
914 signature_key.verify(
915 bytes(self._signature), bytes(self._tbs_cert_body)
916 )
917 elif isinstance(signature_key, ec.EllipticCurvePublicKey):
918 # The signature is encoded as a pair of big-endian integers
919 r, data = _get_mpint(self._signature)
920 s, data = _get_mpint(data)
921 _check_empty(data)
922 computed_sig = asym_utils.encode_dss_signature(r, s)
923 hash_alg = _get_ec_hash_alg(signature_key.curve)
924 signature_key.verify(
925 computed_sig, bytes(self._tbs_cert_body), ec.ECDSA(hash_alg)
926 )
927 else:
928 assert isinstance(signature_key, rsa.RSAPublicKey)
929 if self._inner_sig_type == _SSH_RSA:
930 hash_alg = hashes.SHA1()
931 elif self._inner_sig_type == _SSH_RSA_SHA256:
932 hash_alg = hashes.SHA256()
933 else:
934 assert self._inner_sig_type == _SSH_RSA_SHA512
935 hash_alg = hashes.SHA512()
936 signature_key.verify(
937 bytes(self._signature),
938 bytes(self._tbs_cert_body),
939 padding.PKCS1v15(),
940 hash_alg,
941 )
944def _get_ec_hash_alg(curve: ec.EllipticCurve) -> hashes.HashAlgorithm:
945 if isinstance(curve, ec.SECP256R1):
946 return hashes.SHA256()
947 elif isinstance(curve, ec.SECP384R1):
948 return hashes.SHA384()
949 else:
950 assert isinstance(curve, ec.SECP521R1)
951 return hashes.SHA512()
954def _load_ssh_public_identity(
955 data: bytes,
956 _legacy_dsa_allowed=False,
957) -> typing.Union[SSHCertificate, SSHPublicKeyTypes]:
958 utils._check_byteslike("data", data)
960 m = _SSH_PUBKEY_RC.match(data)
961 if not m:
962 raise ValueError("Invalid line format")
963 key_type = orig_key_type = m.group(1)
964 key_body = m.group(2)
965 with_cert = False
966 if key_type.endswith(_CERT_SUFFIX):
967 with_cert = True
968 key_type = key_type[: -len(_CERT_SUFFIX)]
969 if key_type == _SSH_DSA and not _legacy_dsa_allowed:
970 raise UnsupportedAlgorithm(
971 "DSA keys aren't supported in SSH certificates"
972 )
973 kformat = _lookup_kformat(key_type)
975 try:
976 rest = memoryview(binascii.a2b_base64(key_body))
977 except (TypeError, binascii.Error):
978 raise ValueError("Invalid format")
980 if with_cert:
981 cert_body = rest
982 inner_key_type, rest = _get_sshstr(rest)
983 if inner_key_type != orig_key_type:
984 raise ValueError("Invalid key format")
985 if with_cert:
986 nonce, rest = _get_sshstr(rest)
987 public_key, rest = kformat.load_public(rest)
988 if with_cert:
989 serial, rest = _get_u64(rest)
990 cctype, rest = _get_u32(rest)
991 key_id, rest = _get_sshstr(rest)
992 principals, rest = _get_sshstr(rest)
993 valid_principals = []
994 while principals:
995 principal, principals = _get_sshstr(principals)
996 valid_principals.append(bytes(principal))
997 valid_after, rest = _get_u64(rest)
998 valid_before, rest = _get_u64(rest)
999 crit_options, rest = _get_sshstr(rest)
1000 critical_options = _parse_exts_opts(crit_options)
1001 exts, rest = _get_sshstr(rest)
1002 extensions = _parse_exts_opts(exts)
1003 # Get the reserved field, which is unused.
1004 _, rest = _get_sshstr(rest)
1005 sig_key_raw, rest = _get_sshstr(rest)
1006 sig_type, sig_key = _get_sshstr(sig_key_raw)
1007 if sig_type == _SSH_DSA and not _legacy_dsa_allowed:
1008 raise UnsupportedAlgorithm(
1009 "DSA signatures aren't supported in SSH certificates"
1010 )
1011 # Get the entire cert body and subtract the signature
1012 tbs_cert_body = cert_body[: -len(rest)]
1013 signature_raw, rest = _get_sshstr(rest)
1014 _check_empty(rest)
1015 inner_sig_type, sig_rest = _get_sshstr(signature_raw)
1016 # RSA certs can have multiple algorithm types
1017 if (
1018 sig_type == _SSH_RSA
1019 and inner_sig_type
1020 not in [_SSH_RSA_SHA256, _SSH_RSA_SHA512, _SSH_RSA]
1021 ) or (sig_type != _SSH_RSA and inner_sig_type != sig_type):
1022 raise ValueError("Signature key type does not match")
1023 signature, sig_rest = _get_sshstr(sig_rest)
1024 _check_empty(sig_rest)
1025 return SSHCertificate(
1026 nonce,
1027 public_key,
1028 serial,
1029 cctype,
1030 key_id,
1031 valid_principals,
1032 valid_after,
1033 valid_before,
1034 critical_options,
1035 extensions,
1036 sig_type,
1037 sig_key,
1038 inner_sig_type,
1039 signature,
1040 tbs_cert_body,
1041 orig_key_type,
1042 cert_body,
1043 )
1044 else:
1045 _check_empty(rest)
1046 return public_key
1049def load_ssh_public_identity(
1050 data: bytes,
1051) -> typing.Union[SSHCertificate, SSHPublicKeyTypes]:
1052 return _load_ssh_public_identity(data)
1055def _parse_exts_opts(exts_opts: memoryview) -> typing.Dict[bytes, bytes]:
1056 result: typing.Dict[bytes, bytes] = {}
1057 last_name = None
1058 while exts_opts:
1059 name, exts_opts = _get_sshstr(exts_opts)
1060 bname: bytes = bytes(name)
1061 if bname in result:
1062 raise ValueError("Duplicate name")
1063 if last_name is not None and bname < last_name:
1064 raise ValueError("Fields not lexically sorted")
1065 value, exts_opts = _get_sshstr(exts_opts)
1066 result[bname] = bytes(value)
1067 last_name = bname
1068 return result
1071def load_ssh_public_key(
1072 data: bytes, backend: typing.Any = None
1073) -> SSHPublicKeyTypes:
1074 cert_or_key = _load_ssh_public_identity(data, _legacy_dsa_allowed=True)
1075 public_key: SSHPublicKeyTypes
1076 if isinstance(cert_or_key, SSHCertificate):
1077 public_key = cert_or_key.public_key()
1078 else:
1079 public_key = cert_or_key
1081 if isinstance(public_key, dsa.DSAPublicKey):
1082 warnings.warn(
1083 "SSH DSA keys are deprecated and will be removed in a future "
1084 "release.",
1085 utils.DeprecatedIn40,
1086 stacklevel=2,
1087 )
1088 return public_key
1091def serialize_ssh_public_key(public_key: SSHPublicKeyTypes) -> bytes:
1092 """One-line public key format for OpenSSH"""
1093 if isinstance(public_key, dsa.DSAPublicKey):
1094 warnings.warn(
1095 "SSH DSA key support is deprecated and will be "
1096 "removed in a future release",
1097 utils.DeprecatedIn40,
1098 stacklevel=4,
1099 )
1100 key_type = _get_ssh_key_type(public_key)
1101 kformat = _lookup_kformat(key_type)
1103 f_pub = _FragList()
1104 f_pub.put_sshstr(key_type)
1105 kformat.encode_public(public_key, f_pub)
1107 pub = binascii.b2a_base64(f_pub.tobytes()).strip()
1108 return b"".join([key_type, b" ", pub])
1111SSHCertPrivateKeyTypes = typing.Union[
1112 ec.EllipticCurvePrivateKey,
1113 rsa.RSAPrivateKey,
1114 ed25519.Ed25519PrivateKey,
1115]
1118# This is an undocumented limit enforced in the openssh codebase for sshd and
1119# ssh-keygen, but it is undefined in the ssh certificates spec.
1120_SSHKEY_CERT_MAX_PRINCIPALS = 256
1123class SSHCertificateBuilder:
1124 def __init__(
1125 self,
1126 _public_key: typing.Optional[SSHCertPublicKeyTypes] = None,
1127 _serial: typing.Optional[int] = None,
1128 _type: typing.Optional[SSHCertificateType] = None,
1129 _key_id: typing.Optional[bytes] = None,
1130 _valid_principals: typing.List[bytes] = [],
1131 _valid_for_all_principals: bool = False,
1132 _valid_before: typing.Optional[int] = None,
1133 _valid_after: typing.Optional[int] = None,
1134 _critical_options: typing.List[typing.Tuple[bytes, bytes]] = [],
1135 _extensions: typing.List[typing.Tuple[bytes, bytes]] = [],
1136 ):
1137 self._public_key = _public_key
1138 self._serial = _serial
1139 self._type = _type
1140 self._key_id = _key_id
1141 self._valid_principals = _valid_principals
1142 self._valid_for_all_principals = _valid_for_all_principals
1143 self._valid_before = _valid_before
1144 self._valid_after = _valid_after
1145 self._critical_options = _critical_options
1146 self._extensions = _extensions
1148 def public_key(
1149 self, public_key: SSHCertPublicKeyTypes
1150 ) -> SSHCertificateBuilder:
1151 if not isinstance(
1152 public_key,
1153 (
1154 ec.EllipticCurvePublicKey,
1155 rsa.RSAPublicKey,
1156 ed25519.Ed25519PublicKey,
1157 ),
1158 ):
1159 raise TypeError("Unsupported key type")
1160 if self._public_key is not None:
1161 raise ValueError("public_key already set")
1163 return SSHCertificateBuilder(
1164 _public_key=public_key,
1165 _serial=self._serial,
1166 _type=self._type,
1167 _key_id=self._key_id,
1168 _valid_principals=self._valid_principals,
1169 _valid_for_all_principals=self._valid_for_all_principals,
1170 _valid_before=self._valid_before,
1171 _valid_after=self._valid_after,
1172 _critical_options=self._critical_options,
1173 _extensions=self._extensions,
1174 )
1176 def serial(self, serial: int) -> SSHCertificateBuilder:
1177 if not isinstance(serial, int):
1178 raise TypeError("serial must be an integer")
1179 if not 0 <= serial < 2**64:
1180 raise ValueError("serial must be between 0 and 2**64")
1181 if self._serial is not None:
1182 raise ValueError("serial already set")
1184 return SSHCertificateBuilder(
1185 _public_key=self._public_key,
1186 _serial=serial,
1187 _type=self._type,
1188 _key_id=self._key_id,
1189 _valid_principals=self._valid_principals,
1190 _valid_for_all_principals=self._valid_for_all_principals,
1191 _valid_before=self._valid_before,
1192 _valid_after=self._valid_after,
1193 _critical_options=self._critical_options,
1194 _extensions=self._extensions,
1195 )
1197 def type(self, type: SSHCertificateType) -> SSHCertificateBuilder:
1198 if not isinstance(type, SSHCertificateType):
1199 raise TypeError("type must be an SSHCertificateType")
1200 if self._type is not None:
1201 raise ValueError("type already set")
1203 return SSHCertificateBuilder(
1204 _public_key=self._public_key,
1205 _serial=self._serial,
1206 _type=type,
1207 _key_id=self._key_id,
1208 _valid_principals=self._valid_principals,
1209 _valid_for_all_principals=self._valid_for_all_principals,
1210 _valid_before=self._valid_before,
1211 _valid_after=self._valid_after,
1212 _critical_options=self._critical_options,
1213 _extensions=self._extensions,
1214 )
1216 def key_id(self, key_id: bytes) -> SSHCertificateBuilder:
1217 if not isinstance(key_id, bytes):
1218 raise TypeError("key_id must be bytes")
1219 if self._key_id is not None:
1220 raise ValueError("key_id already set")
1222 return SSHCertificateBuilder(
1223 _public_key=self._public_key,
1224 _serial=self._serial,
1225 _type=self._type,
1226 _key_id=key_id,
1227 _valid_principals=self._valid_principals,
1228 _valid_for_all_principals=self._valid_for_all_principals,
1229 _valid_before=self._valid_before,
1230 _valid_after=self._valid_after,
1231 _critical_options=self._critical_options,
1232 _extensions=self._extensions,
1233 )
1235 def valid_principals(
1236 self, valid_principals: typing.List[bytes]
1237 ) -> SSHCertificateBuilder:
1238 if self._valid_for_all_principals:
1239 raise ValueError(
1240 "Principals can't be set because the cert is valid "
1241 "for all principals"
1242 )
1243 if (
1244 not all(isinstance(x, bytes) for x in valid_principals)
1245 or not valid_principals
1246 ):
1247 raise TypeError(
1248 "principals must be a list of bytes and can't be empty"
1249 )
1250 if self._valid_principals:
1251 raise ValueError("valid_principals already set")
1253 if len(valid_principals) > _SSHKEY_CERT_MAX_PRINCIPALS:
1254 raise ValueError(
1255 "Reached or exceeded the maximum number of valid_principals"
1256 )
1258 return SSHCertificateBuilder(
1259 _public_key=self._public_key,
1260 _serial=self._serial,
1261 _type=self._type,
1262 _key_id=self._key_id,
1263 _valid_principals=valid_principals,
1264 _valid_for_all_principals=self._valid_for_all_principals,
1265 _valid_before=self._valid_before,
1266 _valid_after=self._valid_after,
1267 _critical_options=self._critical_options,
1268 _extensions=self._extensions,
1269 )
1271 def valid_for_all_principals(self):
1272 if self._valid_principals:
1273 raise ValueError(
1274 "valid_principals already set, can't set "
1275 "valid_for_all_principals"
1276 )
1277 if self._valid_for_all_principals:
1278 raise ValueError("valid_for_all_principals already set")
1280 return SSHCertificateBuilder(
1281 _public_key=self._public_key,
1282 _serial=self._serial,
1283 _type=self._type,
1284 _key_id=self._key_id,
1285 _valid_principals=self._valid_principals,
1286 _valid_for_all_principals=True,
1287 _valid_before=self._valid_before,
1288 _valid_after=self._valid_after,
1289 _critical_options=self._critical_options,
1290 _extensions=self._extensions,
1291 )
1293 def valid_before(
1294 self, valid_before: typing.Union[int, float]
1295 ) -> SSHCertificateBuilder:
1296 if not isinstance(valid_before, (int, float)):
1297 raise TypeError("valid_before must be an int or float")
1298 valid_before = int(valid_before)
1299 if valid_before < 0 or valid_before >= 2**64:
1300 raise ValueError("valid_before must [0, 2**64)")
1301 if self._valid_before is not None:
1302 raise ValueError("valid_before already set")
1304 return SSHCertificateBuilder(
1305 _public_key=self._public_key,
1306 _serial=self._serial,
1307 _type=self._type,
1308 _key_id=self._key_id,
1309 _valid_principals=self._valid_principals,
1310 _valid_for_all_principals=self._valid_for_all_principals,
1311 _valid_before=valid_before,
1312 _valid_after=self._valid_after,
1313 _critical_options=self._critical_options,
1314 _extensions=self._extensions,
1315 )
1317 def valid_after(
1318 self, valid_after: typing.Union[int, float]
1319 ) -> SSHCertificateBuilder:
1320 if not isinstance(valid_after, (int, float)):
1321 raise TypeError("valid_after must be an int or float")
1322 valid_after = int(valid_after)
1323 if valid_after < 0 or valid_after >= 2**64:
1324 raise ValueError("valid_after must [0, 2**64)")
1325 if self._valid_after is not None:
1326 raise ValueError("valid_after already set")
1328 return SSHCertificateBuilder(
1329 _public_key=self._public_key,
1330 _serial=self._serial,
1331 _type=self._type,
1332 _key_id=self._key_id,
1333 _valid_principals=self._valid_principals,
1334 _valid_for_all_principals=self._valid_for_all_principals,
1335 _valid_before=self._valid_before,
1336 _valid_after=valid_after,
1337 _critical_options=self._critical_options,
1338 _extensions=self._extensions,
1339 )
1341 def add_critical_option(
1342 self, name: bytes, value: bytes
1343 ) -> SSHCertificateBuilder:
1344 if not isinstance(name, bytes) or not isinstance(value, bytes):
1345 raise TypeError("name and value must be bytes")
1346 # This is O(n**2)
1347 if name in [name for name, _ in self._critical_options]:
1348 raise ValueError("Duplicate critical option name")
1350 return SSHCertificateBuilder(
1351 _public_key=self._public_key,
1352 _serial=self._serial,
1353 _type=self._type,
1354 _key_id=self._key_id,
1355 _valid_principals=self._valid_principals,
1356 _valid_for_all_principals=self._valid_for_all_principals,
1357 _valid_before=self._valid_before,
1358 _valid_after=self._valid_after,
1359 _critical_options=self._critical_options + [(name, value)],
1360 _extensions=self._extensions,
1361 )
1363 def add_extension(
1364 self, name: bytes, value: bytes
1365 ) -> SSHCertificateBuilder:
1366 if not isinstance(name, bytes) or not isinstance(value, bytes):
1367 raise TypeError("name and value must be bytes")
1368 # This is O(n**2)
1369 if name in [name for name, _ in self._extensions]:
1370 raise ValueError("Duplicate extension name")
1372 return SSHCertificateBuilder(
1373 _public_key=self._public_key,
1374 _serial=self._serial,
1375 _type=self._type,
1376 _key_id=self._key_id,
1377 _valid_principals=self._valid_principals,
1378 _valid_for_all_principals=self._valid_for_all_principals,
1379 _valid_before=self._valid_before,
1380 _valid_after=self._valid_after,
1381 _critical_options=self._critical_options,
1382 _extensions=self._extensions + [(name, value)],
1383 )
1385 def sign(self, private_key: SSHCertPrivateKeyTypes) -> SSHCertificate:
1386 if not isinstance(
1387 private_key,
1388 (
1389 ec.EllipticCurvePrivateKey,
1390 rsa.RSAPrivateKey,
1391 ed25519.Ed25519PrivateKey,
1392 ),
1393 ):
1394 raise TypeError("Unsupported private key type")
1396 if self._public_key is None:
1397 raise ValueError("public_key must be set")
1399 # Not required
1400 serial = 0 if self._serial is None else self._serial
1402 if self._type is None:
1403 raise ValueError("type must be set")
1405 # Not required
1406 key_id = b"" if self._key_id is None else self._key_id
1408 # A zero length list is valid, but means the certificate
1409 # is valid for any principal of the specified type. We require
1410 # the user to explicitly set valid_for_all_principals to get
1411 # that behavior.
1412 if not self._valid_principals and not self._valid_for_all_principals:
1413 raise ValueError(
1414 "valid_principals must be set if valid_for_all_principals "
1415 "is False"
1416 )
1418 if self._valid_before is None:
1419 raise ValueError("valid_before must be set")
1421 if self._valid_after is None:
1422 raise ValueError("valid_after must be set")
1424 if self._valid_after > self._valid_before:
1425 raise ValueError("valid_after must be earlier than valid_before")
1427 # lexically sort our byte strings
1428 self._critical_options.sort(key=lambda x: x[0])
1429 self._extensions.sort(key=lambda x: x[0])
1431 key_type = _get_ssh_key_type(self._public_key)
1432 cert_prefix = key_type + _CERT_SUFFIX
1434 # Marshal the bytes to be signed
1435 nonce = os.urandom(32)
1436 kformat = _lookup_kformat(key_type)
1437 f = _FragList()
1438 f.put_sshstr(cert_prefix)
1439 f.put_sshstr(nonce)
1440 kformat.encode_public(self._public_key, f)
1441 f.put_u64(serial)
1442 f.put_u32(self._type.value)
1443 f.put_sshstr(key_id)
1444 fprincipals = _FragList()
1445 for p in self._valid_principals:
1446 fprincipals.put_sshstr(p)
1447 f.put_sshstr(fprincipals.tobytes())
1448 f.put_u64(self._valid_after)
1449 f.put_u64(self._valid_before)
1450 fcrit = _FragList()
1451 for name, value in self._critical_options:
1452 fcrit.put_sshstr(name)
1453 fcrit.put_sshstr(value)
1454 f.put_sshstr(fcrit.tobytes())
1455 fext = _FragList()
1456 for name, value in self._extensions:
1457 fext.put_sshstr(name)
1458 fext.put_sshstr(value)
1459 f.put_sshstr(fext.tobytes())
1460 f.put_sshstr(b"") # RESERVED FIELD
1461 # encode CA public key
1462 ca_type = _get_ssh_key_type(private_key)
1463 caformat = _lookup_kformat(ca_type)
1464 caf = _FragList()
1465 caf.put_sshstr(ca_type)
1466 caformat.encode_public(private_key.public_key(), caf)
1467 f.put_sshstr(caf.tobytes())
1468 # Sigs according to the rules defined for the CA's public key
1469 # (RFC4253 section 6.6 for ssh-rsa, RFC5656 for ECDSA,
1470 # and RFC8032 for Ed25519).
1471 if isinstance(private_key, ed25519.Ed25519PrivateKey):
1472 signature = private_key.sign(f.tobytes())
1473 fsig = _FragList()
1474 fsig.put_sshstr(ca_type)
1475 fsig.put_sshstr(signature)
1476 f.put_sshstr(fsig.tobytes())
1477 elif isinstance(private_key, ec.EllipticCurvePrivateKey):
1478 hash_alg = _get_ec_hash_alg(private_key.curve)
1479 signature = private_key.sign(f.tobytes(), ec.ECDSA(hash_alg))
1480 r, s = asym_utils.decode_dss_signature(signature)
1481 fsig = _FragList()
1482 fsig.put_sshstr(ca_type)
1483 fsigblob = _FragList()
1484 fsigblob.put_mpint(r)
1485 fsigblob.put_mpint(s)
1486 fsig.put_sshstr(fsigblob.tobytes())
1487 f.put_sshstr(fsig.tobytes())
1489 else:
1490 assert isinstance(private_key, rsa.RSAPrivateKey)
1491 # Just like Golang, we're going to use SHA512 for RSA
1492 # https://cs.opensource.google/go/x/crypto/+/refs/tags/
1493 # v0.4.0:ssh/certs.go;l=445
1494 # RFC 8332 defines SHA256 and 512 as options
1495 fsig = _FragList()
1496 fsig.put_sshstr(_SSH_RSA_SHA512)
1497 signature = private_key.sign(
1498 f.tobytes(), padding.PKCS1v15(), hashes.SHA512()
1499 )
1500 fsig.put_sshstr(signature)
1501 f.put_sshstr(fsig.tobytes())
1503 cert_data = binascii.b2a_base64(f.tobytes()).strip()
1504 # load_ssh_public_identity returns a union, but this is
1505 # guaranteed to be an SSHCertificate, so we cast to make
1506 # mypy happy.
1507 return typing.cast(
1508 SSHCertificate,
1509 load_ssh_public_identity(b"".join([cert_prefix, b" ", cert_data])),
1510 )