1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import annotations
6
7import binascii
8import enum
9import os
10import re
11import typing
12import warnings
13from base64 import encodebytes as _base64_encode
14from dataclasses import dataclass
15
16from cryptography import utils
17from cryptography.exceptions import UnsupportedAlgorithm
18from cryptography.hazmat.primitives import hashes
19from cryptography.hazmat.primitives.asymmetric import (
20 dsa,
21 ec,
22 ed25519,
23 padding,
24 rsa,
25)
26from cryptography.hazmat.primitives.asymmetric import utils as asym_utils
27from cryptography.hazmat.primitives.ciphers import (
28 AEADDecryptionContext,
29 Cipher,
30 algorithms,
31 modes,
32)
33from cryptography.hazmat.primitives.serialization import (
34 Encoding,
35 KeySerializationEncryption,
36 NoEncryption,
37 PrivateFormat,
38 PublicFormat,
39 _KeySerializationEncryption,
40)
41
42try:
43 from bcrypt import kdf as _bcrypt_kdf
44
45 _bcrypt_supported = True
46except ImportError:
47 _bcrypt_supported = False
48
49 def _bcrypt_kdf(
50 password: bytes,
51 salt: bytes,
52 desired_key_bytes: int,
53 rounds: int,
54 ignore_few_rounds: bool = False,
55 ) -> bytes:
56 raise UnsupportedAlgorithm("Need bcrypt module")
57
58
59_SSH_ED25519 = b"ssh-ed25519"
60_SSH_RSA = b"ssh-rsa"
61_SSH_DSA = b"ssh-dss"
62_ECDSA_NISTP256 = b"ecdsa-sha2-nistp256"
63_ECDSA_NISTP384 = b"ecdsa-sha2-nistp384"
64_ECDSA_NISTP521 = b"ecdsa-sha2-nistp521"
65_CERT_SUFFIX = b"-cert-v01@openssh.com"
66
67# U2F application string suffixed pubkey
68_SK_SSH_ED25519 = b"sk-ssh-ed25519@openssh.com"
69_SK_SSH_ECDSA_NISTP256 = b"sk-ecdsa-sha2-nistp256@openssh.com"
70
71# These are not key types, only algorithms, so they cannot appear
72# as a public key type
73_SSH_RSA_SHA256 = b"rsa-sha2-256"
74_SSH_RSA_SHA512 = b"rsa-sha2-512"
75
76_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)")
77_SK_MAGIC = b"openssh-key-v1\0"
78_SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----"
79_SK_END = b"-----END OPENSSH PRIVATE KEY-----"
80_BCRYPT = b"bcrypt"
81_NONE = b"none"
82_DEFAULT_CIPHER = b"aes256-ctr"
83_DEFAULT_ROUNDS = 16
84
85# re is only way to work on bytes-like data
86_PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL)
87
88# padding for max blocksize
89_PADDING = memoryview(bytearray(range(1, 1 + 16)))
90
91
92@dataclass
93class _SSHCipher:
94 alg: type[algorithms.AES]
95 key_len: int
96 mode: type[modes.CTR] | type[modes.CBC] | type[modes.GCM]
97 block_len: int
98 iv_len: int
99 tag_len: int | None
100 is_aead: bool
101
102
103# ciphers that are actually used in key wrapping
104_SSH_CIPHERS: dict[bytes, _SSHCipher] = {
105 b"aes256-ctr": _SSHCipher(
106 alg=algorithms.AES,
107 key_len=32,
108 mode=modes.CTR,
109 block_len=16,
110 iv_len=16,
111 tag_len=None,
112 is_aead=False,
113 ),
114 b"aes256-cbc": _SSHCipher(
115 alg=algorithms.AES,
116 key_len=32,
117 mode=modes.CBC,
118 block_len=16,
119 iv_len=16,
120 tag_len=None,
121 is_aead=False,
122 ),
123 b"aes256-gcm@openssh.com": _SSHCipher(
124 alg=algorithms.AES,
125 key_len=32,
126 mode=modes.GCM,
127 block_len=16,
128 iv_len=12,
129 tag_len=16,
130 is_aead=True,
131 ),
132}
133
134# map local curve name to key type
135_ECDSA_KEY_TYPE = {
136 "secp256r1": _ECDSA_NISTP256,
137 "secp384r1": _ECDSA_NISTP384,
138 "secp521r1": _ECDSA_NISTP521,
139}
140
141
142def _get_ssh_key_type(key: SSHPrivateKeyTypes | SSHPublicKeyTypes) -> bytes:
143 if isinstance(key, ec.EllipticCurvePrivateKey):
144 key_type = _ecdsa_key_type(key.public_key())
145 elif isinstance(key, ec.EllipticCurvePublicKey):
146 key_type = _ecdsa_key_type(key)
147 elif isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)):
148 key_type = _SSH_RSA
149 elif isinstance(key, (dsa.DSAPrivateKey, dsa.DSAPublicKey)):
150 key_type = _SSH_DSA
151 elif isinstance(
152 key, (ed25519.Ed25519PrivateKey, ed25519.Ed25519PublicKey)
153 ):
154 key_type = _SSH_ED25519
155 else:
156 raise ValueError("Unsupported key type")
157
158 return key_type
159
160
161def _ecdsa_key_type(public_key: ec.EllipticCurvePublicKey) -> bytes:
162 """Return SSH key_type and curve_name for private key."""
163 curve = public_key.curve
164 if curve.name not in _ECDSA_KEY_TYPE:
165 raise ValueError(
166 f"Unsupported curve for ssh private key: {curve.name!r}"
167 )
168 return _ECDSA_KEY_TYPE[curve.name]
169
170
171def _ssh_pem_encode(
172 data: utils.Buffer,
173 prefix: bytes = _SK_START + b"\n",
174 suffix: bytes = _SK_END + b"\n",
175) -> bytes:
176 return b"".join([prefix, _base64_encode(data), suffix])
177
178
179def _check_block_size(data: utils.Buffer, block_len: int) -> None:
180 """Require data to be full blocks"""
181 if not data or len(data) % block_len != 0:
182 raise ValueError("Corrupt data: missing padding")
183
184
185def _check_empty(data: utils.Buffer) -> None:
186 """All data should have been parsed."""
187 if data:
188 raise ValueError("Corrupt data: unparsed data")
189
190
191def _init_cipher(
192 ciphername: bytes,
193 password: bytes | None,
194 salt: bytes,
195 rounds: int,
196) -> Cipher[modes.CBC | modes.CTR | modes.GCM]:
197 """Generate key + iv and return cipher."""
198 if not password:
199 raise TypeError(
200 "Key is password-protected, but password was not provided."
201 )
202
203 ciph = _SSH_CIPHERS[ciphername]
204 seed = _bcrypt_kdf(
205 password, salt, ciph.key_len + ciph.iv_len, rounds, True
206 )
207 return Cipher(
208 ciph.alg(seed[: ciph.key_len]),
209 ciph.mode(seed[ciph.key_len :]),
210 )
211
212
213def _get_u32(data: memoryview) -> tuple[int, memoryview]:
214 """Uint32"""
215 if len(data) < 4:
216 raise ValueError("Invalid data")
217 return int.from_bytes(data[:4], byteorder="big"), data[4:]
218
219
220def _get_u64(data: memoryview) -> tuple[int, memoryview]:
221 """Uint64"""
222 if len(data) < 8:
223 raise ValueError("Invalid data")
224 return int.from_bytes(data[:8], byteorder="big"), data[8:]
225
226
227def _get_sshstr(data: memoryview) -> tuple[memoryview, memoryview]:
228 """Bytes with u32 length prefix"""
229 n, data = _get_u32(data)
230 if n > len(data):
231 raise ValueError("Invalid data")
232 return data[:n], data[n:]
233
234
235def _get_mpint(data: memoryview) -> tuple[int, memoryview]:
236 """Big integer."""
237 val, data = _get_sshstr(data)
238 if val and val[0] > 0x7F:
239 raise ValueError("Invalid data")
240 return int.from_bytes(val, "big"), data
241
242
243def _to_mpint(val: int) -> bytes:
244 """Storage format for signed bigint."""
245 if val < 0:
246 raise ValueError("negative mpint not allowed")
247 if not val:
248 return b""
249 nbytes = (val.bit_length() + 8) // 8
250 return utils.int_to_bytes(val, nbytes)
251
252
253class _FragList:
254 """Build recursive structure without data copy."""
255
256 flist: list[utils.Buffer]
257
258 def __init__(self, init: list[utils.Buffer] | None = None) -> None:
259 self.flist = []
260 if init:
261 self.flist.extend(init)
262
263 def put_raw(self, val: utils.Buffer) -> None:
264 """Add plain bytes"""
265 self.flist.append(val)
266
267 def put_u32(self, val: int) -> None:
268 """Big-endian uint32"""
269 self.flist.append(val.to_bytes(length=4, byteorder="big"))
270
271 def put_u64(self, val: int) -> None:
272 """Big-endian uint64"""
273 self.flist.append(val.to_bytes(length=8, byteorder="big"))
274
275 def put_sshstr(self, val: bytes | _FragList) -> None:
276 """Bytes prefixed with u32 length"""
277 if isinstance(val, (bytes, memoryview, bytearray)):
278 self.put_u32(len(val))
279 self.flist.append(val)
280 else:
281 self.put_u32(val.size())
282 self.flist.extend(val.flist)
283
284 def put_mpint(self, val: int) -> None:
285 """Big-endian bigint prefixed with u32 length"""
286 self.put_sshstr(_to_mpint(val))
287
288 def size(self) -> int:
289 """Current number of bytes"""
290 return sum(map(len, self.flist))
291
292 def render(self, dstbuf: memoryview, pos: int = 0) -> int:
293 """Write into bytearray"""
294 for frag in self.flist:
295 flen = len(frag)
296 start, pos = pos, pos + flen
297 dstbuf[start:pos] = frag
298 return pos
299
300 def tobytes(self) -> bytes:
301 """Return as bytes"""
302 buf = memoryview(bytearray(self.size()))
303 self.render(buf)
304 return buf.tobytes()
305
306
307class _SSHFormatRSA:
308 """Format for RSA keys.
309
310 Public:
311 mpint e, n
312 Private:
313 mpint n, e, d, iqmp, p, q
314 """
315
316 def get_public(
317 self, data: memoryview
318 ) -> tuple[tuple[int, int], memoryview]:
319 """RSA public fields"""
320 e, data = _get_mpint(data)
321 n, data = _get_mpint(data)
322 return (e, n), data
323
324 def load_public(
325 self, data: memoryview
326 ) -> tuple[rsa.RSAPublicKey, memoryview]:
327 """Make RSA public key from data."""
328 (e, n), data = self.get_public(data)
329 public_numbers = rsa.RSAPublicNumbers(e, n)
330 public_key = public_numbers.public_key()
331 return public_key, data
332
333 def load_private(
334 self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool
335 ) -> tuple[rsa.RSAPrivateKey, memoryview]:
336 """Make RSA private key from data."""
337 n, data = _get_mpint(data)
338 e, data = _get_mpint(data)
339 d, data = _get_mpint(data)
340 iqmp, data = _get_mpint(data)
341 p, data = _get_mpint(data)
342 q, data = _get_mpint(data)
343
344 if (e, n) != pubfields:
345 raise ValueError("Corrupt data: rsa field mismatch")
346 dmp1 = rsa.rsa_crt_dmp1(d, p)
347 dmq1 = rsa.rsa_crt_dmq1(d, q)
348 public_numbers = rsa.RSAPublicNumbers(e, n)
349 private_numbers = rsa.RSAPrivateNumbers(
350 p, q, d, dmp1, dmq1, iqmp, public_numbers
351 )
352 private_key = private_numbers.private_key(
353 unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation
354 )
355 return private_key, data
356
357 def encode_public(
358 self, public_key: rsa.RSAPublicKey, f_pub: _FragList
359 ) -> None:
360 """Write RSA public key"""
361 pubn = public_key.public_numbers()
362 f_pub.put_mpint(pubn.e)
363 f_pub.put_mpint(pubn.n)
364
365 def encode_private(
366 self, private_key: rsa.RSAPrivateKey, f_priv: _FragList
367 ) -> None:
368 """Write RSA private key"""
369 private_numbers = private_key.private_numbers()
370 public_numbers = private_numbers.public_numbers
371
372 f_priv.put_mpint(public_numbers.n)
373 f_priv.put_mpint(public_numbers.e)
374
375 f_priv.put_mpint(private_numbers.d)
376 f_priv.put_mpint(private_numbers.iqmp)
377 f_priv.put_mpint(private_numbers.p)
378 f_priv.put_mpint(private_numbers.q)
379
380
381class _SSHFormatDSA:
382 """Format for DSA keys.
383
384 Public:
385 mpint p, q, g, y
386 Private:
387 mpint p, q, g, y, x
388 """
389
390 def get_public(self, data: memoryview) -> tuple[tuple, memoryview]:
391 """DSA public fields"""
392 p, data = _get_mpint(data)
393 q, data = _get_mpint(data)
394 g, data = _get_mpint(data)
395 y, data = _get_mpint(data)
396 return (p, q, g, y), data
397
398 def load_public(
399 self, data: memoryview
400 ) -> tuple[dsa.DSAPublicKey, memoryview]:
401 """Make DSA public key from data."""
402 (p, q, g, y), data = self.get_public(data)
403 parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
404 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
405 self._validate(public_numbers)
406 public_key = public_numbers.public_key()
407 return public_key, data
408
409 def load_private(
410 self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool
411 ) -> tuple[dsa.DSAPrivateKey, memoryview]:
412 """Make DSA private key from data."""
413 (p, q, g, y), data = self.get_public(data)
414 x, data = _get_mpint(data)
415
416 if (p, q, g, y) != pubfields:
417 raise ValueError("Corrupt data: dsa field mismatch")
418 parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
419 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
420 self._validate(public_numbers)
421 private_numbers = dsa.DSAPrivateNumbers(x, public_numbers)
422 private_key = private_numbers.private_key()
423 return private_key, data
424
425 def encode_public(
426 self, public_key: dsa.DSAPublicKey, f_pub: _FragList
427 ) -> None:
428 """Write DSA public key"""
429 public_numbers = public_key.public_numbers()
430 parameter_numbers = public_numbers.parameter_numbers
431 self._validate(public_numbers)
432
433 f_pub.put_mpint(parameter_numbers.p)
434 f_pub.put_mpint(parameter_numbers.q)
435 f_pub.put_mpint(parameter_numbers.g)
436 f_pub.put_mpint(public_numbers.y)
437
438 def encode_private(
439 self, private_key: dsa.DSAPrivateKey, f_priv: _FragList
440 ) -> None:
441 """Write DSA private key"""
442 self.encode_public(private_key.public_key(), f_priv)
443 f_priv.put_mpint(private_key.private_numbers().x)
444
445 def _validate(self, public_numbers: dsa.DSAPublicNumbers) -> None:
446 parameter_numbers = public_numbers.parameter_numbers
447 if parameter_numbers.p.bit_length() != 1024:
448 raise ValueError("SSH supports only 1024 bit DSA keys")
449
450
451class _SSHFormatECDSA:
452 """Format for ECDSA keys.
453
454 Public:
455 str curve
456 bytes point
457 Private:
458 str curve
459 bytes point
460 mpint secret
461 """
462
463 def __init__(self, ssh_curve_name: bytes, curve: ec.EllipticCurve):
464 self.ssh_curve_name = ssh_curve_name
465 self.curve = curve
466
467 def get_public(
468 self, data: memoryview
469 ) -> tuple[tuple[memoryview, memoryview], memoryview]:
470 """ECDSA public fields"""
471 curve, data = _get_sshstr(data)
472 point, data = _get_sshstr(data)
473 if curve != self.ssh_curve_name:
474 raise ValueError("Curve name mismatch")
475 if point[0] != 4:
476 raise NotImplementedError("Need uncompressed point")
477 return (curve, point), data
478
479 def load_public(
480 self, data: memoryview
481 ) -> tuple[ec.EllipticCurvePublicKey, memoryview]:
482 """Make ECDSA public key from data."""
483 (_, point), data = self.get_public(data)
484 public_key = ec.EllipticCurvePublicKey.from_encoded_point(
485 self.curve, point.tobytes()
486 )
487 return public_key, data
488
489 def load_private(
490 self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool
491 ) -> tuple[ec.EllipticCurvePrivateKey, memoryview]:
492 """Make ECDSA private key from data."""
493 (curve_name, point), data = self.get_public(data)
494 secret, data = _get_mpint(data)
495
496 if (curve_name, point) != pubfields:
497 raise ValueError("Corrupt data: ecdsa field mismatch")
498 private_key = ec.derive_private_key(secret, self.curve)
499 return private_key, data
500
501 def encode_public(
502 self, public_key: ec.EllipticCurvePublicKey, f_pub: _FragList
503 ) -> None:
504 """Write ECDSA public key"""
505 point = public_key.public_bytes(
506 Encoding.X962, PublicFormat.UncompressedPoint
507 )
508 f_pub.put_sshstr(self.ssh_curve_name)
509 f_pub.put_sshstr(point)
510
511 def encode_private(
512 self, private_key: ec.EllipticCurvePrivateKey, f_priv: _FragList
513 ) -> None:
514 """Write ECDSA private key"""
515 public_key = private_key.public_key()
516 private_numbers = private_key.private_numbers()
517
518 self.encode_public(public_key, f_priv)
519 f_priv.put_mpint(private_numbers.private_value)
520
521
522class _SSHFormatEd25519:
523 """Format for Ed25519 keys.
524
525 Public:
526 bytes point
527 Private:
528 bytes point
529 bytes secret_and_point
530 """
531
532 def get_public(
533 self, data: memoryview
534 ) -> tuple[tuple[memoryview], memoryview]:
535 """Ed25519 public fields"""
536 point, data = _get_sshstr(data)
537 return (point,), data
538
539 def load_public(
540 self, data: memoryview
541 ) -> tuple[ed25519.Ed25519PublicKey, memoryview]:
542 """Make Ed25519 public key from data."""
543 (point,), data = self.get_public(data)
544 public_key = ed25519.Ed25519PublicKey.from_public_bytes(
545 point.tobytes()
546 )
547 return public_key, data
548
549 def load_private(
550 self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool
551 ) -> tuple[ed25519.Ed25519PrivateKey, memoryview]:
552 """Make Ed25519 private key from data."""
553 (point,), data = self.get_public(data)
554 keypair, data = _get_sshstr(data)
555
556 secret = keypair[:32]
557 point2 = keypair[32:]
558 if point != point2 or (point,) != pubfields:
559 raise ValueError("Corrupt data: ed25519 field mismatch")
560 private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret)
561 return private_key, data
562
563 def encode_public(
564 self, public_key: ed25519.Ed25519PublicKey, f_pub: _FragList
565 ) -> None:
566 """Write Ed25519 public key"""
567 raw_public_key = public_key.public_bytes(
568 Encoding.Raw, PublicFormat.Raw
569 )
570 f_pub.put_sshstr(raw_public_key)
571
572 def encode_private(
573 self, private_key: ed25519.Ed25519PrivateKey, f_priv: _FragList
574 ) -> None:
575 """Write Ed25519 private key"""
576 public_key = private_key.public_key()
577 raw_private_key = private_key.private_bytes(
578 Encoding.Raw, PrivateFormat.Raw, NoEncryption()
579 )
580 raw_public_key = public_key.public_bytes(
581 Encoding.Raw, PublicFormat.Raw
582 )
583 f_keypair = _FragList([raw_private_key, raw_public_key])
584
585 self.encode_public(public_key, f_priv)
586 f_priv.put_sshstr(f_keypair)
587
588
589def load_application(data) -> tuple[memoryview, memoryview]:
590 """
591 U2F application strings
592 """
593 application, data = _get_sshstr(data)
594 if not application.tobytes().startswith(b"ssh:"):
595 raise ValueError(
596 "U2F application string does not start with b'ssh:' "
597 f"({application})"
598 )
599 return application, data
600
601
602class _SSHFormatSKEd25519:
603 """
604 The format of a sk-ssh-ed25519@openssh.com public key is:
605
606 string "sk-ssh-ed25519@openssh.com"
607 string public key
608 string application (user-specified, but typically "ssh:")
609 """
610
611 def load_public(
612 self, data: memoryview
613 ) -> tuple[ed25519.Ed25519PublicKey, memoryview]:
614 """Make Ed25519 public key from data."""
615 public_key, data = _lookup_kformat(_SSH_ED25519).load_public(data)
616 _, data = load_application(data)
617 return public_key, data
618
619 def get_public(self, data: memoryview) -> typing.NoReturn:
620 # Confusingly `get_public` is an entry point used by private key
621 # loading.
622 raise UnsupportedAlgorithm(
623 "sk-ssh-ed25519 private keys cannot be loaded"
624 )
625
626
627class _SSHFormatSKECDSA:
628 """
629 The format of a sk-ecdsa-sha2-nistp256@openssh.com public key is:
630
631 string "sk-ecdsa-sha2-nistp256@openssh.com"
632 string curve name
633 ec_point Q
634 string application (user-specified, but typically "ssh:")
635 """
636
637 def load_public(
638 self, data: memoryview
639 ) -> tuple[ec.EllipticCurvePublicKey, memoryview]:
640 """Make ECDSA public key from data."""
641 public_key, data = _lookup_kformat(_ECDSA_NISTP256).load_public(data)
642 _, data = load_application(data)
643 return public_key, data
644
645 def get_public(self, data: memoryview) -> typing.NoReturn:
646 # Confusingly `get_public` is an entry point used by private key
647 # loading.
648 raise UnsupportedAlgorithm(
649 "sk-ecdsa-sha2-nistp256 private keys cannot be loaded"
650 )
651
652
653_KEY_FORMATS = {
654 _SSH_RSA: _SSHFormatRSA(),
655 _SSH_DSA: _SSHFormatDSA(),
656 _SSH_ED25519: _SSHFormatEd25519(),
657 _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()),
658 _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()),
659 _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()),
660 _SK_SSH_ED25519: _SSHFormatSKEd25519(),
661 _SK_SSH_ECDSA_NISTP256: _SSHFormatSKECDSA(),
662}
663
664
665def _lookup_kformat(key_type: utils.Buffer):
666 """Return valid format or throw error"""
667 if not isinstance(key_type, bytes):
668 key_type = memoryview(key_type).tobytes()
669 if key_type in _KEY_FORMATS:
670 return _KEY_FORMATS[key_type]
671 raise UnsupportedAlgorithm(f"Unsupported key type: {key_type!r}")
672
673
674SSHPrivateKeyTypes = typing.Union[
675 ec.EllipticCurvePrivateKey,
676 rsa.RSAPrivateKey,
677 dsa.DSAPrivateKey,
678 ed25519.Ed25519PrivateKey,
679]
680
681
682def load_ssh_private_key(
683 data: utils.Buffer,
684 password: bytes | None,
685 backend: typing.Any = None,
686 *,
687 unsafe_skip_rsa_key_validation: bool = False,
688) -> SSHPrivateKeyTypes:
689 """Load private key from OpenSSH custom encoding."""
690 utils._check_byteslike("data", data)
691 if password is not None:
692 utils._check_bytes("password", password)
693
694 m = _PEM_RC.search(data)
695 if not m:
696 raise ValueError("Not OpenSSH private key format")
697 p1 = m.start(1)
698 p2 = m.end(1)
699 data = binascii.a2b_base64(memoryview(data)[p1:p2])
700 if not data.startswith(_SK_MAGIC):
701 raise ValueError("Not OpenSSH private key format")
702 data = memoryview(data)[len(_SK_MAGIC) :]
703
704 # parse header
705 ciphername, data = _get_sshstr(data)
706 kdfname, data = _get_sshstr(data)
707 kdfoptions, data = _get_sshstr(data)
708 nkeys, data = _get_u32(data)
709 if nkeys != 1:
710 raise ValueError("Only one key supported")
711
712 # load public key data
713 pubdata, data = _get_sshstr(data)
714 pub_key_type, pubdata = _get_sshstr(pubdata)
715 kformat = _lookup_kformat(pub_key_type)
716 pubfields, pubdata = kformat.get_public(pubdata)
717 _check_empty(pubdata)
718
719 if ciphername != _NONE or kdfname != _NONE:
720 ciphername_bytes = ciphername.tobytes()
721 if ciphername_bytes not in _SSH_CIPHERS:
722 raise UnsupportedAlgorithm(
723 f"Unsupported cipher: {ciphername_bytes!r}"
724 )
725 if kdfname != _BCRYPT:
726 raise UnsupportedAlgorithm(f"Unsupported KDF: {kdfname!r}")
727 blklen = _SSH_CIPHERS[ciphername_bytes].block_len
728 tag_len = _SSH_CIPHERS[ciphername_bytes].tag_len
729 # load secret data
730 edata, data = _get_sshstr(data)
731 # see https://bugzilla.mindrot.org/show_bug.cgi?id=3553 for
732 # information about how OpenSSH handles AEAD tags
733 if _SSH_CIPHERS[ciphername_bytes].is_aead:
734 tag = bytes(data)
735 if len(tag) != tag_len:
736 raise ValueError("Corrupt data: invalid tag length for cipher")
737 else:
738 _check_empty(data)
739 _check_block_size(edata, blklen)
740 salt, kbuf = _get_sshstr(kdfoptions)
741 rounds, kbuf = _get_u32(kbuf)
742 _check_empty(kbuf)
743 ciph = _init_cipher(ciphername_bytes, password, salt.tobytes(), rounds)
744 dec = ciph.decryptor()
745 edata = memoryview(dec.update(edata))
746 if _SSH_CIPHERS[ciphername_bytes].is_aead:
747 assert isinstance(dec, AEADDecryptionContext)
748 _check_empty(dec.finalize_with_tag(tag))
749 else:
750 # _check_block_size requires data to be a full block so there
751 # should be no output from finalize
752 _check_empty(dec.finalize())
753 else:
754 if password:
755 raise TypeError(
756 "Password was given but private key is not encrypted."
757 )
758 # load secret data
759 edata, data = _get_sshstr(data)
760 _check_empty(data)
761 blklen = 8
762 _check_block_size(edata, blklen)
763 ck1, edata = _get_u32(edata)
764 ck2, edata = _get_u32(edata)
765 if ck1 != ck2:
766 raise ValueError("Corrupt data: broken checksum")
767
768 # load per-key struct
769 key_type, edata = _get_sshstr(edata)
770 if key_type != pub_key_type:
771 raise ValueError("Corrupt data: key type mismatch")
772 private_key, edata = kformat.load_private(
773 edata,
774 pubfields,
775 unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation,
776 )
777 # We don't use the comment
778 _, edata = _get_sshstr(edata)
779
780 # yes, SSH does padding check *after* all other parsing is done.
781 # need to follow as it writes zero-byte padding too.
782 if edata != _PADDING[: len(edata)]:
783 raise ValueError("Corrupt data: invalid padding")
784
785 if isinstance(private_key, dsa.DSAPrivateKey):
786 warnings.warn(
787 "SSH DSA keys are deprecated and will be removed in a future "
788 "release.",
789 utils.DeprecatedIn40,
790 stacklevel=2,
791 )
792
793 return private_key
794
795
796def _serialize_ssh_private_key(
797 private_key: SSHPrivateKeyTypes,
798 password: bytes,
799 encryption_algorithm: KeySerializationEncryption,
800) -> bytes:
801 """Serialize private key with OpenSSH custom encoding."""
802 utils._check_bytes("password", password)
803 if isinstance(private_key, dsa.DSAPrivateKey):
804 warnings.warn(
805 "SSH DSA key support is deprecated and will be "
806 "removed in a future release",
807 utils.DeprecatedIn40,
808 stacklevel=4,
809 )
810
811 key_type = _get_ssh_key_type(private_key)
812 kformat = _lookup_kformat(key_type)
813
814 # setup parameters
815 f_kdfoptions = _FragList()
816 if password:
817 ciphername = _DEFAULT_CIPHER
818 blklen = _SSH_CIPHERS[ciphername].block_len
819 kdfname = _BCRYPT
820 rounds = _DEFAULT_ROUNDS
821 if (
822 isinstance(encryption_algorithm, _KeySerializationEncryption)
823 and encryption_algorithm._kdf_rounds is not None
824 ):
825 rounds = encryption_algorithm._kdf_rounds
826 salt = os.urandom(16)
827 f_kdfoptions.put_sshstr(salt)
828 f_kdfoptions.put_u32(rounds)
829 ciph = _init_cipher(ciphername, password, salt, rounds)
830 else:
831 ciphername = kdfname = _NONE
832 blklen = 8
833 ciph = None
834 nkeys = 1
835 checkval = os.urandom(4)
836 comment = b""
837
838 # encode public and private parts together
839 f_public_key = _FragList()
840 f_public_key.put_sshstr(key_type)
841 kformat.encode_public(private_key.public_key(), f_public_key)
842
843 f_secrets = _FragList([checkval, checkval])
844 f_secrets.put_sshstr(key_type)
845 kformat.encode_private(private_key, f_secrets)
846 f_secrets.put_sshstr(comment)
847 f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)])
848
849 # top-level structure
850 f_main = _FragList()
851 f_main.put_raw(_SK_MAGIC)
852 f_main.put_sshstr(ciphername)
853 f_main.put_sshstr(kdfname)
854 f_main.put_sshstr(f_kdfoptions)
855 f_main.put_u32(nkeys)
856 f_main.put_sshstr(f_public_key)
857 f_main.put_sshstr(f_secrets)
858
859 # copy result info bytearray
860 slen = f_secrets.size()
861 mlen = f_main.size()
862 buf = memoryview(bytearray(mlen + blklen))
863 f_main.render(buf)
864 ofs = mlen - slen
865
866 # encrypt in-place
867 if ciph is not None:
868 ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:])
869
870 return _ssh_pem_encode(buf[:mlen])
871
872
873SSHPublicKeyTypes = typing.Union[
874 ec.EllipticCurvePublicKey,
875 rsa.RSAPublicKey,
876 dsa.DSAPublicKey,
877 ed25519.Ed25519PublicKey,
878]
879
880SSHCertPublicKeyTypes = typing.Union[
881 ec.EllipticCurvePublicKey,
882 rsa.RSAPublicKey,
883 ed25519.Ed25519PublicKey,
884]
885
886
887class SSHCertificateType(enum.Enum):
888 USER = 1
889 HOST = 2
890
891
892class SSHCertificate:
893 def __init__(
894 self,
895 _nonce: memoryview,
896 _public_key: SSHPublicKeyTypes,
897 _serial: int,
898 _cctype: int,
899 _key_id: memoryview,
900 _valid_principals: list[bytes],
901 _valid_after: int,
902 _valid_before: int,
903 _critical_options: dict[bytes, bytes],
904 _extensions: dict[bytes, bytes],
905 _sig_type: memoryview,
906 _sig_key: memoryview,
907 _inner_sig_type: memoryview,
908 _signature: memoryview,
909 _tbs_cert_body: memoryview,
910 _cert_key_type: bytes,
911 _cert_body: memoryview,
912 ):
913 self._nonce = _nonce
914 self._public_key = _public_key
915 self._serial = _serial
916 try:
917 self._type = SSHCertificateType(_cctype)
918 except ValueError:
919 raise ValueError("Invalid certificate type")
920 self._key_id = _key_id
921 self._valid_principals = _valid_principals
922 self._valid_after = _valid_after
923 self._valid_before = _valid_before
924 self._critical_options = _critical_options
925 self._extensions = _extensions
926 self._sig_type = _sig_type
927 self._sig_key = _sig_key
928 self._inner_sig_type = _inner_sig_type
929 self._signature = _signature
930 self._cert_key_type = _cert_key_type
931 self._cert_body = _cert_body
932 self._tbs_cert_body = _tbs_cert_body
933
934 @property
935 def nonce(self) -> bytes:
936 return bytes(self._nonce)
937
938 def public_key(self) -> SSHCertPublicKeyTypes:
939 # make mypy happy until we remove DSA support entirely and
940 # the underlying union won't have a disallowed type
941 return typing.cast(SSHCertPublicKeyTypes, self._public_key)
942
943 @property
944 def serial(self) -> int:
945 return self._serial
946
947 @property
948 def type(self) -> SSHCertificateType:
949 return self._type
950
951 @property
952 def key_id(self) -> bytes:
953 return bytes(self._key_id)
954
955 @property
956 def valid_principals(self) -> list[bytes]:
957 return self._valid_principals
958
959 @property
960 def valid_before(self) -> int:
961 return self._valid_before
962
963 @property
964 def valid_after(self) -> int:
965 return self._valid_after
966
967 @property
968 def critical_options(self) -> dict[bytes, bytes]:
969 return self._critical_options
970
971 @property
972 def extensions(self) -> dict[bytes, bytes]:
973 return self._extensions
974
975 def signature_key(self) -> SSHCertPublicKeyTypes:
976 sigformat = _lookup_kformat(self._sig_type)
977 signature_key, sigkey_rest = sigformat.load_public(self._sig_key)
978 _check_empty(sigkey_rest)
979 return signature_key
980
981 def public_bytes(self) -> bytes:
982 return (
983 bytes(self._cert_key_type)
984 + b" "
985 + binascii.b2a_base64(bytes(self._cert_body), newline=False)
986 )
987
988 def verify_cert_signature(self) -> None:
989 signature_key = self.signature_key()
990 if isinstance(signature_key, ed25519.Ed25519PublicKey):
991 signature_key.verify(
992 bytes(self._signature), bytes(self._tbs_cert_body)
993 )
994 elif isinstance(signature_key, ec.EllipticCurvePublicKey):
995 # The signature is encoded as a pair of big-endian integers
996 r, data = _get_mpint(self._signature)
997 s, data = _get_mpint(data)
998 _check_empty(data)
999 computed_sig = asym_utils.encode_dss_signature(r, s)
1000 hash_alg = _get_ec_hash_alg(signature_key.curve)
1001 signature_key.verify(
1002 computed_sig, bytes(self._tbs_cert_body), ec.ECDSA(hash_alg)
1003 )
1004 else:
1005 assert isinstance(signature_key, rsa.RSAPublicKey)
1006 if self._inner_sig_type == _SSH_RSA:
1007 hash_alg = hashes.SHA1()
1008 elif self._inner_sig_type == _SSH_RSA_SHA256:
1009 hash_alg = hashes.SHA256()
1010 else:
1011 assert self._inner_sig_type == _SSH_RSA_SHA512
1012 hash_alg = hashes.SHA512()
1013 signature_key.verify(
1014 bytes(self._signature),
1015 bytes(self._tbs_cert_body),
1016 padding.PKCS1v15(),
1017 hash_alg,
1018 )
1019
1020
1021def _get_ec_hash_alg(curve: ec.EllipticCurve) -> hashes.HashAlgorithm:
1022 if isinstance(curve, ec.SECP256R1):
1023 return hashes.SHA256()
1024 elif isinstance(curve, ec.SECP384R1):
1025 return hashes.SHA384()
1026 else:
1027 assert isinstance(curve, ec.SECP521R1)
1028 return hashes.SHA512()
1029
1030
1031def _load_ssh_public_identity(
1032 data: utils.Buffer,
1033 _legacy_dsa_allowed=False,
1034) -> SSHCertificate | SSHPublicKeyTypes:
1035 utils._check_byteslike("data", data)
1036
1037 m = _SSH_PUBKEY_RC.match(data)
1038 if not m:
1039 raise ValueError("Invalid line format")
1040 key_type = orig_key_type = m.group(1)
1041 key_body = m.group(2)
1042 with_cert = False
1043 if key_type.endswith(_CERT_SUFFIX):
1044 with_cert = True
1045 key_type = key_type[: -len(_CERT_SUFFIX)]
1046 if key_type == _SSH_DSA and not _legacy_dsa_allowed:
1047 raise UnsupportedAlgorithm(
1048 "DSA keys aren't supported in SSH certificates"
1049 )
1050 kformat = _lookup_kformat(key_type)
1051
1052 try:
1053 rest = memoryview(binascii.a2b_base64(key_body))
1054 except (TypeError, binascii.Error):
1055 raise ValueError("Invalid format")
1056
1057 if with_cert:
1058 cert_body = rest
1059 inner_key_type, rest = _get_sshstr(rest)
1060 if inner_key_type != orig_key_type:
1061 raise ValueError("Invalid key format")
1062 if with_cert:
1063 nonce, rest = _get_sshstr(rest)
1064 public_key, rest = kformat.load_public(rest)
1065 if with_cert:
1066 serial, rest = _get_u64(rest)
1067 cctype, rest = _get_u32(rest)
1068 key_id, rest = _get_sshstr(rest)
1069 principals, rest = _get_sshstr(rest)
1070 valid_principals = []
1071 while principals:
1072 principal, principals = _get_sshstr(principals)
1073 valid_principals.append(bytes(principal))
1074 valid_after, rest = _get_u64(rest)
1075 valid_before, rest = _get_u64(rest)
1076 crit_options, rest = _get_sshstr(rest)
1077 critical_options = _parse_exts_opts(crit_options)
1078 exts, rest = _get_sshstr(rest)
1079 extensions = _parse_exts_opts(exts)
1080 # Get the reserved field, which is unused.
1081 _, rest = _get_sshstr(rest)
1082 sig_key_raw, rest = _get_sshstr(rest)
1083 sig_type, sig_key = _get_sshstr(sig_key_raw)
1084 if sig_type == _SSH_DSA and not _legacy_dsa_allowed:
1085 raise UnsupportedAlgorithm(
1086 "DSA signatures aren't supported in SSH certificates"
1087 )
1088 # Get the entire cert body and subtract the signature
1089 tbs_cert_body = cert_body[: -len(rest)]
1090 signature_raw, rest = _get_sshstr(rest)
1091 _check_empty(rest)
1092 inner_sig_type, sig_rest = _get_sshstr(signature_raw)
1093 # RSA certs can have multiple algorithm types
1094 if (
1095 sig_type == _SSH_RSA
1096 and inner_sig_type
1097 not in [_SSH_RSA_SHA256, _SSH_RSA_SHA512, _SSH_RSA]
1098 ) or (sig_type != _SSH_RSA and inner_sig_type != sig_type):
1099 raise ValueError("Signature key type does not match")
1100 signature, sig_rest = _get_sshstr(sig_rest)
1101 _check_empty(sig_rest)
1102 return SSHCertificate(
1103 nonce,
1104 public_key,
1105 serial,
1106 cctype,
1107 key_id,
1108 valid_principals,
1109 valid_after,
1110 valid_before,
1111 critical_options,
1112 extensions,
1113 sig_type,
1114 sig_key,
1115 inner_sig_type,
1116 signature,
1117 tbs_cert_body,
1118 orig_key_type,
1119 cert_body,
1120 )
1121 else:
1122 _check_empty(rest)
1123 return public_key
1124
1125
1126def load_ssh_public_identity(
1127 data: bytes,
1128) -> SSHCertificate | SSHPublicKeyTypes:
1129 return _load_ssh_public_identity(data)
1130
1131
1132def _parse_exts_opts(exts_opts: memoryview) -> dict[bytes, bytes]:
1133 result: dict[bytes, bytes] = {}
1134 last_name = None
1135 while exts_opts:
1136 name, exts_opts = _get_sshstr(exts_opts)
1137 bname: bytes = bytes(name)
1138 if bname in result:
1139 raise ValueError("Duplicate name")
1140 if last_name is not None and bname < last_name:
1141 raise ValueError("Fields not lexically sorted")
1142 value, exts_opts = _get_sshstr(exts_opts)
1143 if len(value) > 0:
1144 value, extra = _get_sshstr(value)
1145 if len(extra) > 0:
1146 raise ValueError("Unexpected extra data after value")
1147 result[bname] = bytes(value)
1148 last_name = bname
1149 return result
1150
1151
1152def ssh_key_fingerprint(
1153 key: SSHPublicKeyTypes,
1154 hash_algorithm: hashes.MD5 | hashes.SHA256,
1155) -> bytes:
1156 if not isinstance(hash_algorithm, (hashes.MD5, hashes.SHA256)):
1157 raise TypeError("hash_algorithm must be either MD5 or SHA256")
1158
1159 key_type = _get_ssh_key_type(key)
1160 kformat = _lookup_kformat(key_type)
1161
1162 f_pub = _FragList()
1163 f_pub.put_sshstr(key_type)
1164 kformat.encode_public(key, f_pub)
1165
1166 ssh_binary_data = f_pub.tobytes()
1167
1168 # Hash the binary data
1169 hash_obj = hashes.Hash(hash_algorithm)
1170 hash_obj.update(ssh_binary_data)
1171 return hash_obj.finalize()
1172
1173
1174def load_ssh_public_key(
1175 data: utils.Buffer, backend: typing.Any = None
1176) -> SSHPublicKeyTypes:
1177 cert_or_key = _load_ssh_public_identity(data, _legacy_dsa_allowed=True)
1178 public_key: SSHPublicKeyTypes
1179 if isinstance(cert_or_key, SSHCertificate):
1180 public_key = cert_or_key.public_key()
1181 else:
1182 public_key = cert_or_key
1183
1184 if isinstance(public_key, dsa.DSAPublicKey):
1185 warnings.warn(
1186 "SSH DSA keys are deprecated and will be removed in a future "
1187 "release.",
1188 utils.DeprecatedIn40,
1189 stacklevel=2,
1190 )
1191 return public_key
1192
1193
1194def serialize_ssh_public_key(public_key: SSHPublicKeyTypes) -> bytes:
1195 """One-line public key format for OpenSSH"""
1196 if isinstance(public_key, dsa.DSAPublicKey):
1197 warnings.warn(
1198 "SSH DSA key support is deprecated and will be "
1199 "removed in a future release",
1200 utils.DeprecatedIn40,
1201 stacklevel=4,
1202 )
1203 key_type = _get_ssh_key_type(public_key)
1204 kformat = _lookup_kformat(key_type)
1205
1206 f_pub = _FragList()
1207 f_pub.put_sshstr(key_type)
1208 kformat.encode_public(public_key, f_pub)
1209
1210 pub = binascii.b2a_base64(f_pub.tobytes()).strip()
1211 return b"".join([key_type, b" ", pub])
1212
1213
1214SSHCertPrivateKeyTypes = typing.Union[
1215 ec.EllipticCurvePrivateKey,
1216 rsa.RSAPrivateKey,
1217 ed25519.Ed25519PrivateKey,
1218]
1219
1220
1221# This is an undocumented limit enforced in the openssh codebase for sshd and
1222# ssh-keygen, but it is undefined in the ssh certificates spec.
1223_SSHKEY_CERT_MAX_PRINCIPALS = 256
1224
1225
1226class SSHCertificateBuilder:
1227 def __init__(
1228 self,
1229 _public_key: SSHCertPublicKeyTypes | None = None,
1230 _serial: int | None = None,
1231 _type: SSHCertificateType | None = None,
1232 _key_id: bytes | None = None,
1233 _valid_principals: list[bytes] = [],
1234 _valid_for_all_principals: bool = False,
1235 _valid_before: int | None = None,
1236 _valid_after: int | None = None,
1237 _critical_options: list[tuple[bytes, bytes]] = [],
1238 _extensions: list[tuple[bytes, bytes]] = [],
1239 ):
1240 self._public_key = _public_key
1241 self._serial = _serial
1242 self._type = _type
1243 self._key_id = _key_id
1244 self._valid_principals = _valid_principals
1245 self._valid_for_all_principals = _valid_for_all_principals
1246 self._valid_before = _valid_before
1247 self._valid_after = _valid_after
1248 self._critical_options = _critical_options
1249 self._extensions = _extensions
1250
1251 def public_key(
1252 self, public_key: SSHCertPublicKeyTypes
1253 ) -> SSHCertificateBuilder:
1254 if not isinstance(
1255 public_key,
1256 (
1257 ec.EllipticCurvePublicKey,
1258 rsa.RSAPublicKey,
1259 ed25519.Ed25519PublicKey,
1260 ),
1261 ):
1262 raise TypeError("Unsupported key type")
1263 if self._public_key is not None:
1264 raise ValueError("public_key already set")
1265
1266 return SSHCertificateBuilder(
1267 _public_key=public_key,
1268 _serial=self._serial,
1269 _type=self._type,
1270 _key_id=self._key_id,
1271 _valid_principals=self._valid_principals,
1272 _valid_for_all_principals=self._valid_for_all_principals,
1273 _valid_before=self._valid_before,
1274 _valid_after=self._valid_after,
1275 _critical_options=self._critical_options,
1276 _extensions=self._extensions,
1277 )
1278
1279 def serial(self, serial: int) -> SSHCertificateBuilder:
1280 if not isinstance(serial, int):
1281 raise TypeError("serial must be an integer")
1282 if not 0 <= serial < 2**64:
1283 raise ValueError("serial must be between 0 and 2**64")
1284 if self._serial is not None:
1285 raise ValueError("serial already set")
1286
1287 return SSHCertificateBuilder(
1288 _public_key=self._public_key,
1289 _serial=serial,
1290 _type=self._type,
1291 _key_id=self._key_id,
1292 _valid_principals=self._valid_principals,
1293 _valid_for_all_principals=self._valid_for_all_principals,
1294 _valid_before=self._valid_before,
1295 _valid_after=self._valid_after,
1296 _critical_options=self._critical_options,
1297 _extensions=self._extensions,
1298 )
1299
1300 def type(self, type: SSHCertificateType) -> SSHCertificateBuilder:
1301 if not isinstance(type, SSHCertificateType):
1302 raise TypeError("type must be an SSHCertificateType")
1303 if self._type is not None:
1304 raise ValueError("type already set")
1305
1306 return SSHCertificateBuilder(
1307 _public_key=self._public_key,
1308 _serial=self._serial,
1309 _type=type,
1310 _key_id=self._key_id,
1311 _valid_principals=self._valid_principals,
1312 _valid_for_all_principals=self._valid_for_all_principals,
1313 _valid_before=self._valid_before,
1314 _valid_after=self._valid_after,
1315 _critical_options=self._critical_options,
1316 _extensions=self._extensions,
1317 )
1318
1319 def key_id(self, key_id: bytes) -> SSHCertificateBuilder:
1320 if not isinstance(key_id, bytes):
1321 raise TypeError("key_id must be bytes")
1322 if self._key_id is not None:
1323 raise ValueError("key_id already set")
1324
1325 return SSHCertificateBuilder(
1326 _public_key=self._public_key,
1327 _serial=self._serial,
1328 _type=self._type,
1329 _key_id=key_id,
1330 _valid_principals=self._valid_principals,
1331 _valid_for_all_principals=self._valid_for_all_principals,
1332 _valid_before=self._valid_before,
1333 _valid_after=self._valid_after,
1334 _critical_options=self._critical_options,
1335 _extensions=self._extensions,
1336 )
1337
1338 def valid_principals(
1339 self, valid_principals: list[bytes]
1340 ) -> SSHCertificateBuilder:
1341 if self._valid_for_all_principals:
1342 raise ValueError(
1343 "Principals can't be set because the cert is valid "
1344 "for all principals"
1345 )
1346 if (
1347 not all(isinstance(x, bytes) for x in valid_principals)
1348 or not valid_principals
1349 ):
1350 raise TypeError(
1351 "principals must be a list of bytes and can't be empty"
1352 )
1353 if self._valid_principals:
1354 raise ValueError("valid_principals already set")
1355
1356 if len(valid_principals) > _SSHKEY_CERT_MAX_PRINCIPALS:
1357 raise ValueError(
1358 "Reached or exceeded the maximum number of valid_principals"
1359 )
1360
1361 return SSHCertificateBuilder(
1362 _public_key=self._public_key,
1363 _serial=self._serial,
1364 _type=self._type,
1365 _key_id=self._key_id,
1366 _valid_principals=valid_principals,
1367 _valid_for_all_principals=self._valid_for_all_principals,
1368 _valid_before=self._valid_before,
1369 _valid_after=self._valid_after,
1370 _critical_options=self._critical_options,
1371 _extensions=self._extensions,
1372 )
1373
1374 def valid_for_all_principals(self):
1375 if self._valid_principals:
1376 raise ValueError(
1377 "valid_principals already set, can't set "
1378 "valid_for_all_principals"
1379 )
1380 if self._valid_for_all_principals:
1381 raise ValueError("valid_for_all_principals already set")
1382
1383 return SSHCertificateBuilder(
1384 _public_key=self._public_key,
1385 _serial=self._serial,
1386 _type=self._type,
1387 _key_id=self._key_id,
1388 _valid_principals=self._valid_principals,
1389 _valid_for_all_principals=True,
1390 _valid_before=self._valid_before,
1391 _valid_after=self._valid_after,
1392 _critical_options=self._critical_options,
1393 _extensions=self._extensions,
1394 )
1395
1396 def valid_before(self, valid_before: int | float) -> SSHCertificateBuilder:
1397 if not isinstance(valid_before, (int, float)):
1398 raise TypeError("valid_before must be an int or float")
1399 valid_before = int(valid_before)
1400 if valid_before < 0 or valid_before >= 2**64:
1401 raise ValueError("valid_before must [0, 2**64)")
1402 if self._valid_before is not None:
1403 raise ValueError("valid_before already set")
1404
1405 return SSHCertificateBuilder(
1406 _public_key=self._public_key,
1407 _serial=self._serial,
1408 _type=self._type,
1409 _key_id=self._key_id,
1410 _valid_principals=self._valid_principals,
1411 _valid_for_all_principals=self._valid_for_all_principals,
1412 _valid_before=valid_before,
1413 _valid_after=self._valid_after,
1414 _critical_options=self._critical_options,
1415 _extensions=self._extensions,
1416 )
1417
1418 def valid_after(self, valid_after: int | float) -> SSHCertificateBuilder:
1419 if not isinstance(valid_after, (int, float)):
1420 raise TypeError("valid_after must be an int or float")
1421 valid_after = int(valid_after)
1422 if valid_after < 0 or valid_after >= 2**64:
1423 raise ValueError("valid_after must [0, 2**64)")
1424 if self._valid_after is not None:
1425 raise ValueError("valid_after already set")
1426
1427 return SSHCertificateBuilder(
1428 _public_key=self._public_key,
1429 _serial=self._serial,
1430 _type=self._type,
1431 _key_id=self._key_id,
1432 _valid_principals=self._valid_principals,
1433 _valid_for_all_principals=self._valid_for_all_principals,
1434 _valid_before=self._valid_before,
1435 _valid_after=valid_after,
1436 _critical_options=self._critical_options,
1437 _extensions=self._extensions,
1438 )
1439
1440 def add_critical_option(
1441 self, name: bytes, value: bytes
1442 ) -> SSHCertificateBuilder:
1443 if not isinstance(name, bytes) or not isinstance(value, bytes):
1444 raise TypeError("name and value must be bytes")
1445 # This is O(n**2)
1446 if name in [name for name, _ in self._critical_options]:
1447 raise ValueError("Duplicate critical option name")
1448
1449 return SSHCertificateBuilder(
1450 _public_key=self._public_key,
1451 _serial=self._serial,
1452 _type=self._type,
1453 _key_id=self._key_id,
1454 _valid_principals=self._valid_principals,
1455 _valid_for_all_principals=self._valid_for_all_principals,
1456 _valid_before=self._valid_before,
1457 _valid_after=self._valid_after,
1458 _critical_options=[*self._critical_options, (name, value)],
1459 _extensions=self._extensions,
1460 )
1461
1462 def add_extension(
1463 self, name: bytes, value: bytes
1464 ) -> SSHCertificateBuilder:
1465 if not isinstance(name, bytes) or not isinstance(value, bytes):
1466 raise TypeError("name and value must be bytes")
1467 # This is O(n**2)
1468 if name in [name for name, _ in self._extensions]:
1469 raise ValueError("Duplicate extension name")
1470
1471 return SSHCertificateBuilder(
1472 _public_key=self._public_key,
1473 _serial=self._serial,
1474 _type=self._type,
1475 _key_id=self._key_id,
1476 _valid_principals=self._valid_principals,
1477 _valid_for_all_principals=self._valid_for_all_principals,
1478 _valid_before=self._valid_before,
1479 _valid_after=self._valid_after,
1480 _critical_options=self._critical_options,
1481 _extensions=[*self._extensions, (name, value)],
1482 )
1483
1484 def sign(self, private_key: SSHCertPrivateKeyTypes) -> SSHCertificate:
1485 if not isinstance(
1486 private_key,
1487 (
1488 ec.EllipticCurvePrivateKey,
1489 rsa.RSAPrivateKey,
1490 ed25519.Ed25519PrivateKey,
1491 ),
1492 ):
1493 raise TypeError("Unsupported private key type")
1494
1495 if self._public_key is None:
1496 raise ValueError("public_key must be set")
1497
1498 # Not required
1499 serial = 0 if self._serial is None else self._serial
1500
1501 if self._type is None:
1502 raise ValueError("type must be set")
1503
1504 # Not required
1505 key_id = b"" if self._key_id is None else self._key_id
1506
1507 # A zero length list is valid, but means the certificate
1508 # is valid for any principal of the specified type. We require
1509 # the user to explicitly set valid_for_all_principals to get
1510 # that behavior.
1511 if not self._valid_principals and not self._valid_for_all_principals:
1512 raise ValueError(
1513 "valid_principals must be set if valid_for_all_principals "
1514 "is False"
1515 )
1516
1517 if self._valid_before is None:
1518 raise ValueError("valid_before must be set")
1519
1520 if self._valid_after is None:
1521 raise ValueError("valid_after must be set")
1522
1523 if self._valid_after > self._valid_before:
1524 raise ValueError("valid_after must be earlier than valid_before")
1525
1526 # lexically sort our byte strings
1527 self._critical_options.sort(key=lambda x: x[0])
1528 self._extensions.sort(key=lambda x: x[0])
1529
1530 key_type = _get_ssh_key_type(self._public_key)
1531 cert_prefix = key_type + _CERT_SUFFIX
1532
1533 # Marshal the bytes to be signed
1534 nonce = os.urandom(32)
1535 kformat = _lookup_kformat(key_type)
1536 f = _FragList()
1537 f.put_sshstr(cert_prefix)
1538 f.put_sshstr(nonce)
1539 kformat.encode_public(self._public_key, f)
1540 f.put_u64(serial)
1541 f.put_u32(self._type.value)
1542 f.put_sshstr(key_id)
1543 fprincipals = _FragList()
1544 for p in self._valid_principals:
1545 fprincipals.put_sshstr(p)
1546 f.put_sshstr(fprincipals.tobytes())
1547 f.put_u64(self._valid_after)
1548 f.put_u64(self._valid_before)
1549 fcrit = _FragList()
1550 for name, value in self._critical_options:
1551 fcrit.put_sshstr(name)
1552 if len(value) > 0:
1553 foptval = _FragList()
1554 foptval.put_sshstr(value)
1555 fcrit.put_sshstr(foptval.tobytes())
1556 else:
1557 fcrit.put_sshstr(value)
1558 f.put_sshstr(fcrit.tobytes())
1559 fext = _FragList()
1560 for name, value in self._extensions:
1561 fext.put_sshstr(name)
1562 if len(value) > 0:
1563 fextval = _FragList()
1564 fextval.put_sshstr(value)
1565 fext.put_sshstr(fextval.tobytes())
1566 else:
1567 fext.put_sshstr(value)
1568 f.put_sshstr(fext.tobytes())
1569 f.put_sshstr(b"") # RESERVED FIELD
1570 # encode CA public key
1571 ca_type = _get_ssh_key_type(private_key)
1572 caformat = _lookup_kformat(ca_type)
1573 caf = _FragList()
1574 caf.put_sshstr(ca_type)
1575 caformat.encode_public(private_key.public_key(), caf)
1576 f.put_sshstr(caf.tobytes())
1577 # Sigs according to the rules defined for the CA's public key
1578 # (RFC4253 section 6.6 for ssh-rsa, RFC5656 for ECDSA,
1579 # and RFC8032 for Ed25519).
1580 if isinstance(private_key, ed25519.Ed25519PrivateKey):
1581 signature = private_key.sign(f.tobytes())
1582 fsig = _FragList()
1583 fsig.put_sshstr(ca_type)
1584 fsig.put_sshstr(signature)
1585 f.put_sshstr(fsig.tobytes())
1586 elif isinstance(private_key, ec.EllipticCurvePrivateKey):
1587 hash_alg = _get_ec_hash_alg(private_key.curve)
1588 signature = private_key.sign(f.tobytes(), ec.ECDSA(hash_alg))
1589 r, s = asym_utils.decode_dss_signature(signature)
1590 fsig = _FragList()
1591 fsig.put_sshstr(ca_type)
1592 fsigblob = _FragList()
1593 fsigblob.put_mpint(r)
1594 fsigblob.put_mpint(s)
1595 fsig.put_sshstr(fsigblob.tobytes())
1596 f.put_sshstr(fsig.tobytes())
1597
1598 else:
1599 assert isinstance(private_key, rsa.RSAPrivateKey)
1600 # Just like Golang, we're going to use SHA512 for RSA
1601 # https://cs.opensource.google/go/x/crypto/+/refs/tags/
1602 # v0.4.0:ssh/certs.go;l=445
1603 # RFC 8332 defines SHA256 and 512 as options
1604 fsig = _FragList()
1605 fsig.put_sshstr(_SSH_RSA_SHA512)
1606 signature = private_key.sign(
1607 f.tobytes(), padding.PKCS1v15(), hashes.SHA512()
1608 )
1609 fsig.put_sshstr(signature)
1610 f.put_sshstr(fsig.tobytes())
1611
1612 cert_data = binascii.b2a_base64(f.tobytes()).strip()
1613 # load_ssh_public_identity returns a union, but this is
1614 # guaranteed to be an SSHCertificate, so we cast to make
1615 # mypy happy.
1616 return typing.cast(
1617 SSHCertificate,
1618 load_ssh_public_identity(b"".join([cert_prefix, b" ", cert_data])),
1619 )