Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.10/site-packages/msal/token_cache.py: 28%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

177 statements  

1import base64 

2import hashlib 

3import json 

4import threading 

5import time 

6import logging 

7import warnings 

8 

9from .authority import canonicalize 

10from .oauth2cli.oidc import decode_part, decode_id_token 

11from .oauth2cli.oauth2 import Client 

12 

13 

14logger = logging.getLogger(__name__) 

15_GRANT_TYPE_BROKER = "broker" 

16 

17# Fields in the request data dict that should NOT be included in the extended 

18# cache key hash. Everything else in data IS included, because those are extra 

19# body parameters going on the wire and must differentiate cached tokens. 

20# 

21# Excluded fields and reasons: 

22# - "client_id" : Standard OAuth2 client identifier, same for every request 

23# - "grant_type" : It is possible to combine grants to get tokens, e.g. obo + refresh_token, auth_code + refresh_token etc. 

24# - "scope" : Already represented as "target" in the AT cache key 

25# - "claims" : Handled separately; its presence forces a token refresh 

26# - "username" : Standard ROPC grant parameter. Tokens are cached by user ID (subject or oid+tid) instead 

27# - "password" : Standard ROPC grant parameter. Tokens are tied to credentials. 

28# - "refresh_token" : Standard refresh grant parameter 

29# - "code" : Standard authorization code grant parameter 

30# - "redirect_uri" : Standard authorization code grant parameter 

31# - "code_verifier" : Standard PKCE parameter 

32# - "device_code" : Standard device flow parameter 

33# - "assertion" : Standard OBO/SAML assertion (RFC 7521) 

34# - "requested_token_use" : OBO indicator ("on_behalf_of"), not an extra param 

35# - "client_assertion" : Client authentication credential (RFC 7521 §4.2) 

36# - "client_assertion_type" : Client authentication type (RFC 7521 §4.2) 

37# - "client_secret" : Client authentication secret 

38# - "token_type" : Used for SSH-cert/POP detection; AT entry stores separately 

39# - "req_cnf" : Ephemeral proof-of-possession nonce, changes per request 

40# - "key_id" : Already handled as a separate cache lookup field 

41# 

42# Included fields (examples — anything NOT in this set is included): 

43# - "fmi_path" : Federated Managed Identity credential path 

44# - any future non-standard body parameter that should isolate cache entries 

45_EXT_CACHE_KEY_EXCLUDED_FIELDS = frozenset({ 

46 # Standard OAuth2 body parameters — these appear in every token request 

47 # and must NOT influence the extended cache key. 

48 # Only non-standard fields (e.g. fmi_path) should contribute to the hash. 

49 "client_id", 

50 "grant_type", 

51 "scope", 

52 "claims", 

53 "username", 

54 "password", 

55 "refresh_token", 

56 "code", 

57 "redirect_uri", 

58 "code_verifier", 

59 "device_code", 

60 "assertion", 

61 "requested_token_use", 

62 "client_assertion", 

63 "client_assertion_type", 

64 "client_secret", 

65 "token_type", 

66 "req_cnf", 

67 "key_id", 

68 # user_fic grant parameters — these are standard body params for the 

69 # user_fic flow; FIC tokens use normal user cache keys (not extended). 

70 "user_federated_identity_credential", 

71 "user_id", 

72 "client_info", 

73}) 

74 

75 

76def _compute_ext_cache_key(data): 

77 """Compute an extended cache key hash from extra body parameters in *data*. 

78 

79 All fields in *data* that go on the wire are included in the hash, 

80 EXCEPT those listed in ``_EXT_CACHE_KEY_EXCLUDED_FIELDS``. 

81 This ensures tokens acquired with different parameter values 

82 (e.g., different FMI paths) are cached separately. 

83 

84 Returns an empty string when *data* has no hashable fields. 

85 

86 The algorithm matches the Go MSAL implementation (CacheExtKeyGenerator): 

87 sorted key+value pairs are concatenated and SHA256 hashed, then base64url encoded. 

88 """ 

89 if not data: 

90 return "" 

91 cache_components = { 

92 k: str(v) for k, v in data.items() 

93 if k not in _EXT_CACHE_KEY_EXCLUDED_FIELDS and v 

94 } 

95 if not cache_components: 

96 return "" 

97 # Sort keys for consistent hashing (matches Go implementation) 

98 key_str = "".join( 

99 k + cache_components[k] for k in sorted(cache_components.keys()) 

100 ) 

101 hash_bytes = hashlib.sha256(key_str.encode("utf-8")).digest() 

102 return base64.urlsafe_b64encode(hash_bytes).rstrip(b"=").decode("ascii").lower() 

103 

104 

105def is_subdict_of(small, big): 

106 return dict(big, **small) == big 

107 

108def _get_username(id_token_claims): 

109 return id_token_claims.get( 

110 "preferred_username", # AAD 

111 id_token_claims.get("upn")) # ADFS 2019 

112 

113class TokenCache(object): 

114 """This is considered as a base class containing minimal cache behavior. 

115 

116 Although it maintains tokens using unified schema across all MSAL libraries, 

117 this class does not serialize/persist them. 

118 See subclass :class:`SerializableTokenCache` for details on serialization. 

119 """ 

120 

121 class CredentialType: 

122 ACCESS_TOKEN = "AccessToken" 

123 ACCESS_TOKEN_EXTENDED = "atext" # Used when ext_cache_key is present (matches Go/dotnet) 

124 REFRESH_TOKEN = "RefreshToken" 

125 ACCOUNT = "Account" # Not exactly a credential type, but we put it here 

126 ID_TOKEN = "IdToken" 

127 APP_METADATA = "AppMetadata" 

128 

129 class AuthorityType: 

130 ADFS = "ADFS" 

131 MSSTS = "MSSTS" # MSSTS means AAD v2 for both AAD & MSA 

132 

133 def __init__(self): 

134 self._lock = threading.RLock() 

135 self._cache = {} 

136 self.key_makers = { 

137 # Note: We have changed token key format before when ordering scopes; 

138 # changing token key won't result in cache miss. 

139 self.CredentialType.REFRESH_TOKEN: 

140 lambda home_account_id=None, environment=None, client_id=None, 

141 target=None, **ignored_payload_from_a_real_token: 

142 "-".join([ 

143 home_account_id or "", 

144 environment or "", 

145 self.CredentialType.REFRESH_TOKEN, 

146 client_id or "", 

147 "", # RT is cross-tenant in AAD 

148 target or "", # raw value could be None if deserialized from other SDK 

149 ]).lower(), 

150 self.CredentialType.ACCESS_TOKEN: 

151 lambda home_account_id=None, environment=None, client_id=None, 

152 realm=None, target=None, 

153 ext_cache_key=None, 

154 # Note: New field(s) can be added here 

155 #key_id=None, 

156 **ignored_payload_from_a_real_token: 

157 "-".join([ # Note: Could use a hash here to shorten key length 

158 home_account_id or "", 

159 environment or "", 

160 # Use "atext" credential type when ext_cache_key is 

161 # present, matching MSAL Go and MSAL .NET behaviour. 

162 "atext" if ext_cache_key else "AccessToken", 

163 client_id or "", 

164 realm or "", 

165 target or "", 

166 #key_id or "", # So ATs of different key_id can coexist 

167 ] + ([ext_cache_key] if ext_cache_key else []) 

168 ).lower(), 

169 self.CredentialType.ID_TOKEN: 

170 lambda home_account_id=None, environment=None, client_id=None, 

171 realm=None, **ignored_payload_from_a_real_token: 

172 "-".join([ 

173 home_account_id or "", 

174 environment or "", 

175 self.CredentialType.ID_TOKEN, 

176 client_id or "", 

177 realm or "", 

178 "" # Albeit irrelevant, schema requires an empty scope here 

179 ]).lower(), 

180 self.CredentialType.ACCOUNT: 

181 lambda home_account_id=None, environment=None, realm=None, 

182 **ignored_payload_from_a_real_entry: 

183 "-".join([ 

184 home_account_id or "", 

185 environment or "", 

186 realm or "", 

187 ]).lower(), 

188 self.CredentialType.APP_METADATA: 

189 lambda environment=None, client_id=None, **kwargs: 

190 "appmetadata-{}-{}".format(environment or "", client_id or ""), 

191 } 

192 

193 def _get_access_token( 

194 self, 

195 home_account_id, environment, client_id, realm, target, # Together they form a compound key 

196 ext_cache_key=None, 

197 default=None, 

198 ): # O(1) 

199 return self._get( 

200 self.CredentialType.ACCESS_TOKEN, 

201 self.key_makers[TokenCache.CredentialType.ACCESS_TOKEN]( 

202 home_account_id=home_account_id, 

203 environment=environment, 

204 client_id=client_id, 

205 realm=realm, 

206 target=" ".join(target), 

207 ext_cache_key=ext_cache_key, 

208 ), 

209 default=default) 

210 

211 def _get_app_metadata(self, environment, client_id, default=None): # O(1) 

212 return self._get( 

213 self.CredentialType.APP_METADATA, 

214 self.key_makers[TokenCache.CredentialType.APP_METADATA]( 

215 environment=environment, 

216 client_id=client_id, 

217 ), 

218 default=default) 

219 

220 def _get(self, credential_type, key, default=None): # O(1) 

221 with self._lock: 

222 return self._cache.get(credential_type, {}).get(key, default) 

223 

224 @staticmethod 

225 def _is_matching(entry: dict, query: dict, target_set: set = None) -> bool: 

226 query_with_lowercase_environment = { 

227 # __add() canonicalized entry's environment value to lower case, 

228 # so we do the same here. 

229 k: v.lower() if k == "environment" and isinstance(v, str) else v 

230 for k, v in query.items() 

231 } if query else {} 

232 return is_subdict_of(query_with_lowercase_environment, entry) and ( 

233 target_set <= set(entry.get("target", "").split()) 

234 if target_set else True) 

235 

236 def search(self, credential_type, target=None, query=None, *, now=None): # O(n) generator 

237 """Returns a generator of matching entries. 

238 

239 It is O(1) for AT hits, and O(n) for other types. 

240 Note that it holds a lock during the entire search. 

241 """ 

242 target = sorted(target or []) # Match the order sorted by add() 

243 assert isinstance(target, list), "Invalid parameter type" 

244 

245 preferred_result = None 

246 if (credential_type == self.CredentialType.ACCESS_TOKEN 

247 and isinstance(query, dict) 

248 and "home_account_id" in query and "environment" in query 

249 and "client_id" in query and "realm" in query and target 

250 ): # Special case for O(1) AT lookup 

251 preferred_result = self._get_access_token( 

252 query["home_account_id"], query["environment"], 

253 query["client_id"], query["realm"], target, 

254 ext_cache_key=query.get("ext_cache_key")) 

255 if preferred_result and self._is_matching( 

256 preferred_result, query, 

257 # Needs no target_set here because it is satisfied by dict key 

258 ): 

259 yield preferred_result 

260 

261 target_set = set(target) 

262 with self._lock: 

263 # O(n) search. The key is NOT used in search. 

264 now = int(time.time() if now is None else now) 

265 expired_access_tokens = [ 

266 # Especially when/if we key ATs by ephemeral fields such as key_id, 

267 # stale ATs keyed by an old key_id would stay forever. 

268 # Here we collect them for their removal. 

269 ] 

270 for entry in self._cache.get(credential_type, {}).values(): 

271 if ( # Automatically delete expired access tokens 

272 credential_type == self.CredentialType.ACCESS_TOKEN 

273 and int(entry["expires_on"]) < now 

274 ): 

275 expired_access_tokens.append(entry) # Can't delete them within current for-loop 

276 continue 

277 if (entry != preferred_result # Avoid yielding the same entry twice 

278 and self._is_matching(entry, query, target_set=target_set) 

279 ): 

280 # Cache isolation for extended cache keys (e.g., FMI path). 

281 # Entries with ext_cache_key must not match queries without one. 

282 if (credential_type == self.CredentialType.ACCESS_TOKEN 

283 and "ext_cache_key" in entry 

284 and "ext_cache_key" not in (query or {}) 

285 ): 

286 continue 

287 yield entry 

288 for at in expired_access_tokens: 

289 self.remove_at(at) 

290 

291 def find(self, credential_type, target=None, query=None, *, now=None): 

292 """Equivalent to list(search(...)).""" 

293 warnings.warn( 

294 "Use list(search(...)) instead to explicitly get a list.", 

295 DeprecationWarning) 

296 return list(self.search(credential_type, target=target, query=query, now=now)) 

297 

298 def add(self, event, now=None): 

299 """Handle a token obtaining event, and add tokens into cache.""" 

300 def make_clean_copy(dictionary, sensitive_fields): # Masks sensitive info 

301 return { 

302 k: "********" if k in sensitive_fields else v 

303 for k, v in dictionary.items() 

304 } 

305 clean_event = dict( 

306 event, 

307 data=make_clean_copy(event.get("data", {}), ( 

308 "password", "client_secret", "refresh_token", "assertion", 

309 "user_federated_identity_credential", 

310 )), 

311 response=make_clean_copy(event.get("response", {}), ( 

312 "id_token_claims", # Provided by broker 

313 "access_token", "refresh_token", "id_token", "username", 

314 )), 

315 ) 

316 logger.debug("event=%s", json.dumps( 

317 # We examined and concluded that this log won't have Log Injection risk, 

318 # because the event payload is already in JSON so CR/LF will be escaped. 

319 clean_event, 

320 indent=4, sort_keys=True, 

321 default=str, # assertion is in bytes in Python 3 

322 )) 

323 return self.__add(event, now=now) 

324 

325 def __parse_account(self, response, id_token_claims): 

326 """Return client_info and home_account_id""" 

327 if "client_info" in response: # It happens when client_info and profile are in request 

328 client_info = json.loads(decode_part(response["client_info"])) 

329 if "uid" in client_info and "utid" in client_info: 

330 return client_info, "{uid}.{utid}".format(**client_info) 

331 # https://github.com/AzureAD/microsoft-authentication-library-for-python/issues/387 

332 if id_token_claims: # This would be an end user on ADFS-direct scenario 

333 sub = id_token_claims["sub"] # "sub" always exists, per OIDC specs 

334 return {"uid": sub}, sub 

335 # client_credentials flow will reach this code path 

336 return {}, None 

337 

338 def __add(self, event, now=None): 

339 # event typically contains: client_id, scope, token_endpoint, 

340 # response, params, data, grant_type 

341 environment = realm = None 

342 if "token_endpoint" in event: 

343 _, environment, realm = canonicalize(event["token_endpoint"]) 

344 if "environment" in event: # Always available unless in legacy test cases 

345 environment = event["environment"] # Set by application.py 

346 response = event.get("response", {}) 

347 data = event.get("data", {}) 

348 access_token = response.get("access_token") 

349 refresh_token = response.get("refresh_token") 

350 id_token = response.get("id_token") 

351 id_token_claims = response.get("id_token_claims") or ( # Prefer the claims from broker 

352 # Only use decode_id_token() when necessary, it contains time-sensitive validation 

353 decode_id_token(id_token, client_id=event["client_id"]) if id_token else {}) 

354 client_info, home_account_id = self.__parse_account(response, id_token_claims) 

355 

356 target = ' '.join(sorted(event.get("scope") or [])) # Schema should have required sorting 

357 

358 with self._lock: 

359 now = int(time.time() if now is None else now) 

360 

361 if access_token: 

362 default_expires_in = ( # https://www.rfc-editor.org/rfc/rfc6749#section-5.1 

363 int(response.get("expires_on")) - now # Some Managed Identity emits this 

364 ) if response.get("expires_on") else 600 

365 expires_in = int( # AADv1-like endpoint returns a string 

366 response.get("expires_in", default_expires_in)) 

367 ext_expires_in = int( # AADv1-like endpoint returns a string 

368 response.get("ext_expires_in", expires_in)) 

369 at = { 

370 "credential_type": self.CredentialType.ACCESS_TOKEN, 

371 "secret": access_token, 

372 "home_account_id": home_account_id, 

373 "environment": environment, 

374 "client_id": event.get("client_id"), 

375 "target": target, 

376 "realm": realm, 

377 "token_type": response.get("token_type", "Bearer"), 

378 "cached_at": str(now), # Schema defines it as a string 

379 "expires_on": str(now + expires_in), # Same here 

380 "extended_expires_on": str(now + ext_expires_in) # Same here 

381 } 

382 at.update({k: data[k] for k in data if k in { 

383 # Also store extra data which we explicitly allow 

384 # So that we won't accidentally store a user's password etc. 

385 "key_id", # It happens in SSH-cert or POP scenario 

386 }}) 

387 # Compute and store extended cache key for cache isolation 

388 # (e.g., different FMI paths should have separate cache entries) 

389 ext_cache_key = _compute_ext_cache_key(data) 

390 

391 if ext_cache_key: 

392 at["ext_cache_key"] = ext_cache_key 

393 if "refresh_in" in response: 

394 refresh_in = response["refresh_in"] # It is an integer 

395 at["refresh_on"] = str(now + refresh_in) # Schema wants a string 

396 self.modify(self.CredentialType.ACCESS_TOKEN, at, at) 

397 

398 if client_info and not event.get("skip_account_creation"): 

399 account = { 

400 "home_account_id": home_account_id, 

401 "environment": environment, 

402 "realm": realm, 

403 "local_account_id": event.get( 

404 "_account_id", # Came from mid-tier code path. 

405 # Emperically, it is the oid in AAD or cid in MSA. 

406 id_token_claims.get("oid", id_token_claims.get("sub"))), 

407 "username": _get_username(id_token_claims) 

408 or data.get("username") # Falls back to ROPC username 

409 or event.get("username") # Falls back to Federated ROPC username 

410 or "", # The schema does not like null 

411 "authority_type": event.get( 

412 "authority_type", # Honor caller's choice of authority_type 

413 self.AuthorityType.ADFS if realm == "adfs" 

414 else self.AuthorityType.MSSTS), 

415 # "client_info": response.get("client_info"), # Optional 

416 } 

417 grant_types_that_establish_an_account = ( 

418 _GRANT_TYPE_BROKER, "authorization_code", "password", 

419 Client.DEVICE_FLOW["GRANT_TYPE"], "user_fic") 

420 if event.get("grant_type") in grant_types_that_establish_an_account: 

421 account["account_source"] = event["grant_type"] 

422 self.modify(self.CredentialType.ACCOUNT, account, account) 

423 

424 if id_token: 

425 idt = { 

426 "credential_type": self.CredentialType.ID_TOKEN, 

427 "secret": id_token, 

428 "home_account_id": home_account_id, 

429 "environment": environment, 

430 "realm": realm, 

431 "client_id": event.get("client_id"), 

432 # "authority": "it is optional", 

433 } 

434 self.modify(self.CredentialType.ID_TOKEN, idt, idt) 

435 

436 if refresh_token: 

437 rt = { 

438 "credential_type": self.CredentialType.REFRESH_TOKEN, 

439 "secret": refresh_token, 

440 "home_account_id": home_account_id, 

441 "environment": environment, 

442 "client_id": event.get("client_id"), 

443 "target": target, # Optional per schema though 

444 "last_modification_time": str(now), # Optional. Schema defines it as a string. 

445 } 

446 if "foci" in response: 

447 rt["family_id"] = response["foci"] 

448 self.modify(self.CredentialType.REFRESH_TOKEN, rt, rt) 

449 

450 app_metadata = { 

451 "client_id": event.get("client_id"), 

452 "environment": environment, 

453 } 

454 if "foci" in response: 

455 app_metadata["family_id"] = response.get("foci") 

456 self.modify(self.CredentialType.APP_METADATA, app_metadata, app_metadata) 

457 

458 def modify(self, credential_type, old_entry, new_key_value_pairs=None): 

459 # Modify the specified old_entry with new_key_value_pairs, 

460 # or remove the old_entry if the new_key_value_pairs is None. 

461 

462 # This helper exists to consolidate all token add/modify/remove behaviors, 

463 # so that the sub-classes will have only one method to work on, 

464 # instead of patching a pair of update_xx() and remove_xx() per type. 

465 # You can monkeypatch self.key_makers to support more types on-the-fly. 

466 key = self.key_makers[credential_type](**old_entry) 

467 with self._lock: 

468 if new_key_value_pairs: # Update with them 

469 entries = self._cache.setdefault(credential_type, {}) 

470 entries[key] = dict( 

471 old_entry, # Do not use entries[key] b/c it might not exist 

472 **new_key_value_pairs) 

473 else: # Remove old_entry 

474 self._cache.setdefault(credential_type, {}).pop(key, None) 

475 

476 def remove_rt(self, rt_item): 

477 assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN 

478 return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item) 

479 

480 def update_rt(self, rt_item, new_rt): 

481 assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN 

482 return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item, { 

483 "secret": new_rt, 

484 "last_modification_time": str(int(time.time())), # Optional. Schema defines it as a string. 

485 }) 

486 

487 def remove_at(self, at_item): 

488 assert at_item.get("credential_type") == self.CredentialType.ACCESS_TOKEN 

489 return self.modify(self.CredentialType.ACCESS_TOKEN, at_item) 

490 

491 def remove_idt(self, idt_item): 

492 assert idt_item.get("credential_type") == self.CredentialType.ID_TOKEN 

493 return self.modify(self.CredentialType.ID_TOKEN, idt_item) 

494 

495 def remove_account(self, account_item): 

496 assert "authority_type" in account_item 

497 return self.modify(self.CredentialType.ACCOUNT, account_item) 

498 

499 

500class SerializableTokenCache(TokenCache): 

501 """This serialization can be a starting point to implement your own persistence. 

502 

503 This class does NOT actually persist the cache on disk/db/etc.. 

504 Depending on your need, 

505 the following simple recipe for file-based, unencrypted persistence may be sufficient:: 

506 

507 import os, atexit, msal 

508 cache_filename = os.path.join( # Persist cache into this file 

509 os.getenv( 

510 # Automatically wipe out the cache from Linux when user's ssh session ends. 

511 # See also https://github.com/AzureAD/microsoft-authentication-library-for-python/issues/690 

512 "XDG_RUNTIME_DIR", ""), 

513 "my_cache.bin") 

514 cache = msal.SerializableTokenCache() 

515 if os.path.exists(cache_filename): 

516 cache.deserialize(open(cache_filename, "r").read()) 

517 atexit.register(lambda: 

518 open(cache_filename, "w").write(cache.serialize()) 

519 # Hint: The following optional line persists only when state changed 

520 if cache.has_state_changed else None 

521 ) 

522 app = msal.ClientApplication(..., token_cache=cache) 

523 ... 

524 

525 Alternatively, you may use a more sophisticated cache persistence library, 

526 `MSAL Extensions <https://github.com/AzureAD/microsoft-authentication-extensions-for-python>`_, 

527 which provides token cache persistence with encryption, and more. 

528 

529 :var bool has_state_changed: 

530 Indicates whether the cache state in the memory has changed since last 

531 :func:`~serialize` or :func:`~deserialize` call. 

532 """ 

533 has_state_changed = False 

534 

535 def add(self, event, **kwargs): 

536 super(SerializableTokenCache, self).add(event, **kwargs) 

537 self.has_state_changed = True 

538 

539 def modify(self, credential_type, old_entry, new_key_value_pairs=None): 

540 super(SerializableTokenCache, self).modify( 

541 credential_type, old_entry, new_key_value_pairs) 

542 self.has_state_changed = True 

543 

544 def deserialize(self, state): 

545 # type: (Optional[str]) -> None 

546 """Deserialize the cache from a state previously obtained by serialize()""" 

547 with self._lock: 

548 self._cache = json.loads(state) if state else {} 

549 self.has_state_changed = False # reset 

550 

551 def serialize(self): 

552 # type: () -> str 

553 """Serialize the current cache state into a string.""" 

554 with self._lock: 

555 self.has_state_changed = False 

556 return json.dumps(self._cache, indent=4) 

557