Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/msal/token_cache.py: 29%
126 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:20 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:20 +0000
1import json
2import threading
3import time
4import logging
6from .authority import canonicalize
7from .oauth2cli.oidc import decode_part, decode_id_token
10logger = logging.getLogger(__name__)
12def is_subdict_of(small, big):
13 return dict(big, **small) == big
15def _get_username(id_token_claims):
16 return id_token_claims.get(
17 "preferred_username", # AAD
18 id_token_claims.get("upn")) # ADFS 2019
20class TokenCache(object):
21 """This is considered as a base class containing minimal cache behavior.
23 Although it maintains tokens using unified schema across all MSAL libraries,
24 this class does not serialize/persist them.
25 See subclass :class:`SerializableTokenCache` for details on serialization.
26 """
28 class CredentialType:
29 ACCESS_TOKEN = "AccessToken"
30 REFRESH_TOKEN = "RefreshToken"
31 ACCOUNT = "Account" # Not exactly a credential type, but we put it here
32 ID_TOKEN = "IdToken"
33 APP_METADATA = "AppMetadata"
35 class AuthorityType:
36 ADFS = "ADFS"
37 MSSTS = "MSSTS" # MSSTS means AAD v2 for both AAD & MSA
39 def __init__(self):
40 self._lock = threading.RLock()
41 self._cache = {}
42 self.key_makers = {
43 self.CredentialType.REFRESH_TOKEN:
44 lambda home_account_id=None, environment=None, client_id=None,
45 target=None, **ignored_payload_from_a_real_token:
46 "-".join([
47 home_account_id or "",
48 environment or "",
49 self.CredentialType.REFRESH_TOKEN,
50 client_id or "",
51 "", # RT is cross-tenant in AAD
52 target or "", # raw value could be None if deserialized from other SDK
53 ]).lower(),
54 self.CredentialType.ACCESS_TOKEN:
55 lambda home_account_id=None, environment=None, client_id=None,
56 realm=None, target=None, **ignored_payload_from_a_real_token:
57 "-".join([
58 home_account_id or "",
59 environment or "",
60 self.CredentialType.ACCESS_TOKEN,
61 client_id or "",
62 realm or "",
63 target or "",
64 ]).lower(),
65 self.CredentialType.ID_TOKEN:
66 lambda home_account_id=None, environment=None, client_id=None,
67 realm=None, **ignored_payload_from_a_real_token:
68 "-".join([
69 home_account_id or "",
70 environment or "",
71 self.CredentialType.ID_TOKEN,
72 client_id or "",
73 realm or "",
74 "" # Albeit irrelevant, schema requires an empty scope here
75 ]).lower(),
76 self.CredentialType.ACCOUNT:
77 lambda home_account_id=None, environment=None, realm=None,
78 **ignored_payload_from_a_real_entry:
79 "-".join([
80 home_account_id or "",
81 environment or "",
82 realm or "",
83 ]).lower(),
84 self.CredentialType.APP_METADATA:
85 lambda environment=None, client_id=None, **kwargs:
86 "appmetadata-{}-{}".format(environment or "", client_id or ""),
87 }
89 def find(self, credential_type, target=None, query=None):
90 target = target or []
91 assert isinstance(target, list), "Invalid parameter type"
92 target_set = set(target)
93 with self._lock:
94 # Since the target inside token cache key is (per schema) unsorted,
95 # there is no point to attempt an O(1) key-value search here.
96 # So we always do an O(n) in-memory search.
97 return [entry
98 for entry in self._cache.get(credential_type, {}).values()
99 if is_subdict_of(query or {}, entry)
100 and (target_set <= set(entry.get("target", "").split())
101 if target else True)
102 ]
104 def add(self, event, now=None):
105 # type: (dict) -> None
106 """Handle a token obtaining event, and add tokens into cache."""
107 def make_clean_copy(dictionary, sensitive_fields): # Masks sensitive info
108 return {
109 k: "********" if k in sensitive_fields else v
110 for k, v in dictionary.items()
111 }
112 clean_event = dict(
113 event,
114 data=make_clean_copy(event.get("data", {}), (
115 "password", "client_secret", "refresh_token", "assertion",
116 )),
117 response=make_clean_copy(event.get("response", {}), (
118 "id_token_claims", # Provided by broker
119 "access_token", "refresh_token", "id_token", "username",
120 )),
121 )
122 logger.debug("event=%s", json.dumps(
123 # We examined and concluded that this log won't have Log Injection risk,
124 # because the event payload is already in JSON so CR/LF will be escaped.
125 clean_event,
126 indent=4, sort_keys=True,
127 default=str, # assertion is in bytes in Python 3
128 ))
129 return self.__add(event, now=now)
131 def __parse_account(self, response, id_token_claims):
132 """Return client_info and home_account_id"""
133 if "client_info" in response: # It happens when client_info and profile are in request
134 client_info = json.loads(decode_part(response["client_info"]))
135 if "uid" in client_info and "utid" in client_info:
136 return client_info, "{uid}.{utid}".format(**client_info)
137 # https://github.com/AzureAD/microsoft-authentication-library-for-python/issues/387
138 if id_token_claims: # This would be an end user on ADFS-direct scenario
139 sub = id_token_claims["sub"] # "sub" always exists, per OIDC specs
140 return {"uid": sub}, sub
141 # client_credentials flow will reach this code path
142 return {}, None
144 def __add(self, event, now=None):
145 # event typically contains: client_id, scope, token_endpoint,
146 # response, params, data, grant_type
147 environment = realm = None
148 if "token_endpoint" in event:
149 _, environment, realm = canonicalize(event["token_endpoint"])
150 if "environment" in event: # Always available unless in legacy test cases
151 environment = event["environment"] # Set by application.py
152 response = event.get("response", {})
153 data = event.get("data", {})
154 access_token = response.get("access_token")
155 refresh_token = response.get("refresh_token")
156 id_token = response.get("id_token")
157 id_token_claims = response.get("id_token_claims") or ( # Prefer the claims from broker
158 # Only use decode_id_token() when necessary, it contains time-sensitive validation
159 decode_id_token(id_token, client_id=event["client_id"]) if id_token else {})
160 client_info, home_account_id = self.__parse_account(response, id_token_claims)
162 target = ' '.join(event.get("scope") or []) # Per schema, we don't sort it
164 with self._lock:
165 now = int(time.time() if now is None else now)
167 if access_token:
168 default_expires_in = ( # https://www.rfc-editor.org/rfc/rfc6749#section-5.1
169 int(response.get("expires_on")) - now # Some Managed Identity emits this
170 ) if response.get("expires_on") else 600
171 expires_in = int( # AADv1-like endpoint returns a string
172 response.get("expires_in", default_expires_in))
173 ext_expires_in = int( # AADv1-like endpoint returns a string
174 response.get("ext_expires_in", expires_in))
175 at = {
176 "credential_type": self.CredentialType.ACCESS_TOKEN,
177 "secret": access_token,
178 "home_account_id": home_account_id,
179 "environment": environment,
180 "client_id": event.get("client_id"),
181 "target": target,
182 "realm": realm,
183 "token_type": response.get("token_type", "Bearer"),
184 "cached_at": str(now), # Schema defines it as a string
185 "expires_on": str(now + expires_in), # Same here
186 "extended_expires_on": str(now + ext_expires_in) # Same here
187 }
188 if data.get("key_id"): # It happens in SSH-cert or POP scenario
189 at["key_id"] = data.get("key_id")
190 if "refresh_in" in response:
191 refresh_in = response["refresh_in"] # It is an integer
192 at["refresh_on"] = str(now + refresh_in) # Schema wants a string
193 self.modify(self.CredentialType.ACCESS_TOKEN, at, at)
195 if client_info and not event.get("skip_account_creation"):
196 account = {
197 "home_account_id": home_account_id,
198 "environment": environment,
199 "realm": realm,
200 "local_account_id": event.get(
201 "_account_id", # Came from mid-tier code path.
202 # Emperically, it is the oid in AAD or cid in MSA.
203 id_token_claims.get("oid", id_token_claims.get("sub"))),
204 "username": _get_username(id_token_claims)
205 or data.get("username") # Falls back to ROPC username
206 or event.get("username") # Falls back to Federated ROPC username
207 or "", # The schema does not like null
208 "authority_type": event.get(
209 "authority_type", # Honor caller's choice of authority_type
210 self.AuthorityType.ADFS if realm == "adfs"
211 else self.AuthorityType.MSSTS),
212 # "client_info": response.get("client_info"), # Optional
213 }
214 self.modify(self.CredentialType.ACCOUNT, account, account)
216 if id_token:
217 idt = {
218 "credential_type": self.CredentialType.ID_TOKEN,
219 "secret": id_token,
220 "home_account_id": home_account_id,
221 "environment": environment,
222 "realm": realm,
223 "client_id": event.get("client_id"),
224 # "authority": "it is optional",
225 }
226 self.modify(self.CredentialType.ID_TOKEN, idt, idt)
228 if refresh_token:
229 rt = {
230 "credential_type": self.CredentialType.REFRESH_TOKEN,
231 "secret": refresh_token,
232 "home_account_id": home_account_id,
233 "environment": environment,
234 "client_id": event.get("client_id"),
235 "target": target, # Optional per schema though
236 "last_modification_time": str(now), # Optional. Schema defines it as a string.
237 }
238 if "foci" in response:
239 rt["family_id"] = response["foci"]
240 self.modify(self.CredentialType.REFRESH_TOKEN, rt, rt)
242 app_metadata = {
243 "client_id": event.get("client_id"),
244 "environment": environment,
245 }
246 if "foci" in response:
247 app_metadata["family_id"] = response.get("foci")
248 self.modify(self.CredentialType.APP_METADATA, app_metadata, app_metadata)
250 def modify(self, credential_type, old_entry, new_key_value_pairs=None):
251 # Modify the specified old_entry with new_key_value_pairs,
252 # or remove the old_entry if the new_key_value_pairs is None.
254 # This helper exists to consolidate all token add/modify/remove behaviors,
255 # so that the sub-classes will have only one method to work on,
256 # instead of patching a pair of update_xx() and remove_xx() per type.
257 # You can monkeypatch self.key_makers to support more types on-the-fly.
258 key = self.key_makers[credential_type](**old_entry)
259 with self._lock:
260 if new_key_value_pairs: # Update with them
261 entries = self._cache.setdefault(credential_type, {})
262 entries[key] = dict(
263 old_entry, # Do not use entries[key] b/c it might not exist
264 **new_key_value_pairs)
265 else: # Remove old_entry
266 self._cache.setdefault(credential_type, {}).pop(key, None)
268 def remove_rt(self, rt_item):
269 assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN
270 return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item)
272 def update_rt(self, rt_item, new_rt):
273 assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN
274 return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item, {
275 "secret": new_rt,
276 "last_modification_time": str(int(time.time())), # Optional. Schema defines it as a string.
277 })
279 def remove_at(self, at_item):
280 assert at_item.get("credential_type") == self.CredentialType.ACCESS_TOKEN
281 return self.modify(self.CredentialType.ACCESS_TOKEN, at_item)
283 def remove_idt(self, idt_item):
284 assert idt_item.get("credential_type") == self.CredentialType.ID_TOKEN
285 return self.modify(self.CredentialType.ID_TOKEN, idt_item)
287 def remove_account(self, account_item):
288 assert "authority_type" in account_item
289 return self.modify(self.CredentialType.ACCOUNT, account_item)
292class SerializableTokenCache(TokenCache):
293 """This serialization can be a starting point to implement your own persistence.
295 This class does NOT actually persist the cache on disk/db/etc..
296 Depending on your need,
297 the following simple recipe for file-based persistence may be sufficient::
299 import os, atexit, msal
300 cache = msal.SerializableTokenCache()
301 if os.path.exists("my_cache.bin"):
302 cache.deserialize(open("my_cache.bin", "r").read())
303 atexit.register(lambda:
304 open("my_cache.bin", "w").write(cache.serialize())
305 # Hint: The following optional line persists only when state changed
306 if cache.has_state_changed else None
307 )
308 app = msal.ClientApplication(..., token_cache=cache)
309 ...
311 :var bool has_state_changed:
312 Indicates whether the cache state in the memory has changed since last
313 :func:`~serialize` or :func:`~deserialize` call.
314 """
315 has_state_changed = False
317 def add(self, event, **kwargs):
318 super(SerializableTokenCache, self).add(event, **kwargs)
319 self.has_state_changed = True
321 def modify(self, credential_type, old_entry, new_key_value_pairs=None):
322 super(SerializableTokenCache, self).modify(
323 credential_type, old_entry, new_key_value_pairs)
324 self.has_state_changed = True
326 def deserialize(self, state):
327 # type: (Optional[str]) -> None
328 """Deserialize the cache from a state previously obtained by serialize()"""
329 with self._lock:
330 self._cache = json.loads(state) if state else {}
331 self.has_state_changed = False # reset
333 def serialize(self):
334 # type: () -> str
335 """Serialize the current cache state into a string."""
336 with self._lock:
337 self.has_state_changed = False
338 return json.dumps(self._cache, indent=4)