Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/cryptography/hazmat/primitives/serialization/ssh.py: 21%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

775 statements  

1# This file is dual licensed under the terms of the Apache License, Version 

2# 2.0, and the BSD License. See the LICENSE file in the root of this repository 

3# for complete details. 

4 

5from __future__ import annotations 

6 

7import binascii 

8import enum 

9import os 

10import re 

11import typing 

12import warnings 

13from base64 import encodebytes as _base64_encode 

14from dataclasses import dataclass 

15 

16from cryptography import utils 

17from cryptography.exceptions import UnsupportedAlgorithm 

18from cryptography.hazmat.primitives import hashes 

19from cryptography.hazmat.primitives.asymmetric import ( 

20 dsa, 

21 ec, 

22 ed25519, 

23 padding, 

24 rsa, 

25) 

26from cryptography.hazmat.primitives.asymmetric import utils as asym_utils 

27from cryptography.hazmat.primitives.ciphers import ( 

28 AEADDecryptionContext, 

29 Cipher, 

30 algorithms, 

31 modes, 

32) 

33from cryptography.hazmat.primitives.serialization import ( 

34 Encoding, 

35 KeySerializationEncryption, 

36 NoEncryption, 

37 PrivateFormat, 

38 PublicFormat, 

39 _KeySerializationEncryption, 

40) 

41 

42try: 

43 from bcrypt import kdf as _bcrypt_kdf 

44 

45 _bcrypt_supported = True 

46except ImportError: 

47 _bcrypt_supported = False 

48 

49 def _bcrypt_kdf( 

50 password: bytes, 

51 salt: bytes, 

52 desired_key_bytes: int, 

53 rounds: int, 

54 ignore_few_rounds: bool = False, 

55 ) -> bytes: 

56 raise UnsupportedAlgorithm("Need bcrypt module") 

57 

58 

59_SSH_ED25519 = b"ssh-ed25519" 

60_SSH_RSA = b"ssh-rsa" 

61_SSH_DSA = b"ssh-dss" 

62_ECDSA_NISTP256 = b"ecdsa-sha2-nistp256" 

63_ECDSA_NISTP384 = b"ecdsa-sha2-nistp384" 

64_ECDSA_NISTP521 = b"ecdsa-sha2-nistp521" 

65_CERT_SUFFIX = b"-cert-v01@openssh.com" 

66 

67# U2F application string suffixed pubkey 

68_SK_SSH_ED25519 = b"sk-ssh-ed25519@openssh.com" 

69_SK_SSH_ECDSA_NISTP256 = b"sk-ecdsa-sha2-nistp256@openssh.com" 

70 

71# These are not key types, only algorithms, so they cannot appear 

72# as a public key type 

73_SSH_RSA_SHA256 = b"rsa-sha2-256" 

74_SSH_RSA_SHA512 = b"rsa-sha2-512" 

75 

76_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)") 

77_SK_MAGIC = b"openssh-key-v1\0" 

78_SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----" 

79_SK_END = b"-----END OPENSSH PRIVATE KEY-----" 

80_BCRYPT = b"bcrypt" 

81_NONE = b"none" 

82_DEFAULT_CIPHER = b"aes256-ctr" 

83_DEFAULT_ROUNDS = 16 

84 

85# re is only way to work on bytes-like data 

86_PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL) 

87 

88# padding for max blocksize 

89_PADDING = memoryview(bytearray(range(1, 1 + 16))) 

90 

91 

92@dataclass 

93class _SSHCipher: 

94 alg: type[algorithms.AES] 

95 key_len: int 

96 mode: type[modes.CTR] | type[modes.CBC] | type[modes.GCM] 

97 block_len: int 

98 iv_len: int 

99 tag_len: int | None 

100 is_aead: bool 

101 

102 

103# ciphers that are actually used in key wrapping 

104_SSH_CIPHERS: dict[bytes, _SSHCipher] = { 

105 b"aes256-ctr": _SSHCipher( 

106 alg=algorithms.AES, 

107 key_len=32, 

108 mode=modes.CTR, 

109 block_len=16, 

110 iv_len=16, 

111 tag_len=None, 

112 is_aead=False, 

113 ), 

114 b"aes256-cbc": _SSHCipher( 

115 alg=algorithms.AES, 

116 key_len=32, 

117 mode=modes.CBC, 

118 block_len=16, 

119 iv_len=16, 

120 tag_len=None, 

121 is_aead=False, 

122 ), 

123 b"aes256-gcm@openssh.com": _SSHCipher( 

124 alg=algorithms.AES, 

125 key_len=32, 

126 mode=modes.GCM, 

127 block_len=16, 

128 iv_len=12, 

129 tag_len=16, 

130 is_aead=True, 

131 ), 

132} 

133 

134# map local curve name to key type 

135_ECDSA_KEY_TYPE = { 

136 "secp256r1": _ECDSA_NISTP256, 

137 "secp384r1": _ECDSA_NISTP384, 

138 "secp521r1": _ECDSA_NISTP521, 

139} 

140 

141 

142def _get_ssh_key_type(key: SSHPrivateKeyTypes | SSHPublicKeyTypes) -> bytes: 

143 if isinstance(key, ec.EllipticCurvePrivateKey): 

144 key_type = _ecdsa_key_type(key.public_key()) 

145 elif isinstance(key, ec.EllipticCurvePublicKey): 

146 key_type = _ecdsa_key_type(key) 

147 elif isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)): 

148 key_type = _SSH_RSA 

149 elif isinstance(key, (dsa.DSAPrivateKey, dsa.DSAPublicKey)): 

150 key_type = _SSH_DSA 

151 elif isinstance( 

152 key, (ed25519.Ed25519PrivateKey, ed25519.Ed25519PublicKey) 

153 ): 

154 key_type = _SSH_ED25519 

155 else: 

156 raise ValueError("Unsupported key type") 

157 

158 return key_type 

159 

160 

161def _ecdsa_key_type(public_key: ec.EllipticCurvePublicKey) -> bytes: 

162 """Return SSH key_type and curve_name for private key.""" 

163 curve = public_key.curve 

164 if curve.name not in _ECDSA_KEY_TYPE: 

165 raise ValueError( 

166 f"Unsupported curve for ssh private key: {curve.name!r}" 

167 ) 

168 return _ECDSA_KEY_TYPE[curve.name] 

169 

170 

171def _ssh_pem_encode( 

172 data: bytes, 

173 prefix: bytes = _SK_START + b"\n", 

174 suffix: bytes = _SK_END + b"\n", 

175) -> bytes: 

176 return b"".join([prefix, _base64_encode(data), suffix]) 

177 

178 

179def _check_block_size(data: bytes, block_len: int) -> None: 

180 """Require data to be full blocks""" 

181 if not data or len(data) % block_len != 0: 

182 raise ValueError("Corrupt data: missing padding") 

183 

184 

185def _check_empty(data: bytes) -> None: 

186 """All data should have been parsed.""" 

187 if data: 

188 raise ValueError("Corrupt data: unparsed data") 

189 

190 

191def _init_cipher( 

192 ciphername: bytes, 

193 password: bytes | None, 

194 salt: bytes, 

195 rounds: int, 

196) -> Cipher[modes.CBC | modes.CTR | modes.GCM]: 

197 """Generate key + iv and return cipher.""" 

198 if not password: 

199 raise ValueError("Key is password-protected.") 

200 

201 ciph = _SSH_CIPHERS[ciphername] 

202 seed = _bcrypt_kdf( 

203 password, salt, ciph.key_len + ciph.iv_len, rounds, True 

204 ) 

205 return Cipher( 

206 ciph.alg(seed[: ciph.key_len]), 

207 ciph.mode(seed[ciph.key_len :]), 

208 ) 

209 

210 

211def _get_u32(data: memoryview) -> tuple[int, memoryview]: 

212 """Uint32""" 

213 if len(data) < 4: 

214 raise ValueError("Invalid data") 

215 return int.from_bytes(data[:4], byteorder="big"), data[4:] 

216 

217 

218def _get_u64(data: memoryview) -> tuple[int, memoryview]: 

219 """Uint64""" 

220 if len(data) < 8: 

221 raise ValueError("Invalid data") 

222 return int.from_bytes(data[:8], byteorder="big"), data[8:] 

223 

224 

225def _get_sshstr(data: memoryview) -> tuple[memoryview, memoryview]: 

226 """Bytes with u32 length prefix""" 

227 n, data = _get_u32(data) 

228 if n > len(data): 

229 raise ValueError("Invalid data") 

230 return data[:n], data[n:] 

231 

232 

233def _get_mpint(data: memoryview) -> tuple[int, memoryview]: 

234 """Big integer.""" 

235 val, data = _get_sshstr(data) 

236 if val and val[0] > 0x7F: 

237 raise ValueError("Invalid data") 

238 return int.from_bytes(val, "big"), data 

239 

240 

241def _to_mpint(val: int) -> bytes: 

242 """Storage format for signed bigint.""" 

243 if val < 0: 

244 raise ValueError("negative mpint not allowed") 

245 if not val: 

246 return b"" 

247 nbytes = (val.bit_length() + 8) // 8 

248 return utils.int_to_bytes(val, nbytes) 

249 

250 

251class _FragList: 

252 """Build recursive structure without data copy.""" 

253 

254 flist: list[bytes] 

255 

256 def __init__(self, init: list[bytes] | None = None) -> None: 

257 self.flist = [] 

258 if init: 

259 self.flist.extend(init) 

260 

261 def put_raw(self, val: bytes) -> None: 

262 """Add plain bytes""" 

263 self.flist.append(val) 

264 

265 def put_u32(self, val: int) -> None: 

266 """Big-endian uint32""" 

267 self.flist.append(val.to_bytes(length=4, byteorder="big")) 

268 

269 def put_u64(self, val: int) -> None: 

270 """Big-endian uint64""" 

271 self.flist.append(val.to_bytes(length=8, byteorder="big")) 

272 

273 def put_sshstr(self, val: bytes | _FragList) -> None: 

274 """Bytes prefixed with u32 length""" 

275 if isinstance(val, (bytes, memoryview, bytearray)): 

276 self.put_u32(len(val)) 

277 self.flist.append(val) 

278 else: 

279 self.put_u32(val.size()) 

280 self.flist.extend(val.flist) 

281 

282 def put_mpint(self, val: int) -> None: 

283 """Big-endian bigint prefixed with u32 length""" 

284 self.put_sshstr(_to_mpint(val)) 

285 

286 def size(self) -> int: 

287 """Current number of bytes""" 

288 return sum(map(len, self.flist)) 

289 

290 def render(self, dstbuf: memoryview, pos: int = 0) -> int: 

291 """Write into bytearray""" 

292 for frag in self.flist: 

293 flen = len(frag) 

294 start, pos = pos, pos + flen 

295 dstbuf[start:pos] = frag 

296 return pos 

297 

298 def tobytes(self) -> bytes: 

299 """Return as bytes""" 

300 buf = memoryview(bytearray(self.size())) 

301 self.render(buf) 

302 return buf.tobytes() 

303 

304 

305class _SSHFormatRSA: 

306 """Format for RSA keys. 

307 

308 Public: 

309 mpint e, n 

310 Private: 

311 mpint n, e, d, iqmp, p, q 

312 """ 

313 

314 def get_public( 

315 self, data: memoryview 

316 ) -> tuple[tuple[int, int], memoryview]: 

317 """RSA public fields""" 

318 e, data = _get_mpint(data) 

319 n, data = _get_mpint(data) 

320 return (e, n), data 

321 

322 def load_public( 

323 self, data: memoryview 

324 ) -> tuple[rsa.RSAPublicKey, memoryview]: 

325 """Make RSA public key from data.""" 

326 (e, n), data = self.get_public(data) 

327 public_numbers = rsa.RSAPublicNumbers(e, n) 

328 public_key = public_numbers.public_key() 

329 return public_key, data 

330 

331 def load_private( 

332 self, data: memoryview, pubfields 

333 ) -> tuple[rsa.RSAPrivateKey, memoryview]: 

334 """Make RSA private key from data.""" 

335 n, data = _get_mpint(data) 

336 e, data = _get_mpint(data) 

337 d, data = _get_mpint(data) 

338 iqmp, data = _get_mpint(data) 

339 p, data = _get_mpint(data) 

340 q, data = _get_mpint(data) 

341 

342 if (e, n) != pubfields: 

343 raise ValueError("Corrupt data: rsa field mismatch") 

344 dmp1 = rsa.rsa_crt_dmp1(d, p) 

345 dmq1 = rsa.rsa_crt_dmq1(d, q) 

346 public_numbers = rsa.RSAPublicNumbers(e, n) 

347 private_numbers = rsa.RSAPrivateNumbers( 

348 p, q, d, dmp1, dmq1, iqmp, public_numbers 

349 ) 

350 private_key = private_numbers.private_key() 

351 return private_key, data 

352 

353 def encode_public( 

354 self, public_key: rsa.RSAPublicKey, f_pub: _FragList 

355 ) -> None: 

356 """Write RSA public key""" 

357 pubn = public_key.public_numbers() 

358 f_pub.put_mpint(pubn.e) 

359 f_pub.put_mpint(pubn.n) 

360 

361 def encode_private( 

362 self, private_key: rsa.RSAPrivateKey, f_priv: _FragList 

363 ) -> None: 

364 """Write RSA private key""" 

365 private_numbers = private_key.private_numbers() 

366 public_numbers = private_numbers.public_numbers 

367 

368 f_priv.put_mpint(public_numbers.n) 

369 f_priv.put_mpint(public_numbers.e) 

370 

371 f_priv.put_mpint(private_numbers.d) 

372 f_priv.put_mpint(private_numbers.iqmp) 

373 f_priv.put_mpint(private_numbers.p) 

374 f_priv.put_mpint(private_numbers.q) 

375 

376 

377class _SSHFormatDSA: 

378 """Format for DSA keys. 

379 

380 Public: 

381 mpint p, q, g, y 

382 Private: 

383 mpint p, q, g, y, x 

384 """ 

385 

386 def get_public(self, data: memoryview) -> tuple[tuple, memoryview]: 

387 """DSA public fields""" 

388 p, data = _get_mpint(data) 

389 q, data = _get_mpint(data) 

390 g, data = _get_mpint(data) 

391 y, data = _get_mpint(data) 

392 return (p, q, g, y), data 

393 

394 def load_public( 

395 self, data: memoryview 

396 ) -> tuple[dsa.DSAPublicKey, memoryview]: 

397 """Make DSA public key from data.""" 

398 (p, q, g, y), data = self.get_public(data) 

399 parameter_numbers = dsa.DSAParameterNumbers(p, q, g) 

400 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) 

401 self._validate(public_numbers) 

402 public_key = public_numbers.public_key() 

403 return public_key, data 

404 

405 def load_private( 

406 self, data: memoryview, pubfields 

407 ) -> tuple[dsa.DSAPrivateKey, memoryview]: 

408 """Make DSA private key from data.""" 

409 (p, q, g, y), data = self.get_public(data) 

410 x, data = _get_mpint(data) 

411 

412 if (p, q, g, y) != pubfields: 

413 raise ValueError("Corrupt data: dsa field mismatch") 

414 parameter_numbers = dsa.DSAParameterNumbers(p, q, g) 

415 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) 

416 self._validate(public_numbers) 

417 private_numbers = dsa.DSAPrivateNumbers(x, public_numbers) 

418 private_key = private_numbers.private_key() 

419 return private_key, data 

420 

421 def encode_public( 

422 self, public_key: dsa.DSAPublicKey, f_pub: _FragList 

423 ) -> None: 

424 """Write DSA public key""" 

425 public_numbers = public_key.public_numbers() 

426 parameter_numbers = public_numbers.parameter_numbers 

427 self._validate(public_numbers) 

428 

429 f_pub.put_mpint(parameter_numbers.p) 

430 f_pub.put_mpint(parameter_numbers.q) 

431 f_pub.put_mpint(parameter_numbers.g) 

432 f_pub.put_mpint(public_numbers.y) 

433 

434 def encode_private( 

435 self, private_key: dsa.DSAPrivateKey, f_priv: _FragList 

436 ) -> None: 

437 """Write DSA private key""" 

438 self.encode_public(private_key.public_key(), f_priv) 

439 f_priv.put_mpint(private_key.private_numbers().x) 

440 

441 def _validate(self, public_numbers: dsa.DSAPublicNumbers) -> None: 

442 parameter_numbers = public_numbers.parameter_numbers 

443 if parameter_numbers.p.bit_length() != 1024: 

444 raise ValueError("SSH supports only 1024 bit DSA keys") 

445 

446 

447class _SSHFormatECDSA: 

448 """Format for ECDSA keys. 

449 

450 Public: 

451 str curve 

452 bytes point 

453 Private: 

454 str curve 

455 bytes point 

456 mpint secret 

457 """ 

458 

459 def __init__(self, ssh_curve_name: bytes, curve: ec.EllipticCurve): 

460 self.ssh_curve_name = ssh_curve_name 

461 self.curve = curve 

462 

463 def get_public( 

464 self, data: memoryview 

465 ) -> tuple[tuple[memoryview, memoryview], memoryview]: 

466 """ECDSA public fields""" 

467 curve, data = _get_sshstr(data) 

468 point, data = _get_sshstr(data) 

469 if curve != self.ssh_curve_name: 

470 raise ValueError("Curve name mismatch") 

471 if point[0] != 4: 

472 raise NotImplementedError("Need uncompressed point") 

473 return (curve, point), data 

474 

475 def load_public( 

476 self, data: memoryview 

477 ) -> tuple[ec.EllipticCurvePublicKey, memoryview]: 

478 """Make ECDSA public key from data.""" 

479 (_, point), data = self.get_public(data) 

480 public_key = ec.EllipticCurvePublicKey.from_encoded_point( 

481 self.curve, point.tobytes() 

482 ) 

483 return public_key, data 

484 

485 def load_private( 

486 self, data: memoryview, pubfields 

487 ) -> tuple[ec.EllipticCurvePrivateKey, memoryview]: 

488 """Make ECDSA private key from data.""" 

489 (curve_name, point), data = self.get_public(data) 

490 secret, data = _get_mpint(data) 

491 

492 if (curve_name, point) != pubfields: 

493 raise ValueError("Corrupt data: ecdsa field mismatch") 

494 private_key = ec.derive_private_key(secret, self.curve) 

495 return private_key, data 

496 

497 def encode_public( 

498 self, public_key: ec.EllipticCurvePublicKey, f_pub: _FragList 

499 ) -> None: 

500 """Write ECDSA public key""" 

501 point = public_key.public_bytes( 

502 Encoding.X962, PublicFormat.UncompressedPoint 

503 ) 

504 f_pub.put_sshstr(self.ssh_curve_name) 

505 f_pub.put_sshstr(point) 

506 

507 def encode_private( 

508 self, private_key: ec.EllipticCurvePrivateKey, f_priv: _FragList 

509 ) -> None: 

510 """Write ECDSA private key""" 

511 public_key = private_key.public_key() 

512 private_numbers = private_key.private_numbers() 

513 

514 self.encode_public(public_key, f_priv) 

515 f_priv.put_mpint(private_numbers.private_value) 

516 

517 

518class _SSHFormatEd25519: 

519 """Format for Ed25519 keys. 

520 

521 Public: 

522 bytes point 

523 Private: 

524 bytes point 

525 bytes secret_and_point 

526 """ 

527 

528 def get_public( 

529 self, data: memoryview 

530 ) -> tuple[tuple[memoryview], memoryview]: 

531 """Ed25519 public fields""" 

532 point, data = _get_sshstr(data) 

533 return (point,), data 

534 

535 def load_public( 

536 self, data: memoryview 

537 ) -> tuple[ed25519.Ed25519PublicKey, memoryview]: 

538 """Make Ed25519 public key from data.""" 

539 (point,), data = self.get_public(data) 

540 public_key = ed25519.Ed25519PublicKey.from_public_bytes( 

541 point.tobytes() 

542 ) 

543 return public_key, data 

544 

545 def load_private( 

546 self, data: memoryview, pubfields 

547 ) -> tuple[ed25519.Ed25519PrivateKey, memoryview]: 

548 """Make Ed25519 private key from data.""" 

549 (point,), data = self.get_public(data) 

550 keypair, data = _get_sshstr(data) 

551 

552 secret = keypair[:32] 

553 point2 = keypair[32:] 

554 if point != point2 or (point,) != pubfields: 

555 raise ValueError("Corrupt data: ed25519 field mismatch") 

556 private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret) 

557 return private_key, data 

558 

559 def encode_public( 

560 self, public_key: ed25519.Ed25519PublicKey, f_pub: _FragList 

561 ) -> None: 

562 """Write Ed25519 public key""" 

563 raw_public_key = public_key.public_bytes( 

564 Encoding.Raw, PublicFormat.Raw 

565 ) 

566 f_pub.put_sshstr(raw_public_key) 

567 

568 def encode_private( 

569 self, private_key: ed25519.Ed25519PrivateKey, f_priv: _FragList 

570 ) -> None: 

571 """Write Ed25519 private key""" 

572 public_key = private_key.public_key() 

573 raw_private_key = private_key.private_bytes( 

574 Encoding.Raw, PrivateFormat.Raw, NoEncryption() 

575 ) 

576 raw_public_key = public_key.public_bytes( 

577 Encoding.Raw, PublicFormat.Raw 

578 ) 

579 f_keypair = _FragList([raw_private_key, raw_public_key]) 

580 

581 self.encode_public(public_key, f_priv) 

582 f_priv.put_sshstr(f_keypair) 

583 

584 

585def load_application(data) -> tuple[memoryview, memoryview]: 

586 """ 

587 U2F application strings 

588 """ 

589 application, data = _get_sshstr(data) 

590 if not application.tobytes().startswith(b"ssh:"): 

591 raise ValueError( 

592 "U2F application string does not start with b'ssh:' " 

593 f"({application})" 

594 ) 

595 return application, data 

596 

597 

598class _SSHFormatSKEd25519: 

599 """ 

600 The format of a sk-ssh-ed25519@openssh.com public key is: 

601 

602 string "sk-ssh-ed25519@openssh.com" 

603 string public key 

604 string application (user-specified, but typically "ssh:") 

605 """ 

606 

607 def load_public( 

608 self, data: memoryview 

609 ) -> tuple[ed25519.Ed25519PublicKey, memoryview]: 

610 """Make Ed25519 public key from data.""" 

611 public_key, data = _lookup_kformat(_SSH_ED25519).load_public(data) 

612 _, data = load_application(data) 

613 return public_key, data 

614 

615 

616class _SSHFormatSKECDSA: 

617 """ 

618 The format of a sk-ecdsa-sha2-nistp256@openssh.com public key is: 

619 

620 string "sk-ecdsa-sha2-nistp256@openssh.com" 

621 string curve name 

622 ec_point Q 

623 string application (user-specified, but typically "ssh:") 

624 """ 

625 

626 def load_public( 

627 self, data: memoryview 

628 ) -> tuple[ec.EllipticCurvePublicKey, memoryview]: 

629 """Make ECDSA public key from data.""" 

630 public_key, data = _lookup_kformat(_ECDSA_NISTP256).load_public(data) 

631 _, data = load_application(data) 

632 return public_key, data 

633 

634 

635_KEY_FORMATS = { 

636 _SSH_RSA: _SSHFormatRSA(), 

637 _SSH_DSA: _SSHFormatDSA(), 

638 _SSH_ED25519: _SSHFormatEd25519(), 

639 _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()), 

640 _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()), 

641 _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()), 

642 _SK_SSH_ED25519: _SSHFormatSKEd25519(), 

643 _SK_SSH_ECDSA_NISTP256: _SSHFormatSKECDSA(), 

644} 

645 

646 

647def _lookup_kformat(key_type: bytes): 

648 """Return valid format or throw error""" 

649 if not isinstance(key_type, bytes): 

650 key_type = memoryview(key_type).tobytes() 

651 if key_type in _KEY_FORMATS: 

652 return _KEY_FORMATS[key_type] 

653 raise UnsupportedAlgorithm(f"Unsupported key type: {key_type!r}") 

654 

655 

656SSHPrivateKeyTypes = typing.Union[ 

657 ec.EllipticCurvePrivateKey, 

658 rsa.RSAPrivateKey, 

659 dsa.DSAPrivateKey, 

660 ed25519.Ed25519PrivateKey, 

661] 

662 

663 

664def load_ssh_private_key( 

665 data: bytes, 

666 password: bytes | None, 

667 backend: typing.Any = None, 

668) -> SSHPrivateKeyTypes: 

669 """Load private key from OpenSSH custom encoding.""" 

670 utils._check_byteslike("data", data) 

671 if password is not None: 

672 utils._check_bytes("password", password) 

673 

674 m = _PEM_RC.search(data) 

675 if not m: 

676 raise ValueError("Not OpenSSH private key format") 

677 p1 = m.start(1) 

678 p2 = m.end(1) 

679 data = binascii.a2b_base64(memoryview(data)[p1:p2]) 

680 if not data.startswith(_SK_MAGIC): 

681 raise ValueError("Not OpenSSH private key format") 

682 data = memoryview(data)[len(_SK_MAGIC) :] 

683 

684 # parse header 

685 ciphername, data = _get_sshstr(data) 

686 kdfname, data = _get_sshstr(data) 

687 kdfoptions, data = _get_sshstr(data) 

688 nkeys, data = _get_u32(data) 

689 if nkeys != 1: 

690 raise ValueError("Only one key supported") 

691 

692 # load public key data 

693 pubdata, data = _get_sshstr(data) 

694 pub_key_type, pubdata = _get_sshstr(pubdata) 

695 kformat = _lookup_kformat(pub_key_type) 

696 pubfields, pubdata = kformat.get_public(pubdata) 

697 _check_empty(pubdata) 

698 

699 if (ciphername, kdfname) != (_NONE, _NONE): 

700 ciphername_bytes = ciphername.tobytes() 

701 if ciphername_bytes not in _SSH_CIPHERS: 

702 raise UnsupportedAlgorithm( 

703 f"Unsupported cipher: {ciphername_bytes!r}" 

704 ) 

705 if kdfname != _BCRYPT: 

706 raise UnsupportedAlgorithm(f"Unsupported KDF: {kdfname!r}") 

707 blklen = _SSH_CIPHERS[ciphername_bytes].block_len 

708 tag_len = _SSH_CIPHERS[ciphername_bytes].tag_len 

709 # load secret data 

710 edata, data = _get_sshstr(data) 

711 # see https://bugzilla.mindrot.org/show_bug.cgi?id=3553 for 

712 # information about how OpenSSH handles AEAD tags 

713 if _SSH_CIPHERS[ciphername_bytes].is_aead: 

714 tag = bytes(data) 

715 if len(tag) != tag_len: 

716 raise ValueError("Corrupt data: invalid tag length for cipher") 

717 else: 

718 _check_empty(data) 

719 _check_block_size(edata, blklen) 

720 salt, kbuf = _get_sshstr(kdfoptions) 

721 rounds, kbuf = _get_u32(kbuf) 

722 _check_empty(kbuf) 

723 ciph = _init_cipher(ciphername_bytes, password, salt.tobytes(), rounds) 

724 dec = ciph.decryptor() 

725 edata = memoryview(dec.update(edata)) 

726 if _SSH_CIPHERS[ciphername_bytes].is_aead: 

727 assert isinstance(dec, AEADDecryptionContext) 

728 _check_empty(dec.finalize_with_tag(tag)) 

729 else: 

730 # _check_block_size requires data to be a full block so there 

731 # should be no output from finalize 

732 _check_empty(dec.finalize()) 

733 else: 

734 # load secret data 

735 edata, data = _get_sshstr(data) 

736 _check_empty(data) 

737 blklen = 8 

738 _check_block_size(edata, blklen) 

739 ck1, edata = _get_u32(edata) 

740 ck2, edata = _get_u32(edata) 

741 if ck1 != ck2: 

742 raise ValueError("Corrupt data: broken checksum") 

743 

744 # load per-key struct 

745 key_type, edata = _get_sshstr(edata) 

746 if key_type != pub_key_type: 

747 raise ValueError("Corrupt data: key type mismatch") 

748 private_key, edata = kformat.load_private(edata, pubfields) 

749 # We don't use the comment 

750 _, edata = _get_sshstr(edata) 

751 

752 # yes, SSH does padding check *after* all other parsing is done. 

753 # need to follow as it writes zero-byte padding too. 

754 if edata != _PADDING[: len(edata)]: 

755 raise ValueError("Corrupt data: invalid padding") 

756 

757 if isinstance(private_key, dsa.DSAPrivateKey): 

758 warnings.warn( 

759 "SSH DSA keys are deprecated and will be removed in a future " 

760 "release.", 

761 utils.DeprecatedIn40, 

762 stacklevel=2, 

763 ) 

764 

765 return private_key 

766 

767 

768def _serialize_ssh_private_key( 

769 private_key: SSHPrivateKeyTypes, 

770 password: bytes, 

771 encryption_algorithm: KeySerializationEncryption, 

772) -> bytes: 

773 """Serialize private key with OpenSSH custom encoding.""" 

774 utils._check_bytes("password", password) 

775 if isinstance(private_key, dsa.DSAPrivateKey): 

776 warnings.warn( 

777 "SSH DSA key support is deprecated and will be " 

778 "removed in a future release", 

779 utils.DeprecatedIn40, 

780 stacklevel=4, 

781 ) 

782 

783 key_type = _get_ssh_key_type(private_key) 

784 kformat = _lookup_kformat(key_type) 

785 

786 # setup parameters 

787 f_kdfoptions = _FragList() 

788 if password: 

789 ciphername = _DEFAULT_CIPHER 

790 blklen = _SSH_CIPHERS[ciphername].block_len 

791 kdfname = _BCRYPT 

792 rounds = _DEFAULT_ROUNDS 

793 if ( 

794 isinstance(encryption_algorithm, _KeySerializationEncryption) 

795 and encryption_algorithm._kdf_rounds is not None 

796 ): 

797 rounds = encryption_algorithm._kdf_rounds 

798 salt = os.urandom(16) 

799 f_kdfoptions.put_sshstr(salt) 

800 f_kdfoptions.put_u32(rounds) 

801 ciph = _init_cipher(ciphername, password, salt, rounds) 

802 else: 

803 ciphername = kdfname = _NONE 

804 blklen = 8 

805 ciph = None 

806 nkeys = 1 

807 checkval = os.urandom(4) 

808 comment = b"" 

809 

810 # encode public and private parts together 

811 f_public_key = _FragList() 

812 f_public_key.put_sshstr(key_type) 

813 kformat.encode_public(private_key.public_key(), f_public_key) 

814 

815 f_secrets = _FragList([checkval, checkval]) 

816 f_secrets.put_sshstr(key_type) 

817 kformat.encode_private(private_key, f_secrets) 

818 f_secrets.put_sshstr(comment) 

819 f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)]) 

820 

821 # top-level structure 

822 f_main = _FragList() 

823 f_main.put_raw(_SK_MAGIC) 

824 f_main.put_sshstr(ciphername) 

825 f_main.put_sshstr(kdfname) 

826 f_main.put_sshstr(f_kdfoptions) 

827 f_main.put_u32(nkeys) 

828 f_main.put_sshstr(f_public_key) 

829 f_main.put_sshstr(f_secrets) 

830 

831 # copy result info bytearray 

832 slen = f_secrets.size() 

833 mlen = f_main.size() 

834 buf = memoryview(bytearray(mlen + blklen)) 

835 f_main.render(buf) 

836 ofs = mlen - slen 

837 

838 # encrypt in-place 

839 if ciph is not None: 

840 ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:]) 

841 

842 return _ssh_pem_encode(buf[:mlen]) 

843 

844 

845SSHPublicKeyTypes = typing.Union[ 

846 ec.EllipticCurvePublicKey, 

847 rsa.RSAPublicKey, 

848 dsa.DSAPublicKey, 

849 ed25519.Ed25519PublicKey, 

850] 

851 

852SSHCertPublicKeyTypes = typing.Union[ 

853 ec.EllipticCurvePublicKey, 

854 rsa.RSAPublicKey, 

855 ed25519.Ed25519PublicKey, 

856] 

857 

858 

859class SSHCertificateType(enum.Enum): 

860 USER = 1 

861 HOST = 2 

862 

863 

864class SSHCertificate: 

865 def __init__( 

866 self, 

867 _nonce: memoryview, 

868 _public_key: SSHPublicKeyTypes, 

869 _serial: int, 

870 _cctype: int, 

871 _key_id: memoryview, 

872 _valid_principals: list[bytes], 

873 _valid_after: int, 

874 _valid_before: int, 

875 _critical_options: dict[bytes, bytes], 

876 _extensions: dict[bytes, bytes], 

877 _sig_type: memoryview, 

878 _sig_key: memoryview, 

879 _inner_sig_type: memoryview, 

880 _signature: memoryview, 

881 _tbs_cert_body: memoryview, 

882 _cert_key_type: bytes, 

883 _cert_body: memoryview, 

884 ): 

885 self._nonce = _nonce 

886 self._public_key = _public_key 

887 self._serial = _serial 

888 try: 

889 self._type = SSHCertificateType(_cctype) 

890 except ValueError: 

891 raise ValueError("Invalid certificate type") 

892 self._key_id = _key_id 

893 self._valid_principals = _valid_principals 

894 self._valid_after = _valid_after 

895 self._valid_before = _valid_before 

896 self._critical_options = _critical_options 

897 self._extensions = _extensions 

898 self._sig_type = _sig_type 

899 self._sig_key = _sig_key 

900 self._inner_sig_type = _inner_sig_type 

901 self._signature = _signature 

902 self._cert_key_type = _cert_key_type 

903 self._cert_body = _cert_body 

904 self._tbs_cert_body = _tbs_cert_body 

905 

906 @property 

907 def nonce(self) -> bytes: 

908 return bytes(self._nonce) 

909 

910 def public_key(self) -> SSHCertPublicKeyTypes: 

911 # make mypy happy until we remove DSA support entirely and 

912 # the underlying union won't have a disallowed type 

913 return typing.cast(SSHCertPublicKeyTypes, self._public_key) 

914 

915 @property 

916 def serial(self) -> int: 

917 return self._serial 

918 

919 @property 

920 def type(self) -> SSHCertificateType: 

921 return self._type 

922 

923 @property 

924 def key_id(self) -> bytes: 

925 return bytes(self._key_id) 

926 

927 @property 

928 def valid_principals(self) -> list[bytes]: 

929 return self._valid_principals 

930 

931 @property 

932 def valid_before(self) -> int: 

933 return self._valid_before 

934 

935 @property 

936 def valid_after(self) -> int: 

937 return self._valid_after 

938 

939 @property 

940 def critical_options(self) -> dict[bytes, bytes]: 

941 return self._critical_options 

942 

943 @property 

944 def extensions(self) -> dict[bytes, bytes]: 

945 return self._extensions 

946 

947 def signature_key(self) -> SSHCertPublicKeyTypes: 

948 sigformat = _lookup_kformat(self._sig_type) 

949 signature_key, sigkey_rest = sigformat.load_public(self._sig_key) 

950 _check_empty(sigkey_rest) 

951 return signature_key 

952 

953 def public_bytes(self) -> bytes: 

954 return ( 

955 bytes(self._cert_key_type) 

956 + b" " 

957 + binascii.b2a_base64(bytes(self._cert_body), newline=False) 

958 ) 

959 

960 def verify_cert_signature(self) -> None: 

961 signature_key = self.signature_key() 

962 if isinstance(signature_key, ed25519.Ed25519PublicKey): 

963 signature_key.verify( 

964 bytes(self._signature), bytes(self._tbs_cert_body) 

965 ) 

966 elif isinstance(signature_key, ec.EllipticCurvePublicKey): 

967 # The signature is encoded as a pair of big-endian integers 

968 r, data = _get_mpint(self._signature) 

969 s, data = _get_mpint(data) 

970 _check_empty(data) 

971 computed_sig = asym_utils.encode_dss_signature(r, s) 

972 hash_alg = _get_ec_hash_alg(signature_key.curve) 

973 signature_key.verify( 

974 computed_sig, bytes(self._tbs_cert_body), ec.ECDSA(hash_alg) 

975 ) 

976 else: 

977 assert isinstance(signature_key, rsa.RSAPublicKey) 

978 if self._inner_sig_type == _SSH_RSA: 

979 hash_alg = hashes.SHA1() 

980 elif self._inner_sig_type == _SSH_RSA_SHA256: 

981 hash_alg = hashes.SHA256() 

982 else: 

983 assert self._inner_sig_type == _SSH_RSA_SHA512 

984 hash_alg = hashes.SHA512() 

985 signature_key.verify( 

986 bytes(self._signature), 

987 bytes(self._tbs_cert_body), 

988 padding.PKCS1v15(), 

989 hash_alg, 

990 ) 

991 

992 

993def _get_ec_hash_alg(curve: ec.EllipticCurve) -> hashes.HashAlgorithm: 

994 if isinstance(curve, ec.SECP256R1): 

995 return hashes.SHA256() 

996 elif isinstance(curve, ec.SECP384R1): 

997 return hashes.SHA384() 

998 else: 

999 assert isinstance(curve, ec.SECP521R1) 

1000 return hashes.SHA512() 

1001 

1002 

1003def _load_ssh_public_identity( 

1004 data: bytes, 

1005 _legacy_dsa_allowed=False, 

1006) -> SSHCertificate | SSHPublicKeyTypes: 

1007 utils._check_byteslike("data", data) 

1008 

1009 m = _SSH_PUBKEY_RC.match(data) 

1010 if not m: 

1011 raise ValueError("Invalid line format") 

1012 key_type = orig_key_type = m.group(1) 

1013 key_body = m.group(2) 

1014 with_cert = False 

1015 if key_type.endswith(_CERT_SUFFIX): 

1016 with_cert = True 

1017 key_type = key_type[: -len(_CERT_SUFFIX)] 

1018 if key_type == _SSH_DSA and not _legacy_dsa_allowed: 

1019 raise UnsupportedAlgorithm( 

1020 "DSA keys aren't supported in SSH certificates" 

1021 ) 

1022 kformat = _lookup_kformat(key_type) 

1023 

1024 try: 

1025 rest = memoryview(binascii.a2b_base64(key_body)) 

1026 except (TypeError, binascii.Error): 

1027 raise ValueError("Invalid format") 

1028 

1029 if with_cert: 

1030 cert_body = rest 

1031 inner_key_type, rest = _get_sshstr(rest) 

1032 if inner_key_type != orig_key_type: 

1033 raise ValueError("Invalid key format") 

1034 if with_cert: 

1035 nonce, rest = _get_sshstr(rest) 

1036 public_key, rest = kformat.load_public(rest) 

1037 if with_cert: 

1038 serial, rest = _get_u64(rest) 

1039 cctype, rest = _get_u32(rest) 

1040 key_id, rest = _get_sshstr(rest) 

1041 principals, rest = _get_sshstr(rest) 

1042 valid_principals = [] 

1043 while principals: 

1044 principal, principals = _get_sshstr(principals) 

1045 valid_principals.append(bytes(principal)) 

1046 valid_after, rest = _get_u64(rest) 

1047 valid_before, rest = _get_u64(rest) 

1048 crit_options, rest = _get_sshstr(rest) 

1049 critical_options = _parse_exts_opts(crit_options) 

1050 exts, rest = _get_sshstr(rest) 

1051 extensions = _parse_exts_opts(exts) 

1052 # Get the reserved field, which is unused. 

1053 _, rest = _get_sshstr(rest) 

1054 sig_key_raw, rest = _get_sshstr(rest) 

1055 sig_type, sig_key = _get_sshstr(sig_key_raw) 

1056 if sig_type == _SSH_DSA and not _legacy_dsa_allowed: 

1057 raise UnsupportedAlgorithm( 

1058 "DSA signatures aren't supported in SSH certificates" 

1059 ) 

1060 # Get the entire cert body and subtract the signature 

1061 tbs_cert_body = cert_body[: -len(rest)] 

1062 signature_raw, rest = _get_sshstr(rest) 

1063 _check_empty(rest) 

1064 inner_sig_type, sig_rest = _get_sshstr(signature_raw) 

1065 # RSA certs can have multiple algorithm types 

1066 if ( 

1067 sig_type == _SSH_RSA 

1068 and inner_sig_type 

1069 not in [_SSH_RSA_SHA256, _SSH_RSA_SHA512, _SSH_RSA] 

1070 ) or (sig_type != _SSH_RSA and inner_sig_type != sig_type): 

1071 raise ValueError("Signature key type does not match") 

1072 signature, sig_rest = _get_sshstr(sig_rest) 

1073 _check_empty(sig_rest) 

1074 return SSHCertificate( 

1075 nonce, 

1076 public_key, 

1077 serial, 

1078 cctype, 

1079 key_id, 

1080 valid_principals, 

1081 valid_after, 

1082 valid_before, 

1083 critical_options, 

1084 extensions, 

1085 sig_type, 

1086 sig_key, 

1087 inner_sig_type, 

1088 signature, 

1089 tbs_cert_body, 

1090 orig_key_type, 

1091 cert_body, 

1092 ) 

1093 else: 

1094 _check_empty(rest) 

1095 return public_key 

1096 

1097 

1098def load_ssh_public_identity( 

1099 data: bytes, 

1100) -> SSHCertificate | SSHPublicKeyTypes: 

1101 return _load_ssh_public_identity(data) 

1102 

1103 

1104def _parse_exts_opts(exts_opts: memoryview) -> dict[bytes, bytes]: 

1105 result: dict[bytes, bytes] = {} 

1106 last_name = None 

1107 while exts_opts: 

1108 name, exts_opts = _get_sshstr(exts_opts) 

1109 bname: bytes = bytes(name) 

1110 if bname in result: 

1111 raise ValueError("Duplicate name") 

1112 if last_name is not None and bname < last_name: 

1113 raise ValueError("Fields not lexically sorted") 

1114 value, exts_opts = _get_sshstr(exts_opts) 

1115 if len(value) > 0: 

1116 value, extra = _get_sshstr(value) 

1117 if len(extra) > 0: 

1118 raise ValueError("Unexpected extra data after value") 

1119 result[bname] = bytes(value) 

1120 last_name = bname 

1121 return result 

1122 

1123 

1124def load_ssh_public_key( 

1125 data: bytes, backend: typing.Any = None 

1126) -> SSHPublicKeyTypes: 

1127 cert_or_key = _load_ssh_public_identity(data, _legacy_dsa_allowed=True) 

1128 public_key: SSHPublicKeyTypes 

1129 if isinstance(cert_or_key, SSHCertificate): 

1130 public_key = cert_or_key.public_key() 

1131 else: 

1132 public_key = cert_or_key 

1133 

1134 if isinstance(public_key, dsa.DSAPublicKey): 

1135 warnings.warn( 

1136 "SSH DSA keys are deprecated and will be removed in a future " 

1137 "release.", 

1138 utils.DeprecatedIn40, 

1139 stacklevel=2, 

1140 ) 

1141 return public_key 

1142 

1143 

1144def serialize_ssh_public_key(public_key: SSHPublicKeyTypes) -> bytes: 

1145 """One-line public key format for OpenSSH""" 

1146 if isinstance(public_key, dsa.DSAPublicKey): 

1147 warnings.warn( 

1148 "SSH DSA key support is deprecated and will be " 

1149 "removed in a future release", 

1150 utils.DeprecatedIn40, 

1151 stacklevel=4, 

1152 ) 

1153 key_type = _get_ssh_key_type(public_key) 

1154 kformat = _lookup_kformat(key_type) 

1155 

1156 f_pub = _FragList() 

1157 f_pub.put_sshstr(key_type) 

1158 kformat.encode_public(public_key, f_pub) 

1159 

1160 pub = binascii.b2a_base64(f_pub.tobytes()).strip() 

1161 return b"".join([key_type, b" ", pub]) 

1162 

1163 

1164SSHCertPrivateKeyTypes = typing.Union[ 

1165 ec.EllipticCurvePrivateKey, 

1166 rsa.RSAPrivateKey, 

1167 ed25519.Ed25519PrivateKey, 

1168] 

1169 

1170 

1171# This is an undocumented limit enforced in the openssh codebase for sshd and 

1172# ssh-keygen, but it is undefined in the ssh certificates spec. 

1173_SSHKEY_CERT_MAX_PRINCIPALS = 256 

1174 

1175 

1176class SSHCertificateBuilder: 

1177 def __init__( 

1178 self, 

1179 _public_key: SSHCertPublicKeyTypes | None = None, 

1180 _serial: int | None = None, 

1181 _type: SSHCertificateType | None = None, 

1182 _key_id: bytes | None = None, 

1183 _valid_principals: list[bytes] = [], 

1184 _valid_for_all_principals: bool = False, 

1185 _valid_before: int | None = None, 

1186 _valid_after: int | None = None, 

1187 _critical_options: list[tuple[bytes, bytes]] = [], 

1188 _extensions: list[tuple[bytes, bytes]] = [], 

1189 ): 

1190 self._public_key = _public_key 

1191 self._serial = _serial 

1192 self._type = _type 

1193 self._key_id = _key_id 

1194 self._valid_principals = _valid_principals 

1195 self._valid_for_all_principals = _valid_for_all_principals 

1196 self._valid_before = _valid_before 

1197 self._valid_after = _valid_after 

1198 self._critical_options = _critical_options 

1199 self._extensions = _extensions 

1200 

1201 def public_key( 

1202 self, public_key: SSHCertPublicKeyTypes 

1203 ) -> SSHCertificateBuilder: 

1204 if not isinstance( 

1205 public_key, 

1206 ( 

1207 ec.EllipticCurvePublicKey, 

1208 rsa.RSAPublicKey, 

1209 ed25519.Ed25519PublicKey, 

1210 ), 

1211 ): 

1212 raise TypeError("Unsupported key type") 

1213 if self._public_key is not None: 

1214 raise ValueError("public_key already set") 

1215 

1216 return SSHCertificateBuilder( 

1217 _public_key=public_key, 

1218 _serial=self._serial, 

1219 _type=self._type, 

1220 _key_id=self._key_id, 

1221 _valid_principals=self._valid_principals, 

1222 _valid_for_all_principals=self._valid_for_all_principals, 

1223 _valid_before=self._valid_before, 

1224 _valid_after=self._valid_after, 

1225 _critical_options=self._critical_options, 

1226 _extensions=self._extensions, 

1227 ) 

1228 

1229 def serial(self, serial: int) -> SSHCertificateBuilder: 

1230 if not isinstance(serial, int): 

1231 raise TypeError("serial must be an integer") 

1232 if not 0 <= serial < 2**64: 

1233 raise ValueError("serial must be between 0 and 2**64") 

1234 if self._serial is not None: 

1235 raise ValueError("serial already set") 

1236 

1237 return SSHCertificateBuilder( 

1238 _public_key=self._public_key, 

1239 _serial=serial, 

1240 _type=self._type, 

1241 _key_id=self._key_id, 

1242 _valid_principals=self._valid_principals, 

1243 _valid_for_all_principals=self._valid_for_all_principals, 

1244 _valid_before=self._valid_before, 

1245 _valid_after=self._valid_after, 

1246 _critical_options=self._critical_options, 

1247 _extensions=self._extensions, 

1248 ) 

1249 

1250 def type(self, type: SSHCertificateType) -> SSHCertificateBuilder: 

1251 if not isinstance(type, SSHCertificateType): 

1252 raise TypeError("type must be an SSHCertificateType") 

1253 if self._type is not None: 

1254 raise ValueError("type already set") 

1255 

1256 return SSHCertificateBuilder( 

1257 _public_key=self._public_key, 

1258 _serial=self._serial, 

1259 _type=type, 

1260 _key_id=self._key_id, 

1261 _valid_principals=self._valid_principals, 

1262 _valid_for_all_principals=self._valid_for_all_principals, 

1263 _valid_before=self._valid_before, 

1264 _valid_after=self._valid_after, 

1265 _critical_options=self._critical_options, 

1266 _extensions=self._extensions, 

1267 ) 

1268 

1269 def key_id(self, key_id: bytes) -> SSHCertificateBuilder: 

1270 if not isinstance(key_id, bytes): 

1271 raise TypeError("key_id must be bytes") 

1272 if self._key_id is not None: 

1273 raise ValueError("key_id already set") 

1274 

1275 return SSHCertificateBuilder( 

1276 _public_key=self._public_key, 

1277 _serial=self._serial, 

1278 _type=self._type, 

1279 _key_id=key_id, 

1280 _valid_principals=self._valid_principals, 

1281 _valid_for_all_principals=self._valid_for_all_principals, 

1282 _valid_before=self._valid_before, 

1283 _valid_after=self._valid_after, 

1284 _critical_options=self._critical_options, 

1285 _extensions=self._extensions, 

1286 ) 

1287 

1288 def valid_principals( 

1289 self, valid_principals: list[bytes] 

1290 ) -> SSHCertificateBuilder: 

1291 if self._valid_for_all_principals: 

1292 raise ValueError( 

1293 "Principals can't be set because the cert is valid " 

1294 "for all principals" 

1295 ) 

1296 if ( 

1297 not all(isinstance(x, bytes) for x in valid_principals) 

1298 or not valid_principals 

1299 ): 

1300 raise TypeError( 

1301 "principals must be a list of bytes and can't be empty" 

1302 ) 

1303 if self._valid_principals: 

1304 raise ValueError("valid_principals already set") 

1305 

1306 if len(valid_principals) > _SSHKEY_CERT_MAX_PRINCIPALS: 

1307 raise ValueError( 

1308 "Reached or exceeded the maximum number of valid_principals" 

1309 ) 

1310 

1311 return SSHCertificateBuilder( 

1312 _public_key=self._public_key, 

1313 _serial=self._serial, 

1314 _type=self._type, 

1315 _key_id=self._key_id, 

1316 _valid_principals=valid_principals, 

1317 _valid_for_all_principals=self._valid_for_all_principals, 

1318 _valid_before=self._valid_before, 

1319 _valid_after=self._valid_after, 

1320 _critical_options=self._critical_options, 

1321 _extensions=self._extensions, 

1322 ) 

1323 

1324 def valid_for_all_principals(self): 

1325 if self._valid_principals: 

1326 raise ValueError( 

1327 "valid_principals already set, can't set " 

1328 "valid_for_all_principals" 

1329 ) 

1330 if self._valid_for_all_principals: 

1331 raise ValueError("valid_for_all_principals already set") 

1332 

1333 return SSHCertificateBuilder( 

1334 _public_key=self._public_key, 

1335 _serial=self._serial, 

1336 _type=self._type, 

1337 _key_id=self._key_id, 

1338 _valid_principals=self._valid_principals, 

1339 _valid_for_all_principals=True, 

1340 _valid_before=self._valid_before, 

1341 _valid_after=self._valid_after, 

1342 _critical_options=self._critical_options, 

1343 _extensions=self._extensions, 

1344 ) 

1345 

1346 def valid_before(self, valid_before: int | float) -> SSHCertificateBuilder: 

1347 if not isinstance(valid_before, (int, float)): 

1348 raise TypeError("valid_before must be an int or float") 

1349 valid_before = int(valid_before) 

1350 if valid_before < 0 or valid_before >= 2**64: 

1351 raise ValueError("valid_before must [0, 2**64)") 

1352 if self._valid_before is not None: 

1353 raise ValueError("valid_before already set") 

1354 

1355 return SSHCertificateBuilder( 

1356 _public_key=self._public_key, 

1357 _serial=self._serial, 

1358 _type=self._type, 

1359 _key_id=self._key_id, 

1360 _valid_principals=self._valid_principals, 

1361 _valid_for_all_principals=self._valid_for_all_principals, 

1362 _valid_before=valid_before, 

1363 _valid_after=self._valid_after, 

1364 _critical_options=self._critical_options, 

1365 _extensions=self._extensions, 

1366 ) 

1367 

1368 def valid_after(self, valid_after: int | float) -> SSHCertificateBuilder: 

1369 if not isinstance(valid_after, (int, float)): 

1370 raise TypeError("valid_after must be an int or float") 

1371 valid_after = int(valid_after) 

1372 if valid_after < 0 or valid_after >= 2**64: 

1373 raise ValueError("valid_after must [0, 2**64)") 

1374 if self._valid_after is not None: 

1375 raise ValueError("valid_after already set") 

1376 

1377 return SSHCertificateBuilder( 

1378 _public_key=self._public_key, 

1379 _serial=self._serial, 

1380 _type=self._type, 

1381 _key_id=self._key_id, 

1382 _valid_principals=self._valid_principals, 

1383 _valid_for_all_principals=self._valid_for_all_principals, 

1384 _valid_before=self._valid_before, 

1385 _valid_after=valid_after, 

1386 _critical_options=self._critical_options, 

1387 _extensions=self._extensions, 

1388 ) 

1389 

1390 def add_critical_option( 

1391 self, name: bytes, value: bytes 

1392 ) -> SSHCertificateBuilder: 

1393 if not isinstance(name, bytes) or not isinstance(value, bytes): 

1394 raise TypeError("name and value must be bytes") 

1395 # This is O(n**2) 

1396 if name in [name for name, _ in self._critical_options]: 

1397 raise ValueError("Duplicate critical option name") 

1398 

1399 return SSHCertificateBuilder( 

1400 _public_key=self._public_key, 

1401 _serial=self._serial, 

1402 _type=self._type, 

1403 _key_id=self._key_id, 

1404 _valid_principals=self._valid_principals, 

1405 _valid_for_all_principals=self._valid_for_all_principals, 

1406 _valid_before=self._valid_before, 

1407 _valid_after=self._valid_after, 

1408 _critical_options=[*self._critical_options, (name, value)], 

1409 _extensions=self._extensions, 

1410 ) 

1411 

1412 def add_extension( 

1413 self, name: bytes, value: bytes 

1414 ) -> SSHCertificateBuilder: 

1415 if not isinstance(name, bytes) or not isinstance(value, bytes): 

1416 raise TypeError("name and value must be bytes") 

1417 # This is O(n**2) 

1418 if name in [name for name, _ in self._extensions]: 

1419 raise ValueError("Duplicate extension name") 

1420 

1421 return SSHCertificateBuilder( 

1422 _public_key=self._public_key, 

1423 _serial=self._serial, 

1424 _type=self._type, 

1425 _key_id=self._key_id, 

1426 _valid_principals=self._valid_principals, 

1427 _valid_for_all_principals=self._valid_for_all_principals, 

1428 _valid_before=self._valid_before, 

1429 _valid_after=self._valid_after, 

1430 _critical_options=self._critical_options, 

1431 _extensions=[*self._extensions, (name, value)], 

1432 ) 

1433 

1434 def sign(self, private_key: SSHCertPrivateKeyTypes) -> SSHCertificate: 

1435 if not isinstance( 

1436 private_key, 

1437 ( 

1438 ec.EllipticCurvePrivateKey, 

1439 rsa.RSAPrivateKey, 

1440 ed25519.Ed25519PrivateKey, 

1441 ), 

1442 ): 

1443 raise TypeError("Unsupported private key type") 

1444 

1445 if self._public_key is None: 

1446 raise ValueError("public_key must be set") 

1447 

1448 # Not required 

1449 serial = 0 if self._serial is None else self._serial 

1450 

1451 if self._type is None: 

1452 raise ValueError("type must be set") 

1453 

1454 # Not required 

1455 key_id = b"" if self._key_id is None else self._key_id 

1456 

1457 # A zero length list is valid, but means the certificate 

1458 # is valid for any principal of the specified type. We require 

1459 # the user to explicitly set valid_for_all_principals to get 

1460 # that behavior. 

1461 if not self._valid_principals and not self._valid_for_all_principals: 

1462 raise ValueError( 

1463 "valid_principals must be set if valid_for_all_principals " 

1464 "is False" 

1465 ) 

1466 

1467 if self._valid_before is None: 

1468 raise ValueError("valid_before must be set") 

1469 

1470 if self._valid_after is None: 

1471 raise ValueError("valid_after must be set") 

1472 

1473 if self._valid_after > self._valid_before: 

1474 raise ValueError("valid_after must be earlier than valid_before") 

1475 

1476 # lexically sort our byte strings 

1477 self._critical_options.sort(key=lambda x: x[0]) 

1478 self._extensions.sort(key=lambda x: x[0]) 

1479 

1480 key_type = _get_ssh_key_type(self._public_key) 

1481 cert_prefix = key_type + _CERT_SUFFIX 

1482 

1483 # Marshal the bytes to be signed 

1484 nonce = os.urandom(32) 

1485 kformat = _lookup_kformat(key_type) 

1486 f = _FragList() 

1487 f.put_sshstr(cert_prefix) 

1488 f.put_sshstr(nonce) 

1489 kformat.encode_public(self._public_key, f) 

1490 f.put_u64(serial) 

1491 f.put_u32(self._type.value) 

1492 f.put_sshstr(key_id) 

1493 fprincipals = _FragList() 

1494 for p in self._valid_principals: 

1495 fprincipals.put_sshstr(p) 

1496 f.put_sshstr(fprincipals.tobytes()) 

1497 f.put_u64(self._valid_after) 

1498 f.put_u64(self._valid_before) 

1499 fcrit = _FragList() 

1500 for name, value in self._critical_options: 

1501 fcrit.put_sshstr(name) 

1502 if len(value) > 0: 

1503 foptval = _FragList() 

1504 foptval.put_sshstr(value) 

1505 fcrit.put_sshstr(foptval.tobytes()) 

1506 else: 

1507 fcrit.put_sshstr(value) 

1508 f.put_sshstr(fcrit.tobytes()) 

1509 fext = _FragList() 

1510 for name, value in self._extensions: 

1511 fext.put_sshstr(name) 

1512 if len(value) > 0: 

1513 fextval = _FragList() 

1514 fextval.put_sshstr(value) 

1515 fext.put_sshstr(fextval.tobytes()) 

1516 else: 

1517 fext.put_sshstr(value) 

1518 f.put_sshstr(fext.tobytes()) 

1519 f.put_sshstr(b"") # RESERVED FIELD 

1520 # encode CA public key 

1521 ca_type = _get_ssh_key_type(private_key) 

1522 caformat = _lookup_kformat(ca_type) 

1523 caf = _FragList() 

1524 caf.put_sshstr(ca_type) 

1525 caformat.encode_public(private_key.public_key(), caf) 

1526 f.put_sshstr(caf.tobytes()) 

1527 # Sigs according to the rules defined for the CA's public key 

1528 # (RFC4253 section 6.6 for ssh-rsa, RFC5656 for ECDSA, 

1529 # and RFC8032 for Ed25519). 

1530 if isinstance(private_key, ed25519.Ed25519PrivateKey): 

1531 signature = private_key.sign(f.tobytes()) 

1532 fsig = _FragList() 

1533 fsig.put_sshstr(ca_type) 

1534 fsig.put_sshstr(signature) 

1535 f.put_sshstr(fsig.tobytes()) 

1536 elif isinstance(private_key, ec.EllipticCurvePrivateKey): 

1537 hash_alg = _get_ec_hash_alg(private_key.curve) 

1538 signature = private_key.sign(f.tobytes(), ec.ECDSA(hash_alg)) 

1539 r, s = asym_utils.decode_dss_signature(signature) 

1540 fsig = _FragList() 

1541 fsig.put_sshstr(ca_type) 

1542 fsigblob = _FragList() 

1543 fsigblob.put_mpint(r) 

1544 fsigblob.put_mpint(s) 

1545 fsig.put_sshstr(fsigblob.tobytes()) 

1546 f.put_sshstr(fsig.tobytes()) 

1547 

1548 else: 

1549 assert isinstance(private_key, rsa.RSAPrivateKey) 

1550 # Just like Golang, we're going to use SHA512 for RSA 

1551 # https://cs.opensource.google/go/x/crypto/+/refs/tags/ 

1552 # v0.4.0:ssh/certs.go;l=445 

1553 # RFC 8332 defines SHA256 and 512 as options 

1554 fsig = _FragList() 

1555 fsig.put_sshstr(_SSH_RSA_SHA512) 

1556 signature = private_key.sign( 

1557 f.tobytes(), padding.PKCS1v15(), hashes.SHA512() 

1558 ) 

1559 fsig.put_sshstr(signature) 

1560 f.put_sshstr(fsig.tobytes()) 

1561 

1562 cert_data = binascii.b2a_base64(f.tobytes()).strip() 

1563 # load_ssh_public_identity returns a union, but this is 

1564 # guaranteed to be an SSHCertificate, so we cast to make 

1565 # mypy happy. 

1566 return typing.cast( 

1567 SSHCertificate, 

1568 load_ssh_public_identity(b"".join([cert_prefix, b" ", cert_data])), 

1569 )