Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/cryptography/hazmat/primitives/serialization/ssh.py: 22%

415 statements  

« prev     ^ index     » next       coverage.py v7.0.1, created at 2022-12-25 06:11 +0000

1# This file is dual licensed under the terms of the Apache License, Version 

2# 2.0, and the BSD License. See the LICENSE file in the root of this repository 

3# for complete details. 

4 

5 

6import binascii 

7import os 

8import re 

9import typing 

10from base64 import encodebytes as _base64_encode 

11 

12from cryptography import utils 

13from cryptography.exceptions import UnsupportedAlgorithm 

14from cryptography.hazmat.primitives.asymmetric import dsa, ec, ed25519, rsa 

15from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes 

16from cryptography.hazmat.primitives.serialization import ( 

17 Encoding, 

18 KeySerializationEncryption, 

19 NoEncryption, 

20 PrivateFormat, 

21 PublicFormat, 

22 _KeySerializationEncryption, 

23) 

24 

25try: 

26 from bcrypt import kdf as _bcrypt_kdf 

27 

28 _bcrypt_supported = True 

29except ImportError: 

30 _bcrypt_supported = False 

31 

32 def _bcrypt_kdf( 

33 password: bytes, 

34 salt: bytes, 

35 desired_key_bytes: int, 

36 rounds: int, 

37 ignore_few_rounds: bool = False, 

38 ) -> bytes: 

39 raise UnsupportedAlgorithm("Need bcrypt module") 

40 

41 

42_SSH_ED25519 = b"ssh-ed25519" 

43_SSH_RSA = b"ssh-rsa" 

44_SSH_DSA = b"ssh-dss" 

45_ECDSA_NISTP256 = b"ecdsa-sha2-nistp256" 

46_ECDSA_NISTP384 = b"ecdsa-sha2-nistp384" 

47_ECDSA_NISTP521 = b"ecdsa-sha2-nistp521" 

48_CERT_SUFFIX = b"-cert-v01@openssh.com" 

49 

50_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)") 

51_SK_MAGIC = b"openssh-key-v1\0" 

52_SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----" 

53_SK_END = b"-----END OPENSSH PRIVATE KEY-----" 

54_BCRYPT = b"bcrypt" 

55_NONE = b"none" 

56_DEFAULT_CIPHER = b"aes256-ctr" 

57_DEFAULT_ROUNDS = 16 

58 

59# re is only way to work on bytes-like data 

60_PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL) 

61 

62# padding for max blocksize 

63_PADDING = memoryview(bytearray(range(1, 1 + 16))) 

64 

65# ciphers that are actually used in key wrapping 

66_SSH_CIPHERS: typing.Dict[ 

67 bytes, 

68 typing.Tuple[ 

69 typing.Type[algorithms.AES], 

70 int, 

71 typing.Union[typing.Type[modes.CTR], typing.Type[modes.CBC]], 

72 int, 

73 ], 

74] = { 

75 b"aes256-ctr": (algorithms.AES, 32, modes.CTR, 16), 

76 b"aes256-cbc": (algorithms.AES, 32, modes.CBC, 16), 

77} 

78 

79# map local curve name to key type 

80_ECDSA_KEY_TYPE = { 

81 "secp256r1": _ECDSA_NISTP256, 

82 "secp384r1": _ECDSA_NISTP384, 

83 "secp521r1": _ECDSA_NISTP521, 

84} 

85 

86 

87def _ecdsa_key_type(public_key: ec.EllipticCurvePublicKey) -> bytes: 

88 """Return SSH key_type and curve_name for private key.""" 

89 curve = public_key.curve 

90 if curve.name not in _ECDSA_KEY_TYPE: 

91 raise ValueError( 

92 f"Unsupported curve for ssh private key: {curve.name!r}" 

93 ) 

94 return _ECDSA_KEY_TYPE[curve.name] 

95 

96 

97def _ssh_pem_encode( 

98 data: bytes, 

99 prefix: bytes = _SK_START + b"\n", 

100 suffix: bytes = _SK_END + b"\n", 

101) -> bytes: 

102 return b"".join([prefix, _base64_encode(data), suffix]) 

103 

104 

105def _check_block_size(data: bytes, block_len: int) -> None: 

106 """Require data to be full blocks""" 

107 if not data or len(data) % block_len != 0: 

108 raise ValueError("Corrupt data: missing padding") 

109 

110 

111def _check_empty(data: bytes) -> None: 

112 """All data should have been parsed.""" 

113 if data: 

114 raise ValueError("Corrupt data: unparsed data") 

115 

116 

117def _init_cipher( 

118 ciphername: bytes, 

119 password: typing.Optional[bytes], 

120 salt: bytes, 

121 rounds: int, 

122) -> Cipher[typing.Union[modes.CBC, modes.CTR]]: 

123 """Generate key + iv and return cipher.""" 

124 if not password: 

125 raise ValueError("Key is password-protected.") 

126 

127 algo, key_len, mode, iv_len = _SSH_CIPHERS[ciphername] 

128 seed = _bcrypt_kdf(password, salt, key_len + iv_len, rounds, True) 

129 return Cipher(algo(seed[:key_len]), mode(seed[key_len:])) 

130 

131 

132def _get_u32(data: memoryview) -> typing.Tuple[int, memoryview]: 

133 """Uint32""" 

134 if len(data) < 4: 

135 raise ValueError("Invalid data") 

136 return int.from_bytes(data[:4], byteorder="big"), data[4:] 

137 

138 

139def _get_u64(data: memoryview) -> typing.Tuple[int, memoryview]: 

140 """Uint64""" 

141 if len(data) < 8: 

142 raise ValueError("Invalid data") 

143 return int.from_bytes(data[:8], byteorder="big"), data[8:] 

144 

145 

146def _get_sshstr(data: memoryview) -> typing.Tuple[memoryview, memoryview]: 

147 """Bytes with u32 length prefix""" 

148 n, data = _get_u32(data) 

149 if n > len(data): 

150 raise ValueError("Invalid data") 

151 return data[:n], data[n:] 

152 

153 

154def _get_mpint(data: memoryview) -> typing.Tuple[int, memoryview]: 

155 """Big integer.""" 

156 val, data = _get_sshstr(data) 

157 if val and val[0] > 0x7F: 

158 raise ValueError("Invalid data") 

159 return int.from_bytes(val, "big"), data 

160 

161 

162def _to_mpint(val: int) -> bytes: 

163 """Storage format for signed bigint.""" 

164 if val < 0: 

165 raise ValueError("negative mpint not allowed") 

166 if not val: 

167 return b"" 

168 nbytes = (val.bit_length() + 8) // 8 

169 return utils.int_to_bytes(val, nbytes) 

170 

171 

172class _FragList: 

173 """Build recursive structure without data copy.""" 

174 

175 flist: typing.List[bytes] 

176 

177 def __init__( 

178 self, init: typing.Optional[typing.List[bytes]] = None 

179 ) -> None: 

180 self.flist = [] 

181 if init: 

182 self.flist.extend(init) 

183 

184 def put_raw(self, val: bytes) -> None: 

185 """Add plain bytes""" 

186 self.flist.append(val) 

187 

188 def put_u32(self, val: int) -> None: 

189 """Big-endian uint32""" 

190 self.flist.append(val.to_bytes(length=4, byteorder="big")) 

191 

192 def put_sshstr(self, val: typing.Union[bytes, "_FragList"]) -> None: 

193 """Bytes prefixed with u32 length""" 

194 if isinstance(val, (bytes, memoryview, bytearray)): 

195 self.put_u32(len(val)) 

196 self.flist.append(val) 

197 else: 

198 self.put_u32(val.size()) 

199 self.flist.extend(val.flist) 

200 

201 def put_mpint(self, val: int) -> None: 

202 """Big-endian bigint prefixed with u32 length""" 

203 self.put_sshstr(_to_mpint(val)) 

204 

205 def size(self) -> int: 

206 """Current number of bytes""" 

207 return sum(map(len, self.flist)) 

208 

209 def render(self, dstbuf: memoryview, pos: int = 0) -> int: 

210 """Write into bytearray""" 

211 for frag in self.flist: 

212 flen = len(frag) 

213 start, pos = pos, pos + flen 

214 dstbuf[start:pos] = frag 

215 return pos 

216 

217 def tobytes(self) -> bytes: 

218 """Return as bytes""" 

219 buf = memoryview(bytearray(self.size())) 

220 self.render(buf) 

221 return buf.tobytes() 

222 

223 

224class _SSHFormatRSA: 

225 """Format for RSA keys. 

226 

227 Public: 

228 mpint e, n 

229 Private: 

230 mpint n, e, d, iqmp, p, q 

231 """ 

232 

233 def get_public(self, data: memoryview): 

234 """RSA public fields""" 

235 e, data = _get_mpint(data) 

236 n, data = _get_mpint(data) 

237 return (e, n), data 

238 

239 def load_public( 

240 self, data: memoryview 

241 ) -> typing.Tuple[rsa.RSAPublicKey, memoryview]: 

242 """Make RSA public key from data.""" 

243 (e, n), data = self.get_public(data) 

244 public_numbers = rsa.RSAPublicNumbers(e, n) 

245 public_key = public_numbers.public_key() 

246 return public_key, data 

247 

248 def load_private( 

249 self, data: memoryview, pubfields 

250 ) -> typing.Tuple[rsa.RSAPrivateKey, memoryview]: 

251 """Make RSA private key from data.""" 

252 n, data = _get_mpint(data) 

253 e, data = _get_mpint(data) 

254 d, data = _get_mpint(data) 

255 iqmp, data = _get_mpint(data) 

256 p, data = _get_mpint(data) 

257 q, data = _get_mpint(data) 

258 

259 if (e, n) != pubfields: 

260 raise ValueError("Corrupt data: rsa field mismatch") 

261 dmp1 = rsa.rsa_crt_dmp1(d, p) 

262 dmq1 = rsa.rsa_crt_dmq1(d, q) 

263 public_numbers = rsa.RSAPublicNumbers(e, n) 

264 private_numbers = rsa.RSAPrivateNumbers( 

265 p, q, d, dmp1, dmq1, iqmp, public_numbers 

266 ) 

267 private_key = private_numbers.private_key() 

268 return private_key, data 

269 

270 def encode_public( 

271 self, public_key: rsa.RSAPublicKey, f_pub: _FragList 

272 ) -> None: 

273 """Write RSA public key""" 

274 pubn = public_key.public_numbers() 

275 f_pub.put_mpint(pubn.e) 

276 f_pub.put_mpint(pubn.n) 

277 

278 def encode_private( 

279 self, private_key: rsa.RSAPrivateKey, f_priv: _FragList 

280 ) -> None: 

281 """Write RSA private key""" 

282 private_numbers = private_key.private_numbers() 

283 public_numbers = private_numbers.public_numbers 

284 

285 f_priv.put_mpint(public_numbers.n) 

286 f_priv.put_mpint(public_numbers.e) 

287 

288 f_priv.put_mpint(private_numbers.d) 

289 f_priv.put_mpint(private_numbers.iqmp) 

290 f_priv.put_mpint(private_numbers.p) 

291 f_priv.put_mpint(private_numbers.q) 

292 

293 

294class _SSHFormatDSA: 

295 """Format for DSA keys. 

296 

297 Public: 

298 mpint p, q, g, y 

299 Private: 

300 mpint p, q, g, y, x 

301 """ 

302 

303 def get_public( 

304 self, data: memoryview 

305 ) -> typing.Tuple[typing.Tuple, memoryview]: 

306 """DSA public fields""" 

307 p, data = _get_mpint(data) 

308 q, data = _get_mpint(data) 

309 g, data = _get_mpint(data) 

310 y, data = _get_mpint(data) 

311 return (p, q, g, y), data 

312 

313 def load_public( 

314 self, data: memoryview 

315 ) -> typing.Tuple[dsa.DSAPublicKey, memoryview]: 

316 """Make DSA public key from data.""" 

317 (p, q, g, y), data = self.get_public(data) 

318 parameter_numbers = dsa.DSAParameterNumbers(p, q, g) 

319 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) 

320 self._validate(public_numbers) 

321 public_key = public_numbers.public_key() 

322 return public_key, data 

323 

324 def load_private( 

325 self, data: memoryview, pubfields 

326 ) -> typing.Tuple[dsa.DSAPrivateKey, memoryview]: 

327 """Make DSA private key from data.""" 

328 (p, q, g, y), data = self.get_public(data) 

329 x, data = _get_mpint(data) 

330 

331 if (p, q, g, y) != pubfields: 

332 raise ValueError("Corrupt data: dsa field mismatch") 

333 parameter_numbers = dsa.DSAParameterNumbers(p, q, g) 

334 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) 

335 self._validate(public_numbers) 

336 private_numbers = dsa.DSAPrivateNumbers(x, public_numbers) 

337 private_key = private_numbers.private_key() 

338 return private_key, data 

339 

340 def encode_public( 

341 self, public_key: dsa.DSAPublicKey, f_pub: _FragList 

342 ) -> None: 

343 """Write DSA public key""" 

344 public_numbers = public_key.public_numbers() 

345 parameter_numbers = public_numbers.parameter_numbers 

346 self._validate(public_numbers) 

347 

348 f_pub.put_mpint(parameter_numbers.p) 

349 f_pub.put_mpint(parameter_numbers.q) 

350 f_pub.put_mpint(parameter_numbers.g) 

351 f_pub.put_mpint(public_numbers.y) 

352 

353 def encode_private( 

354 self, private_key: dsa.DSAPrivateKey, f_priv: _FragList 

355 ) -> None: 

356 """Write DSA private key""" 

357 self.encode_public(private_key.public_key(), f_priv) 

358 f_priv.put_mpint(private_key.private_numbers().x) 

359 

360 def _validate(self, public_numbers: dsa.DSAPublicNumbers) -> None: 

361 parameter_numbers = public_numbers.parameter_numbers 

362 if parameter_numbers.p.bit_length() != 1024: 

363 raise ValueError("SSH supports only 1024 bit DSA keys") 

364 

365 

366class _SSHFormatECDSA: 

367 """Format for ECDSA keys. 

368 

369 Public: 

370 str curve 

371 bytes point 

372 Private: 

373 str curve 

374 bytes point 

375 mpint secret 

376 """ 

377 

378 def __init__(self, ssh_curve_name: bytes, curve: ec.EllipticCurve): 

379 self.ssh_curve_name = ssh_curve_name 

380 self.curve = curve 

381 

382 def get_public( 

383 self, data: memoryview 

384 ) -> typing.Tuple[typing.Tuple, memoryview]: 

385 """ECDSA public fields""" 

386 curve, data = _get_sshstr(data) 

387 point, data = _get_sshstr(data) 

388 if curve != self.ssh_curve_name: 

389 raise ValueError("Curve name mismatch") 

390 if point[0] != 4: 

391 raise NotImplementedError("Need uncompressed point") 

392 return (curve, point), data 

393 

394 def load_public( 

395 self, data: memoryview 

396 ) -> typing.Tuple[ec.EllipticCurvePublicKey, memoryview]: 

397 """Make ECDSA public key from data.""" 

398 (curve_name, point), data = self.get_public(data) 

399 public_key = ec.EllipticCurvePublicKey.from_encoded_point( 

400 self.curve, point.tobytes() 

401 ) 

402 return public_key, data 

403 

404 def load_private( 

405 self, data: memoryview, pubfields 

406 ) -> typing.Tuple[ec.EllipticCurvePrivateKey, memoryview]: 

407 """Make ECDSA private key from data.""" 

408 (curve_name, point), data = self.get_public(data) 

409 secret, data = _get_mpint(data) 

410 

411 if (curve_name, point) != pubfields: 

412 raise ValueError("Corrupt data: ecdsa field mismatch") 

413 private_key = ec.derive_private_key(secret, self.curve) 

414 return private_key, data 

415 

416 def encode_public( 

417 self, public_key: ec.EllipticCurvePublicKey, f_pub: _FragList 

418 ) -> None: 

419 """Write ECDSA public key""" 

420 point = public_key.public_bytes( 

421 Encoding.X962, PublicFormat.UncompressedPoint 

422 ) 

423 f_pub.put_sshstr(self.ssh_curve_name) 

424 f_pub.put_sshstr(point) 

425 

426 def encode_private( 

427 self, private_key: ec.EllipticCurvePrivateKey, f_priv: _FragList 

428 ) -> None: 

429 """Write ECDSA private key""" 

430 public_key = private_key.public_key() 

431 private_numbers = private_key.private_numbers() 

432 

433 self.encode_public(public_key, f_priv) 

434 f_priv.put_mpint(private_numbers.private_value) 

435 

436 

437class _SSHFormatEd25519: 

438 """Format for Ed25519 keys. 

439 

440 Public: 

441 bytes point 

442 Private: 

443 bytes point 

444 bytes secret_and_point 

445 """ 

446 

447 def get_public( 

448 self, data: memoryview 

449 ) -> typing.Tuple[typing.Tuple, memoryview]: 

450 """Ed25519 public fields""" 

451 point, data = _get_sshstr(data) 

452 return (point,), data 

453 

454 def load_public( 

455 self, data: memoryview 

456 ) -> typing.Tuple[ed25519.Ed25519PublicKey, memoryview]: 

457 """Make Ed25519 public key from data.""" 

458 (point,), data = self.get_public(data) 

459 public_key = ed25519.Ed25519PublicKey.from_public_bytes( 

460 point.tobytes() 

461 ) 

462 return public_key, data 

463 

464 def load_private( 

465 self, data: memoryview, pubfields 

466 ) -> typing.Tuple[ed25519.Ed25519PrivateKey, memoryview]: 

467 """Make Ed25519 private key from data.""" 

468 (point,), data = self.get_public(data) 

469 keypair, data = _get_sshstr(data) 

470 

471 secret = keypair[:32] 

472 point2 = keypair[32:] 

473 if point != point2 or (point,) != pubfields: 

474 raise ValueError("Corrupt data: ed25519 field mismatch") 

475 private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret) 

476 return private_key, data 

477 

478 def encode_public( 

479 self, public_key: ed25519.Ed25519PublicKey, f_pub: _FragList 

480 ) -> None: 

481 """Write Ed25519 public key""" 

482 raw_public_key = public_key.public_bytes( 

483 Encoding.Raw, PublicFormat.Raw 

484 ) 

485 f_pub.put_sshstr(raw_public_key) 

486 

487 def encode_private( 

488 self, private_key: ed25519.Ed25519PrivateKey, f_priv: _FragList 

489 ) -> None: 

490 """Write Ed25519 private key""" 

491 public_key = private_key.public_key() 

492 raw_private_key = private_key.private_bytes( 

493 Encoding.Raw, PrivateFormat.Raw, NoEncryption() 

494 ) 

495 raw_public_key = public_key.public_bytes( 

496 Encoding.Raw, PublicFormat.Raw 

497 ) 

498 f_keypair = _FragList([raw_private_key, raw_public_key]) 

499 

500 self.encode_public(public_key, f_priv) 

501 f_priv.put_sshstr(f_keypair) 

502 

503 

504_KEY_FORMATS = { 

505 _SSH_RSA: _SSHFormatRSA(), 

506 _SSH_DSA: _SSHFormatDSA(), 

507 _SSH_ED25519: _SSHFormatEd25519(), 

508 _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()), 

509 _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()), 

510 _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()), 

511} 

512 

513 

514def _lookup_kformat(key_type: bytes): 

515 """Return valid format or throw error""" 

516 if not isinstance(key_type, bytes): 

517 key_type = memoryview(key_type).tobytes() 

518 if key_type in _KEY_FORMATS: 

519 return _KEY_FORMATS[key_type] 

520 raise UnsupportedAlgorithm(f"Unsupported key type: {key_type!r}") 

521 

522 

523_SSH_PRIVATE_KEY_TYPES = typing.Union[ 

524 ec.EllipticCurvePrivateKey, 

525 rsa.RSAPrivateKey, 

526 dsa.DSAPrivateKey, 

527 ed25519.Ed25519PrivateKey, 

528] 

529 

530 

531def load_ssh_private_key( 

532 data: bytes, 

533 password: typing.Optional[bytes], 

534 backend: typing.Any = None, 

535) -> _SSH_PRIVATE_KEY_TYPES: 

536 """Load private key from OpenSSH custom encoding.""" 

537 utils._check_byteslike("data", data) 

538 if password is not None: 

539 utils._check_bytes("password", password) 

540 

541 m = _PEM_RC.search(data) 

542 if not m: 

543 raise ValueError("Not OpenSSH private key format") 

544 p1 = m.start(1) 

545 p2 = m.end(1) 

546 data = binascii.a2b_base64(memoryview(data)[p1:p2]) 

547 if not data.startswith(_SK_MAGIC): 

548 raise ValueError("Not OpenSSH private key format") 

549 data = memoryview(data)[len(_SK_MAGIC) :] 

550 

551 # parse header 

552 ciphername, data = _get_sshstr(data) 

553 kdfname, data = _get_sshstr(data) 

554 kdfoptions, data = _get_sshstr(data) 

555 nkeys, data = _get_u32(data) 

556 if nkeys != 1: 

557 raise ValueError("Only one key supported") 

558 

559 # load public key data 

560 pubdata, data = _get_sshstr(data) 

561 pub_key_type, pubdata = _get_sshstr(pubdata) 

562 kformat = _lookup_kformat(pub_key_type) 

563 pubfields, pubdata = kformat.get_public(pubdata) 

564 _check_empty(pubdata) 

565 

566 # load secret data 

567 edata, data = _get_sshstr(data) 

568 _check_empty(data) 

569 

570 if (ciphername, kdfname) != (_NONE, _NONE): 

571 ciphername_bytes = ciphername.tobytes() 

572 if ciphername_bytes not in _SSH_CIPHERS: 

573 raise UnsupportedAlgorithm( 

574 f"Unsupported cipher: {ciphername_bytes!r}" 

575 ) 

576 if kdfname != _BCRYPT: 

577 raise UnsupportedAlgorithm(f"Unsupported KDF: {kdfname!r}") 

578 blklen = _SSH_CIPHERS[ciphername_bytes][3] 

579 _check_block_size(edata, blklen) 

580 salt, kbuf = _get_sshstr(kdfoptions) 

581 rounds, kbuf = _get_u32(kbuf) 

582 _check_empty(kbuf) 

583 ciph = _init_cipher(ciphername_bytes, password, salt.tobytes(), rounds) 

584 edata = memoryview(ciph.decryptor().update(edata)) 

585 else: 

586 blklen = 8 

587 _check_block_size(edata, blklen) 

588 ck1, edata = _get_u32(edata) 

589 ck2, edata = _get_u32(edata) 

590 if ck1 != ck2: 

591 raise ValueError("Corrupt data: broken checksum") 

592 

593 # load per-key struct 

594 key_type, edata = _get_sshstr(edata) 

595 if key_type != pub_key_type: 

596 raise ValueError("Corrupt data: key type mismatch") 

597 private_key, edata = kformat.load_private(edata, pubfields) 

598 comment, edata = _get_sshstr(edata) 

599 

600 # yes, SSH does padding check *after* all other parsing is done. 

601 # need to follow as it writes zero-byte padding too. 

602 if edata != _PADDING[: len(edata)]: 

603 raise ValueError("Corrupt data: invalid padding") 

604 

605 return private_key 

606 

607 

608def _serialize_ssh_private_key( 

609 private_key: _SSH_PRIVATE_KEY_TYPES, 

610 password: bytes, 

611 encryption_algorithm: KeySerializationEncryption, 

612) -> bytes: 

613 """Serialize private key with OpenSSH custom encoding.""" 

614 utils._check_bytes("password", password) 

615 

616 if isinstance(private_key, ec.EllipticCurvePrivateKey): 

617 key_type = _ecdsa_key_type(private_key.public_key()) 

618 elif isinstance(private_key, rsa.RSAPrivateKey): 

619 key_type = _SSH_RSA 

620 elif isinstance(private_key, dsa.DSAPrivateKey): 

621 key_type = _SSH_DSA 

622 elif isinstance(private_key, ed25519.Ed25519PrivateKey): 

623 key_type = _SSH_ED25519 

624 else: 

625 raise ValueError("Unsupported key type") 

626 kformat = _lookup_kformat(key_type) 

627 

628 # setup parameters 

629 f_kdfoptions = _FragList() 

630 if password: 

631 ciphername = _DEFAULT_CIPHER 

632 blklen = _SSH_CIPHERS[ciphername][3] 

633 kdfname = _BCRYPT 

634 rounds = _DEFAULT_ROUNDS 

635 if ( 

636 isinstance(encryption_algorithm, _KeySerializationEncryption) 

637 and encryption_algorithm._kdf_rounds is not None 

638 ): 

639 rounds = encryption_algorithm._kdf_rounds 

640 salt = os.urandom(16) 

641 f_kdfoptions.put_sshstr(salt) 

642 f_kdfoptions.put_u32(rounds) 

643 ciph = _init_cipher(ciphername, password, salt, rounds) 

644 else: 

645 ciphername = kdfname = _NONE 

646 blklen = 8 

647 ciph = None 

648 nkeys = 1 

649 checkval = os.urandom(4) 

650 comment = b"" 

651 

652 # encode public and private parts together 

653 f_public_key = _FragList() 

654 f_public_key.put_sshstr(key_type) 

655 kformat.encode_public(private_key.public_key(), f_public_key) 

656 

657 f_secrets = _FragList([checkval, checkval]) 

658 f_secrets.put_sshstr(key_type) 

659 kformat.encode_private(private_key, f_secrets) 

660 f_secrets.put_sshstr(comment) 

661 f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)]) 

662 

663 # top-level structure 

664 f_main = _FragList() 

665 f_main.put_raw(_SK_MAGIC) 

666 f_main.put_sshstr(ciphername) 

667 f_main.put_sshstr(kdfname) 

668 f_main.put_sshstr(f_kdfoptions) 

669 f_main.put_u32(nkeys) 

670 f_main.put_sshstr(f_public_key) 

671 f_main.put_sshstr(f_secrets) 

672 

673 # copy result info bytearray 

674 slen = f_secrets.size() 

675 mlen = f_main.size() 

676 buf = memoryview(bytearray(mlen + blklen)) 

677 f_main.render(buf) 

678 ofs = mlen - slen 

679 

680 # encrypt in-place 

681 if ciph is not None: 

682 ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:]) 

683 

684 return _ssh_pem_encode(buf[:mlen]) 

685 

686 

687_SSH_PUBLIC_KEY_TYPES = typing.Union[ 

688 ec.EllipticCurvePublicKey, 

689 rsa.RSAPublicKey, 

690 dsa.DSAPublicKey, 

691 ed25519.Ed25519PublicKey, 

692] 

693 

694 

695def load_ssh_public_key( 

696 data: bytes, backend: typing.Any = None 

697) -> _SSH_PUBLIC_KEY_TYPES: 

698 """Load public key from OpenSSH one-line format.""" 

699 utils._check_byteslike("data", data) 

700 

701 m = _SSH_PUBKEY_RC.match(data) 

702 if not m: 

703 raise ValueError("Invalid line format") 

704 key_type = orig_key_type = m.group(1) 

705 key_body = m.group(2) 

706 with_cert = False 

707 if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]: 

708 with_cert = True 

709 key_type = key_type[: -len(_CERT_SUFFIX)] 

710 kformat = _lookup_kformat(key_type) 

711 

712 try: 

713 rest = memoryview(binascii.a2b_base64(key_body)) 

714 except (TypeError, binascii.Error): 

715 raise ValueError("Invalid key format") 

716 

717 inner_key_type, rest = _get_sshstr(rest) 

718 if inner_key_type != orig_key_type: 

719 raise ValueError("Invalid key format") 

720 if with_cert: 

721 nonce, rest = _get_sshstr(rest) 

722 public_key, rest = kformat.load_public(rest) 

723 if with_cert: 

724 serial, rest = _get_u64(rest) 

725 cctype, rest = _get_u32(rest) 

726 key_id, rest = _get_sshstr(rest) 

727 principals, rest = _get_sshstr(rest) 

728 valid_after, rest = _get_u64(rest) 

729 valid_before, rest = _get_u64(rest) 

730 crit_options, rest = _get_sshstr(rest) 

731 extensions, rest = _get_sshstr(rest) 

732 reserved, rest = _get_sshstr(rest) 

733 sig_key, rest = _get_sshstr(rest) 

734 signature, rest = _get_sshstr(rest) 

735 _check_empty(rest) 

736 return public_key 

737 

738 

739def serialize_ssh_public_key(public_key: _SSH_PUBLIC_KEY_TYPES) -> bytes: 

740 """One-line public key format for OpenSSH""" 

741 if isinstance(public_key, ec.EllipticCurvePublicKey): 

742 key_type = _ecdsa_key_type(public_key) 

743 elif isinstance(public_key, rsa.RSAPublicKey): 

744 key_type = _SSH_RSA 

745 elif isinstance(public_key, dsa.DSAPublicKey): 

746 key_type = _SSH_DSA 

747 elif isinstance(public_key, ed25519.Ed25519PublicKey): 

748 key_type = _SSH_ED25519 

749 else: 

750 raise ValueError("Unsupported key type") 

751 kformat = _lookup_kformat(key_type) 

752 

753 f_pub = _FragList() 

754 f_pub.put_sshstr(key_type) 

755 kformat.encode_public(public_key, f_pub) 

756 

757 pub = binascii.b2a_base64(f_pub.tobytes()).strip() 

758 return b"".join([key_type, b" ", pub])