Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/cryptography/hazmat/primitives/serialization/ssh.py: 20%

761 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-08 06:05 +0000

1# This file is dual licensed under the terms of the Apache License, Version 

2# 2.0, and the BSD License. See the LICENSE file in the root of this repository 

3# for complete details. 

4 

5from __future__ import annotations 

6 

7import binascii 

8import enum 

9import os 

10import re 

11import typing 

12import warnings 

13from base64 import encodebytes as _base64_encode 

14from dataclasses import dataclass 

15 

16from cryptography import utils 

17from cryptography.exceptions import UnsupportedAlgorithm 

18from cryptography.hazmat.primitives import hashes 

19from cryptography.hazmat.primitives.asymmetric import ( 

20 dsa, 

21 ec, 

22 ed25519, 

23 padding, 

24 rsa, 

25) 

26from cryptography.hazmat.primitives.asymmetric import utils as asym_utils 

27from cryptography.hazmat.primitives.ciphers import ( 

28 AEADDecryptionContext, 

29 Cipher, 

30 algorithms, 

31 modes, 

32) 

33from cryptography.hazmat.primitives.serialization import ( 

34 Encoding, 

35 KeySerializationEncryption, 

36 NoEncryption, 

37 PrivateFormat, 

38 PublicFormat, 

39 _KeySerializationEncryption, 

40) 

41 

42try: 

43 from bcrypt import kdf as _bcrypt_kdf 

44 

45 _bcrypt_supported = True 

46except ImportError: 

47 _bcrypt_supported = False 

48 

49 def _bcrypt_kdf( 

50 password: bytes, 

51 salt: bytes, 

52 desired_key_bytes: int, 

53 rounds: int, 

54 ignore_few_rounds: bool = False, 

55 ) -> bytes: 

56 raise UnsupportedAlgorithm("Need bcrypt module") 

57 

58 

59_SSH_ED25519 = b"ssh-ed25519" 

60_SSH_RSA = b"ssh-rsa" 

61_SSH_DSA = b"ssh-dss" 

62_ECDSA_NISTP256 = b"ecdsa-sha2-nistp256" 

63_ECDSA_NISTP384 = b"ecdsa-sha2-nistp384" 

64_ECDSA_NISTP521 = b"ecdsa-sha2-nistp521" 

65_CERT_SUFFIX = b"-cert-v01@openssh.com" 

66 

67# These are not key types, only algorithms, so they cannot appear 

68# as a public key type 

69_SSH_RSA_SHA256 = b"rsa-sha2-256" 

70_SSH_RSA_SHA512 = b"rsa-sha2-512" 

71 

72_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)") 

73_SK_MAGIC = b"openssh-key-v1\0" 

74_SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----" 

75_SK_END = b"-----END OPENSSH PRIVATE KEY-----" 

76_BCRYPT = b"bcrypt" 

77_NONE = b"none" 

78_DEFAULT_CIPHER = b"aes256-ctr" 

79_DEFAULT_ROUNDS = 16 

80 

81# re is only way to work on bytes-like data 

82_PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL) 

83 

84# padding for max blocksize 

85_PADDING = memoryview(bytearray(range(1, 1 + 16))) 

86 

87 

88@dataclass 

89class _SSHCipher: 

90 alg: typing.Type[algorithms.AES] 

91 key_len: int 

92 mode: typing.Union[ 

93 typing.Type[modes.CTR], 

94 typing.Type[modes.CBC], 

95 typing.Type[modes.GCM], 

96 ] 

97 block_len: int 

98 iv_len: int 

99 tag_len: typing.Optional[int] 

100 is_aead: bool 

101 

102 

103# ciphers that are actually used in key wrapping 

104_SSH_CIPHERS: typing.Dict[bytes, _SSHCipher] = { 

105 b"aes256-ctr": _SSHCipher( 

106 alg=algorithms.AES, 

107 key_len=32, 

108 mode=modes.CTR, 

109 block_len=16, 

110 iv_len=16, 

111 tag_len=None, 

112 is_aead=False, 

113 ), 

114 b"aes256-cbc": _SSHCipher( 

115 alg=algorithms.AES, 

116 key_len=32, 

117 mode=modes.CBC, 

118 block_len=16, 

119 iv_len=16, 

120 tag_len=None, 

121 is_aead=False, 

122 ), 

123 b"aes256-gcm@openssh.com": _SSHCipher( 

124 alg=algorithms.AES, 

125 key_len=32, 

126 mode=modes.GCM, 

127 block_len=16, 

128 iv_len=12, 

129 tag_len=16, 

130 is_aead=True, 

131 ), 

132} 

133 

134# map local curve name to key type 

135_ECDSA_KEY_TYPE = { 

136 "secp256r1": _ECDSA_NISTP256, 

137 "secp384r1": _ECDSA_NISTP384, 

138 "secp521r1": _ECDSA_NISTP521, 

139} 

140 

141 

142def _get_ssh_key_type( 

143 key: typing.Union[SSHPrivateKeyTypes, SSHPublicKeyTypes] 

144) -> bytes: 

145 if isinstance(key, ec.EllipticCurvePrivateKey): 

146 key_type = _ecdsa_key_type(key.public_key()) 

147 elif isinstance(key, ec.EllipticCurvePublicKey): 

148 key_type = _ecdsa_key_type(key) 

149 elif isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)): 

150 key_type = _SSH_RSA 

151 elif isinstance(key, (dsa.DSAPrivateKey, dsa.DSAPublicKey)): 

152 key_type = _SSH_DSA 

153 elif isinstance( 

154 key, (ed25519.Ed25519PrivateKey, ed25519.Ed25519PublicKey) 

155 ): 

156 key_type = _SSH_ED25519 

157 else: 

158 raise ValueError("Unsupported key type") 

159 

160 return key_type 

161 

162 

163def _ecdsa_key_type(public_key: ec.EllipticCurvePublicKey) -> bytes: 

164 """Return SSH key_type and curve_name for private key.""" 

165 curve = public_key.curve 

166 if curve.name not in _ECDSA_KEY_TYPE: 

167 raise ValueError( 

168 f"Unsupported curve for ssh private key: {curve.name!r}" 

169 ) 

170 return _ECDSA_KEY_TYPE[curve.name] 

171 

172 

173def _ssh_pem_encode( 

174 data: bytes, 

175 prefix: bytes = _SK_START + b"\n", 

176 suffix: bytes = _SK_END + b"\n", 

177) -> bytes: 

178 return b"".join([prefix, _base64_encode(data), suffix]) 

179 

180 

181def _check_block_size(data: bytes, block_len: int) -> None: 

182 """Require data to be full blocks""" 

183 if not data or len(data) % block_len != 0: 

184 raise ValueError("Corrupt data: missing padding") 

185 

186 

187def _check_empty(data: bytes) -> None: 

188 """All data should have been parsed.""" 

189 if data: 

190 raise ValueError("Corrupt data: unparsed data") 

191 

192 

193def _init_cipher( 

194 ciphername: bytes, 

195 password: typing.Optional[bytes], 

196 salt: bytes, 

197 rounds: int, 

198) -> Cipher[typing.Union[modes.CBC, modes.CTR, modes.GCM]]: 

199 """Generate key + iv and return cipher.""" 

200 if not password: 

201 raise ValueError("Key is password-protected.") 

202 

203 ciph = _SSH_CIPHERS[ciphername] 

204 seed = _bcrypt_kdf( 

205 password, salt, ciph.key_len + ciph.iv_len, rounds, True 

206 ) 

207 return Cipher( 

208 ciph.alg(seed[: ciph.key_len]), 

209 ciph.mode(seed[ciph.key_len :]), 

210 ) 

211 

212 

213def _get_u32(data: memoryview) -> typing.Tuple[int, memoryview]: 

214 """Uint32""" 

215 if len(data) < 4: 

216 raise ValueError("Invalid data") 

217 return int.from_bytes(data[:4], byteorder="big"), data[4:] 

218 

219 

220def _get_u64(data: memoryview) -> typing.Tuple[int, memoryview]: 

221 """Uint64""" 

222 if len(data) < 8: 

223 raise ValueError("Invalid data") 

224 return int.from_bytes(data[:8], byteorder="big"), data[8:] 

225 

226 

227def _get_sshstr(data: memoryview) -> typing.Tuple[memoryview, memoryview]: 

228 """Bytes with u32 length prefix""" 

229 n, data = _get_u32(data) 

230 if n > len(data): 

231 raise ValueError("Invalid data") 

232 return data[:n], data[n:] 

233 

234 

235def _get_mpint(data: memoryview) -> typing.Tuple[int, memoryview]: 

236 """Big integer.""" 

237 val, data = _get_sshstr(data) 

238 if val and val[0] > 0x7F: 

239 raise ValueError("Invalid data") 

240 return int.from_bytes(val, "big"), data 

241 

242 

243def _to_mpint(val: int) -> bytes: 

244 """Storage format for signed bigint.""" 

245 if val < 0: 

246 raise ValueError("negative mpint not allowed") 

247 if not val: 

248 return b"" 

249 nbytes = (val.bit_length() + 8) // 8 

250 return utils.int_to_bytes(val, nbytes) 

251 

252 

253class _FragList: 

254 """Build recursive structure without data copy.""" 

255 

256 flist: typing.List[bytes] 

257 

258 def __init__( 

259 self, init: typing.Optional[typing.List[bytes]] = None 

260 ) -> None: 

261 self.flist = [] 

262 if init: 

263 self.flist.extend(init) 

264 

265 def put_raw(self, val: bytes) -> None: 

266 """Add plain bytes""" 

267 self.flist.append(val) 

268 

269 def put_u32(self, val: int) -> None: 

270 """Big-endian uint32""" 

271 self.flist.append(val.to_bytes(length=4, byteorder="big")) 

272 

273 def put_u64(self, val: int) -> None: 

274 """Big-endian uint64""" 

275 self.flist.append(val.to_bytes(length=8, byteorder="big")) 

276 

277 def put_sshstr(self, val: typing.Union[bytes, _FragList]) -> None: 

278 """Bytes prefixed with u32 length""" 

279 if isinstance(val, (bytes, memoryview, bytearray)): 

280 self.put_u32(len(val)) 

281 self.flist.append(val) 

282 else: 

283 self.put_u32(val.size()) 

284 self.flist.extend(val.flist) 

285 

286 def put_mpint(self, val: int) -> None: 

287 """Big-endian bigint prefixed with u32 length""" 

288 self.put_sshstr(_to_mpint(val)) 

289 

290 def size(self) -> int: 

291 """Current number of bytes""" 

292 return sum(map(len, self.flist)) 

293 

294 def render(self, dstbuf: memoryview, pos: int = 0) -> int: 

295 """Write into bytearray""" 

296 for frag in self.flist: 

297 flen = len(frag) 

298 start, pos = pos, pos + flen 

299 dstbuf[start:pos] = frag 

300 return pos 

301 

302 def tobytes(self) -> bytes: 

303 """Return as bytes""" 

304 buf = memoryview(bytearray(self.size())) 

305 self.render(buf) 

306 return buf.tobytes() 

307 

308 

309class _SSHFormatRSA: 

310 """Format for RSA keys. 

311 

312 Public: 

313 mpint e, n 

314 Private: 

315 mpint n, e, d, iqmp, p, q 

316 """ 

317 

318 def get_public(self, data: memoryview): 

319 """RSA public fields""" 

320 e, data = _get_mpint(data) 

321 n, data = _get_mpint(data) 

322 return (e, n), data 

323 

324 def load_public( 

325 self, data: memoryview 

326 ) -> typing.Tuple[rsa.RSAPublicKey, memoryview]: 

327 """Make RSA public key from data.""" 

328 (e, n), data = self.get_public(data) 

329 public_numbers = rsa.RSAPublicNumbers(e, n) 

330 public_key = public_numbers.public_key() 

331 return public_key, data 

332 

333 def load_private( 

334 self, data: memoryview, pubfields 

335 ) -> typing.Tuple[rsa.RSAPrivateKey, memoryview]: 

336 """Make RSA private key from data.""" 

337 n, data = _get_mpint(data) 

338 e, data = _get_mpint(data) 

339 d, data = _get_mpint(data) 

340 iqmp, data = _get_mpint(data) 

341 p, data = _get_mpint(data) 

342 q, data = _get_mpint(data) 

343 

344 if (e, n) != pubfields: 

345 raise ValueError("Corrupt data: rsa field mismatch") 

346 dmp1 = rsa.rsa_crt_dmp1(d, p) 

347 dmq1 = rsa.rsa_crt_dmq1(d, q) 

348 public_numbers = rsa.RSAPublicNumbers(e, n) 

349 private_numbers = rsa.RSAPrivateNumbers( 

350 p, q, d, dmp1, dmq1, iqmp, public_numbers 

351 ) 

352 private_key = private_numbers.private_key() 

353 return private_key, data 

354 

355 def encode_public( 

356 self, public_key: rsa.RSAPublicKey, f_pub: _FragList 

357 ) -> None: 

358 """Write RSA public key""" 

359 pubn = public_key.public_numbers() 

360 f_pub.put_mpint(pubn.e) 

361 f_pub.put_mpint(pubn.n) 

362 

363 def encode_private( 

364 self, private_key: rsa.RSAPrivateKey, f_priv: _FragList 

365 ) -> None: 

366 """Write RSA private key""" 

367 private_numbers = private_key.private_numbers() 

368 public_numbers = private_numbers.public_numbers 

369 

370 f_priv.put_mpint(public_numbers.n) 

371 f_priv.put_mpint(public_numbers.e) 

372 

373 f_priv.put_mpint(private_numbers.d) 

374 f_priv.put_mpint(private_numbers.iqmp) 

375 f_priv.put_mpint(private_numbers.p) 

376 f_priv.put_mpint(private_numbers.q) 

377 

378 

379class _SSHFormatDSA: 

380 """Format for DSA keys. 

381 

382 Public: 

383 mpint p, q, g, y 

384 Private: 

385 mpint p, q, g, y, x 

386 """ 

387 

388 def get_public( 

389 self, data: memoryview 

390 ) -> typing.Tuple[typing.Tuple, memoryview]: 

391 """DSA public fields""" 

392 p, data = _get_mpint(data) 

393 q, data = _get_mpint(data) 

394 g, data = _get_mpint(data) 

395 y, data = _get_mpint(data) 

396 return (p, q, g, y), data 

397 

398 def load_public( 

399 self, data: memoryview 

400 ) -> typing.Tuple[dsa.DSAPublicKey, memoryview]: 

401 """Make DSA public key from data.""" 

402 (p, q, g, y), data = self.get_public(data) 

403 parameter_numbers = dsa.DSAParameterNumbers(p, q, g) 

404 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) 

405 self._validate(public_numbers) 

406 public_key = public_numbers.public_key() 

407 return public_key, data 

408 

409 def load_private( 

410 self, data: memoryview, pubfields 

411 ) -> typing.Tuple[dsa.DSAPrivateKey, memoryview]: 

412 """Make DSA private key from data.""" 

413 (p, q, g, y), data = self.get_public(data) 

414 x, data = _get_mpint(data) 

415 

416 if (p, q, g, y) != pubfields: 

417 raise ValueError("Corrupt data: dsa field mismatch") 

418 parameter_numbers = dsa.DSAParameterNumbers(p, q, g) 

419 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) 

420 self._validate(public_numbers) 

421 private_numbers = dsa.DSAPrivateNumbers(x, public_numbers) 

422 private_key = private_numbers.private_key() 

423 return private_key, data 

424 

425 def encode_public( 

426 self, public_key: dsa.DSAPublicKey, f_pub: _FragList 

427 ) -> None: 

428 """Write DSA public key""" 

429 public_numbers = public_key.public_numbers() 

430 parameter_numbers = public_numbers.parameter_numbers 

431 self._validate(public_numbers) 

432 

433 f_pub.put_mpint(parameter_numbers.p) 

434 f_pub.put_mpint(parameter_numbers.q) 

435 f_pub.put_mpint(parameter_numbers.g) 

436 f_pub.put_mpint(public_numbers.y) 

437 

438 def encode_private( 

439 self, private_key: dsa.DSAPrivateKey, f_priv: _FragList 

440 ) -> None: 

441 """Write DSA private key""" 

442 self.encode_public(private_key.public_key(), f_priv) 

443 f_priv.put_mpint(private_key.private_numbers().x) 

444 

445 def _validate(self, public_numbers: dsa.DSAPublicNumbers) -> None: 

446 parameter_numbers = public_numbers.parameter_numbers 

447 if parameter_numbers.p.bit_length() != 1024: 

448 raise ValueError("SSH supports only 1024 bit DSA keys") 

449 

450 

451class _SSHFormatECDSA: 

452 """Format for ECDSA keys. 

453 

454 Public: 

455 str curve 

456 bytes point 

457 Private: 

458 str curve 

459 bytes point 

460 mpint secret 

461 """ 

462 

463 def __init__(self, ssh_curve_name: bytes, curve: ec.EllipticCurve): 

464 self.ssh_curve_name = ssh_curve_name 

465 self.curve = curve 

466 

467 def get_public( 

468 self, data: memoryview 

469 ) -> typing.Tuple[typing.Tuple, memoryview]: 

470 """ECDSA public fields""" 

471 curve, data = _get_sshstr(data) 

472 point, data = _get_sshstr(data) 

473 if curve != self.ssh_curve_name: 

474 raise ValueError("Curve name mismatch") 

475 if point[0] != 4: 

476 raise NotImplementedError("Need uncompressed point") 

477 return (curve, point), data 

478 

479 def load_public( 

480 self, data: memoryview 

481 ) -> typing.Tuple[ec.EllipticCurvePublicKey, memoryview]: 

482 """Make ECDSA public key from data.""" 

483 (curve_name, point), data = self.get_public(data) 

484 public_key = ec.EllipticCurvePublicKey.from_encoded_point( 

485 self.curve, point.tobytes() 

486 ) 

487 return public_key, data 

488 

489 def load_private( 

490 self, data: memoryview, pubfields 

491 ) -> typing.Tuple[ec.EllipticCurvePrivateKey, memoryview]: 

492 """Make ECDSA private key from data.""" 

493 (curve_name, point), data = self.get_public(data) 

494 secret, data = _get_mpint(data) 

495 

496 if (curve_name, point) != pubfields: 

497 raise ValueError("Corrupt data: ecdsa field mismatch") 

498 private_key = ec.derive_private_key(secret, self.curve) 

499 return private_key, data 

500 

501 def encode_public( 

502 self, public_key: ec.EllipticCurvePublicKey, f_pub: _FragList 

503 ) -> None: 

504 """Write ECDSA public key""" 

505 point = public_key.public_bytes( 

506 Encoding.X962, PublicFormat.UncompressedPoint 

507 ) 

508 f_pub.put_sshstr(self.ssh_curve_name) 

509 f_pub.put_sshstr(point) 

510 

511 def encode_private( 

512 self, private_key: ec.EllipticCurvePrivateKey, f_priv: _FragList 

513 ) -> None: 

514 """Write ECDSA private key""" 

515 public_key = private_key.public_key() 

516 private_numbers = private_key.private_numbers() 

517 

518 self.encode_public(public_key, f_priv) 

519 f_priv.put_mpint(private_numbers.private_value) 

520 

521 

522class _SSHFormatEd25519: 

523 """Format for Ed25519 keys. 

524 

525 Public: 

526 bytes point 

527 Private: 

528 bytes point 

529 bytes secret_and_point 

530 """ 

531 

532 def get_public( 

533 self, data: memoryview 

534 ) -> typing.Tuple[typing.Tuple, memoryview]: 

535 """Ed25519 public fields""" 

536 point, data = _get_sshstr(data) 

537 return (point,), data 

538 

539 def load_public( 

540 self, data: memoryview 

541 ) -> typing.Tuple[ed25519.Ed25519PublicKey, memoryview]: 

542 """Make Ed25519 public key from data.""" 

543 (point,), data = self.get_public(data) 

544 public_key = ed25519.Ed25519PublicKey.from_public_bytes( 

545 point.tobytes() 

546 ) 

547 return public_key, data 

548 

549 def load_private( 

550 self, data: memoryview, pubfields 

551 ) -> typing.Tuple[ed25519.Ed25519PrivateKey, memoryview]: 

552 """Make Ed25519 private key from data.""" 

553 (point,), data = self.get_public(data) 

554 keypair, data = _get_sshstr(data) 

555 

556 secret = keypair[:32] 

557 point2 = keypair[32:] 

558 if point != point2 or (point,) != pubfields: 

559 raise ValueError("Corrupt data: ed25519 field mismatch") 

560 private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret) 

561 return private_key, data 

562 

563 def encode_public( 

564 self, public_key: ed25519.Ed25519PublicKey, f_pub: _FragList 

565 ) -> None: 

566 """Write Ed25519 public key""" 

567 raw_public_key = public_key.public_bytes( 

568 Encoding.Raw, PublicFormat.Raw 

569 ) 

570 f_pub.put_sshstr(raw_public_key) 

571 

572 def encode_private( 

573 self, private_key: ed25519.Ed25519PrivateKey, f_priv: _FragList 

574 ) -> None: 

575 """Write Ed25519 private key""" 

576 public_key = private_key.public_key() 

577 raw_private_key = private_key.private_bytes( 

578 Encoding.Raw, PrivateFormat.Raw, NoEncryption() 

579 ) 

580 raw_public_key = public_key.public_bytes( 

581 Encoding.Raw, PublicFormat.Raw 

582 ) 

583 f_keypair = _FragList([raw_private_key, raw_public_key]) 

584 

585 self.encode_public(public_key, f_priv) 

586 f_priv.put_sshstr(f_keypair) 

587 

588 

589_KEY_FORMATS = { 

590 _SSH_RSA: _SSHFormatRSA(), 

591 _SSH_DSA: _SSHFormatDSA(), 

592 _SSH_ED25519: _SSHFormatEd25519(), 

593 _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()), 

594 _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()), 

595 _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()), 

596} 

597 

598 

599def _lookup_kformat(key_type: bytes): 

600 """Return valid format or throw error""" 

601 if not isinstance(key_type, bytes): 

602 key_type = memoryview(key_type).tobytes() 

603 if key_type in _KEY_FORMATS: 

604 return _KEY_FORMATS[key_type] 

605 raise UnsupportedAlgorithm(f"Unsupported key type: {key_type!r}") 

606 

607 

608SSHPrivateKeyTypes = typing.Union[ 

609 ec.EllipticCurvePrivateKey, 

610 rsa.RSAPrivateKey, 

611 dsa.DSAPrivateKey, 

612 ed25519.Ed25519PrivateKey, 

613] 

614 

615 

616def load_ssh_private_key( 

617 data: bytes, 

618 password: typing.Optional[bytes], 

619 backend: typing.Any = None, 

620) -> SSHPrivateKeyTypes: 

621 """Load private key from OpenSSH custom encoding.""" 

622 utils._check_byteslike("data", data) 

623 if password is not None: 

624 utils._check_bytes("password", password) 

625 

626 m = _PEM_RC.search(data) 

627 if not m: 

628 raise ValueError("Not OpenSSH private key format") 

629 p1 = m.start(1) 

630 p2 = m.end(1) 

631 data = binascii.a2b_base64(memoryview(data)[p1:p2]) 

632 if not data.startswith(_SK_MAGIC): 

633 raise ValueError("Not OpenSSH private key format") 

634 data = memoryview(data)[len(_SK_MAGIC) :] 

635 

636 # parse header 

637 ciphername, data = _get_sshstr(data) 

638 kdfname, data = _get_sshstr(data) 

639 kdfoptions, data = _get_sshstr(data) 

640 nkeys, data = _get_u32(data) 

641 if nkeys != 1: 

642 raise ValueError("Only one key supported") 

643 

644 # load public key data 

645 pubdata, data = _get_sshstr(data) 

646 pub_key_type, pubdata = _get_sshstr(pubdata) 

647 kformat = _lookup_kformat(pub_key_type) 

648 pubfields, pubdata = kformat.get_public(pubdata) 

649 _check_empty(pubdata) 

650 

651 if (ciphername, kdfname) != (_NONE, _NONE): 

652 ciphername_bytes = ciphername.tobytes() 

653 if ciphername_bytes not in _SSH_CIPHERS: 

654 raise UnsupportedAlgorithm( 

655 f"Unsupported cipher: {ciphername_bytes!r}" 

656 ) 

657 if kdfname != _BCRYPT: 

658 raise UnsupportedAlgorithm(f"Unsupported KDF: {kdfname!r}") 

659 blklen = _SSH_CIPHERS[ciphername_bytes].block_len 

660 tag_len = _SSH_CIPHERS[ciphername_bytes].tag_len 

661 # load secret data 

662 edata, data = _get_sshstr(data) 

663 # see https://bugzilla.mindrot.org/show_bug.cgi?id=3553 for 

664 # information about how OpenSSH handles AEAD tags 

665 if _SSH_CIPHERS[ciphername_bytes].is_aead: 

666 tag = bytes(data) 

667 if len(tag) != tag_len: 

668 raise ValueError("Corrupt data: invalid tag length for cipher") 

669 else: 

670 _check_empty(data) 

671 _check_block_size(edata, blklen) 

672 salt, kbuf = _get_sshstr(kdfoptions) 

673 rounds, kbuf = _get_u32(kbuf) 

674 _check_empty(kbuf) 

675 ciph = _init_cipher(ciphername_bytes, password, salt.tobytes(), rounds) 

676 dec = ciph.decryptor() 

677 edata = memoryview(dec.update(edata)) 

678 if _SSH_CIPHERS[ciphername_bytes].is_aead: 

679 assert isinstance(dec, AEADDecryptionContext) 

680 _check_empty(dec.finalize_with_tag(tag)) 

681 else: 

682 # _check_block_size requires data to be a full block so there 

683 # should be no output from finalize 

684 _check_empty(dec.finalize()) 

685 else: 

686 # load secret data 

687 edata, data = _get_sshstr(data) 

688 _check_empty(data) 

689 blklen = 8 

690 _check_block_size(edata, blklen) 

691 ck1, edata = _get_u32(edata) 

692 ck2, edata = _get_u32(edata) 

693 if ck1 != ck2: 

694 raise ValueError("Corrupt data: broken checksum") 

695 

696 # load per-key struct 

697 key_type, edata = _get_sshstr(edata) 

698 if key_type != pub_key_type: 

699 raise ValueError("Corrupt data: key type mismatch") 

700 private_key, edata = kformat.load_private(edata, pubfields) 

701 comment, edata = _get_sshstr(edata) 

702 

703 # yes, SSH does padding check *after* all other parsing is done. 

704 # need to follow as it writes zero-byte padding too. 

705 if edata != _PADDING[: len(edata)]: 

706 raise ValueError("Corrupt data: invalid padding") 

707 

708 if isinstance(private_key, dsa.DSAPrivateKey): 

709 warnings.warn( 

710 "SSH DSA keys are deprecated and will be removed in a future " 

711 "release.", 

712 utils.DeprecatedIn40, 

713 stacklevel=2, 

714 ) 

715 

716 return private_key 

717 

718 

719def _serialize_ssh_private_key( 

720 private_key: SSHPrivateKeyTypes, 

721 password: bytes, 

722 encryption_algorithm: KeySerializationEncryption, 

723) -> bytes: 

724 """Serialize private key with OpenSSH custom encoding.""" 

725 utils._check_bytes("password", password) 

726 if isinstance(private_key, dsa.DSAPrivateKey): 

727 warnings.warn( 

728 "SSH DSA key support is deprecated and will be " 

729 "removed in a future release", 

730 utils.DeprecatedIn40, 

731 stacklevel=4, 

732 ) 

733 

734 key_type = _get_ssh_key_type(private_key) 

735 kformat = _lookup_kformat(key_type) 

736 

737 # setup parameters 

738 f_kdfoptions = _FragList() 

739 if password: 

740 ciphername = _DEFAULT_CIPHER 

741 blklen = _SSH_CIPHERS[ciphername].block_len 

742 kdfname = _BCRYPT 

743 rounds = _DEFAULT_ROUNDS 

744 if ( 

745 isinstance(encryption_algorithm, _KeySerializationEncryption) 

746 and encryption_algorithm._kdf_rounds is not None 

747 ): 

748 rounds = encryption_algorithm._kdf_rounds 

749 salt = os.urandom(16) 

750 f_kdfoptions.put_sshstr(salt) 

751 f_kdfoptions.put_u32(rounds) 

752 ciph = _init_cipher(ciphername, password, salt, rounds) 

753 else: 

754 ciphername = kdfname = _NONE 

755 blklen = 8 

756 ciph = None 

757 nkeys = 1 

758 checkval = os.urandom(4) 

759 comment = b"" 

760 

761 # encode public and private parts together 

762 f_public_key = _FragList() 

763 f_public_key.put_sshstr(key_type) 

764 kformat.encode_public(private_key.public_key(), f_public_key) 

765 

766 f_secrets = _FragList([checkval, checkval]) 

767 f_secrets.put_sshstr(key_type) 

768 kformat.encode_private(private_key, f_secrets) 

769 f_secrets.put_sshstr(comment) 

770 f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)]) 

771 

772 # top-level structure 

773 f_main = _FragList() 

774 f_main.put_raw(_SK_MAGIC) 

775 f_main.put_sshstr(ciphername) 

776 f_main.put_sshstr(kdfname) 

777 f_main.put_sshstr(f_kdfoptions) 

778 f_main.put_u32(nkeys) 

779 f_main.put_sshstr(f_public_key) 

780 f_main.put_sshstr(f_secrets) 

781 

782 # copy result info bytearray 

783 slen = f_secrets.size() 

784 mlen = f_main.size() 

785 buf = memoryview(bytearray(mlen + blklen)) 

786 f_main.render(buf) 

787 ofs = mlen - slen 

788 

789 # encrypt in-place 

790 if ciph is not None: 

791 ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:]) 

792 

793 return _ssh_pem_encode(buf[:mlen]) 

794 

795 

796SSHPublicKeyTypes = typing.Union[ 

797 ec.EllipticCurvePublicKey, 

798 rsa.RSAPublicKey, 

799 dsa.DSAPublicKey, 

800 ed25519.Ed25519PublicKey, 

801] 

802 

803SSHCertPublicKeyTypes = typing.Union[ 

804 ec.EllipticCurvePublicKey, 

805 rsa.RSAPublicKey, 

806 ed25519.Ed25519PublicKey, 

807] 

808 

809 

810class SSHCertificateType(enum.Enum): 

811 USER = 1 

812 HOST = 2 

813 

814 

815class SSHCertificate: 

816 def __init__( 

817 self, 

818 _nonce: memoryview, 

819 _public_key: SSHPublicKeyTypes, 

820 _serial: int, 

821 _cctype: int, 

822 _key_id: memoryview, 

823 _valid_principals: typing.List[bytes], 

824 _valid_after: int, 

825 _valid_before: int, 

826 _critical_options: typing.Dict[bytes, bytes], 

827 _extensions: typing.Dict[bytes, bytes], 

828 _sig_type: memoryview, 

829 _sig_key: memoryview, 

830 _inner_sig_type: memoryview, 

831 _signature: memoryview, 

832 _tbs_cert_body: memoryview, 

833 _cert_key_type: bytes, 

834 _cert_body: memoryview, 

835 ): 

836 self._nonce = _nonce 

837 self._public_key = _public_key 

838 self._serial = _serial 

839 try: 

840 self._type = SSHCertificateType(_cctype) 

841 except ValueError: 

842 raise ValueError("Invalid certificate type") 

843 self._key_id = _key_id 

844 self._valid_principals = _valid_principals 

845 self._valid_after = _valid_after 

846 self._valid_before = _valid_before 

847 self._critical_options = _critical_options 

848 self._extensions = _extensions 

849 self._sig_type = _sig_type 

850 self._sig_key = _sig_key 

851 self._inner_sig_type = _inner_sig_type 

852 self._signature = _signature 

853 self._cert_key_type = _cert_key_type 

854 self._cert_body = _cert_body 

855 self._tbs_cert_body = _tbs_cert_body 

856 

857 @property 

858 def nonce(self) -> bytes: 

859 return bytes(self._nonce) 

860 

861 def public_key(self) -> SSHCertPublicKeyTypes: 

862 # make mypy happy until we remove DSA support entirely and 

863 # the underlying union won't have a disallowed type 

864 return typing.cast(SSHCertPublicKeyTypes, self._public_key) 

865 

866 @property 

867 def serial(self) -> int: 

868 return self._serial 

869 

870 @property 

871 def type(self) -> SSHCertificateType: 

872 return self._type 

873 

874 @property 

875 def key_id(self) -> bytes: 

876 return bytes(self._key_id) 

877 

878 @property 

879 def valid_principals(self) -> typing.List[bytes]: 

880 return self._valid_principals 

881 

882 @property 

883 def valid_before(self) -> int: 

884 return self._valid_before 

885 

886 @property 

887 def valid_after(self) -> int: 

888 return self._valid_after 

889 

890 @property 

891 def critical_options(self) -> typing.Dict[bytes, bytes]: 

892 return self._critical_options 

893 

894 @property 

895 def extensions(self) -> typing.Dict[bytes, bytes]: 

896 return self._extensions 

897 

898 def signature_key(self) -> SSHCertPublicKeyTypes: 

899 sigformat = _lookup_kformat(self._sig_type) 

900 signature_key, sigkey_rest = sigformat.load_public(self._sig_key) 

901 _check_empty(sigkey_rest) 

902 return signature_key 

903 

904 def public_bytes(self) -> bytes: 

905 return ( 

906 bytes(self._cert_key_type) 

907 + b" " 

908 + binascii.b2a_base64(bytes(self._cert_body), newline=False) 

909 ) 

910 

911 def verify_cert_signature(self) -> None: 

912 signature_key = self.signature_key() 

913 if isinstance(signature_key, ed25519.Ed25519PublicKey): 

914 signature_key.verify( 

915 bytes(self._signature), bytes(self._tbs_cert_body) 

916 ) 

917 elif isinstance(signature_key, ec.EllipticCurvePublicKey): 

918 # The signature is encoded as a pair of big-endian integers 

919 r, data = _get_mpint(self._signature) 

920 s, data = _get_mpint(data) 

921 _check_empty(data) 

922 computed_sig = asym_utils.encode_dss_signature(r, s) 

923 hash_alg = _get_ec_hash_alg(signature_key.curve) 

924 signature_key.verify( 

925 computed_sig, bytes(self._tbs_cert_body), ec.ECDSA(hash_alg) 

926 ) 

927 else: 

928 assert isinstance(signature_key, rsa.RSAPublicKey) 

929 if self._inner_sig_type == _SSH_RSA: 

930 hash_alg = hashes.SHA1() 

931 elif self._inner_sig_type == _SSH_RSA_SHA256: 

932 hash_alg = hashes.SHA256() 

933 else: 

934 assert self._inner_sig_type == _SSH_RSA_SHA512 

935 hash_alg = hashes.SHA512() 

936 signature_key.verify( 

937 bytes(self._signature), 

938 bytes(self._tbs_cert_body), 

939 padding.PKCS1v15(), 

940 hash_alg, 

941 ) 

942 

943 

944def _get_ec_hash_alg(curve: ec.EllipticCurve) -> hashes.HashAlgorithm: 

945 if isinstance(curve, ec.SECP256R1): 

946 return hashes.SHA256() 

947 elif isinstance(curve, ec.SECP384R1): 

948 return hashes.SHA384() 

949 else: 

950 assert isinstance(curve, ec.SECP521R1) 

951 return hashes.SHA512() 

952 

953 

954def _load_ssh_public_identity( 

955 data: bytes, 

956 _legacy_dsa_allowed=False, 

957) -> typing.Union[SSHCertificate, SSHPublicKeyTypes]: 

958 utils._check_byteslike("data", data) 

959 

960 m = _SSH_PUBKEY_RC.match(data) 

961 if not m: 

962 raise ValueError("Invalid line format") 

963 key_type = orig_key_type = m.group(1) 

964 key_body = m.group(2) 

965 with_cert = False 

966 if key_type.endswith(_CERT_SUFFIX): 

967 with_cert = True 

968 key_type = key_type[: -len(_CERT_SUFFIX)] 

969 if key_type == _SSH_DSA and not _legacy_dsa_allowed: 

970 raise UnsupportedAlgorithm( 

971 "DSA keys aren't supported in SSH certificates" 

972 ) 

973 kformat = _lookup_kformat(key_type) 

974 

975 try: 

976 rest = memoryview(binascii.a2b_base64(key_body)) 

977 except (TypeError, binascii.Error): 

978 raise ValueError("Invalid format") 

979 

980 if with_cert: 

981 cert_body = rest 

982 inner_key_type, rest = _get_sshstr(rest) 

983 if inner_key_type != orig_key_type: 

984 raise ValueError("Invalid key format") 

985 if with_cert: 

986 nonce, rest = _get_sshstr(rest) 

987 public_key, rest = kformat.load_public(rest) 

988 if with_cert: 

989 serial, rest = _get_u64(rest) 

990 cctype, rest = _get_u32(rest) 

991 key_id, rest = _get_sshstr(rest) 

992 principals, rest = _get_sshstr(rest) 

993 valid_principals = [] 

994 while principals: 

995 principal, principals = _get_sshstr(principals) 

996 valid_principals.append(bytes(principal)) 

997 valid_after, rest = _get_u64(rest) 

998 valid_before, rest = _get_u64(rest) 

999 crit_options, rest = _get_sshstr(rest) 

1000 critical_options = _parse_exts_opts(crit_options) 

1001 exts, rest = _get_sshstr(rest) 

1002 extensions = _parse_exts_opts(exts) 

1003 # Get the reserved field, which is unused. 

1004 _, rest = _get_sshstr(rest) 

1005 sig_key_raw, rest = _get_sshstr(rest) 

1006 sig_type, sig_key = _get_sshstr(sig_key_raw) 

1007 if sig_type == _SSH_DSA and not _legacy_dsa_allowed: 

1008 raise UnsupportedAlgorithm( 

1009 "DSA signatures aren't supported in SSH certificates" 

1010 ) 

1011 # Get the entire cert body and subtract the signature 

1012 tbs_cert_body = cert_body[: -len(rest)] 

1013 signature_raw, rest = _get_sshstr(rest) 

1014 _check_empty(rest) 

1015 inner_sig_type, sig_rest = _get_sshstr(signature_raw) 

1016 # RSA certs can have multiple algorithm types 

1017 if ( 

1018 sig_type == _SSH_RSA 

1019 and inner_sig_type 

1020 not in [_SSH_RSA_SHA256, _SSH_RSA_SHA512, _SSH_RSA] 

1021 ) or (sig_type != _SSH_RSA and inner_sig_type != sig_type): 

1022 raise ValueError("Signature key type does not match") 

1023 signature, sig_rest = _get_sshstr(sig_rest) 

1024 _check_empty(sig_rest) 

1025 return SSHCertificate( 

1026 nonce, 

1027 public_key, 

1028 serial, 

1029 cctype, 

1030 key_id, 

1031 valid_principals, 

1032 valid_after, 

1033 valid_before, 

1034 critical_options, 

1035 extensions, 

1036 sig_type, 

1037 sig_key, 

1038 inner_sig_type, 

1039 signature, 

1040 tbs_cert_body, 

1041 orig_key_type, 

1042 cert_body, 

1043 ) 

1044 else: 

1045 _check_empty(rest) 

1046 return public_key 

1047 

1048 

1049def load_ssh_public_identity( 

1050 data: bytes, 

1051) -> typing.Union[SSHCertificate, SSHPublicKeyTypes]: 

1052 return _load_ssh_public_identity(data) 

1053 

1054 

1055def _parse_exts_opts(exts_opts: memoryview) -> typing.Dict[bytes, bytes]: 

1056 result: typing.Dict[bytes, bytes] = {} 

1057 last_name = None 

1058 while exts_opts: 

1059 name, exts_opts = _get_sshstr(exts_opts) 

1060 bname: bytes = bytes(name) 

1061 if bname in result: 

1062 raise ValueError("Duplicate name") 

1063 if last_name is not None and bname < last_name: 

1064 raise ValueError("Fields not lexically sorted") 

1065 value, exts_opts = _get_sshstr(exts_opts) 

1066 if len(value) > 0: 

1067 try: 

1068 value, extra = _get_sshstr(value) 

1069 except ValueError: 

1070 warnings.warn( 

1071 "This certificate has an incorrect encoding for critical " 

1072 "options or extensions. This will be an exception in " 

1073 "cryptography 42", 

1074 utils.DeprecatedIn41, 

1075 stacklevel=4, 

1076 ) 

1077 else: 

1078 if len(extra) > 0: 

1079 raise ValueError("Unexpected extra data after value") 

1080 result[bname] = bytes(value) 

1081 last_name = bname 

1082 return result 

1083 

1084 

1085def load_ssh_public_key( 

1086 data: bytes, backend: typing.Any = None 

1087) -> SSHPublicKeyTypes: 

1088 cert_or_key = _load_ssh_public_identity(data, _legacy_dsa_allowed=True) 

1089 public_key: SSHPublicKeyTypes 

1090 if isinstance(cert_or_key, SSHCertificate): 

1091 public_key = cert_or_key.public_key() 

1092 else: 

1093 public_key = cert_or_key 

1094 

1095 if isinstance(public_key, dsa.DSAPublicKey): 

1096 warnings.warn( 

1097 "SSH DSA keys are deprecated and will be removed in a future " 

1098 "release.", 

1099 utils.DeprecatedIn40, 

1100 stacklevel=2, 

1101 ) 

1102 return public_key 

1103 

1104 

1105def serialize_ssh_public_key(public_key: SSHPublicKeyTypes) -> bytes: 

1106 """One-line public key format for OpenSSH""" 

1107 if isinstance(public_key, dsa.DSAPublicKey): 

1108 warnings.warn( 

1109 "SSH DSA key support is deprecated and will be " 

1110 "removed in a future release", 

1111 utils.DeprecatedIn40, 

1112 stacklevel=4, 

1113 ) 

1114 key_type = _get_ssh_key_type(public_key) 

1115 kformat = _lookup_kformat(key_type) 

1116 

1117 f_pub = _FragList() 

1118 f_pub.put_sshstr(key_type) 

1119 kformat.encode_public(public_key, f_pub) 

1120 

1121 pub = binascii.b2a_base64(f_pub.tobytes()).strip() 

1122 return b"".join([key_type, b" ", pub]) 

1123 

1124 

1125SSHCertPrivateKeyTypes = typing.Union[ 

1126 ec.EllipticCurvePrivateKey, 

1127 rsa.RSAPrivateKey, 

1128 ed25519.Ed25519PrivateKey, 

1129] 

1130 

1131 

1132# This is an undocumented limit enforced in the openssh codebase for sshd and 

1133# ssh-keygen, but it is undefined in the ssh certificates spec. 

1134_SSHKEY_CERT_MAX_PRINCIPALS = 256 

1135 

1136 

1137class SSHCertificateBuilder: 

1138 def __init__( 

1139 self, 

1140 _public_key: typing.Optional[SSHCertPublicKeyTypes] = None, 

1141 _serial: typing.Optional[int] = None, 

1142 _type: typing.Optional[SSHCertificateType] = None, 

1143 _key_id: typing.Optional[bytes] = None, 

1144 _valid_principals: typing.List[bytes] = [], 

1145 _valid_for_all_principals: bool = False, 

1146 _valid_before: typing.Optional[int] = None, 

1147 _valid_after: typing.Optional[int] = None, 

1148 _critical_options: typing.List[typing.Tuple[bytes, bytes]] = [], 

1149 _extensions: typing.List[typing.Tuple[bytes, bytes]] = [], 

1150 ): 

1151 self._public_key = _public_key 

1152 self._serial = _serial 

1153 self._type = _type 

1154 self._key_id = _key_id 

1155 self._valid_principals = _valid_principals 

1156 self._valid_for_all_principals = _valid_for_all_principals 

1157 self._valid_before = _valid_before 

1158 self._valid_after = _valid_after 

1159 self._critical_options = _critical_options 

1160 self._extensions = _extensions 

1161 

1162 def public_key( 

1163 self, public_key: SSHCertPublicKeyTypes 

1164 ) -> SSHCertificateBuilder: 

1165 if not isinstance( 

1166 public_key, 

1167 ( 

1168 ec.EllipticCurvePublicKey, 

1169 rsa.RSAPublicKey, 

1170 ed25519.Ed25519PublicKey, 

1171 ), 

1172 ): 

1173 raise TypeError("Unsupported key type") 

1174 if self._public_key is not None: 

1175 raise ValueError("public_key already set") 

1176 

1177 return SSHCertificateBuilder( 

1178 _public_key=public_key, 

1179 _serial=self._serial, 

1180 _type=self._type, 

1181 _key_id=self._key_id, 

1182 _valid_principals=self._valid_principals, 

1183 _valid_for_all_principals=self._valid_for_all_principals, 

1184 _valid_before=self._valid_before, 

1185 _valid_after=self._valid_after, 

1186 _critical_options=self._critical_options, 

1187 _extensions=self._extensions, 

1188 ) 

1189 

1190 def serial(self, serial: int) -> SSHCertificateBuilder: 

1191 if not isinstance(serial, int): 

1192 raise TypeError("serial must be an integer") 

1193 if not 0 <= serial < 2**64: 

1194 raise ValueError("serial must be between 0 and 2**64") 

1195 if self._serial is not None: 

1196 raise ValueError("serial already set") 

1197 

1198 return SSHCertificateBuilder( 

1199 _public_key=self._public_key, 

1200 _serial=serial, 

1201 _type=self._type, 

1202 _key_id=self._key_id, 

1203 _valid_principals=self._valid_principals, 

1204 _valid_for_all_principals=self._valid_for_all_principals, 

1205 _valid_before=self._valid_before, 

1206 _valid_after=self._valid_after, 

1207 _critical_options=self._critical_options, 

1208 _extensions=self._extensions, 

1209 ) 

1210 

1211 def type(self, type: SSHCertificateType) -> SSHCertificateBuilder: 

1212 if not isinstance(type, SSHCertificateType): 

1213 raise TypeError("type must be an SSHCertificateType") 

1214 if self._type is not None: 

1215 raise ValueError("type already set") 

1216 

1217 return SSHCertificateBuilder( 

1218 _public_key=self._public_key, 

1219 _serial=self._serial, 

1220 _type=type, 

1221 _key_id=self._key_id, 

1222 _valid_principals=self._valid_principals, 

1223 _valid_for_all_principals=self._valid_for_all_principals, 

1224 _valid_before=self._valid_before, 

1225 _valid_after=self._valid_after, 

1226 _critical_options=self._critical_options, 

1227 _extensions=self._extensions, 

1228 ) 

1229 

1230 def key_id(self, key_id: bytes) -> SSHCertificateBuilder: 

1231 if not isinstance(key_id, bytes): 

1232 raise TypeError("key_id must be bytes") 

1233 if self._key_id is not None: 

1234 raise ValueError("key_id already set") 

1235 

1236 return SSHCertificateBuilder( 

1237 _public_key=self._public_key, 

1238 _serial=self._serial, 

1239 _type=self._type, 

1240 _key_id=key_id, 

1241 _valid_principals=self._valid_principals, 

1242 _valid_for_all_principals=self._valid_for_all_principals, 

1243 _valid_before=self._valid_before, 

1244 _valid_after=self._valid_after, 

1245 _critical_options=self._critical_options, 

1246 _extensions=self._extensions, 

1247 ) 

1248 

1249 def valid_principals( 

1250 self, valid_principals: typing.List[bytes] 

1251 ) -> SSHCertificateBuilder: 

1252 if self._valid_for_all_principals: 

1253 raise ValueError( 

1254 "Principals can't be set because the cert is valid " 

1255 "for all principals" 

1256 ) 

1257 if ( 

1258 not all(isinstance(x, bytes) for x in valid_principals) 

1259 or not valid_principals 

1260 ): 

1261 raise TypeError( 

1262 "principals must be a list of bytes and can't be empty" 

1263 ) 

1264 if self._valid_principals: 

1265 raise ValueError("valid_principals already set") 

1266 

1267 if len(valid_principals) > _SSHKEY_CERT_MAX_PRINCIPALS: 

1268 raise ValueError( 

1269 "Reached or exceeded the maximum number of valid_principals" 

1270 ) 

1271 

1272 return SSHCertificateBuilder( 

1273 _public_key=self._public_key, 

1274 _serial=self._serial, 

1275 _type=self._type, 

1276 _key_id=self._key_id, 

1277 _valid_principals=valid_principals, 

1278 _valid_for_all_principals=self._valid_for_all_principals, 

1279 _valid_before=self._valid_before, 

1280 _valid_after=self._valid_after, 

1281 _critical_options=self._critical_options, 

1282 _extensions=self._extensions, 

1283 ) 

1284 

1285 def valid_for_all_principals(self): 

1286 if self._valid_principals: 

1287 raise ValueError( 

1288 "valid_principals already set, can't set " 

1289 "valid_for_all_principals" 

1290 ) 

1291 if self._valid_for_all_principals: 

1292 raise ValueError("valid_for_all_principals already set") 

1293 

1294 return SSHCertificateBuilder( 

1295 _public_key=self._public_key, 

1296 _serial=self._serial, 

1297 _type=self._type, 

1298 _key_id=self._key_id, 

1299 _valid_principals=self._valid_principals, 

1300 _valid_for_all_principals=True, 

1301 _valid_before=self._valid_before, 

1302 _valid_after=self._valid_after, 

1303 _critical_options=self._critical_options, 

1304 _extensions=self._extensions, 

1305 ) 

1306 

1307 def valid_before( 

1308 self, valid_before: typing.Union[int, float] 

1309 ) -> SSHCertificateBuilder: 

1310 if not isinstance(valid_before, (int, float)): 

1311 raise TypeError("valid_before must be an int or float") 

1312 valid_before = int(valid_before) 

1313 if valid_before < 0 or valid_before >= 2**64: 

1314 raise ValueError("valid_before must [0, 2**64)") 

1315 if self._valid_before is not None: 

1316 raise ValueError("valid_before already set") 

1317 

1318 return SSHCertificateBuilder( 

1319 _public_key=self._public_key, 

1320 _serial=self._serial, 

1321 _type=self._type, 

1322 _key_id=self._key_id, 

1323 _valid_principals=self._valid_principals, 

1324 _valid_for_all_principals=self._valid_for_all_principals, 

1325 _valid_before=valid_before, 

1326 _valid_after=self._valid_after, 

1327 _critical_options=self._critical_options, 

1328 _extensions=self._extensions, 

1329 ) 

1330 

1331 def valid_after( 

1332 self, valid_after: typing.Union[int, float] 

1333 ) -> SSHCertificateBuilder: 

1334 if not isinstance(valid_after, (int, float)): 

1335 raise TypeError("valid_after must be an int or float") 

1336 valid_after = int(valid_after) 

1337 if valid_after < 0 or valid_after >= 2**64: 

1338 raise ValueError("valid_after must [0, 2**64)") 

1339 if self._valid_after is not None: 

1340 raise ValueError("valid_after already set") 

1341 

1342 return SSHCertificateBuilder( 

1343 _public_key=self._public_key, 

1344 _serial=self._serial, 

1345 _type=self._type, 

1346 _key_id=self._key_id, 

1347 _valid_principals=self._valid_principals, 

1348 _valid_for_all_principals=self._valid_for_all_principals, 

1349 _valid_before=self._valid_before, 

1350 _valid_after=valid_after, 

1351 _critical_options=self._critical_options, 

1352 _extensions=self._extensions, 

1353 ) 

1354 

1355 def add_critical_option( 

1356 self, name: bytes, value: bytes 

1357 ) -> SSHCertificateBuilder: 

1358 if not isinstance(name, bytes) or not isinstance(value, bytes): 

1359 raise TypeError("name and value must be bytes") 

1360 # This is O(n**2) 

1361 if name in [name for name, _ in self._critical_options]: 

1362 raise ValueError("Duplicate critical option name") 

1363 

1364 return SSHCertificateBuilder( 

1365 _public_key=self._public_key, 

1366 _serial=self._serial, 

1367 _type=self._type, 

1368 _key_id=self._key_id, 

1369 _valid_principals=self._valid_principals, 

1370 _valid_for_all_principals=self._valid_for_all_principals, 

1371 _valid_before=self._valid_before, 

1372 _valid_after=self._valid_after, 

1373 _critical_options=self._critical_options + [(name, value)], 

1374 _extensions=self._extensions, 

1375 ) 

1376 

1377 def add_extension( 

1378 self, name: bytes, value: bytes 

1379 ) -> SSHCertificateBuilder: 

1380 if not isinstance(name, bytes) or not isinstance(value, bytes): 

1381 raise TypeError("name and value must be bytes") 

1382 # This is O(n**2) 

1383 if name in [name for name, _ in self._extensions]: 

1384 raise ValueError("Duplicate extension name") 

1385 

1386 return SSHCertificateBuilder( 

1387 _public_key=self._public_key, 

1388 _serial=self._serial, 

1389 _type=self._type, 

1390 _key_id=self._key_id, 

1391 _valid_principals=self._valid_principals, 

1392 _valid_for_all_principals=self._valid_for_all_principals, 

1393 _valid_before=self._valid_before, 

1394 _valid_after=self._valid_after, 

1395 _critical_options=self._critical_options, 

1396 _extensions=self._extensions + [(name, value)], 

1397 ) 

1398 

1399 def sign(self, private_key: SSHCertPrivateKeyTypes) -> SSHCertificate: 

1400 if not isinstance( 

1401 private_key, 

1402 ( 

1403 ec.EllipticCurvePrivateKey, 

1404 rsa.RSAPrivateKey, 

1405 ed25519.Ed25519PrivateKey, 

1406 ), 

1407 ): 

1408 raise TypeError("Unsupported private key type") 

1409 

1410 if self._public_key is None: 

1411 raise ValueError("public_key must be set") 

1412 

1413 # Not required 

1414 serial = 0 if self._serial is None else self._serial 

1415 

1416 if self._type is None: 

1417 raise ValueError("type must be set") 

1418 

1419 # Not required 

1420 key_id = b"" if self._key_id is None else self._key_id 

1421 

1422 # A zero length list is valid, but means the certificate 

1423 # is valid for any principal of the specified type. We require 

1424 # the user to explicitly set valid_for_all_principals to get 

1425 # that behavior. 

1426 if not self._valid_principals and not self._valid_for_all_principals: 

1427 raise ValueError( 

1428 "valid_principals must be set if valid_for_all_principals " 

1429 "is False" 

1430 ) 

1431 

1432 if self._valid_before is None: 

1433 raise ValueError("valid_before must be set") 

1434 

1435 if self._valid_after is None: 

1436 raise ValueError("valid_after must be set") 

1437 

1438 if self._valid_after > self._valid_before: 

1439 raise ValueError("valid_after must be earlier than valid_before") 

1440 

1441 # lexically sort our byte strings 

1442 self._critical_options.sort(key=lambda x: x[0]) 

1443 self._extensions.sort(key=lambda x: x[0]) 

1444 

1445 key_type = _get_ssh_key_type(self._public_key) 

1446 cert_prefix = key_type + _CERT_SUFFIX 

1447 

1448 # Marshal the bytes to be signed 

1449 nonce = os.urandom(32) 

1450 kformat = _lookup_kformat(key_type) 

1451 f = _FragList() 

1452 f.put_sshstr(cert_prefix) 

1453 f.put_sshstr(nonce) 

1454 kformat.encode_public(self._public_key, f) 

1455 f.put_u64(serial) 

1456 f.put_u32(self._type.value) 

1457 f.put_sshstr(key_id) 

1458 fprincipals = _FragList() 

1459 for p in self._valid_principals: 

1460 fprincipals.put_sshstr(p) 

1461 f.put_sshstr(fprincipals.tobytes()) 

1462 f.put_u64(self._valid_after) 

1463 f.put_u64(self._valid_before) 

1464 fcrit = _FragList() 

1465 for name, value in self._critical_options: 

1466 fcrit.put_sshstr(name) 

1467 if len(value) > 0: 

1468 foptval = _FragList() 

1469 foptval.put_sshstr(value) 

1470 fcrit.put_sshstr(foptval.tobytes()) 

1471 else: 

1472 fcrit.put_sshstr(value) 

1473 f.put_sshstr(fcrit.tobytes()) 

1474 fext = _FragList() 

1475 for name, value in self._extensions: 

1476 fext.put_sshstr(name) 

1477 if len(value) > 0: 

1478 fextval = _FragList() 

1479 fextval.put_sshstr(value) 

1480 fext.put_sshstr(fextval.tobytes()) 

1481 else: 

1482 fext.put_sshstr(value) 

1483 f.put_sshstr(fext.tobytes()) 

1484 f.put_sshstr(b"") # RESERVED FIELD 

1485 # encode CA public key 

1486 ca_type = _get_ssh_key_type(private_key) 

1487 caformat = _lookup_kformat(ca_type) 

1488 caf = _FragList() 

1489 caf.put_sshstr(ca_type) 

1490 caformat.encode_public(private_key.public_key(), caf) 

1491 f.put_sshstr(caf.tobytes()) 

1492 # Sigs according to the rules defined for the CA's public key 

1493 # (RFC4253 section 6.6 for ssh-rsa, RFC5656 for ECDSA, 

1494 # and RFC8032 for Ed25519). 

1495 if isinstance(private_key, ed25519.Ed25519PrivateKey): 

1496 signature = private_key.sign(f.tobytes()) 

1497 fsig = _FragList() 

1498 fsig.put_sshstr(ca_type) 

1499 fsig.put_sshstr(signature) 

1500 f.put_sshstr(fsig.tobytes()) 

1501 elif isinstance(private_key, ec.EllipticCurvePrivateKey): 

1502 hash_alg = _get_ec_hash_alg(private_key.curve) 

1503 signature = private_key.sign(f.tobytes(), ec.ECDSA(hash_alg)) 

1504 r, s = asym_utils.decode_dss_signature(signature) 

1505 fsig = _FragList() 

1506 fsig.put_sshstr(ca_type) 

1507 fsigblob = _FragList() 

1508 fsigblob.put_mpint(r) 

1509 fsigblob.put_mpint(s) 

1510 fsig.put_sshstr(fsigblob.tobytes()) 

1511 f.put_sshstr(fsig.tobytes()) 

1512 

1513 else: 

1514 assert isinstance(private_key, rsa.RSAPrivateKey) 

1515 # Just like Golang, we're going to use SHA512 for RSA 

1516 # https://cs.opensource.google/go/x/crypto/+/refs/tags/ 

1517 # v0.4.0:ssh/certs.go;l=445 

1518 # RFC 8332 defines SHA256 and 512 as options 

1519 fsig = _FragList() 

1520 fsig.put_sshstr(_SSH_RSA_SHA512) 

1521 signature = private_key.sign( 

1522 f.tobytes(), padding.PKCS1v15(), hashes.SHA512() 

1523 ) 

1524 fsig.put_sshstr(signature) 

1525 f.put_sshstr(fsig.tobytes()) 

1526 

1527 cert_data = binascii.b2a_base64(f.tobytes()).strip() 

1528 # load_ssh_public_identity returns a union, but this is 

1529 # guaranteed to be an SSHCertificate, so we cast to make 

1530 # mypy happy. 

1531 return typing.cast( 

1532 SSHCertificate, 

1533 load_ssh_public_identity(b"".join([cert_prefix, b" ", cert_data])), 

1534 )