Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/cryptography/hazmat/primitives/serialization/ssh.py: 21%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

796 statements  

1# This file is dual licensed under the terms of the Apache License, Version 

2# 2.0, and the BSD License. See the LICENSE file in the root of this repository 

3# for complete details. 

4 

5from __future__ import annotations 

6 

7import binascii 

8import enum 

9import os 

10import re 

11import typing 

12import warnings 

13from base64 import encodebytes as _base64_encode 

14from dataclasses import dataclass 

15 

16from cryptography import utils 

17from cryptography.exceptions import UnsupportedAlgorithm 

18from cryptography.hazmat.primitives import hashes 

19from cryptography.hazmat.primitives.asymmetric import ( 

20 dsa, 

21 ec, 

22 ed25519, 

23 padding, 

24 rsa, 

25) 

26from cryptography.hazmat.primitives.asymmetric import utils as asym_utils 

27from cryptography.hazmat.primitives.ciphers import ( 

28 AEADDecryptionContext, 

29 Cipher, 

30 algorithms, 

31 modes, 

32) 

33from cryptography.hazmat.primitives.serialization import ( 

34 Encoding, 

35 KeySerializationEncryption, 

36 NoEncryption, 

37 PrivateFormat, 

38 PublicFormat, 

39 _KeySerializationEncryption, 

40) 

41 

42try: 

43 from bcrypt import kdf as _bcrypt_kdf 

44 

45 _bcrypt_supported = True 

46except ImportError: 

47 _bcrypt_supported = False 

48 

49 def _bcrypt_kdf( 

50 password: bytes, 

51 salt: bytes, 

52 desired_key_bytes: int, 

53 rounds: int, 

54 ignore_few_rounds: bool = False, 

55 ) -> bytes: 

56 raise UnsupportedAlgorithm("Need bcrypt module") 

57 

58 

59_SSH_ED25519 = b"ssh-ed25519" 

60_SSH_RSA = b"ssh-rsa" 

61_SSH_DSA = b"ssh-dss" 

62_ECDSA_NISTP256 = b"ecdsa-sha2-nistp256" 

63_ECDSA_NISTP384 = b"ecdsa-sha2-nistp384" 

64_ECDSA_NISTP521 = b"ecdsa-sha2-nistp521" 

65_CERT_SUFFIX = b"-cert-v01@openssh.com" 

66 

67# U2F application string suffixed pubkey 

68_SK_SSH_ED25519 = b"sk-ssh-ed25519@openssh.com" 

69_SK_SSH_ECDSA_NISTP256 = b"sk-ecdsa-sha2-nistp256@openssh.com" 

70 

71# These are not key types, only algorithms, so they cannot appear 

72# as a public key type 

73_SSH_RSA_SHA256 = b"rsa-sha2-256" 

74_SSH_RSA_SHA512 = b"rsa-sha2-512" 

75 

76_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)") 

77_SK_MAGIC = b"openssh-key-v1\0" 

78_SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----" 

79_SK_END = b"-----END OPENSSH PRIVATE KEY-----" 

80_BCRYPT = b"bcrypt" 

81_NONE = b"none" 

82_DEFAULT_CIPHER = b"aes256-ctr" 

83_DEFAULT_ROUNDS = 16 

84 

85# re is only way to work on bytes-like data 

86_PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL) 

87 

88# padding for max blocksize 

89_PADDING = memoryview(bytearray(range(1, 1 + 16))) 

90 

91 

92@dataclass 

93class _SSHCipher: 

94 alg: type[algorithms.AES] 

95 key_len: int 

96 mode: type[modes.CTR] | type[modes.CBC] | type[modes.GCM] 

97 block_len: int 

98 iv_len: int 

99 tag_len: int | None 

100 is_aead: bool 

101 

102 

103# ciphers that are actually used in key wrapping 

104_SSH_CIPHERS: dict[bytes, _SSHCipher] = { 

105 b"aes256-ctr": _SSHCipher( 

106 alg=algorithms.AES, 

107 key_len=32, 

108 mode=modes.CTR, 

109 block_len=16, 

110 iv_len=16, 

111 tag_len=None, 

112 is_aead=False, 

113 ), 

114 b"aes256-cbc": _SSHCipher( 

115 alg=algorithms.AES, 

116 key_len=32, 

117 mode=modes.CBC, 

118 block_len=16, 

119 iv_len=16, 

120 tag_len=None, 

121 is_aead=False, 

122 ), 

123 b"aes256-gcm@openssh.com": _SSHCipher( 

124 alg=algorithms.AES, 

125 key_len=32, 

126 mode=modes.GCM, 

127 block_len=16, 

128 iv_len=12, 

129 tag_len=16, 

130 is_aead=True, 

131 ), 

132} 

133 

134# map local curve name to key type 

135_ECDSA_KEY_TYPE = { 

136 "secp256r1": _ECDSA_NISTP256, 

137 "secp384r1": _ECDSA_NISTP384, 

138 "secp521r1": _ECDSA_NISTP521, 

139} 

140 

141 

142def _get_ssh_key_type(key: SSHPrivateKeyTypes | SSHPublicKeyTypes) -> bytes: 

143 if isinstance(key, ec.EllipticCurvePrivateKey): 

144 key_type = _ecdsa_key_type(key.public_key()) 

145 elif isinstance(key, ec.EllipticCurvePublicKey): 

146 key_type = _ecdsa_key_type(key) 

147 elif isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)): 

148 key_type = _SSH_RSA 

149 elif isinstance(key, (dsa.DSAPrivateKey, dsa.DSAPublicKey)): 

150 key_type = _SSH_DSA 

151 elif isinstance( 

152 key, (ed25519.Ed25519PrivateKey, ed25519.Ed25519PublicKey) 

153 ): 

154 key_type = _SSH_ED25519 

155 else: 

156 raise ValueError("Unsupported key type") 

157 

158 return key_type 

159 

160 

161def _ecdsa_key_type(public_key: ec.EllipticCurvePublicKey) -> bytes: 

162 """Return SSH key_type and curve_name for private key.""" 

163 curve = public_key.curve 

164 if curve.name not in _ECDSA_KEY_TYPE: 

165 raise ValueError( 

166 f"Unsupported curve for ssh private key: {curve.name!r}" 

167 ) 

168 return _ECDSA_KEY_TYPE[curve.name] 

169 

170 

171def _ssh_pem_encode( 

172 data: utils.Buffer, 

173 prefix: bytes = _SK_START + b"\n", 

174 suffix: bytes = _SK_END + b"\n", 

175) -> bytes: 

176 return b"".join([prefix, _base64_encode(data), suffix]) 

177 

178 

179def _check_block_size(data: utils.Buffer, block_len: int) -> None: 

180 """Require data to be full blocks""" 

181 if not data or len(data) % block_len != 0: 

182 raise ValueError("Corrupt data: missing padding") 

183 

184 

185def _check_empty(data: utils.Buffer) -> None: 

186 """All data should have been parsed.""" 

187 if data: 

188 raise ValueError("Corrupt data: unparsed data") 

189 

190 

191def _init_cipher( 

192 ciphername: bytes, 

193 password: bytes | None, 

194 salt: bytes, 

195 rounds: int, 

196) -> Cipher[modes.CBC | modes.CTR | modes.GCM]: 

197 """Generate key + iv and return cipher.""" 

198 if not password: 

199 raise TypeError( 

200 "Key is password-protected, but password was not provided." 

201 ) 

202 

203 ciph = _SSH_CIPHERS[ciphername] 

204 seed = _bcrypt_kdf( 

205 password, salt, ciph.key_len + ciph.iv_len, rounds, True 

206 ) 

207 return Cipher( 

208 ciph.alg(seed[: ciph.key_len]), 

209 ciph.mode(seed[ciph.key_len :]), 

210 ) 

211 

212 

213def _get_u32(data: memoryview) -> tuple[int, memoryview]: 

214 """Uint32""" 

215 if len(data) < 4: 

216 raise ValueError("Invalid data") 

217 return int.from_bytes(data[:4], byteorder="big"), data[4:] 

218 

219 

220def _get_u64(data: memoryview) -> tuple[int, memoryview]: 

221 """Uint64""" 

222 if len(data) < 8: 

223 raise ValueError("Invalid data") 

224 return int.from_bytes(data[:8], byteorder="big"), data[8:] 

225 

226 

227def _get_sshstr(data: memoryview) -> tuple[memoryview, memoryview]: 

228 """Bytes with u32 length prefix""" 

229 n, data = _get_u32(data) 

230 if n > len(data): 

231 raise ValueError("Invalid data") 

232 return data[:n], data[n:] 

233 

234 

235def _get_mpint(data: memoryview) -> tuple[int, memoryview]: 

236 """Big integer.""" 

237 val, data = _get_sshstr(data) 

238 if val and val[0] > 0x7F: 

239 raise ValueError("Invalid data") 

240 return int.from_bytes(val, "big"), data 

241 

242 

243def _to_mpint(val: int) -> bytes: 

244 """Storage format for signed bigint.""" 

245 if val < 0: 

246 raise ValueError("negative mpint not allowed") 

247 if not val: 

248 return b"" 

249 nbytes = (val.bit_length() + 8) // 8 

250 return utils.int_to_bytes(val, nbytes) 

251 

252 

253class _FragList: 

254 """Build recursive structure without data copy.""" 

255 

256 flist: list[utils.Buffer] 

257 

258 def __init__(self, init: list[utils.Buffer] | None = None) -> None: 

259 self.flist = [] 

260 if init: 

261 self.flist.extend(init) 

262 

263 def put_raw(self, val: utils.Buffer) -> None: 

264 """Add plain bytes""" 

265 self.flist.append(val) 

266 

267 def put_u32(self, val: int) -> None: 

268 """Big-endian uint32""" 

269 self.flist.append(val.to_bytes(length=4, byteorder="big")) 

270 

271 def put_u64(self, val: int) -> None: 

272 """Big-endian uint64""" 

273 self.flist.append(val.to_bytes(length=8, byteorder="big")) 

274 

275 def put_sshstr(self, val: bytes | _FragList) -> None: 

276 """Bytes prefixed with u32 length""" 

277 if isinstance(val, (bytes, memoryview, bytearray)): 

278 self.put_u32(len(val)) 

279 self.flist.append(val) 

280 else: 

281 self.put_u32(val.size()) 

282 self.flist.extend(val.flist) 

283 

284 def put_mpint(self, val: int) -> None: 

285 """Big-endian bigint prefixed with u32 length""" 

286 self.put_sshstr(_to_mpint(val)) 

287 

288 def size(self) -> int: 

289 """Current number of bytes""" 

290 return sum(map(len, self.flist)) 

291 

292 def render(self, dstbuf: memoryview, pos: int = 0) -> int: 

293 """Write into bytearray""" 

294 for frag in self.flist: 

295 flen = len(frag) 

296 start, pos = pos, pos + flen 

297 dstbuf[start:pos] = frag 

298 return pos 

299 

300 def tobytes(self) -> bytes: 

301 """Return as bytes""" 

302 buf = memoryview(bytearray(self.size())) 

303 self.render(buf) 

304 return buf.tobytes() 

305 

306 

307class _SSHFormatRSA: 

308 """Format for RSA keys. 

309 

310 Public: 

311 mpint e, n 

312 Private: 

313 mpint n, e, d, iqmp, p, q 

314 """ 

315 

316 def get_public( 

317 self, data: memoryview 

318 ) -> tuple[tuple[int, int], memoryview]: 

319 """RSA public fields""" 

320 e, data = _get_mpint(data) 

321 n, data = _get_mpint(data) 

322 return (e, n), data 

323 

324 def load_public( 

325 self, data: memoryview 

326 ) -> tuple[rsa.RSAPublicKey, memoryview]: 

327 """Make RSA public key from data.""" 

328 (e, n), data = self.get_public(data) 

329 public_numbers = rsa.RSAPublicNumbers(e, n) 

330 public_key = public_numbers.public_key() 

331 return public_key, data 

332 

333 def load_private( 

334 self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool 

335 ) -> tuple[rsa.RSAPrivateKey, memoryview]: 

336 """Make RSA private key from data.""" 

337 n, data = _get_mpint(data) 

338 e, data = _get_mpint(data) 

339 d, data = _get_mpint(data) 

340 iqmp, data = _get_mpint(data) 

341 p, data = _get_mpint(data) 

342 q, data = _get_mpint(data) 

343 

344 if (e, n) != pubfields: 

345 raise ValueError("Corrupt data: rsa field mismatch") 

346 dmp1 = rsa.rsa_crt_dmp1(d, p) 

347 dmq1 = rsa.rsa_crt_dmq1(d, q) 

348 public_numbers = rsa.RSAPublicNumbers(e, n) 

349 private_numbers = rsa.RSAPrivateNumbers( 

350 p, q, d, dmp1, dmq1, iqmp, public_numbers 

351 ) 

352 private_key = private_numbers.private_key( 

353 unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation 

354 ) 

355 return private_key, data 

356 

357 def encode_public( 

358 self, public_key: rsa.RSAPublicKey, f_pub: _FragList 

359 ) -> None: 

360 """Write RSA public key""" 

361 pubn = public_key.public_numbers() 

362 f_pub.put_mpint(pubn.e) 

363 f_pub.put_mpint(pubn.n) 

364 

365 def encode_private( 

366 self, private_key: rsa.RSAPrivateKey, f_priv: _FragList 

367 ) -> None: 

368 """Write RSA private key""" 

369 private_numbers = private_key.private_numbers() 

370 public_numbers = private_numbers.public_numbers 

371 

372 f_priv.put_mpint(public_numbers.n) 

373 f_priv.put_mpint(public_numbers.e) 

374 

375 f_priv.put_mpint(private_numbers.d) 

376 f_priv.put_mpint(private_numbers.iqmp) 

377 f_priv.put_mpint(private_numbers.p) 

378 f_priv.put_mpint(private_numbers.q) 

379 

380 

381class _SSHFormatDSA: 

382 """Format for DSA keys. 

383 

384 Public: 

385 mpint p, q, g, y 

386 Private: 

387 mpint p, q, g, y, x 

388 """ 

389 

390 def get_public(self, data: memoryview) -> tuple[tuple, memoryview]: 

391 """DSA public fields""" 

392 p, data = _get_mpint(data) 

393 q, data = _get_mpint(data) 

394 g, data = _get_mpint(data) 

395 y, data = _get_mpint(data) 

396 return (p, q, g, y), data 

397 

398 def load_public( 

399 self, data: memoryview 

400 ) -> tuple[dsa.DSAPublicKey, memoryview]: 

401 """Make DSA public key from data.""" 

402 (p, q, g, y), data = self.get_public(data) 

403 parameter_numbers = dsa.DSAParameterNumbers(p, q, g) 

404 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) 

405 self._validate(public_numbers) 

406 public_key = public_numbers.public_key() 

407 return public_key, data 

408 

409 def load_private( 

410 self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool 

411 ) -> tuple[dsa.DSAPrivateKey, memoryview]: 

412 """Make DSA private key from data.""" 

413 (p, q, g, y), data = self.get_public(data) 

414 x, data = _get_mpint(data) 

415 

416 if (p, q, g, y) != pubfields: 

417 raise ValueError("Corrupt data: dsa field mismatch") 

418 parameter_numbers = dsa.DSAParameterNumbers(p, q, g) 

419 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) 

420 self._validate(public_numbers) 

421 private_numbers = dsa.DSAPrivateNumbers(x, public_numbers) 

422 private_key = private_numbers.private_key() 

423 return private_key, data 

424 

425 def encode_public( 

426 self, public_key: dsa.DSAPublicKey, f_pub: _FragList 

427 ) -> None: 

428 """Write DSA public key""" 

429 public_numbers = public_key.public_numbers() 

430 parameter_numbers = public_numbers.parameter_numbers 

431 self._validate(public_numbers) 

432 

433 f_pub.put_mpint(parameter_numbers.p) 

434 f_pub.put_mpint(parameter_numbers.q) 

435 f_pub.put_mpint(parameter_numbers.g) 

436 f_pub.put_mpint(public_numbers.y) 

437 

438 def encode_private( 

439 self, private_key: dsa.DSAPrivateKey, f_priv: _FragList 

440 ) -> None: 

441 """Write DSA private key""" 

442 self.encode_public(private_key.public_key(), f_priv) 

443 f_priv.put_mpint(private_key.private_numbers().x) 

444 

445 def _validate(self, public_numbers: dsa.DSAPublicNumbers) -> None: 

446 parameter_numbers = public_numbers.parameter_numbers 

447 if parameter_numbers.p.bit_length() != 1024: 

448 raise ValueError("SSH supports only 1024 bit DSA keys") 

449 

450 

451class _SSHFormatECDSA: 

452 """Format for ECDSA keys. 

453 

454 Public: 

455 str curve 

456 bytes point 

457 Private: 

458 str curve 

459 bytes point 

460 mpint secret 

461 """ 

462 

463 def __init__(self, ssh_curve_name: bytes, curve: ec.EllipticCurve): 

464 self.ssh_curve_name = ssh_curve_name 

465 self.curve = curve 

466 

467 def get_public( 

468 self, data: memoryview 

469 ) -> tuple[tuple[memoryview, memoryview], memoryview]: 

470 """ECDSA public fields""" 

471 curve, data = _get_sshstr(data) 

472 point, data = _get_sshstr(data) 

473 if curve != self.ssh_curve_name: 

474 raise ValueError("Curve name mismatch") 

475 if len(point) == 0: 

476 raise ValueError("Invalid EC point: empty data") 

477 if point[0] != 4: 

478 raise NotImplementedError("Need uncompressed point") 

479 return (curve, point), data 

480 

481 def load_public( 

482 self, data: memoryview 

483 ) -> tuple[ec.EllipticCurvePublicKey, memoryview]: 

484 """Make ECDSA public key from data.""" 

485 (_, point), data = self.get_public(data) 

486 public_key = ec.EllipticCurvePublicKey.from_encoded_point( 

487 self.curve, point.tobytes() 

488 ) 

489 return public_key, data 

490 

491 def load_private( 

492 self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool 

493 ) -> tuple[ec.EllipticCurvePrivateKey, memoryview]: 

494 """Make ECDSA private key from data.""" 

495 (curve_name, point), data = self.get_public(data) 

496 secret, data = _get_mpint(data) 

497 

498 if (curve_name, point) != pubfields: 

499 raise ValueError("Corrupt data: ecdsa field mismatch") 

500 private_key = ec.derive_private_key(secret, self.curve) 

501 return private_key, data 

502 

503 def encode_public( 

504 self, public_key: ec.EllipticCurvePublicKey, f_pub: _FragList 

505 ) -> None: 

506 """Write ECDSA public key""" 

507 point = public_key.public_bytes( 

508 Encoding.X962, PublicFormat.UncompressedPoint 

509 ) 

510 f_pub.put_sshstr(self.ssh_curve_name) 

511 f_pub.put_sshstr(point) 

512 

513 def encode_private( 

514 self, private_key: ec.EllipticCurvePrivateKey, f_priv: _FragList 

515 ) -> None: 

516 """Write ECDSA private key""" 

517 public_key = private_key.public_key() 

518 private_numbers = private_key.private_numbers() 

519 

520 self.encode_public(public_key, f_priv) 

521 f_priv.put_mpint(private_numbers.private_value) 

522 

523 

524class _SSHFormatEd25519: 

525 """Format for Ed25519 keys. 

526 

527 Public: 

528 bytes point 

529 Private: 

530 bytes point 

531 bytes secret_and_point 

532 """ 

533 

534 def get_public( 

535 self, data: memoryview 

536 ) -> tuple[tuple[memoryview], memoryview]: 

537 """Ed25519 public fields""" 

538 point, data = _get_sshstr(data) 

539 return (point,), data 

540 

541 def load_public( 

542 self, data: memoryview 

543 ) -> tuple[ed25519.Ed25519PublicKey, memoryview]: 

544 """Make Ed25519 public key from data.""" 

545 (point,), data = self.get_public(data) 

546 public_key = ed25519.Ed25519PublicKey.from_public_bytes( 

547 point.tobytes() 

548 ) 

549 return public_key, data 

550 

551 def load_private( 

552 self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool 

553 ) -> tuple[ed25519.Ed25519PrivateKey, memoryview]: 

554 """Make Ed25519 private key from data.""" 

555 (point,), data = self.get_public(data) 

556 keypair, data = _get_sshstr(data) 

557 

558 secret = keypair[:32] 

559 point2 = keypair[32:] 

560 if point != point2 or (point,) != pubfields: 

561 raise ValueError("Corrupt data: ed25519 field mismatch") 

562 private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret) 

563 return private_key, data 

564 

565 def encode_public( 

566 self, public_key: ed25519.Ed25519PublicKey, f_pub: _FragList 

567 ) -> None: 

568 """Write Ed25519 public key""" 

569 raw_public_key = public_key.public_bytes( 

570 Encoding.Raw, PublicFormat.Raw 

571 ) 

572 f_pub.put_sshstr(raw_public_key) 

573 

574 def encode_private( 

575 self, private_key: ed25519.Ed25519PrivateKey, f_priv: _FragList 

576 ) -> None: 

577 """Write Ed25519 private key""" 

578 public_key = private_key.public_key() 

579 raw_private_key = private_key.private_bytes( 

580 Encoding.Raw, PrivateFormat.Raw, NoEncryption() 

581 ) 

582 raw_public_key = public_key.public_bytes( 

583 Encoding.Raw, PublicFormat.Raw 

584 ) 

585 f_keypair = _FragList([raw_private_key, raw_public_key]) 

586 

587 self.encode_public(public_key, f_priv) 

588 f_priv.put_sshstr(f_keypair) 

589 

590 

591def load_application(data) -> tuple[memoryview, memoryview]: 

592 """ 

593 U2F application strings 

594 """ 

595 application, data = _get_sshstr(data) 

596 if not application.tobytes().startswith(b"ssh:"): 

597 raise ValueError( 

598 "U2F application string does not start with b'ssh:' " 

599 f"({application})" 

600 ) 

601 return application, data 

602 

603 

604class _SSHFormatSKEd25519: 

605 """ 

606 The format of a sk-ssh-ed25519@openssh.com public key is: 

607 

608 string "sk-ssh-ed25519@openssh.com" 

609 string public key 

610 string application (user-specified, but typically "ssh:") 

611 """ 

612 

613 def load_public( 

614 self, data: memoryview 

615 ) -> tuple[ed25519.Ed25519PublicKey, memoryview]: 

616 """Make Ed25519 public key from data.""" 

617 public_key, data = _lookup_kformat(_SSH_ED25519).load_public(data) 

618 _, data = load_application(data) 

619 return public_key, data 

620 

621 def get_public(self, data: memoryview) -> typing.NoReturn: 

622 # Confusingly `get_public` is an entry point used by private key 

623 # loading. 

624 raise UnsupportedAlgorithm( 

625 "sk-ssh-ed25519 private keys cannot be loaded" 

626 ) 

627 

628 

629class _SSHFormatSKECDSA: 

630 """ 

631 The format of a sk-ecdsa-sha2-nistp256@openssh.com public key is: 

632 

633 string "sk-ecdsa-sha2-nistp256@openssh.com" 

634 string curve name 

635 ec_point Q 

636 string application (user-specified, but typically "ssh:") 

637 """ 

638 

639 def load_public( 

640 self, data: memoryview 

641 ) -> tuple[ec.EllipticCurvePublicKey, memoryview]: 

642 """Make ECDSA public key from data.""" 

643 public_key, data = _lookup_kformat(_ECDSA_NISTP256).load_public(data) 

644 _, data = load_application(data) 

645 return public_key, data 

646 

647 def get_public(self, data: memoryview) -> typing.NoReturn: 

648 # Confusingly `get_public` is an entry point used by private key 

649 # loading. 

650 raise UnsupportedAlgorithm( 

651 "sk-ecdsa-sha2-nistp256 private keys cannot be loaded" 

652 ) 

653 

654 

655_KEY_FORMATS = { 

656 _SSH_RSA: _SSHFormatRSA(), 

657 _SSH_DSA: _SSHFormatDSA(), 

658 _SSH_ED25519: _SSHFormatEd25519(), 

659 _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()), 

660 _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()), 

661 _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()), 

662 _SK_SSH_ED25519: _SSHFormatSKEd25519(), 

663 _SK_SSH_ECDSA_NISTP256: _SSHFormatSKECDSA(), 

664} 

665 

666 

667def _lookup_kformat(key_type: utils.Buffer): 

668 """Return valid format or throw error""" 

669 if not isinstance(key_type, bytes): 

670 key_type = memoryview(key_type).tobytes() 

671 if key_type in _KEY_FORMATS: 

672 return _KEY_FORMATS[key_type] 

673 raise UnsupportedAlgorithm(f"Unsupported key type: {key_type!r}") 

674 

675 

676SSHPrivateKeyTypes = typing.Union[ 

677 ec.EllipticCurvePrivateKey, 

678 rsa.RSAPrivateKey, 

679 dsa.DSAPrivateKey, 

680 ed25519.Ed25519PrivateKey, 

681] 

682 

683 

684def load_ssh_private_key( 

685 data: utils.Buffer, 

686 password: bytes | None, 

687 backend: typing.Any = None, 

688 *, 

689 unsafe_skip_rsa_key_validation: bool = False, 

690) -> SSHPrivateKeyTypes: 

691 """Load private key from OpenSSH custom encoding.""" 

692 utils._check_byteslike("data", data) 

693 if password is not None: 

694 utils._check_bytes("password", password) 

695 

696 m = _PEM_RC.search(data) 

697 if not m: 

698 raise ValueError("Not OpenSSH private key format") 

699 p1 = m.start(1) 

700 p2 = m.end(1) 

701 data = binascii.a2b_base64(memoryview(data)[p1:p2]) 

702 if not data.startswith(_SK_MAGIC): 

703 raise ValueError("Not OpenSSH private key format") 

704 data = memoryview(data)[len(_SK_MAGIC) :] 

705 

706 # parse header 

707 ciphername, data = _get_sshstr(data) 

708 kdfname, data = _get_sshstr(data) 

709 kdfoptions, data = _get_sshstr(data) 

710 nkeys, data = _get_u32(data) 

711 if nkeys != 1: 

712 raise ValueError("Only one key supported") 

713 

714 # load public key data 

715 pubdata, data = _get_sshstr(data) 

716 pub_key_type, pubdata = _get_sshstr(pubdata) 

717 kformat = _lookup_kformat(pub_key_type) 

718 pubfields, pubdata = kformat.get_public(pubdata) 

719 _check_empty(pubdata) 

720 

721 if ciphername != _NONE or kdfname != _NONE: 

722 ciphername_bytes = ciphername.tobytes() 

723 if ciphername_bytes not in _SSH_CIPHERS: 

724 raise UnsupportedAlgorithm( 

725 f"Unsupported cipher: {ciphername_bytes!r}" 

726 ) 

727 if kdfname != _BCRYPT: 

728 raise UnsupportedAlgorithm(f"Unsupported KDF: {kdfname!r}") 

729 blklen = _SSH_CIPHERS[ciphername_bytes].block_len 

730 tag_len = _SSH_CIPHERS[ciphername_bytes].tag_len 

731 # load secret data 

732 edata, data = _get_sshstr(data) 

733 # see https://bugzilla.mindrot.org/show_bug.cgi?id=3553 for 

734 # information about how OpenSSH handles AEAD tags 

735 if _SSH_CIPHERS[ciphername_bytes].is_aead: 

736 tag = bytes(data) 

737 if len(tag) != tag_len: 

738 raise ValueError("Corrupt data: invalid tag length for cipher") 

739 else: 

740 _check_empty(data) 

741 _check_block_size(edata, blklen) 

742 salt, kbuf = _get_sshstr(kdfoptions) 

743 rounds, kbuf = _get_u32(kbuf) 

744 _check_empty(kbuf) 

745 ciph = _init_cipher(ciphername_bytes, password, salt.tobytes(), rounds) 

746 dec = ciph.decryptor() 

747 edata = memoryview(dec.update(edata)) 

748 if _SSH_CIPHERS[ciphername_bytes].is_aead: 

749 assert isinstance(dec, AEADDecryptionContext) 

750 _check_empty(dec.finalize_with_tag(tag)) 

751 else: 

752 # _check_block_size requires data to be a full block so there 

753 # should be no output from finalize 

754 _check_empty(dec.finalize()) 

755 else: 

756 if password: 

757 raise TypeError( 

758 "Password was given but private key is not encrypted." 

759 ) 

760 # load secret data 

761 edata, data = _get_sshstr(data) 

762 _check_empty(data) 

763 blklen = 8 

764 _check_block_size(edata, blklen) 

765 ck1, edata = _get_u32(edata) 

766 ck2, edata = _get_u32(edata) 

767 if ck1 != ck2: 

768 raise ValueError("Corrupt data: broken checksum") 

769 

770 # load per-key struct 

771 key_type, edata = _get_sshstr(edata) 

772 if key_type != pub_key_type: 

773 raise ValueError("Corrupt data: key type mismatch") 

774 private_key, edata = kformat.load_private( 

775 edata, 

776 pubfields, 

777 unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation, 

778 ) 

779 # We don't use the comment 

780 _, edata = _get_sshstr(edata) 

781 

782 # yes, SSH does padding check *after* all other parsing is done. 

783 # need to follow as it writes zero-byte padding too. 

784 if edata != _PADDING[: len(edata)]: 

785 raise ValueError("Corrupt data: invalid padding") 

786 

787 if isinstance(private_key, dsa.DSAPrivateKey): 

788 warnings.warn( 

789 "SSH DSA keys are deprecated and will be removed in a future " 

790 "release.", 

791 utils.DeprecatedIn40, 

792 stacklevel=2, 

793 ) 

794 

795 return private_key 

796 

797 

798def _serialize_ssh_private_key( 

799 private_key: SSHPrivateKeyTypes, 

800 password: bytes, 

801 encryption_algorithm: KeySerializationEncryption, 

802) -> bytes: 

803 """Serialize private key with OpenSSH custom encoding.""" 

804 utils._check_bytes("password", password) 

805 if isinstance(private_key, dsa.DSAPrivateKey): 

806 warnings.warn( 

807 "SSH DSA key support is deprecated and will be " 

808 "removed in a future release", 

809 utils.DeprecatedIn40, 

810 stacklevel=4, 

811 ) 

812 

813 key_type = _get_ssh_key_type(private_key) 

814 kformat = _lookup_kformat(key_type) 

815 

816 # setup parameters 

817 f_kdfoptions = _FragList() 

818 if password: 

819 ciphername = _DEFAULT_CIPHER 

820 blklen = _SSH_CIPHERS[ciphername].block_len 

821 kdfname = _BCRYPT 

822 rounds = _DEFAULT_ROUNDS 

823 if ( 

824 isinstance(encryption_algorithm, _KeySerializationEncryption) 

825 and encryption_algorithm._kdf_rounds is not None 

826 ): 

827 rounds = encryption_algorithm._kdf_rounds 

828 salt = os.urandom(16) 

829 f_kdfoptions.put_sshstr(salt) 

830 f_kdfoptions.put_u32(rounds) 

831 ciph = _init_cipher(ciphername, password, salt, rounds) 

832 else: 

833 ciphername = kdfname = _NONE 

834 blklen = 8 

835 ciph = None 

836 nkeys = 1 

837 checkval = os.urandom(4) 

838 comment = b"" 

839 

840 # encode public and private parts together 

841 f_public_key = _FragList() 

842 f_public_key.put_sshstr(key_type) 

843 kformat.encode_public(private_key.public_key(), f_public_key) 

844 

845 f_secrets = _FragList([checkval, checkval]) 

846 f_secrets.put_sshstr(key_type) 

847 kformat.encode_private(private_key, f_secrets) 

848 f_secrets.put_sshstr(comment) 

849 f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)]) 

850 

851 # top-level structure 

852 f_main = _FragList() 

853 f_main.put_raw(_SK_MAGIC) 

854 f_main.put_sshstr(ciphername) 

855 f_main.put_sshstr(kdfname) 

856 f_main.put_sshstr(f_kdfoptions) 

857 f_main.put_u32(nkeys) 

858 f_main.put_sshstr(f_public_key) 

859 f_main.put_sshstr(f_secrets) 

860 

861 # copy result info bytearray 

862 slen = f_secrets.size() 

863 mlen = f_main.size() 

864 buf = memoryview(bytearray(mlen + blklen)) 

865 f_main.render(buf) 

866 ofs = mlen - slen 

867 

868 # encrypt in-place 

869 if ciph is not None: 

870 ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:]) 

871 

872 return _ssh_pem_encode(buf[:mlen]) 

873 

874 

875SSHPublicKeyTypes = typing.Union[ 

876 ec.EllipticCurvePublicKey, 

877 rsa.RSAPublicKey, 

878 dsa.DSAPublicKey, 

879 ed25519.Ed25519PublicKey, 

880] 

881 

882SSHCertPublicKeyTypes = typing.Union[ 

883 ec.EllipticCurvePublicKey, 

884 rsa.RSAPublicKey, 

885 ed25519.Ed25519PublicKey, 

886] 

887 

888 

889class SSHCertificateType(enum.Enum): 

890 USER = 1 

891 HOST = 2 

892 

893 

894class SSHCertificate: 

895 def __init__( 

896 self, 

897 _nonce: memoryview, 

898 _public_key: SSHPublicKeyTypes, 

899 _serial: int, 

900 _cctype: int, 

901 _key_id: memoryview, 

902 _valid_principals: list[bytes], 

903 _valid_after: int, 

904 _valid_before: int, 

905 _critical_options: dict[bytes, bytes], 

906 _extensions: dict[bytes, bytes], 

907 _sig_type: memoryview, 

908 _sig_key: memoryview, 

909 _inner_sig_type: memoryview, 

910 _signature: memoryview, 

911 _tbs_cert_body: memoryview, 

912 _cert_key_type: bytes, 

913 _cert_body: memoryview, 

914 ): 

915 self._nonce = _nonce 

916 self._public_key = _public_key 

917 self._serial = _serial 

918 try: 

919 self._type = SSHCertificateType(_cctype) 

920 except ValueError: 

921 raise ValueError("Invalid certificate type") 

922 self._key_id = _key_id 

923 self._valid_principals = _valid_principals 

924 self._valid_after = _valid_after 

925 self._valid_before = _valid_before 

926 self._critical_options = _critical_options 

927 self._extensions = _extensions 

928 self._sig_type = _sig_type 

929 self._sig_key = _sig_key 

930 self._inner_sig_type = _inner_sig_type 

931 self._signature = _signature 

932 self._cert_key_type = _cert_key_type 

933 self._cert_body = _cert_body 

934 self._tbs_cert_body = _tbs_cert_body 

935 

936 @property 

937 def nonce(self) -> bytes: 

938 return bytes(self._nonce) 

939 

940 def public_key(self) -> SSHCertPublicKeyTypes: 

941 # make mypy happy until we remove DSA support entirely and 

942 # the underlying union won't have a disallowed type 

943 return typing.cast(SSHCertPublicKeyTypes, self._public_key) 

944 

945 @property 

946 def serial(self) -> int: 

947 return self._serial 

948 

949 @property 

950 def type(self) -> SSHCertificateType: 

951 return self._type 

952 

953 @property 

954 def key_id(self) -> bytes: 

955 return bytes(self._key_id) 

956 

957 @property 

958 def valid_principals(self) -> list[bytes]: 

959 return self._valid_principals 

960 

961 @property 

962 def valid_before(self) -> int: 

963 return self._valid_before 

964 

965 @property 

966 def valid_after(self) -> int: 

967 return self._valid_after 

968 

969 @property 

970 def critical_options(self) -> dict[bytes, bytes]: 

971 return self._critical_options 

972 

973 @property 

974 def extensions(self) -> dict[bytes, bytes]: 

975 return self._extensions 

976 

977 def signature_key(self) -> SSHCertPublicKeyTypes: 

978 sigformat = _lookup_kformat(self._sig_type) 

979 signature_key, sigkey_rest = sigformat.load_public(self._sig_key) 

980 _check_empty(sigkey_rest) 

981 return signature_key 

982 

983 def public_bytes(self) -> bytes: 

984 return ( 

985 bytes(self._cert_key_type) 

986 + b" " 

987 + binascii.b2a_base64(bytes(self._cert_body), newline=False) 

988 ) 

989 

990 def verify_cert_signature(self) -> None: 

991 signature_key = self.signature_key() 

992 if isinstance(signature_key, ed25519.Ed25519PublicKey): 

993 signature_key.verify( 

994 bytes(self._signature), bytes(self._tbs_cert_body) 

995 ) 

996 elif isinstance(signature_key, ec.EllipticCurvePublicKey): 

997 # The signature is encoded as a pair of big-endian integers 

998 r, data = _get_mpint(self._signature) 

999 s, data = _get_mpint(data) 

1000 _check_empty(data) 

1001 computed_sig = asym_utils.encode_dss_signature(r, s) 

1002 hash_alg = _get_ec_hash_alg(signature_key.curve) 

1003 signature_key.verify( 

1004 computed_sig, bytes(self._tbs_cert_body), ec.ECDSA(hash_alg) 

1005 ) 

1006 else: 

1007 assert isinstance(signature_key, rsa.RSAPublicKey) 

1008 if self._inner_sig_type == _SSH_RSA: 

1009 hash_alg = hashes.SHA1() 

1010 elif self._inner_sig_type == _SSH_RSA_SHA256: 

1011 hash_alg = hashes.SHA256() 

1012 else: 

1013 assert self._inner_sig_type == _SSH_RSA_SHA512 

1014 hash_alg = hashes.SHA512() 

1015 signature_key.verify( 

1016 bytes(self._signature), 

1017 bytes(self._tbs_cert_body), 

1018 padding.PKCS1v15(), 

1019 hash_alg, 

1020 ) 

1021 

1022 

1023def _get_ec_hash_alg(curve: ec.EllipticCurve) -> hashes.HashAlgorithm: 

1024 if isinstance(curve, ec.SECP256R1): 

1025 return hashes.SHA256() 

1026 elif isinstance(curve, ec.SECP384R1): 

1027 return hashes.SHA384() 

1028 else: 

1029 assert isinstance(curve, ec.SECP521R1) 

1030 return hashes.SHA512() 

1031 

1032 

1033def _load_ssh_public_identity( 

1034 data: utils.Buffer, 

1035 _legacy_dsa_allowed=False, 

1036) -> SSHCertificate | SSHPublicKeyTypes: 

1037 utils._check_byteslike("data", data) 

1038 

1039 m = _SSH_PUBKEY_RC.match(data) 

1040 if not m: 

1041 raise ValueError("Invalid line format") 

1042 key_type = orig_key_type = m.group(1) 

1043 key_body = m.group(2) 

1044 with_cert = False 

1045 if key_type.endswith(_CERT_SUFFIX): 

1046 with_cert = True 

1047 key_type = key_type[: -len(_CERT_SUFFIX)] 

1048 if key_type == _SSH_DSA and not _legacy_dsa_allowed: 

1049 raise UnsupportedAlgorithm( 

1050 "DSA keys aren't supported in SSH certificates" 

1051 ) 

1052 kformat = _lookup_kformat(key_type) 

1053 

1054 try: 

1055 rest = memoryview(binascii.a2b_base64(key_body)) 

1056 except (TypeError, binascii.Error): 

1057 raise ValueError("Invalid format") 

1058 

1059 if with_cert: 

1060 cert_body = rest 

1061 inner_key_type, rest = _get_sshstr(rest) 

1062 if inner_key_type != orig_key_type: 

1063 raise ValueError("Invalid key format") 

1064 if with_cert: 

1065 nonce, rest = _get_sshstr(rest) 

1066 public_key, rest = kformat.load_public(rest) 

1067 if with_cert: 

1068 serial, rest = _get_u64(rest) 

1069 cctype, rest = _get_u32(rest) 

1070 key_id, rest = _get_sshstr(rest) 

1071 principals, rest = _get_sshstr(rest) 

1072 valid_principals = [] 

1073 while principals: 

1074 principal, principals = _get_sshstr(principals) 

1075 valid_principals.append(bytes(principal)) 

1076 valid_after, rest = _get_u64(rest) 

1077 valid_before, rest = _get_u64(rest) 

1078 crit_options, rest = _get_sshstr(rest) 

1079 critical_options = _parse_exts_opts(crit_options) 

1080 exts, rest = _get_sshstr(rest) 

1081 extensions = _parse_exts_opts(exts) 

1082 # Get the reserved field, which is unused. 

1083 _, rest = _get_sshstr(rest) 

1084 sig_key_raw, rest = _get_sshstr(rest) 

1085 sig_type, sig_key = _get_sshstr(sig_key_raw) 

1086 if sig_type == _SSH_DSA and not _legacy_dsa_allowed: 

1087 raise UnsupportedAlgorithm( 

1088 "DSA signatures aren't supported in SSH certificates" 

1089 ) 

1090 # Get the entire cert body and subtract the signature 

1091 tbs_cert_body = cert_body[: -len(rest)] 

1092 signature_raw, rest = _get_sshstr(rest) 

1093 _check_empty(rest) 

1094 inner_sig_type, sig_rest = _get_sshstr(signature_raw) 

1095 # RSA certs can have multiple algorithm types 

1096 if ( 

1097 sig_type == _SSH_RSA 

1098 and inner_sig_type 

1099 not in [_SSH_RSA_SHA256, _SSH_RSA_SHA512, _SSH_RSA] 

1100 ) or (sig_type != _SSH_RSA and inner_sig_type != sig_type): 

1101 raise ValueError("Signature key type does not match") 

1102 signature, sig_rest = _get_sshstr(sig_rest) 

1103 _check_empty(sig_rest) 

1104 return SSHCertificate( 

1105 nonce, 

1106 public_key, 

1107 serial, 

1108 cctype, 

1109 key_id, 

1110 valid_principals, 

1111 valid_after, 

1112 valid_before, 

1113 critical_options, 

1114 extensions, 

1115 sig_type, 

1116 sig_key, 

1117 inner_sig_type, 

1118 signature, 

1119 tbs_cert_body, 

1120 orig_key_type, 

1121 cert_body, 

1122 ) 

1123 else: 

1124 _check_empty(rest) 

1125 return public_key 

1126 

1127 

1128def load_ssh_public_identity( 

1129 data: utils.Buffer, 

1130) -> SSHCertificate | SSHPublicKeyTypes: 

1131 return _load_ssh_public_identity(data) 

1132 

1133 

1134def _parse_exts_opts(exts_opts: memoryview) -> dict[bytes, bytes]: 

1135 result: dict[bytes, bytes] = {} 

1136 last_name = None 

1137 while exts_opts: 

1138 name, exts_opts = _get_sshstr(exts_opts) 

1139 bname: bytes = bytes(name) 

1140 if bname in result: 

1141 raise ValueError("Duplicate name") 

1142 if last_name is not None and bname < last_name: 

1143 raise ValueError("Fields not lexically sorted") 

1144 value, exts_opts = _get_sshstr(exts_opts) 

1145 if len(value) > 0: 

1146 value, extra = _get_sshstr(value) 

1147 if len(extra) > 0: 

1148 raise ValueError("Unexpected extra data after value") 

1149 result[bname] = bytes(value) 

1150 last_name = bname 

1151 return result 

1152 

1153 

1154def ssh_key_fingerprint( 

1155 key: SSHPublicKeyTypes, 

1156 hash_algorithm: hashes.MD5 | hashes.SHA256, 

1157) -> bytes: 

1158 if not isinstance(hash_algorithm, (hashes.MD5, hashes.SHA256)): 

1159 raise TypeError("hash_algorithm must be either MD5 or SHA256") 

1160 

1161 key_type = _get_ssh_key_type(key) 

1162 kformat = _lookup_kformat(key_type) 

1163 

1164 f_pub = _FragList() 

1165 f_pub.put_sshstr(key_type) 

1166 kformat.encode_public(key, f_pub) 

1167 

1168 ssh_binary_data = f_pub.tobytes() 

1169 

1170 # Hash the binary data 

1171 hash_obj = hashes.Hash(hash_algorithm) 

1172 hash_obj.update(ssh_binary_data) 

1173 return hash_obj.finalize() 

1174 

1175 

1176def load_ssh_public_key( 

1177 data: utils.Buffer, backend: typing.Any = None 

1178) -> SSHPublicKeyTypes: 

1179 cert_or_key = _load_ssh_public_identity(data, _legacy_dsa_allowed=True) 

1180 public_key: SSHPublicKeyTypes 

1181 if isinstance(cert_or_key, SSHCertificate): 

1182 public_key = cert_or_key.public_key() 

1183 else: 

1184 public_key = cert_or_key 

1185 

1186 if isinstance(public_key, dsa.DSAPublicKey): 

1187 warnings.warn( 

1188 "SSH DSA keys are deprecated and will be removed in a future " 

1189 "release.", 

1190 utils.DeprecatedIn40, 

1191 stacklevel=2, 

1192 ) 

1193 return public_key 

1194 

1195 

1196def serialize_ssh_public_key(public_key: SSHPublicKeyTypes) -> bytes: 

1197 """One-line public key format for OpenSSH""" 

1198 if isinstance(public_key, dsa.DSAPublicKey): 

1199 warnings.warn( 

1200 "SSH DSA key support is deprecated and will be " 

1201 "removed in a future release", 

1202 utils.DeprecatedIn40, 

1203 stacklevel=4, 

1204 ) 

1205 key_type = _get_ssh_key_type(public_key) 

1206 kformat = _lookup_kformat(key_type) 

1207 

1208 f_pub = _FragList() 

1209 f_pub.put_sshstr(key_type) 

1210 kformat.encode_public(public_key, f_pub) 

1211 

1212 pub = binascii.b2a_base64(f_pub.tobytes()).strip() 

1213 return b"".join([key_type, b" ", pub]) 

1214 

1215 

1216SSHCertPrivateKeyTypes = typing.Union[ 

1217 ec.EllipticCurvePrivateKey, 

1218 rsa.RSAPrivateKey, 

1219 ed25519.Ed25519PrivateKey, 

1220] 

1221 

1222 

1223# This is an undocumented limit enforced in the openssh codebase for sshd and 

1224# ssh-keygen, but it is undefined in the ssh certificates spec. 

1225_SSHKEY_CERT_MAX_PRINCIPALS = 256 

1226 

1227 

1228class SSHCertificateBuilder: 

1229 def __init__( 

1230 self, 

1231 _public_key: SSHCertPublicKeyTypes | None = None, 

1232 _serial: int | None = None, 

1233 _type: SSHCertificateType | None = None, 

1234 _key_id: bytes | None = None, 

1235 _valid_principals: list[bytes] = [], 

1236 _valid_for_all_principals: bool = False, 

1237 _valid_before: int | None = None, 

1238 _valid_after: int | None = None, 

1239 _critical_options: list[tuple[bytes, bytes]] = [], 

1240 _extensions: list[tuple[bytes, bytes]] = [], 

1241 ): 

1242 self._public_key = _public_key 

1243 self._serial = _serial 

1244 self._type = _type 

1245 self._key_id = _key_id 

1246 self._valid_principals = _valid_principals 

1247 self._valid_for_all_principals = _valid_for_all_principals 

1248 self._valid_before = _valid_before 

1249 self._valid_after = _valid_after 

1250 self._critical_options = _critical_options 

1251 self._extensions = _extensions 

1252 

1253 def public_key( 

1254 self, public_key: SSHCertPublicKeyTypes 

1255 ) -> SSHCertificateBuilder: 

1256 if not isinstance( 

1257 public_key, 

1258 ( 

1259 ec.EllipticCurvePublicKey, 

1260 rsa.RSAPublicKey, 

1261 ed25519.Ed25519PublicKey, 

1262 ), 

1263 ): 

1264 raise TypeError("Unsupported key type") 

1265 if self._public_key is not None: 

1266 raise ValueError("public_key already set") 

1267 

1268 return SSHCertificateBuilder( 

1269 _public_key=public_key, 

1270 _serial=self._serial, 

1271 _type=self._type, 

1272 _key_id=self._key_id, 

1273 _valid_principals=self._valid_principals, 

1274 _valid_for_all_principals=self._valid_for_all_principals, 

1275 _valid_before=self._valid_before, 

1276 _valid_after=self._valid_after, 

1277 _critical_options=self._critical_options, 

1278 _extensions=self._extensions, 

1279 ) 

1280 

1281 def serial(self, serial: int) -> SSHCertificateBuilder: 

1282 if not isinstance(serial, int): 

1283 raise TypeError("serial must be an integer") 

1284 if not 0 <= serial < 2**64: 

1285 raise ValueError("serial must be between 0 and 2**64") 

1286 if self._serial is not None: 

1287 raise ValueError("serial already set") 

1288 

1289 return SSHCertificateBuilder( 

1290 _public_key=self._public_key, 

1291 _serial=serial, 

1292 _type=self._type, 

1293 _key_id=self._key_id, 

1294 _valid_principals=self._valid_principals, 

1295 _valid_for_all_principals=self._valid_for_all_principals, 

1296 _valid_before=self._valid_before, 

1297 _valid_after=self._valid_after, 

1298 _critical_options=self._critical_options, 

1299 _extensions=self._extensions, 

1300 ) 

1301 

1302 def type(self, type: SSHCertificateType) -> SSHCertificateBuilder: 

1303 if not isinstance(type, SSHCertificateType): 

1304 raise TypeError("type must be an SSHCertificateType") 

1305 if self._type is not None: 

1306 raise ValueError("type already set") 

1307 

1308 return SSHCertificateBuilder( 

1309 _public_key=self._public_key, 

1310 _serial=self._serial, 

1311 _type=type, 

1312 _key_id=self._key_id, 

1313 _valid_principals=self._valid_principals, 

1314 _valid_for_all_principals=self._valid_for_all_principals, 

1315 _valid_before=self._valid_before, 

1316 _valid_after=self._valid_after, 

1317 _critical_options=self._critical_options, 

1318 _extensions=self._extensions, 

1319 ) 

1320 

1321 def key_id(self, key_id: bytes) -> SSHCertificateBuilder: 

1322 if not isinstance(key_id, bytes): 

1323 raise TypeError("key_id must be bytes") 

1324 if self._key_id is not None: 

1325 raise ValueError("key_id already set") 

1326 

1327 return SSHCertificateBuilder( 

1328 _public_key=self._public_key, 

1329 _serial=self._serial, 

1330 _type=self._type, 

1331 _key_id=key_id, 

1332 _valid_principals=self._valid_principals, 

1333 _valid_for_all_principals=self._valid_for_all_principals, 

1334 _valid_before=self._valid_before, 

1335 _valid_after=self._valid_after, 

1336 _critical_options=self._critical_options, 

1337 _extensions=self._extensions, 

1338 ) 

1339 

1340 def valid_principals( 

1341 self, valid_principals: list[bytes] 

1342 ) -> SSHCertificateBuilder: 

1343 if self._valid_for_all_principals: 

1344 raise ValueError( 

1345 "Principals can't be set because the cert is valid " 

1346 "for all principals" 

1347 ) 

1348 if ( 

1349 not all(isinstance(x, bytes) for x in valid_principals) 

1350 or not valid_principals 

1351 ): 

1352 raise TypeError( 

1353 "principals must be a list of bytes and can't be empty" 

1354 ) 

1355 if self._valid_principals: 

1356 raise ValueError("valid_principals already set") 

1357 

1358 if len(valid_principals) > _SSHKEY_CERT_MAX_PRINCIPALS: 

1359 raise ValueError( 

1360 "Reached or exceeded the maximum number of valid_principals" 

1361 ) 

1362 

1363 return SSHCertificateBuilder( 

1364 _public_key=self._public_key, 

1365 _serial=self._serial, 

1366 _type=self._type, 

1367 _key_id=self._key_id, 

1368 _valid_principals=valid_principals, 

1369 _valid_for_all_principals=self._valid_for_all_principals, 

1370 _valid_before=self._valid_before, 

1371 _valid_after=self._valid_after, 

1372 _critical_options=self._critical_options, 

1373 _extensions=self._extensions, 

1374 ) 

1375 

1376 def valid_for_all_principals(self): 

1377 if self._valid_principals: 

1378 raise ValueError( 

1379 "valid_principals already set, can't set " 

1380 "valid_for_all_principals" 

1381 ) 

1382 if self._valid_for_all_principals: 

1383 raise ValueError("valid_for_all_principals already set") 

1384 

1385 return SSHCertificateBuilder( 

1386 _public_key=self._public_key, 

1387 _serial=self._serial, 

1388 _type=self._type, 

1389 _key_id=self._key_id, 

1390 _valid_principals=self._valid_principals, 

1391 _valid_for_all_principals=True, 

1392 _valid_before=self._valid_before, 

1393 _valid_after=self._valid_after, 

1394 _critical_options=self._critical_options, 

1395 _extensions=self._extensions, 

1396 ) 

1397 

1398 def valid_before(self, valid_before: int | float) -> SSHCertificateBuilder: 

1399 if not isinstance(valid_before, (int, float)): 

1400 raise TypeError("valid_before must be an int or float") 

1401 valid_before = int(valid_before) 

1402 if valid_before < 0 or valid_before >= 2**64: 

1403 raise ValueError("valid_before must [0, 2**64)") 

1404 if self._valid_before is not None: 

1405 raise ValueError("valid_before already set") 

1406 

1407 return SSHCertificateBuilder( 

1408 _public_key=self._public_key, 

1409 _serial=self._serial, 

1410 _type=self._type, 

1411 _key_id=self._key_id, 

1412 _valid_principals=self._valid_principals, 

1413 _valid_for_all_principals=self._valid_for_all_principals, 

1414 _valid_before=valid_before, 

1415 _valid_after=self._valid_after, 

1416 _critical_options=self._critical_options, 

1417 _extensions=self._extensions, 

1418 ) 

1419 

1420 def valid_after(self, valid_after: int | float) -> SSHCertificateBuilder: 

1421 if not isinstance(valid_after, (int, float)): 

1422 raise TypeError("valid_after must be an int or float") 

1423 valid_after = int(valid_after) 

1424 if valid_after < 0 or valid_after >= 2**64: 

1425 raise ValueError("valid_after must [0, 2**64)") 

1426 if self._valid_after is not None: 

1427 raise ValueError("valid_after already set") 

1428 

1429 return SSHCertificateBuilder( 

1430 _public_key=self._public_key, 

1431 _serial=self._serial, 

1432 _type=self._type, 

1433 _key_id=self._key_id, 

1434 _valid_principals=self._valid_principals, 

1435 _valid_for_all_principals=self._valid_for_all_principals, 

1436 _valid_before=self._valid_before, 

1437 _valid_after=valid_after, 

1438 _critical_options=self._critical_options, 

1439 _extensions=self._extensions, 

1440 ) 

1441 

1442 def add_critical_option( 

1443 self, name: bytes, value: bytes 

1444 ) -> SSHCertificateBuilder: 

1445 if not isinstance(name, bytes) or not isinstance(value, bytes): 

1446 raise TypeError("name and value must be bytes") 

1447 # This is O(n**2) 

1448 if name in [name for name, _ in self._critical_options]: 

1449 raise ValueError("Duplicate critical option name") 

1450 

1451 return SSHCertificateBuilder( 

1452 _public_key=self._public_key, 

1453 _serial=self._serial, 

1454 _type=self._type, 

1455 _key_id=self._key_id, 

1456 _valid_principals=self._valid_principals, 

1457 _valid_for_all_principals=self._valid_for_all_principals, 

1458 _valid_before=self._valid_before, 

1459 _valid_after=self._valid_after, 

1460 _critical_options=[*self._critical_options, (name, value)], 

1461 _extensions=self._extensions, 

1462 ) 

1463 

1464 def add_extension( 

1465 self, name: bytes, value: bytes 

1466 ) -> SSHCertificateBuilder: 

1467 if not isinstance(name, bytes) or not isinstance(value, bytes): 

1468 raise TypeError("name and value must be bytes") 

1469 # This is O(n**2) 

1470 if name in [name for name, _ in self._extensions]: 

1471 raise ValueError("Duplicate extension name") 

1472 

1473 return SSHCertificateBuilder( 

1474 _public_key=self._public_key, 

1475 _serial=self._serial, 

1476 _type=self._type, 

1477 _key_id=self._key_id, 

1478 _valid_principals=self._valid_principals, 

1479 _valid_for_all_principals=self._valid_for_all_principals, 

1480 _valid_before=self._valid_before, 

1481 _valid_after=self._valid_after, 

1482 _critical_options=self._critical_options, 

1483 _extensions=[*self._extensions, (name, value)], 

1484 ) 

1485 

1486 def sign(self, private_key: SSHCertPrivateKeyTypes) -> SSHCertificate: 

1487 if not isinstance( 

1488 private_key, 

1489 ( 

1490 ec.EllipticCurvePrivateKey, 

1491 rsa.RSAPrivateKey, 

1492 ed25519.Ed25519PrivateKey, 

1493 ), 

1494 ): 

1495 raise TypeError("Unsupported private key type") 

1496 

1497 if self._public_key is None: 

1498 raise ValueError("public_key must be set") 

1499 

1500 # Not required 

1501 serial = 0 if self._serial is None else self._serial 

1502 

1503 if self._type is None: 

1504 raise ValueError("type must be set") 

1505 

1506 # Not required 

1507 key_id = b"" if self._key_id is None else self._key_id 

1508 

1509 # A zero length list is valid, but means the certificate 

1510 # is valid for any principal of the specified type. We require 

1511 # the user to explicitly set valid_for_all_principals to get 

1512 # that behavior. 

1513 if not self._valid_principals and not self._valid_for_all_principals: 

1514 raise ValueError( 

1515 "valid_principals must be set if valid_for_all_principals " 

1516 "is False" 

1517 ) 

1518 

1519 if self._valid_before is None: 

1520 raise ValueError("valid_before must be set") 

1521 

1522 if self._valid_after is None: 

1523 raise ValueError("valid_after must be set") 

1524 

1525 if self._valid_after > self._valid_before: 

1526 raise ValueError("valid_after must be earlier than valid_before") 

1527 

1528 # lexically sort our byte strings 

1529 self._critical_options.sort(key=lambda x: x[0]) 

1530 self._extensions.sort(key=lambda x: x[0]) 

1531 

1532 key_type = _get_ssh_key_type(self._public_key) 

1533 cert_prefix = key_type + _CERT_SUFFIX 

1534 

1535 # Marshal the bytes to be signed 

1536 nonce = os.urandom(32) 

1537 kformat = _lookup_kformat(key_type) 

1538 f = _FragList() 

1539 f.put_sshstr(cert_prefix) 

1540 f.put_sshstr(nonce) 

1541 kformat.encode_public(self._public_key, f) 

1542 f.put_u64(serial) 

1543 f.put_u32(self._type.value) 

1544 f.put_sshstr(key_id) 

1545 fprincipals = _FragList() 

1546 for p in self._valid_principals: 

1547 fprincipals.put_sshstr(p) 

1548 f.put_sshstr(fprincipals.tobytes()) 

1549 f.put_u64(self._valid_after) 

1550 f.put_u64(self._valid_before) 

1551 fcrit = _FragList() 

1552 for name, value in self._critical_options: 

1553 fcrit.put_sshstr(name) 

1554 if len(value) > 0: 

1555 foptval = _FragList() 

1556 foptval.put_sshstr(value) 

1557 fcrit.put_sshstr(foptval.tobytes()) 

1558 else: 

1559 fcrit.put_sshstr(value) 

1560 f.put_sshstr(fcrit.tobytes()) 

1561 fext = _FragList() 

1562 for name, value in self._extensions: 

1563 fext.put_sshstr(name) 

1564 if len(value) > 0: 

1565 fextval = _FragList() 

1566 fextval.put_sshstr(value) 

1567 fext.put_sshstr(fextval.tobytes()) 

1568 else: 

1569 fext.put_sshstr(value) 

1570 f.put_sshstr(fext.tobytes()) 

1571 f.put_sshstr(b"") # RESERVED FIELD 

1572 # encode CA public key 

1573 ca_type = _get_ssh_key_type(private_key) 

1574 caformat = _lookup_kformat(ca_type) 

1575 caf = _FragList() 

1576 caf.put_sshstr(ca_type) 

1577 caformat.encode_public(private_key.public_key(), caf) 

1578 f.put_sshstr(caf.tobytes()) 

1579 # Sigs according to the rules defined for the CA's public key 

1580 # (RFC4253 section 6.6 for ssh-rsa, RFC5656 for ECDSA, 

1581 # and RFC8032 for Ed25519). 

1582 if isinstance(private_key, ed25519.Ed25519PrivateKey): 

1583 signature = private_key.sign(f.tobytes()) 

1584 fsig = _FragList() 

1585 fsig.put_sshstr(ca_type) 

1586 fsig.put_sshstr(signature) 

1587 f.put_sshstr(fsig.tobytes()) 

1588 elif isinstance(private_key, ec.EllipticCurvePrivateKey): 

1589 hash_alg = _get_ec_hash_alg(private_key.curve) 

1590 signature = private_key.sign(f.tobytes(), ec.ECDSA(hash_alg)) 

1591 r, s = asym_utils.decode_dss_signature(signature) 

1592 fsig = _FragList() 

1593 fsig.put_sshstr(ca_type) 

1594 fsigblob = _FragList() 

1595 fsigblob.put_mpint(r) 

1596 fsigblob.put_mpint(s) 

1597 fsig.put_sshstr(fsigblob.tobytes()) 

1598 f.put_sshstr(fsig.tobytes()) 

1599 

1600 else: 

1601 assert isinstance(private_key, rsa.RSAPrivateKey) 

1602 # Just like Golang, we're going to use SHA512 for RSA 

1603 # https://cs.opensource.google/go/x/crypto/+/refs/tags/ 

1604 # v0.4.0:ssh/certs.go;l=445 

1605 # RFC 8332 defines SHA256 and 512 as options 

1606 fsig = _FragList() 

1607 fsig.put_sshstr(_SSH_RSA_SHA512) 

1608 signature = private_key.sign( 

1609 f.tobytes(), padding.PKCS1v15(), hashes.SHA512() 

1610 ) 

1611 fsig.put_sshstr(signature) 

1612 f.put_sshstr(fsig.tobytes()) 

1613 

1614 cert_data = binascii.b2a_base64(f.tobytes()).strip() 

1615 # load_ssh_public_identity returns a union, but this is 

1616 # guaranteed to be an SSHCertificate, so we cast to make 

1617 # mypy happy. 

1618 return typing.cast( 

1619 SSHCertificate, 

1620 load_ssh_public_identity(b"".join([cert_prefix, b" ", cert_data])), 

1621 )