Coverage for /pythoncovmergedfiles/medio/medio/src/paramiko/paramiko/ecdsakey.py: 30%
175 statements
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 06:36 +0000
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 06:36 +0000
1# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
2#
3# This file is part of paramiko.
4#
5# Paramiko is free software; you can redistribute it and/or modify it under the
6# terms of the GNU Lesser General Public License as published by the Free
7# Software Foundation; either version 2.1 of the License, or (at your option)
8# any later version.
9#
10# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
11# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
12# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
13# details.
14#
15# You should have received a copy of the GNU Lesser General Public License
16# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
17# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
19"""
20ECDSA keys
21"""
23from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
24from cryptography.hazmat.backends import default_backend
25from cryptography.hazmat.primitives import hashes, serialization
26from cryptography.hazmat.primitives.asymmetric import ec
27from cryptography.hazmat.primitives.asymmetric.utils import (
28 decode_dss_signature,
29 encode_dss_signature,
30)
32from paramiko.common import four_byte
33from paramiko.message import Message
34from paramiko.pkey import PKey
35from paramiko.ssh_exception import SSHException
36from paramiko.util import deflate_long
39class _ECDSACurve:
40 """
41 Represents a specific ECDSA Curve (nistp256, nistp384, etc).
43 Handles the generation of the key format identifier and the selection of
44 the proper hash function. Also grabs the proper curve from the 'ecdsa'
45 package.
46 """
48 def __init__(self, curve_class, nist_name):
49 self.nist_name = nist_name
50 self.key_length = curve_class.key_size
52 # Defined in RFC 5656 6.2
53 self.key_format_identifier = "ecdsa-sha2-" + self.nist_name
55 # Defined in RFC 5656 6.2.1
56 if self.key_length <= 256:
57 self.hash_object = hashes.SHA256
58 elif self.key_length <= 384:
59 self.hash_object = hashes.SHA384
60 else:
61 self.hash_object = hashes.SHA512
63 self.curve_class = curve_class
66class _ECDSACurveSet:
67 """
68 A collection to hold the ECDSA curves. Allows querying by oid and by key
69 format identifier. The two ways in which ECDSAKey needs to be able to look
70 up curves.
71 """
73 def __init__(self, ecdsa_curves):
74 self.ecdsa_curves = ecdsa_curves
76 def get_key_format_identifier_list(self):
77 return [curve.key_format_identifier for curve in self.ecdsa_curves]
79 def get_by_curve_class(self, curve_class):
80 for curve in self.ecdsa_curves:
81 if curve.curve_class == curve_class:
82 return curve
84 def get_by_key_format_identifier(self, key_format_identifier):
85 for curve in self.ecdsa_curves:
86 if curve.key_format_identifier == key_format_identifier:
87 return curve
89 def get_by_key_length(self, key_length):
90 for curve in self.ecdsa_curves:
91 if curve.key_length == key_length:
92 return curve
95class ECDSAKey(PKey):
96 """
97 Representation of an ECDSA key which can be used to sign and verify SSH2
98 data.
99 """
101 _ECDSA_CURVES = _ECDSACurveSet(
102 [
103 _ECDSACurve(ec.SECP256R1, "nistp256"),
104 _ECDSACurve(ec.SECP384R1, "nistp384"),
105 _ECDSACurve(ec.SECP521R1, "nistp521"),
106 ]
107 )
109 def __init__(
110 self,
111 msg=None,
112 data=None,
113 filename=None,
114 password=None,
115 vals=None,
116 file_obj=None,
117 validate_point=True,
118 ):
119 self.verifying_key = None
120 self.signing_key = None
121 self.public_blob = None
122 if file_obj is not None:
123 self._from_private_key(file_obj, password)
124 return
125 if filename is not None:
126 self._from_private_key_file(filename, password)
127 return
128 if (msg is None) and (data is not None):
129 msg = Message(data)
130 if vals is not None:
131 self.signing_key, self.verifying_key = vals
132 c_class = self.signing_key.curve.__class__
133 self.ecdsa_curve = self._ECDSA_CURVES.get_by_curve_class(c_class)
134 else:
135 # Must set ecdsa_curve first; subroutines called herein may need to
136 # spit out our get_name(), which relies on this.
137 key_type = msg.get_text()
138 # But this also means we need to hand it a real key/curve
139 # identifier, so strip out any cert business. (NOTE: could push
140 # that into _ECDSACurveSet.get_by_key_format_identifier(), but it
141 # feels more correct to do it here?)
142 suffix = "-cert-v01@openssh.com"
143 if key_type.endswith(suffix):
144 key_type = key_type[: -len(suffix)]
145 self.ecdsa_curve = self._ECDSA_CURVES.get_by_key_format_identifier(
146 key_type
147 )
148 key_types = self._ECDSA_CURVES.get_key_format_identifier_list()
149 cert_types = [
150 "{}-cert-v01@openssh.com".format(x) for x in key_types
151 ]
152 self._check_type_and_load_cert(
153 msg=msg, key_type=key_types, cert_type=cert_types
154 )
155 curvename = msg.get_text()
156 if curvename != self.ecdsa_curve.nist_name:
157 raise SSHException(
158 "Can't handle curve of type {}".format(curvename)
159 )
161 pointinfo = msg.get_binary()
162 try:
163 key = ec.EllipticCurvePublicKey.from_encoded_point(
164 self.ecdsa_curve.curve_class(), pointinfo
165 )
166 self.verifying_key = key
167 except ValueError:
168 raise SSHException("Invalid public key")
170 @classmethod
171 def supported_key_format_identifiers(cls):
172 return cls._ECDSA_CURVES.get_key_format_identifier_list()
174 def asbytes(self):
175 key = self.verifying_key
176 m = Message()
177 m.add_string(self.ecdsa_curve.key_format_identifier)
178 m.add_string(self.ecdsa_curve.nist_name)
180 numbers = key.public_numbers()
182 key_size_bytes = (key.curve.key_size + 7) // 8
184 x_bytes = deflate_long(numbers.x, add_sign_padding=False)
185 x_bytes = b"\x00" * (key_size_bytes - len(x_bytes)) + x_bytes
187 y_bytes = deflate_long(numbers.y, add_sign_padding=False)
188 y_bytes = b"\x00" * (key_size_bytes - len(y_bytes)) + y_bytes
190 point_str = four_byte + x_bytes + y_bytes
191 m.add_string(point_str)
192 return m.asbytes()
194 def __str__(self):
195 return self.asbytes()
197 @property
198 def _fields(self):
199 return (
200 self.get_name(),
201 self.verifying_key.public_numbers().x,
202 self.verifying_key.public_numbers().y,
203 )
205 def get_name(self):
206 return self.ecdsa_curve.key_format_identifier
208 def get_bits(self):
209 return self.ecdsa_curve.key_length
211 def can_sign(self):
212 return self.signing_key is not None
214 def sign_ssh_data(self, data, algorithm=None):
215 ecdsa = ec.ECDSA(self.ecdsa_curve.hash_object())
216 sig = self.signing_key.sign(data, ecdsa)
217 r, s = decode_dss_signature(sig)
219 m = Message()
220 m.add_string(self.ecdsa_curve.key_format_identifier)
221 m.add_string(self._sigencode(r, s))
222 return m
224 def verify_ssh_sig(self, data, msg):
225 if msg.get_text() != self.ecdsa_curve.key_format_identifier:
226 return False
227 sig = msg.get_binary()
228 sigR, sigS = self._sigdecode(sig)
229 signature = encode_dss_signature(sigR, sigS)
231 try:
232 self.verifying_key.verify(
233 signature, data, ec.ECDSA(self.ecdsa_curve.hash_object())
234 )
235 except InvalidSignature:
236 return False
237 else:
238 return True
240 def write_private_key_file(self, filename, password=None):
241 self._write_private_key_file(
242 filename,
243 self.signing_key,
244 serialization.PrivateFormat.TraditionalOpenSSL,
245 password=password,
246 )
248 def write_private_key(self, file_obj, password=None):
249 self._write_private_key(
250 file_obj,
251 self.signing_key,
252 serialization.PrivateFormat.TraditionalOpenSSL,
253 password=password,
254 )
256 @classmethod
257 def generate(cls, curve=ec.SECP256R1(), progress_func=None, bits=None):
258 """
259 Generate a new private ECDSA key. This factory function can be used to
260 generate a new host key or authentication key.
262 :param progress_func: Not used for this type of key.
263 :returns: A new private key (`.ECDSAKey`) object
264 """
265 if bits is not None:
266 curve = cls._ECDSA_CURVES.get_by_key_length(bits)
267 if curve is None:
268 raise ValueError("Unsupported key length: {:d}".format(bits))
269 curve = curve.curve_class()
271 private_key = ec.generate_private_key(curve, backend=default_backend())
272 return ECDSAKey(vals=(private_key, private_key.public_key()))
274 # ...internals...
276 def _from_private_key_file(self, filename, password):
277 data = self._read_private_key_file("EC", filename, password)
278 self._decode_key(data)
280 def _from_private_key(self, file_obj, password):
281 data = self._read_private_key("EC", file_obj, password)
282 self._decode_key(data)
284 def _decode_key(self, data):
285 pkformat, data = data
286 if pkformat == self._PRIVATE_KEY_FORMAT_ORIGINAL:
287 try:
288 key = serialization.load_der_private_key(
289 data, password=None, backend=default_backend()
290 )
291 except (
292 ValueError,
293 AssertionError,
294 TypeError,
295 UnsupportedAlgorithm,
296 ) as e:
297 raise SSHException(str(e))
298 elif pkformat == self._PRIVATE_KEY_FORMAT_OPENSSH:
299 try:
300 msg = Message(data)
301 curve_name = msg.get_text()
302 verkey = msg.get_binary() # noqa: F841
303 sigkey = msg.get_mpint()
304 name = "ecdsa-sha2-" + curve_name
305 curve = self._ECDSA_CURVES.get_by_key_format_identifier(name)
306 if not curve:
307 raise SSHException("Invalid key curve identifier")
308 key = ec.derive_private_key(
309 sigkey, curve.curve_class(), default_backend()
310 )
311 except Exception as e:
312 # PKey._read_private_key_openssh() should check or return
313 # keytype - parsing could fail for any reason due to wrong type
314 raise SSHException(str(e))
315 else:
316 self._got_bad_key_format_id(pkformat)
318 self.signing_key = key
319 self.verifying_key = key.public_key()
320 curve_class = key.curve.__class__
321 self.ecdsa_curve = self._ECDSA_CURVES.get_by_curve_class(curve_class)
323 def _sigencode(self, r, s):
324 msg = Message()
325 msg.add_mpint(r)
326 msg.add_mpint(s)
327 return msg.asbytes()
329 def _sigdecode(self, sig):
330 msg = Message(sig)
331 r = msg.get_mpint()
332 s = msg.get_mpint()
333 return r, s