Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/msal/token_cache.py: 30%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import json
2import threading
3import time
4import logging
5import warnings
7from .authority import canonicalize
8from .oauth2cli.oidc import decode_part, decode_id_token
9from .oauth2cli.oauth2 import Client
12logger = logging.getLogger(__name__)
13_GRANT_TYPE_BROKER = "broker"
15def is_subdict_of(small, big):
16 return dict(big, **small) == big
18def _get_username(id_token_claims):
19 return id_token_claims.get(
20 "preferred_username", # AAD
21 id_token_claims.get("upn")) # ADFS 2019
23class TokenCache(object):
24 """This is considered as a base class containing minimal cache behavior.
26 Although it maintains tokens using unified schema across all MSAL libraries,
27 this class does not serialize/persist them.
28 See subclass :class:`SerializableTokenCache` for details on serialization.
29 """
31 class CredentialType:
32 ACCESS_TOKEN = "AccessToken"
33 REFRESH_TOKEN = "RefreshToken"
34 ACCOUNT = "Account" # Not exactly a credential type, but we put it here
35 ID_TOKEN = "IdToken"
36 APP_METADATA = "AppMetadata"
38 class AuthorityType:
39 ADFS = "ADFS"
40 MSSTS = "MSSTS" # MSSTS means AAD v2 for both AAD & MSA
42 def __init__(self):
43 self._lock = threading.RLock()
44 self._cache = {}
45 self.key_makers = {
46 self.CredentialType.REFRESH_TOKEN:
47 lambda home_account_id=None, environment=None, client_id=None,
48 target=None, **ignored_payload_from_a_real_token:
49 "-".join([
50 home_account_id or "",
51 environment or "",
52 self.CredentialType.REFRESH_TOKEN,
53 client_id or "",
54 "", # RT is cross-tenant in AAD
55 target or "", # raw value could be None if deserialized from other SDK
56 ]).lower(),
57 self.CredentialType.ACCESS_TOKEN:
58 lambda home_account_id=None, environment=None, client_id=None,
59 realm=None, target=None, **ignored_payload_from_a_real_token:
60 "-".join([
61 home_account_id or "",
62 environment or "",
63 self.CredentialType.ACCESS_TOKEN,
64 client_id or "",
65 realm or "",
66 target or "",
67 ]).lower(),
68 self.CredentialType.ID_TOKEN:
69 lambda home_account_id=None, environment=None, client_id=None,
70 realm=None, **ignored_payload_from_a_real_token:
71 "-".join([
72 home_account_id or "",
73 environment or "",
74 self.CredentialType.ID_TOKEN,
75 client_id or "",
76 realm or "",
77 "" # Albeit irrelevant, schema requires an empty scope here
78 ]).lower(),
79 self.CredentialType.ACCOUNT:
80 lambda home_account_id=None, environment=None, realm=None,
81 **ignored_payload_from_a_real_entry:
82 "-".join([
83 home_account_id or "",
84 environment or "",
85 realm or "",
86 ]).lower(),
87 self.CredentialType.APP_METADATA:
88 lambda environment=None, client_id=None, **kwargs:
89 "appmetadata-{}-{}".format(environment or "", client_id or ""),
90 }
92 def _get_access_token(
93 self,
94 home_account_id, environment, client_id, realm, target, # Together they form a compound key
95 default=None,
96 ): # O(1)
97 return self._get(
98 self.CredentialType.ACCESS_TOKEN,
99 self.key_makers[TokenCache.CredentialType.ACCESS_TOKEN](
100 home_account_id=home_account_id,
101 environment=environment,
102 client_id=client_id,
103 realm=realm,
104 target=" ".join(target),
105 ),
106 default=default)
108 def _get_app_metadata(self, environment, client_id, default=None): # O(1)
109 return self._get(
110 self.CredentialType.APP_METADATA,
111 self.key_makers[TokenCache.CredentialType.APP_METADATA](
112 environment=environment,
113 client_id=client_id,
114 ),
115 default=default)
117 def _get(self, credential_type, key, default=None): # O(1)
118 with self._lock:
119 return self._cache.get(credential_type, {}).get(key, default)
121 @staticmethod
122 def _is_matching(entry: dict, query: dict, target_set: set = None) -> bool:
123 return is_subdict_of(query or {}, entry) and (
124 target_set <= set(entry.get("target", "").split())
125 if target_set else True)
127 def search(self, credential_type, target=None, query=None): # O(n) generator
128 """Returns a generator of matching entries.
130 It is O(1) for AT hits, and O(n) for other types.
131 Note that it holds a lock during the entire search.
132 """
133 target = sorted(target or []) # Match the order sorted by add()
134 assert isinstance(target, list), "Invalid parameter type"
136 preferred_result = None
137 if (credential_type == self.CredentialType.ACCESS_TOKEN
138 and isinstance(query, dict)
139 and "home_account_id" in query and "environment" in query
140 and "client_id" in query and "realm" in query and target
141 ): # Special case for O(1) AT lookup
142 preferred_result = self._get_access_token(
143 query["home_account_id"], query["environment"],
144 query["client_id"], query["realm"], target)
145 if preferred_result and self._is_matching(
146 preferred_result, query,
147 # Needs no target_set here because it is satisfied by dict key
148 ):
149 yield preferred_result
151 target_set = set(target)
152 with self._lock:
153 # Since the target inside token cache key is (per schema) unsorted,
154 # there is no point to attempt an O(1) key-value search here.
155 # So we always do an O(n) in-memory search.
156 for entry in self._cache.get(credential_type, {}).values():
157 if (entry != preferred_result # Avoid yielding the same entry twice
158 and self._is_matching(entry, query, target_set=target_set)
159 ):
160 yield entry
162 def find(self, credential_type, target=None, query=None):
163 """Equivalent to list(search(...))."""
164 warnings.warn(
165 "Use list(search(...)) instead to explicitly get a list.",
166 DeprecationWarning)
167 return list(self.search(credential_type, target=target, query=query))
169 def add(self, event, now=None):
170 """Handle a token obtaining event, and add tokens into cache."""
171 def make_clean_copy(dictionary, sensitive_fields): # Masks sensitive info
172 return {
173 k: "********" if k in sensitive_fields else v
174 for k, v in dictionary.items()
175 }
176 clean_event = dict(
177 event,
178 data=make_clean_copy(event.get("data", {}), (
179 "password", "client_secret", "refresh_token", "assertion",
180 )),
181 response=make_clean_copy(event.get("response", {}), (
182 "id_token_claims", # Provided by broker
183 "access_token", "refresh_token", "id_token", "username",
184 )),
185 )
186 logger.debug("event=%s", json.dumps(
187 # We examined and concluded that this log won't have Log Injection risk,
188 # because the event payload is already in JSON so CR/LF will be escaped.
189 clean_event,
190 indent=4, sort_keys=True,
191 default=str, # assertion is in bytes in Python 3
192 ))
193 return self.__add(event, now=now)
195 def __parse_account(self, response, id_token_claims):
196 """Return client_info and home_account_id"""
197 if "client_info" in response: # It happens when client_info and profile are in request
198 client_info = json.loads(decode_part(response["client_info"]))
199 if "uid" in client_info and "utid" in client_info:
200 return client_info, "{uid}.{utid}".format(**client_info)
201 # https://github.com/AzureAD/microsoft-authentication-library-for-python/issues/387
202 if id_token_claims: # This would be an end user on ADFS-direct scenario
203 sub = id_token_claims["sub"] # "sub" always exists, per OIDC specs
204 return {"uid": sub}, sub
205 # client_credentials flow will reach this code path
206 return {}, None
208 def __add(self, event, now=None):
209 # event typically contains: client_id, scope, token_endpoint,
210 # response, params, data, grant_type
211 environment = realm = None
212 if "token_endpoint" in event:
213 _, environment, realm = canonicalize(event["token_endpoint"])
214 if "environment" in event: # Always available unless in legacy test cases
215 environment = event["environment"] # Set by application.py
216 response = event.get("response", {})
217 data = event.get("data", {})
218 access_token = response.get("access_token")
219 refresh_token = response.get("refresh_token")
220 id_token = response.get("id_token")
221 id_token_claims = response.get("id_token_claims") or ( # Prefer the claims from broker
222 # Only use decode_id_token() when necessary, it contains time-sensitive validation
223 decode_id_token(id_token, client_id=event["client_id"]) if id_token else {})
224 client_info, home_account_id = self.__parse_account(response, id_token_claims)
226 target = ' '.join(sorted(event.get("scope") or [])) # Schema should have required sorting
228 with self._lock:
229 now = int(time.time() if now is None else now)
231 if access_token:
232 default_expires_in = ( # https://www.rfc-editor.org/rfc/rfc6749#section-5.1
233 int(response.get("expires_on")) - now # Some Managed Identity emits this
234 ) if response.get("expires_on") else 600
235 expires_in = int( # AADv1-like endpoint returns a string
236 response.get("expires_in", default_expires_in))
237 ext_expires_in = int( # AADv1-like endpoint returns a string
238 response.get("ext_expires_in", expires_in))
239 at = {
240 "credential_type": self.CredentialType.ACCESS_TOKEN,
241 "secret": access_token,
242 "home_account_id": home_account_id,
243 "environment": environment,
244 "client_id": event.get("client_id"),
245 "target": target,
246 "realm": realm,
247 "token_type": response.get("token_type", "Bearer"),
248 "cached_at": str(now), # Schema defines it as a string
249 "expires_on": str(now + expires_in), # Same here
250 "extended_expires_on": str(now + ext_expires_in) # Same here
251 }
252 if data.get("key_id"): # It happens in SSH-cert or POP scenario
253 at["key_id"] = data.get("key_id")
254 if "refresh_in" in response:
255 refresh_in = response["refresh_in"] # It is an integer
256 at["refresh_on"] = str(now + refresh_in) # Schema wants a string
257 self.modify(self.CredentialType.ACCESS_TOKEN, at, at)
259 if client_info and not event.get("skip_account_creation"):
260 account = {
261 "home_account_id": home_account_id,
262 "environment": environment,
263 "realm": realm,
264 "local_account_id": event.get(
265 "_account_id", # Came from mid-tier code path.
266 # Emperically, it is the oid in AAD or cid in MSA.
267 id_token_claims.get("oid", id_token_claims.get("sub"))),
268 "username": _get_username(id_token_claims)
269 or data.get("username") # Falls back to ROPC username
270 or event.get("username") # Falls back to Federated ROPC username
271 or "", # The schema does not like null
272 "authority_type": event.get(
273 "authority_type", # Honor caller's choice of authority_type
274 self.AuthorityType.ADFS if realm == "adfs"
275 else self.AuthorityType.MSSTS),
276 # "client_info": response.get("client_info"), # Optional
277 }
278 grant_types_that_establish_an_account = (
279 _GRANT_TYPE_BROKER, "authorization_code", "password",
280 Client.DEVICE_FLOW["GRANT_TYPE"])
281 if event.get("grant_type") in grant_types_that_establish_an_account:
282 account["account_source"] = event["grant_type"]
283 self.modify(self.CredentialType.ACCOUNT, account, account)
285 if id_token:
286 idt = {
287 "credential_type": self.CredentialType.ID_TOKEN,
288 "secret": id_token,
289 "home_account_id": home_account_id,
290 "environment": environment,
291 "realm": realm,
292 "client_id": event.get("client_id"),
293 # "authority": "it is optional",
294 }
295 self.modify(self.CredentialType.ID_TOKEN, idt, idt)
297 if refresh_token:
298 rt = {
299 "credential_type": self.CredentialType.REFRESH_TOKEN,
300 "secret": refresh_token,
301 "home_account_id": home_account_id,
302 "environment": environment,
303 "client_id": event.get("client_id"),
304 "target": target, # Optional per schema though
305 "last_modification_time": str(now), # Optional. Schema defines it as a string.
306 }
307 if "foci" in response:
308 rt["family_id"] = response["foci"]
309 self.modify(self.CredentialType.REFRESH_TOKEN, rt, rt)
311 app_metadata = {
312 "client_id": event.get("client_id"),
313 "environment": environment,
314 }
315 if "foci" in response:
316 app_metadata["family_id"] = response.get("foci")
317 self.modify(self.CredentialType.APP_METADATA, app_metadata, app_metadata)
319 def modify(self, credential_type, old_entry, new_key_value_pairs=None):
320 # Modify the specified old_entry with new_key_value_pairs,
321 # or remove the old_entry if the new_key_value_pairs is None.
323 # This helper exists to consolidate all token add/modify/remove behaviors,
324 # so that the sub-classes will have only one method to work on,
325 # instead of patching a pair of update_xx() and remove_xx() per type.
326 # You can monkeypatch self.key_makers to support more types on-the-fly.
327 key = self.key_makers[credential_type](**old_entry)
328 with self._lock:
329 if new_key_value_pairs: # Update with them
330 entries = self._cache.setdefault(credential_type, {})
331 entries[key] = dict(
332 old_entry, # Do not use entries[key] b/c it might not exist
333 **new_key_value_pairs)
334 else: # Remove old_entry
335 self._cache.setdefault(credential_type, {}).pop(key, None)
337 def remove_rt(self, rt_item):
338 assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN
339 return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item)
341 def update_rt(self, rt_item, new_rt):
342 assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN
343 return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item, {
344 "secret": new_rt,
345 "last_modification_time": str(int(time.time())), # Optional. Schema defines it as a string.
346 })
348 def remove_at(self, at_item):
349 assert at_item.get("credential_type") == self.CredentialType.ACCESS_TOKEN
350 return self.modify(self.CredentialType.ACCESS_TOKEN, at_item)
352 def remove_idt(self, idt_item):
353 assert idt_item.get("credential_type") == self.CredentialType.ID_TOKEN
354 return self.modify(self.CredentialType.ID_TOKEN, idt_item)
356 def remove_account(self, account_item):
357 assert "authority_type" in account_item
358 return self.modify(self.CredentialType.ACCOUNT, account_item)
361class SerializableTokenCache(TokenCache):
362 """This serialization can be a starting point to implement your own persistence.
364 This class does NOT actually persist the cache on disk/db/etc..
365 Depending on your need,
366 the following simple recipe for file-based persistence may be sufficient::
368 import os, atexit, msal
369 cache_filename = os.path.join( # Persist cache into this file
370 os.getenv("XDG_RUNTIME_DIR", ""), # Automatically wipe out the cache from Linux when user's ssh session ends. See also https://github.com/AzureAD/microsoft-authentication-library-for-python/issues/690
371 "my_cache.bin")
372 cache = msal.SerializableTokenCache()
373 if os.path.exists(cache_filename):
374 cache.deserialize(open(cache_filename, "r").read())
375 atexit.register(lambda:
376 open(cache_filename, "w").write(cache.serialize())
377 # Hint: The following optional line persists only when state changed
378 if cache.has_state_changed else None
379 )
380 app = msal.ClientApplication(..., token_cache=cache)
381 ...
383 :var bool has_state_changed:
384 Indicates whether the cache state in the memory has changed since last
385 :func:`~serialize` or :func:`~deserialize` call.
386 """
387 has_state_changed = False
389 def add(self, event, **kwargs):
390 super(SerializableTokenCache, self).add(event, **kwargs)
391 self.has_state_changed = True
393 def modify(self, credential_type, old_entry, new_key_value_pairs=None):
394 super(SerializableTokenCache, self).modify(
395 credential_type, old_entry, new_key_value_pairs)
396 self.has_state_changed = True
398 def deserialize(self, state):
399 # type: (Optional[str]) -> None
400 """Deserialize the cache from a state previously obtained by serialize()"""
401 with self._lock:
402 self._cache = json.loads(state) if state else {}
403 self.has_state_changed = False # reset
405 def serialize(self):
406 # type: () -> str
407 """Serialize the current cache state into a string."""
408 with self._lock:
409 self.has_state_changed = False
410 return json.dumps(self._cache, indent=4)