Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/cryptography/hazmat/primitives/serialization/ssh.py: 21%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

758 statements  

1# This file is dual licensed under the terms of the Apache License, Version 

2# 2.0, and the BSD License. See the LICENSE file in the root of this repository 

3# for complete details. 

4 

5from __future__ import annotations 

6 

7import binascii 

8import enum 

9import os 

10import re 

11import typing 

12import warnings 

13from base64 import encodebytes as _base64_encode 

14from dataclasses import dataclass 

15 

16from cryptography import utils 

17from cryptography.exceptions import UnsupportedAlgorithm 

18from cryptography.hazmat.primitives import hashes 

19from cryptography.hazmat.primitives.asymmetric import ( 

20 dsa, 

21 ec, 

22 ed25519, 

23 padding, 

24 rsa, 

25) 

26from cryptography.hazmat.primitives.asymmetric import utils as asym_utils 

27from cryptography.hazmat.primitives.ciphers import ( 

28 AEADDecryptionContext, 

29 Cipher, 

30 algorithms, 

31 modes, 

32) 

33from cryptography.hazmat.primitives.serialization import ( 

34 Encoding, 

35 KeySerializationEncryption, 

36 NoEncryption, 

37 PrivateFormat, 

38 PublicFormat, 

39 _KeySerializationEncryption, 

40) 

41 

42try: 

43 from bcrypt import kdf as _bcrypt_kdf 

44 

45 _bcrypt_supported = True 

46except ImportError: 

47 _bcrypt_supported = False 

48 

49 def _bcrypt_kdf( 

50 password: bytes, 

51 salt: bytes, 

52 desired_key_bytes: int, 

53 rounds: int, 

54 ignore_few_rounds: bool = False, 

55 ) -> bytes: 

56 raise UnsupportedAlgorithm("Need bcrypt module") 

57 

58 

59_SSH_ED25519 = b"ssh-ed25519" 

60_SSH_RSA = b"ssh-rsa" 

61_SSH_DSA = b"ssh-dss" 

62_ECDSA_NISTP256 = b"ecdsa-sha2-nistp256" 

63_ECDSA_NISTP384 = b"ecdsa-sha2-nistp384" 

64_ECDSA_NISTP521 = b"ecdsa-sha2-nistp521" 

65_CERT_SUFFIX = b"-cert-v01@openssh.com" 

66 

67# These are not key types, only algorithms, so they cannot appear 

68# as a public key type 

69_SSH_RSA_SHA256 = b"rsa-sha2-256" 

70_SSH_RSA_SHA512 = b"rsa-sha2-512" 

71 

72_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)") 

73_SK_MAGIC = b"openssh-key-v1\0" 

74_SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----" 

75_SK_END = b"-----END OPENSSH PRIVATE KEY-----" 

76_BCRYPT = b"bcrypt" 

77_NONE = b"none" 

78_DEFAULT_CIPHER = b"aes256-ctr" 

79_DEFAULT_ROUNDS = 16 

80 

81# re is only way to work on bytes-like data 

82_PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL) 

83 

84# padding for max blocksize 

85_PADDING = memoryview(bytearray(range(1, 1 + 16))) 

86 

87 

88@dataclass 

89class _SSHCipher: 

90 alg: type[algorithms.AES] 

91 key_len: int 

92 mode: type[modes.CTR] | type[modes.CBC] | type[modes.GCM] 

93 block_len: int 

94 iv_len: int 

95 tag_len: int | None 

96 is_aead: bool 

97 

98 

99# ciphers that are actually used in key wrapping 

100_SSH_CIPHERS: dict[bytes, _SSHCipher] = { 

101 b"aes256-ctr": _SSHCipher( 

102 alg=algorithms.AES, 

103 key_len=32, 

104 mode=modes.CTR, 

105 block_len=16, 

106 iv_len=16, 

107 tag_len=None, 

108 is_aead=False, 

109 ), 

110 b"aes256-cbc": _SSHCipher( 

111 alg=algorithms.AES, 

112 key_len=32, 

113 mode=modes.CBC, 

114 block_len=16, 

115 iv_len=16, 

116 tag_len=None, 

117 is_aead=False, 

118 ), 

119 b"aes256-gcm@openssh.com": _SSHCipher( 

120 alg=algorithms.AES, 

121 key_len=32, 

122 mode=modes.GCM, 

123 block_len=16, 

124 iv_len=12, 

125 tag_len=16, 

126 is_aead=True, 

127 ), 

128} 

129 

130# map local curve name to key type 

131_ECDSA_KEY_TYPE = { 

132 "secp256r1": _ECDSA_NISTP256, 

133 "secp384r1": _ECDSA_NISTP384, 

134 "secp521r1": _ECDSA_NISTP521, 

135} 

136 

137 

138def _get_ssh_key_type(key: SSHPrivateKeyTypes | SSHPublicKeyTypes) -> bytes: 

139 if isinstance(key, ec.EllipticCurvePrivateKey): 

140 key_type = _ecdsa_key_type(key.public_key()) 

141 elif isinstance(key, ec.EllipticCurvePublicKey): 

142 key_type = _ecdsa_key_type(key) 

143 elif isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)): 

144 key_type = _SSH_RSA 

145 elif isinstance(key, (dsa.DSAPrivateKey, dsa.DSAPublicKey)): 

146 key_type = _SSH_DSA 

147 elif isinstance( 

148 key, (ed25519.Ed25519PrivateKey, ed25519.Ed25519PublicKey) 

149 ): 

150 key_type = _SSH_ED25519 

151 else: 

152 raise ValueError("Unsupported key type") 

153 

154 return key_type 

155 

156 

157def _ecdsa_key_type(public_key: ec.EllipticCurvePublicKey) -> bytes: 

158 """Return SSH key_type and curve_name for private key.""" 

159 curve = public_key.curve 

160 if curve.name not in _ECDSA_KEY_TYPE: 

161 raise ValueError( 

162 f"Unsupported curve for ssh private key: {curve.name!r}" 

163 ) 

164 return _ECDSA_KEY_TYPE[curve.name] 

165 

166 

167def _ssh_pem_encode( 

168 data: bytes, 

169 prefix: bytes = _SK_START + b"\n", 

170 suffix: bytes = _SK_END + b"\n", 

171) -> bytes: 

172 return b"".join([prefix, _base64_encode(data), suffix]) 

173 

174 

175def _check_block_size(data: bytes, block_len: int) -> None: 

176 """Require data to be full blocks""" 

177 if not data or len(data) % block_len != 0: 

178 raise ValueError("Corrupt data: missing padding") 

179 

180 

181def _check_empty(data: bytes) -> None: 

182 """All data should have been parsed.""" 

183 if data: 

184 raise ValueError("Corrupt data: unparsed data") 

185 

186 

187def _init_cipher( 

188 ciphername: bytes, 

189 password: bytes | None, 

190 salt: bytes, 

191 rounds: int, 

192) -> Cipher[modes.CBC | modes.CTR | modes.GCM]: 

193 """Generate key + iv and return cipher.""" 

194 if not password: 

195 raise ValueError("Key is password-protected.") 

196 

197 ciph = _SSH_CIPHERS[ciphername] 

198 seed = _bcrypt_kdf( 

199 password, salt, ciph.key_len + ciph.iv_len, rounds, True 

200 ) 

201 return Cipher( 

202 ciph.alg(seed[: ciph.key_len]), 

203 ciph.mode(seed[ciph.key_len :]), 

204 ) 

205 

206 

207def _get_u32(data: memoryview) -> tuple[int, memoryview]: 

208 """Uint32""" 

209 if len(data) < 4: 

210 raise ValueError("Invalid data") 

211 return int.from_bytes(data[:4], byteorder="big"), data[4:] 

212 

213 

214def _get_u64(data: memoryview) -> tuple[int, memoryview]: 

215 """Uint64""" 

216 if len(data) < 8: 

217 raise ValueError("Invalid data") 

218 return int.from_bytes(data[:8], byteorder="big"), data[8:] 

219 

220 

221def _get_sshstr(data: memoryview) -> tuple[memoryview, memoryview]: 

222 """Bytes with u32 length prefix""" 

223 n, data = _get_u32(data) 

224 if n > len(data): 

225 raise ValueError("Invalid data") 

226 return data[:n], data[n:] 

227 

228 

229def _get_mpint(data: memoryview) -> tuple[int, memoryview]: 

230 """Big integer.""" 

231 val, data = _get_sshstr(data) 

232 if val and val[0] > 0x7F: 

233 raise ValueError("Invalid data") 

234 return int.from_bytes(val, "big"), data 

235 

236 

237def _to_mpint(val: int) -> bytes: 

238 """Storage format for signed bigint.""" 

239 if val < 0: 

240 raise ValueError("negative mpint not allowed") 

241 if not val: 

242 return b"" 

243 nbytes = (val.bit_length() + 8) // 8 

244 return utils.int_to_bytes(val, nbytes) 

245 

246 

247class _FragList: 

248 """Build recursive structure without data copy.""" 

249 

250 flist: list[bytes] 

251 

252 def __init__(self, init: list[bytes] | None = None) -> None: 

253 self.flist = [] 

254 if init: 

255 self.flist.extend(init) 

256 

257 def put_raw(self, val: bytes) -> None: 

258 """Add plain bytes""" 

259 self.flist.append(val) 

260 

261 def put_u32(self, val: int) -> None: 

262 """Big-endian uint32""" 

263 self.flist.append(val.to_bytes(length=4, byteorder="big")) 

264 

265 def put_u64(self, val: int) -> None: 

266 """Big-endian uint64""" 

267 self.flist.append(val.to_bytes(length=8, byteorder="big")) 

268 

269 def put_sshstr(self, val: bytes | _FragList) -> None: 

270 """Bytes prefixed with u32 length""" 

271 if isinstance(val, (bytes, memoryview, bytearray)): 

272 self.put_u32(len(val)) 

273 self.flist.append(val) 

274 else: 

275 self.put_u32(val.size()) 

276 self.flist.extend(val.flist) 

277 

278 def put_mpint(self, val: int) -> None: 

279 """Big-endian bigint prefixed with u32 length""" 

280 self.put_sshstr(_to_mpint(val)) 

281 

282 def size(self) -> int: 

283 """Current number of bytes""" 

284 return sum(map(len, self.flist)) 

285 

286 def render(self, dstbuf: memoryview, pos: int = 0) -> int: 

287 """Write into bytearray""" 

288 for frag in self.flist: 

289 flen = len(frag) 

290 start, pos = pos, pos + flen 

291 dstbuf[start:pos] = frag 

292 return pos 

293 

294 def tobytes(self) -> bytes: 

295 """Return as bytes""" 

296 buf = memoryview(bytearray(self.size())) 

297 self.render(buf) 

298 return buf.tobytes() 

299 

300 

301class _SSHFormatRSA: 

302 """Format for RSA keys. 

303 

304 Public: 

305 mpint e, n 

306 Private: 

307 mpint n, e, d, iqmp, p, q 

308 """ 

309 

310 def get_public(self, data: memoryview): 

311 """RSA public fields""" 

312 e, data = _get_mpint(data) 

313 n, data = _get_mpint(data) 

314 return (e, n), data 

315 

316 def load_public( 

317 self, data: memoryview 

318 ) -> tuple[rsa.RSAPublicKey, memoryview]: 

319 """Make RSA public key from data.""" 

320 (e, n), data = self.get_public(data) 

321 public_numbers = rsa.RSAPublicNumbers(e, n) 

322 public_key = public_numbers.public_key() 

323 return public_key, data 

324 

325 def load_private( 

326 self, data: memoryview, pubfields 

327 ) -> tuple[rsa.RSAPrivateKey, memoryview]: 

328 """Make RSA private key from data.""" 

329 n, data = _get_mpint(data) 

330 e, data = _get_mpint(data) 

331 d, data = _get_mpint(data) 

332 iqmp, data = _get_mpint(data) 

333 p, data = _get_mpint(data) 

334 q, data = _get_mpint(data) 

335 

336 if (e, n) != pubfields: 

337 raise ValueError("Corrupt data: rsa field mismatch") 

338 dmp1 = rsa.rsa_crt_dmp1(d, p) 

339 dmq1 = rsa.rsa_crt_dmq1(d, q) 

340 public_numbers = rsa.RSAPublicNumbers(e, n) 

341 private_numbers = rsa.RSAPrivateNumbers( 

342 p, q, d, dmp1, dmq1, iqmp, public_numbers 

343 ) 

344 private_key = private_numbers.private_key() 

345 return private_key, data 

346 

347 def encode_public( 

348 self, public_key: rsa.RSAPublicKey, f_pub: _FragList 

349 ) -> None: 

350 """Write RSA public key""" 

351 pubn = public_key.public_numbers() 

352 f_pub.put_mpint(pubn.e) 

353 f_pub.put_mpint(pubn.n) 

354 

355 def encode_private( 

356 self, private_key: rsa.RSAPrivateKey, f_priv: _FragList 

357 ) -> None: 

358 """Write RSA private key""" 

359 private_numbers = private_key.private_numbers() 

360 public_numbers = private_numbers.public_numbers 

361 

362 f_priv.put_mpint(public_numbers.n) 

363 f_priv.put_mpint(public_numbers.e) 

364 

365 f_priv.put_mpint(private_numbers.d) 

366 f_priv.put_mpint(private_numbers.iqmp) 

367 f_priv.put_mpint(private_numbers.p) 

368 f_priv.put_mpint(private_numbers.q) 

369 

370 

371class _SSHFormatDSA: 

372 """Format for DSA keys. 

373 

374 Public: 

375 mpint p, q, g, y 

376 Private: 

377 mpint p, q, g, y, x 

378 """ 

379 

380 def get_public(self, data: memoryview) -> tuple[tuple, memoryview]: 

381 """DSA public fields""" 

382 p, data = _get_mpint(data) 

383 q, data = _get_mpint(data) 

384 g, data = _get_mpint(data) 

385 y, data = _get_mpint(data) 

386 return (p, q, g, y), data 

387 

388 def load_public( 

389 self, data: memoryview 

390 ) -> tuple[dsa.DSAPublicKey, memoryview]: 

391 """Make DSA public key from data.""" 

392 (p, q, g, y), data = self.get_public(data) 

393 parameter_numbers = dsa.DSAParameterNumbers(p, q, g) 

394 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) 

395 self._validate(public_numbers) 

396 public_key = public_numbers.public_key() 

397 return public_key, data 

398 

399 def load_private( 

400 self, data: memoryview, pubfields 

401 ) -> tuple[dsa.DSAPrivateKey, memoryview]: 

402 """Make DSA private key from data.""" 

403 (p, q, g, y), data = self.get_public(data) 

404 x, data = _get_mpint(data) 

405 

406 if (p, q, g, y) != pubfields: 

407 raise ValueError("Corrupt data: dsa field mismatch") 

408 parameter_numbers = dsa.DSAParameterNumbers(p, q, g) 

409 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) 

410 self._validate(public_numbers) 

411 private_numbers = dsa.DSAPrivateNumbers(x, public_numbers) 

412 private_key = private_numbers.private_key() 

413 return private_key, data 

414 

415 def encode_public( 

416 self, public_key: dsa.DSAPublicKey, f_pub: _FragList 

417 ) -> None: 

418 """Write DSA public key""" 

419 public_numbers = public_key.public_numbers() 

420 parameter_numbers = public_numbers.parameter_numbers 

421 self._validate(public_numbers) 

422 

423 f_pub.put_mpint(parameter_numbers.p) 

424 f_pub.put_mpint(parameter_numbers.q) 

425 f_pub.put_mpint(parameter_numbers.g) 

426 f_pub.put_mpint(public_numbers.y) 

427 

428 def encode_private( 

429 self, private_key: dsa.DSAPrivateKey, f_priv: _FragList 

430 ) -> None: 

431 """Write DSA private key""" 

432 self.encode_public(private_key.public_key(), f_priv) 

433 f_priv.put_mpint(private_key.private_numbers().x) 

434 

435 def _validate(self, public_numbers: dsa.DSAPublicNumbers) -> None: 

436 parameter_numbers = public_numbers.parameter_numbers 

437 if parameter_numbers.p.bit_length() != 1024: 

438 raise ValueError("SSH supports only 1024 bit DSA keys") 

439 

440 

441class _SSHFormatECDSA: 

442 """Format for ECDSA keys. 

443 

444 Public: 

445 str curve 

446 bytes point 

447 Private: 

448 str curve 

449 bytes point 

450 mpint secret 

451 """ 

452 

453 def __init__(self, ssh_curve_name: bytes, curve: ec.EllipticCurve): 

454 self.ssh_curve_name = ssh_curve_name 

455 self.curve = curve 

456 

457 def get_public(self, data: memoryview) -> tuple[tuple, memoryview]: 

458 """ECDSA public fields""" 

459 curve, data = _get_sshstr(data) 

460 point, data = _get_sshstr(data) 

461 if curve != self.ssh_curve_name: 

462 raise ValueError("Curve name mismatch") 

463 if point[0] != 4: 

464 raise NotImplementedError("Need uncompressed point") 

465 return (curve, point), data 

466 

467 def load_public( 

468 self, data: memoryview 

469 ) -> tuple[ec.EllipticCurvePublicKey, memoryview]: 

470 """Make ECDSA public key from data.""" 

471 (_, point), data = self.get_public(data) 

472 public_key = ec.EllipticCurvePublicKey.from_encoded_point( 

473 self.curve, point.tobytes() 

474 ) 

475 return public_key, data 

476 

477 def load_private( 

478 self, data: memoryview, pubfields 

479 ) -> tuple[ec.EllipticCurvePrivateKey, memoryview]: 

480 """Make ECDSA private key from data.""" 

481 (curve_name, point), data = self.get_public(data) 

482 secret, data = _get_mpint(data) 

483 

484 if (curve_name, point) != pubfields: 

485 raise ValueError("Corrupt data: ecdsa field mismatch") 

486 private_key = ec.derive_private_key(secret, self.curve) 

487 return private_key, data 

488 

489 def encode_public( 

490 self, public_key: ec.EllipticCurvePublicKey, f_pub: _FragList 

491 ) -> None: 

492 """Write ECDSA public key""" 

493 point = public_key.public_bytes( 

494 Encoding.X962, PublicFormat.UncompressedPoint 

495 ) 

496 f_pub.put_sshstr(self.ssh_curve_name) 

497 f_pub.put_sshstr(point) 

498 

499 def encode_private( 

500 self, private_key: ec.EllipticCurvePrivateKey, f_priv: _FragList 

501 ) -> None: 

502 """Write ECDSA private key""" 

503 public_key = private_key.public_key() 

504 private_numbers = private_key.private_numbers() 

505 

506 self.encode_public(public_key, f_priv) 

507 f_priv.put_mpint(private_numbers.private_value) 

508 

509 

510class _SSHFormatEd25519: 

511 """Format for Ed25519 keys. 

512 

513 Public: 

514 bytes point 

515 Private: 

516 bytes point 

517 bytes secret_and_point 

518 """ 

519 

520 def get_public(self, data: memoryview) -> tuple[tuple, memoryview]: 

521 """Ed25519 public fields""" 

522 point, data = _get_sshstr(data) 

523 return (point,), data 

524 

525 def load_public( 

526 self, data: memoryview 

527 ) -> tuple[ed25519.Ed25519PublicKey, memoryview]: 

528 """Make Ed25519 public key from data.""" 

529 (point,), data = self.get_public(data) 

530 public_key = ed25519.Ed25519PublicKey.from_public_bytes( 

531 point.tobytes() 

532 ) 

533 return public_key, data 

534 

535 def load_private( 

536 self, data: memoryview, pubfields 

537 ) -> tuple[ed25519.Ed25519PrivateKey, memoryview]: 

538 """Make Ed25519 private key from data.""" 

539 (point,), data = self.get_public(data) 

540 keypair, data = _get_sshstr(data) 

541 

542 secret = keypair[:32] 

543 point2 = keypair[32:] 

544 if point != point2 or (point,) != pubfields: 

545 raise ValueError("Corrupt data: ed25519 field mismatch") 

546 private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret) 

547 return private_key, data 

548 

549 def encode_public( 

550 self, public_key: ed25519.Ed25519PublicKey, f_pub: _FragList 

551 ) -> None: 

552 """Write Ed25519 public key""" 

553 raw_public_key = public_key.public_bytes( 

554 Encoding.Raw, PublicFormat.Raw 

555 ) 

556 f_pub.put_sshstr(raw_public_key) 

557 

558 def encode_private( 

559 self, private_key: ed25519.Ed25519PrivateKey, f_priv: _FragList 

560 ) -> None: 

561 """Write Ed25519 private key""" 

562 public_key = private_key.public_key() 

563 raw_private_key = private_key.private_bytes( 

564 Encoding.Raw, PrivateFormat.Raw, NoEncryption() 

565 ) 

566 raw_public_key = public_key.public_bytes( 

567 Encoding.Raw, PublicFormat.Raw 

568 ) 

569 f_keypair = _FragList([raw_private_key, raw_public_key]) 

570 

571 self.encode_public(public_key, f_priv) 

572 f_priv.put_sshstr(f_keypair) 

573 

574 

575_KEY_FORMATS = { 

576 _SSH_RSA: _SSHFormatRSA(), 

577 _SSH_DSA: _SSHFormatDSA(), 

578 _SSH_ED25519: _SSHFormatEd25519(), 

579 _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()), 

580 _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()), 

581 _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()), 

582} 

583 

584 

585def _lookup_kformat(key_type: bytes): 

586 """Return valid format or throw error""" 

587 if not isinstance(key_type, bytes): 

588 key_type = memoryview(key_type).tobytes() 

589 if key_type in _KEY_FORMATS: 

590 return _KEY_FORMATS[key_type] 

591 raise UnsupportedAlgorithm(f"Unsupported key type: {key_type!r}") 

592 

593 

594SSHPrivateKeyTypes = typing.Union[ 

595 ec.EllipticCurvePrivateKey, 

596 rsa.RSAPrivateKey, 

597 dsa.DSAPrivateKey, 

598 ed25519.Ed25519PrivateKey, 

599] 

600 

601 

602def load_ssh_private_key( 

603 data: bytes, 

604 password: bytes | None, 

605 backend: typing.Any = None, 

606) -> SSHPrivateKeyTypes: 

607 """Load private key from OpenSSH custom encoding.""" 

608 utils._check_byteslike("data", data) 

609 if password is not None: 

610 utils._check_bytes("password", password) 

611 

612 m = _PEM_RC.search(data) 

613 if not m: 

614 raise ValueError("Not OpenSSH private key format") 

615 p1 = m.start(1) 

616 p2 = m.end(1) 

617 data = binascii.a2b_base64(memoryview(data)[p1:p2]) 

618 if not data.startswith(_SK_MAGIC): 

619 raise ValueError("Not OpenSSH private key format") 

620 data = memoryview(data)[len(_SK_MAGIC) :] 

621 

622 # parse header 

623 ciphername, data = _get_sshstr(data) 

624 kdfname, data = _get_sshstr(data) 

625 kdfoptions, data = _get_sshstr(data) 

626 nkeys, data = _get_u32(data) 

627 if nkeys != 1: 

628 raise ValueError("Only one key supported") 

629 

630 # load public key data 

631 pubdata, data = _get_sshstr(data) 

632 pub_key_type, pubdata = _get_sshstr(pubdata) 

633 kformat = _lookup_kformat(pub_key_type) 

634 pubfields, pubdata = kformat.get_public(pubdata) 

635 _check_empty(pubdata) 

636 

637 if (ciphername, kdfname) != (_NONE, _NONE): 

638 ciphername_bytes = ciphername.tobytes() 

639 if ciphername_bytes not in _SSH_CIPHERS: 

640 raise UnsupportedAlgorithm( 

641 f"Unsupported cipher: {ciphername_bytes!r}" 

642 ) 

643 if kdfname != _BCRYPT: 

644 raise UnsupportedAlgorithm(f"Unsupported KDF: {kdfname!r}") 

645 blklen = _SSH_CIPHERS[ciphername_bytes].block_len 

646 tag_len = _SSH_CIPHERS[ciphername_bytes].tag_len 

647 # load secret data 

648 edata, data = _get_sshstr(data) 

649 # see https://bugzilla.mindrot.org/show_bug.cgi?id=3553 for 

650 # information about how OpenSSH handles AEAD tags 

651 if _SSH_CIPHERS[ciphername_bytes].is_aead: 

652 tag = bytes(data) 

653 if len(tag) != tag_len: 

654 raise ValueError("Corrupt data: invalid tag length for cipher") 

655 else: 

656 _check_empty(data) 

657 _check_block_size(edata, blklen) 

658 salt, kbuf = _get_sshstr(kdfoptions) 

659 rounds, kbuf = _get_u32(kbuf) 

660 _check_empty(kbuf) 

661 ciph = _init_cipher(ciphername_bytes, password, salt.tobytes(), rounds) 

662 dec = ciph.decryptor() 

663 edata = memoryview(dec.update(edata)) 

664 if _SSH_CIPHERS[ciphername_bytes].is_aead: 

665 assert isinstance(dec, AEADDecryptionContext) 

666 _check_empty(dec.finalize_with_tag(tag)) 

667 else: 

668 # _check_block_size requires data to be a full block so there 

669 # should be no output from finalize 

670 _check_empty(dec.finalize()) 

671 else: 

672 # load secret data 

673 edata, data = _get_sshstr(data) 

674 _check_empty(data) 

675 blklen = 8 

676 _check_block_size(edata, blklen) 

677 ck1, edata = _get_u32(edata) 

678 ck2, edata = _get_u32(edata) 

679 if ck1 != ck2: 

680 raise ValueError("Corrupt data: broken checksum") 

681 

682 # load per-key struct 

683 key_type, edata = _get_sshstr(edata) 

684 if key_type != pub_key_type: 

685 raise ValueError("Corrupt data: key type mismatch") 

686 private_key, edata = kformat.load_private(edata, pubfields) 

687 # We don't use the comment 

688 _, edata = _get_sshstr(edata) 

689 

690 # yes, SSH does padding check *after* all other parsing is done. 

691 # need to follow as it writes zero-byte padding too. 

692 if edata != _PADDING[: len(edata)]: 

693 raise ValueError("Corrupt data: invalid padding") 

694 

695 if isinstance(private_key, dsa.DSAPrivateKey): 

696 warnings.warn( 

697 "SSH DSA keys are deprecated and will be removed in a future " 

698 "release.", 

699 utils.DeprecatedIn40, 

700 stacklevel=2, 

701 ) 

702 

703 return private_key 

704 

705 

706def _serialize_ssh_private_key( 

707 private_key: SSHPrivateKeyTypes, 

708 password: bytes, 

709 encryption_algorithm: KeySerializationEncryption, 

710) -> bytes: 

711 """Serialize private key with OpenSSH custom encoding.""" 

712 utils._check_bytes("password", password) 

713 if isinstance(private_key, dsa.DSAPrivateKey): 

714 warnings.warn( 

715 "SSH DSA key support is deprecated and will be " 

716 "removed in a future release", 

717 utils.DeprecatedIn40, 

718 stacklevel=4, 

719 ) 

720 

721 key_type = _get_ssh_key_type(private_key) 

722 kformat = _lookup_kformat(key_type) 

723 

724 # setup parameters 

725 f_kdfoptions = _FragList() 

726 if password: 

727 ciphername = _DEFAULT_CIPHER 

728 blklen = _SSH_CIPHERS[ciphername].block_len 

729 kdfname = _BCRYPT 

730 rounds = _DEFAULT_ROUNDS 

731 if ( 

732 isinstance(encryption_algorithm, _KeySerializationEncryption) 

733 and encryption_algorithm._kdf_rounds is not None 

734 ): 

735 rounds = encryption_algorithm._kdf_rounds 

736 salt = os.urandom(16) 

737 f_kdfoptions.put_sshstr(salt) 

738 f_kdfoptions.put_u32(rounds) 

739 ciph = _init_cipher(ciphername, password, salt, rounds) 

740 else: 

741 ciphername = kdfname = _NONE 

742 blklen = 8 

743 ciph = None 

744 nkeys = 1 

745 checkval = os.urandom(4) 

746 comment = b"" 

747 

748 # encode public and private parts together 

749 f_public_key = _FragList() 

750 f_public_key.put_sshstr(key_type) 

751 kformat.encode_public(private_key.public_key(), f_public_key) 

752 

753 f_secrets = _FragList([checkval, checkval]) 

754 f_secrets.put_sshstr(key_type) 

755 kformat.encode_private(private_key, f_secrets) 

756 f_secrets.put_sshstr(comment) 

757 f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)]) 

758 

759 # top-level structure 

760 f_main = _FragList() 

761 f_main.put_raw(_SK_MAGIC) 

762 f_main.put_sshstr(ciphername) 

763 f_main.put_sshstr(kdfname) 

764 f_main.put_sshstr(f_kdfoptions) 

765 f_main.put_u32(nkeys) 

766 f_main.put_sshstr(f_public_key) 

767 f_main.put_sshstr(f_secrets) 

768 

769 # copy result info bytearray 

770 slen = f_secrets.size() 

771 mlen = f_main.size() 

772 buf = memoryview(bytearray(mlen + blklen)) 

773 f_main.render(buf) 

774 ofs = mlen - slen 

775 

776 # encrypt in-place 

777 if ciph is not None: 

778 ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:]) 

779 

780 return _ssh_pem_encode(buf[:mlen]) 

781 

782 

783SSHPublicKeyTypes = typing.Union[ 

784 ec.EllipticCurvePublicKey, 

785 rsa.RSAPublicKey, 

786 dsa.DSAPublicKey, 

787 ed25519.Ed25519PublicKey, 

788] 

789 

790SSHCertPublicKeyTypes = typing.Union[ 

791 ec.EllipticCurvePublicKey, 

792 rsa.RSAPublicKey, 

793 ed25519.Ed25519PublicKey, 

794] 

795 

796 

797class SSHCertificateType(enum.Enum): 

798 USER = 1 

799 HOST = 2 

800 

801 

802class SSHCertificate: 

803 def __init__( 

804 self, 

805 _nonce: memoryview, 

806 _public_key: SSHPublicKeyTypes, 

807 _serial: int, 

808 _cctype: int, 

809 _key_id: memoryview, 

810 _valid_principals: list[bytes], 

811 _valid_after: int, 

812 _valid_before: int, 

813 _critical_options: dict[bytes, bytes], 

814 _extensions: dict[bytes, bytes], 

815 _sig_type: memoryview, 

816 _sig_key: memoryview, 

817 _inner_sig_type: memoryview, 

818 _signature: memoryview, 

819 _tbs_cert_body: memoryview, 

820 _cert_key_type: bytes, 

821 _cert_body: memoryview, 

822 ): 

823 self._nonce = _nonce 

824 self._public_key = _public_key 

825 self._serial = _serial 

826 try: 

827 self._type = SSHCertificateType(_cctype) 

828 except ValueError: 

829 raise ValueError("Invalid certificate type") 

830 self._key_id = _key_id 

831 self._valid_principals = _valid_principals 

832 self._valid_after = _valid_after 

833 self._valid_before = _valid_before 

834 self._critical_options = _critical_options 

835 self._extensions = _extensions 

836 self._sig_type = _sig_type 

837 self._sig_key = _sig_key 

838 self._inner_sig_type = _inner_sig_type 

839 self._signature = _signature 

840 self._cert_key_type = _cert_key_type 

841 self._cert_body = _cert_body 

842 self._tbs_cert_body = _tbs_cert_body 

843 

844 @property 

845 def nonce(self) -> bytes: 

846 return bytes(self._nonce) 

847 

848 def public_key(self) -> SSHCertPublicKeyTypes: 

849 # make mypy happy until we remove DSA support entirely and 

850 # the underlying union won't have a disallowed type 

851 return typing.cast(SSHCertPublicKeyTypes, self._public_key) 

852 

853 @property 

854 def serial(self) -> int: 

855 return self._serial 

856 

857 @property 

858 def type(self) -> SSHCertificateType: 

859 return self._type 

860 

861 @property 

862 def key_id(self) -> bytes: 

863 return bytes(self._key_id) 

864 

865 @property 

866 def valid_principals(self) -> list[bytes]: 

867 return self._valid_principals 

868 

869 @property 

870 def valid_before(self) -> int: 

871 return self._valid_before 

872 

873 @property 

874 def valid_after(self) -> int: 

875 return self._valid_after 

876 

877 @property 

878 def critical_options(self) -> dict[bytes, bytes]: 

879 return self._critical_options 

880 

881 @property 

882 def extensions(self) -> dict[bytes, bytes]: 

883 return self._extensions 

884 

885 def signature_key(self) -> SSHCertPublicKeyTypes: 

886 sigformat = _lookup_kformat(self._sig_type) 

887 signature_key, sigkey_rest = sigformat.load_public(self._sig_key) 

888 _check_empty(sigkey_rest) 

889 return signature_key 

890 

891 def public_bytes(self) -> bytes: 

892 return ( 

893 bytes(self._cert_key_type) 

894 + b" " 

895 + binascii.b2a_base64(bytes(self._cert_body), newline=False) 

896 ) 

897 

898 def verify_cert_signature(self) -> None: 

899 signature_key = self.signature_key() 

900 if isinstance(signature_key, ed25519.Ed25519PublicKey): 

901 signature_key.verify( 

902 bytes(self._signature), bytes(self._tbs_cert_body) 

903 ) 

904 elif isinstance(signature_key, ec.EllipticCurvePublicKey): 

905 # The signature is encoded as a pair of big-endian integers 

906 r, data = _get_mpint(self._signature) 

907 s, data = _get_mpint(data) 

908 _check_empty(data) 

909 computed_sig = asym_utils.encode_dss_signature(r, s) 

910 hash_alg = _get_ec_hash_alg(signature_key.curve) 

911 signature_key.verify( 

912 computed_sig, bytes(self._tbs_cert_body), ec.ECDSA(hash_alg) 

913 ) 

914 else: 

915 assert isinstance(signature_key, rsa.RSAPublicKey) 

916 if self._inner_sig_type == _SSH_RSA: 

917 hash_alg = hashes.SHA1() 

918 elif self._inner_sig_type == _SSH_RSA_SHA256: 

919 hash_alg = hashes.SHA256() 

920 else: 

921 assert self._inner_sig_type == _SSH_RSA_SHA512 

922 hash_alg = hashes.SHA512() 

923 signature_key.verify( 

924 bytes(self._signature), 

925 bytes(self._tbs_cert_body), 

926 padding.PKCS1v15(), 

927 hash_alg, 

928 ) 

929 

930 

931def _get_ec_hash_alg(curve: ec.EllipticCurve) -> hashes.HashAlgorithm: 

932 if isinstance(curve, ec.SECP256R1): 

933 return hashes.SHA256() 

934 elif isinstance(curve, ec.SECP384R1): 

935 return hashes.SHA384() 

936 else: 

937 assert isinstance(curve, ec.SECP521R1) 

938 return hashes.SHA512() 

939 

940 

941def _load_ssh_public_identity( 

942 data: bytes, 

943 _legacy_dsa_allowed=False, 

944) -> SSHCertificate | SSHPublicKeyTypes: 

945 utils._check_byteslike("data", data) 

946 

947 m = _SSH_PUBKEY_RC.match(data) 

948 if not m: 

949 raise ValueError("Invalid line format") 

950 key_type = orig_key_type = m.group(1) 

951 key_body = m.group(2) 

952 with_cert = False 

953 if key_type.endswith(_CERT_SUFFIX): 

954 with_cert = True 

955 key_type = key_type[: -len(_CERT_SUFFIX)] 

956 if key_type == _SSH_DSA and not _legacy_dsa_allowed: 

957 raise UnsupportedAlgorithm( 

958 "DSA keys aren't supported in SSH certificates" 

959 ) 

960 kformat = _lookup_kformat(key_type) 

961 

962 try: 

963 rest = memoryview(binascii.a2b_base64(key_body)) 

964 except (TypeError, binascii.Error): 

965 raise ValueError("Invalid format") 

966 

967 if with_cert: 

968 cert_body = rest 

969 inner_key_type, rest = _get_sshstr(rest) 

970 if inner_key_type != orig_key_type: 

971 raise ValueError("Invalid key format") 

972 if with_cert: 

973 nonce, rest = _get_sshstr(rest) 

974 public_key, rest = kformat.load_public(rest) 

975 if with_cert: 

976 serial, rest = _get_u64(rest) 

977 cctype, rest = _get_u32(rest) 

978 key_id, rest = _get_sshstr(rest) 

979 principals, rest = _get_sshstr(rest) 

980 valid_principals = [] 

981 while principals: 

982 principal, principals = _get_sshstr(principals) 

983 valid_principals.append(bytes(principal)) 

984 valid_after, rest = _get_u64(rest) 

985 valid_before, rest = _get_u64(rest) 

986 crit_options, rest = _get_sshstr(rest) 

987 critical_options = _parse_exts_opts(crit_options) 

988 exts, rest = _get_sshstr(rest) 

989 extensions = _parse_exts_opts(exts) 

990 # Get the reserved field, which is unused. 

991 _, rest = _get_sshstr(rest) 

992 sig_key_raw, rest = _get_sshstr(rest) 

993 sig_type, sig_key = _get_sshstr(sig_key_raw) 

994 if sig_type == _SSH_DSA and not _legacy_dsa_allowed: 

995 raise UnsupportedAlgorithm( 

996 "DSA signatures aren't supported in SSH certificates" 

997 ) 

998 # Get the entire cert body and subtract the signature 

999 tbs_cert_body = cert_body[: -len(rest)] 

1000 signature_raw, rest = _get_sshstr(rest) 

1001 _check_empty(rest) 

1002 inner_sig_type, sig_rest = _get_sshstr(signature_raw) 

1003 # RSA certs can have multiple algorithm types 

1004 if ( 

1005 sig_type == _SSH_RSA 

1006 and inner_sig_type 

1007 not in [_SSH_RSA_SHA256, _SSH_RSA_SHA512, _SSH_RSA] 

1008 ) or (sig_type != _SSH_RSA and inner_sig_type != sig_type): 

1009 raise ValueError("Signature key type does not match") 

1010 signature, sig_rest = _get_sshstr(sig_rest) 

1011 _check_empty(sig_rest) 

1012 return SSHCertificate( 

1013 nonce, 

1014 public_key, 

1015 serial, 

1016 cctype, 

1017 key_id, 

1018 valid_principals, 

1019 valid_after, 

1020 valid_before, 

1021 critical_options, 

1022 extensions, 

1023 sig_type, 

1024 sig_key, 

1025 inner_sig_type, 

1026 signature, 

1027 tbs_cert_body, 

1028 orig_key_type, 

1029 cert_body, 

1030 ) 

1031 else: 

1032 _check_empty(rest) 

1033 return public_key 

1034 

1035 

1036def load_ssh_public_identity( 

1037 data: bytes, 

1038) -> SSHCertificate | SSHPublicKeyTypes: 

1039 return _load_ssh_public_identity(data) 

1040 

1041 

1042def _parse_exts_opts(exts_opts: memoryview) -> dict[bytes, bytes]: 

1043 result: dict[bytes, bytes] = {} 

1044 last_name = None 

1045 while exts_opts: 

1046 name, exts_opts = _get_sshstr(exts_opts) 

1047 bname: bytes = bytes(name) 

1048 if bname in result: 

1049 raise ValueError("Duplicate name") 

1050 if last_name is not None and bname < last_name: 

1051 raise ValueError("Fields not lexically sorted") 

1052 value, exts_opts = _get_sshstr(exts_opts) 

1053 if len(value) > 0: 

1054 value, extra = _get_sshstr(value) 

1055 if len(extra) > 0: 

1056 raise ValueError("Unexpected extra data after value") 

1057 result[bname] = bytes(value) 

1058 last_name = bname 

1059 return result 

1060 

1061 

1062def load_ssh_public_key( 

1063 data: bytes, backend: typing.Any = None 

1064) -> SSHPublicKeyTypes: 

1065 cert_or_key = _load_ssh_public_identity(data, _legacy_dsa_allowed=True) 

1066 public_key: SSHPublicKeyTypes 

1067 if isinstance(cert_or_key, SSHCertificate): 

1068 public_key = cert_or_key.public_key() 

1069 else: 

1070 public_key = cert_or_key 

1071 

1072 if isinstance(public_key, dsa.DSAPublicKey): 

1073 warnings.warn( 

1074 "SSH DSA keys are deprecated and will be removed in a future " 

1075 "release.", 

1076 utils.DeprecatedIn40, 

1077 stacklevel=2, 

1078 ) 

1079 return public_key 

1080 

1081 

1082def serialize_ssh_public_key(public_key: SSHPublicKeyTypes) -> bytes: 

1083 """One-line public key format for OpenSSH""" 

1084 if isinstance(public_key, dsa.DSAPublicKey): 

1085 warnings.warn( 

1086 "SSH DSA key support is deprecated and will be " 

1087 "removed in a future release", 

1088 utils.DeprecatedIn40, 

1089 stacklevel=4, 

1090 ) 

1091 key_type = _get_ssh_key_type(public_key) 

1092 kformat = _lookup_kformat(key_type) 

1093 

1094 f_pub = _FragList() 

1095 f_pub.put_sshstr(key_type) 

1096 kformat.encode_public(public_key, f_pub) 

1097 

1098 pub = binascii.b2a_base64(f_pub.tobytes()).strip() 

1099 return b"".join([key_type, b" ", pub]) 

1100 

1101 

1102SSHCertPrivateKeyTypes = typing.Union[ 

1103 ec.EllipticCurvePrivateKey, 

1104 rsa.RSAPrivateKey, 

1105 ed25519.Ed25519PrivateKey, 

1106] 

1107 

1108 

1109# This is an undocumented limit enforced in the openssh codebase for sshd and 

1110# ssh-keygen, but it is undefined in the ssh certificates spec. 

1111_SSHKEY_CERT_MAX_PRINCIPALS = 256 

1112 

1113 

1114class SSHCertificateBuilder: 

1115 def __init__( 

1116 self, 

1117 _public_key: SSHCertPublicKeyTypes | None = None, 

1118 _serial: int | None = None, 

1119 _type: SSHCertificateType | None = None, 

1120 _key_id: bytes | None = None, 

1121 _valid_principals: list[bytes] = [], 

1122 _valid_for_all_principals: bool = False, 

1123 _valid_before: int | None = None, 

1124 _valid_after: int | None = None, 

1125 _critical_options: list[tuple[bytes, bytes]] = [], 

1126 _extensions: list[tuple[bytes, bytes]] = [], 

1127 ): 

1128 self._public_key = _public_key 

1129 self._serial = _serial 

1130 self._type = _type 

1131 self._key_id = _key_id 

1132 self._valid_principals = _valid_principals 

1133 self._valid_for_all_principals = _valid_for_all_principals 

1134 self._valid_before = _valid_before 

1135 self._valid_after = _valid_after 

1136 self._critical_options = _critical_options 

1137 self._extensions = _extensions 

1138 

1139 def public_key( 

1140 self, public_key: SSHCertPublicKeyTypes 

1141 ) -> SSHCertificateBuilder: 

1142 if not isinstance( 

1143 public_key, 

1144 ( 

1145 ec.EllipticCurvePublicKey, 

1146 rsa.RSAPublicKey, 

1147 ed25519.Ed25519PublicKey, 

1148 ), 

1149 ): 

1150 raise TypeError("Unsupported key type") 

1151 if self._public_key is not None: 

1152 raise ValueError("public_key already set") 

1153 

1154 return SSHCertificateBuilder( 

1155 _public_key=public_key, 

1156 _serial=self._serial, 

1157 _type=self._type, 

1158 _key_id=self._key_id, 

1159 _valid_principals=self._valid_principals, 

1160 _valid_for_all_principals=self._valid_for_all_principals, 

1161 _valid_before=self._valid_before, 

1162 _valid_after=self._valid_after, 

1163 _critical_options=self._critical_options, 

1164 _extensions=self._extensions, 

1165 ) 

1166 

1167 def serial(self, serial: int) -> SSHCertificateBuilder: 

1168 if not isinstance(serial, int): 

1169 raise TypeError("serial must be an integer") 

1170 if not 0 <= serial < 2**64: 

1171 raise ValueError("serial must be between 0 and 2**64") 

1172 if self._serial is not None: 

1173 raise ValueError("serial already set") 

1174 

1175 return SSHCertificateBuilder( 

1176 _public_key=self._public_key, 

1177 _serial=serial, 

1178 _type=self._type, 

1179 _key_id=self._key_id, 

1180 _valid_principals=self._valid_principals, 

1181 _valid_for_all_principals=self._valid_for_all_principals, 

1182 _valid_before=self._valid_before, 

1183 _valid_after=self._valid_after, 

1184 _critical_options=self._critical_options, 

1185 _extensions=self._extensions, 

1186 ) 

1187 

1188 def type(self, type: SSHCertificateType) -> SSHCertificateBuilder: 

1189 if not isinstance(type, SSHCertificateType): 

1190 raise TypeError("type must be an SSHCertificateType") 

1191 if self._type is not None: 

1192 raise ValueError("type already set") 

1193 

1194 return SSHCertificateBuilder( 

1195 _public_key=self._public_key, 

1196 _serial=self._serial, 

1197 _type=type, 

1198 _key_id=self._key_id, 

1199 _valid_principals=self._valid_principals, 

1200 _valid_for_all_principals=self._valid_for_all_principals, 

1201 _valid_before=self._valid_before, 

1202 _valid_after=self._valid_after, 

1203 _critical_options=self._critical_options, 

1204 _extensions=self._extensions, 

1205 ) 

1206 

1207 def key_id(self, key_id: bytes) -> SSHCertificateBuilder: 

1208 if not isinstance(key_id, bytes): 

1209 raise TypeError("key_id must be bytes") 

1210 if self._key_id is not None: 

1211 raise ValueError("key_id already set") 

1212 

1213 return SSHCertificateBuilder( 

1214 _public_key=self._public_key, 

1215 _serial=self._serial, 

1216 _type=self._type, 

1217 _key_id=key_id, 

1218 _valid_principals=self._valid_principals, 

1219 _valid_for_all_principals=self._valid_for_all_principals, 

1220 _valid_before=self._valid_before, 

1221 _valid_after=self._valid_after, 

1222 _critical_options=self._critical_options, 

1223 _extensions=self._extensions, 

1224 ) 

1225 

1226 def valid_principals( 

1227 self, valid_principals: list[bytes] 

1228 ) -> SSHCertificateBuilder: 

1229 if self._valid_for_all_principals: 

1230 raise ValueError( 

1231 "Principals can't be set because the cert is valid " 

1232 "for all principals" 

1233 ) 

1234 if ( 

1235 not all(isinstance(x, bytes) for x in valid_principals) 

1236 or not valid_principals 

1237 ): 

1238 raise TypeError( 

1239 "principals must be a list of bytes and can't be empty" 

1240 ) 

1241 if self._valid_principals: 

1242 raise ValueError("valid_principals already set") 

1243 

1244 if len(valid_principals) > _SSHKEY_CERT_MAX_PRINCIPALS: 

1245 raise ValueError( 

1246 "Reached or exceeded the maximum number of valid_principals" 

1247 ) 

1248 

1249 return SSHCertificateBuilder( 

1250 _public_key=self._public_key, 

1251 _serial=self._serial, 

1252 _type=self._type, 

1253 _key_id=self._key_id, 

1254 _valid_principals=valid_principals, 

1255 _valid_for_all_principals=self._valid_for_all_principals, 

1256 _valid_before=self._valid_before, 

1257 _valid_after=self._valid_after, 

1258 _critical_options=self._critical_options, 

1259 _extensions=self._extensions, 

1260 ) 

1261 

1262 def valid_for_all_principals(self): 

1263 if self._valid_principals: 

1264 raise ValueError( 

1265 "valid_principals already set, can't set " 

1266 "valid_for_all_principals" 

1267 ) 

1268 if self._valid_for_all_principals: 

1269 raise ValueError("valid_for_all_principals already set") 

1270 

1271 return SSHCertificateBuilder( 

1272 _public_key=self._public_key, 

1273 _serial=self._serial, 

1274 _type=self._type, 

1275 _key_id=self._key_id, 

1276 _valid_principals=self._valid_principals, 

1277 _valid_for_all_principals=True, 

1278 _valid_before=self._valid_before, 

1279 _valid_after=self._valid_after, 

1280 _critical_options=self._critical_options, 

1281 _extensions=self._extensions, 

1282 ) 

1283 

1284 def valid_before(self, valid_before: int | float) -> SSHCertificateBuilder: 

1285 if not isinstance(valid_before, (int, float)): 

1286 raise TypeError("valid_before must be an int or float") 

1287 valid_before = int(valid_before) 

1288 if valid_before < 0 or valid_before >= 2**64: 

1289 raise ValueError("valid_before must [0, 2**64)") 

1290 if self._valid_before is not None: 

1291 raise ValueError("valid_before already set") 

1292 

1293 return SSHCertificateBuilder( 

1294 _public_key=self._public_key, 

1295 _serial=self._serial, 

1296 _type=self._type, 

1297 _key_id=self._key_id, 

1298 _valid_principals=self._valid_principals, 

1299 _valid_for_all_principals=self._valid_for_all_principals, 

1300 _valid_before=valid_before, 

1301 _valid_after=self._valid_after, 

1302 _critical_options=self._critical_options, 

1303 _extensions=self._extensions, 

1304 ) 

1305 

1306 def valid_after(self, valid_after: int | float) -> SSHCertificateBuilder: 

1307 if not isinstance(valid_after, (int, float)): 

1308 raise TypeError("valid_after must be an int or float") 

1309 valid_after = int(valid_after) 

1310 if valid_after < 0 or valid_after >= 2**64: 

1311 raise ValueError("valid_after must [0, 2**64)") 

1312 if self._valid_after is not None: 

1313 raise ValueError("valid_after already set") 

1314 

1315 return SSHCertificateBuilder( 

1316 _public_key=self._public_key, 

1317 _serial=self._serial, 

1318 _type=self._type, 

1319 _key_id=self._key_id, 

1320 _valid_principals=self._valid_principals, 

1321 _valid_for_all_principals=self._valid_for_all_principals, 

1322 _valid_before=self._valid_before, 

1323 _valid_after=valid_after, 

1324 _critical_options=self._critical_options, 

1325 _extensions=self._extensions, 

1326 ) 

1327 

1328 def add_critical_option( 

1329 self, name: bytes, value: bytes 

1330 ) -> SSHCertificateBuilder: 

1331 if not isinstance(name, bytes) or not isinstance(value, bytes): 

1332 raise TypeError("name and value must be bytes") 

1333 # This is O(n**2) 

1334 if name in [name for name, _ in self._critical_options]: 

1335 raise ValueError("Duplicate critical option name") 

1336 

1337 return SSHCertificateBuilder( 

1338 _public_key=self._public_key, 

1339 _serial=self._serial, 

1340 _type=self._type, 

1341 _key_id=self._key_id, 

1342 _valid_principals=self._valid_principals, 

1343 _valid_for_all_principals=self._valid_for_all_principals, 

1344 _valid_before=self._valid_before, 

1345 _valid_after=self._valid_after, 

1346 _critical_options=[*self._critical_options, (name, value)], 

1347 _extensions=self._extensions, 

1348 ) 

1349 

1350 def add_extension( 

1351 self, name: bytes, value: bytes 

1352 ) -> SSHCertificateBuilder: 

1353 if not isinstance(name, bytes) or not isinstance(value, bytes): 

1354 raise TypeError("name and value must be bytes") 

1355 # This is O(n**2) 

1356 if name in [name for name, _ in self._extensions]: 

1357 raise ValueError("Duplicate extension name") 

1358 

1359 return SSHCertificateBuilder( 

1360 _public_key=self._public_key, 

1361 _serial=self._serial, 

1362 _type=self._type, 

1363 _key_id=self._key_id, 

1364 _valid_principals=self._valid_principals, 

1365 _valid_for_all_principals=self._valid_for_all_principals, 

1366 _valid_before=self._valid_before, 

1367 _valid_after=self._valid_after, 

1368 _critical_options=self._critical_options, 

1369 _extensions=[*self._extensions, (name, value)], 

1370 ) 

1371 

1372 def sign(self, private_key: SSHCertPrivateKeyTypes) -> SSHCertificate: 

1373 if not isinstance( 

1374 private_key, 

1375 ( 

1376 ec.EllipticCurvePrivateKey, 

1377 rsa.RSAPrivateKey, 

1378 ed25519.Ed25519PrivateKey, 

1379 ), 

1380 ): 

1381 raise TypeError("Unsupported private key type") 

1382 

1383 if self._public_key is None: 

1384 raise ValueError("public_key must be set") 

1385 

1386 # Not required 

1387 serial = 0 if self._serial is None else self._serial 

1388 

1389 if self._type is None: 

1390 raise ValueError("type must be set") 

1391 

1392 # Not required 

1393 key_id = b"" if self._key_id is None else self._key_id 

1394 

1395 # A zero length list is valid, but means the certificate 

1396 # is valid for any principal of the specified type. We require 

1397 # the user to explicitly set valid_for_all_principals to get 

1398 # that behavior. 

1399 if not self._valid_principals and not self._valid_for_all_principals: 

1400 raise ValueError( 

1401 "valid_principals must be set if valid_for_all_principals " 

1402 "is False" 

1403 ) 

1404 

1405 if self._valid_before is None: 

1406 raise ValueError("valid_before must be set") 

1407 

1408 if self._valid_after is None: 

1409 raise ValueError("valid_after must be set") 

1410 

1411 if self._valid_after > self._valid_before: 

1412 raise ValueError("valid_after must be earlier than valid_before") 

1413 

1414 # lexically sort our byte strings 

1415 self._critical_options.sort(key=lambda x: x[0]) 

1416 self._extensions.sort(key=lambda x: x[0]) 

1417 

1418 key_type = _get_ssh_key_type(self._public_key) 

1419 cert_prefix = key_type + _CERT_SUFFIX 

1420 

1421 # Marshal the bytes to be signed 

1422 nonce = os.urandom(32) 

1423 kformat = _lookup_kformat(key_type) 

1424 f = _FragList() 

1425 f.put_sshstr(cert_prefix) 

1426 f.put_sshstr(nonce) 

1427 kformat.encode_public(self._public_key, f) 

1428 f.put_u64(serial) 

1429 f.put_u32(self._type.value) 

1430 f.put_sshstr(key_id) 

1431 fprincipals = _FragList() 

1432 for p in self._valid_principals: 

1433 fprincipals.put_sshstr(p) 

1434 f.put_sshstr(fprincipals.tobytes()) 

1435 f.put_u64(self._valid_after) 

1436 f.put_u64(self._valid_before) 

1437 fcrit = _FragList() 

1438 for name, value in self._critical_options: 

1439 fcrit.put_sshstr(name) 

1440 if len(value) > 0: 

1441 foptval = _FragList() 

1442 foptval.put_sshstr(value) 

1443 fcrit.put_sshstr(foptval.tobytes()) 

1444 else: 

1445 fcrit.put_sshstr(value) 

1446 f.put_sshstr(fcrit.tobytes()) 

1447 fext = _FragList() 

1448 for name, value in self._extensions: 

1449 fext.put_sshstr(name) 

1450 if len(value) > 0: 

1451 fextval = _FragList() 

1452 fextval.put_sshstr(value) 

1453 fext.put_sshstr(fextval.tobytes()) 

1454 else: 

1455 fext.put_sshstr(value) 

1456 f.put_sshstr(fext.tobytes()) 

1457 f.put_sshstr(b"") # RESERVED FIELD 

1458 # encode CA public key 

1459 ca_type = _get_ssh_key_type(private_key) 

1460 caformat = _lookup_kformat(ca_type) 

1461 caf = _FragList() 

1462 caf.put_sshstr(ca_type) 

1463 caformat.encode_public(private_key.public_key(), caf) 

1464 f.put_sshstr(caf.tobytes()) 

1465 # Sigs according to the rules defined for the CA's public key 

1466 # (RFC4253 section 6.6 for ssh-rsa, RFC5656 for ECDSA, 

1467 # and RFC8032 for Ed25519). 

1468 if isinstance(private_key, ed25519.Ed25519PrivateKey): 

1469 signature = private_key.sign(f.tobytes()) 

1470 fsig = _FragList() 

1471 fsig.put_sshstr(ca_type) 

1472 fsig.put_sshstr(signature) 

1473 f.put_sshstr(fsig.tobytes()) 

1474 elif isinstance(private_key, ec.EllipticCurvePrivateKey): 

1475 hash_alg = _get_ec_hash_alg(private_key.curve) 

1476 signature = private_key.sign(f.tobytes(), ec.ECDSA(hash_alg)) 

1477 r, s = asym_utils.decode_dss_signature(signature) 

1478 fsig = _FragList() 

1479 fsig.put_sshstr(ca_type) 

1480 fsigblob = _FragList() 

1481 fsigblob.put_mpint(r) 

1482 fsigblob.put_mpint(s) 

1483 fsig.put_sshstr(fsigblob.tobytes()) 

1484 f.put_sshstr(fsig.tobytes()) 

1485 

1486 else: 

1487 assert isinstance(private_key, rsa.RSAPrivateKey) 

1488 # Just like Golang, we're going to use SHA512 for RSA 

1489 # https://cs.opensource.google/go/x/crypto/+/refs/tags/ 

1490 # v0.4.0:ssh/certs.go;l=445 

1491 # RFC 8332 defines SHA256 and 512 as options 

1492 fsig = _FragList() 

1493 fsig.put_sshstr(_SSH_RSA_SHA512) 

1494 signature = private_key.sign( 

1495 f.tobytes(), padding.PKCS1v15(), hashes.SHA512() 

1496 ) 

1497 fsig.put_sshstr(signature) 

1498 f.put_sshstr(fsig.tobytes()) 

1499 

1500 cert_data = binascii.b2a_base64(f.tobytes()).strip() 

1501 # load_ssh_public_identity returns a union, but this is 

1502 # guaranteed to be an SSHCertificate, so we cast to make 

1503 # mypy happy. 

1504 return typing.cast( 

1505 SSHCertificate, 

1506 load_ssh_public_identity(b"".join([cert_prefix, b" ", cert_data])), 

1507 )