Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/cryptography/hazmat/primitives/serialization/ssh.py: 22%
415 statements
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
6import binascii
7import os
8import re
9import typing
10from base64 import encodebytes as _base64_encode
12from cryptography import utils
13from cryptography.exceptions import UnsupportedAlgorithm
14from cryptography.hazmat.primitives.asymmetric import dsa, ec, ed25519, rsa
15from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
16from cryptography.hazmat.primitives.serialization import (
17 Encoding,
18 KeySerializationEncryption,
19 NoEncryption,
20 PrivateFormat,
21 PublicFormat,
22 _KeySerializationEncryption,
23)
25try:
26 from bcrypt import kdf as _bcrypt_kdf
28 _bcrypt_supported = True
29except ImportError:
30 _bcrypt_supported = False
32 def _bcrypt_kdf(
33 password: bytes,
34 salt: bytes,
35 desired_key_bytes: int,
36 rounds: int,
37 ignore_few_rounds: bool = False,
38 ) -> bytes:
39 raise UnsupportedAlgorithm("Need bcrypt module")
42_SSH_ED25519 = b"ssh-ed25519"
43_SSH_RSA = b"ssh-rsa"
44_SSH_DSA = b"ssh-dss"
45_ECDSA_NISTP256 = b"ecdsa-sha2-nistp256"
46_ECDSA_NISTP384 = b"ecdsa-sha2-nistp384"
47_ECDSA_NISTP521 = b"ecdsa-sha2-nistp521"
48_CERT_SUFFIX = b"-cert-v01@openssh.com"
50_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)")
51_SK_MAGIC = b"openssh-key-v1\0"
52_SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----"
53_SK_END = b"-----END OPENSSH PRIVATE KEY-----"
54_BCRYPT = b"bcrypt"
55_NONE = b"none"
56_DEFAULT_CIPHER = b"aes256-ctr"
57_DEFAULT_ROUNDS = 16
59# re is only way to work on bytes-like data
60_PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL)
62# padding for max blocksize
63_PADDING = memoryview(bytearray(range(1, 1 + 16)))
65# ciphers that are actually used in key wrapping
66_SSH_CIPHERS: typing.Dict[
67 bytes,
68 typing.Tuple[
69 typing.Type[algorithms.AES],
70 int,
71 typing.Union[typing.Type[modes.CTR], typing.Type[modes.CBC]],
72 int,
73 ],
74] = {
75 b"aes256-ctr": (algorithms.AES, 32, modes.CTR, 16),
76 b"aes256-cbc": (algorithms.AES, 32, modes.CBC, 16),
77}
79# map local curve name to key type
80_ECDSA_KEY_TYPE = {
81 "secp256r1": _ECDSA_NISTP256,
82 "secp384r1": _ECDSA_NISTP384,
83 "secp521r1": _ECDSA_NISTP521,
84}
87def _ecdsa_key_type(public_key: ec.EllipticCurvePublicKey) -> bytes:
88 """Return SSH key_type and curve_name for private key."""
89 curve = public_key.curve
90 if curve.name not in _ECDSA_KEY_TYPE:
91 raise ValueError(
92 f"Unsupported curve for ssh private key: {curve.name!r}"
93 )
94 return _ECDSA_KEY_TYPE[curve.name]
97def _ssh_pem_encode(
98 data: bytes,
99 prefix: bytes = _SK_START + b"\n",
100 suffix: bytes = _SK_END + b"\n",
101) -> bytes:
102 return b"".join([prefix, _base64_encode(data), suffix])
105def _check_block_size(data: bytes, block_len: int) -> None:
106 """Require data to be full blocks"""
107 if not data or len(data) % block_len != 0:
108 raise ValueError("Corrupt data: missing padding")
111def _check_empty(data: bytes) -> None:
112 """All data should have been parsed."""
113 if data:
114 raise ValueError("Corrupt data: unparsed data")
117def _init_cipher(
118 ciphername: bytes,
119 password: typing.Optional[bytes],
120 salt: bytes,
121 rounds: int,
122) -> Cipher[typing.Union[modes.CBC, modes.CTR]]:
123 """Generate key + iv and return cipher."""
124 if not password:
125 raise ValueError("Key is password-protected.")
127 algo, key_len, mode, iv_len = _SSH_CIPHERS[ciphername]
128 seed = _bcrypt_kdf(password, salt, key_len + iv_len, rounds, True)
129 return Cipher(algo(seed[:key_len]), mode(seed[key_len:]))
132def _get_u32(data: memoryview) -> typing.Tuple[int, memoryview]:
133 """Uint32"""
134 if len(data) < 4:
135 raise ValueError("Invalid data")
136 return int.from_bytes(data[:4], byteorder="big"), data[4:]
139def _get_u64(data: memoryview) -> typing.Tuple[int, memoryview]:
140 """Uint64"""
141 if len(data) < 8:
142 raise ValueError("Invalid data")
143 return int.from_bytes(data[:8], byteorder="big"), data[8:]
146def _get_sshstr(data: memoryview) -> typing.Tuple[memoryview, memoryview]:
147 """Bytes with u32 length prefix"""
148 n, data = _get_u32(data)
149 if n > len(data):
150 raise ValueError("Invalid data")
151 return data[:n], data[n:]
154def _get_mpint(data: memoryview) -> typing.Tuple[int, memoryview]:
155 """Big integer."""
156 val, data = _get_sshstr(data)
157 if val and val[0] > 0x7F:
158 raise ValueError("Invalid data")
159 return int.from_bytes(val, "big"), data
162def _to_mpint(val: int) -> bytes:
163 """Storage format for signed bigint."""
164 if val < 0:
165 raise ValueError("negative mpint not allowed")
166 if not val:
167 return b""
168 nbytes = (val.bit_length() + 8) // 8
169 return utils.int_to_bytes(val, nbytes)
172class _FragList:
173 """Build recursive structure without data copy."""
175 flist: typing.List[bytes]
177 def __init__(
178 self, init: typing.Optional[typing.List[bytes]] = None
179 ) -> None:
180 self.flist = []
181 if init:
182 self.flist.extend(init)
184 def put_raw(self, val: bytes) -> None:
185 """Add plain bytes"""
186 self.flist.append(val)
188 def put_u32(self, val: int) -> None:
189 """Big-endian uint32"""
190 self.flist.append(val.to_bytes(length=4, byteorder="big"))
192 def put_sshstr(self, val: typing.Union[bytes, "_FragList"]) -> None:
193 """Bytes prefixed with u32 length"""
194 if isinstance(val, (bytes, memoryview, bytearray)):
195 self.put_u32(len(val))
196 self.flist.append(val)
197 else:
198 self.put_u32(val.size())
199 self.flist.extend(val.flist)
201 def put_mpint(self, val: int) -> None:
202 """Big-endian bigint prefixed with u32 length"""
203 self.put_sshstr(_to_mpint(val))
205 def size(self) -> int:
206 """Current number of bytes"""
207 return sum(map(len, self.flist))
209 def render(self, dstbuf: memoryview, pos: int = 0) -> int:
210 """Write into bytearray"""
211 for frag in self.flist:
212 flen = len(frag)
213 start, pos = pos, pos + flen
214 dstbuf[start:pos] = frag
215 return pos
217 def tobytes(self) -> bytes:
218 """Return as bytes"""
219 buf = memoryview(bytearray(self.size()))
220 self.render(buf)
221 return buf.tobytes()
224class _SSHFormatRSA:
225 """Format for RSA keys.
227 Public:
228 mpint e, n
229 Private:
230 mpint n, e, d, iqmp, p, q
231 """
233 def get_public(self, data: memoryview):
234 """RSA public fields"""
235 e, data = _get_mpint(data)
236 n, data = _get_mpint(data)
237 return (e, n), data
239 def load_public(
240 self, data: memoryview
241 ) -> typing.Tuple[rsa.RSAPublicKey, memoryview]:
242 """Make RSA public key from data."""
243 (e, n), data = self.get_public(data)
244 public_numbers = rsa.RSAPublicNumbers(e, n)
245 public_key = public_numbers.public_key()
246 return public_key, data
248 def load_private(
249 self, data: memoryview, pubfields
250 ) -> typing.Tuple[rsa.RSAPrivateKey, memoryview]:
251 """Make RSA private key from data."""
252 n, data = _get_mpint(data)
253 e, data = _get_mpint(data)
254 d, data = _get_mpint(data)
255 iqmp, data = _get_mpint(data)
256 p, data = _get_mpint(data)
257 q, data = _get_mpint(data)
259 if (e, n) != pubfields:
260 raise ValueError("Corrupt data: rsa field mismatch")
261 dmp1 = rsa.rsa_crt_dmp1(d, p)
262 dmq1 = rsa.rsa_crt_dmq1(d, q)
263 public_numbers = rsa.RSAPublicNumbers(e, n)
264 private_numbers = rsa.RSAPrivateNumbers(
265 p, q, d, dmp1, dmq1, iqmp, public_numbers
266 )
267 private_key = private_numbers.private_key()
268 return private_key, data
270 def encode_public(
271 self, public_key: rsa.RSAPublicKey, f_pub: _FragList
272 ) -> None:
273 """Write RSA public key"""
274 pubn = public_key.public_numbers()
275 f_pub.put_mpint(pubn.e)
276 f_pub.put_mpint(pubn.n)
278 def encode_private(
279 self, private_key: rsa.RSAPrivateKey, f_priv: _FragList
280 ) -> None:
281 """Write RSA private key"""
282 private_numbers = private_key.private_numbers()
283 public_numbers = private_numbers.public_numbers
285 f_priv.put_mpint(public_numbers.n)
286 f_priv.put_mpint(public_numbers.e)
288 f_priv.put_mpint(private_numbers.d)
289 f_priv.put_mpint(private_numbers.iqmp)
290 f_priv.put_mpint(private_numbers.p)
291 f_priv.put_mpint(private_numbers.q)
294class _SSHFormatDSA:
295 """Format for DSA keys.
297 Public:
298 mpint p, q, g, y
299 Private:
300 mpint p, q, g, y, x
301 """
303 def get_public(
304 self, data: memoryview
305 ) -> typing.Tuple[typing.Tuple, memoryview]:
306 """DSA public fields"""
307 p, data = _get_mpint(data)
308 q, data = _get_mpint(data)
309 g, data = _get_mpint(data)
310 y, data = _get_mpint(data)
311 return (p, q, g, y), data
313 def load_public(
314 self, data: memoryview
315 ) -> typing.Tuple[dsa.DSAPublicKey, memoryview]:
316 """Make DSA public key from data."""
317 (p, q, g, y), data = self.get_public(data)
318 parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
319 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
320 self._validate(public_numbers)
321 public_key = public_numbers.public_key()
322 return public_key, data
324 def load_private(
325 self, data: memoryview, pubfields
326 ) -> typing.Tuple[dsa.DSAPrivateKey, memoryview]:
327 """Make DSA private key from data."""
328 (p, q, g, y), data = self.get_public(data)
329 x, data = _get_mpint(data)
331 if (p, q, g, y) != pubfields:
332 raise ValueError("Corrupt data: dsa field mismatch")
333 parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
334 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
335 self._validate(public_numbers)
336 private_numbers = dsa.DSAPrivateNumbers(x, public_numbers)
337 private_key = private_numbers.private_key()
338 return private_key, data
340 def encode_public(
341 self, public_key: dsa.DSAPublicKey, f_pub: _FragList
342 ) -> None:
343 """Write DSA public key"""
344 public_numbers = public_key.public_numbers()
345 parameter_numbers = public_numbers.parameter_numbers
346 self._validate(public_numbers)
348 f_pub.put_mpint(parameter_numbers.p)
349 f_pub.put_mpint(parameter_numbers.q)
350 f_pub.put_mpint(parameter_numbers.g)
351 f_pub.put_mpint(public_numbers.y)
353 def encode_private(
354 self, private_key: dsa.DSAPrivateKey, f_priv: _FragList
355 ) -> None:
356 """Write DSA private key"""
357 self.encode_public(private_key.public_key(), f_priv)
358 f_priv.put_mpint(private_key.private_numbers().x)
360 def _validate(self, public_numbers: dsa.DSAPublicNumbers) -> None:
361 parameter_numbers = public_numbers.parameter_numbers
362 if parameter_numbers.p.bit_length() != 1024:
363 raise ValueError("SSH supports only 1024 bit DSA keys")
366class _SSHFormatECDSA:
367 """Format for ECDSA keys.
369 Public:
370 str curve
371 bytes point
372 Private:
373 str curve
374 bytes point
375 mpint secret
376 """
378 def __init__(self, ssh_curve_name: bytes, curve: ec.EllipticCurve):
379 self.ssh_curve_name = ssh_curve_name
380 self.curve = curve
382 def get_public(
383 self, data: memoryview
384 ) -> typing.Tuple[typing.Tuple, memoryview]:
385 """ECDSA public fields"""
386 curve, data = _get_sshstr(data)
387 point, data = _get_sshstr(data)
388 if curve != self.ssh_curve_name:
389 raise ValueError("Curve name mismatch")
390 if point[0] != 4:
391 raise NotImplementedError("Need uncompressed point")
392 return (curve, point), data
394 def load_public(
395 self, data: memoryview
396 ) -> typing.Tuple[ec.EllipticCurvePublicKey, memoryview]:
397 """Make ECDSA public key from data."""
398 (curve_name, point), data = self.get_public(data)
399 public_key = ec.EllipticCurvePublicKey.from_encoded_point(
400 self.curve, point.tobytes()
401 )
402 return public_key, data
404 def load_private(
405 self, data: memoryview, pubfields
406 ) -> typing.Tuple[ec.EllipticCurvePrivateKey, memoryview]:
407 """Make ECDSA private key from data."""
408 (curve_name, point), data = self.get_public(data)
409 secret, data = _get_mpint(data)
411 if (curve_name, point) != pubfields:
412 raise ValueError("Corrupt data: ecdsa field mismatch")
413 private_key = ec.derive_private_key(secret, self.curve)
414 return private_key, data
416 def encode_public(
417 self, public_key: ec.EllipticCurvePublicKey, f_pub: _FragList
418 ) -> None:
419 """Write ECDSA public key"""
420 point = public_key.public_bytes(
421 Encoding.X962, PublicFormat.UncompressedPoint
422 )
423 f_pub.put_sshstr(self.ssh_curve_name)
424 f_pub.put_sshstr(point)
426 def encode_private(
427 self, private_key: ec.EllipticCurvePrivateKey, f_priv: _FragList
428 ) -> None:
429 """Write ECDSA private key"""
430 public_key = private_key.public_key()
431 private_numbers = private_key.private_numbers()
433 self.encode_public(public_key, f_priv)
434 f_priv.put_mpint(private_numbers.private_value)
437class _SSHFormatEd25519:
438 """Format for Ed25519 keys.
440 Public:
441 bytes point
442 Private:
443 bytes point
444 bytes secret_and_point
445 """
447 def get_public(
448 self, data: memoryview
449 ) -> typing.Tuple[typing.Tuple, memoryview]:
450 """Ed25519 public fields"""
451 point, data = _get_sshstr(data)
452 return (point,), data
454 def load_public(
455 self, data: memoryview
456 ) -> typing.Tuple[ed25519.Ed25519PublicKey, memoryview]:
457 """Make Ed25519 public key from data."""
458 (point,), data = self.get_public(data)
459 public_key = ed25519.Ed25519PublicKey.from_public_bytes(
460 point.tobytes()
461 )
462 return public_key, data
464 def load_private(
465 self, data: memoryview, pubfields
466 ) -> typing.Tuple[ed25519.Ed25519PrivateKey, memoryview]:
467 """Make Ed25519 private key from data."""
468 (point,), data = self.get_public(data)
469 keypair, data = _get_sshstr(data)
471 secret = keypair[:32]
472 point2 = keypair[32:]
473 if point != point2 or (point,) != pubfields:
474 raise ValueError("Corrupt data: ed25519 field mismatch")
475 private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret)
476 return private_key, data
478 def encode_public(
479 self, public_key: ed25519.Ed25519PublicKey, f_pub: _FragList
480 ) -> None:
481 """Write Ed25519 public key"""
482 raw_public_key = public_key.public_bytes(
483 Encoding.Raw, PublicFormat.Raw
484 )
485 f_pub.put_sshstr(raw_public_key)
487 def encode_private(
488 self, private_key: ed25519.Ed25519PrivateKey, f_priv: _FragList
489 ) -> None:
490 """Write Ed25519 private key"""
491 public_key = private_key.public_key()
492 raw_private_key = private_key.private_bytes(
493 Encoding.Raw, PrivateFormat.Raw, NoEncryption()
494 )
495 raw_public_key = public_key.public_bytes(
496 Encoding.Raw, PublicFormat.Raw
497 )
498 f_keypair = _FragList([raw_private_key, raw_public_key])
500 self.encode_public(public_key, f_priv)
501 f_priv.put_sshstr(f_keypair)
504_KEY_FORMATS = {
505 _SSH_RSA: _SSHFormatRSA(),
506 _SSH_DSA: _SSHFormatDSA(),
507 _SSH_ED25519: _SSHFormatEd25519(),
508 _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()),
509 _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()),
510 _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()),
511}
514def _lookup_kformat(key_type: bytes):
515 """Return valid format or throw error"""
516 if not isinstance(key_type, bytes):
517 key_type = memoryview(key_type).tobytes()
518 if key_type in _KEY_FORMATS:
519 return _KEY_FORMATS[key_type]
520 raise UnsupportedAlgorithm(f"Unsupported key type: {key_type!r}")
523_SSH_PRIVATE_KEY_TYPES = typing.Union[
524 ec.EllipticCurvePrivateKey,
525 rsa.RSAPrivateKey,
526 dsa.DSAPrivateKey,
527 ed25519.Ed25519PrivateKey,
528]
531def load_ssh_private_key(
532 data: bytes,
533 password: typing.Optional[bytes],
534 backend: typing.Any = None,
535) -> _SSH_PRIVATE_KEY_TYPES:
536 """Load private key from OpenSSH custom encoding."""
537 utils._check_byteslike("data", data)
538 if password is not None:
539 utils._check_bytes("password", password)
541 m = _PEM_RC.search(data)
542 if not m:
543 raise ValueError("Not OpenSSH private key format")
544 p1 = m.start(1)
545 p2 = m.end(1)
546 data = binascii.a2b_base64(memoryview(data)[p1:p2])
547 if not data.startswith(_SK_MAGIC):
548 raise ValueError("Not OpenSSH private key format")
549 data = memoryview(data)[len(_SK_MAGIC) :]
551 # parse header
552 ciphername, data = _get_sshstr(data)
553 kdfname, data = _get_sshstr(data)
554 kdfoptions, data = _get_sshstr(data)
555 nkeys, data = _get_u32(data)
556 if nkeys != 1:
557 raise ValueError("Only one key supported")
559 # load public key data
560 pubdata, data = _get_sshstr(data)
561 pub_key_type, pubdata = _get_sshstr(pubdata)
562 kformat = _lookup_kformat(pub_key_type)
563 pubfields, pubdata = kformat.get_public(pubdata)
564 _check_empty(pubdata)
566 # load secret data
567 edata, data = _get_sshstr(data)
568 _check_empty(data)
570 if (ciphername, kdfname) != (_NONE, _NONE):
571 ciphername_bytes = ciphername.tobytes()
572 if ciphername_bytes not in _SSH_CIPHERS:
573 raise UnsupportedAlgorithm(
574 f"Unsupported cipher: {ciphername_bytes!r}"
575 )
576 if kdfname != _BCRYPT:
577 raise UnsupportedAlgorithm(f"Unsupported KDF: {kdfname!r}")
578 blklen = _SSH_CIPHERS[ciphername_bytes][3]
579 _check_block_size(edata, blklen)
580 salt, kbuf = _get_sshstr(kdfoptions)
581 rounds, kbuf = _get_u32(kbuf)
582 _check_empty(kbuf)
583 ciph = _init_cipher(ciphername_bytes, password, salt.tobytes(), rounds)
584 edata = memoryview(ciph.decryptor().update(edata))
585 else:
586 blklen = 8
587 _check_block_size(edata, blklen)
588 ck1, edata = _get_u32(edata)
589 ck2, edata = _get_u32(edata)
590 if ck1 != ck2:
591 raise ValueError("Corrupt data: broken checksum")
593 # load per-key struct
594 key_type, edata = _get_sshstr(edata)
595 if key_type != pub_key_type:
596 raise ValueError("Corrupt data: key type mismatch")
597 private_key, edata = kformat.load_private(edata, pubfields)
598 comment, edata = _get_sshstr(edata)
600 # yes, SSH does padding check *after* all other parsing is done.
601 # need to follow as it writes zero-byte padding too.
602 if edata != _PADDING[: len(edata)]:
603 raise ValueError("Corrupt data: invalid padding")
605 return private_key
608def _serialize_ssh_private_key(
609 private_key: _SSH_PRIVATE_KEY_TYPES,
610 password: bytes,
611 encryption_algorithm: KeySerializationEncryption,
612) -> bytes:
613 """Serialize private key with OpenSSH custom encoding."""
614 utils._check_bytes("password", password)
616 if isinstance(private_key, ec.EllipticCurvePrivateKey):
617 key_type = _ecdsa_key_type(private_key.public_key())
618 elif isinstance(private_key, rsa.RSAPrivateKey):
619 key_type = _SSH_RSA
620 elif isinstance(private_key, dsa.DSAPrivateKey):
621 key_type = _SSH_DSA
622 elif isinstance(private_key, ed25519.Ed25519PrivateKey):
623 key_type = _SSH_ED25519
624 else:
625 raise ValueError("Unsupported key type")
626 kformat = _lookup_kformat(key_type)
628 # setup parameters
629 f_kdfoptions = _FragList()
630 if password:
631 ciphername = _DEFAULT_CIPHER
632 blklen = _SSH_CIPHERS[ciphername][3]
633 kdfname = _BCRYPT
634 rounds = _DEFAULT_ROUNDS
635 if (
636 isinstance(encryption_algorithm, _KeySerializationEncryption)
637 and encryption_algorithm._kdf_rounds is not None
638 ):
639 rounds = encryption_algorithm._kdf_rounds
640 salt = os.urandom(16)
641 f_kdfoptions.put_sshstr(salt)
642 f_kdfoptions.put_u32(rounds)
643 ciph = _init_cipher(ciphername, password, salt, rounds)
644 else:
645 ciphername = kdfname = _NONE
646 blklen = 8
647 ciph = None
648 nkeys = 1
649 checkval = os.urandom(4)
650 comment = b""
652 # encode public and private parts together
653 f_public_key = _FragList()
654 f_public_key.put_sshstr(key_type)
655 kformat.encode_public(private_key.public_key(), f_public_key)
657 f_secrets = _FragList([checkval, checkval])
658 f_secrets.put_sshstr(key_type)
659 kformat.encode_private(private_key, f_secrets)
660 f_secrets.put_sshstr(comment)
661 f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)])
663 # top-level structure
664 f_main = _FragList()
665 f_main.put_raw(_SK_MAGIC)
666 f_main.put_sshstr(ciphername)
667 f_main.put_sshstr(kdfname)
668 f_main.put_sshstr(f_kdfoptions)
669 f_main.put_u32(nkeys)
670 f_main.put_sshstr(f_public_key)
671 f_main.put_sshstr(f_secrets)
673 # copy result info bytearray
674 slen = f_secrets.size()
675 mlen = f_main.size()
676 buf = memoryview(bytearray(mlen + blklen))
677 f_main.render(buf)
678 ofs = mlen - slen
680 # encrypt in-place
681 if ciph is not None:
682 ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:])
684 return _ssh_pem_encode(buf[:mlen])
687_SSH_PUBLIC_KEY_TYPES = typing.Union[
688 ec.EllipticCurvePublicKey,
689 rsa.RSAPublicKey,
690 dsa.DSAPublicKey,
691 ed25519.Ed25519PublicKey,
692]
695def load_ssh_public_key(
696 data: bytes, backend: typing.Any = None
697) -> _SSH_PUBLIC_KEY_TYPES:
698 """Load public key from OpenSSH one-line format."""
699 utils._check_byteslike("data", data)
701 m = _SSH_PUBKEY_RC.match(data)
702 if not m:
703 raise ValueError("Invalid line format")
704 key_type = orig_key_type = m.group(1)
705 key_body = m.group(2)
706 with_cert = False
707 if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]:
708 with_cert = True
709 key_type = key_type[: -len(_CERT_SUFFIX)]
710 kformat = _lookup_kformat(key_type)
712 try:
713 rest = memoryview(binascii.a2b_base64(key_body))
714 except (TypeError, binascii.Error):
715 raise ValueError("Invalid key format")
717 inner_key_type, rest = _get_sshstr(rest)
718 if inner_key_type != orig_key_type:
719 raise ValueError("Invalid key format")
720 if with_cert:
721 nonce, rest = _get_sshstr(rest)
722 public_key, rest = kformat.load_public(rest)
723 if with_cert:
724 serial, rest = _get_u64(rest)
725 cctype, rest = _get_u32(rest)
726 key_id, rest = _get_sshstr(rest)
727 principals, rest = _get_sshstr(rest)
728 valid_after, rest = _get_u64(rest)
729 valid_before, rest = _get_u64(rest)
730 crit_options, rest = _get_sshstr(rest)
731 extensions, rest = _get_sshstr(rest)
732 reserved, rest = _get_sshstr(rest)
733 sig_key, rest = _get_sshstr(rest)
734 signature, rest = _get_sshstr(rest)
735 _check_empty(rest)
736 return public_key
739def serialize_ssh_public_key(public_key: _SSH_PUBLIC_KEY_TYPES) -> bytes:
740 """One-line public key format for OpenSSH"""
741 if isinstance(public_key, ec.EllipticCurvePublicKey):
742 key_type = _ecdsa_key_type(public_key)
743 elif isinstance(public_key, rsa.RSAPublicKey):
744 key_type = _SSH_RSA
745 elif isinstance(public_key, dsa.DSAPublicKey):
746 key_type = _SSH_DSA
747 elif isinstance(public_key, ed25519.Ed25519PublicKey):
748 key_type = _SSH_ED25519
749 else:
750 raise ValueError("Unsupported key type")
751 kformat = _lookup_kformat(key_type)
753 f_pub = _FragList()
754 f_pub.put_sshstr(key_type)
755 kformat.encode_public(public_key, f_pub)
757 pub = binascii.b2a_base64(f_pub.tobytes()).strip()
758 return b"".join([key_type, b" ", pub])