Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/cryptography/hazmat/primitives/serialization/ssh.py: 20%
761 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:05 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:05 +0000
1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
5from __future__ import annotations
7import binascii
8import enum
9import os
10import re
11import typing
12import warnings
13from base64 import encodebytes as _base64_encode
14from dataclasses import dataclass
16from cryptography import utils
17from cryptography.exceptions import UnsupportedAlgorithm
18from cryptography.hazmat.primitives import hashes
19from cryptography.hazmat.primitives.asymmetric import (
20 dsa,
21 ec,
22 ed25519,
23 padding,
24 rsa,
25)
26from cryptography.hazmat.primitives.asymmetric import utils as asym_utils
27from cryptography.hazmat.primitives.ciphers import (
28 AEADDecryptionContext,
29 Cipher,
30 algorithms,
31 modes,
32)
33from cryptography.hazmat.primitives.serialization import (
34 Encoding,
35 KeySerializationEncryption,
36 NoEncryption,
37 PrivateFormat,
38 PublicFormat,
39 _KeySerializationEncryption,
40)
42try:
43 from bcrypt import kdf as _bcrypt_kdf
45 _bcrypt_supported = True
46except ImportError:
47 _bcrypt_supported = False
49 def _bcrypt_kdf(
50 password: bytes,
51 salt: bytes,
52 desired_key_bytes: int,
53 rounds: int,
54 ignore_few_rounds: bool = False,
55 ) -> bytes:
56 raise UnsupportedAlgorithm("Need bcrypt module")
59_SSH_ED25519 = b"ssh-ed25519"
60_SSH_RSA = b"ssh-rsa"
61_SSH_DSA = b"ssh-dss"
62_ECDSA_NISTP256 = b"ecdsa-sha2-nistp256"
63_ECDSA_NISTP384 = b"ecdsa-sha2-nistp384"
64_ECDSA_NISTP521 = b"ecdsa-sha2-nistp521"
65_CERT_SUFFIX = b"-cert-v01@openssh.com"
67# These are not key types, only algorithms, so they cannot appear
68# as a public key type
69_SSH_RSA_SHA256 = b"rsa-sha2-256"
70_SSH_RSA_SHA512 = b"rsa-sha2-512"
72_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)")
73_SK_MAGIC = b"openssh-key-v1\0"
74_SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----"
75_SK_END = b"-----END OPENSSH PRIVATE KEY-----"
76_BCRYPT = b"bcrypt"
77_NONE = b"none"
78_DEFAULT_CIPHER = b"aes256-ctr"
79_DEFAULT_ROUNDS = 16
81# re is only way to work on bytes-like data
82_PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL)
84# padding for max blocksize
85_PADDING = memoryview(bytearray(range(1, 1 + 16)))
88@dataclass
89class _SSHCipher:
90 alg: typing.Type[algorithms.AES]
91 key_len: int
92 mode: typing.Union[
93 typing.Type[modes.CTR],
94 typing.Type[modes.CBC],
95 typing.Type[modes.GCM],
96 ]
97 block_len: int
98 iv_len: int
99 tag_len: typing.Optional[int]
100 is_aead: bool
103# ciphers that are actually used in key wrapping
104_SSH_CIPHERS: typing.Dict[bytes, _SSHCipher] = {
105 b"aes256-ctr": _SSHCipher(
106 alg=algorithms.AES,
107 key_len=32,
108 mode=modes.CTR,
109 block_len=16,
110 iv_len=16,
111 tag_len=None,
112 is_aead=False,
113 ),
114 b"aes256-cbc": _SSHCipher(
115 alg=algorithms.AES,
116 key_len=32,
117 mode=modes.CBC,
118 block_len=16,
119 iv_len=16,
120 tag_len=None,
121 is_aead=False,
122 ),
123 b"aes256-gcm@openssh.com": _SSHCipher(
124 alg=algorithms.AES,
125 key_len=32,
126 mode=modes.GCM,
127 block_len=16,
128 iv_len=12,
129 tag_len=16,
130 is_aead=True,
131 ),
132}
134# map local curve name to key type
135_ECDSA_KEY_TYPE = {
136 "secp256r1": _ECDSA_NISTP256,
137 "secp384r1": _ECDSA_NISTP384,
138 "secp521r1": _ECDSA_NISTP521,
139}
142def _get_ssh_key_type(
143 key: typing.Union[SSHPrivateKeyTypes, SSHPublicKeyTypes]
144) -> bytes:
145 if isinstance(key, ec.EllipticCurvePrivateKey):
146 key_type = _ecdsa_key_type(key.public_key())
147 elif isinstance(key, ec.EllipticCurvePublicKey):
148 key_type = _ecdsa_key_type(key)
149 elif isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)):
150 key_type = _SSH_RSA
151 elif isinstance(key, (dsa.DSAPrivateKey, dsa.DSAPublicKey)):
152 key_type = _SSH_DSA
153 elif isinstance(
154 key, (ed25519.Ed25519PrivateKey, ed25519.Ed25519PublicKey)
155 ):
156 key_type = _SSH_ED25519
157 else:
158 raise ValueError("Unsupported key type")
160 return key_type
163def _ecdsa_key_type(public_key: ec.EllipticCurvePublicKey) -> bytes:
164 """Return SSH key_type and curve_name for private key."""
165 curve = public_key.curve
166 if curve.name not in _ECDSA_KEY_TYPE:
167 raise ValueError(
168 f"Unsupported curve for ssh private key: {curve.name!r}"
169 )
170 return _ECDSA_KEY_TYPE[curve.name]
173def _ssh_pem_encode(
174 data: bytes,
175 prefix: bytes = _SK_START + b"\n",
176 suffix: bytes = _SK_END + b"\n",
177) -> bytes:
178 return b"".join([prefix, _base64_encode(data), suffix])
181def _check_block_size(data: bytes, block_len: int) -> None:
182 """Require data to be full blocks"""
183 if not data or len(data) % block_len != 0:
184 raise ValueError("Corrupt data: missing padding")
187def _check_empty(data: bytes) -> None:
188 """All data should have been parsed."""
189 if data:
190 raise ValueError("Corrupt data: unparsed data")
193def _init_cipher(
194 ciphername: bytes,
195 password: typing.Optional[bytes],
196 salt: bytes,
197 rounds: int,
198) -> Cipher[typing.Union[modes.CBC, modes.CTR, modes.GCM]]:
199 """Generate key + iv and return cipher."""
200 if not password:
201 raise ValueError("Key is password-protected.")
203 ciph = _SSH_CIPHERS[ciphername]
204 seed = _bcrypt_kdf(
205 password, salt, ciph.key_len + ciph.iv_len, rounds, True
206 )
207 return Cipher(
208 ciph.alg(seed[: ciph.key_len]),
209 ciph.mode(seed[ciph.key_len :]),
210 )
213def _get_u32(data: memoryview) -> typing.Tuple[int, memoryview]:
214 """Uint32"""
215 if len(data) < 4:
216 raise ValueError("Invalid data")
217 return int.from_bytes(data[:4], byteorder="big"), data[4:]
220def _get_u64(data: memoryview) -> typing.Tuple[int, memoryview]:
221 """Uint64"""
222 if len(data) < 8:
223 raise ValueError("Invalid data")
224 return int.from_bytes(data[:8], byteorder="big"), data[8:]
227def _get_sshstr(data: memoryview) -> typing.Tuple[memoryview, memoryview]:
228 """Bytes with u32 length prefix"""
229 n, data = _get_u32(data)
230 if n > len(data):
231 raise ValueError("Invalid data")
232 return data[:n], data[n:]
235def _get_mpint(data: memoryview) -> typing.Tuple[int, memoryview]:
236 """Big integer."""
237 val, data = _get_sshstr(data)
238 if val and val[0] > 0x7F:
239 raise ValueError("Invalid data")
240 return int.from_bytes(val, "big"), data
243def _to_mpint(val: int) -> bytes:
244 """Storage format for signed bigint."""
245 if val < 0:
246 raise ValueError("negative mpint not allowed")
247 if not val:
248 return b""
249 nbytes = (val.bit_length() + 8) // 8
250 return utils.int_to_bytes(val, nbytes)
253class _FragList:
254 """Build recursive structure without data copy."""
256 flist: typing.List[bytes]
258 def __init__(
259 self, init: typing.Optional[typing.List[bytes]] = None
260 ) -> None:
261 self.flist = []
262 if init:
263 self.flist.extend(init)
265 def put_raw(self, val: bytes) -> None:
266 """Add plain bytes"""
267 self.flist.append(val)
269 def put_u32(self, val: int) -> None:
270 """Big-endian uint32"""
271 self.flist.append(val.to_bytes(length=4, byteorder="big"))
273 def put_u64(self, val: int) -> None:
274 """Big-endian uint64"""
275 self.flist.append(val.to_bytes(length=8, byteorder="big"))
277 def put_sshstr(self, val: typing.Union[bytes, _FragList]) -> None:
278 """Bytes prefixed with u32 length"""
279 if isinstance(val, (bytes, memoryview, bytearray)):
280 self.put_u32(len(val))
281 self.flist.append(val)
282 else:
283 self.put_u32(val.size())
284 self.flist.extend(val.flist)
286 def put_mpint(self, val: int) -> None:
287 """Big-endian bigint prefixed with u32 length"""
288 self.put_sshstr(_to_mpint(val))
290 def size(self) -> int:
291 """Current number of bytes"""
292 return sum(map(len, self.flist))
294 def render(self, dstbuf: memoryview, pos: int = 0) -> int:
295 """Write into bytearray"""
296 for frag in self.flist:
297 flen = len(frag)
298 start, pos = pos, pos + flen
299 dstbuf[start:pos] = frag
300 return pos
302 def tobytes(self) -> bytes:
303 """Return as bytes"""
304 buf = memoryview(bytearray(self.size()))
305 self.render(buf)
306 return buf.tobytes()
309class _SSHFormatRSA:
310 """Format for RSA keys.
312 Public:
313 mpint e, n
314 Private:
315 mpint n, e, d, iqmp, p, q
316 """
318 def get_public(self, data: memoryview):
319 """RSA public fields"""
320 e, data = _get_mpint(data)
321 n, data = _get_mpint(data)
322 return (e, n), data
324 def load_public(
325 self, data: memoryview
326 ) -> typing.Tuple[rsa.RSAPublicKey, memoryview]:
327 """Make RSA public key from data."""
328 (e, n), data = self.get_public(data)
329 public_numbers = rsa.RSAPublicNumbers(e, n)
330 public_key = public_numbers.public_key()
331 return public_key, data
333 def load_private(
334 self, data: memoryview, pubfields
335 ) -> typing.Tuple[rsa.RSAPrivateKey, memoryview]:
336 """Make RSA private key from data."""
337 n, data = _get_mpint(data)
338 e, data = _get_mpint(data)
339 d, data = _get_mpint(data)
340 iqmp, data = _get_mpint(data)
341 p, data = _get_mpint(data)
342 q, data = _get_mpint(data)
344 if (e, n) != pubfields:
345 raise ValueError("Corrupt data: rsa field mismatch")
346 dmp1 = rsa.rsa_crt_dmp1(d, p)
347 dmq1 = rsa.rsa_crt_dmq1(d, q)
348 public_numbers = rsa.RSAPublicNumbers(e, n)
349 private_numbers = rsa.RSAPrivateNumbers(
350 p, q, d, dmp1, dmq1, iqmp, public_numbers
351 )
352 private_key = private_numbers.private_key()
353 return private_key, data
355 def encode_public(
356 self, public_key: rsa.RSAPublicKey, f_pub: _FragList
357 ) -> None:
358 """Write RSA public key"""
359 pubn = public_key.public_numbers()
360 f_pub.put_mpint(pubn.e)
361 f_pub.put_mpint(pubn.n)
363 def encode_private(
364 self, private_key: rsa.RSAPrivateKey, f_priv: _FragList
365 ) -> None:
366 """Write RSA private key"""
367 private_numbers = private_key.private_numbers()
368 public_numbers = private_numbers.public_numbers
370 f_priv.put_mpint(public_numbers.n)
371 f_priv.put_mpint(public_numbers.e)
373 f_priv.put_mpint(private_numbers.d)
374 f_priv.put_mpint(private_numbers.iqmp)
375 f_priv.put_mpint(private_numbers.p)
376 f_priv.put_mpint(private_numbers.q)
379class _SSHFormatDSA:
380 """Format for DSA keys.
382 Public:
383 mpint p, q, g, y
384 Private:
385 mpint p, q, g, y, x
386 """
388 def get_public(
389 self, data: memoryview
390 ) -> typing.Tuple[typing.Tuple, memoryview]:
391 """DSA public fields"""
392 p, data = _get_mpint(data)
393 q, data = _get_mpint(data)
394 g, data = _get_mpint(data)
395 y, data = _get_mpint(data)
396 return (p, q, g, y), data
398 def load_public(
399 self, data: memoryview
400 ) -> typing.Tuple[dsa.DSAPublicKey, memoryview]:
401 """Make DSA public key from data."""
402 (p, q, g, y), data = self.get_public(data)
403 parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
404 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
405 self._validate(public_numbers)
406 public_key = public_numbers.public_key()
407 return public_key, data
409 def load_private(
410 self, data: memoryview, pubfields
411 ) -> typing.Tuple[dsa.DSAPrivateKey, memoryview]:
412 """Make DSA private key from data."""
413 (p, q, g, y), data = self.get_public(data)
414 x, data = _get_mpint(data)
416 if (p, q, g, y) != pubfields:
417 raise ValueError("Corrupt data: dsa field mismatch")
418 parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
419 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
420 self._validate(public_numbers)
421 private_numbers = dsa.DSAPrivateNumbers(x, public_numbers)
422 private_key = private_numbers.private_key()
423 return private_key, data
425 def encode_public(
426 self, public_key: dsa.DSAPublicKey, f_pub: _FragList
427 ) -> None:
428 """Write DSA public key"""
429 public_numbers = public_key.public_numbers()
430 parameter_numbers = public_numbers.parameter_numbers
431 self._validate(public_numbers)
433 f_pub.put_mpint(parameter_numbers.p)
434 f_pub.put_mpint(parameter_numbers.q)
435 f_pub.put_mpint(parameter_numbers.g)
436 f_pub.put_mpint(public_numbers.y)
438 def encode_private(
439 self, private_key: dsa.DSAPrivateKey, f_priv: _FragList
440 ) -> None:
441 """Write DSA private key"""
442 self.encode_public(private_key.public_key(), f_priv)
443 f_priv.put_mpint(private_key.private_numbers().x)
445 def _validate(self, public_numbers: dsa.DSAPublicNumbers) -> None:
446 parameter_numbers = public_numbers.parameter_numbers
447 if parameter_numbers.p.bit_length() != 1024:
448 raise ValueError("SSH supports only 1024 bit DSA keys")
451class _SSHFormatECDSA:
452 """Format for ECDSA keys.
454 Public:
455 str curve
456 bytes point
457 Private:
458 str curve
459 bytes point
460 mpint secret
461 """
463 def __init__(self, ssh_curve_name: bytes, curve: ec.EllipticCurve):
464 self.ssh_curve_name = ssh_curve_name
465 self.curve = curve
467 def get_public(
468 self, data: memoryview
469 ) -> typing.Tuple[typing.Tuple, memoryview]:
470 """ECDSA public fields"""
471 curve, data = _get_sshstr(data)
472 point, data = _get_sshstr(data)
473 if curve != self.ssh_curve_name:
474 raise ValueError("Curve name mismatch")
475 if point[0] != 4:
476 raise NotImplementedError("Need uncompressed point")
477 return (curve, point), data
479 def load_public(
480 self, data: memoryview
481 ) -> typing.Tuple[ec.EllipticCurvePublicKey, memoryview]:
482 """Make ECDSA public key from data."""
483 (curve_name, point), data = self.get_public(data)
484 public_key = ec.EllipticCurvePublicKey.from_encoded_point(
485 self.curve, point.tobytes()
486 )
487 return public_key, data
489 def load_private(
490 self, data: memoryview, pubfields
491 ) -> typing.Tuple[ec.EllipticCurvePrivateKey, memoryview]:
492 """Make ECDSA private key from data."""
493 (curve_name, point), data = self.get_public(data)
494 secret, data = _get_mpint(data)
496 if (curve_name, point) != pubfields:
497 raise ValueError("Corrupt data: ecdsa field mismatch")
498 private_key = ec.derive_private_key(secret, self.curve)
499 return private_key, data
501 def encode_public(
502 self, public_key: ec.EllipticCurvePublicKey, f_pub: _FragList
503 ) -> None:
504 """Write ECDSA public key"""
505 point = public_key.public_bytes(
506 Encoding.X962, PublicFormat.UncompressedPoint
507 )
508 f_pub.put_sshstr(self.ssh_curve_name)
509 f_pub.put_sshstr(point)
511 def encode_private(
512 self, private_key: ec.EllipticCurvePrivateKey, f_priv: _FragList
513 ) -> None:
514 """Write ECDSA private key"""
515 public_key = private_key.public_key()
516 private_numbers = private_key.private_numbers()
518 self.encode_public(public_key, f_priv)
519 f_priv.put_mpint(private_numbers.private_value)
522class _SSHFormatEd25519:
523 """Format for Ed25519 keys.
525 Public:
526 bytes point
527 Private:
528 bytes point
529 bytes secret_and_point
530 """
532 def get_public(
533 self, data: memoryview
534 ) -> typing.Tuple[typing.Tuple, memoryview]:
535 """Ed25519 public fields"""
536 point, data = _get_sshstr(data)
537 return (point,), data
539 def load_public(
540 self, data: memoryview
541 ) -> typing.Tuple[ed25519.Ed25519PublicKey, memoryview]:
542 """Make Ed25519 public key from data."""
543 (point,), data = self.get_public(data)
544 public_key = ed25519.Ed25519PublicKey.from_public_bytes(
545 point.tobytes()
546 )
547 return public_key, data
549 def load_private(
550 self, data: memoryview, pubfields
551 ) -> typing.Tuple[ed25519.Ed25519PrivateKey, memoryview]:
552 """Make Ed25519 private key from data."""
553 (point,), data = self.get_public(data)
554 keypair, data = _get_sshstr(data)
556 secret = keypair[:32]
557 point2 = keypair[32:]
558 if point != point2 or (point,) != pubfields:
559 raise ValueError("Corrupt data: ed25519 field mismatch")
560 private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret)
561 return private_key, data
563 def encode_public(
564 self, public_key: ed25519.Ed25519PublicKey, f_pub: _FragList
565 ) -> None:
566 """Write Ed25519 public key"""
567 raw_public_key = public_key.public_bytes(
568 Encoding.Raw, PublicFormat.Raw
569 )
570 f_pub.put_sshstr(raw_public_key)
572 def encode_private(
573 self, private_key: ed25519.Ed25519PrivateKey, f_priv: _FragList
574 ) -> None:
575 """Write Ed25519 private key"""
576 public_key = private_key.public_key()
577 raw_private_key = private_key.private_bytes(
578 Encoding.Raw, PrivateFormat.Raw, NoEncryption()
579 )
580 raw_public_key = public_key.public_bytes(
581 Encoding.Raw, PublicFormat.Raw
582 )
583 f_keypair = _FragList([raw_private_key, raw_public_key])
585 self.encode_public(public_key, f_priv)
586 f_priv.put_sshstr(f_keypair)
589_KEY_FORMATS = {
590 _SSH_RSA: _SSHFormatRSA(),
591 _SSH_DSA: _SSHFormatDSA(),
592 _SSH_ED25519: _SSHFormatEd25519(),
593 _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()),
594 _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()),
595 _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()),
596}
599def _lookup_kformat(key_type: bytes):
600 """Return valid format or throw error"""
601 if not isinstance(key_type, bytes):
602 key_type = memoryview(key_type).tobytes()
603 if key_type in _KEY_FORMATS:
604 return _KEY_FORMATS[key_type]
605 raise UnsupportedAlgorithm(f"Unsupported key type: {key_type!r}")
608SSHPrivateKeyTypes = typing.Union[
609 ec.EllipticCurvePrivateKey,
610 rsa.RSAPrivateKey,
611 dsa.DSAPrivateKey,
612 ed25519.Ed25519PrivateKey,
613]
616def load_ssh_private_key(
617 data: bytes,
618 password: typing.Optional[bytes],
619 backend: typing.Any = None,
620) -> SSHPrivateKeyTypes:
621 """Load private key from OpenSSH custom encoding."""
622 utils._check_byteslike("data", data)
623 if password is not None:
624 utils._check_bytes("password", password)
626 m = _PEM_RC.search(data)
627 if not m:
628 raise ValueError("Not OpenSSH private key format")
629 p1 = m.start(1)
630 p2 = m.end(1)
631 data = binascii.a2b_base64(memoryview(data)[p1:p2])
632 if not data.startswith(_SK_MAGIC):
633 raise ValueError("Not OpenSSH private key format")
634 data = memoryview(data)[len(_SK_MAGIC) :]
636 # parse header
637 ciphername, data = _get_sshstr(data)
638 kdfname, data = _get_sshstr(data)
639 kdfoptions, data = _get_sshstr(data)
640 nkeys, data = _get_u32(data)
641 if nkeys != 1:
642 raise ValueError("Only one key supported")
644 # load public key data
645 pubdata, data = _get_sshstr(data)
646 pub_key_type, pubdata = _get_sshstr(pubdata)
647 kformat = _lookup_kformat(pub_key_type)
648 pubfields, pubdata = kformat.get_public(pubdata)
649 _check_empty(pubdata)
651 if (ciphername, kdfname) != (_NONE, _NONE):
652 ciphername_bytes = ciphername.tobytes()
653 if ciphername_bytes not in _SSH_CIPHERS:
654 raise UnsupportedAlgorithm(
655 f"Unsupported cipher: {ciphername_bytes!r}"
656 )
657 if kdfname != _BCRYPT:
658 raise UnsupportedAlgorithm(f"Unsupported KDF: {kdfname!r}")
659 blklen = _SSH_CIPHERS[ciphername_bytes].block_len
660 tag_len = _SSH_CIPHERS[ciphername_bytes].tag_len
661 # load secret data
662 edata, data = _get_sshstr(data)
663 # see https://bugzilla.mindrot.org/show_bug.cgi?id=3553 for
664 # information about how OpenSSH handles AEAD tags
665 if _SSH_CIPHERS[ciphername_bytes].is_aead:
666 tag = bytes(data)
667 if len(tag) != tag_len:
668 raise ValueError("Corrupt data: invalid tag length for cipher")
669 else:
670 _check_empty(data)
671 _check_block_size(edata, blklen)
672 salt, kbuf = _get_sshstr(kdfoptions)
673 rounds, kbuf = _get_u32(kbuf)
674 _check_empty(kbuf)
675 ciph = _init_cipher(ciphername_bytes, password, salt.tobytes(), rounds)
676 dec = ciph.decryptor()
677 edata = memoryview(dec.update(edata))
678 if _SSH_CIPHERS[ciphername_bytes].is_aead:
679 assert isinstance(dec, AEADDecryptionContext)
680 _check_empty(dec.finalize_with_tag(tag))
681 else:
682 # _check_block_size requires data to be a full block so there
683 # should be no output from finalize
684 _check_empty(dec.finalize())
685 else:
686 # load secret data
687 edata, data = _get_sshstr(data)
688 _check_empty(data)
689 blklen = 8
690 _check_block_size(edata, blklen)
691 ck1, edata = _get_u32(edata)
692 ck2, edata = _get_u32(edata)
693 if ck1 != ck2:
694 raise ValueError("Corrupt data: broken checksum")
696 # load per-key struct
697 key_type, edata = _get_sshstr(edata)
698 if key_type != pub_key_type:
699 raise ValueError("Corrupt data: key type mismatch")
700 private_key, edata = kformat.load_private(edata, pubfields)
701 comment, edata = _get_sshstr(edata)
703 # yes, SSH does padding check *after* all other parsing is done.
704 # need to follow as it writes zero-byte padding too.
705 if edata != _PADDING[: len(edata)]:
706 raise ValueError("Corrupt data: invalid padding")
708 if isinstance(private_key, dsa.DSAPrivateKey):
709 warnings.warn(
710 "SSH DSA keys are deprecated and will be removed in a future "
711 "release.",
712 utils.DeprecatedIn40,
713 stacklevel=2,
714 )
716 return private_key
719def _serialize_ssh_private_key(
720 private_key: SSHPrivateKeyTypes,
721 password: bytes,
722 encryption_algorithm: KeySerializationEncryption,
723) -> bytes:
724 """Serialize private key with OpenSSH custom encoding."""
725 utils._check_bytes("password", password)
726 if isinstance(private_key, dsa.DSAPrivateKey):
727 warnings.warn(
728 "SSH DSA key support is deprecated and will be "
729 "removed in a future release",
730 utils.DeprecatedIn40,
731 stacklevel=4,
732 )
734 key_type = _get_ssh_key_type(private_key)
735 kformat = _lookup_kformat(key_type)
737 # setup parameters
738 f_kdfoptions = _FragList()
739 if password:
740 ciphername = _DEFAULT_CIPHER
741 blklen = _SSH_CIPHERS[ciphername].block_len
742 kdfname = _BCRYPT
743 rounds = _DEFAULT_ROUNDS
744 if (
745 isinstance(encryption_algorithm, _KeySerializationEncryption)
746 and encryption_algorithm._kdf_rounds is not None
747 ):
748 rounds = encryption_algorithm._kdf_rounds
749 salt = os.urandom(16)
750 f_kdfoptions.put_sshstr(salt)
751 f_kdfoptions.put_u32(rounds)
752 ciph = _init_cipher(ciphername, password, salt, rounds)
753 else:
754 ciphername = kdfname = _NONE
755 blklen = 8
756 ciph = None
757 nkeys = 1
758 checkval = os.urandom(4)
759 comment = b""
761 # encode public and private parts together
762 f_public_key = _FragList()
763 f_public_key.put_sshstr(key_type)
764 kformat.encode_public(private_key.public_key(), f_public_key)
766 f_secrets = _FragList([checkval, checkval])
767 f_secrets.put_sshstr(key_type)
768 kformat.encode_private(private_key, f_secrets)
769 f_secrets.put_sshstr(comment)
770 f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)])
772 # top-level structure
773 f_main = _FragList()
774 f_main.put_raw(_SK_MAGIC)
775 f_main.put_sshstr(ciphername)
776 f_main.put_sshstr(kdfname)
777 f_main.put_sshstr(f_kdfoptions)
778 f_main.put_u32(nkeys)
779 f_main.put_sshstr(f_public_key)
780 f_main.put_sshstr(f_secrets)
782 # copy result info bytearray
783 slen = f_secrets.size()
784 mlen = f_main.size()
785 buf = memoryview(bytearray(mlen + blklen))
786 f_main.render(buf)
787 ofs = mlen - slen
789 # encrypt in-place
790 if ciph is not None:
791 ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:])
793 return _ssh_pem_encode(buf[:mlen])
796SSHPublicKeyTypes = typing.Union[
797 ec.EllipticCurvePublicKey,
798 rsa.RSAPublicKey,
799 dsa.DSAPublicKey,
800 ed25519.Ed25519PublicKey,
801]
803SSHCertPublicKeyTypes = typing.Union[
804 ec.EllipticCurvePublicKey,
805 rsa.RSAPublicKey,
806 ed25519.Ed25519PublicKey,
807]
810class SSHCertificateType(enum.Enum):
811 USER = 1
812 HOST = 2
815class SSHCertificate:
816 def __init__(
817 self,
818 _nonce: memoryview,
819 _public_key: SSHPublicKeyTypes,
820 _serial: int,
821 _cctype: int,
822 _key_id: memoryview,
823 _valid_principals: typing.List[bytes],
824 _valid_after: int,
825 _valid_before: int,
826 _critical_options: typing.Dict[bytes, bytes],
827 _extensions: typing.Dict[bytes, bytes],
828 _sig_type: memoryview,
829 _sig_key: memoryview,
830 _inner_sig_type: memoryview,
831 _signature: memoryview,
832 _tbs_cert_body: memoryview,
833 _cert_key_type: bytes,
834 _cert_body: memoryview,
835 ):
836 self._nonce = _nonce
837 self._public_key = _public_key
838 self._serial = _serial
839 try:
840 self._type = SSHCertificateType(_cctype)
841 except ValueError:
842 raise ValueError("Invalid certificate type")
843 self._key_id = _key_id
844 self._valid_principals = _valid_principals
845 self._valid_after = _valid_after
846 self._valid_before = _valid_before
847 self._critical_options = _critical_options
848 self._extensions = _extensions
849 self._sig_type = _sig_type
850 self._sig_key = _sig_key
851 self._inner_sig_type = _inner_sig_type
852 self._signature = _signature
853 self._cert_key_type = _cert_key_type
854 self._cert_body = _cert_body
855 self._tbs_cert_body = _tbs_cert_body
857 @property
858 def nonce(self) -> bytes:
859 return bytes(self._nonce)
861 def public_key(self) -> SSHCertPublicKeyTypes:
862 # make mypy happy until we remove DSA support entirely and
863 # the underlying union won't have a disallowed type
864 return typing.cast(SSHCertPublicKeyTypes, self._public_key)
866 @property
867 def serial(self) -> int:
868 return self._serial
870 @property
871 def type(self) -> SSHCertificateType:
872 return self._type
874 @property
875 def key_id(self) -> bytes:
876 return bytes(self._key_id)
878 @property
879 def valid_principals(self) -> typing.List[bytes]:
880 return self._valid_principals
882 @property
883 def valid_before(self) -> int:
884 return self._valid_before
886 @property
887 def valid_after(self) -> int:
888 return self._valid_after
890 @property
891 def critical_options(self) -> typing.Dict[bytes, bytes]:
892 return self._critical_options
894 @property
895 def extensions(self) -> typing.Dict[bytes, bytes]:
896 return self._extensions
898 def signature_key(self) -> SSHCertPublicKeyTypes:
899 sigformat = _lookup_kformat(self._sig_type)
900 signature_key, sigkey_rest = sigformat.load_public(self._sig_key)
901 _check_empty(sigkey_rest)
902 return signature_key
904 def public_bytes(self) -> bytes:
905 return (
906 bytes(self._cert_key_type)
907 + b" "
908 + binascii.b2a_base64(bytes(self._cert_body), newline=False)
909 )
911 def verify_cert_signature(self) -> None:
912 signature_key = self.signature_key()
913 if isinstance(signature_key, ed25519.Ed25519PublicKey):
914 signature_key.verify(
915 bytes(self._signature), bytes(self._tbs_cert_body)
916 )
917 elif isinstance(signature_key, ec.EllipticCurvePublicKey):
918 # The signature is encoded as a pair of big-endian integers
919 r, data = _get_mpint(self._signature)
920 s, data = _get_mpint(data)
921 _check_empty(data)
922 computed_sig = asym_utils.encode_dss_signature(r, s)
923 hash_alg = _get_ec_hash_alg(signature_key.curve)
924 signature_key.verify(
925 computed_sig, bytes(self._tbs_cert_body), ec.ECDSA(hash_alg)
926 )
927 else:
928 assert isinstance(signature_key, rsa.RSAPublicKey)
929 if self._inner_sig_type == _SSH_RSA:
930 hash_alg = hashes.SHA1()
931 elif self._inner_sig_type == _SSH_RSA_SHA256:
932 hash_alg = hashes.SHA256()
933 else:
934 assert self._inner_sig_type == _SSH_RSA_SHA512
935 hash_alg = hashes.SHA512()
936 signature_key.verify(
937 bytes(self._signature),
938 bytes(self._tbs_cert_body),
939 padding.PKCS1v15(),
940 hash_alg,
941 )
944def _get_ec_hash_alg(curve: ec.EllipticCurve) -> hashes.HashAlgorithm:
945 if isinstance(curve, ec.SECP256R1):
946 return hashes.SHA256()
947 elif isinstance(curve, ec.SECP384R1):
948 return hashes.SHA384()
949 else:
950 assert isinstance(curve, ec.SECP521R1)
951 return hashes.SHA512()
954def _load_ssh_public_identity(
955 data: bytes,
956 _legacy_dsa_allowed=False,
957) -> typing.Union[SSHCertificate, SSHPublicKeyTypes]:
958 utils._check_byteslike("data", data)
960 m = _SSH_PUBKEY_RC.match(data)
961 if not m:
962 raise ValueError("Invalid line format")
963 key_type = orig_key_type = m.group(1)
964 key_body = m.group(2)
965 with_cert = False
966 if key_type.endswith(_CERT_SUFFIX):
967 with_cert = True
968 key_type = key_type[: -len(_CERT_SUFFIX)]
969 if key_type == _SSH_DSA and not _legacy_dsa_allowed:
970 raise UnsupportedAlgorithm(
971 "DSA keys aren't supported in SSH certificates"
972 )
973 kformat = _lookup_kformat(key_type)
975 try:
976 rest = memoryview(binascii.a2b_base64(key_body))
977 except (TypeError, binascii.Error):
978 raise ValueError("Invalid format")
980 if with_cert:
981 cert_body = rest
982 inner_key_type, rest = _get_sshstr(rest)
983 if inner_key_type != orig_key_type:
984 raise ValueError("Invalid key format")
985 if with_cert:
986 nonce, rest = _get_sshstr(rest)
987 public_key, rest = kformat.load_public(rest)
988 if with_cert:
989 serial, rest = _get_u64(rest)
990 cctype, rest = _get_u32(rest)
991 key_id, rest = _get_sshstr(rest)
992 principals, rest = _get_sshstr(rest)
993 valid_principals = []
994 while principals:
995 principal, principals = _get_sshstr(principals)
996 valid_principals.append(bytes(principal))
997 valid_after, rest = _get_u64(rest)
998 valid_before, rest = _get_u64(rest)
999 crit_options, rest = _get_sshstr(rest)
1000 critical_options = _parse_exts_opts(crit_options)
1001 exts, rest = _get_sshstr(rest)
1002 extensions = _parse_exts_opts(exts)
1003 # Get the reserved field, which is unused.
1004 _, rest = _get_sshstr(rest)
1005 sig_key_raw, rest = _get_sshstr(rest)
1006 sig_type, sig_key = _get_sshstr(sig_key_raw)
1007 if sig_type == _SSH_DSA and not _legacy_dsa_allowed:
1008 raise UnsupportedAlgorithm(
1009 "DSA signatures aren't supported in SSH certificates"
1010 )
1011 # Get the entire cert body and subtract the signature
1012 tbs_cert_body = cert_body[: -len(rest)]
1013 signature_raw, rest = _get_sshstr(rest)
1014 _check_empty(rest)
1015 inner_sig_type, sig_rest = _get_sshstr(signature_raw)
1016 # RSA certs can have multiple algorithm types
1017 if (
1018 sig_type == _SSH_RSA
1019 and inner_sig_type
1020 not in [_SSH_RSA_SHA256, _SSH_RSA_SHA512, _SSH_RSA]
1021 ) or (sig_type != _SSH_RSA and inner_sig_type != sig_type):
1022 raise ValueError("Signature key type does not match")
1023 signature, sig_rest = _get_sshstr(sig_rest)
1024 _check_empty(sig_rest)
1025 return SSHCertificate(
1026 nonce,
1027 public_key,
1028 serial,
1029 cctype,
1030 key_id,
1031 valid_principals,
1032 valid_after,
1033 valid_before,
1034 critical_options,
1035 extensions,
1036 sig_type,
1037 sig_key,
1038 inner_sig_type,
1039 signature,
1040 tbs_cert_body,
1041 orig_key_type,
1042 cert_body,
1043 )
1044 else:
1045 _check_empty(rest)
1046 return public_key
1049def load_ssh_public_identity(
1050 data: bytes,
1051) -> typing.Union[SSHCertificate, SSHPublicKeyTypes]:
1052 return _load_ssh_public_identity(data)
1055def _parse_exts_opts(exts_opts: memoryview) -> typing.Dict[bytes, bytes]:
1056 result: typing.Dict[bytes, bytes] = {}
1057 last_name = None
1058 while exts_opts:
1059 name, exts_opts = _get_sshstr(exts_opts)
1060 bname: bytes = bytes(name)
1061 if bname in result:
1062 raise ValueError("Duplicate name")
1063 if last_name is not None and bname < last_name:
1064 raise ValueError("Fields not lexically sorted")
1065 value, exts_opts = _get_sshstr(exts_opts)
1066 if len(value) > 0:
1067 try:
1068 value, extra = _get_sshstr(value)
1069 except ValueError:
1070 warnings.warn(
1071 "This certificate has an incorrect encoding for critical "
1072 "options or extensions. This will be an exception in "
1073 "cryptography 42",
1074 utils.DeprecatedIn41,
1075 stacklevel=4,
1076 )
1077 else:
1078 if len(extra) > 0:
1079 raise ValueError("Unexpected extra data after value")
1080 result[bname] = bytes(value)
1081 last_name = bname
1082 return result
1085def load_ssh_public_key(
1086 data: bytes, backend: typing.Any = None
1087) -> SSHPublicKeyTypes:
1088 cert_or_key = _load_ssh_public_identity(data, _legacy_dsa_allowed=True)
1089 public_key: SSHPublicKeyTypes
1090 if isinstance(cert_or_key, SSHCertificate):
1091 public_key = cert_or_key.public_key()
1092 else:
1093 public_key = cert_or_key
1095 if isinstance(public_key, dsa.DSAPublicKey):
1096 warnings.warn(
1097 "SSH DSA keys are deprecated and will be removed in a future "
1098 "release.",
1099 utils.DeprecatedIn40,
1100 stacklevel=2,
1101 )
1102 return public_key
1105def serialize_ssh_public_key(public_key: SSHPublicKeyTypes) -> bytes:
1106 """One-line public key format for OpenSSH"""
1107 if isinstance(public_key, dsa.DSAPublicKey):
1108 warnings.warn(
1109 "SSH DSA key support is deprecated and will be "
1110 "removed in a future release",
1111 utils.DeprecatedIn40,
1112 stacklevel=4,
1113 )
1114 key_type = _get_ssh_key_type(public_key)
1115 kformat = _lookup_kformat(key_type)
1117 f_pub = _FragList()
1118 f_pub.put_sshstr(key_type)
1119 kformat.encode_public(public_key, f_pub)
1121 pub = binascii.b2a_base64(f_pub.tobytes()).strip()
1122 return b"".join([key_type, b" ", pub])
1125SSHCertPrivateKeyTypes = typing.Union[
1126 ec.EllipticCurvePrivateKey,
1127 rsa.RSAPrivateKey,
1128 ed25519.Ed25519PrivateKey,
1129]
1132# This is an undocumented limit enforced in the openssh codebase for sshd and
1133# ssh-keygen, but it is undefined in the ssh certificates spec.
1134_SSHKEY_CERT_MAX_PRINCIPALS = 256
1137class SSHCertificateBuilder:
1138 def __init__(
1139 self,
1140 _public_key: typing.Optional[SSHCertPublicKeyTypes] = None,
1141 _serial: typing.Optional[int] = None,
1142 _type: typing.Optional[SSHCertificateType] = None,
1143 _key_id: typing.Optional[bytes] = None,
1144 _valid_principals: typing.List[bytes] = [],
1145 _valid_for_all_principals: bool = False,
1146 _valid_before: typing.Optional[int] = None,
1147 _valid_after: typing.Optional[int] = None,
1148 _critical_options: typing.List[typing.Tuple[bytes, bytes]] = [],
1149 _extensions: typing.List[typing.Tuple[bytes, bytes]] = [],
1150 ):
1151 self._public_key = _public_key
1152 self._serial = _serial
1153 self._type = _type
1154 self._key_id = _key_id
1155 self._valid_principals = _valid_principals
1156 self._valid_for_all_principals = _valid_for_all_principals
1157 self._valid_before = _valid_before
1158 self._valid_after = _valid_after
1159 self._critical_options = _critical_options
1160 self._extensions = _extensions
1162 def public_key(
1163 self, public_key: SSHCertPublicKeyTypes
1164 ) -> SSHCertificateBuilder:
1165 if not isinstance(
1166 public_key,
1167 (
1168 ec.EllipticCurvePublicKey,
1169 rsa.RSAPublicKey,
1170 ed25519.Ed25519PublicKey,
1171 ),
1172 ):
1173 raise TypeError("Unsupported key type")
1174 if self._public_key is not None:
1175 raise ValueError("public_key already set")
1177 return SSHCertificateBuilder(
1178 _public_key=public_key,
1179 _serial=self._serial,
1180 _type=self._type,
1181 _key_id=self._key_id,
1182 _valid_principals=self._valid_principals,
1183 _valid_for_all_principals=self._valid_for_all_principals,
1184 _valid_before=self._valid_before,
1185 _valid_after=self._valid_after,
1186 _critical_options=self._critical_options,
1187 _extensions=self._extensions,
1188 )
1190 def serial(self, serial: int) -> SSHCertificateBuilder:
1191 if not isinstance(serial, int):
1192 raise TypeError("serial must be an integer")
1193 if not 0 <= serial < 2**64:
1194 raise ValueError("serial must be between 0 and 2**64")
1195 if self._serial is not None:
1196 raise ValueError("serial already set")
1198 return SSHCertificateBuilder(
1199 _public_key=self._public_key,
1200 _serial=serial,
1201 _type=self._type,
1202 _key_id=self._key_id,
1203 _valid_principals=self._valid_principals,
1204 _valid_for_all_principals=self._valid_for_all_principals,
1205 _valid_before=self._valid_before,
1206 _valid_after=self._valid_after,
1207 _critical_options=self._critical_options,
1208 _extensions=self._extensions,
1209 )
1211 def type(self, type: SSHCertificateType) -> SSHCertificateBuilder:
1212 if not isinstance(type, SSHCertificateType):
1213 raise TypeError("type must be an SSHCertificateType")
1214 if self._type is not None:
1215 raise ValueError("type already set")
1217 return SSHCertificateBuilder(
1218 _public_key=self._public_key,
1219 _serial=self._serial,
1220 _type=type,
1221 _key_id=self._key_id,
1222 _valid_principals=self._valid_principals,
1223 _valid_for_all_principals=self._valid_for_all_principals,
1224 _valid_before=self._valid_before,
1225 _valid_after=self._valid_after,
1226 _critical_options=self._critical_options,
1227 _extensions=self._extensions,
1228 )
1230 def key_id(self, key_id: bytes) -> SSHCertificateBuilder:
1231 if not isinstance(key_id, bytes):
1232 raise TypeError("key_id must be bytes")
1233 if self._key_id is not None:
1234 raise ValueError("key_id already set")
1236 return SSHCertificateBuilder(
1237 _public_key=self._public_key,
1238 _serial=self._serial,
1239 _type=self._type,
1240 _key_id=key_id,
1241 _valid_principals=self._valid_principals,
1242 _valid_for_all_principals=self._valid_for_all_principals,
1243 _valid_before=self._valid_before,
1244 _valid_after=self._valid_after,
1245 _critical_options=self._critical_options,
1246 _extensions=self._extensions,
1247 )
1249 def valid_principals(
1250 self, valid_principals: typing.List[bytes]
1251 ) -> SSHCertificateBuilder:
1252 if self._valid_for_all_principals:
1253 raise ValueError(
1254 "Principals can't be set because the cert is valid "
1255 "for all principals"
1256 )
1257 if (
1258 not all(isinstance(x, bytes) for x in valid_principals)
1259 or not valid_principals
1260 ):
1261 raise TypeError(
1262 "principals must be a list of bytes and can't be empty"
1263 )
1264 if self._valid_principals:
1265 raise ValueError("valid_principals already set")
1267 if len(valid_principals) > _SSHKEY_CERT_MAX_PRINCIPALS:
1268 raise ValueError(
1269 "Reached or exceeded the maximum number of valid_principals"
1270 )
1272 return SSHCertificateBuilder(
1273 _public_key=self._public_key,
1274 _serial=self._serial,
1275 _type=self._type,
1276 _key_id=self._key_id,
1277 _valid_principals=valid_principals,
1278 _valid_for_all_principals=self._valid_for_all_principals,
1279 _valid_before=self._valid_before,
1280 _valid_after=self._valid_after,
1281 _critical_options=self._critical_options,
1282 _extensions=self._extensions,
1283 )
1285 def valid_for_all_principals(self):
1286 if self._valid_principals:
1287 raise ValueError(
1288 "valid_principals already set, can't set "
1289 "valid_for_all_principals"
1290 )
1291 if self._valid_for_all_principals:
1292 raise ValueError("valid_for_all_principals already set")
1294 return SSHCertificateBuilder(
1295 _public_key=self._public_key,
1296 _serial=self._serial,
1297 _type=self._type,
1298 _key_id=self._key_id,
1299 _valid_principals=self._valid_principals,
1300 _valid_for_all_principals=True,
1301 _valid_before=self._valid_before,
1302 _valid_after=self._valid_after,
1303 _critical_options=self._critical_options,
1304 _extensions=self._extensions,
1305 )
1307 def valid_before(
1308 self, valid_before: typing.Union[int, float]
1309 ) -> SSHCertificateBuilder:
1310 if not isinstance(valid_before, (int, float)):
1311 raise TypeError("valid_before must be an int or float")
1312 valid_before = int(valid_before)
1313 if valid_before < 0 or valid_before >= 2**64:
1314 raise ValueError("valid_before must [0, 2**64)")
1315 if self._valid_before is not None:
1316 raise ValueError("valid_before already set")
1318 return SSHCertificateBuilder(
1319 _public_key=self._public_key,
1320 _serial=self._serial,
1321 _type=self._type,
1322 _key_id=self._key_id,
1323 _valid_principals=self._valid_principals,
1324 _valid_for_all_principals=self._valid_for_all_principals,
1325 _valid_before=valid_before,
1326 _valid_after=self._valid_after,
1327 _critical_options=self._critical_options,
1328 _extensions=self._extensions,
1329 )
1331 def valid_after(
1332 self, valid_after: typing.Union[int, float]
1333 ) -> SSHCertificateBuilder:
1334 if not isinstance(valid_after, (int, float)):
1335 raise TypeError("valid_after must be an int or float")
1336 valid_after = int(valid_after)
1337 if valid_after < 0 or valid_after >= 2**64:
1338 raise ValueError("valid_after must [0, 2**64)")
1339 if self._valid_after is not None:
1340 raise ValueError("valid_after already set")
1342 return SSHCertificateBuilder(
1343 _public_key=self._public_key,
1344 _serial=self._serial,
1345 _type=self._type,
1346 _key_id=self._key_id,
1347 _valid_principals=self._valid_principals,
1348 _valid_for_all_principals=self._valid_for_all_principals,
1349 _valid_before=self._valid_before,
1350 _valid_after=valid_after,
1351 _critical_options=self._critical_options,
1352 _extensions=self._extensions,
1353 )
1355 def add_critical_option(
1356 self, name: bytes, value: bytes
1357 ) -> SSHCertificateBuilder:
1358 if not isinstance(name, bytes) or not isinstance(value, bytes):
1359 raise TypeError("name and value must be bytes")
1360 # This is O(n**2)
1361 if name in [name for name, _ in self._critical_options]:
1362 raise ValueError("Duplicate critical option name")
1364 return SSHCertificateBuilder(
1365 _public_key=self._public_key,
1366 _serial=self._serial,
1367 _type=self._type,
1368 _key_id=self._key_id,
1369 _valid_principals=self._valid_principals,
1370 _valid_for_all_principals=self._valid_for_all_principals,
1371 _valid_before=self._valid_before,
1372 _valid_after=self._valid_after,
1373 _critical_options=self._critical_options + [(name, value)],
1374 _extensions=self._extensions,
1375 )
1377 def add_extension(
1378 self, name: bytes, value: bytes
1379 ) -> SSHCertificateBuilder:
1380 if not isinstance(name, bytes) or not isinstance(value, bytes):
1381 raise TypeError("name and value must be bytes")
1382 # This is O(n**2)
1383 if name in [name for name, _ in self._extensions]:
1384 raise ValueError("Duplicate extension name")
1386 return SSHCertificateBuilder(
1387 _public_key=self._public_key,
1388 _serial=self._serial,
1389 _type=self._type,
1390 _key_id=self._key_id,
1391 _valid_principals=self._valid_principals,
1392 _valid_for_all_principals=self._valid_for_all_principals,
1393 _valid_before=self._valid_before,
1394 _valid_after=self._valid_after,
1395 _critical_options=self._critical_options,
1396 _extensions=self._extensions + [(name, value)],
1397 )
1399 def sign(self, private_key: SSHCertPrivateKeyTypes) -> SSHCertificate:
1400 if not isinstance(
1401 private_key,
1402 (
1403 ec.EllipticCurvePrivateKey,
1404 rsa.RSAPrivateKey,
1405 ed25519.Ed25519PrivateKey,
1406 ),
1407 ):
1408 raise TypeError("Unsupported private key type")
1410 if self._public_key is None:
1411 raise ValueError("public_key must be set")
1413 # Not required
1414 serial = 0 if self._serial is None else self._serial
1416 if self._type is None:
1417 raise ValueError("type must be set")
1419 # Not required
1420 key_id = b"" if self._key_id is None else self._key_id
1422 # A zero length list is valid, but means the certificate
1423 # is valid for any principal of the specified type. We require
1424 # the user to explicitly set valid_for_all_principals to get
1425 # that behavior.
1426 if not self._valid_principals and not self._valid_for_all_principals:
1427 raise ValueError(
1428 "valid_principals must be set if valid_for_all_principals "
1429 "is False"
1430 )
1432 if self._valid_before is None:
1433 raise ValueError("valid_before must be set")
1435 if self._valid_after is None:
1436 raise ValueError("valid_after must be set")
1438 if self._valid_after > self._valid_before:
1439 raise ValueError("valid_after must be earlier than valid_before")
1441 # lexically sort our byte strings
1442 self._critical_options.sort(key=lambda x: x[0])
1443 self._extensions.sort(key=lambda x: x[0])
1445 key_type = _get_ssh_key_type(self._public_key)
1446 cert_prefix = key_type + _CERT_SUFFIX
1448 # Marshal the bytes to be signed
1449 nonce = os.urandom(32)
1450 kformat = _lookup_kformat(key_type)
1451 f = _FragList()
1452 f.put_sshstr(cert_prefix)
1453 f.put_sshstr(nonce)
1454 kformat.encode_public(self._public_key, f)
1455 f.put_u64(serial)
1456 f.put_u32(self._type.value)
1457 f.put_sshstr(key_id)
1458 fprincipals = _FragList()
1459 for p in self._valid_principals:
1460 fprincipals.put_sshstr(p)
1461 f.put_sshstr(fprincipals.tobytes())
1462 f.put_u64(self._valid_after)
1463 f.put_u64(self._valid_before)
1464 fcrit = _FragList()
1465 for name, value in self._critical_options:
1466 fcrit.put_sshstr(name)
1467 if len(value) > 0:
1468 foptval = _FragList()
1469 foptval.put_sshstr(value)
1470 fcrit.put_sshstr(foptval.tobytes())
1471 else:
1472 fcrit.put_sshstr(value)
1473 f.put_sshstr(fcrit.tobytes())
1474 fext = _FragList()
1475 for name, value in self._extensions:
1476 fext.put_sshstr(name)
1477 if len(value) > 0:
1478 fextval = _FragList()
1479 fextval.put_sshstr(value)
1480 fext.put_sshstr(fextval.tobytes())
1481 else:
1482 fext.put_sshstr(value)
1483 f.put_sshstr(fext.tobytes())
1484 f.put_sshstr(b"") # RESERVED FIELD
1485 # encode CA public key
1486 ca_type = _get_ssh_key_type(private_key)
1487 caformat = _lookup_kformat(ca_type)
1488 caf = _FragList()
1489 caf.put_sshstr(ca_type)
1490 caformat.encode_public(private_key.public_key(), caf)
1491 f.put_sshstr(caf.tobytes())
1492 # Sigs according to the rules defined for the CA's public key
1493 # (RFC4253 section 6.6 for ssh-rsa, RFC5656 for ECDSA,
1494 # and RFC8032 for Ed25519).
1495 if isinstance(private_key, ed25519.Ed25519PrivateKey):
1496 signature = private_key.sign(f.tobytes())
1497 fsig = _FragList()
1498 fsig.put_sshstr(ca_type)
1499 fsig.put_sshstr(signature)
1500 f.put_sshstr(fsig.tobytes())
1501 elif isinstance(private_key, ec.EllipticCurvePrivateKey):
1502 hash_alg = _get_ec_hash_alg(private_key.curve)
1503 signature = private_key.sign(f.tobytes(), ec.ECDSA(hash_alg))
1504 r, s = asym_utils.decode_dss_signature(signature)
1505 fsig = _FragList()
1506 fsig.put_sshstr(ca_type)
1507 fsigblob = _FragList()
1508 fsigblob.put_mpint(r)
1509 fsigblob.put_mpint(s)
1510 fsig.put_sshstr(fsigblob.tobytes())
1511 f.put_sshstr(fsig.tobytes())
1513 else:
1514 assert isinstance(private_key, rsa.RSAPrivateKey)
1515 # Just like Golang, we're going to use SHA512 for RSA
1516 # https://cs.opensource.google/go/x/crypto/+/refs/tags/
1517 # v0.4.0:ssh/certs.go;l=445
1518 # RFC 8332 defines SHA256 and 512 as options
1519 fsig = _FragList()
1520 fsig.put_sshstr(_SSH_RSA_SHA512)
1521 signature = private_key.sign(
1522 f.tobytes(), padding.PKCS1v15(), hashes.SHA512()
1523 )
1524 fsig.put_sshstr(signature)
1525 f.put_sshstr(fsig.tobytes())
1527 cert_data = binascii.b2a_base64(f.tobytes()).strip()
1528 # load_ssh_public_identity returns a union, but this is
1529 # guaranteed to be an SSHCertificate, so we cast to make
1530 # mypy happy.
1531 return typing.cast(
1532 SSHCertificate,
1533 load_ssh_public_identity(b"".join([cert_prefix, b" ", cert_data])),
1534 )