Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jwt/api_jws.py: 23%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import binascii
4import json
5import warnings
6from collections.abc import Sequence
7from typing import TYPE_CHECKING, Any
9from .algorithms import (
10 Algorithm,
11 get_default_algorithms,
12 has_crypto,
13 requires_cryptography,
14)
15from .api_jwk import PyJWK
16from .exceptions import (
17 DecodeError,
18 InvalidAlgorithmError,
19 InvalidKeyError,
20 InvalidSignatureError,
21 InvalidTokenError,
22)
23from .utils import base64url_decode, base64url_encode
24from .warnings import InsecureKeyLengthWarning, RemovedInPyjwt3Warning
26if TYPE_CHECKING:
27 from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
28 from .types import SigOptions
30_ALGORITHM_UNSET = object()
33class PyJWS:
34 header_typ = "JWT"
36 def __init__(
37 self,
38 algorithms: Sequence[str] | None = None,
39 options: SigOptions | None = None,
40 ) -> None:
41 self._algorithms = get_default_algorithms()
42 self._valid_algs = (
43 set(algorithms) if algorithms is not None else set(self._algorithms)
44 )
46 # Remove algorithms that aren't on the whitelist
47 for key in list(self._algorithms.keys()):
48 if key not in self._valid_algs:
49 del self._algorithms[key]
51 self.options: SigOptions = self._get_default_options()
52 if options is not None:
53 self.options = {**self.options, **options}
55 @staticmethod
56 def _get_default_options() -> SigOptions:
57 return {"verify_signature": True, "enforce_minimum_key_length": False}
59 def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None:
60 """
61 Registers a new Algorithm for use when creating and verifying tokens.
63 :param str alg_id: the ID of the Algorithm
64 :param alg_obj: the Algorithm object
65 :type alg_obj: Algorithm
66 """
67 if alg_id in self._algorithms:
68 raise ValueError("Algorithm already has a handler.")
70 if not isinstance(alg_obj, Algorithm):
71 raise TypeError("Object is not of type `Algorithm`")
73 self._algorithms[alg_id] = alg_obj
74 self._valid_algs.add(alg_id)
76 def unregister_algorithm(self, alg_id: str) -> None:
77 """
78 Unregisters an Algorithm for use when creating and verifying tokens
79 :param str alg_id: the ID of the Algorithm
80 :raises KeyError: if algorithm is not registered.
81 """
82 if alg_id not in self._algorithms:
83 raise KeyError(
84 "The specified algorithm could not be removed"
85 " because it is not registered."
86 )
88 del self._algorithms[alg_id]
89 self._valid_algs.remove(alg_id)
91 def get_algorithms(self) -> list[str]:
92 """
93 Returns a list of supported values for the `alg` parameter.
95 :rtype: list[str]
96 """
97 return list(self._valid_algs)
99 def get_algorithm_by_name(self, alg_name: str) -> Algorithm:
100 """
101 For a given string name, return the matching Algorithm object.
103 Example usage:
104 >>> jws_obj = PyJWS()
105 >>> jws_obj.get_algorithm_by_name("RS256")
107 :param alg_name: The name of the algorithm to retrieve
108 :type alg_name: str
109 :rtype: Algorithm
110 """
111 try:
112 return self._algorithms[alg_name]
113 except KeyError as e:
114 if not has_crypto and alg_name in requires_cryptography:
115 raise NotImplementedError(
116 f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?"
117 ) from e
118 raise NotImplementedError("Algorithm not supported") from e
120 def encode(
121 self,
122 payload: bytes,
123 key: AllowedPrivateKeys | PyJWK | str | bytes,
124 algorithm: str | None = _ALGORITHM_UNSET, # type: ignore[assignment]
125 headers: dict[str, Any] | None = None,
126 json_encoder: type[json.JSONEncoder] | None = None,
127 is_payload_detached: bool = False,
128 sort_headers: bool = True,
129 ) -> str:
130 segments: list[bytes] = []
132 # declare a new var to narrow the type for type checkers
133 if algorithm is _ALGORITHM_UNSET:
134 if isinstance(key, PyJWK):
135 algorithm_ = key.algorithm_name
136 else:
137 algorithm_ = "HS256"
138 elif algorithm is None:
139 if isinstance(key, PyJWK):
140 algorithm_ = key.algorithm_name
141 else:
142 algorithm_ = "none"
143 else:
144 algorithm_ = algorithm
146 # Prefer headers values if present to function parameters.
147 if headers:
148 headers_alg = headers.get("alg")
149 if headers_alg:
150 algorithm_ = headers["alg"]
152 headers_b64 = headers.get("b64")
153 if headers_b64 is False:
154 is_payload_detached = True
156 # Header
157 header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_}
159 if headers:
160 self._validate_headers(headers, encoding=True)
161 header.update(headers)
163 if not header["typ"]:
164 del header["typ"]
166 if is_payload_detached:
167 header["b64"] = False
168 elif "b64" in header:
169 # True is the standard value for b64, so no need for it
170 del header["b64"]
172 json_header = json.dumps(
173 header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers
174 ).encode()
176 segments.append(base64url_encode(json_header))
178 if is_payload_detached:
179 msg_payload = payload
180 else:
181 msg_payload = base64url_encode(payload)
182 segments.append(msg_payload)
184 # Segments
185 signing_input = b".".join(segments)
187 alg_obj = self.get_algorithm_by_name(algorithm_)
188 if isinstance(key, PyJWK):
189 key = key.key
190 key = alg_obj.prepare_key(key)
192 key_length_msg = alg_obj.check_key_length(key)
193 if key_length_msg:
194 if self.options.get("enforce_minimum_key_length", False):
195 raise InvalidKeyError(key_length_msg)
196 else:
197 warnings.warn(key_length_msg, InsecureKeyLengthWarning, stacklevel=2)
199 signature = alg_obj.sign(signing_input, key)
201 segments.append(base64url_encode(signature))
203 # Don't put the payload content inside the encoded token when detached
204 if is_payload_detached:
205 segments[1] = b""
206 encoded_string = b".".join(segments)
208 return encoded_string.decode("utf-8")
210 def decode_complete(
211 self,
212 jwt: str | bytes,
213 key: AllowedPublicKeys | PyJWK | str | bytes = "",
214 algorithms: Sequence[str] | None = None,
215 options: SigOptions | None = None,
216 detached_payload: bytes | None = None,
217 **kwargs: dict[str, Any],
218 ) -> dict[str, Any]:
219 if kwargs:
220 warnings.warn(
221 "passing additional kwargs to decode_complete() is deprecated "
222 "and will be removed in pyjwt version 3. "
223 f"Unsupported kwargs: {tuple(kwargs.keys())}",
224 RemovedInPyjwt3Warning,
225 stacklevel=2,
226 )
227 merged_options: SigOptions
228 if options is None:
229 merged_options = self.options
230 else:
231 merged_options = {**self.options, **options}
233 verify_signature = merged_options["verify_signature"]
235 if verify_signature and not algorithms and not isinstance(key, PyJWK):
236 raise DecodeError(
237 'It is required that you pass in a value for the "algorithms" argument when calling decode().'
238 )
240 payload, signing_input, header, signature = self._load(jwt)
242 self._validate_headers(header)
244 if header.get("b64", True) is False:
245 if detached_payload is None:
246 raise DecodeError(
247 'It is required that you pass in a value for the "detached_payload" argument to decode a message having the b64 header set to false.'
248 )
249 payload = detached_payload
250 signing_input = b".".join([signing_input.rsplit(b".", 1)[0], payload])
252 if verify_signature:
253 self._verify_signature(signing_input, header, signature, key, algorithms)
255 return {
256 "payload": payload,
257 "header": header,
258 "signature": signature,
259 }
261 def decode(
262 self,
263 jwt: str | bytes,
264 key: AllowedPublicKeys | PyJWK | str | bytes = "",
265 algorithms: Sequence[str] | None = None,
266 options: SigOptions | None = None,
267 detached_payload: bytes | None = None,
268 **kwargs: dict[str, Any],
269 ) -> Any:
270 if kwargs:
271 warnings.warn(
272 "passing additional kwargs to decode() is deprecated "
273 "and will be removed in pyjwt version 3. "
274 f"Unsupported kwargs: {tuple(kwargs.keys())}",
275 RemovedInPyjwt3Warning,
276 stacklevel=2,
277 )
278 decoded = self.decode_complete(
279 jwt, key, algorithms, options, detached_payload=detached_payload
280 )
281 return decoded["payload"]
283 def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]:
284 """Returns back the JWT header parameters as a `dict`
286 Note: The signature is not verified so the header parameters
287 should not be fully trusted until signature verification is complete
288 """
289 headers = self._load(jwt)[2]
290 self._validate_headers(headers)
292 return headers
294 def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict[str, Any], bytes]:
295 if isinstance(jwt, str):
296 jwt = jwt.encode("utf-8")
298 if not isinstance(jwt, bytes):
299 raise DecodeError(f"Invalid token type. Token must be a {bytes}")
301 try:
302 signing_input, crypto_segment = jwt.rsplit(b".", 1)
303 header_segment, payload_segment = signing_input.split(b".", 1)
304 except ValueError as err:
305 raise DecodeError("Not enough segments") from err
307 try:
308 header_data = base64url_decode(header_segment)
309 except (TypeError, binascii.Error) as err:
310 raise DecodeError("Invalid header padding") from err
312 try:
313 header: dict[str, Any] = json.loads(header_data)
314 except ValueError as e:
315 raise DecodeError(f"Invalid header string: {e}") from e
317 if not isinstance(header, dict):
318 raise DecodeError("Invalid header string: must be a json object")
320 try:
321 payload = base64url_decode(payload_segment)
322 except (TypeError, binascii.Error) as err:
323 raise DecodeError("Invalid payload padding") from err
325 try:
326 signature = base64url_decode(crypto_segment)
327 except (TypeError, binascii.Error) as err:
328 raise DecodeError("Invalid crypto padding") from err
330 return (payload, signing_input, header, signature)
332 def _verify_signature(
333 self,
334 signing_input: bytes,
335 header: dict[str, Any],
336 signature: bytes,
337 key: AllowedPublicKeys | PyJWK | str | bytes = "",
338 algorithms: Sequence[str] | None = None,
339 ) -> None:
340 if algorithms is None and isinstance(key, PyJWK):
341 algorithms = [key.algorithm_name]
342 try:
343 alg = header["alg"]
344 except KeyError:
345 raise InvalidAlgorithmError("Algorithm not specified") from None
347 if not alg or (algorithms is not None and alg not in algorithms):
348 raise InvalidAlgorithmError("The specified alg value is not allowed")
350 if isinstance(key, PyJWK):
351 alg_obj = key.Algorithm
352 prepared_key = key.key
353 else:
354 try:
355 alg_obj = self.get_algorithm_by_name(alg)
356 except NotImplementedError as e:
357 raise InvalidAlgorithmError("Algorithm not supported") from e
358 prepared_key = alg_obj.prepare_key(key)
360 key_length_msg = alg_obj.check_key_length(prepared_key)
361 if key_length_msg:
362 if self.options.get("enforce_minimum_key_length", False):
363 raise InvalidKeyError(key_length_msg)
364 else:
365 warnings.warn(key_length_msg, InsecureKeyLengthWarning, stacklevel=4)
367 if not alg_obj.verify(signing_input, prepared_key, signature):
368 raise InvalidSignatureError("Signature verification failed")
370 # Extensions that PyJWT actually understands and supports
371 _supported_crit: set[str] = {"b64"}
373 def _validate_headers(
374 self, headers: dict[str, Any], *, encoding: bool = False
375 ) -> None:
376 if "kid" in headers:
377 self._validate_kid(headers["kid"])
378 if not encoding and "crit" in headers:
379 self._validate_crit(headers)
381 def _validate_kid(self, kid: Any) -> None:
382 if not isinstance(kid, str):
383 raise InvalidTokenError("Key ID header parameter must be a string")
385 def _validate_crit(self, headers: dict[str, Any]) -> None:
386 crit = headers["crit"]
387 if not isinstance(crit, list) or len(crit) == 0:
388 raise InvalidTokenError("Invalid 'crit' header: must be a non-empty list")
389 for ext in crit:
390 if not isinstance(ext, str):
391 raise InvalidTokenError("Invalid 'crit' header: values must be strings")
392 if ext not in self._supported_crit:
393 raise InvalidTokenError(f"Unsupported critical extension: {ext}")
394 if ext not in headers:
395 raise InvalidTokenError(
396 f"Critical extension '{ext}' is missing from headers"
397 )
400_jws_global_obj = PyJWS()
401encode = _jws_global_obj.encode
402decode_complete = _jws_global_obj.decode_complete
403decode = _jws_global_obj.decode
404register_algorithm = _jws_global_obj.register_algorithm
405unregister_algorithm = _jws_global_obj.unregister_algorithm
406get_algorithm_by_name = _jws_global_obj.get_algorithm_by_name
407get_unverified_header = _jws_global_obj.get_unverified_header