Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/jwt/jwks_client.py: 26%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

80 statements  

1from __future__ import annotations 

2 

3import json 

4import urllib.request 

5from functools import lru_cache 

6from ssl import SSLContext 

7from typing import Any 

8from urllib.error import HTTPError, URLError 

9from urllib.parse import urlparse 

10 

11from .api_jwk import PyJWK, PyJWKSet 

12from .api_jwt import decode_complete as decode_token 

13from .exceptions import PyJWKClientConnectionError, PyJWKClientError 

14from .jwk_set_cache import JWKSetCache 

15 

16 

17class PyJWKClient: 

18 def __init__( 

19 self, 

20 uri: str, 

21 cache_keys: bool = False, 

22 max_cached_keys: int = 16, 

23 cache_jwk_set: bool = True, 

24 lifespan: float = 300, 

25 headers: dict[str, Any] | None = None, 

26 timeout: float = 30, 

27 ssl_context: SSLContext | None = None, 

28 ): 

29 """A client for retrieving signing keys from a JWKS endpoint. 

30 

31 ``PyJWKClient`` uses a two-tier caching system to avoid unnecessary 

32 network requests: 

33 

34 **Tier 1 — JWK Set cache** (enabled by default): 

35 Caches the entire JSON Web Key Set response from the endpoint. 

36 Controlled by: 

37 

38 - ``cache_jwk_set``: Set to ``True`` (the default) to enable this 

39 cache. When enabled, the JWK Set is fetched from the network only 

40 when the cache is empty or expired. 

41 - ``lifespan``: Time in seconds before the cached JWK Set expires. 

42 Defaults to ``300`` (5 minutes). Must be greater than 0. 

43 

44 **Tier 2 — Signing key cache** (disabled by default): 

45 Caches individual signing keys (looked up by ``kid``) using an LRU 

46 cache with **no time-based expiration**. Keys are evicted only when 

47 the cache reaches its maximum size. Controlled by: 

48 

49 - ``cache_keys``: Set to ``True`` to enable this cache. 

50 Defaults to ``False``. 

51 - ``max_cached_keys``: Maximum number of signing keys to keep in 

52 the LRU cache. Defaults to ``16``. 

53 

54 :param uri: The URL of the JWKS endpoint. 

55 :type uri: str 

56 :param cache_keys: Enable the per-key LRU cache (Tier 2). 

57 :type cache_keys: bool 

58 :param max_cached_keys: Max entries in the signing key LRU cache. 

59 :type max_cached_keys: int 

60 :param cache_jwk_set: Enable the JWK Set response cache (Tier 1). 

61 :type cache_jwk_set: bool 

62 :param lifespan: TTL in seconds for the JWK Set cache. 

63 :type lifespan: float 

64 :param headers: Optional HTTP headers to include in requests. 

65 :type headers: dict or None 

66 :param timeout: HTTP request timeout in seconds. 

67 :type timeout: float 

68 :param ssl_context: Optional SSL context for the request. 

69 :type ssl_context: ssl.SSLContext or None 

70 """ 

71 if headers is None: 

72 headers = {} 

73 # urllib's default OpenerDirector also handles file://, ftp://, and 

74 # data: URIs. Reject anything that isn't http(s) eagerly so a caller 

75 # passing an attacker-influenced URL (e.g. taken from a `jku` token 

76 # header) can't read local files or reach other unintended schemes. 

77 scheme = urlparse(uri).scheme.lower() 

78 if scheme not in ("http", "https"): 

79 raise PyJWKClientError( 

80 f"Invalid JWKS URI scheme {scheme!r}: only 'http' and 'https' " 

81 f"are supported." 

82 ) 

83 self.uri = uri 

84 self.jwk_set_cache: JWKSetCache | None = None 

85 self.headers = headers 

86 self.timeout = timeout 

87 self.ssl_context = ssl_context 

88 

89 if cache_jwk_set: 

90 # Init jwt set cache with default or given lifespan. 

91 # Default lifespan is 300 seconds (5 minutes). 

92 if lifespan <= 0: 

93 raise PyJWKClientError( 

94 f'Lifespan must be greater than 0, the input is "{lifespan}"' 

95 ) 

96 self.jwk_set_cache = JWKSetCache(lifespan) 

97 else: 

98 self.jwk_set_cache = None 

99 

100 if cache_keys: 

101 # Cache signing keys 

102 get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key) 

103 # Ignore mypy (https://github.com/python/mypy/issues/2427) 

104 self.get_signing_key = get_signing_key # type: ignore[method-assign] 

105 

106 def fetch_data(self) -> Any: 

107 """Fetch the JWK Set from the JWKS endpoint. 

108 

109 Makes an HTTP request to the configured ``uri`` and returns the 

110 parsed JSON response. If the JWK Set cache is enabled, the 

111 response is stored in the cache. 

112 

113 :returns: The parsed JWK Set as a dictionary. 

114 :raises PyJWKClientConnectionError: If the HTTP request fails. 

115 """ 

116 try: 

117 r = urllib.request.Request(url=self.uri, headers=self.headers) 

118 with urllib.request.urlopen( 

119 r, timeout=self.timeout, context=self.ssl_context 

120 ) as response: 

121 jwk_set = json.load(response) 

122 except (URLError, TimeoutError) as e: 

123 if isinstance(e, HTTPError): 

124 e.close() 

125 raise PyJWKClientConnectionError( 

126 f'Fail to fetch data from the url, err: "{e}"' 

127 ) from e 

128 

129 # Only update the cache on a successful fetch. Writing in a 

130 # `finally` block with `jwk_set=None` on error clears any 

131 # previously-cached JWKS, turning a transient outage into a cache 

132 # wipe that breaks legitimate auth. 

133 if self.jwk_set_cache is not None: 

134 self.jwk_set_cache.put(jwk_set) 

135 return jwk_set 

136 

137 def get_jwk_set(self, refresh: bool = False) -> PyJWKSet: 

138 """Return the JWK Set, using the cache when available. 

139 

140 :param refresh: Force a fresh fetch from the endpoint, bypassing 

141 the cache. 

142 :type refresh: bool 

143 :returns: The JWK Set. 

144 :rtype: PyJWKSet 

145 :raises PyJWKClientError: If the endpoint does not return a JSON 

146 object. 

147 """ 

148 data = None 

149 if self.jwk_set_cache is not None and not refresh: 

150 data = self.jwk_set_cache.get() 

151 

152 if data is None: 

153 data = self.fetch_data() 

154 

155 if not isinstance(data, dict): 

156 raise PyJWKClientError("The JWKS endpoint did not return a JSON object") 

157 

158 return PyJWKSet.from_dict(data) 

159 

160 def get_signing_keys(self, refresh: bool = False) -> list[PyJWK]: 

161 """Return all signing keys from the JWK Set. 

162 

163 Filters the JWK Set to keys whose ``use`` is ``"sig"`` (or 

164 unspecified) and that have a ``kid``. 

165 

166 :param refresh: Force a fresh fetch from the endpoint, bypassing 

167 the cache. 

168 :type refresh: bool 

169 :returns: A list of signing keys. 

170 :rtype: list[PyJWK] 

171 :raises PyJWKClientError: If no signing keys are found. 

172 """ 

173 jwk_set = self.get_jwk_set(refresh) 

174 signing_keys = [ 

175 jwk_set_key 

176 for jwk_set_key in jwk_set.keys 

177 if jwk_set_key.public_key_use in ["sig", None] and jwk_set_key.key_id 

178 ] 

179 

180 if not signing_keys: 

181 raise PyJWKClientError("The JWKS endpoint did not contain any signing keys") 

182 

183 return signing_keys 

184 

185 def get_signing_key(self, kid: str) -> PyJWK: 

186 """Return the signing key matching the given ``kid``. 

187 

188 If no match is found in the current JWK Set, the set is 

189 refreshed from the endpoint and the lookup is retried once. 

190 

191 :param kid: The key ID to look up. 

192 :type kid: str 

193 :returns: The matching signing key. 

194 :rtype: PyJWK 

195 :raises PyJWKClientError: If no matching key is found after 

196 refreshing. 

197 """ 

198 signing_keys = self.get_signing_keys() 

199 signing_key = self.match_kid(signing_keys, kid) 

200 

201 if not signing_key: 

202 # If no matching signing key from the jwk set, refresh the jwk set and try again. 

203 signing_keys = self.get_signing_keys(refresh=True) 

204 signing_key = self.match_kid(signing_keys, kid) 

205 

206 if not signing_key: 

207 raise PyJWKClientError( 

208 f'Unable to find a signing key that matches: "{kid}"' 

209 ) 

210 

211 return signing_key 

212 

213 def get_signing_key_from_jwt(self, token: str | bytes) -> PyJWK: 

214 """Return the signing key for a JWT by reading its ``kid`` header. 

215 

216 Extracts the ``kid`` from the token's unverified header and 

217 delegates to :meth:`get_signing_key`. 

218 

219 :param token: The encoded JWT. 

220 :type token: str or bytes 

221 :returns: The matching signing key. 

222 :rtype: PyJWK 

223 """ 

224 unverified = decode_token(token, options={"verify_signature": False}) 

225 header = unverified["header"] 

226 return self.get_signing_key(header.get("kid")) 

227 

228 @staticmethod 

229 def match_kid(signing_keys: list[PyJWK], kid: str) -> PyJWK | None: 

230 """Find a key in *signing_keys* that matches *kid*. 

231 

232 :param signing_keys: The list of keys to search. 

233 :type signing_keys: list[PyJWK] 

234 :param kid: The key ID to match. 

235 :type kid: str 

236 :returns: The matching key, or ``None`` if not found. 

237 :rtype: PyJWK or None 

238 """ 

239 signing_key = None 

240 

241 for key in signing_keys: 

242 if key.key_id == kid: 

243 signing_key = key 

244 break 

245 

246 return signing_key