Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.10/site-packages/msal/token_cache.py: 28%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import base64
2import hashlib
3import json
4import threading
5import time
6import logging
7import warnings
9from .authority import canonicalize
10from .oauth2cli.oidc import decode_part, decode_id_token
11from .oauth2cli.oauth2 import Client
14logger = logging.getLogger(__name__)
15_GRANT_TYPE_BROKER = "broker"
17# Fields in the request data dict that should NOT be included in the extended
18# cache key hash. Everything else in data IS included, because those are extra
19# body parameters going on the wire and must differentiate cached tokens.
20#
21# Excluded fields and reasons:
22# - "client_id" : Standard OAuth2 client identifier, same for every request
23# - "grant_type" : It is possible to combine grants to get tokens, e.g. obo + refresh_token, auth_code + refresh_token etc.
24# - "scope" : Already represented as "target" in the AT cache key
25# - "claims" : Handled separately; its presence forces a token refresh
26# - "username" : Standard ROPC grant parameter. Tokens are cached by user ID (subject or oid+tid) instead
27# - "password" : Standard ROPC grant parameter. Tokens are tied to credentials.
28# - "refresh_token" : Standard refresh grant parameter
29# - "code" : Standard authorization code grant parameter
30# - "redirect_uri" : Standard authorization code grant parameter
31# - "code_verifier" : Standard PKCE parameter
32# - "device_code" : Standard device flow parameter
33# - "assertion" : Standard OBO/SAML assertion (RFC 7521)
34# - "requested_token_use" : OBO indicator ("on_behalf_of"), not an extra param
35# - "client_assertion" : Client authentication credential (RFC 7521 §4.2)
36# - "client_assertion_type" : Client authentication type (RFC 7521 §4.2)
37# - "client_secret" : Client authentication secret
38# - "token_type" : Used for SSH-cert/POP detection; AT entry stores separately
39# - "req_cnf" : Ephemeral proof-of-possession nonce, changes per request
40# - "key_id" : Already handled as a separate cache lookup field
41#
42# Included fields (examples — anything NOT in this set is included):
43# - "fmi_path" : Federated Managed Identity credential path
44# - any future non-standard body parameter that should isolate cache entries
45_EXT_CACHE_KEY_EXCLUDED_FIELDS = frozenset({
46 # Standard OAuth2 body parameters — these appear in every token request
47 # and must NOT influence the extended cache key.
48 # Only non-standard fields (e.g. fmi_path) should contribute to the hash.
49 "client_id",
50 "grant_type",
51 "scope",
52 "claims",
53 "username",
54 "password",
55 "refresh_token",
56 "code",
57 "redirect_uri",
58 "code_verifier",
59 "device_code",
60 "assertion",
61 "requested_token_use",
62 "client_assertion",
63 "client_assertion_type",
64 "client_secret",
65 "token_type",
66 "req_cnf",
67 "key_id",
68 # user_fic grant parameters — these are standard body params for the
69 # user_fic flow; FIC tokens use normal user cache keys (not extended).
70 "user_federated_identity_credential",
71 "user_id",
72 "client_info",
73})
76def _compute_ext_cache_key(data):
77 """Compute an extended cache key hash from extra body parameters in *data*.
79 All fields in *data* that go on the wire are included in the hash,
80 EXCEPT those listed in ``_EXT_CACHE_KEY_EXCLUDED_FIELDS``.
81 This ensures tokens acquired with different parameter values
82 (e.g., different FMI paths) are cached separately.
84 Returns an empty string when *data* has no hashable fields.
86 The algorithm matches the Go MSAL implementation (CacheExtKeyGenerator):
87 sorted key+value pairs are concatenated and SHA256 hashed, then base64url encoded.
88 """
89 if not data:
90 return ""
91 cache_components = {
92 k: str(v) for k, v in data.items()
93 if k not in _EXT_CACHE_KEY_EXCLUDED_FIELDS and v
94 }
95 if not cache_components:
96 return ""
97 # Sort keys for consistent hashing (matches Go implementation)
98 key_str = "".join(
99 k + cache_components[k] for k in sorted(cache_components.keys())
100 )
101 hash_bytes = hashlib.sha256(key_str.encode("utf-8")).digest()
102 return base64.urlsafe_b64encode(hash_bytes).rstrip(b"=").decode("ascii").lower()
105def is_subdict_of(small, big):
106 return dict(big, **small) == big
108def _get_username(id_token_claims):
109 return id_token_claims.get(
110 "preferred_username", # AAD
111 id_token_claims.get("upn")) # ADFS 2019
113class TokenCache(object):
114 """This is considered as a base class containing minimal cache behavior.
116 Although it maintains tokens using unified schema across all MSAL libraries,
117 this class does not serialize/persist them.
118 See subclass :class:`SerializableTokenCache` for details on serialization.
119 """
121 class CredentialType:
122 ACCESS_TOKEN = "AccessToken"
123 ACCESS_TOKEN_EXTENDED = "atext" # Used when ext_cache_key is present (matches Go/dotnet)
124 REFRESH_TOKEN = "RefreshToken"
125 ACCOUNT = "Account" # Not exactly a credential type, but we put it here
126 ID_TOKEN = "IdToken"
127 APP_METADATA = "AppMetadata"
129 class AuthorityType:
130 ADFS = "ADFS"
131 MSSTS = "MSSTS" # MSSTS means AAD v2 for both AAD & MSA
133 def __init__(self):
134 self._lock = threading.RLock()
135 self._cache = {}
136 self.key_makers = {
137 # Note: We have changed token key format before when ordering scopes;
138 # changing token key won't result in cache miss.
139 self.CredentialType.REFRESH_TOKEN:
140 lambda home_account_id=None, environment=None, client_id=None,
141 target=None, **ignored_payload_from_a_real_token:
142 "-".join([
143 home_account_id or "",
144 environment or "",
145 self.CredentialType.REFRESH_TOKEN,
146 client_id or "",
147 "", # RT is cross-tenant in AAD
148 target or "", # raw value could be None if deserialized from other SDK
149 ]).lower(),
150 self.CredentialType.ACCESS_TOKEN:
151 lambda home_account_id=None, environment=None, client_id=None,
152 realm=None, target=None,
153 ext_cache_key=None,
154 # Note: New field(s) can be added here
155 #key_id=None,
156 **ignored_payload_from_a_real_token:
157 "-".join([ # Note: Could use a hash here to shorten key length
158 home_account_id or "",
159 environment or "",
160 # Use "atext" credential type when ext_cache_key is
161 # present, matching MSAL Go and MSAL .NET behaviour.
162 "atext" if ext_cache_key else "AccessToken",
163 client_id or "",
164 realm or "",
165 target or "",
166 #key_id or "", # So ATs of different key_id can coexist
167 ] + ([ext_cache_key] if ext_cache_key else [])
168 ).lower(),
169 self.CredentialType.ID_TOKEN:
170 lambda home_account_id=None, environment=None, client_id=None,
171 realm=None, **ignored_payload_from_a_real_token:
172 "-".join([
173 home_account_id or "",
174 environment or "",
175 self.CredentialType.ID_TOKEN,
176 client_id or "",
177 realm or "",
178 "" # Albeit irrelevant, schema requires an empty scope here
179 ]).lower(),
180 self.CredentialType.ACCOUNT:
181 lambda home_account_id=None, environment=None, realm=None,
182 **ignored_payload_from_a_real_entry:
183 "-".join([
184 home_account_id or "",
185 environment or "",
186 realm or "",
187 ]).lower(),
188 self.CredentialType.APP_METADATA:
189 lambda environment=None, client_id=None, **kwargs:
190 "appmetadata-{}-{}".format(environment or "", client_id or ""),
191 }
193 def _get_access_token(
194 self,
195 home_account_id, environment, client_id, realm, target, # Together they form a compound key
196 ext_cache_key=None,
197 default=None,
198 ): # O(1)
199 return self._get(
200 self.CredentialType.ACCESS_TOKEN,
201 self.key_makers[TokenCache.CredentialType.ACCESS_TOKEN](
202 home_account_id=home_account_id,
203 environment=environment,
204 client_id=client_id,
205 realm=realm,
206 target=" ".join(target),
207 ext_cache_key=ext_cache_key,
208 ),
209 default=default)
211 def _get_app_metadata(self, environment, client_id, default=None): # O(1)
212 return self._get(
213 self.CredentialType.APP_METADATA,
214 self.key_makers[TokenCache.CredentialType.APP_METADATA](
215 environment=environment,
216 client_id=client_id,
217 ),
218 default=default)
220 def _get(self, credential_type, key, default=None): # O(1)
221 with self._lock:
222 return self._cache.get(credential_type, {}).get(key, default)
224 @staticmethod
225 def _is_matching(entry: dict, query: dict, target_set: set = None) -> bool:
226 query_with_lowercase_environment = {
227 # __add() canonicalized entry's environment value to lower case,
228 # so we do the same here.
229 k: v.lower() if k == "environment" and isinstance(v, str) else v
230 for k, v in query.items()
231 } if query else {}
232 return is_subdict_of(query_with_lowercase_environment, entry) and (
233 target_set <= set(entry.get("target", "").split())
234 if target_set else True)
236 def search(self, credential_type, target=None, query=None, *, now=None): # O(n) generator
237 """Returns a generator of matching entries.
239 It is O(1) for AT hits, and O(n) for other types.
240 Note that it holds a lock during the entire search.
241 """
242 target = sorted(target or []) # Match the order sorted by add()
243 assert isinstance(target, list), "Invalid parameter type"
245 preferred_result = None
246 if (credential_type == self.CredentialType.ACCESS_TOKEN
247 and isinstance(query, dict)
248 and "home_account_id" in query and "environment" in query
249 and "client_id" in query and "realm" in query and target
250 ): # Special case for O(1) AT lookup
251 preferred_result = self._get_access_token(
252 query["home_account_id"], query["environment"],
253 query["client_id"], query["realm"], target,
254 ext_cache_key=query.get("ext_cache_key"))
255 if preferred_result and self._is_matching(
256 preferred_result, query,
257 # Needs no target_set here because it is satisfied by dict key
258 ):
259 yield preferred_result
261 target_set = set(target)
262 with self._lock:
263 # O(n) search. The key is NOT used in search.
264 now = int(time.time() if now is None else now)
265 expired_access_tokens = [
266 # Especially when/if we key ATs by ephemeral fields such as key_id,
267 # stale ATs keyed by an old key_id would stay forever.
268 # Here we collect them for their removal.
269 ]
270 for entry in self._cache.get(credential_type, {}).values():
271 if ( # Automatically delete expired access tokens
272 credential_type == self.CredentialType.ACCESS_TOKEN
273 and int(entry["expires_on"]) < now
274 ):
275 expired_access_tokens.append(entry) # Can't delete them within current for-loop
276 continue
277 if (entry != preferred_result # Avoid yielding the same entry twice
278 and self._is_matching(entry, query, target_set=target_set)
279 ):
280 # Cache isolation for extended cache keys (e.g., FMI path).
281 # Entries with ext_cache_key must not match queries without one.
282 if (credential_type == self.CredentialType.ACCESS_TOKEN
283 and "ext_cache_key" in entry
284 and "ext_cache_key" not in (query or {})
285 ):
286 continue
287 yield entry
288 for at in expired_access_tokens:
289 self.remove_at(at)
291 def find(self, credential_type, target=None, query=None, *, now=None):
292 """Equivalent to list(search(...))."""
293 warnings.warn(
294 "Use list(search(...)) instead to explicitly get a list.",
295 DeprecationWarning)
296 return list(self.search(credential_type, target=target, query=query, now=now))
298 def add(self, event, now=None):
299 """Handle a token obtaining event, and add tokens into cache."""
300 def make_clean_copy(dictionary, sensitive_fields): # Masks sensitive info
301 return {
302 k: "********" if k in sensitive_fields else v
303 for k, v in dictionary.items()
304 }
305 clean_event = dict(
306 event,
307 data=make_clean_copy(event.get("data", {}), (
308 "password", "client_secret", "refresh_token", "assertion",
309 "user_federated_identity_credential",
310 )),
311 response=make_clean_copy(event.get("response", {}), (
312 "id_token_claims", # Provided by broker
313 "access_token", "refresh_token", "id_token", "username",
314 )),
315 )
316 logger.debug("event=%s", json.dumps(
317 # We examined and concluded that this log won't have Log Injection risk,
318 # because the event payload is already in JSON so CR/LF will be escaped.
319 clean_event,
320 indent=4, sort_keys=True,
321 default=str, # assertion is in bytes in Python 3
322 ))
323 return self.__add(event, now=now)
325 def __parse_account(self, response, id_token_claims):
326 """Return client_info and home_account_id"""
327 if "client_info" in response: # It happens when client_info and profile are in request
328 client_info = json.loads(decode_part(response["client_info"]))
329 if "uid" in client_info and "utid" in client_info:
330 return client_info, "{uid}.{utid}".format(**client_info)
331 # https://github.com/AzureAD/microsoft-authentication-library-for-python/issues/387
332 if id_token_claims: # This would be an end user on ADFS-direct scenario
333 sub = id_token_claims["sub"] # "sub" always exists, per OIDC specs
334 return {"uid": sub}, sub
335 # client_credentials flow will reach this code path
336 return {}, None
338 def __add(self, event, now=None):
339 # event typically contains: client_id, scope, token_endpoint,
340 # response, params, data, grant_type
341 environment = realm = None
342 if "token_endpoint" in event:
343 _, environment, realm = canonicalize(event["token_endpoint"])
344 if "environment" in event: # Always available unless in legacy test cases
345 environment = event["environment"] # Set by application.py
346 response = event.get("response", {})
347 data = event.get("data", {})
348 access_token = response.get("access_token")
349 refresh_token = response.get("refresh_token")
350 id_token = response.get("id_token")
351 id_token_claims = response.get("id_token_claims") or ( # Prefer the claims from broker
352 # Only use decode_id_token() when necessary, it contains time-sensitive validation
353 decode_id_token(id_token, client_id=event["client_id"]) if id_token else {})
354 client_info, home_account_id = self.__parse_account(response, id_token_claims)
356 target = ' '.join(sorted(event.get("scope") or [])) # Schema should have required sorting
358 with self._lock:
359 now = int(time.time() if now is None else now)
361 if access_token:
362 default_expires_in = ( # https://www.rfc-editor.org/rfc/rfc6749#section-5.1
363 int(response.get("expires_on")) - now # Some Managed Identity emits this
364 ) if response.get("expires_on") else 600
365 expires_in = int( # AADv1-like endpoint returns a string
366 response.get("expires_in", default_expires_in))
367 ext_expires_in = int( # AADv1-like endpoint returns a string
368 response.get("ext_expires_in", expires_in))
369 at = {
370 "credential_type": self.CredentialType.ACCESS_TOKEN,
371 "secret": access_token,
372 "home_account_id": home_account_id,
373 "environment": environment,
374 "client_id": event.get("client_id"),
375 "target": target,
376 "realm": realm,
377 "token_type": response.get("token_type", "Bearer"),
378 "cached_at": str(now), # Schema defines it as a string
379 "expires_on": str(now + expires_in), # Same here
380 "extended_expires_on": str(now + ext_expires_in) # Same here
381 }
382 at.update({k: data[k] for k in data if k in {
383 # Also store extra data which we explicitly allow
384 # So that we won't accidentally store a user's password etc.
385 "key_id", # It happens in SSH-cert or POP scenario
386 }})
387 # Compute and store extended cache key for cache isolation
388 # (e.g., different FMI paths should have separate cache entries)
389 ext_cache_key = _compute_ext_cache_key(data)
391 if ext_cache_key:
392 at["ext_cache_key"] = ext_cache_key
393 if "refresh_in" in response:
394 refresh_in = response["refresh_in"] # It is an integer
395 at["refresh_on"] = str(now + refresh_in) # Schema wants a string
396 self.modify(self.CredentialType.ACCESS_TOKEN, at, at)
398 if client_info and not event.get("skip_account_creation"):
399 account = {
400 "home_account_id": home_account_id,
401 "environment": environment,
402 "realm": realm,
403 "local_account_id": event.get(
404 "_account_id", # Came from mid-tier code path.
405 # Emperically, it is the oid in AAD or cid in MSA.
406 id_token_claims.get("oid", id_token_claims.get("sub"))),
407 "username": _get_username(id_token_claims)
408 or data.get("username") # Falls back to ROPC username
409 or event.get("username") # Falls back to Federated ROPC username
410 or "", # The schema does not like null
411 "authority_type": event.get(
412 "authority_type", # Honor caller's choice of authority_type
413 self.AuthorityType.ADFS if realm == "adfs"
414 else self.AuthorityType.MSSTS),
415 # "client_info": response.get("client_info"), # Optional
416 }
417 grant_types_that_establish_an_account = (
418 _GRANT_TYPE_BROKER, "authorization_code", "password",
419 Client.DEVICE_FLOW["GRANT_TYPE"], "user_fic")
420 if event.get("grant_type") in grant_types_that_establish_an_account:
421 account["account_source"] = event["grant_type"]
422 self.modify(self.CredentialType.ACCOUNT, account, account)
424 if id_token:
425 idt = {
426 "credential_type": self.CredentialType.ID_TOKEN,
427 "secret": id_token,
428 "home_account_id": home_account_id,
429 "environment": environment,
430 "realm": realm,
431 "client_id": event.get("client_id"),
432 # "authority": "it is optional",
433 }
434 self.modify(self.CredentialType.ID_TOKEN, idt, idt)
436 if refresh_token:
437 rt = {
438 "credential_type": self.CredentialType.REFRESH_TOKEN,
439 "secret": refresh_token,
440 "home_account_id": home_account_id,
441 "environment": environment,
442 "client_id": event.get("client_id"),
443 "target": target, # Optional per schema though
444 "last_modification_time": str(now), # Optional. Schema defines it as a string.
445 }
446 if "foci" in response:
447 rt["family_id"] = response["foci"]
448 self.modify(self.CredentialType.REFRESH_TOKEN, rt, rt)
450 app_metadata = {
451 "client_id": event.get("client_id"),
452 "environment": environment,
453 }
454 if "foci" in response:
455 app_metadata["family_id"] = response.get("foci")
456 self.modify(self.CredentialType.APP_METADATA, app_metadata, app_metadata)
458 def modify(self, credential_type, old_entry, new_key_value_pairs=None):
459 # Modify the specified old_entry with new_key_value_pairs,
460 # or remove the old_entry if the new_key_value_pairs is None.
462 # This helper exists to consolidate all token add/modify/remove behaviors,
463 # so that the sub-classes will have only one method to work on,
464 # instead of patching a pair of update_xx() and remove_xx() per type.
465 # You can monkeypatch self.key_makers to support more types on-the-fly.
466 key = self.key_makers[credential_type](**old_entry)
467 with self._lock:
468 if new_key_value_pairs: # Update with them
469 entries = self._cache.setdefault(credential_type, {})
470 entries[key] = dict(
471 old_entry, # Do not use entries[key] b/c it might not exist
472 **new_key_value_pairs)
473 else: # Remove old_entry
474 self._cache.setdefault(credential_type, {}).pop(key, None)
476 def remove_rt(self, rt_item):
477 assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN
478 return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item)
480 def update_rt(self, rt_item, new_rt):
481 assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN
482 return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item, {
483 "secret": new_rt,
484 "last_modification_time": str(int(time.time())), # Optional. Schema defines it as a string.
485 })
487 def remove_at(self, at_item):
488 assert at_item.get("credential_type") == self.CredentialType.ACCESS_TOKEN
489 return self.modify(self.CredentialType.ACCESS_TOKEN, at_item)
491 def remove_idt(self, idt_item):
492 assert idt_item.get("credential_type") == self.CredentialType.ID_TOKEN
493 return self.modify(self.CredentialType.ID_TOKEN, idt_item)
495 def remove_account(self, account_item):
496 assert "authority_type" in account_item
497 return self.modify(self.CredentialType.ACCOUNT, account_item)
500class SerializableTokenCache(TokenCache):
501 """This serialization can be a starting point to implement your own persistence.
503 This class does NOT actually persist the cache on disk/db/etc..
504 Depending on your need,
505 the following simple recipe for file-based, unencrypted persistence may be sufficient::
507 import os, atexit, msal
508 cache_filename = os.path.join( # Persist cache into this file
509 os.getenv(
510 # Automatically wipe out the cache from Linux when user's ssh session ends.
511 # See also https://github.com/AzureAD/microsoft-authentication-library-for-python/issues/690
512 "XDG_RUNTIME_DIR", ""),
513 "my_cache.bin")
514 cache = msal.SerializableTokenCache()
515 if os.path.exists(cache_filename):
516 cache.deserialize(open(cache_filename, "r").read())
517 atexit.register(lambda:
518 open(cache_filename, "w").write(cache.serialize())
519 # Hint: The following optional line persists only when state changed
520 if cache.has_state_changed else None
521 )
522 app = msal.ClientApplication(..., token_cache=cache)
523 ...
525 Alternatively, you may use a more sophisticated cache persistence library,
526 `MSAL Extensions <https://github.com/AzureAD/microsoft-authentication-extensions-for-python>`_,
527 which provides token cache persistence with encryption, and more.
529 :var bool has_state_changed:
530 Indicates whether the cache state in the memory has changed since last
531 :func:`~serialize` or :func:`~deserialize` call.
532 """
533 has_state_changed = False
535 def add(self, event, **kwargs):
536 super(SerializableTokenCache, self).add(event, **kwargs)
537 self.has_state_changed = True
539 def modify(self, credential_type, old_entry, new_key_value_pairs=None):
540 super(SerializableTokenCache, self).modify(
541 credential_type, old_entry, new_key_value_pairs)
542 self.has_state_changed = True
544 def deserialize(self, state):
545 # type: (Optional[str]) -> None
546 """Deserialize the cache from a state previously obtained by serialize()"""
547 with self._lock:
548 self._cache = json.loads(state) if state else {}
549 self.has_state_changed = False # reset
551 def serialize(self):
552 # type: () -> str
553 """Serialize the current cache state into a string."""
554 with self._lock:
555 self.has_state_changed = False
556 return json.dumps(self._cache, indent=4)