1import hashlib 
    2 
    3import ecdsa 
    4 
    5from jose.backends.base import Key 
    6from jose.constants import ALGORITHMS 
    7from jose.exceptions import JWKError 
    8from jose.utils import base64_to_long, long_to_base64 
    9 
    10 
    11class ECDSAECKey(Key): 
    12    """ 
    13    Performs signing and verification operations using 
    14    ECDSA and the specified hash function 
    15 
    16    This class requires the ecdsa package to be installed. 
    17 
    18    This is based off of the implementation in PyJWT 0.3.2 
    19    """ 
    20 
    21    SHA256 = hashlib.sha256 
    22    SHA384 = hashlib.sha384 
    23    SHA512 = hashlib.sha512 
    24 
    25    CURVE_MAP = { 
    26        SHA256: ecdsa.curves.NIST256p, 
    27        SHA384: ecdsa.curves.NIST384p, 
    28        SHA512: ecdsa.curves.NIST521p, 
    29    } 
    30    CURVE_NAMES = ( 
    31        (ecdsa.curves.NIST256p, "P-256"), 
    32        (ecdsa.curves.NIST384p, "P-384"), 
    33        (ecdsa.curves.NIST521p, "P-521"), 
    34    ) 
    35 
    36    def __init__(self, key, algorithm): 
    37        if algorithm not in ALGORITHMS.EC: 
    38            raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm) 
    39 
    40        self.hash_alg = { 
    41            ALGORITHMS.ES256: self.SHA256, 
    42            ALGORITHMS.ES384: self.SHA384, 
    43            ALGORITHMS.ES512: self.SHA512, 
    44        }.get(algorithm) 
    45        self._algorithm = algorithm 
    46 
    47        self.curve = self.CURVE_MAP.get(self.hash_alg) 
    48 
    49        if isinstance(key, (ecdsa.SigningKey, ecdsa.VerifyingKey)): 
    50            self.prepared_key = key 
    51            return 
    52 
    53        if isinstance(key, dict): 
    54            self.prepared_key = self._process_jwk(key) 
    55            return 
    56 
    57        if isinstance(key, str): 
    58            key = key.encode("utf-8") 
    59 
    60        if isinstance(key, bytes): 
    61            # Attempt to load key. We don't know if it's 
    62            # a Signing Key or a Verifying Key, so we try 
    63            # the Verifying Key first. 
    64            try: 
    65                key = ecdsa.VerifyingKey.from_pem(key) 
    66            except ecdsa.der.UnexpectedDER: 
    67                key = ecdsa.SigningKey.from_pem(key) 
    68            except Exception as e: 
    69                raise JWKError(e) 
    70 
    71            self.prepared_key = key 
    72            return 
    73 
    74        raise JWKError("Unable to parse an ECKey from key: %s" % key) 
    75 
    76    def _process_jwk(self, jwk_dict): 
    77        if not jwk_dict.get("kty") == "EC": 
    78            raise JWKError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty")) 
    79 
    80        if not all(k in jwk_dict for k in ["x", "y", "crv"]): 
    81            raise JWKError("Mandatory parameters are missing") 
    82 
    83        if "d" in jwk_dict: 
    84            # We are dealing with a private key; the secret exponent is enough 
    85            # to create an ecdsa key. 
    86            d = base64_to_long(jwk_dict.get("d")) 
    87            return ecdsa.keys.SigningKey.from_secret_exponent(d, self.curve) 
    88        else: 
    89            x = base64_to_long(jwk_dict.get("x")) 
    90            y = base64_to_long(jwk_dict.get("y")) 
    91 
    92            if not ecdsa.ecdsa.point_is_valid(self.curve.generator, x, y): 
    93                raise JWKError(f"Point: {x}, {y} is not a valid point") 
    94 
    95            point = ecdsa.ellipticcurve.Point(self.curve.curve, x, y, self.curve.order) 
    96            return ecdsa.keys.VerifyingKey.from_public_point(point, self.curve) 
    97 
    98    def sign(self, msg): 
    99        return self.prepared_key.sign( 
    100            msg, hashfunc=self.hash_alg, sigencode=ecdsa.util.sigencode_string, allow_truncate=False 
    101        ) 
    102 
    103    def verify(self, msg, sig): 
    104        try: 
    105            return self.prepared_key.verify( 
    106                sig, msg, hashfunc=self.hash_alg, sigdecode=ecdsa.util.sigdecode_string, allow_truncate=False 
    107            ) 
    108        except Exception: 
    109            return False 
    110 
    111    def is_public(self): 
    112        return isinstance(self.prepared_key, ecdsa.VerifyingKey) 
    113 
    114    def public_key(self): 
    115        if self.is_public(): 
    116            return self 
    117        return self.__class__(self.prepared_key.get_verifying_key(), self._algorithm) 
    118 
    119    def to_pem(self): 
    120        return self.prepared_key.to_pem() 
    121 
    122    def to_dict(self): 
    123        if not self.is_public(): 
    124            public_key = self.prepared_key.get_verifying_key() 
    125        else: 
    126            public_key = self.prepared_key 
    127        crv = None 
    128        for key, value in self.CURVE_NAMES: 
    129            if key == self.prepared_key.curve: 
    130                crv = value 
    131        if not crv: 
    132            raise KeyError(f"Can't match {self.prepared_key.curve}") 
    133 
    134        # Calculate the key size in bytes. Section 6.2.1.2 and 6.2.1.3 of 
    135        # RFC7518 prescribes that the 'x', 'y' and 'd' parameters of the curve 
    136        # points must be encoded as octed-strings of this length. 
    137        key_size = self.prepared_key.curve.baselen 
    138 
    139        data = { 
    140            "alg": self._algorithm, 
    141            "kty": "EC", 
    142            "crv": crv, 
    143            "x": long_to_base64(public_key.pubkey.point.x(), size=key_size).decode("ASCII"), 
    144            "y": long_to_base64(public_key.pubkey.point.y(), size=key_size).decode("ASCII"), 
    145        } 
    146 
    147        if not self.is_public(): 
    148            data["d"] = long_to_base64(self.prepared_key.privkey.secret_multiplier, size=key_size).decode("ASCII") 
    149 
    150        return data