Coverage for /pythoncovmergedfiles/medio/medio/src/paramiko/paramiko/ecdsakey.py: 31%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
2#
3# This file is part of paramiko.
4#
5# Paramiko is free software; you can redistribute it and/or modify it under the
6# terms of the GNU Lesser General Public License as published by the Free
7# Software Foundation; either version 2.1 of the License, or (at your option)
8# any later version.
9#
10# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
11# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
12# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
13# details.
14#
15# You should have received a copy of the GNU Lesser General Public License
16# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
17# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
19"""
20ECDSA keys
21"""
23from typing import Optional
25from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
26from cryptography.hazmat.backends import default_backend
27from cryptography.hazmat.primitives import hashes, serialization
28from cryptography.hazmat.primitives.asymmetric import ec
29from cryptography.hazmat.primitives.asymmetric.utils import (
30 decode_dss_signature,
31 encode_dss_signature,
32)
34from paramiko.common import four_byte
35from paramiko.message import Message
36from paramiko.pkey import PKey
37from paramiko.ssh_exception import SSHException
38from paramiko.util import deflate_long
41class _ECDSACurve:
42 """
43 Represents a specific ECDSA Curve (nistp256, nistp384, etc).
45 Handles the generation of the key format identifier and the selection of
46 the proper hash function. Also grabs the proper curve from the 'ecdsa'
47 package.
48 """
50 def __init__(self, curve_class, nist_name):
51 self.nist_name = nist_name
52 self.key_length = curve_class.key_size
54 # Defined in RFC 5656 6.2
55 self.key_format_identifier = "ecdsa-sha2-" + self.nist_name
57 # Defined in RFC 5656 6.2.1
58 if self.key_length <= 256:
59 self.hash_object = hashes.SHA256
60 elif self.key_length <= 384:
61 self.hash_object = hashes.SHA384
62 else:
63 self.hash_object = hashes.SHA512
65 self.curve_class = curve_class
68class _ECDSACurveSet:
69 """
70 A collection to hold the ECDSA curves. Allows querying by oid and by key
71 format identifier. The two ways in which ECDSAKey needs to be able to look
72 up curves.
73 """
75 def __init__(self, ecdsa_curves):
76 self.ecdsa_curves = ecdsa_curves
78 def get_key_format_identifier_list(self):
79 return [curve.key_format_identifier for curve in self.ecdsa_curves]
81 def get_by_curve_class(self, curve_class):
82 for curve in self.ecdsa_curves:
83 if curve.curve_class == curve_class:
84 return curve
86 def get_by_key_format_identifier(self, key_format_identifier):
87 for curve in self.ecdsa_curves:
88 if curve.key_format_identifier == key_format_identifier:
89 return curve
91 def get_by_key_length(self, key_length):
92 for curve in self.ecdsa_curves:
93 if curve.key_length == key_length:
94 return curve
97class ECDSAKey(PKey):
98 """
99 Representation of an ECDSA key which can be used to sign and verify SSH2
100 data.
101 """
103 _ECDSA_CURVES = _ECDSACurveSet(
104 [
105 _ECDSACurve(ec.SECP256R1, "nistp256"),
106 _ECDSACurve(ec.SECP384R1, "nistp384"),
107 _ECDSACurve(ec.SECP521R1, "nistp521"),
108 ]
109 )
111 def __init__(
112 self,
113 msg=None,
114 data=None,
115 filename=None,
116 password=None,
117 vals=None,
118 file_obj=None,
119 # TODO (backwards incompat): remove; it does nothing since porting to
120 # cryptography.io
121 validate_point=True,
122 ):
123 self.verifying_key = None
124 self.signing_key = None
125 self.public_blob = None
126 if file_obj is not None:
127 self._from_private_key(file_obj, password)
128 return
129 if filename is not None:
130 self._from_private_key_file(filename, password)
131 return
132 if (msg is None) and (data is not None):
133 msg = Message(data)
134 if vals is not None:
135 self.signing_key, self.verifying_key = vals
136 c_class = self.signing_key.curve.__class__
137 self.ecdsa_curve = self._ECDSA_CURVES.get_by_curve_class(c_class)
138 else:
139 # Must set ecdsa_curve first; subroutines called herein may need to
140 # spit out our get_name(), which relies on this.
141 key_type = msg.get_text()
142 # But this also means we need to hand it a real key/curve
143 # identifier, so strip out any cert business. (NOTE: could push
144 # that into _ECDSACurveSet.get_by_key_format_identifier(), but it
145 # feels more correct to do it here?)
146 suffix = "-cert-v01@openssh.com"
147 if key_type.endswith(suffix):
148 key_type = key_type[: -len(suffix)]
149 self.ecdsa_curve = self._ECDSA_CURVES.get_by_key_format_identifier(
150 key_type
151 )
152 key_types = self._ECDSA_CURVES.get_key_format_identifier_list()
153 cert_types = [
154 "{}-cert-v01@openssh.com".format(x) for x in key_types
155 ]
156 self._check_type_and_load_cert(
157 msg=msg, key_type=key_types, cert_type=cert_types
158 )
159 curvename = msg.get_text()
160 if curvename != self.ecdsa_curve.nist_name:
161 raise SSHException(
162 "Can't handle curve of type {}".format(curvename)
163 )
165 pointinfo = msg.get_binary()
166 try:
167 key = ec.EllipticCurvePublicKey.from_encoded_point(
168 self.ecdsa_curve.curve_class(), pointinfo
169 )
170 self.verifying_key = key
171 except ValueError:
172 raise SSHException("Invalid public key")
174 @classmethod
175 def identifiers(cls):
176 return cls._ECDSA_CURVES.get_key_format_identifier_list()
178 # TODO (backwards incompat): deprecate/remove
179 @classmethod
180 def supported_key_format_identifiers(cls):
181 return cls.identifiers()
183 def asbytes(self):
184 key = self.verifying_key
185 m = Message()
186 m.add_string(self.ecdsa_curve.key_format_identifier)
187 m.add_string(self.ecdsa_curve.nist_name)
189 numbers = key.public_numbers()
191 key_size_bytes = (key.curve.key_size + 7) // 8
193 x_bytes = deflate_long(numbers.x, add_sign_padding=False)
194 x_bytes = b"\x00" * (key_size_bytes - len(x_bytes)) + x_bytes
196 y_bytes = deflate_long(numbers.y, add_sign_padding=False)
197 y_bytes = b"\x00" * (key_size_bytes - len(y_bytes)) + y_bytes
199 point_str = four_byte + x_bytes + y_bytes
200 m.add_string(point_str)
201 return m.asbytes()
203 def __str__(self):
204 return self.asbytes()
206 @property
207 def _fields(self):
208 return (
209 self.get_name(),
210 self.verifying_key.public_numbers().x,
211 self.verifying_key.public_numbers().y,
212 )
214 def get_name(self):
215 return self.ecdsa_curve.key_format_identifier
217 def get_bits(self):
218 return self.ecdsa_curve.key_length
220 def can_sign(self):
221 return self.signing_key is not None
223 def sign_ssh_data(self, data, algorithm=None):
224 ecdsa = ec.ECDSA(self.ecdsa_curve.hash_object())
225 sig = self.signing_key.sign(data, ecdsa)
226 r, s = decode_dss_signature(sig)
228 m = Message()
229 m.add_string(self.ecdsa_curve.key_format_identifier)
230 m.add_string(self._sigencode(r, s))
231 return m
233 def verify_ssh_sig(self, data, msg):
234 if msg.get_text() != self.ecdsa_curve.key_format_identifier:
235 return False
236 sig = msg.get_binary()
237 sigR, sigS = self._sigdecode(sig)
238 signature = encode_dss_signature(sigR, sigS)
240 try:
241 self.verifying_key.verify(
242 signature, data, ec.ECDSA(self.ecdsa_curve.hash_object())
243 )
244 except InvalidSignature:
245 return False
246 else:
247 return True
249 @property
250 def private_key(self) -> Optional[ec.EllipticCurvePrivateKey]:
251 return self.signing_key
253 @classmethod
254 def generate(cls, curve=ec.SECP256R1(), progress_func=None, bits=None):
255 """
256 Generate a new private ECDSA key. This factory function can be used to
257 generate a new host key or authentication key.
259 :param progress_func: Not used for this type of key.
260 :returns: A new private key (`.ECDSAKey`) object
261 """
262 if bits is not None:
263 curve = cls._ECDSA_CURVES.get_by_key_length(bits)
264 if curve is None:
265 raise ValueError("Unsupported key length: {:d}".format(bits))
266 curve = curve.curve_class()
268 private_key = ec.generate_private_key(curve, backend=default_backend())
269 return ECDSAKey(vals=(private_key, private_key.public_key()))
271 # ...internals...
273 def _from_private_key_file(self, filename, password):
274 data = self._read_private_key_file("EC", filename, password)
275 self._decode_key(data)
277 def _from_private_key(self, file_obj, password):
278 data = self._read_private_key("EC", file_obj, password)
279 self._decode_key(data)
281 def _decode_key(self, data):
282 pkformat, data = data
283 if pkformat == self._PRIVATE_KEY_FORMAT_ORIGINAL:
284 try:
285 key = serialization.load_der_private_key(
286 data, password=None, backend=default_backend()
287 )
288 except (
289 ValueError,
290 AssertionError,
291 TypeError,
292 UnsupportedAlgorithm,
293 ) as e:
294 raise SSHException(str(e))
295 elif pkformat == self._PRIVATE_KEY_FORMAT_OPENSSH:
296 try:
297 msg = Message(data)
298 curve_name = msg.get_text()
299 verkey = msg.get_binary() # noqa: F841
300 sigkey = msg.get_mpint()
301 name = "ecdsa-sha2-" + curve_name
302 curve = self._ECDSA_CURVES.get_by_key_format_identifier(name)
303 if not curve:
304 raise SSHException("Invalid key curve identifier")
305 key = ec.derive_private_key(
306 sigkey, curve.curve_class(), default_backend()
307 )
308 except Exception as e:
309 # PKey._read_private_key_openssh() should check or return
310 # keytype - parsing could fail for any reason due to wrong type
311 raise SSHException(str(e))
312 else:
313 self._got_bad_key_format_id(pkformat)
315 self.signing_key = key
316 self.verifying_key = key.public_key()
317 curve_class = key.curve.__class__
318 self.ecdsa_curve = self._ECDSA_CURVES.get_by_curve_class(curve_class)
320 def _sigencode(self, r, s):
321 msg = Message()
322 msg.add_mpint(r)
323 msg.add_mpint(s)
324 return msg.asbytes()
326 def _sigdecode(self, sig):
327 msg = Message(sig)
328 r = msg.get_mpint()
329 s = msg.get_mpint()
330 return r, s