Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jwt/jwks_client.py: 27%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import json
4import urllib.request
5from functools import lru_cache
6from ssl import SSLContext
7from typing import Any
8from urllib.error import URLError
10from .api_jwk import PyJWK, PyJWKSet
11from .api_jwt import decode_complete as decode_token
12from .exceptions import PyJWKClientConnectionError, PyJWKClientError
13from .jwk_set_cache import JWKSetCache
16class PyJWKClient:
17 def __init__(
18 self,
19 uri: str,
20 cache_keys: bool = False,
21 max_cached_keys: int = 16,
22 cache_jwk_set: bool = True,
23 lifespan: float = 300,
24 headers: dict[str, Any] | None = None,
25 timeout: float = 30,
26 ssl_context: SSLContext | None = None,
27 ):
28 """A client for retrieving signing keys from a JWKS endpoint.
30 ``PyJWKClient`` uses a two-tier caching system to avoid unnecessary
31 network requests:
33 **Tier 1 — JWK Set cache** (enabled by default):
34 Caches the entire JSON Web Key Set response from the endpoint.
35 Controlled by:
37 - ``cache_jwk_set``: Set to ``True`` (the default) to enable this
38 cache. When enabled, the JWK Set is fetched from the network only
39 when the cache is empty or expired.
40 - ``lifespan``: Time in seconds before the cached JWK Set expires.
41 Defaults to ``300`` (5 minutes). Must be greater than 0.
43 **Tier 2 — Signing key cache** (disabled by default):
44 Caches individual signing keys (looked up by ``kid``) using an LRU
45 cache with **no time-based expiration**. Keys are evicted only when
46 the cache reaches its maximum size. Controlled by:
48 - ``cache_keys``: Set to ``True`` to enable this cache.
49 Defaults to ``False``.
50 - ``max_cached_keys``: Maximum number of signing keys to keep in
51 the LRU cache. Defaults to ``16``.
53 :param uri: The URL of the JWKS endpoint.
54 :type uri: str
55 :param cache_keys: Enable the per-key LRU cache (Tier 2).
56 :type cache_keys: bool
57 :param max_cached_keys: Max entries in the signing key LRU cache.
58 :type max_cached_keys: int
59 :param cache_jwk_set: Enable the JWK Set response cache (Tier 1).
60 :type cache_jwk_set: bool
61 :param lifespan: TTL in seconds for the JWK Set cache.
62 :type lifespan: float
63 :param headers: Optional HTTP headers to include in requests.
64 :type headers: dict or None
65 :param timeout: HTTP request timeout in seconds.
66 :type timeout: float
67 :param ssl_context: Optional SSL context for the request.
68 :type ssl_context: ssl.SSLContext or None
69 """
70 if headers is None:
71 headers = {}
72 self.uri = uri
73 self.jwk_set_cache: JWKSetCache | None = None
74 self.headers = headers
75 self.timeout = timeout
76 self.ssl_context = ssl_context
78 if cache_jwk_set:
79 # Init jwt set cache with default or given lifespan.
80 # Default lifespan is 300 seconds (5 minutes).
81 if lifespan <= 0:
82 raise PyJWKClientError(
83 f'Lifespan must be greater than 0, the input is "{lifespan}"'
84 )
85 self.jwk_set_cache = JWKSetCache(lifespan)
86 else:
87 self.jwk_set_cache = None
89 if cache_keys:
90 # Cache signing keys
91 get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key)
92 # Ignore mypy (https://github.com/python/mypy/issues/2427)
93 self.get_signing_key = get_signing_key # type: ignore[method-assign]
95 def fetch_data(self) -> Any:
96 """Fetch the JWK Set from the JWKS endpoint.
98 Makes an HTTP request to the configured ``uri`` and returns the
99 parsed JSON response. If the JWK Set cache is enabled, the
100 response is stored in the cache.
102 :returns: The parsed JWK Set as a dictionary.
103 :raises PyJWKClientConnectionError: If the HTTP request fails.
104 """
105 jwk_set: Any = None
106 try:
107 r = urllib.request.Request(url=self.uri, headers=self.headers)
108 with urllib.request.urlopen(
109 r, timeout=self.timeout, context=self.ssl_context
110 ) as response:
111 jwk_set = json.load(response)
112 except (URLError, TimeoutError) as e:
113 raise PyJWKClientConnectionError(
114 f'Fail to fetch data from the url, err: "{e}"'
115 ) from e
116 else:
117 return jwk_set
118 finally:
119 if self.jwk_set_cache is not None:
120 self.jwk_set_cache.put(jwk_set)
122 def get_jwk_set(self, refresh: bool = False) -> PyJWKSet:
123 """Return the JWK Set, using the cache when available.
125 :param refresh: Force a fresh fetch from the endpoint, bypassing
126 the cache.
127 :type refresh: bool
128 :returns: The JWK Set.
129 :rtype: PyJWKSet
130 :raises PyJWKClientError: If the endpoint does not return a JSON
131 object.
132 """
133 data = None
134 if self.jwk_set_cache is not None and not refresh:
135 data = self.jwk_set_cache.get()
137 if data is None:
138 data = self.fetch_data()
140 if not isinstance(data, dict):
141 raise PyJWKClientError("The JWKS endpoint did not return a JSON object")
143 return PyJWKSet.from_dict(data)
145 def get_signing_keys(self, refresh: bool = False) -> list[PyJWK]:
146 """Return all signing keys from the JWK Set.
148 Filters the JWK Set to keys whose ``use`` is ``"sig"`` (or
149 unspecified) and that have a ``kid``.
151 :param refresh: Force a fresh fetch from the endpoint, bypassing
152 the cache.
153 :type refresh: bool
154 :returns: A list of signing keys.
155 :rtype: list[PyJWK]
156 :raises PyJWKClientError: If no signing keys are found.
157 """
158 jwk_set = self.get_jwk_set(refresh)
159 signing_keys = [
160 jwk_set_key
161 for jwk_set_key in jwk_set.keys
162 if jwk_set_key.public_key_use in ["sig", None] and jwk_set_key.key_id
163 ]
165 if not signing_keys:
166 raise PyJWKClientError("The JWKS endpoint did not contain any signing keys")
168 return signing_keys
170 def get_signing_key(self, kid: str) -> PyJWK:
171 """Return the signing key matching the given ``kid``.
173 If no match is found in the current JWK Set, the set is
174 refreshed from the endpoint and the lookup is retried once.
176 :param kid: The key ID to look up.
177 :type kid: str
178 :returns: The matching signing key.
179 :rtype: PyJWK
180 :raises PyJWKClientError: If no matching key is found after
181 refreshing.
182 """
183 signing_keys = self.get_signing_keys()
184 signing_key = self.match_kid(signing_keys, kid)
186 if not signing_key:
187 # If no matching signing key from the jwk set, refresh the jwk set and try again.
188 signing_keys = self.get_signing_keys(refresh=True)
189 signing_key = self.match_kid(signing_keys, kid)
191 if not signing_key:
192 raise PyJWKClientError(
193 f'Unable to find a signing key that matches: "{kid}"'
194 )
196 return signing_key
198 def get_signing_key_from_jwt(self, token: str | bytes) -> PyJWK:
199 """Return the signing key for a JWT by reading its ``kid`` header.
201 Extracts the ``kid`` from the token's unverified header and
202 delegates to :meth:`get_signing_key`.
204 :param token: The encoded JWT.
205 :type token: str or bytes
206 :returns: The matching signing key.
207 :rtype: PyJWK
208 """
209 unverified = decode_token(token, options={"verify_signature": False})
210 header = unverified["header"]
211 return self.get_signing_key(header.get("kid"))
213 @staticmethod
214 def match_kid(signing_keys: list[PyJWK], kid: str) -> PyJWK | None:
215 """Find a key in *signing_keys* that matches *kid*.
217 :param signing_keys: The list of keys to search.
218 :type signing_keys: list[PyJWK]
219 :param kid: The key ID to match.
220 :type kid: str
221 :returns: The matching key, or ``None`` if not found.
222 :rtype: PyJWK or None
223 """
224 signing_key = None
226 for key in signing_keys:
227 if key.key_id == kid:
228 signing_key = key
229 break
231 return signing_key