Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/cryptography/hazmat/primitives/serialization/ssh.py: 20%

723 statements  

« prev     ^ index     » next       coverage.py v7.2.2, created at 2023-03-26 06:36 +0000

1# This file is dual licensed under the terms of the Apache License, Version 

2# 2.0, and the BSD License. See the LICENSE file in the root of this repository 

3# for complete details. 

4 

5 

6import binascii 

7import enum 

8import os 

9import re 

10import typing 

11import warnings 

12from base64 import encodebytes as _base64_encode 

13 

14from cryptography import utils 

15from cryptography.exceptions import UnsupportedAlgorithm 

16from cryptography.hazmat.primitives import hashes 

17from cryptography.hazmat.primitives.asymmetric import ( 

18 dsa, 

19 ec, 

20 ed25519, 

21 padding, 

22 rsa, 

23) 

24from cryptography.hazmat.primitives.asymmetric import utils as asym_utils 

25from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes 

26from cryptography.hazmat.primitives.serialization import ( 

27 Encoding, 

28 KeySerializationEncryption, 

29 NoEncryption, 

30 PrivateFormat, 

31 PublicFormat, 

32 _KeySerializationEncryption, 

33) 

34 

35try: 

36 from bcrypt import kdf as _bcrypt_kdf 

37 

38 _bcrypt_supported = True 

39except ImportError: 

40 _bcrypt_supported = False 

41 

42 def _bcrypt_kdf( 

43 password: bytes, 

44 salt: bytes, 

45 desired_key_bytes: int, 

46 rounds: int, 

47 ignore_few_rounds: bool = False, 

48 ) -> bytes: 

49 raise UnsupportedAlgorithm("Need bcrypt module") 

50 

51 

52_SSH_ED25519 = b"ssh-ed25519" 

53_SSH_RSA = b"ssh-rsa" 

54_SSH_DSA = b"ssh-dss" 

55_ECDSA_NISTP256 = b"ecdsa-sha2-nistp256" 

56_ECDSA_NISTP384 = b"ecdsa-sha2-nistp384" 

57_ECDSA_NISTP521 = b"ecdsa-sha2-nistp521" 

58_CERT_SUFFIX = b"-cert-v01@openssh.com" 

59 

60# These are not key types, only algorithms, so they cannot appear 

61# as a public key type 

62_SSH_RSA_SHA256 = b"rsa-sha2-256" 

63_SSH_RSA_SHA512 = b"rsa-sha2-512" 

64 

65_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)") 

66_SK_MAGIC = b"openssh-key-v1\0" 

67_SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----" 

68_SK_END = b"-----END OPENSSH PRIVATE KEY-----" 

69_BCRYPT = b"bcrypt" 

70_NONE = b"none" 

71_DEFAULT_CIPHER = b"aes256-ctr" 

72_DEFAULT_ROUNDS = 16 

73 

74# re is only way to work on bytes-like data 

75_PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL) 

76 

77# padding for max blocksize 

78_PADDING = memoryview(bytearray(range(1, 1 + 16))) 

79 

80# ciphers that are actually used in key wrapping 

81_SSH_CIPHERS: typing.Dict[ 

82 bytes, 

83 typing.Tuple[ 

84 typing.Type[algorithms.AES], 

85 int, 

86 typing.Union[typing.Type[modes.CTR], typing.Type[modes.CBC]], 

87 int, 

88 ], 

89] = { 

90 b"aes256-ctr": (algorithms.AES, 32, modes.CTR, 16), 

91 b"aes256-cbc": (algorithms.AES, 32, modes.CBC, 16), 

92} 

93 

94# map local curve name to key type 

95_ECDSA_KEY_TYPE = { 

96 "secp256r1": _ECDSA_NISTP256, 

97 "secp384r1": _ECDSA_NISTP384, 

98 "secp521r1": _ECDSA_NISTP521, 

99} 

100 

101 

102def _get_ssh_key_type( 

103 key: typing.Union["SSHPrivateKeyTypes", "SSHPublicKeyTypes"] 

104) -> bytes: 

105 if isinstance(key, ec.EllipticCurvePrivateKey): 

106 key_type = _ecdsa_key_type(key.public_key()) 

107 elif isinstance(key, ec.EllipticCurvePublicKey): 

108 key_type = _ecdsa_key_type(key) 

109 elif isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)): 

110 key_type = _SSH_RSA 

111 elif isinstance(key, (dsa.DSAPrivateKey, dsa.DSAPublicKey)): 

112 key_type = _SSH_DSA 

113 elif isinstance( 

114 key, (ed25519.Ed25519PrivateKey, ed25519.Ed25519PublicKey) 

115 ): 

116 key_type = _SSH_ED25519 

117 else: 

118 raise ValueError("Unsupported key type") 

119 

120 return key_type 

121 

122 

123def _ecdsa_key_type(public_key: ec.EllipticCurvePublicKey) -> bytes: 

124 """Return SSH key_type and curve_name for private key.""" 

125 curve = public_key.curve 

126 if curve.name not in _ECDSA_KEY_TYPE: 

127 raise ValueError( 

128 f"Unsupported curve for ssh private key: {curve.name!r}" 

129 ) 

130 return _ECDSA_KEY_TYPE[curve.name] 

131 

132 

133def _ssh_pem_encode( 

134 data: bytes, 

135 prefix: bytes = _SK_START + b"\n", 

136 suffix: bytes = _SK_END + b"\n", 

137) -> bytes: 

138 return b"".join([prefix, _base64_encode(data), suffix]) 

139 

140 

141def _check_block_size(data: bytes, block_len: int) -> None: 

142 """Require data to be full blocks""" 

143 if not data or len(data) % block_len != 0: 

144 raise ValueError("Corrupt data: missing padding") 

145 

146 

147def _check_empty(data: bytes) -> None: 

148 """All data should have been parsed.""" 

149 if data: 

150 raise ValueError("Corrupt data: unparsed data") 

151 

152 

153def _init_cipher( 

154 ciphername: bytes, 

155 password: typing.Optional[bytes], 

156 salt: bytes, 

157 rounds: int, 

158) -> Cipher[typing.Union[modes.CBC, modes.CTR]]: 

159 """Generate key + iv and return cipher.""" 

160 if not password: 

161 raise ValueError("Key is password-protected.") 

162 

163 algo, key_len, mode, iv_len = _SSH_CIPHERS[ciphername] 

164 seed = _bcrypt_kdf(password, salt, key_len + iv_len, rounds, True) 

165 return Cipher(algo(seed[:key_len]), mode(seed[key_len:])) 

166 

167 

168def _get_u32(data: memoryview) -> typing.Tuple[int, memoryview]: 

169 """Uint32""" 

170 if len(data) < 4: 

171 raise ValueError("Invalid data") 

172 return int.from_bytes(data[:4], byteorder="big"), data[4:] 

173 

174 

175def _get_u64(data: memoryview) -> typing.Tuple[int, memoryview]: 

176 """Uint64""" 

177 if len(data) < 8: 

178 raise ValueError("Invalid data") 

179 return int.from_bytes(data[:8], byteorder="big"), data[8:] 

180 

181 

182def _get_sshstr(data: memoryview) -> typing.Tuple[memoryview, memoryview]: 

183 """Bytes with u32 length prefix""" 

184 n, data = _get_u32(data) 

185 if n > len(data): 

186 raise ValueError("Invalid data") 

187 return data[:n], data[n:] 

188 

189 

190def _get_mpint(data: memoryview) -> typing.Tuple[int, memoryview]: 

191 """Big integer.""" 

192 val, data = _get_sshstr(data) 

193 if val and val[0] > 0x7F: 

194 raise ValueError("Invalid data") 

195 return int.from_bytes(val, "big"), data 

196 

197 

198def _to_mpint(val: int) -> bytes: 

199 """Storage format for signed bigint.""" 

200 if val < 0: 

201 raise ValueError("negative mpint not allowed") 

202 if not val: 

203 return b"" 

204 nbytes = (val.bit_length() + 8) // 8 

205 return utils.int_to_bytes(val, nbytes) 

206 

207 

208class _FragList: 

209 """Build recursive structure without data copy.""" 

210 

211 flist: typing.List[bytes] 

212 

213 def __init__( 

214 self, init: typing.Optional[typing.List[bytes]] = None 

215 ) -> None: 

216 self.flist = [] 

217 if init: 

218 self.flist.extend(init) 

219 

220 def put_raw(self, val: bytes) -> None: 

221 """Add plain bytes""" 

222 self.flist.append(val) 

223 

224 def put_u32(self, val: int) -> None: 

225 """Big-endian uint32""" 

226 self.flist.append(val.to_bytes(length=4, byteorder="big")) 

227 

228 def put_u64(self, val: int) -> None: 

229 """Big-endian uint64""" 

230 self.flist.append(val.to_bytes(length=8, byteorder="big")) 

231 

232 def put_sshstr(self, val: typing.Union[bytes, "_FragList"]) -> None: 

233 """Bytes prefixed with u32 length""" 

234 if isinstance(val, (bytes, memoryview, bytearray)): 

235 self.put_u32(len(val)) 

236 self.flist.append(val) 

237 else: 

238 self.put_u32(val.size()) 

239 self.flist.extend(val.flist) 

240 

241 def put_mpint(self, val: int) -> None: 

242 """Big-endian bigint prefixed with u32 length""" 

243 self.put_sshstr(_to_mpint(val)) 

244 

245 def size(self) -> int: 

246 """Current number of bytes""" 

247 return sum(map(len, self.flist)) 

248 

249 def render(self, dstbuf: memoryview, pos: int = 0) -> int: 

250 """Write into bytearray""" 

251 for frag in self.flist: 

252 flen = len(frag) 

253 start, pos = pos, pos + flen 

254 dstbuf[start:pos] = frag 

255 return pos 

256 

257 def tobytes(self) -> bytes: 

258 """Return as bytes""" 

259 buf = memoryview(bytearray(self.size())) 

260 self.render(buf) 

261 return buf.tobytes() 

262 

263 

264class _SSHFormatRSA: 

265 """Format for RSA keys. 

266 

267 Public: 

268 mpint e, n 

269 Private: 

270 mpint n, e, d, iqmp, p, q 

271 """ 

272 

273 def get_public(self, data: memoryview): 

274 """RSA public fields""" 

275 e, data = _get_mpint(data) 

276 n, data = _get_mpint(data) 

277 return (e, n), data 

278 

279 def load_public( 

280 self, data: memoryview 

281 ) -> typing.Tuple[rsa.RSAPublicKey, memoryview]: 

282 """Make RSA public key from data.""" 

283 (e, n), data = self.get_public(data) 

284 public_numbers = rsa.RSAPublicNumbers(e, n) 

285 public_key = public_numbers.public_key() 

286 return public_key, data 

287 

288 def load_private( 

289 self, data: memoryview, pubfields 

290 ) -> typing.Tuple[rsa.RSAPrivateKey, memoryview]: 

291 """Make RSA private key from data.""" 

292 n, data = _get_mpint(data) 

293 e, data = _get_mpint(data) 

294 d, data = _get_mpint(data) 

295 iqmp, data = _get_mpint(data) 

296 p, data = _get_mpint(data) 

297 q, data = _get_mpint(data) 

298 

299 if (e, n) != pubfields: 

300 raise ValueError("Corrupt data: rsa field mismatch") 

301 dmp1 = rsa.rsa_crt_dmp1(d, p) 

302 dmq1 = rsa.rsa_crt_dmq1(d, q) 

303 public_numbers = rsa.RSAPublicNumbers(e, n) 

304 private_numbers = rsa.RSAPrivateNumbers( 

305 p, q, d, dmp1, dmq1, iqmp, public_numbers 

306 ) 

307 private_key = private_numbers.private_key() 

308 return private_key, data 

309 

310 def encode_public( 

311 self, public_key: rsa.RSAPublicKey, f_pub: _FragList 

312 ) -> None: 

313 """Write RSA public key""" 

314 pubn = public_key.public_numbers() 

315 f_pub.put_mpint(pubn.e) 

316 f_pub.put_mpint(pubn.n) 

317 

318 def encode_private( 

319 self, private_key: rsa.RSAPrivateKey, f_priv: _FragList 

320 ) -> None: 

321 """Write RSA private key""" 

322 private_numbers = private_key.private_numbers() 

323 public_numbers = private_numbers.public_numbers 

324 

325 f_priv.put_mpint(public_numbers.n) 

326 f_priv.put_mpint(public_numbers.e) 

327 

328 f_priv.put_mpint(private_numbers.d) 

329 f_priv.put_mpint(private_numbers.iqmp) 

330 f_priv.put_mpint(private_numbers.p) 

331 f_priv.put_mpint(private_numbers.q) 

332 

333 

334class _SSHFormatDSA: 

335 """Format for DSA keys. 

336 

337 Public: 

338 mpint p, q, g, y 

339 Private: 

340 mpint p, q, g, y, x 

341 """ 

342 

343 def get_public( 

344 self, data: memoryview 

345 ) -> typing.Tuple[typing.Tuple, memoryview]: 

346 """DSA public fields""" 

347 p, data = _get_mpint(data) 

348 q, data = _get_mpint(data) 

349 g, data = _get_mpint(data) 

350 y, data = _get_mpint(data) 

351 return (p, q, g, y), data 

352 

353 def load_public( 

354 self, data: memoryview 

355 ) -> typing.Tuple[dsa.DSAPublicKey, memoryview]: 

356 """Make DSA public key from data.""" 

357 (p, q, g, y), data = self.get_public(data) 

358 parameter_numbers = dsa.DSAParameterNumbers(p, q, g) 

359 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) 

360 self._validate(public_numbers) 

361 public_key = public_numbers.public_key() 

362 return public_key, data 

363 

364 def load_private( 

365 self, data: memoryview, pubfields 

366 ) -> typing.Tuple[dsa.DSAPrivateKey, memoryview]: 

367 """Make DSA private key from data.""" 

368 (p, q, g, y), data = self.get_public(data) 

369 x, data = _get_mpint(data) 

370 

371 if (p, q, g, y) != pubfields: 

372 raise ValueError("Corrupt data: dsa field mismatch") 

373 parameter_numbers = dsa.DSAParameterNumbers(p, q, g) 

374 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) 

375 self._validate(public_numbers) 

376 private_numbers = dsa.DSAPrivateNumbers(x, public_numbers) 

377 private_key = private_numbers.private_key() 

378 return private_key, data 

379 

380 def encode_public( 

381 self, public_key: dsa.DSAPublicKey, f_pub: _FragList 

382 ) -> None: 

383 """Write DSA public key""" 

384 public_numbers = public_key.public_numbers() 

385 parameter_numbers = public_numbers.parameter_numbers 

386 self._validate(public_numbers) 

387 

388 f_pub.put_mpint(parameter_numbers.p) 

389 f_pub.put_mpint(parameter_numbers.q) 

390 f_pub.put_mpint(parameter_numbers.g) 

391 f_pub.put_mpint(public_numbers.y) 

392 

393 def encode_private( 

394 self, private_key: dsa.DSAPrivateKey, f_priv: _FragList 

395 ) -> None: 

396 """Write DSA private key""" 

397 self.encode_public(private_key.public_key(), f_priv) 

398 f_priv.put_mpint(private_key.private_numbers().x) 

399 

400 def _validate(self, public_numbers: dsa.DSAPublicNumbers) -> None: 

401 parameter_numbers = public_numbers.parameter_numbers 

402 if parameter_numbers.p.bit_length() != 1024: 

403 raise ValueError("SSH supports only 1024 bit DSA keys") 

404 

405 

406class _SSHFormatECDSA: 

407 """Format for ECDSA keys. 

408 

409 Public: 

410 str curve 

411 bytes point 

412 Private: 

413 str curve 

414 bytes point 

415 mpint secret 

416 """ 

417 

418 def __init__(self, ssh_curve_name: bytes, curve: ec.EllipticCurve): 

419 self.ssh_curve_name = ssh_curve_name 

420 self.curve = curve 

421 

422 def get_public( 

423 self, data: memoryview 

424 ) -> typing.Tuple[typing.Tuple, memoryview]: 

425 """ECDSA public fields""" 

426 curve, data = _get_sshstr(data) 

427 point, data = _get_sshstr(data) 

428 if curve != self.ssh_curve_name: 

429 raise ValueError("Curve name mismatch") 

430 if point[0] != 4: 

431 raise NotImplementedError("Need uncompressed point") 

432 return (curve, point), data 

433 

434 def load_public( 

435 self, data: memoryview 

436 ) -> typing.Tuple[ec.EllipticCurvePublicKey, memoryview]: 

437 """Make ECDSA public key from data.""" 

438 (curve_name, point), data = self.get_public(data) 

439 public_key = ec.EllipticCurvePublicKey.from_encoded_point( 

440 self.curve, point.tobytes() 

441 ) 

442 return public_key, data 

443 

444 def load_private( 

445 self, data: memoryview, pubfields 

446 ) -> typing.Tuple[ec.EllipticCurvePrivateKey, memoryview]: 

447 """Make ECDSA private key from data.""" 

448 (curve_name, point), data = self.get_public(data) 

449 secret, data = _get_mpint(data) 

450 

451 if (curve_name, point) != pubfields: 

452 raise ValueError("Corrupt data: ecdsa field mismatch") 

453 private_key = ec.derive_private_key(secret, self.curve) 

454 return private_key, data 

455 

456 def encode_public( 

457 self, public_key: ec.EllipticCurvePublicKey, f_pub: _FragList 

458 ) -> None: 

459 """Write ECDSA public key""" 

460 point = public_key.public_bytes( 

461 Encoding.X962, PublicFormat.UncompressedPoint 

462 ) 

463 f_pub.put_sshstr(self.ssh_curve_name) 

464 f_pub.put_sshstr(point) 

465 

466 def encode_private( 

467 self, private_key: ec.EllipticCurvePrivateKey, f_priv: _FragList 

468 ) -> None: 

469 """Write ECDSA private key""" 

470 public_key = private_key.public_key() 

471 private_numbers = private_key.private_numbers() 

472 

473 self.encode_public(public_key, f_priv) 

474 f_priv.put_mpint(private_numbers.private_value) 

475 

476 

477class _SSHFormatEd25519: 

478 """Format for Ed25519 keys. 

479 

480 Public: 

481 bytes point 

482 Private: 

483 bytes point 

484 bytes secret_and_point 

485 """ 

486 

487 def get_public( 

488 self, data: memoryview 

489 ) -> typing.Tuple[typing.Tuple, memoryview]: 

490 """Ed25519 public fields""" 

491 point, data = _get_sshstr(data) 

492 return (point,), data 

493 

494 def load_public( 

495 self, data: memoryview 

496 ) -> typing.Tuple[ed25519.Ed25519PublicKey, memoryview]: 

497 """Make Ed25519 public key from data.""" 

498 (point,), data = self.get_public(data) 

499 public_key = ed25519.Ed25519PublicKey.from_public_bytes( 

500 point.tobytes() 

501 ) 

502 return public_key, data 

503 

504 def load_private( 

505 self, data: memoryview, pubfields 

506 ) -> typing.Tuple[ed25519.Ed25519PrivateKey, memoryview]: 

507 """Make Ed25519 private key from data.""" 

508 (point,), data = self.get_public(data) 

509 keypair, data = _get_sshstr(data) 

510 

511 secret = keypair[:32] 

512 point2 = keypair[32:] 

513 if point != point2 or (point,) != pubfields: 

514 raise ValueError("Corrupt data: ed25519 field mismatch") 

515 private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret) 

516 return private_key, data 

517 

518 def encode_public( 

519 self, public_key: ed25519.Ed25519PublicKey, f_pub: _FragList 

520 ) -> None: 

521 """Write Ed25519 public key""" 

522 raw_public_key = public_key.public_bytes( 

523 Encoding.Raw, PublicFormat.Raw 

524 ) 

525 f_pub.put_sshstr(raw_public_key) 

526 

527 def encode_private( 

528 self, private_key: ed25519.Ed25519PrivateKey, f_priv: _FragList 

529 ) -> None: 

530 """Write Ed25519 private key""" 

531 public_key = private_key.public_key() 

532 raw_private_key = private_key.private_bytes( 

533 Encoding.Raw, PrivateFormat.Raw, NoEncryption() 

534 ) 

535 raw_public_key = public_key.public_bytes( 

536 Encoding.Raw, PublicFormat.Raw 

537 ) 

538 f_keypair = _FragList([raw_private_key, raw_public_key]) 

539 

540 self.encode_public(public_key, f_priv) 

541 f_priv.put_sshstr(f_keypair) 

542 

543 

544_KEY_FORMATS = { 

545 _SSH_RSA: _SSHFormatRSA(), 

546 _SSH_DSA: _SSHFormatDSA(), 

547 _SSH_ED25519: _SSHFormatEd25519(), 

548 _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()), 

549 _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()), 

550 _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()), 

551} 

552 

553 

554def _lookup_kformat(key_type: bytes): 

555 """Return valid format or throw error""" 

556 if not isinstance(key_type, bytes): 

557 key_type = memoryview(key_type).tobytes() 

558 if key_type in _KEY_FORMATS: 

559 return _KEY_FORMATS[key_type] 

560 raise UnsupportedAlgorithm(f"Unsupported key type: {key_type!r}") 

561 

562 

563SSHPrivateKeyTypes = typing.Union[ 

564 ec.EllipticCurvePrivateKey, 

565 rsa.RSAPrivateKey, 

566 dsa.DSAPrivateKey, 

567 ed25519.Ed25519PrivateKey, 

568] 

569 

570 

571def load_ssh_private_key( 

572 data: bytes, 

573 password: typing.Optional[bytes], 

574 backend: typing.Any = None, 

575) -> SSHPrivateKeyTypes: 

576 """Load private key from OpenSSH custom encoding.""" 

577 utils._check_byteslike("data", data) 

578 if password is not None: 

579 utils._check_bytes("password", password) 

580 

581 m = _PEM_RC.search(data) 

582 if not m: 

583 raise ValueError("Not OpenSSH private key format") 

584 p1 = m.start(1) 

585 p2 = m.end(1) 

586 data = binascii.a2b_base64(memoryview(data)[p1:p2]) 

587 if not data.startswith(_SK_MAGIC): 

588 raise ValueError("Not OpenSSH private key format") 

589 data = memoryview(data)[len(_SK_MAGIC) :] 

590 

591 # parse header 

592 ciphername, data = _get_sshstr(data) 

593 kdfname, data = _get_sshstr(data) 

594 kdfoptions, data = _get_sshstr(data) 

595 nkeys, data = _get_u32(data) 

596 if nkeys != 1: 

597 raise ValueError("Only one key supported") 

598 

599 # load public key data 

600 pubdata, data = _get_sshstr(data) 

601 pub_key_type, pubdata = _get_sshstr(pubdata) 

602 kformat = _lookup_kformat(pub_key_type) 

603 pubfields, pubdata = kformat.get_public(pubdata) 

604 _check_empty(pubdata) 

605 

606 # load secret data 

607 edata, data = _get_sshstr(data) 

608 _check_empty(data) 

609 

610 if (ciphername, kdfname) != (_NONE, _NONE): 

611 ciphername_bytes = ciphername.tobytes() 

612 if ciphername_bytes not in _SSH_CIPHERS: 

613 raise UnsupportedAlgorithm( 

614 f"Unsupported cipher: {ciphername_bytes!r}" 

615 ) 

616 if kdfname != _BCRYPT: 

617 raise UnsupportedAlgorithm(f"Unsupported KDF: {kdfname!r}") 

618 blklen = _SSH_CIPHERS[ciphername_bytes][3] 

619 _check_block_size(edata, blklen) 

620 salt, kbuf = _get_sshstr(kdfoptions) 

621 rounds, kbuf = _get_u32(kbuf) 

622 _check_empty(kbuf) 

623 ciph = _init_cipher(ciphername_bytes, password, salt.tobytes(), rounds) 

624 edata = memoryview(ciph.decryptor().update(edata)) 

625 else: 

626 blklen = 8 

627 _check_block_size(edata, blklen) 

628 ck1, edata = _get_u32(edata) 

629 ck2, edata = _get_u32(edata) 

630 if ck1 != ck2: 

631 raise ValueError("Corrupt data: broken checksum") 

632 

633 # load per-key struct 

634 key_type, edata = _get_sshstr(edata) 

635 if key_type != pub_key_type: 

636 raise ValueError("Corrupt data: key type mismatch") 

637 private_key, edata = kformat.load_private(edata, pubfields) 

638 comment, edata = _get_sshstr(edata) 

639 

640 # yes, SSH does padding check *after* all other parsing is done. 

641 # need to follow as it writes zero-byte padding too. 

642 if edata != _PADDING[: len(edata)]: 

643 raise ValueError("Corrupt data: invalid padding") 

644 

645 if isinstance(private_key, dsa.DSAPrivateKey): 

646 warnings.warn( 

647 "SSH DSA keys are deprecated and will be removed in a future " 

648 "release.", 

649 utils.DeprecatedIn40, 

650 stacklevel=2, 

651 ) 

652 

653 return private_key 

654 

655 

656def _serialize_ssh_private_key( 

657 private_key: SSHPrivateKeyTypes, 

658 password: bytes, 

659 encryption_algorithm: KeySerializationEncryption, 

660) -> bytes: 

661 """Serialize private key with OpenSSH custom encoding.""" 

662 utils._check_bytes("password", password) 

663 if isinstance(private_key, dsa.DSAPrivateKey): 

664 warnings.warn( 

665 "SSH DSA key support is deprecated and will be " 

666 "removed in a future release", 

667 utils.DeprecatedIn40, 

668 stacklevel=4, 

669 ) 

670 

671 key_type = _get_ssh_key_type(private_key) 

672 kformat = _lookup_kformat(key_type) 

673 

674 # setup parameters 

675 f_kdfoptions = _FragList() 

676 if password: 

677 ciphername = _DEFAULT_CIPHER 

678 blklen = _SSH_CIPHERS[ciphername][3] 

679 kdfname = _BCRYPT 

680 rounds = _DEFAULT_ROUNDS 

681 if ( 

682 isinstance(encryption_algorithm, _KeySerializationEncryption) 

683 and encryption_algorithm._kdf_rounds is not None 

684 ): 

685 rounds = encryption_algorithm._kdf_rounds 

686 salt = os.urandom(16) 

687 f_kdfoptions.put_sshstr(salt) 

688 f_kdfoptions.put_u32(rounds) 

689 ciph = _init_cipher(ciphername, password, salt, rounds) 

690 else: 

691 ciphername = kdfname = _NONE 

692 blklen = 8 

693 ciph = None 

694 nkeys = 1 

695 checkval = os.urandom(4) 

696 comment = b"" 

697 

698 # encode public and private parts together 

699 f_public_key = _FragList() 

700 f_public_key.put_sshstr(key_type) 

701 kformat.encode_public(private_key.public_key(), f_public_key) 

702 

703 f_secrets = _FragList([checkval, checkval]) 

704 f_secrets.put_sshstr(key_type) 

705 kformat.encode_private(private_key, f_secrets) 

706 f_secrets.put_sshstr(comment) 

707 f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)]) 

708 

709 # top-level structure 

710 f_main = _FragList() 

711 f_main.put_raw(_SK_MAGIC) 

712 f_main.put_sshstr(ciphername) 

713 f_main.put_sshstr(kdfname) 

714 f_main.put_sshstr(f_kdfoptions) 

715 f_main.put_u32(nkeys) 

716 f_main.put_sshstr(f_public_key) 

717 f_main.put_sshstr(f_secrets) 

718 

719 # copy result info bytearray 

720 slen = f_secrets.size() 

721 mlen = f_main.size() 

722 buf = memoryview(bytearray(mlen + blklen)) 

723 f_main.render(buf) 

724 ofs = mlen - slen 

725 

726 # encrypt in-place 

727 if ciph is not None: 

728 ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:]) 

729 

730 return _ssh_pem_encode(buf[:mlen]) 

731 

732 

733SSHPublicKeyTypes = typing.Union[ 

734 ec.EllipticCurvePublicKey, 

735 rsa.RSAPublicKey, 

736 dsa.DSAPublicKey, 

737 ed25519.Ed25519PublicKey, 

738] 

739 

740SSHCertPublicKeyTypes = typing.Union[ 

741 ec.EllipticCurvePublicKey, 

742 rsa.RSAPublicKey, 

743 ed25519.Ed25519PublicKey, 

744] 

745 

746 

747class SSHCertificateType(enum.Enum): 

748 USER = 1 

749 HOST = 2 

750 

751 

752class SSHCertificate: 

753 def __init__( 

754 self, 

755 _nonce: memoryview, 

756 _public_key: SSHPublicKeyTypes, 

757 _serial: int, 

758 _cctype: int, 

759 _key_id: memoryview, 

760 _valid_principals: typing.List[bytes], 

761 _valid_after: int, 

762 _valid_before: int, 

763 _critical_options: typing.Dict[bytes, bytes], 

764 _extensions: typing.Dict[bytes, bytes], 

765 _sig_type: memoryview, 

766 _sig_key: memoryview, 

767 _inner_sig_type: memoryview, 

768 _signature: memoryview, 

769 _tbs_cert_body: memoryview, 

770 _cert_key_type: bytes, 

771 _cert_body: memoryview, 

772 ): 

773 self._nonce = _nonce 

774 self._public_key = _public_key 

775 self._serial = _serial 

776 try: 

777 self._type = SSHCertificateType(_cctype) 

778 except ValueError: 

779 raise ValueError("Invalid certificate type") 

780 self._key_id = _key_id 

781 self._valid_principals = _valid_principals 

782 self._valid_after = _valid_after 

783 self._valid_before = _valid_before 

784 self._critical_options = _critical_options 

785 self._extensions = _extensions 

786 self._sig_type = _sig_type 

787 self._sig_key = _sig_key 

788 self._inner_sig_type = _inner_sig_type 

789 self._signature = _signature 

790 self._cert_key_type = _cert_key_type 

791 self._cert_body = _cert_body 

792 self._tbs_cert_body = _tbs_cert_body 

793 

794 @property 

795 def nonce(self) -> bytes: 

796 return bytes(self._nonce) 

797 

798 def public_key(self) -> SSHCertPublicKeyTypes: 

799 # make mypy happy until we remove DSA support entirely and 

800 # the underlying union won't have a disallowed type 

801 return typing.cast(SSHCertPublicKeyTypes, self._public_key) 

802 

803 @property 

804 def serial(self) -> int: 

805 return self._serial 

806 

807 @property 

808 def type(self) -> SSHCertificateType: 

809 return self._type 

810 

811 @property 

812 def key_id(self) -> bytes: 

813 return bytes(self._key_id) 

814 

815 @property 

816 def valid_principals(self) -> typing.List[bytes]: 

817 return self._valid_principals 

818 

819 @property 

820 def valid_before(self) -> int: 

821 return self._valid_before 

822 

823 @property 

824 def valid_after(self) -> int: 

825 return self._valid_after 

826 

827 @property 

828 def critical_options(self) -> typing.Dict[bytes, bytes]: 

829 return self._critical_options 

830 

831 @property 

832 def extensions(self) -> typing.Dict[bytes, bytes]: 

833 return self._extensions 

834 

835 def signature_key(self) -> SSHCertPublicKeyTypes: 

836 sigformat = _lookup_kformat(self._sig_type) 

837 signature_key, sigkey_rest = sigformat.load_public(self._sig_key) 

838 _check_empty(sigkey_rest) 

839 return signature_key 

840 

841 def public_bytes(self) -> bytes: 

842 return ( 

843 bytes(self._cert_key_type) 

844 + b" " 

845 + binascii.b2a_base64(bytes(self._cert_body), newline=False) 

846 ) 

847 

848 def verify_cert_signature(self) -> None: 

849 signature_key = self.signature_key() 

850 if isinstance(signature_key, ed25519.Ed25519PublicKey): 

851 signature_key.verify( 

852 bytes(self._signature), bytes(self._tbs_cert_body) 

853 ) 

854 elif isinstance(signature_key, ec.EllipticCurvePublicKey): 

855 # The signature is encoded as a pair of big-endian integers 

856 r, data = _get_mpint(self._signature) 

857 s, data = _get_mpint(data) 

858 _check_empty(data) 

859 computed_sig = asym_utils.encode_dss_signature(r, s) 

860 hash_alg = _get_ec_hash_alg(signature_key.curve) 

861 signature_key.verify( 

862 computed_sig, bytes(self._tbs_cert_body), ec.ECDSA(hash_alg) 

863 ) 

864 else: 

865 assert isinstance(signature_key, rsa.RSAPublicKey) 

866 if self._inner_sig_type == _SSH_RSA: 

867 hash_alg = hashes.SHA1() 

868 elif self._inner_sig_type == _SSH_RSA_SHA256: 

869 hash_alg = hashes.SHA256() 

870 else: 

871 assert self._inner_sig_type == _SSH_RSA_SHA512 

872 hash_alg = hashes.SHA512() 

873 signature_key.verify( 

874 bytes(self._signature), 

875 bytes(self._tbs_cert_body), 

876 padding.PKCS1v15(), 

877 hash_alg, 

878 ) 

879 

880 

881def _get_ec_hash_alg(curve: ec.EllipticCurve) -> hashes.HashAlgorithm: 

882 if isinstance(curve, ec.SECP256R1): 

883 return hashes.SHA256() 

884 elif isinstance(curve, ec.SECP384R1): 

885 return hashes.SHA384() 

886 else: 

887 assert isinstance(curve, ec.SECP521R1) 

888 return hashes.SHA512() 

889 

890 

891def _load_ssh_public_identity( 

892 data: bytes, 

893 _legacy_dsa_allowed=False, 

894) -> typing.Union[SSHCertificate, SSHPublicKeyTypes]: 

895 utils._check_byteslike("data", data) 

896 

897 m = _SSH_PUBKEY_RC.match(data) 

898 if not m: 

899 raise ValueError("Invalid line format") 

900 key_type = orig_key_type = m.group(1) 

901 key_body = m.group(2) 

902 with_cert = False 

903 if key_type.endswith(_CERT_SUFFIX): 

904 with_cert = True 

905 key_type = key_type[: -len(_CERT_SUFFIX)] 

906 if key_type == _SSH_DSA and not _legacy_dsa_allowed: 

907 raise UnsupportedAlgorithm( 

908 "DSA keys aren't supported in SSH certificates" 

909 ) 

910 kformat = _lookup_kformat(key_type) 

911 

912 try: 

913 rest = memoryview(binascii.a2b_base64(key_body)) 

914 except (TypeError, binascii.Error): 

915 raise ValueError("Invalid format") 

916 

917 if with_cert: 

918 cert_body = rest 

919 inner_key_type, rest = _get_sshstr(rest) 

920 if inner_key_type != orig_key_type: 

921 raise ValueError("Invalid key format") 

922 if with_cert: 

923 nonce, rest = _get_sshstr(rest) 

924 public_key, rest = kformat.load_public(rest) 

925 if with_cert: 

926 serial, rest = _get_u64(rest) 

927 cctype, rest = _get_u32(rest) 

928 key_id, rest = _get_sshstr(rest) 

929 principals, rest = _get_sshstr(rest) 

930 valid_principals = [] 

931 while principals: 

932 principal, principals = _get_sshstr(principals) 

933 valid_principals.append(bytes(principal)) 

934 valid_after, rest = _get_u64(rest) 

935 valid_before, rest = _get_u64(rest) 

936 crit_options, rest = _get_sshstr(rest) 

937 critical_options = _parse_exts_opts(crit_options) 

938 exts, rest = _get_sshstr(rest) 

939 extensions = _parse_exts_opts(exts) 

940 # Get the reserved field, which is unused. 

941 _, rest = _get_sshstr(rest) 

942 sig_key_raw, rest = _get_sshstr(rest) 

943 sig_type, sig_key = _get_sshstr(sig_key_raw) 

944 if sig_type == _SSH_DSA and not _legacy_dsa_allowed: 

945 raise UnsupportedAlgorithm( 

946 "DSA signatures aren't supported in SSH certificates" 

947 ) 

948 # Get the entire cert body and subtract the signature 

949 tbs_cert_body = cert_body[: -len(rest)] 

950 signature_raw, rest = _get_sshstr(rest) 

951 _check_empty(rest) 

952 inner_sig_type, sig_rest = _get_sshstr(signature_raw) 

953 # RSA certs can have multiple algorithm types 

954 if ( 

955 sig_type == _SSH_RSA 

956 and inner_sig_type 

957 not in [_SSH_RSA_SHA256, _SSH_RSA_SHA512, _SSH_RSA] 

958 ) or (sig_type != _SSH_RSA and inner_sig_type != sig_type): 

959 raise ValueError("Signature key type does not match") 

960 signature, sig_rest = _get_sshstr(sig_rest) 

961 _check_empty(sig_rest) 

962 return SSHCertificate( 

963 nonce, 

964 public_key, 

965 serial, 

966 cctype, 

967 key_id, 

968 valid_principals, 

969 valid_after, 

970 valid_before, 

971 critical_options, 

972 extensions, 

973 sig_type, 

974 sig_key, 

975 inner_sig_type, 

976 signature, 

977 tbs_cert_body, 

978 orig_key_type, 

979 cert_body, 

980 ) 

981 else: 

982 _check_empty(rest) 

983 return public_key 

984 

985 

986def load_ssh_public_identity( 

987 data: bytes, 

988) -> typing.Union[SSHCertificate, SSHPublicKeyTypes]: 

989 return _load_ssh_public_identity(data) 

990 

991 

992def _parse_exts_opts(exts_opts: memoryview) -> typing.Dict[bytes, bytes]: 

993 result: typing.Dict[bytes, bytes] = {} 

994 last_name = None 

995 while exts_opts: 

996 name, exts_opts = _get_sshstr(exts_opts) 

997 bname: bytes = bytes(name) 

998 if bname in result: 

999 raise ValueError("Duplicate name") 

1000 if last_name is not None and bname < last_name: 

1001 raise ValueError("Fields not lexically sorted") 

1002 value, exts_opts = _get_sshstr(exts_opts) 

1003 result[bname] = bytes(value) 

1004 last_name = bname 

1005 return result 

1006 

1007 

1008def load_ssh_public_key( 

1009 data: bytes, backend: typing.Any = None 

1010) -> SSHPublicKeyTypes: 

1011 cert_or_key = _load_ssh_public_identity(data, _legacy_dsa_allowed=True) 

1012 public_key: SSHPublicKeyTypes 

1013 if isinstance(cert_or_key, SSHCertificate): 

1014 public_key = cert_or_key.public_key() 

1015 else: 

1016 public_key = cert_or_key 

1017 

1018 if isinstance(public_key, dsa.DSAPublicKey): 

1019 warnings.warn( 

1020 "SSH DSA keys are deprecated and will be removed in a future " 

1021 "release.", 

1022 utils.DeprecatedIn40, 

1023 stacklevel=2, 

1024 ) 

1025 return public_key 

1026 

1027 

1028def serialize_ssh_public_key(public_key: SSHPublicKeyTypes) -> bytes: 

1029 """One-line public key format for OpenSSH""" 

1030 if isinstance(public_key, dsa.DSAPublicKey): 

1031 warnings.warn( 

1032 "SSH DSA key support is deprecated and will be " 

1033 "removed in a future release", 

1034 utils.DeprecatedIn40, 

1035 stacklevel=4, 

1036 ) 

1037 key_type = _get_ssh_key_type(public_key) 

1038 kformat = _lookup_kformat(key_type) 

1039 

1040 f_pub = _FragList() 

1041 f_pub.put_sshstr(key_type) 

1042 kformat.encode_public(public_key, f_pub) 

1043 

1044 pub = binascii.b2a_base64(f_pub.tobytes()).strip() 

1045 return b"".join([key_type, b" ", pub]) 

1046 

1047 

1048SSHCertPrivateKeyTypes = typing.Union[ 

1049 ec.EllipticCurvePrivateKey, 

1050 rsa.RSAPrivateKey, 

1051 ed25519.Ed25519PrivateKey, 

1052] 

1053 

1054 

1055# This is an undocumented limit enforced in the openssh codebase for sshd and 

1056# ssh-keygen, but it is undefined in the ssh certificates spec. 

1057_SSHKEY_CERT_MAX_PRINCIPALS = 256 

1058 

1059 

1060class SSHCertificateBuilder: 

1061 def __init__( 

1062 self, 

1063 _public_key: typing.Optional[SSHCertPublicKeyTypes] = None, 

1064 _serial: typing.Optional[int] = None, 

1065 _type: typing.Optional[SSHCertificateType] = None, 

1066 _key_id: typing.Optional[bytes] = None, 

1067 _valid_principals: typing.List[bytes] = [], 

1068 _valid_for_all_principals: bool = False, 

1069 _valid_before: typing.Optional[int] = None, 

1070 _valid_after: typing.Optional[int] = None, 

1071 _critical_options: typing.List[typing.Tuple[bytes, bytes]] = [], 

1072 _extensions: typing.List[typing.Tuple[bytes, bytes]] = [], 

1073 ): 

1074 self._public_key = _public_key 

1075 self._serial = _serial 

1076 self._type = _type 

1077 self._key_id = _key_id 

1078 self._valid_principals = _valid_principals 

1079 self._valid_for_all_principals = _valid_for_all_principals 

1080 self._valid_before = _valid_before 

1081 self._valid_after = _valid_after 

1082 self._critical_options = _critical_options 

1083 self._extensions = _extensions 

1084 

1085 def public_key( 

1086 self, public_key: SSHCertPublicKeyTypes 

1087 ) -> "SSHCertificateBuilder": 

1088 if not isinstance( 

1089 public_key, 

1090 ( 

1091 ec.EllipticCurvePublicKey, 

1092 rsa.RSAPublicKey, 

1093 ed25519.Ed25519PublicKey, 

1094 ), 

1095 ): 

1096 raise TypeError("Unsupported key type") 

1097 if self._public_key is not None: 

1098 raise ValueError("public_key already set") 

1099 

1100 return SSHCertificateBuilder( 

1101 _public_key=public_key, 

1102 _serial=self._serial, 

1103 _type=self._type, 

1104 _key_id=self._key_id, 

1105 _valid_principals=self._valid_principals, 

1106 _valid_for_all_principals=self._valid_for_all_principals, 

1107 _valid_before=self._valid_before, 

1108 _valid_after=self._valid_after, 

1109 _critical_options=self._critical_options, 

1110 _extensions=self._extensions, 

1111 ) 

1112 

1113 def serial(self, serial: int) -> "SSHCertificateBuilder": 

1114 if not isinstance(serial, int): 

1115 raise TypeError("serial must be an integer") 

1116 if not 0 <= serial < 2**64: 

1117 raise ValueError("serial must be between 0 and 2**64") 

1118 if self._serial is not None: 

1119 raise ValueError("serial already set") 

1120 

1121 return SSHCertificateBuilder( 

1122 _public_key=self._public_key, 

1123 _serial=serial, 

1124 _type=self._type, 

1125 _key_id=self._key_id, 

1126 _valid_principals=self._valid_principals, 

1127 _valid_for_all_principals=self._valid_for_all_principals, 

1128 _valid_before=self._valid_before, 

1129 _valid_after=self._valid_after, 

1130 _critical_options=self._critical_options, 

1131 _extensions=self._extensions, 

1132 ) 

1133 

1134 def type(self, type: SSHCertificateType) -> "SSHCertificateBuilder": 

1135 if not isinstance(type, SSHCertificateType): 

1136 raise TypeError("type must be an SSHCertificateType") 

1137 if self._type is not None: 

1138 raise ValueError("type already set") 

1139 

1140 return SSHCertificateBuilder( 

1141 _public_key=self._public_key, 

1142 _serial=self._serial, 

1143 _type=type, 

1144 _key_id=self._key_id, 

1145 _valid_principals=self._valid_principals, 

1146 _valid_for_all_principals=self._valid_for_all_principals, 

1147 _valid_before=self._valid_before, 

1148 _valid_after=self._valid_after, 

1149 _critical_options=self._critical_options, 

1150 _extensions=self._extensions, 

1151 ) 

1152 

1153 def key_id(self, key_id: bytes) -> "SSHCertificateBuilder": 

1154 if not isinstance(key_id, bytes): 

1155 raise TypeError("key_id must be bytes") 

1156 if self._key_id is not None: 

1157 raise ValueError("key_id already set") 

1158 

1159 return SSHCertificateBuilder( 

1160 _public_key=self._public_key, 

1161 _serial=self._serial, 

1162 _type=self._type, 

1163 _key_id=key_id, 

1164 _valid_principals=self._valid_principals, 

1165 _valid_for_all_principals=self._valid_for_all_principals, 

1166 _valid_before=self._valid_before, 

1167 _valid_after=self._valid_after, 

1168 _critical_options=self._critical_options, 

1169 _extensions=self._extensions, 

1170 ) 

1171 

1172 def valid_principals( 

1173 self, valid_principals: typing.List[bytes] 

1174 ) -> "SSHCertificateBuilder": 

1175 if self._valid_for_all_principals: 

1176 raise ValueError( 

1177 "Principals can't be set because the cert is valid " 

1178 "for all principals" 

1179 ) 

1180 if ( 

1181 not all(isinstance(x, bytes) for x in valid_principals) 

1182 or not valid_principals 

1183 ): 

1184 raise TypeError( 

1185 "principals must be a list of bytes and can't be empty" 

1186 ) 

1187 if self._valid_principals: 

1188 raise ValueError("valid_principals already set") 

1189 

1190 if len(valid_principals) > _SSHKEY_CERT_MAX_PRINCIPALS: 

1191 raise ValueError( 

1192 "Reached or exceeded the maximum number of valid_principals" 

1193 ) 

1194 

1195 return SSHCertificateBuilder( 

1196 _public_key=self._public_key, 

1197 _serial=self._serial, 

1198 _type=self._type, 

1199 _key_id=self._key_id, 

1200 _valid_principals=valid_principals, 

1201 _valid_for_all_principals=self._valid_for_all_principals, 

1202 _valid_before=self._valid_before, 

1203 _valid_after=self._valid_after, 

1204 _critical_options=self._critical_options, 

1205 _extensions=self._extensions, 

1206 ) 

1207 

1208 def valid_for_all_principals(self): 

1209 if self._valid_principals: 

1210 raise ValueError( 

1211 "valid_principals already set, can't set " 

1212 "valid_for_all_principals" 

1213 ) 

1214 if self._valid_for_all_principals: 

1215 raise ValueError("valid_for_all_principals already set") 

1216 

1217 return SSHCertificateBuilder( 

1218 _public_key=self._public_key, 

1219 _serial=self._serial, 

1220 _type=self._type, 

1221 _key_id=self._key_id, 

1222 _valid_principals=self._valid_principals, 

1223 _valid_for_all_principals=True, 

1224 _valid_before=self._valid_before, 

1225 _valid_after=self._valid_after, 

1226 _critical_options=self._critical_options, 

1227 _extensions=self._extensions, 

1228 ) 

1229 

1230 def valid_before( 

1231 self, valid_before: typing.Union[int, float] 

1232 ) -> "SSHCertificateBuilder": 

1233 if not isinstance(valid_before, (int, float)): 

1234 raise TypeError("valid_before must be an int or float") 

1235 valid_before = int(valid_before) 

1236 if valid_before < 0 or valid_before >= 2**64: 

1237 raise ValueError("valid_before must [0, 2**64)") 

1238 if self._valid_before is not None: 

1239 raise ValueError("valid_before already set") 

1240 

1241 return SSHCertificateBuilder( 

1242 _public_key=self._public_key, 

1243 _serial=self._serial, 

1244 _type=self._type, 

1245 _key_id=self._key_id, 

1246 _valid_principals=self._valid_principals, 

1247 _valid_for_all_principals=self._valid_for_all_principals, 

1248 _valid_before=valid_before, 

1249 _valid_after=self._valid_after, 

1250 _critical_options=self._critical_options, 

1251 _extensions=self._extensions, 

1252 ) 

1253 

1254 def valid_after( 

1255 self, valid_after: typing.Union[int, float] 

1256 ) -> "SSHCertificateBuilder": 

1257 if not isinstance(valid_after, (int, float)): 

1258 raise TypeError("valid_after must be an int or float") 

1259 valid_after = int(valid_after) 

1260 if valid_after < 0 or valid_after >= 2**64: 

1261 raise ValueError("valid_after must [0, 2**64)") 

1262 if self._valid_after is not None: 

1263 raise ValueError("valid_after already set") 

1264 

1265 return SSHCertificateBuilder( 

1266 _public_key=self._public_key, 

1267 _serial=self._serial, 

1268 _type=self._type, 

1269 _key_id=self._key_id, 

1270 _valid_principals=self._valid_principals, 

1271 _valid_for_all_principals=self._valid_for_all_principals, 

1272 _valid_before=self._valid_before, 

1273 _valid_after=valid_after, 

1274 _critical_options=self._critical_options, 

1275 _extensions=self._extensions, 

1276 ) 

1277 

1278 def add_critical_option( 

1279 self, name: bytes, value: bytes 

1280 ) -> "SSHCertificateBuilder": 

1281 if not isinstance(name, bytes) or not isinstance(value, bytes): 

1282 raise TypeError("name and value must be bytes") 

1283 # This is O(n**2) 

1284 if name in [name for name, _ in self._critical_options]: 

1285 raise ValueError("Duplicate critical option name") 

1286 

1287 return SSHCertificateBuilder( 

1288 _public_key=self._public_key, 

1289 _serial=self._serial, 

1290 _type=self._type, 

1291 _key_id=self._key_id, 

1292 _valid_principals=self._valid_principals, 

1293 _valid_for_all_principals=self._valid_for_all_principals, 

1294 _valid_before=self._valid_before, 

1295 _valid_after=self._valid_after, 

1296 _critical_options=self._critical_options + [(name, value)], 

1297 _extensions=self._extensions, 

1298 ) 

1299 

1300 def add_extension( 

1301 self, name: bytes, value: bytes 

1302 ) -> "SSHCertificateBuilder": 

1303 if not isinstance(name, bytes) or not isinstance(value, bytes): 

1304 raise TypeError("name and value must be bytes") 

1305 # This is O(n**2) 

1306 if name in [name for name, _ in self._extensions]: 

1307 raise ValueError("Duplicate extension name") 

1308 

1309 return SSHCertificateBuilder( 

1310 _public_key=self._public_key, 

1311 _serial=self._serial, 

1312 _type=self._type, 

1313 _key_id=self._key_id, 

1314 _valid_principals=self._valid_principals, 

1315 _valid_for_all_principals=self._valid_for_all_principals, 

1316 _valid_before=self._valid_before, 

1317 _valid_after=self._valid_after, 

1318 _critical_options=self._critical_options, 

1319 _extensions=self._extensions + [(name, value)], 

1320 ) 

1321 

1322 def sign(self, private_key: SSHCertPrivateKeyTypes) -> SSHCertificate: 

1323 if not isinstance( 

1324 private_key, 

1325 ( 

1326 ec.EllipticCurvePrivateKey, 

1327 rsa.RSAPrivateKey, 

1328 ed25519.Ed25519PrivateKey, 

1329 ), 

1330 ): 

1331 raise TypeError("Unsupported private key type") 

1332 

1333 if self._public_key is None: 

1334 raise ValueError("public_key must be set") 

1335 

1336 # Not required 

1337 serial = 0 if self._serial is None else self._serial 

1338 

1339 if self._type is None: 

1340 raise ValueError("type must be set") 

1341 

1342 # Not required 

1343 key_id = b"" if self._key_id is None else self._key_id 

1344 

1345 # A zero length list is valid, but means the certificate 

1346 # is valid for any principal of the specified type. We require 

1347 # the user to explicitly set valid_for_all_principals to get 

1348 # that behavior. 

1349 if not self._valid_principals and not self._valid_for_all_principals: 

1350 raise ValueError( 

1351 "valid_principals must be set if valid_for_all_principals " 

1352 "is False" 

1353 ) 

1354 

1355 if self._valid_before is None: 

1356 raise ValueError("valid_before must be set") 

1357 

1358 if self._valid_after is None: 

1359 raise ValueError("valid_after must be set") 

1360 

1361 if self._valid_after > self._valid_before: 

1362 raise ValueError("valid_after must be earlier than valid_before") 

1363 

1364 # lexically sort our byte strings 

1365 self._critical_options.sort(key=lambda x: x[0]) 

1366 self._extensions.sort(key=lambda x: x[0]) 

1367 

1368 key_type = _get_ssh_key_type(self._public_key) 

1369 cert_prefix = key_type + _CERT_SUFFIX 

1370 

1371 # Marshal the bytes to be signed 

1372 nonce = os.urandom(32) 

1373 kformat = _lookup_kformat(key_type) 

1374 f = _FragList() 

1375 f.put_sshstr(cert_prefix) 

1376 f.put_sshstr(nonce) 

1377 kformat.encode_public(self._public_key, f) 

1378 f.put_u64(serial) 

1379 f.put_u32(self._type.value) 

1380 f.put_sshstr(key_id) 

1381 fprincipals = _FragList() 

1382 for p in self._valid_principals: 

1383 fprincipals.put_sshstr(p) 

1384 f.put_sshstr(fprincipals.tobytes()) 

1385 f.put_u64(self._valid_after) 

1386 f.put_u64(self._valid_before) 

1387 fcrit = _FragList() 

1388 for name, value in self._critical_options: 

1389 fcrit.put_sshstr(name) 

1390 fcrit.put_sshstr(value) 

1391 f.put_sshstr(fcrit.tobytes()) 

1392 fext = _FragList() 

1393 for name, value in self._extensions: 

1394 fext.put_sshstr(name) 

1395 fext.put_sshstr(value) 

1396 f.put_sshstr(fext.tobytes()) 

1397 f.put_sshstr(b"") # RESERVED FIELD 

1398 # encode CA public key 

1399 ca_type = _get_ssh_key_type(private_key) 

1400 caformat = _lookup_kformat(ca_type) 

1401 caf = _FragList() 

1402 caf.put_sshstr(ca_type) 

1403 caformat.encode_public(private_key.public_key(), caf) 

1404 f.put_sshstr(caf.tobytes()) 

1405 # Sigs according to the rules defined for the CA's public key 

1406 # (RFC4253 section 6.6 for ssh-rsa, RFC5656 for ECDSA, 

1407 # and RFC8032 for Ed25519). 

1408 if isinstance(private_key, ed25519.Ed25519PrivateKey): 

1409 signature = private_key.sign(f.tobytes()) 

1410 fsig = _FragList() 

1411 fsig.put_sshstr(ca_type) 

1412 fsig.put_sshstr(signature) 

1413 f.put_sshstr(fsig.tobytes()) 

1414 elif isinstance(private_key, ec.EllipticCurvePrivateKey): 

1415 hash_alg = _get_ec_hash_alg(private_key.curve) 

1416 signature = private_key.sign(f.tobytes(), ec.ECDSA(hash_alg)) 

1417 r, s = asym_utils.decode_dss_signature(signature) 

1418 fsig = _FragList() 

1419 fsig.put_sshstr(ca_type) 

1420 fsigblob = _FragList() 

1421 fsigblob.put_mpint(r) 

1422 fsigblob.put_mpint(s) 

1423 fsig.put_sshstr(fsigblob.tobytes()) 

1424 f.put_sshstr(fsig.tobytes()) 

1425 

1426 else: 

1427 assert isinstance(private_key, rsa.RSAPrivateKey) 

1428 # Just like Golang, we're going to use SHA512 for RSA 

1429 # https://cs.opensource.google/go/x/crypto/+/refs/tags/ 

1430 # v0.4.0:ssh/certs.go;l=445 

1431 # RFC 8332 defines SHA256 and 512 as options 

1432 fsig = _FragList() 

1433 fsig.put_sshstr(_SSH_RSA_SHA512) 

1434 signature = private_key.sign( 

1435 f.tobytes(), padding.PKCS1v15(), hashes.SHA512() 

1436 ) 

1437 fsig.put_sshstr(signature) 

1438 f.put_sshstr(fsig.tobytes()) 

1439 

1440 cert_data = binascii.b2a_base64(f.tobytes()).strip() 

1441 # load_ssh_public_identity returns a union, but this is 

1442 # guaranteed to be an SSHCertificate, so we cast to make 

1443 # mypy happy. 

1444 return typing.cast( 

1445 SSHCertificate, 

1446 load_ssh_public_identity(b"".join([cert_prefix, b" ", cert_data])), 

1447 )