Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jwt/jwks_client.py: 26%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import json
4import urllib.request
5from functools import lru_cache
6from ssl import SSLContext
7from typing import Any
8from urllib.error import HTTPError, URLError
10from .api_jwk import PyJWK, PyJWKSet
11from .api_jwt import decode_complete as decode_token
12from .exceptions import PyJWKClientConnectionError, PyJWKClientError
13from .jwk_set_cache import JWKSetCache
16class PyJWKClient:
17 def __init__(
18 self,
19 uri: str,
20 cache_keys: bool = False,
21 max_cached_keys: int = 16,
22 cache_jwk_set: bool = True,
23 lifespan: float = 300,
24 headers: dict[str, Any] | None = None,
25 timeout: float = 30,
26 ssl_context: SSLContext | None = None,
27 ):
28 """A client for retrieving signing keys from a JWKS endpoint.
30 ``PyJWKClient`` uses a two-tier caching system to avoid unnecessary
31 network requests:
33 **Tier 1 — JWK Set cache** (enabled by default):
34 Caches the entire JSON Web Key Set response from the endpoint.
35 Controlled by:
37 - ``cache_jwk_set``: Set to ``True`` (the default) to enable this
38 cache. When enabled, the JWK Set is fetched from the network only
39 when the cache is empty or expired.
40 - ``lifespan``: Time in seconds before the cached JWK Set expires.
41 Defaults to ``300`` (5 minutes). Must be greater than 0.
43 **Tier 2 — Signing key cache** (disabled by default):
44 Caches individual signing keys (looked up by ``kid``) using an LRU
45 cache with **no time-based expiration**. Keys are evicted only when
46 the cache reaches its maximum size. Controlled by:
48 - ``cache_keys``: Set to ``True`` to enable this cache.
49 Defaults to ``False``.
50 - ``max_cached_keys``: Maximum number of signing keys to keep in
51 the LRU cache. Defaults to ``16``.
53 :param uri: The URL of the JWKS endpoint.
54 :type uri: str
55 :param cache_keys: Enable the per-key LRU cache (Tier 2).
56 :type cache_keys: bool
57 :param max_cached_keys: Max entries in the signing key LRU cache.
58 :type max_cached_keys: int
59 :param cache_jwk_set: Enable the JWK Set response cache (Tier 1).
60 :type cache_jwk_set: bool
61 :param lifespan: TTL in seconds for the JWK Set cache.
62 :type lifespan: float
63 :param headers: Optional HTTP headers to include in requests.
64 :type headers: dict or None
65 :param timeout: HTTP request timeout in seconds.
66 :type timeout: float
67 :param ssl_context: Optional SSL context for the request.
68 :type ssl_context: ssl.SSLContext or None
69 """
70 if headers is None:
71 headers = {}
72 self.uri = uri
73 self.jwk_set_cache: JWKSetCache | None = None
74 self.headers = headers
75 self.timeout = timeout
76 self.ssl_context = ssl_context
78 if cache_jwk_set:
79 # Init jwt set cache with default or given lifespan.
80 # Default lifespan is 300 seconds (5 minutes).
81 if lifespan <= 0:
82 raise PyJWKClientError(
83 f'Lifespan must be greater than 0, the input is "{lifespan}"'
84 )
85 self.jwk_set_cache = JWKSetCache(lifespan)
86 else:
87 self.jwk_set_cache = None
89 if cache_keys:
90 # Cache signing keys
91 get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key)
92 # Ignore mypy (https://github.com/python/mypy/issues/2427)
93 self.get_signing_key = get_signing_key # type: ignore[method-assign]
95 def fetch_data(self) -> Any:
96 """Fetch the JWK Set from the JWKS endpoint.
98 Makes an HTTP request to the configured ``uri`` and returns the
99 parsed JSON response. If the JWK Set cache is enabled, the
100 response is stored in the cache.
102 :returns: The parsed JWK Set as a dictionary.
103 :raises PyJWKClientConnectionError: If the HTTP request fails.
104 """
105 jwk_set: Any = None
106 try:
107 r = urllib.request.Request(url=self.uri, headers=self.headers)
108 with urllib.request.urlopen(
109 r, timeout=self.timeout, context=self.ssl_context
110 ) as response:
111 jwk_set = json.load(response)
112 except (URLError, TimeoutError) as e:
113 if isinstance(e, HTTPError):
114 e.close()
115 raise PyJWKClientConnectionError(
116 f'Fail to fetch data from the url, err: "{e}"'
117 ) from e
118 else:
119 return jwk_set
120 finally:
121 if self.jwk_set_cache is not None:
122 self.jwk_set_cache.put(jwk_set)
124 def get_jwk_set(self, refresh: bool = False) -> PyJWKSet:
125 """Return the JWK Set, using the cache when available.
127 :param refresh: Force a fresh fetch from the endpoint, bypassing
128 the cache.
129 :type refresh: bool
130 :returns: The JWK Set.
131 :rtype: PyJWKSet
132 :raises PyJWKClientError: If the endpoint does not return a JSON
133 object.
134 """
135 data = None
136 if self.jwk_set_cache is not None and not refresh:
137 data = self.jwk_set_cache.get()
139 if data is None:
140 data = self.fetch_data()
142 if not isinstance(data, dict):
143 raise PyJWKClientError("The JWKS endpoint did not return a JSON object")
145 return PyJWKSet.from_dict(data)
147 def get_signing_keys(self, refresh: bool = False) -> list[PyJWK]:
148 """Return all signing keys from the JWK Set.
150 Filters the JWK Set to keys whose ``use`` is ``"sig"`` (or
151 unspecified) and that have a ``kid``.
153 :param refresh: Force a fresh fetch from the endpoint, bypassing
154 the cache.
155 :type refresh: bool
156 :returns: A list of signing keys.
157 :rtype: list[PyJWK]
158 :raises PyJWKClientError: If no signing keys are found.
159 """
160 jwk_set = self.get_jwk_set(refresh)
161 signing_keys = [
162 jwk_set_key
163 for jwk_set_key in jwk_set.keys
164 if jwk_set_key.public_key_use in ["sig", None] and jwk_set_key.key_id
165 ]
167 if not signing_keys:
168 raise PyJWKClientError("The JWKS endpoint did not contain any signing keys")
170 return signing_keys
172 def get_signing_key(self, kid: str) -> PyJWK:
173 """Return the signing key matching the given ``kid``.
175 If no match is found in the current JWK Set, the set is
176 refreshed from the endpoint and the lookup is retried once.
178 :param kid: The key ID to look up.
179 :type kid: str
180 :returns: The matching signing key.
181 :rtype: PyJWK
182 :raises PyJWKClientError: If no matching key is found after
183 refreshing.
184 """
185 signing_keys = self.get_signing_keys()
186 signing_key = self.match_kid(signing_keys, kid)
188 if not signing_key:
189 # If no matching signing key from the jwk set, refresh the jwk set and try again.
190 signing_keys = self.get_signing_keys(refresh=True)
191 signing_key = self.match_kid(signing_keys, kid)
193 if not signing_key:
194 raise PyJWKClientError(
195 f'Unable to find a signing key that matches: "{kid}"'
196 )
198 return signing_key
200 def get_signing_key_from_jwt(self, token: str | bytes) -> PyJWK:
201 """Return the signing key for a JWT by reading its ``kid`` header.
203 Extracts the ``kid`` from the token's unverified header and
204 delegates to :meth:`get_signing_key`.
206 :param token: The encoded JWT.
207 :type token: str or bytes
208 :returns: The matching signing key.
209 :rtype: PyJWK
210 """
211 unverified = decode_token(token, options={"verify_signature": False})
212 header = unverified["header"]
213 return self.get_signing_key(header.get("kid"))
215 @staticmethod
216 def match_kid(signing_keys: list[PyJWK], kid: str) -> PyJWK | None:
217 """Find a key in *signing_keys* that matches *kid*.
219 :param signing_keys: The list of keys to search.
220 :type signing_keys: list[PyJWK]
221 :param kid: The key ID to match.
222 :type kid: str
223 :returns: The matching key, or ``None`` if not found.
224 :rtype: PyJWK or None
225 """
226 signing_key = None
228 for key in signing_keys:
229 if key.key_id == kid:
230 signing_key = key
231 break
233 return signing_key