Coverage for /pythoncovmergedfiles/medio/medio/src/paramiko/paramiko/ecdsakey.py: 30%

175 statements  

« prev     ^ index     » next       coverage.py v7.2.2, created at 2023-03-26 06:36 +0000

1# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com> 

2# 

3# This file is part of paramiko. 

4# 

5# Paramiko is free software; you can redistribute it and/or modify it under the 

6# terms of the GNU Lesser General Public License as published by the Free 

7# Software Foundation; either version 2.1 of the License, or (at your option) 

8# any later version. 

9# 

10# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY 

11# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 

12# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more 

13# details. 

14# 

15# You should have received a copy of the GNU Lesser General Public License 

16# along with Paramiko; if not, write to the Free Software Foundation, Inc., 

17# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 

18 

19""" 

20ECDSA keys 

21""" 

22 

23from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm 

24from cryptography.hazmat.backends import default_backend 

25from cryptography.hazmat.primitives import hashes, serialization 

26from cryptography.hazmat.primitives.asymmetric import ec 

27from cryptography.hazmat.primitives.asymmetric.utils import ( 

28 decode_dss_signature, 

29 encode_dss_signature, 

30) 

31 

32from paramiko.common import four_byte 

33from paramiko.message import Message 

34from paramiko.pkey import PKey 

35from paramiko.ssh_exception import SSHException 

36from paramiko.util import deflate_long 

37 

38 

39class _ECDSACurve: 

40 """ 

41 Represents a specific ECDSA Curve (nistp256, nistp384, etc). 

42 

43 Handles the generation of the key format identifier and the selection of 

44 the proper hash function. Also grabs the proper curve from the 'ecdsa' 

45 package. 

46 """ 

47 

48 def __init__(self, curve_class, nist_name): 

49 self.nist_name = nist_name 

50 self.key_length = curve_class.key_size 

51 

52 # Defined in RFC 5656 6.2 

53 self.key_format_identifier = "ecdsa-sha2-" + self.nist_name 

54 

55 # Defined in RFC 5656 6.2.1 

56 if self.key_length <= 256: 

57 self.hash_object = hashes.SHA256 

58 elif self.key_length <= 384: 

59 self.hash_object = hashes.SHA384 

60 else: 

61 self.hash_object = hashes.SHA512 

62 

63 self.curve_class = curve_class 

64 

65 

66class _ECDSACurveSet: 

67 """ 

68 A collection to hold the ECDSA curves. Allows querying by oid and by key 

69 format identifier. The two ways in which ECDSAKey needs to be able to look 

70 up curves. 

71 """ 

72 

73 def __init__(self, ecdsa_curves): 

74 self.ecdsa_curves = ecdsa_curves 

75 

76 def get_key_format_identifier_list(self): 

77 return [curve.key_format_identifier for curve in self.ecdsa_curves] 

78 

79 def get_by_curve_class(self, curve_class): 

80 for curve in self.ecdsa_curves: 

81 if curve.curve_class == curve_class: 

82 return curve 

83 

84 def get_by_key_format_identifier(self, key_format_identifier): 

85 for curve in self.ecdsa_curves: 

86 if curve.key_format_identifier == key_format_identifier: 

87 return curve 

88 

89 def get_by_key_length(self, key_length): 

90 for curve in self.ecdsa_curves: 

91 if curve.key_length == key_length: 

92 return curve 

93 

94 

95class ECDSAKey(PKey): 

96 """ 

97 Representation of an ECDSA key which can be used to sign and verify SSH2 

98 data. 

99 """ 

100 

101 _ECDSA_CURVES = _ECDSACurveSet( 

102 [ 

103 _ECDSACurve(ec.SECP256R1, "nistp256"), 

104 _ECDSACurve(ec.SECP384R1, "nistp384"), 

105 _ECDSACurve(ec.SECP521R1, "nistp521"), 

106 ] 

107 ) 

108 

109 def __init__( 

110 self, 

111 msg=None, 

112 data=None, 

113 filename=None, 

114 password=None, 

115 vals=None, 

116 file_obj=None, 

117 validate_point=True, 

118 ): 

119 self.verifying_key = None 

120 self.signing_key = None 

121 self.public_blob = None 

122 if file_obj is not None: 

123 self._from_private_key(file_obj, password) 

124 return 

125 if filename is not None: 

126 self._from_private_key_file(filename, password) 

127 return 

128 if (msg is None) and (data is not None): 

129 msg = Message(data) 

130 if vals is not None: 

131 self.signing_key, self.verifying_key = vals 

132 c_class = self.signing_key.curve.__class__ 

133 self.ecdsa_curve = self._ECDSA_CURVES.get_by_curve_class(c_class) 

134 else: 

135 # Must set ecdsa_curve first; subroutines called herein may need to 

136 # spit out our get_name(), which relies on this. 

137 key_type = msg.get_text() 

138 # But this also means we need to hand it a real key/curve 

139 # identifier, so strip out any cert business. (NOTE: could push 

140 # that into _ECDSACurveSet.get_by_key_format_identifier(), but it 

141 # feels more correct to do it here?) 

142 suffix = "-cert-v01@openssh.com" 

143 if key_type.endswith(suffix): 

144 key_type = key_type[: -len(suffix)] 

145 self.ecdsa_curve = self._ECDSA_CURVES.get_by_key_format_identifier( 

146 key_type 

147 ) 

148 key_types = self._ECDSA_CURVES.get_key_format_identifier_list() 

149 cert_types = [ 

150 "{}-cert-v01@openssh.com".format(x) for x in key_types 

151 ] 

152 self._check_type_and_load_cert( 

153 msg=msg, key_type=key_types, cert_type=cert_types 

154 ) 

155 curvename = msg.get_text() 

156 if curvename != self.ecdsa_curve.nist_name: 

157 raise SSHException( 

158 "Can't handle curve of type {}".format(curvename) 

159 ) 

160 

161 pointinfo = msg.get_binary() 

162 try: 

163 key = ec.EllipticCurvePublicKey.from_encoded_point( 

164 self.ecdsa_curve.curve_class(), pointinfo 

165 ) 

166 self.verifying_key = key 

167 except ValueError: 

168 raise SSHException("Invalid public key") 

169 

170 @classmethod 

171 def supported_key_format_identifiers(cls): 

172 return cls._ECDSA_CURVES.get_key_format_identifier_list() 

173 

174 def asbytes(self): 

175 key = self.verifying_key 

176 m = Message() 

177 m.add_string(self.ecdsa_curve.key_format_identifier) 

178 m.add_string(self.ecdsa_curve.nist_name) 

179 

180 numbers = key.public_numbers() 

181 

182 key_size_bytes = (key.curve.key_size + 7) // 8 

183 

184 x_bytes = deflate_long(numbers.x, add_sign_padding=False) 

185 x_bytes = b"\x00" * (key_size_bytes - len(x_bytes)) + x_bytes 

186 

187 y_bytes = deflate_long(numbers.y, add_sign_padding=False) 

188 y_bytes = b"\x00" * (key_size_bytes - len(y_bytes)) + y_bytes 

189 

190 point_str = four_byte + x_bytes + y_bytes 

191 m.add_string(point_str) 

192 return m.asbytes() 

193 

194 def __str__(self): 

195 return self.asbytes() 

196 

197 @property 

198 def _fields(self): 

199 return ( 

200 self.get_name(), 

201 self.verifying_key.public_numbers().x, 

202 self.verifying_key.public_numbers().y, 

203 ) 

204 

205 def get_name(self): 

206 return self.ecdsa_curve.key_format_identifier 

207 

208 def get_bits(self): 

209 return self.ecdsa_curve.key_length 

210 

211 def can_sign(self): 

212 return self.signing_key is not None 

213 

214 def sign_ssh_data(self, data, algorithm=None): 

215 ecdsa = ec.ECDSA(self.ecdsa_curve.hash_object()) 

216 sig = self.signing_key.sign(data, ecdsa) 

217 r, s = decode_dss_signature(sig) 

218 

219 m = Message() 

220 m.add_string(self.ecdsa_curve.key_format_identifier) 

221 m.add_string(self._sigencode(r, s)) 

222 return m 

223 

224 def verify_ssh_sig(self, data, msg): 

225 if msg.get_text() != self.ecdsa_curve.key_format_identifier: 

226 return False 

227 sig = msg.get_binary() 

228 sigR, sigS = self._sigdecode(sig) 

229 signature = encode_dss_signature(sigR, sigS) 

230 

231 try: 

232 self.verifying_key.verify( 

233 signature, data, ec.ECDSA(self.ecdsa_curve.hash_object()) 

234 ) 

235 except InvalidSignature: 

236 return False 

237 else: 

238 return True 

239 

240 def write_private_key_file(self, filename, password=None): 

241 self._write_private_key_file( 

242 filename, 

243 self.signing_key, 

244 serialization.PrivateFormat.TraditionalOpenSSL, 

245 password=password, 

246 ) 

247 

248 def write_private_key(self, file_obj, password=None): 

249 self._write_private_key( 

250 file_obj, 

251 self.signing_key, 

252 serialization.PrivateFormat.TraditionalOpenSSL, 

253 password=password, 

254 ) 

255 

256 @classmethod 

257 def generate(cls, curve=ec.SECP256R1(), progress_func=None, bits=None): 

258 """ 

259 Generate a new private ECDSA key. This factory function can be used to 

260 generate a new host key or authentication key. 

261 

262 :param progress_func: Not used for this type of key. 

263 :returns: A new private key (`.ECDSAKey`) object 

264 """ 

265 if bits is not None: 

266 curve = cls._ECDSA_CURVES.get_by_key_length(bits) 

267 if curve is None: 

268 raise ValueError("Unsupported key length: {:d}".format(bits)) 

269 curve = curve.curve_class() 

270 

271 private_key = ec.generate_private_key(curve, backend=default_backend()) 

272 return ECDSAKey(vals=(private_key, private_key.public_key())) 

273 

274 # ...internals... 

275 

276 def _from_private_key_file(self, filename, password): 

277 data = self._read_private_key_file("EC", filename, password) 

278 self._decode_key(data) 

279 

280 def _from_private_key(self, file_obj, password): 

281 data = self._read_private_key("EC", file_obj, password) 

282 self._decode_key(data) 

283 

284 def _decode_key(self, data): 

285 pkformat, data = data 

286 if pkformat == self._PRIVATE_KEY_FORMAT_ORIGINAL: 

287 try: 

288 key = serialization.load_der_private_key( 

289 data, password=None, backend=default_backend() 

290 ) 

291 except ( 

292 ValueError, 

293 AssertionError, 

294 TypeError, 

295 UnsupportedAlgorithm, 

296 ) as e: 

297 raise SSHException(str(e)) 

298 elif pkformat == self._PRIVATE_KEY_FORMAT_OPENSSH: 

299 try: 

300 msg = Message(data) 

301 curve_name = msg.get_text() 

302 verkey = msg.get_binary() # noqa: F841 

303 sigkey = msg.get_mpint() 

304 name = "ecdsa-sha2-" + curve_name 

305 curve = self._ECDSA_CURVES.get_by_key_format_identifier(name) 

306 if not curve: 

307 raise SSHException("Invalid key curve identifier") 

308 key = ec.derive_private_key( 

309 sigkey, curve.curve_class(), default_backend() 

310 ) 

311 except Exception as e: 

312 # PKey._read_private_key_openssh() should check or return 

313 # keytype - parsing could fail for any reason due to wrong type 

314 raise SSHException(str(e)) 

315 else: 

316 self._got_bad_key_format_id(pkformat) 

317 

318 self.signing_key = key 

319 self.verifying_key = key.public_key() 

320 curve_class = key.curve.__class__ 

321 self.ecdsa_curve = self._ECDSA_CURVES.get_by_curve_class(curve_class) 

322 

323 def _sigencode(self, r, s): 

324 msg = Message() 

325 msg.add_mpint(r) 

326 msg.add_mpint(s) 

327 return msg.asbytes() 

328 

329 def _sigdecode(self, sig): 

330 msg = Message(sig) 

331 r = msg.get_mpint() 

332 s = msg.get_mpint() 

333 return r, s