Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jwt/jwks_client.py: 26%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import json
4import urllib.request
5from functools import lru_cache
6from ssl import SSLContext
7from typing import Any
8from urllib.error import HTTPError, URLError
9from urllib.parse import urlparse
11from .api_jwk import PyJWK, PyJWKSet
12from .api_jwt import decode_complete as decode_token
13from .exceptions import PyJWKClientConnectionError, PyJWKClientError
14from .jwk_set_cache import JWKSetCache
17class PyJWKClient:
18 def __init__(
19 self,
20 uri: str,
21 cache_keys: bool = False,
22 max_cached_keys: int = 16,
23 cache_jwk_set: bool = True,
24 lifespan: float = 300,
25 headers: dict[str, Any] | None = None,
26 timeout: float = 30,
27 ssl_context: SSLContext | None = None,
28 ):
29 """A client for retrieving signing keys from a JWKS endpoint.
31 ``PyJWKClient`` uses a two-tier caching system to avoid unnecessary
32 network requests:
34 **Tier 1 — JWK Set cache** (enabled by default):
35 Caches the entire JSON Web Key Set response from the endpoint.
36 Controlled by:
38 - ``cache_jwk_set``: Set to ``True`` (the default) to enable this
39 cache. When enabled, the JWK Set is fetched from the network only
40 when the cache is empty or expired.
41 - ``lifespan``: Time in seconds before the cached JWK Set expires.
42 Defaults to ``300`` (5 minutes). Must be greater than 0.
44 **Tier 2 — Signing key cache** (disabled by default):
45 Caches individual signing keys (looked up by ``kid``) using an LRU
46 cache with **no time-based expiration**. Keys are evicted only when
47 the cache reaches its maximum size. Controlled by:
49 - ``cache_keys``: Set to ``True`` to enable this cache.
50 Defaults to ``False``.
51 - ``max_cached_keys``: Maximum number of signing keys to keep in
52 the LRU cache. Defaults to ``16``.
54 :param uri: The URL of the JWKS endpoint.
55 :type uri: str
56 :param cache_keys: Enable the per-key LRU cache (Tier 2).
57 :type cache_keys: bool
58 :param max_cached_keys: Max entries in the signing key LRU cache.
59 :type max_cached_keys: int
60 :param cache_jwk_set: Enable the JWK Set response cache (Tier 1).
61 :type cache_jwk_set: bool
62 :param lifespan: TTL in seconds for the JWK Set cache.
63 :type lifespan: float
64 :param headers: Optional HTTP headers to include in requests.
65 :type headers: dict or None
66 :param timeout: HTTP request timeout in seconds.
67 :type timeout: float
68 :param ssl_context: Optional SSL context for the request.
69 :type ssl_context: ssl.SSLContext or None
70 """
71 if headers is None:
72 headers = {}
73 # urllib's default OpenerDirector also handles file://, ftp://, and
74 # data: URIs. Reject anything that isn't http(s) eagerly so a caller
75 # passing an attacker-influenced URL (e.g. taken from a `jku` token
76 # header) can't read local files or reach other unintended schemes.
77 scheme = urlparse(uri).scheme.lower()
78 if scheme not in ("http", "https"):
79 raise PyJWKClientError(
80 f"Invalid JWKS URI scheme {scheme!r}: only 'http' and 'https' "
81 f"are supported."
82 )
83 self.uri = uri
84 self.jwk_set_cache: JWKSetCache | None = None
85 self.headers = headers
86 self.timeout = timeout
87 self.ssl_context = ssl_context
89 if cache_jwk_set:
90 # Init jwt set cache with default or given lifespan.
91 # Default lifespan is 300 seconds (5 minutes).
92 if lifespan <= 0:
93 raise PyJWKClientError(
94 f'Lifespan must be greater than 0, the input is "{lifespan}"'
95 )
96 self.jwk_set_cache = JWKSetCache(lifespan)
97 else:
98 self.jwk_set_cache = None
100 if cache_keys:
101 # Cache signing keys
102 get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key)
103 # Ignore mypy (https://github.com/python/mypy/issues/2427)
104 self.get_signing_key = get_signing_key # type: ignore[method-assign]
106 def fetch_data(self) -> Any:
107 """Fetch the JWK Set from the JWKS endpoint.
109 Makes an HTTP request to the configured ``uri`` and returns the
110 parsed JSON response. If the JWK Set cache is enabled, the
111 response is stored in the cache.
113 :returns: The parsed JWK Set as a dictionary.
114 :raises PyJWKClientConnectionError: If the HTTP request fails.
115 """
116 try:
117 r = urllib.request.Request(url=self.uri, headers=self.headers)
118 with urllib.request.urlopen(
119 r, timeout=self.timeout, context=self.ssl_context
120 ) as response:
121 jwk_set = json.load(response)
122 except (URLError, TimeoutError) as e:
123 if isinstance(e, HTTPError):
124 e.close()
125 raise PyJWKClientConnectionError(
126 f'Fail to fetch data from the url, err: "{e}"'
127 ) from e
129 # Only update the cache on a successful fetch. Writing in a
130 # `finally` block with `jwk_set=None` on error clears any
131 # previously-cached JWKS, turning a transient outage into a cache
132 # wipe that breaks legitimate auth.
133 if self.jwk_set_cache is not None:
134 self.jwk_set_cache.put(jwk_set)
135 return jwk_set
137 def get_jwk_set(self, refresh: bool = False) -> PyJWKSet:
138 """Return the JWK Set, using the cache when available.
140 :param refresh: Force a fresh fetch from the endpoint, bypassing
141 the cache.
142 :type refresh: bool
143 :returns: The JWK Set.
144 :rtype: PyJWKSet
145 :raises PyJWKClientError: If the endpoint does not return a JSON
146 object.
147 """
148 data = None
149 if self.jwk_set_cache is not None and not refresh:
150 data = self.jwk_set_cache.get()
152 if data is None:
153 data = self.fetch_data()
155 if not isinstance(data, dict):
156 raise PyJWKClientError("The JWKS endpoint did not return a JSON object")
158 return PyJWKSet.from_dict(data)
160 def get_signing_keys(self, refresh: bool = False) -> list[PyJWK]:
161 """Return all signing keys from the JWK Set.
163 Filters the JWK Set to keys whose ``use`` is ``"sig"`` (or
164 unspecified) and that have a ``kid``.
166 :param refresh: Force a fresh fetch from the endpoint, bypassing
167 the cache.
168 :type refresh: bool
169 :returns: A list of signing keys.
170 :rtype: list[PyJWK]
171 :raises PyJWKClientError: If no signing keys are found.
172 """
173 jwk_set = self.get_jwk_set(refresh)
174 signing_keys = [
175 jwk_set_key
176 for jwk_set_key in jwk_set.keys
177 if jwk_set_key.public_key_use in ["sig", None] and jwk_set_key.key_id
178 ]
180 if not signing_keys:
181 raise PyJWKClientError("The JWKS endpoint did not contain any signing keys")
183 return signing_keys
185 def get_signing_key(self, kid: str) -> PyJWK:
186 """Return the signing key matching the given ``kid``.
188 If no match is found in the current JWK Set, the set is
189 refreshed from the endpoint and the lookup is retried once.
191 :param kid: The key ID to look up.
192 :type kid: str
193 :returns: The matching signing key.
194 :rtype: PyJWK
195 :raises PyJWKClientError: If no matching key is found after
196 refreshing.
197 """
198 signing_keys = self.get_signing_keys()
199 signing_key = self.match_kid(signing_keys, kid)
201 if not signing_key:
202 # If no matching signing key from the jwk set, refresh the jwk set and try again.
203 signing_keys = self.get_signing_keys(refresh=True)
204 signing_key = self.match_kid(signing_keys, kid)
206 if not signing_key:
207 raise PyJWKClientError(
208 f'Unable to find a signing key that matches: "{kid}"'
209 )
211 return signing_key
213 def get_signing_key_from_jwt(self, token: str | bytes) -> PyJWK:
214 """Return the signing key for a JWT by reading its ``kid`` header.
216 Extracts the ``kid`` from the token's unverified header and
217 delegates to :meth:`get_signing_key`.
219 :param token: The encoded JWT.
220 :type token: str or bytes
221 :returns: The matching signing key.
222 :rtype: PyJWK
223 """
224 unverified = decode_token(token, options={"verify_signature": False})
225 header = unverified["header"]
226 return self.get_signing_key(header.get("kid"))
228 @staticmethod
229 def match_kid(signing_keys: list[PyJWK], kid: str) -> PyJWK | None:
230 """Find a key in *signing_keys* that matches *kid*.
232 :param signing_keys: The list of keys to search.
233 :type signing_keys: list[PyJWK]
234 :param kid: The key ID to match.
235 :type kid: str
236 :returns: The matching key, or ``None`` if not found.
237 :rtype: PyJWK or None
238 """
239 signing_key = None
241 for key in signing_keys:
242 if key.key_id == kid:
243 signing_key = key
244 break
246 return signing_key