1import hashlib
2
3import ecdsa
4
5from jose.backends.base import Key
6from jose.constants import ALGORITHMS
7from jose.exceptions import JWKError
8from jose.utils import base64_to_long, long_to_base64
9
10
11class ECDSAECKey(Key):
12 """
13 Performs signing and verification operations using
14 ECDSA and the specified hash function
15
16 This class requires the ecdsa package to be installed.
17
18 This is based off of the implementation in PyJWT 0.3.2
19 """
20
21 SHA256 = hashlib.sha256
22 SHA384 = hashlib.sha384
23 SHA512 = hashlib.sha512
24
25 CURVE_MAP = {
26 SHA256: ecdsa.curves.NIST256p,
27 SHA384: ecdsa.curves.NIST384p,
28 SHA512: ecdsa.curves.NIST521p,
29 }
30 CURVE_NAMES = (
31 (ecdsa.curves.NIST256p, "P-256"),
32 (ecdsa.curves.NIST384p, "P-384"),
33 (ecdsa.curves.NIST521p, "P-521"),
34 )
35
36 def __init__(self, key, algorithm):
37 if algorithm not in ALGORITHMS.EC:
38 raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
39
40 self.hash_alg = {
41 ALGORITHMS.ES256: self.SHA256,
42 ALGORITHMS.ES384: self.SHA384,
43 ALGORITHMS.ES512: self.SHA512,
44 }.get(algorithm)
45 self._algorithm = algorithm
46
47 self.curve = self.CURVE_MAP.get(self.hash_alg)
48
49 if isinstance(key, (ecdsa.SigningKey, ecdsa.VerifyingKey)):
50 self.prepared_key = key
51 return
52
53 if isinstance(key, dict):
54 self.prepared_key = self._process_jwk(key)
55 return
56
57 if isinstance(key, str):
58 key = key.encode("utf-8")
59
60 if isinstance(key, bytes):
61 # Attempt to load key. We don't know if it's
62 # a Signing Key or a Verifying Key, so we try
63 # the Verifying Key first.
64 try:
65 key = ecdsa.VerifyingKey.from_pem(key)
66 except ecdsa.der.UnexpectedDER:
67 key = ecdsa.SigningKey.from_pem(key)
68 except Exception as e:
69 raise JWKError(e)
70
71 self.prepared_key = key
72 return
73
74 raise JWKError("Unable to parse an ECKey from key: %s" % key)
75
76 def _process_jwk(self, jwk_dict):
77 if not jwk_dict.get("kty") == "EC":
78 raise JWKError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty"))
79
80 if not all(k in jwk_dict for k in ["x", "y", "crv"]):
81 raise JWKError("Mandatory parameters are missing")
82
83 if "d" in jwk_dict:
84 # We are dealing with a private key; the secret exponent is enough
85 # to create an ecdsa key.
86 d = base64_to_long(jwk_dict.get("d"))
87 return ecdsa.keys.SigningKey.from_secret_exponent(d, self.curve)
88 else:
89 x = base64_to_long(jwk_dict.get("x"))
90 y = base64_to_long(jwk_dict.get("y"))
91
92 if not ecdsa.ecdsa.point_is_valid(self.curve.generator, x, y):
93 raise JWKError(f"Point: {x}, {y} is not a valid point")
94
95 point = ecdsa.ellipticcurve.Point(self.curve.curve, x, y, self.curve.order)
96 return ecdsa.keys.VerifyingKey.from_public_point(point, self.curve)
97
98 def sign(self, msg):
99 return self.prepared_key.sign(
100 msg, hashfunc=self.hash_alg, sigencode=ecdsa.util.sigencode_string, allow_truncate=False
101 )
102
103 def verify(self, msg, sig):
104 try:
105 return self.prepared_key.verify(
106 sig, msg, hashfunc=self.hash_alg, sigdecode=ecdsa.util.sigdecode_string, allow_truncate=False
107 )
108 except Exception:
109 return False
110
111 def is_public(self):
112 return isinstance(self.prepared_key, ecdsa.VerifyingKey)
113
114 def public_key(self):
115 if self.is_public():
116 return self
117 return self.__class__(self.prepared_key.get_verifying_key(), self._algorithm)
118
119 def to_pem(self):
120 return self.prepared_key.to_pem()
121
122 def to_dict(self):
123 if not self.is_public():
124 public_key = self.prepared_key.get_verifying_key()
125 else:
126 public_key = self.prepared_key
127 crv = None
128 for key, value in self.CURVE_NAMES:
129 if key == self.prepared_key.curve:
130 crv = value
131 if not crv:
132 raise KeyError(f"Can't match {self.prepared_key.curve}")
133
134 # Calculate the key size in bytes. Section 6.2.1.2 and 6.2.1.3 of
135 # RFC7518 prescribes that the 'x', 'y' and 'd' parameters of the curve
136 # points must be encoded as octed-strings of this length.
137 key_size = self.prepared_key.curve.baselen
138
139 data = {
140 "alg": self._algorithm,
141 "kty": "EC",
142 "crv": crv,
143 "x": long_to_base64(public_key.pubkey.point.x(), size=key_size).decode("ASCII"),
144 "y": long_to_base64(public_key.pubkey.point.y(), size=key_size).decode("ASCII"),
145 }
146
147 if not self.is_public():
148 data["d"] = long_to_base64(self.prepared_key.privkey.secret_multiplier, size=key_size).decode("ASCII")
149
150 return data