1import json
2import threading
3import time
4import logging
5import warnings
6
7from .authority import canonicalize
8from .oauth2cli.oidc import decode_part, decode_id_token
9from .oauth2cli.oauth2 import Client
10
11
12logger = logging.getLogger(__name__)
13_GRANT_TYPE_BROKER = "broker"
14
15def is_subdict_of(small, big):
16 return dict(big, **small) == big
17
18def _get_username(id_token_claims):
19 return id_token_claims.get(
20 "preferred_username", # AAD
21 id_token_claims.get("upn")) # ADFS 2019
22
23class TokenCache(object):
24 """This is considered as a base class containing minimal cache behavior.
25
26 Although it maintains tokens using unified schema across all MSAL libraries,
27 this class does not serialize/persist them.
28 See subclass :class:`SerializableTokenCache` for details on serialization.
29 """
30
31 class CredentialType:
32 ACCESS_TOKEN = "AccessToken"
33 REFRESH_TOKEN = "RefreshToken"
34 ACCOUNT = "Account" # Not exactly a credential type, but we put it here
35 ID_TOKEN = "IdToken"
36 APP_METADATA = "AppMetadata"
37
38 class AuthorityType:
39 ADFS = "ADFS"
40 MSSTS = "MSSTS" # MSSTS means AAD v2 for both AAD & MSA
41
42 def __init__(self):
43 self._lock = threading.RLock()
44 self._cache = {}
45 self.key_makers = {
46 # Note: We have changed token key format before when ordering scopes;
47 # changing token key won't result in cache miss.
48 self.CredentialType.REFRESH_TOKEN:
49 lambda home_account_id=None, environment=None, client_id=None,
50 target=None, **ignored_payload_from_a_real_token:
51 "-".join([
52 home_account_id or "",
53 environment or "",
54 self.CredentialType.REFRESH_TOKEN,
55 client_id or "",
56 "", # RT is cross-tenant in AAD
57 target or "", # raw value could be None if deserialized from other SDK
58 ]).lower(),
59 self.CredentialType.ACCESS_TOKEN:
60 lambda home_account_id=None, environment=None, client_id=None,
61 realm=None, target=None,
62 # Note: New field(s) can be added here
63 #key_id=None,
64 **ignored_payload_from_a_real_token:
65 "-".join([ # Note: Could use a hash here to shorten key length
66 home_account_id or "",
67 environment or "",
68 self.CredentialType.ACCESS_TOKEN,
69 client_id or "",
70 realm or "",
71 target or "",
72 #key_id or "", # So ATs of different key_id can coexist
73 ]).lower(),
74 self.CredentialType.ID_TOKEN:
75 lambda home_account_id=None, environment=None, client_id=None,
76 realm=None, **ignored_payload_from_a_real_token:
77 "-".join([
78 home_account_id or "",
79 environment or "",
80 self.CredentialType.ID_TOKEN,
81 client_id or "",
82 realm or "",
83 "" # Albeit irrelevant, schema requires an empty scope here
84 ]).lower(),
85 self.CredentialType.ACCOUNT:
86 lambda home_account_id=None, environment=None, realm=None,
87 **ignored_payload_from_a_real_entry:
88 "-".join([
89 home_account_id or "",
90 environment or "",
91 realm or "",
92 ]).lower(),
93 self.CredentialType.APP_METADATA:
94 lambda environment=None, client_id=None, **kwargs:
95 "appmetadata-{}-{}".format(environment or "", client_id or ""),
96 }
97
98 def _get_access_token(
99 self,
100 home_account_id, environment, client_id, realm, target, # Together they form a compound key
101 default=None,
102 ): # O(1)
103 return self._get(
104 self.CredentialType.ACCESS_TOKEN,
105 self.key_makers[TokenCache.CredentialType.ACCESS_TOKEN](
106 home_account_id=home_account_id,
107 environment=environment,
108 client_id=client_id,
109 realm=realm,
110 target=" ".join(target),
111 ),
112 default=default)
113
114 def _get_app_metadata(self, environment, client_id, default=None): # O(1)
115 return self._get(
116 self.CredentialType.APP_METADATA,
117 self.key_makers[TokenCache.CredentialType.APP_METADATA](
118 environment=environment,
119 client_id=client_id,
120 ),
121 default=default)
122
123 def _get(self, credential_type, key, default=None): # O(1)
124 with self._lock:
125 return self._cache.get(credential_type, {}).get(key, default)
126
127 @staticmethod
128 def _is_matching(entry: dict, query: dict, target_set: set = None) -> bool:
129 query_with_lowercase_environment = {
130 # __add() canonicalized entry's environment value to lower case,
131 # so we do the same here.
132 k: v.lower() if k == "environment" and isinstance(v, str) else v
133 for k, v in query.items()
134 } if query else {}
135 return is_subdict_of(query_with_lowercase_environment, entry) and (
136 target_set <= set(entry.get("target", "").split())
137 if target_set else True)
138
139 def search(self, credential_type, target=None, query=None, *, now=None): # O(n) generator
140 """Returns a generator of matching entries.
141
142 It is O(1) for AT hits, and O(n) for other types.
143 Note that it holds a lock during the entire search.
144 """
145 target = sorted(target or []) # Match the order sorted by add()
146 assert isinstance(target, list), "Invalid parameter type"
147
148 preferred_result = None
149 if (credential_type == self.CredentialType.ACCESS_TOKEN
150 and isinstance(query, dict)
151 and "home_account_id" in query and "environment" in query
152 and "client_id" in query and "realm" in query and target
153 ): # Special case for O(1) AT lookup
154 preferred_result = self._get_access_token(
155 query["home_account_id"], query["environment"],
156 query["client_id"], query["realm"], target)
157 if preferred_result and self._is_matching(
158 preferred_result, query,
159 # Needs no target_set here because it is satisfied by dict key
160 ):
161 yield preferred_result
162
163 target_set = set(target)
164 with self._lock:
165 # O(n) search. The key is NOT used in search.
166 now = int(time.time() if now is None else now)
167 expired_access_tokens = [
168 # Especially when/if we key ATs by ephemeral fields such as key_id,
169 # stale ATs keyed by an old key_id would stay forever.
170 # Here we collect them for their removal.
171 ]
172 for entry in self._cache.get(credential_type, {}).values():
173 if ( # Automatically delete expired access tokens
174 credential_type == self.CredentialType.ACCESS_TOKEN
175 and int(entry["expires_on"]) < now
176 ):
177 expired_access_tokens.append(entry) # Can't delete them within current for-loop
178 continue
179 if (entry != preferred_result # Avoid yielding the same entry twice
180 and self._is_matching(entry, query, target_set=target_set)
181 ):
182 yield entry
183 for at in expired_access_tokens:
184 self.remove_at(at)
185
186 def find(self, credential_type, target=None, query=None, *, now=None):
187 """Equivalent to list(search(...))."""
188 warnings.warn(
189 "Use list(search(...)) instead to explicitly get a list.",
190 DeprecationWarning)
191 return list(self.search(credential_type, target=target, query=query, now=now))
192
193 def add(self, event, now=None):
194 """Handle a token obtaining event, and add tokens into cache."""
195 def make_clean_copy(dictionary, sensitive_fields): # Masks sensitive info
196 return {
197 k: "********" if k in sensitive_fields else v
198 for k, v in dictionary.items()
199 }
200 clean_event = dict(
201 event,
202 data=make_clean_copy(event.get("data", {}), (
203 "password", "client_secret", "refresh_token", "assertion",
204 )),
205 response=make_clean_copy(event.get("response", {}), (
206 "id_token_claims", # Provided by broker
207 "access_token", "refresh_token", "id_token", "username",
208 )),
209 )
210 logger.debug("event=%s", json.dumps(
211 # We examined and concluded that this log won't have Log Injection risk,
212 # because the event payload is already in JSON so CR/LF will be escaped.
213 clean_event,
214 indent=4, sort_keys=True,
215 default=str, # assertion is in bytes in Python 3
216 ))
217 return self.__add(event, now=now)
218
219 def __parse_account(self, response, id_token_claims):
220 """Return client_info and home_account_id"""
221 if "client_info" in response: # It happens when client_info and profile are in request
222 client_info = json.loads(decode_part(response["client_info"]))
223 if "uid" in client_info and "utid" in client_info:
224 return client_info, "{uid}.{utid}".format(**client_info)
225 # https://github.com/AzureAD/microsoft-authentication-library-for-python/issues/387
226 if id_token_claims: # This would be an end user on ADFS-direct scenario
227 sub = id_token_claims["sub"] # "sub" always exists, per OIDC specs
228 return {"uid": sub}, sub
229 # client_credentials flow will reach this code path
230 return {}, None
231
232 def __add(self, event, now=None):
233 # event typically contains: client_id, scope, token_endpoint,
234 # response, params, data, grant_type
235 environment = realm = None
236 if "token_endpoint" in event:
237 _, environment, realm = canonicalize(event["token_endpoint"])
238 if "environment" in event: # Always available unless in legacy test cases
239 environment = event["environment"] # Set by application.py
240 response = event.get("response", {})
241 data = event.get("data", {})
242 access_token = response.get("access_token")
243 refresh_token = response.get("refresh_token")
244 id_token = response.get("id_token")
245 id_token_claims = response.get("id_token_claims") or ( # Prefer the claims from broker
246 # Only use decode_id_token() when necessary, it contains time-sensitive validation
247 decode_id_token(id_token, client_id=event["client_id"]) if id_token else {})
248 client_info, home_account_id = self.__parse_account(response, id_token_claims)
249
250 target = ' '.join(sorted(event.get("scope") or [])) # Schema should have required sorting
251
252 with self._lock:
253 now = int(time.time() if now is None else now)
254
255 if access_token:
256 default_expires_in = ( # https://www.rfc-editor.org/rfc/rfc6749#section-5.1
257 int(response.get("expires_on")) - now # Some Managed Identity emits this
258 ) if response.get("expires_on") else 600
259 expires_in = int( # AADv1-like endpoint returns a string
260 response.get("expires_in", default_expires_in))
261 ext_expires_in = int( # AADv1-like endpoint returns a string
262 response.get("ext_expires_in", expires_in))
263 at = {
264 "credential_type": self.CredentialType.ACCESS_TOKEN,
265 "secret": access_token,
266 "home_account_id": home_account_id,
267 "environment": environment,
268 "client_id": event.get("client_id"),
269 "target": target,
270 "realm": realm,
271 "token_type": response.get("token_type", "Bearer"),
272 "cached_at": str(now), # Schema defines it as a string
273 "expires_on": str(now + expires_in), # Same here
274 "extended_expires_on": str(now + ext_expires_in) # Same here
275 }
276 at.update({k: data[k] for k in data if k in {
277 # Also store extra data which we explicitly allow
278 # So that we won't accidentally store a user's password etc.
279 "key_id", # It happens in SSH-cert or POP scenario
280 }})
281 if "refresh_in" in response:
282 refresh_in = response["refresh_in"] # It is an integer
283 at["refresh_on"] = str(now + refresh_in) # Schema wants a string
284 self.modify(self.CredentialType.ACCESS_TOKEN, at, at)
285
286 if client_info and not event.get("skip_account_creation"):
287 account = {
288 "home_account_id": home_account_id,
289 "environment": environment,
290 "realm": realm,
291 "local_account_id": event.get(
292 "_account_id", # Came from mid-tier code path.
293 # Emperically, it is the oid in AAD or cid in MSA.
294 id_token_claims.get("oid", id_token_claims.get("sub"))),
295 "username": _get_username(id_token_claims)
296 or data.get("username") # Falls back to ROPC username
297 or event.get("username") # Falls back to Federated ROPC username
298 or "", # The schema does not like null
299 "authority_type": event.get(
300 "authority_type", # Honor caller's choice of authority_type
301 self.AuthorityType.ADFS if realm == "adfs"
302 else self.AuthorityType.MSSTS),
303 # "client_info": response.get("client_info"), # Optional
304 }
305 grant_types_that_establish_an_account = (
306 _GRANT_TYPE_BROKER, "authorization_code", "password",
307 Client.DEVICE_FLOW["GRANT_TYPE"])
308 if event.get("grant_type") in grant_types_that_establish_an_account:
309 account["account_source"] = event["grant_type"]
310 self.modify(self.CredentialType.ACCOUNT, account, account)
311
312 if id_token:
313 idt = {
314 "credential_type": self.CredentialType.ID_TOKEN,
315 "secret": id_token,
316 "home_account_id": home_account_id,
317 "environment": environment,
318 "realm": realm,
319 "client_id": event.get("client_id"),
320 # "authority": "it is optional",
321 }
322 self.modify(self.CredentialType.ID_TOKEN, idt, idt)
323
324 if refresh_token:
325 rt = {
326 "credential_type": self.CredentialType.REFRESH_TOKEN,
327 "secret": refresh_token,
328 "home_account_id": home_account_id,
329 "environment": environment,
330 "client_id": event.get("client_id"),
331 "target": target, # Optional per schema though
332 "last_modification_time": str(now), # Optional. Schema defines it as a string.
333 }
334 if "foci" in response:
335 rt["family_id"] = response["foci"]
336 self.modify(self.CredentialType.REFRESH_TOKEN, rt, rt)
337
338 app_metadata = {
339 "client_id": event.get("client_id"),
340 "environment": environment,
341 }
342 if "foci" in response:
343 app_metadata["family_id"] = response.get("foci")
344 self.modify(self.CredentialType.APP_METADATA, app_metadata, app_metadata)
345
346 def modify(self, credential_type, old_entry, new_key_value_pairs=None):
347 # Modify the specified old_entry with new_key_value_pairs,
348 # or remove the old_entry if the new_key_value_pairs is None.
349
350 # This helper exists to consolidate all token add/modify/remove behaviors,
351 # so that the sub-classes will have only one method to work on,
352 # instead of patching a pair of update_xx() and remove_xx() per type.
353 # You can monkeypatch self.key_makers to support more types on-the-fly.
354 key = self.key_makers[credential_type](**old_entry)
355 with self._lock:
356 if new_key_value_pairs: # Update with them
357 entries = self._cache.setdefault(credential_type, {})
358 entries[key] = dict(
359 old_entry, # Do not use entries[key] b/c it might not exist
360 **new_key_value_pairs)
361 else: # Remove old_entry
362 self._cache.setdefault(credential_type, {}).pop(key, None)
363
364 def remove_rt(self, rt_item):
365 assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN
366 return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item)
367
368 def update_rt(self, rt_item, new_rt):
369 assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN
370 return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item, {
371 "secret": new_rt,
372 "last_modification_time": str(int(time.time())), # Optional. Schema defines it as a string.
373 })
374
375 def remove_at(self, at_item):
376 assert at_item.get("credential_type") == self.CredentialType.ACCESS_TOKEN
377 return self.modify(self.CredentialType.ACCESS_TOKEN, at_item)
378
379 def remove_idt(self, idt_item):
380 assert idt_item.get("credential_type") == self.CredentialType.ID_TOKEN
381 return self.modify(self.CredentialType.ID_TOKEN, idt_item)
382
383 def remove_account(self, account_item):
384 assert "authority_type" in account_item
385 return self.modify(self.CredentialType.ACCOUNT, account_item)
386
387
388class SerializableTokenCache(TokenCache):
389 """This serialization can be a starting point to implement your own persistence.
390
391 This class does NOT actually persist the cache on disk/db/etc..
392 Depending on your need,
393 the following simple recipe for file-based, unencrypted persistence may be sufficient::
394
395 import os, atexit, msal
396 cache_filename = os.path.join( # Persist cache into this file
397 os.getenv(
398 # Automatically wipe out the cache from Linux when user's ssh session ends.
399 # See also https://github.com/AzureAD/microsoft-authentication-library-for-python/issues/690
400 "XDG_RUNTIME_DIR", ""),
401 "my_cache.bin")
402 cache = msal.SerializableTokenCache()
403 if os.path.exists(cache_filename):
404 cache.deserialize(open(cache_filename, "r").read())
405 atexit.register(lambda:
406 open(cache_filename, "w").write(cache.serialize())
407 # Hint: The following optional line persists only when state changed
408 if cache.has_state_changed else None
409 )
410 app = msal.ClientApplication(..., token_cache=cache)
411 ...
412
413 Alternatively, you may use a more sophisticated cache persistence library,
414 `MSAL Extensions <https://github.com/AzureAD/microsoft-authentication-extensions-for-python>`_,
415 which provides token cache persistence with encryption, and more.
416
417 :var bool has_state_changed:
418 Indicates whether the cache state in the memory has changed since last
419 :func:`~serialize` or :func:`~deserialize` call.
420 """
421 has_state_changed = False
422
423 def add(self, event, **kwargs):
424 super(SerializableTokenCache, self).add(event, **kwargs)
425 self.has_state_changed = True
426
427 def modify(self, credential_type, old_entry, new_key_value_pairs=None):
428 super(SerializableTokenCache, self).modify(
429 credential_type, old_entry, new_key_value_pairs)
430 self.has_state_changed = True
431
432 def deserialize(self, state):
433 # type: (Optional[str]) -> None
434 """Deserialize the cache from a state previously obtained by serialize()"""
435 with self._lock:
436 self._cache = json.loads(state) if state else {}
437 self.has_state_changed = False # reset
438
439 def serialize(self):
440 # type: () -> str
441 """Serialize the current cache state into a string."""
442 with self._lock:
443 self.has_state_changed = False
444 return json.dumps(self._cache, indent=4)
445