1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import annotations
6
7import binascii
8import enum
9import os
10import re
11import typing
12import warnings
13from base64 import encodebytes as _base64_encode
14from dataclasses import dataclass
15
16from cryptography import utils
17from cryptography.exceptions import UnsupportedAlgorithm
18from cryptography.hazmat.primitives import hashes
19from cryptography.hazmat.primitives.asymmetric import (
20 dsa,
21 ec,
22 ed25519,
23 padding,
24 rsa,
25)
26from cryptography.hazmat.primitives.asymmetric import utils as asym_utils
27from cryptography.hazmat.primitives.ciphers import (
28 AEADDecryptionContext,
29 Cipher,
30 algorithms,
31 modes,
32)
33from cryptography.hazmat.primitives.serialization import (
34 Encoding,
35 KeySerializationEncryption,
36 NoEncryption,
37 PrivateFormat,
38 PublicFormat,
39 _KeySerializationEncryption,
40)
41
42try:
43 from bcrypt import kdf as _bcrypt_kdf
44
45 _bcrypt_supported = True
46except ImportError:
47 _bcrypt_supported = False
48
49 def _bcrypt_kdf(
50 password: bytes,
51 salt: bytes,
52 desired_key_bytes: int,
53 rounds: int,
54 ignore_few_rounds: bool = False,
55 ) -> bytes:
56 raise UnsupportedAlgorithm("Need bcrypt module")
57
58
59_SSH_ED25519 = b"ssh-ed25519"
60_SSH_RSA = b"ssh-rsa"
61_SSH_DSA = b"ssh-dss"
62_ECDSA_NISTP256 = b"ecdsa-sha2-nistp256"
63_ECDSA_NISTP384 = b"ecdsa-sha2-nistp384"
64_ECDSA_NISTP521 = b"ecdsa-sha2-nistp521"
65_CERT_SUFFIX = b"-cert-v01@openssh.com"
66
67# U2F application string suffixed pubkey
68_SK_SSH_ED25519 = b"sk-ssh-ed25519@openssh.com"
69_SK_SSH_ECDSA_NISTP256 = b"sk-ecdsa-sha2-nistp256@openssh.com"
70
71# These are not key types, only algorithms, so they cannot appear
72# as a public key type
73_SSH_RSA_SHA256 = b"rsa-sha2-256"
74_SSH_RSA_SHA512 = b"rsa-sha2-512"
75
76_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)")
77_SK_MAGIC = b"openssh-key-v1\0"
78_SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----"
79_SK_END = b"-----END OPENSSH PRIVATE KEY-----"
80_BCRYPT = b"bcrypt"
81_NONE = b"none"
82_DEFAULT_CIPHER = b"aes256-ctr"
83_DEFAULT_ROUNDS = 16
84
85# re is only way to work on bytes-like data
86_PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL)
87
88# padding for max blocksize
89_PADDING = memoryview(bytearray(range(1, 1 + 16)))
90
91
92@dataclass
93class _SSHCipher:
94 alg: type[algorithms.AES]
95 key_len: int
96 mode: type[modes.CTR] | type[modes.CBC] | type[modes.GCM]
97 block_len: int
98 iv_len: int
99 tag_len: int | None
100 is_aead: bool
101
102
103# ciphers that are actually used in key wrapping
104_SSH_CIPHERS: dict[bytes, _SSHCipher] = {
105 b"aes256-ctr": _SSHCipher(
106 alg=algorithms.AES,
107 key_len=32,
108 mode=modes.CTR,
109 block_len=16,
110 iv_len=16,
111 tag_len=None,
112 is_aead=False,
113 ),
114 b"aes256-cbc": _SSHCipher(
115 alg=algorithms.AES,
116 key_len=32,
117 mode=modes.CBC,
118 block_len=16,
119 iv_len=16,
120 tag_len=None,
121 is_aead=False,
122 ),
123 b"aes256-gcm@openssh.com": _SSHCipher(
124 alg=algorithms.AES,
125 key_len=32,
126 mode=modes.GCM,
127 block_len=16,
128 iv_len=12,
129 tag_len=16,
130 is_aead=True,
131 ),
132}
133
134# map local curve name to key type
135_ECDSA_KEY_TYPE = {
136 "secp256r1": _ECDSA_NISTP256,
137 "secp384r1": _ECDSA_NISTP384,
138 "secp521r1": _ECDSA_NISTP521,
139}
140
141
142def _get_ssh_key_type(key: SSHPrivateKeyTypes | SSHPublicKeyTypes) -> bytes:
143 if isinstance(key, ec.EllipticCurvePrivateKey):
144 key_type = _ecdsa_key_type(key.public_key())
145 elif isinstance(key, ec.EllipticCurvePublicKey):
146 key_type = _ecdsa_key_type(key)
147 elif isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)):
148 key_type = _SSH_RSA
149 elif isinstance(key, (dsa.DSAPrivateKey, dsa.DSAPublicKey)):
150 key_type = _SSH_DSA
151 elif isinstance(
152 key, (ed25519.Ed25519PrivateKey, ed25519.Ed25519PublicKey)
153 ):
154 key_type = _SSH_ED25519
155 else:
156 raise ValueError("Unsupported key type")
157
158 return key_type
159
160
161def _ecdsa_key_type(public_key: ec.EllipticCurvePublicKey) -> bytes:
162 """Return SSH key_type and curve_name for private key."""
163 curve = public_key.curve
164 if curve.name not in _ECDSA_KEY_TYPE:
165 raise ValueError(
166 f"Unsupported curve for ssh private key: {curve.name!r}"
167 )
168 return _ECDSA_KEY_TYPE[curve.name]
169
170
171def _ssh_pem_encode(
172 data: bytes,
173 prefix: bytes = _SK_START + b"\n",
174 suffix: bytes = _SK_END + b"\n",
175) -> bytes:
176 return b"".join([prefix, _base64_encode(data), suffix])
177
178
179def _check_block_size(data: bytes, block_len: int) -> None:
180 """Require data to be full blocks"""
181 if not data or len(data) % block_len != 0:
182 raise ValueError("Corrupt data: missing padding")
183
184
185def _check_empty(data: bytes) -> None:
186 """All data should have been parsed."""
187 if data:
188 raise ValueError("Corrupt data: unparsed data")
189
190
191def _init_cipher(
192 ciphername: bytes,
193 password: bytes | None,
194 salt: bytes,
195 rounds: int,
196) -> Cipher[modes.CBC | modes.CTR | modes.GCM]:
197 """Generate key + iv and return cipher."""
198 if not password:
199 raise ValueError("Key is password-protected.")
200
201 ciph = _SSH_CIPHERS[ciphername]
202 seed = _bcrypt_kdf(
203 password, salt, ciph.key_len + ciph.iv_len, rounds, True
204 )
205 return Cipher(
206 ciph.alg(seed[: ciph.key_len]),
207 ciph.mode(seed[ciph.key_len :]),
208 )
209
210
211def _get_u32(data: memoryview) -> tuple[int, memoryview]:
212 """Uint32"""
213 if len(data) < 4:
214 raise ValueError("Invalid data")
215 return int.from_bytes(data[:4], byteorder="big"), data[4:]
216
217
218def _get_u64(data: memoryview) -> tuple[int, memoryview]:
219 """Uint64"""
220 if len(data) < 8:
221 raise ValueError("Invalid data")
222 return int.from_bytes(data[:8], byteorder="big"), data[8:]
223
224
225def _get_sshstr(data: memoryview) -> tuple[memoryview, memoryview]:
226 """Bytes with u32 length prefix"""
227 n, data = _get_u32(data)
228 if n > len(data):
229 raise ValueError("Invalid data")
230 return data[:n], data[n:]
231
232
233def _get_mpint(data: memoryview) -> tuple[int, memoryview]:
234 """Big integer."""
235 val, data = _get_sshstr(data)
236 if val and val[0] > 0x7F:
237 raise ValueError("Invalid data")
238 return int.from_bytes(val, "big"), data
239
240
241def _to_mpint(val: int) -> bytes:
242 """Storage format for signed bigint."""
243 if val < 0:
244 raise ValueError("negative mpint not allowed")
245 if not val:
246 return b""
247 nbytes = (val.bit_length() + 8) // 8
248 return utils.int_to_bytes(val, nbytes)
249
250
251class _FragList:
252 """Build recursive structure without data copy."""
253
254 flist: list[bytes]
255
256 def __init__(self, init: list[bytes] | None = None) -> None:
257 self.flist = []
258 if init:
259 self.flist.extend(init)
260
261 def put_raw(self, val: bytes) -> None:
262 """Add plain bytes"""
263 self.flist.append(val)
264
265 def put_u32(self, val: int) -> None:
266 """Big-endian uint32"""
267 self.flist.append(val.to_bytes(length=4, byteorder="big"))
268
269 def put_u64(self, val: int) -> None:
270 """Big-endian uint64"""
271 self.flist.append(val.to_bytes(length=8, byteorder="big"))
272
273 def put_sshstr(self, val: bytes | _FragList) -> None:
274 """Bytes prefixed with u32 length"""
275 if isinstance(val, (bytes, memoryview, bytearray)):
276 self.put_u32(len(val))
277 self.flist.append(val)
278 else:
279 self.put_u32(val.size())
280 self.flist.extend(val.flist)
281
282 def put_mpint(self, val: int) -> None:
283 """Big-endian bigint prefixed with u32 length"""
284 self.put_sshstr(_to_mpint(val))
285
286 def size(self) -> int:
287 """Current number of bytes"""
288 return sum(map(len, self.flist))
289
290 def render(self, dstbuf: memoryview, pos: int = 0) -> int:
291 """Write into bytearray"""
292 for frag in self.flist:
293 flen = len(frag)
294 start, pos = pos, pos + flen
295 dstbuf[start:pos] = frag
296 return pos
297
298 def tobytes(self) -> bytes:
299 """Return as bytes"""
300 buf = memoryview(bytearray(self.size()))
301 self.render(buf)
302 return buf.tobytes()
303
304
305class _SSHFormatRSA:
306 """Format for RSA keys.
307
308 Public:
309 mpint e, n
310 Private:
311 mpint n, e, d, iqmp, p, q
312 """
313
314 def get_public(
315 self, data: memoryview
316 ) -> tuple[tuple[int, int], memoryview]:
317 """RSA public fields"""
318 e, data = _get_mpint(data)
319 n, data = _get_mpint(data)
320 return (e, n), data
321
322 def load_public(
323 self, data: memoryview
324 ) -> tuple[rsa.RSAPublicKey, memoryview]:
325 """Make RSA public key from data."""
326 (e, n), data = self.get_public(data)
327 public_numbers = rsa.RSAPublicNumbers(e, n)
328 public_key = public_numbers.public_key()
329 return public_key, data
330
331 def load_private(
332 self, data: memoryview, pubfields
333 ) -> tuple[rsa.RSAPrivateKey, memoryview]:
334 """Make RSA private key from data."""
335 n, data = _get_mpint(data)
336 e, data = _get_mpint(data)
337 d, data = _get_mpint(data)
338 iqmp, data = _get_mpint(data)
339 p, data = _get_mpint(data)
340 q, data = _get_mpint(data)
341
342 if (e, n) != pubfields:
343 raise ValueError("Corrupt data: rsa field mismatch")
344 dmp1 = rsa.rsa_crt_dmp1(d, p)
345 dmq1 = rsa.rsa_crt_dmq1(d, q)
346 public_numbers = rsa.RSAPublicNumbers(e, n)
347 private_numbers = rsa.RSAPrivateNumbers(
348 p, q, d, dmp1, dmq1, iqmp, public_numbers
349 )
350 private_key = private_numbers.private_key()
351 return private_key, data
352
353 def encode_public(
354 self, public_key: rsa.RSAPublicKey, f_pub: _FragList
355 ) -> None:
356 """Write RSA public key"""
357 pubn = public_key.public_numbers()
358 f_pub.put_mpint(pubn.e)
359 f_pub.put_mpint(pubn.n)
360
361 def encode_private(
362 self, private_key: rsa.RSAPrivateKey, f_priv: _FragList
363 ) -> None:
364 """Write RSA private key"""
365 private_numbers = private_key.private_numbers()
366 public_numbers = private_numbers.public_numbers
367
368 f_priv.put_mpint(public_numbers.n)
369 f_priv.put_mpint(public_numbers.e)
370
371 f_priv.put_mpint(private_numbers.d)
372 f_priv.put_mpint(private_numbers.iqmp)
373 f_priv.put_mpint(private_numbers.p)
374 f_priv.put_mpint(private_numbers.q)
375
376
377class _SSHFormatDSA:
378 """Format for DSA keys.
379
380 Public:
381 mpint p, q, g, y
382 Private:
383 mpint p, q, g, y, x
384 """
385
386 def get_public(self, data: memoryview) -> tuple[tuple, memoryview]:
387 """DSA public fields"""
388 p, data = _get_mpint(data)
389 q, data = _get_mpint(data)
390 g, data = _get_mpint(data)
391 y, data = _get_mpint(data)
392 return (p, q, g, y), data
393
394 def load_public(
395 self, data: memoryview
396 ) -> tuple[dsa.DSAPublicKey, memoryview]:
397 """Make DSA public key from data."""
398 (p, q, g, y), data = self.get_public(data)
399 parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
400 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
401 self._validate(public_numbers)
402 public_key = public_numbers.public_key()
403 return public_key, data
404
405 def load_private(
406 self, data: memoryview, pubfields
407 ) -> tuple[dsa.DSAPrivateKey, memoryview]:
408 """Make DSA private key from data."""
409 (p, q, g, y), data = self.get_public(data)
410 x, data = _get_mpint(data)
411
412 if (p, q, g, y) != pubfields:
413 raise ValueError("Corrupt data: dsa field mismatch")
414 parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
415 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
416 self._validate(public_numbers)
417 private_numbers = dsa.DSAPrivateNumbers(x, public_numbers)
418 private_key = private_numbers.private_key()
419 return private_key, data
420
421 def encode_public(
422 self, public_key: dsa.DSAPublicKey, f_pub: _FragList
423 ) -> None:
424 """Write DSA public key"""
425 public_numbers = public_key.public_numbers()
426 parameter_numbers = public_numbers.parameter_numbers
427 self._validate(public_numbers)
428
429 f_pub.put_mpint(parameter_numbers.p)
430 f_pub.put_mpint(parameter_numbers.q)
431 f_pub.put_mpint(parameter_numbers.g)
432 f_pub.put_mpint(public_numbers.y)
433
434 def encode_private(
435 self, private_key: dsa.DSAPrivateKey, f_priv: _FragList
436 ) -> None:
437 """Write DSA private key"""
438 self.encode_public(private_key.public_key(), f_priv)
439 f_priv.put_mpint(private_key.private_numbers().x)
440
441 def _validate(self, public_numbers: dsa.DSAPublicNumbers) -> None:
442 parameter_numbers = public_numbers.parameter_numbers
443 if parameter_numbers.p.bit_length() != 1024:
444 raise ValueError("SSH supports only 1024 bit DSA keys")
445
446
447class _SSHFormatECDSA:
448 """Format for ECDSA keys.
449
450 Public:
451 str curve
452 bytes point
453 Private:
454 str curve
455 bytes point
456 mpint secret
457 """
458
459 def __init__(self, ssh_curve_name: bytes, curve: ec.EllipticCurve):
460 self.ssh_curve_name = ssh_curve_name
461 self.curve = curve
462
463 def get_public(
464 self, data: memoryview
465 ) -> tuple[tuple[memoryview, memoryview], memoryview]:
466 """ECDSA public fields"""
467 curve, data = _get_sshstr(data)
468 point, data = _get_sshstr(data)
469 if curve != self.ssh_curve_name:
470 raise ValueError("Curve name mismatch")
471 if point[0] != 4:
472 raise NotImplementedError("Need uncompressed point")
473 return (curve, point), data
474
475 def load_public(
476 self, data: memoryview
477 ) -> tuple[ec.EllipticCurvePublicKey, memoryview]:
478 """Make ECDSA public key from data."""
479 (_, point), data = self.get_public(data)
480 public_key = ec.EllipticCurvePublicKey.from_encoded_point(
481 self.curve, point.tobytes()
482 )
483 return public_key, data
484
485 def load_private(
486 self, data: memoryview, pubfields
487 ) -> tuple[ec.EllipticCurvePrivateKey, memoryview]:
488 """Make ECDSA private key from data."""
489 (curve_name, point), data = self.get_public(data)
490 secret, data = _get_mpint(data)
491
492 if (curve_name, point) != pubfields:
493 raise ValueError("Corrupt data: ecdsa field mismatch")
494 private_key = ec.derive_private_key(secret, self.curve)
495 return private_key, data
496
497 def encode_public(
498 self, public_key: ec.EllipticCurvePublicKey, f_pub: _FragList
499 ) -> None:
500 """Write ECDSA public key"""
501 point = public_key.public_bytes(
502 Encoding.X962, PublicFormat.UncompressedPoint
503 )
504 f_pub.put_sshstr(self.ssh_curve_name)
505 f_pub.put_sshstr(point)
506
507 def encode_private(
508 self, private_key: ec.EllipticCurvePrivateKey, f_priv: _FragList
509 ) -> None:
510 """Write ECDSA private key"""
511 public_key = private_key.public_key()
512 private_numbers = private_key.private_numbers()
513
514 self.encode_public(public_key, f_priv)
515 f_priv.put_mpint(private_numbers.private_value)
516
517
518class _SSHFormatEd25519:
519 """Format for Ed25519 keys.
520
521 Public:
522 bytes point
523 Private:
524 bytes point
525 bytes secret_and_point
526 """
527
528 def get_public(
529 self, data: memoryview
530 ) -> tuple[tuple[memoryview], memoryview]:
531 """Ed25519 public fields"""
532 point, data = _get_sshstr(data)
533 return (point,), data
534
535 def load_public(
536 self, data: memoryview
537 ) -> tuple[ed25519.Ed25519PublicKey, memoryview]:
538 """Make Ed25519 public key from data."""
539 (point,), data = self.get_public(data)
540 public_key = ed25519.Ed25519PublicKey.from_public_bytes(
541 point.tobytes()
542 )
543 return public_key, data
544
545 def load_private(
546 self, data: memoryview, pubfields
547 ) -> tuple[ed25519.Ed25519PrivateKey, memoryview]:
548 """Make Ed25519 private key from data."""
549 (point,), data = self.get_public(data)
550 keypair, data = _get_sshstr(data)
551
552 secret = keypair[:32]
553 point2 = keypair[32:]
554 if point != point2 or (point,) != pubfields:
555 raise ValueError("Corrupt data: ed25519 field mismatch")
556 private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret)
557 return private_key, data
558
559 def encode_public(
560 self, public_key: ed25519.Ed25519PublicKey, f_pub: _FragList
561 ) -> None:
562 """Write Ed25519 public key"""
563 raw_public_key = public_key.public_bytes(
564 Encoding.Raw, PublicFormat.Raw
565 )
566 f_pub.put_sshstr(raw_public_key)
567
568 def encode_private(
569 self, private_key: ed25519.Ed25519PrivateKey, f_priv: _FragList
570 ) -> None:
571 """Write Ed25519 private key"""
572 public_key = private_key.public_key()
573 raw_private_key = private_key.private_bytes(
574 Encoding.Raw, PrivateFormat.Raw, NoEncryption()
575 )
576 raw_public_key = public_key.public_bytes(
577 Encoding.Raw, PublicFormat.Raw
578 )
579 f_keypair = _FragList([raw_private_key, raw_public_key])
580
581 self.encode_public(public_key, f_priv)
582 f_priv.put_sshstr(f_keypair)
583
584
585def load_application(data) -> tuple[memoryview, memoryview]:
586 """
587 U2F application strings
588 """
589 application, data = _get_sshstr(data)
590 if not application.tobytes().startswith(b"ssh:"):
591 raise ValueError(
592 "U2F application string does not start with b'ssh:' "
593 f"({application})"
594 )
595 return application, data
596
597
598class _SSHFormatSKEd25519:
599 """
600 The format of a sk-ssh-ed25519@openssh.com public key is:
601
602 string "sk-ssh-ed25519@openssh.com"
603 string public key
604 string application (user-specified, but typically "ssh:")
605 """
606
607 def load_public(
608 self, data: memoryview
609 ) -> tuple[ed25519.Ed25519PublicKey, memoryview]:
610 """Make Ed25519 public key from data."""
611 public_key, data = _lookup_kformat(_SSH_ED25519).load_public(data)
612 _, data = load_application(data)
613 return public_key, data
614
615
616class _SSHFormatSKECDSA:
617 """
618 The format of a sk-ecdsa-sha2-nistp256@openssh.com public key is:
619
620 string "sk-ecdsa-sha2-nistp256@openssh.com"
621 string curve name
622 ec_point Q
623 string application (user-specified, but typically "ssh:")
624 """
625
626 def load_public(
627 self, data: memoryview
628 ) -> tuple[ec.EllipticCurvePublicKey, memoryview]:
629 """Make ECDSA public key from data."""
630 public_key, data = _lookup_kformat(_ECDSA_NISTP256).load_public(data)
631 _, data = load_application(data)
632 return public_key, data
633
634
635_KEY_FORMATS = {
636 _SSH_RSA: _SSHFormatRSA(),
637 _SSH_DSA: _SSHFormatDSA(),
638 _SSH_ED25519: _SSHFormatEd25519(),
639 _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()),
640 _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()),
641 _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()),
642 _SK_SSH_ED25519: _SSHFormatSKEd25519(),
643 _SK_SSH_ECDSA_NISTP256: _SSHFormatSKECDSA(),
644}
645
646
647def _lookup_kformat(key_type: bytes):
648 """Return valid format or throw error"""
649 if not isinstance(key_type, bytes):
650 key_type = memoryview(key_type).tobytes()
651 if key_type in _KEY_FORMATS:
652 return _KEY_FORMATS[key_type]
653 raise UnsupportedAlgorithm(f"Unsupported key type: {key_type!r}")
654
655
656SSHPrivateKeyTypes = typing.Union[
657 ec.EllipticCurvePrivateKey,
658 rsa.RSAPrivateKey,
659 dsa.DSAPrivateKey,
660 ed25519.Ed25519PrivateKey,
661]
662
663
664def load_ssh_private_key(
665 data: bytes,
666 password: bytes | None,
667 backend: typing.Any = None,
668) -> SSHPrivateKeyTypes:
669 """Load private key from OpenSSH custom encoding."""
670 utils._check_byteslike("data", data)
671 if password is not None:
672 utils._check_bytes("password", password)
673
674 m = _PEM_RC.search(data)
675 if not m:
676 raise ValueError("Not OpenSSH private key format")
677 p1 = m.start(1)
678 p2 = m.end(1)
679 data = binascii.a2b_base64(memoryview(data)[p1:p2])
680 if not data.startswith(_SK_MAGIC):
681 raise ValueError("Not OpenSSH private key format")
682 data = memoryview(data)[len(_SK_MAGIC) :]
683
684 # parse header
685 ciphername, data = _get_sshstr(data)
686 kdfname, data = _get_sshstr(data)
687 kdfoptions, data = _get_sshstr(data)
688 nkeys, data = _get_u32(data)
689 if nkeys != 1:
690 raise ValueError("Only one key supported")
691
692 # load public key data
693 pubdata, data = _get_sshstr(data)
694 pub_key_type, pubdata = _get_sshstr(pubdata)
695 kformat = _lookup_kformat(pub_key_type)
696 pubfields, pubdata = kformat.get_public(pubdata)
697 _check_empty(pubdata)
698
699 if (ciphername, kdfname) != (_NONE, _NONE):
700 ciphername_bytes = ciphername.tobytes()
701 if ciphername_bytes not in _SSH_CIPHERS:
702 raise UnsupportedAlgorithm(
703 f"Unsupported cipher: {ciphername_bytes!r}"
704 )
705 if kdfname != _BCRYPT:
706 raise UnsupportedAlgorithm(f"Unsupported KDF: {kdfname!r}")
707 blklen = _SSH_CIPHERS[ciphername_bytes].block_len
708 tag_len = _SSH_CIPHERS[ciphername_bytes].tag_len
709 # load secret data
710 edata, data = _get_sshstr(data)
711 # see https://bugzilla.mindrot.org/show_bug.cgi?id=3553 for
712 # information about how OpenSSH handles AEAD tags
713 if _SSH_CIPHERS[ciphername_bytes].is_aead:
714 tag = bytes(data)
715 if len(tag) != tag_len:
716 raise ValueError("Corrupt data: invalid tag length for cipher")
717 else:
718 _check_empty(data)
719 _check_block_size(edata, blklen)
720 salt, kbuf = _get_sshstr(kdfoptions)
721 rounds, kbuf = _get_u32(kbuf)
722 _check_empty(kbuf)
723 ciph = _init_cipher(ciphername_bytes, password, salt.tobytes(), rounds)
724 dec = ciph.decryptor()
725 edata = memoryview(dec.update(edata))
726 if _SSH_CIPHERS[ciphername_bytes].is_aead:
727 assert isinstance(dec, AEADDecryptionContext)
728 _check_empty(dec.finalize_with_tag(tag))
729 else:
730 # _check_block_size requires data to be a full block so there
731 # should be no output from finalize
732 _check_empty(dec.finalize())
733 else:
734 # load secret data
735 edata, data = _get_sshstr(data)
736 _check_empty(data)
737 blklen = 8
738 _check_block_size(edata, blklen)
739 ck1, edata = _get_u32(edata)
740 ck2, edata = _get_u32(edata)
741 if ck1 != ck2:
742 raise ValueError("Corrupt data: broken checksum")
743
744 # load per-key struct
745 key_type, edata = _get_sshstr(edata)
746 if key_type != pub_key_type:
747 raise ValueError("Corrupt data: key type mismatch")
748 private_key, edata = kformat.load_private(edata, pubfields)
749 # We don't use the comment
750 _, edata = _get_sshstr(edata)
751
752 # yes, SSH does padding check *after* all other parsing is done.
753 # need to follow as it writes zero-byte padding too.
754 if edata != _PADDING[: len(edata)]:
755 raise ValueError("Corrupt data: invalid padding")
756
757 if isinstance(private_key, dsa.DSAPrivateKey):
758 warnings.warn(
759 "SSH DSA keys are deprecated and will be removed in a future "
760 "release.",
761 utils.DeprecatedIn40,
762 stacklevel=2,
763 )
764
765 return private_key
766
767
768def _serialize_ssh_private_key(
769 private_key: SSHPrivateKeyTypes,
770 password: bytes,
771 encryption_algorithm: KeySerializationEncryption,
772) -> bytes:
773 """Serialize private key with OpenSSH custom encoding."""
774 utils._check_bytes("password", password)
775 if isinstance(private_key, dsa.DSAPrivateKey):
776 warnings.warn(
777 "SSH DSA key support is deprecated and will be "
778 "removed in a future release",
779 utils.DeprecatedIn40,
780 stacklevel=4,
781 )
782
783 key_type = _get_ssh_key_type(private_key)
784 kformat = _lookup_kformat(key_type)
785
786 # setup parameters
787 f_kdfoptions = _FragList()
788 if password:
789 ciphername = _DEFAULT_CIPHER
790 blklen = _SSH_CIPHERS[ciphername].block_len
791 kdfname = _BCRYPT
792 rounds = _DEFAULT_ROUNDS
793 if (
794 isinstance(encryption_algorithm, _KeySerializationEncryption)
795 and encryption_algorithm._kdf_rounds is not None
796 ):
797 rounds = encryption_algorithm._kdf_rounds
798 salt = os.urandom(16)
799 f_kdfoptions.put_sshstr(salt)
800 f_kdfoptions.put_u32(rounds)
801 ciph = _init_cipher(ciphername, password, salt, rounds)
802 else:
803 ciphername = kdfname = _NONE
804 blklen = 8
805 ciph = None
806 nkeys = 1
807 checkval = os.urandom(4)
808 comment = b""
809
810 # encode public and private parts together
811 f_public_key = _FragList()
812 f_public_key.put_sshstr(key_type)
813 kformat.encode_public(private_key.public_key(), f_public_key)
814
815 f_secrets = _FragList([checkval, checkval])
816 f_secrets.put_sshstr(key_type)
817 kformat.encode_private(private_key, f_secrets)
818 f_secrets.put_sshstr(comment)
819 f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)])
820
821 # top-level structure
822 f_main = _FragList()
823 f_main.put_raw(_SK_MAGIC)
824 f_main.put_sshstr(ciphername)
825 f_main.put_sshstr(kdfname)
826 f_main.put_sshstr(f_kdfoptions)
827 f_main.put_u32(nkeys)
828 f_main.put_sshstr(f_public_key)
829 f_main.put_sshstr(f_secrets)
830
831 # copy result info bytearray
832 slen = f_secrets.size()
833 mlen = f_main.size()
834 buf = memoryview(bytearray(mlen + blklen))
835 f_main.render(buf)
836 ofs = mlen - slen
837
838 # encrypt in-place
839 if ciph is not None:
840 ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:])
841
842 return _ssh_pem_encode(buf[:mlen])
843
844
845SSHPublicKeyTypes = typing.Union[
846 ec.EllipticCurvePublicKey,
847 rsa.RSAPublicKey,
848 dsa.DSAPublicKey,
849 ed25519.Ed25519PublicKey,
850]
851
852SSHCertPublicKeyTypes = typing.Union[
853 ec.EllipticCurvePublicKey,
854 rsa.RSAPublicKey,
855 ed25519.Ed25519PublicKey,
856]
857
858
859class SSHCertificateType(enum.Enum):
860 USER = 1
861 HOST = 2
862
863
864class SSHCertificate:
865 def __init__(
866 self,
867 _nonce: memoryview,
868 _public_key: SSHPublicKeyTypes,
869 _serial: int,
870 _cctype: int,
871 _key_id: memoryview,
872 _valid_principals: list[bytes],
873 _valid_after: int,
874 _valid_before: int,
875 _critical_options: dict[bytes, bytes],
876 _extensions: dict[bytes, bytes],
877 _sig_type: memoryview,
878 _sig_key: memoryview,
879 _inner_sig_type: memoryview,
880 _signature: memoryview,
881 _tbs_cert_body: memoryview,
882 _cert_key_type: bytes,
883 _cert_body: memoryview,
884 ):
885 self._nonce = _nonce
886 self._public_key = _public_key
887 self._serial = _serial
888 try:
889 self._type = SSHCertificateType(_cctype)
890 except ValueError:
891 raise ValueError("Invalid certificate type")
892 self._key_id = _key_id
893 self._valid_principals = _valid_principals
894 self._valid_after = _valid_after
895 self._valid_before = _valid_before
896 self._critical_options = _critical_options
897 self._extensions = _extensions
898 self._sig_type = _sig_type
899 self._sig_key = _sig_key
900 self._inner_sig_type = _inner_sig_type
901 self._signature = _signature
902 self._cert_key_type = _cert_key_type
903 self._cert_body = _cert_body
904 self._tbs_cert_body = _tbs_cert_body
905
906 @property
907 def nonce(self) -> bytes:
908 return bytes(self._nonce)
909
910 def public_key(self) -> SSHCertPublicKeyTypes:
911 # make mypy happy until we remove DSA support entirely and
912 # the underlying union won't have a disallowed type
913 return typing.cast(SSHCertPublicKeyTypes, self._public_key)
914
915 @property
916 def serial(self) -> int:
917 return self._serial
918
919 @property
920 def type(self) -> SSHCertificateType:
921 return self._type
922
923 @property
924 def key_id(self) -> bytes:
925 return bytes(self._key_id)
926
927 @property
928 def valid_principals(self) -> list[bytes]:
929 return self._valid_principals
930
931 @property
932 def valid_before(self) -> int:
933 return self._valid_before
934
935 @property
936 def valid_after(self) -> int:
937 return self._valid_after
938
939 @property
940 def critical_options(self) -> dict[bytes, bytes]:
941 return self._critical_options
942
943 @property
944 def extensions(self) -> dict[bytes, bytes]:
945 return self._extensions
946
947 def signature_key(self) -> SSHCertPublicKeyTypes:
948 sigformat = _lookup_kformat(self._sig_type)
949 signature_key, sigkey_rest = sigformat.load_public(self._sig_key)
950 _check_empty(sigkey_rest)
951 return signature_key
952
953 def public_bytes(self) -> bytes:
954 return (
955 bytes(self._cert_key_type)
956 + b" "
957 + binascii.b2a_base64(bytes(self._cert_body), newline=False)
958 )
959
960 def verify_cert_signature(self) -> None:
961 signature_key = self.signature_key()
962 if isinstance(signature_key, ed25519.Ed25519PublicKey):
963 signature_key.verify(
964 bytes(self._signature), bytes(self._tbs_cert_body)
965 )
966 elif isinstance(signature_key, ec.EllipticCurvePublicKey):
967 # The signature is encoded as a pair of big-endian integers
968 r, data = _get_mpint(self._signature)
969 s, data = _get_mpint(data)
970 _check_empty(data)
971 computed_sig = asym_utils.encode_dss_signature(r, s)
972 hash_alg = _get_ec_hash_alg(signature_key.curve)
973 signature_key.verify(
974 computed_sig, bytes(self._tbs_cert_body), ec.ECDSA(hash_alg)
975 )
976 else:
977 assert isinstance(signature_key, rsa.RSAPublicKey)
978 if self._inner_sig_type == _SSH_RSA:
979 hash_alg = hashes.SHA1()
980 elif self._inner_sig_type == _SSH_RSA_SHA256:
981 hash_alg = hashes.SHA256()
982 else:
983 assert self._inner_sig_type == _SSH_RSA_SHA512
984 hash_alg = hashes.SHA512()
985 signature_key.verify(
986 bytes(self._signature),
987 bytes(self._tbs_cert_body),
988 padding.PKCS1v15(),
989 hash_alg,
990 )
991
992
993def _get_ec_hash_alg(curve: ec.EllipticCurve) -> hashes.HashAlgorithm:
994 if isinstance(curve, ec.SECP256R1):
995 return hashes.SHA256()
996 elif isinstance(curve, ec.SECP384R1):
997 return hashes.SHA384()
998 else:
999 assert isinstance(curve, ec.SECP521R1)
1000 return hashes.SHA512()
1001
1002
1003def _load_ssh_public_identity(
1004 data: bytes,
1005 _legacy_dsa_allowed=False,
1006) -> SSHCertificate | SSHPublicKeyTypes:
1007 utils._check_byteslike("data", data)
1008
1009 m = _SSH_PUBKEY_RC.match(data)
1010 if not m:
1011 raise ValueError("Invalid line format")
1012 key_type = orig_key_type = m.group(1)
1013 key_body = m.group(2)
1014 with_cert = False
1015 if key_type.endswith(_CERT_SUFFIX):
1016 with_cert = True
1017 key_type = key_type[: -len(_CERT_SUFFIX)]
1018 if key_type == _SSH_DSA and not _legacy_dsa_allowed:
1019 raise UnsupportedAlgorithm(
1020 "DSA keys aren't supported in SSH certificates"
1021 )
1022 kformat = _lookup_kformat(key_type)
1023
1024 try:
1025 rest = memoryview(binascii.a2b_base64(key_body))
1026 except (TypeError, binascii.Error):
1027 raise ValueError("Invalid format")
1028
1029 if with_cert:
1030 cert_body = rest
1031 inner_key_type, rest = _get_sshstr(rest)
1032 if inner_key_type != orig_key_type:
1033 raise ValueError("Invalid key format")
1034 if with_cert:
1035 nonce, rest = _get_sshstr(rest)
1036 public_key, rest = kformat.load_public(rest)
1037 if with_cert:
1038 serial, rest = _get_u64(rest)
1039 cctype, rest = _get_u32(rest)
1040 key_id, rest = _get_sshstr(rest)
1041 principals, rest = _get_sshstr(rest)
1042 valid_principals = []
1043 while principals:
1044 principal, principals = _get_sshstr(principals)
1045 valid_principals.append(bytes(principal))
1046 valid_after, rest = _get_u64(rest)
1047 valid_before, rest = _get_u64(rest)
1048 crit_options, rest = _get_sshstr(rest)
1049 critical_options = _parse_exts_opts(crit_options)
1050 exts, rest = _get_sshstr(rest)
1051 extensions = _parse_exts_opts(exts)
1052 # Get the reserved field, which is unused.
1053 _, rest = _get_sshstr(rest)
1054 sig_key_raw, rest = _get_sshstr(rest)
1055 sig_type, sig_key = _get_sshstr(sig_key_raw)
1056 if sig_type == _SSH_DSA and not _legacy_dsa_allowed:
1057 raise UnsupportedAlgorithm(
1058 "DSA signatures aren't supported in SSH certificates"
1059 )
1060 # Get the entire cert body and subtract the signature
1061 tbs_cert_body = cert_body[: -len(rest)]
1062 signature_raw, rest = _get_sshstr(rest)
1063 _check_empty(rest)
1064 inner_sig_type, sig_rest = _get_sshstr(signature_raw)
1065 # RSA certs can have multiple algorithm types
1066 if (
1067 sig_type == _SSH_RSA
1068 and inner_sig_type
1069 not in [_SSH_RSA_SHA256, _SSH_RSA_SHA512, _SSH_RSA]
1070 ) or (sig_type != _SSH_RSA and inner_sig_type != sig_type):
1071 raise ValueError("Signature key type does not match")
1072 signature, sig_rest = _get_sshstr(sig_rest)
1073 _check_empty(sig_rest)
1074 return SSHCertificate(
1075 nonce,
1076 public_key,
1077 serial,
1078 cctype,
1079 key_id,
1080 valid_principals,
1081 valid_after,
1082 valid_before,
1083 critical_options,
1084 extensions,
1085 sig_type,
1086 sig_key,
1087 inner_sig_type,
1088 signature,
1089 tbs_cert_body,
1090 orig_key_type,
1091 cert_body,
1092 )
1093 else:
1094 _check_empty(rest)
1095 return public_key
1096
1097
1098def load_ssh_public_identity(
1099 data: bytes,
1100) -> SSHCertificate | SSHPublicKeyTypes:
1101 return _load_ssh_public_identity(data)
1102
1103
1104def _parse_exts_opts(exts_opts: memoryview) -> dict[bytes, bytes]:
1105 result: dict[bytes, bytes] = {}
1106 last_name = None
1107 while exts_opts:
1108 name, exts_opts = _get_sshstr(exts_opts)
1109 bname: bytes = bytes(name)
1110 if bname in result:
1111 raise ValueError("Duplicate name")
1112 if last_name is not None and bname < last_name:
1113 raise ValueError("Fields not lexically sorted")
1114 value, exts_opts = _get_sshstr(exts_opts)
1115 if len(value) > 0:
1116 value, extra = _get_sshstr(value)
1117 if len(extra) > 0:
1118 raise ValueError("Unexpected extra data after value")
1119 result[bname] = bytes(value)
1120 last_name = bname
1121 return result
1122
1123
1124def load_ssh_public_key(
1125 data: bytes, backend: typing.Any = None
1126) -> SSHPublicKeyTypes:
1127 cert_or_key = _load_ssh_public_identity(data, _legacy_dsa_allowed=True)
1128 public_key: SSHPublicKeyTypes
1129 if isinstance(cert_or_key, SSHCertificate):
1130 public_key = cert_or_key.public_key()
1131 else:
1132 public_key = cert_or_key
1133
1134 if isinstance(public_key, dsa.DSAPublicKey):
1135 warnings.warn(
1136 "SSH DSA keys are deprecated and will be removed in a future "
1137 "release.",
1138 utils.DeprecatedIn40,
1139 stacklevel=2,
1140 )
1141 return public_key
1142
1143
1144def serialize_ssh_public_key(public_key: SSHPublicKeyTypes) -> bytes:
1145 """One-line public key format for OpenSSH"""
1146 if isinstance(public_key, dsa.DSAPublicKey):
1147 warnings.warn(
1148 "SSH DSA key support is deprecated and will be "
1149 "removed in a future release",
1150 utils.DeprecatedIn40,
1151 stacklevel=4,
1152 )
1153 key_type = _get_ssh_key_type(public_key)
1154 kformat = _lookup_kformat(key_type)
1155
1156 f_pub = _FragList()
1157 f_pub.put_sshstr(key_type)
1158 kformat.encode_public(public_key, f_pub)
1159
1160 pub = binascii.b2a_base64(f_pub.tobytes()).strip()
1161 return b"".join([key_type, b" ", pub])
1162
1163
1164SSHCertPrivateKeyTypes = typing.Union[
1165 ec.EllipticCurvePrivateKey,
1166 rsa.RSAPrivateKey,
1167 ed25519.Ed25519PrivateKey,
1168]
1169
1170
1171# This is an undocumented limit enforced in the openssh codebase for sshd and
1172# ssh-keygen, but it is undefined in the ssh certificates spec.
1173_SSHKEY_CERT_MAX_PRINCIPALS = 256
1174
1175
1176class SSHCertificateBuilder:
1177 def __init__(
1178 self,
1179 _public_key: SSHCertPublicKeyTypes | None = None,
1180 _serial: int | None = None,
1181 _type: SSHCertificateType | None = None,
1182 _key_id: bytes | None = None,
1183 _valid_principals: list[bytes] = [],
1184 _valid_for_all_principals: bool = False,
1185 _valid_before: int | None = None,
1186 _valid_after: int | None = None,
1187 _critical_options: list[tuple[bytes, bytes]] = [],
1188 _extensions: list[tuple[bytes, bytes]] = [],
1189 ):
1190 self._public_key = _public_key
1191 self._serial = _serial
1192 self._type = _type
1193 self._key_id = _key_id
1194 self._valid_principals = _valid_principals
1195 self._valid_for_all_principals = _valid_for_all_principals
1196 self._valid_before = _valid_before
1197 self._valid_after = _valid_after
1198 self._critical_options = _critical_options
1199 self._extensions = _extensions
1200
1201 def public_key(
1202 self, public_key: SSHCertPublicKeyTypes
1203 ) -> SSHCertificateBuilder:
1204 if not isinstance(
1205 public_key,
1206 (
1207 ec.EllipticCurvePublicKey,
1208 rsa.RSAPublicKey,
1209 ed25519.Ed25519PublicKey,
1210 ),
1211 ):
1212 raise TypeError("Unsupported key type")
1213 if self._public_key is not None:
1214 raise ValueError("public_key already set")
1215
1216 return SSHCertificateBuilder(
1217 _public_key=public_key,
1218 _serial=self._serial,
1219 _type=self._type,
1220 _key_id=self._key_id,
1221 _valid_principals=self._valid_principals,
1222 _valid_for_all_principals=self._valid_for_all_principals,
1223 _valid_before=self._valid_before,
1224 _valid_after=self._valid_after,
1225 _critical_options=self._critical_options,
1226 _extensions=self._extensions,
1227 )
1228
1229 def serial(self, serial: int) -> SSHCertificateBuilder:
1230 if not isinstance(serial, int):
1231 raise TypeError("serial must be an integer")
1232 if not 0 <= serial < 2**64:
1233 raise ValueError("serial must be between 0 and 2**64")
1234 if self._serial is not None:
1235 raise ValueError("serial already set")
1236
1237 return SSHCertificateBuilder(
1238 _public_key=self._public_key,
1239 _serial=serial,
1240 _type=self._type,
1241 _key_id=self._key_id,
1242 _valid_principals=self._valid_principals,
1243 _valid_for_all_principals=self._valid_for_all_principals,
1244 _valid_before=self._valid_before,
1245 _valid_after=self._valid_after,
1246 _critical_options=self._critical_options,
1247 _extensions=self._extensions,
1248 )
1249
1250 def type(self, type: SSHCertificateType) -> SSHCertificateBuilder:
1251 if not isinstance(type, SSHCertificateType):
1252 raise TypeError("type must be an SSHCertificateType")
1253 if self._type is not None:
1254 raise ValueError("type already set")
1255
1256 return SSHCertificateBuilder(
1257 _public_key=self._public_key,
1258 _serial=self._serial,
1259 _type=type,
1260 _key_id=self._key_id,
1261 _valid_principals=self._valid_principals,
1262 _valid_for_all_principals=self._valid_for_all_principals,
1263 _valid_before=self._valid_before,
1264 _valid_after=self._valid_after,
1265 _critical_options=self._critical_options,
1266 _extensions=self._extensions,
1267 )
1268
1269 def key_id(self, key_id: bytes) -> SSHCertificateBuilder:
1270 if not isinstance(key_id, bytes):
1271 raise TypeError("key_id must be bytes")
1272 if self._key_id is not None:
1273 raise ValueError("key_id already set")
1274
1275 return SSHCertificateBuilder(
1276 _public_key=self._public_key,
1277 _serial=self._serial,
1278 _type=self._type,
1279 _key_id=key_id,
1280 _valid_principals=self._valid_principals,
1281 _valid_for_all_principals=self._valid_for_all_principals,
1282 _valid_before=self._valid_before,
1283 _valid_after=self._valid_after,
1284 _critical_options=self._critical_options,
1285 _extensions=self._extensions,
1286 )
1287
1288 def valid_principals(
1289 self, valid_principals: list[bytes]
1290 ) -> SSHCertificateBuilder:
1291 if self._valid_for_all_principals:
1292 raise ValueError(
1293 "Principals can't be set because the cert is valid "
1294 "for all principals"
1295 )
1296 if (
1297 not all(isinstance(x, bytes) for x in valid_principals)
1298 or not valid_principals
1299 ):
1300 raise TypeError(
1301 "principals must be a list of bytes and can't be empty"
1302 )
1303 if self._valid_principals:
1304 raise ValueError("valid_principals already set")
1305
1306 if len(valid_principals) > _SSHKEY_CERT_MAX_PRINCIPALS:
1307 raise ValueError(
1308 "Reached or exceeded the maximum number of valid_principals"
1309 )
1310
1311 return SSHCertificateBuilder(
1312 _public_key=self._public_key,
1313 _serial=self._serial,
1314 _type=self._type,
1315 _key_id=self._key_id,
1316 _valid_principals=valid_principals,
1317 _valid_for_all_principals=self._valid_for_all_principals,
1318 _valid_before=self._valid_before,
1319 _valid_after=self._valid_after,
1320 _critical_options=self._critical_options,
1321 _extensions=self._extensions,
1322 )
1323
1324 def valid_for_all_principals(self):
1325 if self._valid_principals:
1326 raise ValueError(
1327 "valid_principals already set, can't set "
1328 "valid_for_all_principals"
1329 )
1330 if self._valid_for_all_principals:
1331 raise ValueError("valid_for_all_principals already set")
1332
1333 return SSHCertificateBuilder(
1334 _public_key=self._public_key,
1335 _serial=self._serial,
1336 _type=self._type,
1337 _key_id=self._key_id,
1338 _valid_principals=self._valid_principals,
1339 _valid_for_all_principals=True,
1340 _valid_before=self._valid_before,
1341 _valid_after=self._valid_after,
1342 _critical_options=self._critical_options,
1343 _extensions=self._extensions,
1344 )
1345
1346 def valid_before(self, valid_before: int | float) -> SSHCertificateBuilder:
1347 if not isinstance(valid_before, (int, float)):
1348 raise TypeError("valid_before must be an int or float")
1349 valid_before = int(valid_before)
1350 if valid_before < 0 or valid_before >= 2**64:
1351 raise ValueError("valid_before must [0, 2**64)")
1352 if self._valid_before is not None:
1353 raise ValueError("valid_before already set")
1354
1355 return SSHCertificateBuilder(
1356 _public_key=self._public_key,
1357 _serial=self._serial,
1358 _type=self._type,
1359 _key_id=self._key_id,
1360 _valid_principals=self._valid_principals,
1361 _valid_for_all_principals=self._valid_for_all_principals,
1362 _valid_before=valid_before,
1363 _valid_after=self._valid_after,
1364 _critical_options=self._critical_options,
1365 _extensions=self._extensions,
1366 )
1367
1368 def valid_after(self, valid_after: int | float) -> SSHCertificateBuilder:
1369 if not isinstance(valid_after, (int, float)):
1370 raise TypeError("valid_after must be an int or float")
1371 valid_after = int(valid_after)
1372 if valid_after < 0 or valid_after >= 2**64:
1373 raise ValueError("valid_after must [0, 2**64)")
1374 if self._valid_after is not None:
1375 raise ValueError("valid_after already set")
1376
1377 return SSHCertificateBuilder(
1378 _public_key=self._public_key,
1379 _serial=self._serial,
1380 _type=self._type,
1381 _key_id=self._key_id,
1382 _valid_principals=self._valid_principals,
1383 _valid_for_all_principals=self._valid_for_all_principals,
1384 _valid_before=self._valid_before,
1385 _valid_after=valid_after,
1386 _critical_options=self._critical_options,
1387 _extensions=self._extensions,
1388 )
1389
1390 def add_critical_option(
1391 self, name: bytes, value: bytes
1392 ) -> SSHCertificateBuilder:
1393 if not isinstance(name, bytes) or not isinstance(value, bytes):
1394 raise TypeError("name and value must be bytes")
1395 # This is O(n**2)
1396 if name in [name for name, _ in self._critical_options]:
1397 raise ValueError("Duplicate critical option name")
1398
1399 return SSHCertificateBuilder(
1400 _public_key=self._public_key,
1401 _serial=self._serial,
1402 _type=self._type,
1403 _key_id=self._key_id,
1404 _valid_principals=self._valid_principals,
1405 _valid_for_all_principals=self._valid_for_all_principals,
1406 _valid_before=self._valid_before,
1407 _valid_after=self._valid_after,
1408 _critical_options=[*self._critical_options, (name, value)],
1409 _extensions=self._extensions,
1410 )
1411
1412 def add_extension(
1413 self, name: bytes, value: bytes
1414 ) -> SSHCertificateBuilder:
1415 if not isinstance(name, bytes) or not isinstance(value, bytes):
1416 raise TypeError("name and value must be bytes")
1417 # This is O(n**2)
1418 if name in [name for name, _ in self._extensions]:
1419 raise ValueError("Duplicate extension name")
1420
1421 return SSHCertificateBuilder(
1422 _public_key=self._public_key,
1423 _serial=self._serial,
1424 _type=self._type,
1425 _key_id=self._key_id,
1426 _valid_principals=self._valid_principals,
1427 _valid_for_all_principals=self._valid_for_all_principals,
1428 _valid_before=self._valid_before,
1429 _valid_after=self._valid_after,
1430 _critical_options=self._critical_options,
1431 _extensions=[*self._extensions, (name, value)],
1432 )
1433
1434 def sign(self, private_key: SSHCertPrivateKeyTypes) -> SSHCertificate:
1435 if not isinstance(
1436 private_key,
1437 (
1438 ec.EllipticCurvePrivateKey,
1439 rsa.RSAPrivateKey,
1440 ed25519.Ed25519PrivateKey,
1441 ),
1442 ):
1443 raise TypeError("Unsupported private key type")
1444
1445 if self._public_key is None:
1446 raise ValueError("public_key must be set")
1447
1448 # Not required
1449 serial = 0 if self._serial is None else self._serial
1450
1451 if self._type is None:
1452 raise ValueError("type must be set")
1453
1454 # Not required
1455 key_id = b"" if self._key_id is None else self._key_id
1456
1457 # A zero length list is valid, but means the certificate
1458 # is valid for any principal of the specified type. We require
1459 # the user to explicitly set valid_for_all_principals to get
1460 # that behavior.
1461 if not self._valid_principals and not self._valid_for_all_principals:
1462 raise ValueError(
1463 "valid_principals must be set if valid_for_all_principals "
1464 "is False"
1465 )
1466
1467 if self._valid_before is None:
1468 raise ValueError("valid_before must be set")
1469
1470 if self._valid_after is None:
1471 raise ValueError("valid_after must be set")
1472
1473 if self._valid_after > self._valid_before:
1474 raise ValueError("valid_after must be earlier than valid_before")
1475
1476 # lexically sort our byte strings
1477 self._critical_options.sort(key=lambda x: x[0])
1478 self._extensions.sort(key=lambda x: x[0])
1479
1480 key_type = _get_ssh_key_type(self._public_key)
1481 cert_prefix = key_type + _CERT_SUFFIX
1482
1483 # Marshal the bytes to be signed
1484 nonce = os.urandom(32)
1485 kformat = _lookup_kformat(key_type)
1486 f = _FragList()
1487 f.put_sshstr(cert_prefix)
1488 f.put_sshstr(nonce)
1489 kformat.encode_public(self._public_key, f)
1490 f.put_u64(serial)
1491 f.put_u32(self._type.value)
1492 f.put_sshstr(key_id)
1493 fprincipals = _FragList()
1494 for p in self._valid_principals:
1495 fprincipals.put_sshstr(p)
1496 f.put_sshstr(fprincipals.tobytes())
1497 f.put_u64(self._valid_after)
1498 f.put_u64(self._valid_before)
1499 fcrit = _FragList()
1500 for name, value in self._critical_options:
1501 fcrit.put_sshstr(name)
1502 if len(value) > 0:
1503 foptval = _FragList()
1504 foptval.put_sshstr(value)
1505 fcrit.put_sshstr(foptval.tobytes())
1506 else:
1507 fcrit.put_sshstr(value)
1508 f.put_sshstr(fcrit.tobytes())
1509 fext = _FragList()
1510 for name, value in self._extensions:
1511 fext.put_sshstr(name)
1512 if len(value) > 0:
1513 fextval = _FragList()
1514 fextval.put_sshstr(value)
1515 fext.put_sshstr(fextval.tobytes())
1516 else:
1517 fext.put_sshstr(value)
1518 f.put_sshstr(fext.tobytes())
1519 f.put_sshstr(b"") # RESERVED FIELD
1520 # encode CA public key
1521 ca_type = _get_ssh_key_type(private_key)
1522 caformat = _lookup_kformat(ca_type)
1523 caf = _FragList()
1524 caf.put_sshstr(ca_type)
1525 caformat.encode_public(private_key.public_key(), caf)
1526 f.put_sshstr(caf.tobytes())
1527 # Sigs according to the rules defined for the CA's public key
1528 # (RFC4253 section 6.6 for ssh-rsa, RFC5656 for ECDSA,
1529 # and RFC8032 for Ed25519).
1530 if isinstance(private_key, ed25519.Ed25519PrivateKey):
1531 signature = private_key.sign(f.tobytes())
1532 fsig = _FragList()
1533 fsig.put_sshstr(ca_type)
1534 fsig.put_sshstr(signature)
1535 f.put_sshstr(fsig.tobytes())
1536 elif isinstance(private_key, ec.EllipticCurvePrivateKey):
1537 hash_alg = _get_ec_hash_alg(private_key.curve)
1538 signature = private_key.sign(f.tobytes(), ec.ECDSA(hash_alg))
1539 r, s = asym_utils.decode_dss_signature(signature)
1540 fsig = _FragList()
1541 fsig.put_sshstr(ca_type)
1542 fsigblob = _FragList()
1543 fsigblob.put_mpint(r)
1544 fsigblob.put_mpint(s)
1545 fsig.put_sshstr(fsigblob.tobytes())
1546 f.put_sshstr(fsig.tobytes())
1547
1548 else:
1549 assert isinstance(private_key, rsa.RSAPrivateKey)
1550 # Just like Golang, we're going to use SHA512 for RSA
1551 # https://cs.opensource.google/go/x/crypto/+/refs/tags/
1552 # v0.4.0:ssh/certs.go;l=445
1553 # RFC 8332 defines SHA256 and 512 as options
1554 fsig = _FragList()
1555 fsig.put_sshstr(_SSH_RSA_SHA512)
1556 signature = private_key.sign(
1557 f.tobytes(), padding.PKCS1v15(), hashes.SHA512()
1558 )
1559 fsig.put_sshstr(signature)
1560 f.put_sshstr(fsig.tobytes())
1561
1562 cert_data = binascii.b2a_base64(f.tobytes()).strip()
1563 # load_ssh_public_identity returns a union, but this is
1564 # guaranteed to be an SSHCertificate, so we cast to make
1565 # mypy happy.
1566 return typing.cast(
1567 SSHCertificate,
1568 load_ssh_public_identity(b"".join([cert_prefix, b" ", cert_data])),
1569 )