Coverage for /pythoncovmergedfiles/medio/medio/src/paramiko/paramiko/ecdsakey.py: 30%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

179 statements  

1# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com> 

2# 

3# This file is part of paramiko. 

4# 

5# Paramiko is free software; you can redistribute it and/or modify it under the 

6# terms of the GNU Lesser General Public License as published by the Free 

7# Software Foundation; either version 2.1 of the License, or (at your option) 

8# any later version. 

9# 

10# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY 

11# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 

12# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more 

13# details. 

14# 

15# You should have received a copy of the GNU Lesser General Public License 

16# along with Paramiko; if not, write to the Free Software Foundation, Inc., 

17# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 

18 

19""" 

20ECDSA keys 

21""" 

22 

23from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm 

24from cryptography.hazmat.backends import default_backend 

25from cryptography.hazmat.primitives import hashes, serialization 

26from cryptography.hazmat.primitives.asymmetric import ec 

27from cryptography.hazmat.primitives.asymmetric.utils import ( 

28 decode_dss_signature, 

29 encode_dss_signature, 

30) 

31 

32from paramiko.common import four_byte 

33from paramiko.message import Message 

34from paramiko.pkey import PKey 

35from paramiko.ssh_exception import SSHException 

36from paramiko.util import deflate_long 

37 

38 

39class _ECDSACurve: 

40 """ 

41 Represents a specific ECDSA Curve (nistp256, nistp384, etc). 

42 

43 Handles the generation of the key format identifier and the selection of 

44 the proper hash function. Also grabs the proper curve from the 'ecdsa' 

45 package. 

46 """ 

47 

48 def __init__(self, curve_class, nist_name): 

49 self.nist_name = nist_name 

50 self.key_length = curve_class.key_size 

51 

52 # Defined in RFC 5656 6.2 

53 self.key_format_identifier = "ecdsa-sha2-" + self.nist_name 

54 

55 # Defined in RFC 5656 6.2.1 

56 if self.key_length <= 256: 

57 self.hash_object = hashes.SHA256 

58 elif self.key_length <= 384: 

59 self.hash_object = hashes.SHA384 

60 else: 

61 self.hash_object = hashes.SHA512 

62 

63 self.curve_class = curve_class 

64 

65 

66class _ECDSACurveSet: 

67 """ 

68 A collection to hold the ECDSA curves. Allows querying by oid and by key 

69 format identifier. The two ways in which ECDSAKey needs to be able to look 

70 up curves. 

71 """ 

72 

73 def __init__(self, ecdsa_curves): 

74 self.ecdsa_curves = ecdsa_curves 

75 

76 def get_key_format_identifier_list(self): 

77 return [curve.key_format_identifier for curve in self.ecdsa_curves] 

78 

79 def get_by_curve_class(self, curve_class): 

80 for curve in self.ecdsa_curves: 

81 if curve.curve_class == curve_class: 

82 return curve 

83 

84 def get_by_key_format_identifier(self, key_format_identifier): 

85 for curve in self.ecdsa_curves: 

86 if curve.key_format_identifier == key_format_identifier: 

87 return curve 

88 

89 def get_by_key_length(self, key_length): 

90 for curve in self.ecdsa_curves: 

91 if curve.key_length == key_length: 

92 return curve 

93 

94 

95class ECDSAKey(PKey): 

96 """ 

97 Representation of an ECDSA key which can be used to sign and verify SSH2 

98 data. 

99 """ 

100 

101 _ECDSA_CURVES = _ECDSACurveSet( 

102 [ 

103 _ECDSACurve(ec.SECP256R1, "nistp256"), 

104 _ECDSACurve(ec.SECP384R1, "nistp384"), 

105 _ECDSACurve(ec.SECP521R1, "nistp521"), 

106 ] 

107 ) 

108 

109 def __init__( 

110 self, 

111 msg=None, 

112 data=None, 

113 filename=None, 

114 password=None, 

115 vals=None, 

116 file_obj=None, 

117 # TODO 4.0: remove; it does nothing since porting to cryptography.io 

118 validate_point=True, 

119 ): 

120 self.verifying_key = None 

121 self.signing_key = None 

122 self.public_blob = None 

123 if file_obj is not None: 

124 self._from_private_key(file_obj, password) 

125 return 

126 if filename is not None: 

127 self._from_private_key_file(filename, password) 

128 return 

129 if (msg is None) and (data is not None): 

130 msg = Message(data) 

131 if vals is not None: 

132 self.signing_key, self.verifying_key = vals 

133 c_class = self.signing_key.curve.__class__ 

134 self.ecdsa_curve = self._ECDSA_CURVES.get_by_curve_class(c_class) 

135 else: 

136 # Must set ecdsa_curve first; subroutines called herein may need to 

137 # spit out our get_name(), which relies on this. 

138 key_type = msg.get_text() 

139 # But this also means we need to hand it a real key/curve 

140 # identifier, so strip out any cert business. (NOTE: could push 

141 # that into _ECDSACurveSet.get_by_key_format_identifier(), but it 

142 # feels more correct to do it here?) 

143 suffix = "-cert-v01@openssh.com" 

144 if key_type.endswith(suffix): 

145 key_type = key_type[: -len(suffix)] 

146 self.ecdsa_curve = self._ECDSA_CURVES.get_by_key_format_identifier( 

147 key_type 

148 ) 

149 key_types = self._ECDSA_CURVES.get_key_format_identifier_list() 

150 cert_types = [ 

151 "{}-cert-v01@openssh.com".format(x) for x in key_types 

152 ] 

153 self._check_type_and_load_cert( 

154 msg=msg, key_type=key_types, cert_type=cert_types 

155 ) 

156 curvename = msg.get_text() 

157 if curvename != self.ecdsa_curve.nist_name: 

158 raise SSHException( 

159 "Can't handle curve of type {}".format(curvename) 

160 ) 

161 

162 pointinfo = msg.get_binary() 

163 try: 

164 key = ec.EllipticCurvePublicKey.from_encoded_point( 

165 self.ecdsa_curve.curve_class(), pointinfo 

166 ) 

167 self.verifying_key = key 

168 except ValueError: 

169 raise SSHException("Invalid public key") 

170 

171 @classmethod 

172 def identifiers(cls): 

173 return cls._ECDSA_CURVES.get_key_format_identifier_list() 

174 

175 # TODO 4.0: deprecate/remove 

176 @classmethod 

177 def supported_key_format_identifiers(cls): 

178 return cls.identifiers() 

179 

180 def asbytes(self): 

181 key = self.verifying_key 

182 m = Message() 

183 m.add_string(self.ecdsa_curve.key_format_identifier) 

184 m.add_string(self.ecdsa_curve.nist_name) 

185 

186 numbers = key.public_numbers() 

187 

188 key_size_bytes = (key.curve.key_size + 7) // 8 

189 

190 x_bytes = deflate_long(numbers.x, add_sign_padding=False) 

191 x_bytes = b"\x00" * (key_size_bytes - len(x_bytes)) + x_bytes 

192 

193 y_bytes = deflate_long(numbers.y, add_sign_padding=False) 

194 y_bytes = b"\x00" * (key_size_bytes - len(y_bytes)) + y_bytes 

195 

196 point_str = four_byte + x_bytes + y_bytes 

197 m.add_string(point_str) 

198 return m.asbytes() 

199 

200 def __str__(self): 

201 return self.asbytes() 

202 

203 @property 

204 def _fields(self): 

205 return ( 

206 self.get_name(), 

207 self.verifying_key.public_numbers().x, 

208 self.verifying_key.public_numbers().y, 

209 ) 

210 

211 def get_name(self): 

212 return self.ecdsa_curve.key_format_identifier 

213 

214 def get_bits(self): 

215 return self.ecdsa_curve.key_length 

216 

217 def can_sign(self): 

218 return self.signing_key is not None 

219 

220 def sign_ssh_data(self, data, algorithm=None): 

221 ecdsa = ec.ECDSA(self.ecdsa_curve.hash_object()) 

222 sig = self.signing_key.sign(data, ecdsa) 

223 r, s = decode_dss_signature(sig) 

224 

225 m = Message() 

226 m.add_string(self.ecdsa_curve.key_format_identifier) 

227 m.add_string(self._sigencode(r, s)) 

228 return m 

229 

230 def verify_ssh_sig(self, data, msg): 

231 if msg.get_text() != self.ecdsa_curve.key_format_identifier: 

232 return False 

233 sig = msg.get_binary() 

234 sigR, sigS = self._sigdecode(sig) 

235 signature = encode_dss_signature(sigR, sigS) 

236 

237 try: 

238 self.verifying_key.verify( 

239 signature, data, ec.ECDSA(self.ecdsa_curve.hash_object()) 

240 ) 

241 except InvalidSignature: 

242 return False 

243 else: 

244 return True 

245 

246 def write_private_key_file(self, filename, password=None): 

247 self._write_private_key_file( 

248 filename, 

249 self.signing_key, 

250 serialization.PrivateFormat.TraditionalOpenSSL, 

251 password=password, 

252 ) 

253 

254 def write_private_key(self, file_obj, password=None): 

255 self._write_private_key( 

256 file_obj, 

257 self.signing_key, 

258 serialization.PrivateFormat.TraditionalOpenSSL, 

259 password=password, 

260 ) 

261 

262 @classmethod 

263 def generate(cls, curve=ec.SECP256R1(), progress_func=None, bits=None): 

264 """ 

265 Generate a new private ECDSA key. This factory function can be used to 

266 generate a new host key or authentication key. 

267 

268 :param progress_func: Not used for this type of key. 

269 :returns: A new private key (`.ECDSAKey`) object 

270 """ 

271 if bits is not None: 

272 curve = cls._ECDSA_CURVES.get_by_key_length(bits) 

273 if curve is None: 

274 raise ValueError("Unsupported key length: {:d}".format(bits)) 

275 curve = curve.curve_class() 

276 

277 private_key = ec.generate_private_key(curve, backend=default_backend()) 

278 return ECDSAKey(vals=(private_key, private_key.public_key())) 

279 

280 # ...internals... 

281 

282 def _from_private_key_file(self, filename, password): 

283 data = self._read_private_key_file("EC", filename, password) 

284 self._decode_key(data) 

285 

286 def _from_private_key(self, file_obj, password): 

287 data = self._read_private_key("EC", file_obj, password) 

288 self._decode_key(data) 

289 

290 def _decode_key(self, data): 

291 pkformat, data = data 

292 if pkformat == self._PRIVATE_KEY_FORMAT_ORIGINAL: 

293 try: 

294 key = serialization.load_der_private_key( 

295 data, password=None, backend=default_backend() 

296 ) 

297 except ( 

298 ValueError, 

299 AssertionError, 

300 TypeError, 

301 UnsupportedAlgorithm, 

302 ) as e: 

303 raise SSHException(str(e)) 

304 elif pkformat == self._PRIVATE_KEY_FORMAT_OPENSSH: 

305 try: 

306 msg = Message(data) 

307 curve_name = msg.get_text() 

308 verkey = msg.get_binary() # noqa: F841 

309 sigkey = msg.get_mpint() 

310 name = "ecdsa-sha2-" + curve_name 

311 curve = self._ECDSA_CURVES.get_by_key_format_identifier(name) 

312 if not curve: 

313 raise SSHException("Invalid key curve identifier") 

314 key = ec.derive_private_key( 

315 sigkey, curve.curve_class(), default_backend() 

316 ) 

317 except Exception as e: 

318 # PKey._read_private_key_openssh() should check or return 

319 # keytype - parsing could fail for any reason due to wrong type 

320 raise SSHException(str(e)) 

321 else: 

322 self._got_bad_key_format_id(pkformat) 

323 

324 self.signing_key = key 

325 self.verifying_key = key.public_key() 

326 curve_class = key.curve.__class__ 

327 self.ecdsa_curve = self._ECDSA_CURVES.get_by_curve_class(curve_class) 

328 

329 def _sigencode(self, r, s): 

330 msg = Message() 

331 msg.add_mpint(r) 

332 msg.add_mpint(s) 

333 return msg.asbytes() 

334 

335 def _sigdecode(self, sig): 

336 msg = Message(sig) 

337 r = msg.get_mpint() 

338 s = msg.get_mpint() 

339 return r, s