Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.10/site-packages/msal/token_cache.py: 28%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

159 statements  

1import json 

2import threading 

3import time 

4import logging 

5import warnings 

6 

7from .authority import canonicalize 

8from .oauth2cli.oidc import decode_part, decode_id_token 

9from .oauth2cli.oauth2 import Client 

10 

11 

12logger = logging.getLogger(__name__) 

13_GRANT_TYPE_BROKER = "broker" 

14 

15def is_subdict_of(small, big): 

16 return dict(big, **small) == big 

17 

18def _get_username(id_token_claims): 

19 return id_token_claims.get( 

20 "preferred_username", # AAD 

21 id_token_claims.get("upn")) # ADFS 2019 

22 

23class TokenCache(object): 

24 """This is considered as a base class containing minimal cache behavior. 

25 

26 Although it maintains tokens using unified schema across all MSAL libraries, 

27 this class does not serialize/persist them. 

28 See subclass :class:`SerializableTokenCache` for details on serialization. 

29 """ 

30 

31 class CredentialType: 

32 ACCESS_TOKEN = "AccessToken" 

33 REFRESH_TOKEN = "RefreshToken" 

34 ACCOUNT = "Account" # Not exactly a credential type, but we put it here 

35 ID_TOKEN = "IdToken" 

36 APP_METADATA = "AppMetadata" 

37 

38 class AuthorityType: 

39 ADFS = "ADFS" 

40 MSSTS = "MSSTS" # MSSTS means AAD v2 for both AAD & MSA 

41 

42 def __init__(self): 

43 self._lock = threading.RLock() 

44 self._cache = {} 

45 self.key_makers = { 

46 # Note: We have changed token key format before when ordering scopes; 

47 # changing token key won't result in cache miss. 

48 self.CredentialType.REFRESH_TOKEN: 

49 lambda home_account_id=None, environment=None, client_id=None, 

50 target=None, **ignored_payload_from_a_real_token: 

51 "-".join([ 

52 home_account_id or "", 

53 environment or "", 

54 self.CredentialType.REFRESH_TOKEN, 

55 client_id or "", 

56 "", # RT is cross-tenant in AAD 

57 target or "", # raw value could be None if deserialized from other SDK 

58 ]).lower(), 

59 self.CredentialType.ACCESS_TOKEN: 

60 lambda home_account_id=None, environment=None, client_id=None, 

61 realm=None, target=None, 

62 # Note: New field(s) can be added here 

63 #key_id=None, 

64 **ignored_payload_from_a_real_token: 

65 "-".join([ # Note: Could use a hash here to shorten key length 

66 home_account_id or "", 

67 environment or "", 

68 self.CredentialType.ACCESS_TOKEN, 

69 client_id or "", 

70 realm or "", 

71 target or "", 

72 #key_id or "", # So ATs of different key_id can coexist 

73 ]).lower(), 

74 self.CredentialType.ID_TOKEN: 

75 lambda home_account_id=None, environment=None, client_id=None, 

76 realm=None, **ignored_payload_from_a_real_token: 

77 "-".join([ 

78 home_account_id or "", 

79 environment or "", 

80 self.CredentialType.ID_TOKEN, 

81 client_id or "", 

82 realm or "", 

83 "" # Albeit irrelevant, schema requires an empty scope here 

84 ]).lower(), 

85 self.CredentialType.ACCOUNT: 

86 lambda home_account_id=None, environment=None, realm=None, 

87 **ignored_payload_from_a_real_entry: 

88 "-".join([ 

89 home_account_id or "", 

90 environment or "", 

91 realm or "", 

92 ]).lower(), 

93 self.CredentialType.APP_METADATA: 

94 lambda environment=None, client_id=None, **kwargs: 

95 "appmetadata-{}-{}".format(environment or "", client_id or ""), 

96 } 

97 

98 def _get_access_token( 

99 self, 

100 home_account_id, environment, client_id, realm, target, # Together they form a compound key 

101 default=None, 

102 ): # O(1) 

103 return self._get( 

104 self.CredentialType.ACCESS_TOKEN, 

105 self.key_makers[TokenCache.CredentialType.ACCESS_TOKEN]( 

106 home_account_id=home_account_id, 

107 environment=environment, 

108 client_id=client_id, 

109 realm=realm, 

110 target=" ".join(target), 

111 ), 

112 default=default) 

113 

114 def _get_app_metadata(self, environment, client_id, default=None): # O(1) 

115 return self._get( 

116 self.CredentialType.APP_METADATA, 

117 self.key_makers[TokenCache.CredentialType.APP_METADATA]( 

118 environment=environment, 

119 client_id=client_id, 

120 ), 

121 default=default) 

122 

123 def _get(self, credential_type, key, default=None): # O(1) 

124 with self._lock: 

125 return self._cache.get(credential_type, {}).get(key, default) 

126 

127 @staticmethod 

128 def _is_matching(entry: dict, query: dict, target_set: set = None) -> bool: 

129 query_with_lowercase_environment = { 

130 # __add() canonicalized entry's environment value to lower case, 

131 # so we do the same here. 

132 k: v.lower() if k == "environment" and isinstance(v, str) else v 

133 for k, v in query.items() 

134 } if query else {} 

135 return is_subdict_of(query_with_lowercase_environment, entry) and ( 

136 target_set <= set(entry.get("target", "").split()) 

137 if target_set else True) 

138 

139 def search(self, credential_type, target=None, query=None, *, now=None): # O(n) generator 

140 """Returns a generator of matching entries. 

141 

142 It is O(1) for AT hits, and O(n) for other types. 

143 Note that it holds a lock during the entire search. 

144 """ 

145 target = sorted(target or []) # Match the order sorted by add() 

146 assert isinstance(target, list), "Invalid parameter type" 

147 

148 preferred_result = None 

149 if (credential_type == self.CredentialType.ACCESS_TOKEN 

150 and isinstance(query, dict) 

151 and "home_account_id" in query and "environment" in query 

152 and "client_id" in query and "realm" in query and target 

153 ): # Special case for O(1) AT lookup 

154 preferred_result = self._get_access_token( 

155 query["home_account_id"], query["environment"], 

156 query["client_id"], query["realm"], target) 

157 if preferred_result and self._is_matching( 

158 preferred_result, query, 

159 # Needs no target_set here because it is satisfied by dict key 

160 ): 

161 yield preferred_result 

162 

163 target_set = set(target) 

164 with self._lock: 

165 # O(n) search. The key is NOT used in search. 

166 now = int(time.time() if now is None else now) 

167 expired_access_tokens = [ 

168 # Especially when/if we key ATs by ephemeral fields such as key_id, 

169 # stale ATs keyed by an old key_id would stay forever. 

170 # Here we collect them for their removal. 

171 ] 

172 for entry in self._cache.get(credential_type, {}).values(): 

173 if ( # Automatically delete expired access tokens 

174 credential_type == self.CredentialType.ACCESS_TOKEN 

175 and int(entry["expires_on"]) < now 

176 ): 

177 expired_access_tokens.append(entry) # Can't delete them within current for-loop 

178 continue 

179 if (entry != preferred_result # Avoid yielding the same entry twice 

180 and self._is_matching(entry, query, target_set=target_set) 

181 ): 

182 yield entry 

183 for at in expired_access_tokens: 

184 self.remove_at(at) 

185 

186 def find(self, credential_type, target=None, query=None, *, now=None): 

187 """Equivalent to list(search(...)).""" 

188 warnings.warn( 

189 "Use list(search(...)) instead to explicitly get a list.", 

190 DeprecationWarning) 

191 return list(self.search(credential_type, target=target, query=query, now=now)) 

192 

193 def add(self, event, now=None): 

194 """Handle a token obtaining event, and add tokens into cache.""" 

195 def make_clean_copy(dictionary, sensitive_fields): # Masks sensitive info 

196 return { 

197 k: "********" if k in sensitive_fields else v 

198 for k, v in dictionary.items() 

199 } 

200 clean_event = dict( 

201 event, 

202 data=make_clean_copy(event.get("data", {}), ( 

203 "password", "client_secret", "refresh_token", "assertion", 

204 )), 

205 response=make_clean_copy(event.get("response", {}), ( 

206 "id_token_claims", # Provided by broker 

207 "access_token", "refresh_token", "id_token", "username", 

208 )), 

209 ) 

210 logger.debug("event=%s", json.dumps( 

211 # We examined and concluded that this log won't have Log Injection risk, 

212 # because the event payload is already in JSON so CR/LF will be escaped. 

213 clean_event, 

214 indent=4, sort_keys=True, 

215 default=str, # assertion is in bytes in Python 3 

216 )) 

217 return self.__add(event, now=now) 

218 

219 def __parse_account(self, response, id_token_claims): 

220 """Return client_info and home_account_id""" 

221 if "client_info" in response: # It happens when client_info and profile are in request 

222 client_info = json.loads(decode_part(response["client_info"])) 

223 if "uid" in client_info and "utid" in client_info: 

224 return client_info, "{uid}.{utid}".format(**client_info) 

225 # https://github.com/AzureAD/microsoft-authentication-library-for-python/issues/387 

226 if id_token_claims: # This would be an end user on ADFS-direct scenario 

227 sub = id_token_claims["sub"] # "sub" always exists, per OIDC specs 

228 return {"uid": sub}, sub 

229 # client_credentials flow will reach this code path 

230 return {}, None 

231 

232 def __add(self, event, now=None): 

233 # event typically contains: client_id, scope, token_endpoint, 

234 # response, params, data, grant_type 

235 environment = realm = None 

236 if "token_endpoint" in event: 

237 _, environment, realm = canonicalize(event["token_endpoint"]) 

238 if "environment" in event: # Always available unless in legacy test cases 

239 environment = event["environment"] # Set by application.py 

240 response = event.get("response", {}) 

241 data = event.get("data", {}) 

242 access_token = response.get("access_token") 

243 refresh_token = response.get("refresh_token") 

244 id_token = response.get("id_token") 

245 id_token_claims = response.get("id_token_claims") or ( # Prefer the claims from broker 

246 # Only use decode_id_token() when necessary, it contains time-sensitive validation 

247 decode_id_token(id_token, client_id=event["client_id"]) if id_token else {}) 

248 client_info, home_account_id = self.__parse_account(response, id_token_claims) 

249 

250 target = ' '.join(sorted(event.get("scope") or [])) # Schema should have required sorting 

251 

252 with self._lock: 

253 now = int(time.time() if now is None else now) 

254 

255 if access_token: 

256 default_expires_in = ( # https://www.rfc-editor.org/rfc/rfc6749#section-5.1 

257 int(response.get("expires_on")) - now # Some Managed Identity emits this 

258 ) if response.get("expires_on") else 600 

259 expires_in = int( # AADv1-like endpoint returns a string 

260 response.get("expires_in", default_expires_in)) 

261 ext_expires_in = int( # AADv1-like endpoint returns a string 

262 response.get("ext_expires_in", expires_in)) 

263 at = { 

264 "credential_type": self.CredentialType.ACCESS_TOKEN, 

265 "secret": access_token, 

266 "home_account_id": home_account_id, 

267 "environment": environment, 

268 "client_id": event.get("client_id"), 

269 "target": target, 

270 "realm": realm, 

271 "token_type": response.get("token_type", "Bearer"), 

272 "cached_at": str(now), # Schema defines it as a string 

273 "expires_on": str(now + expires_in), # Same here 

274 "extended_expires_on": str(now + ext_expires_in) # Same here 

275 } 

276 at.update({k: data[k] for k in data if k in { 

277 # Also store extra data which we explicitly allow 

278 # So that we won't accidentally store a user's password etc. 

279 "key_id", # It happens in SSH-cert or POP scenario 

280 }}) 

281 if "refresh_in" in response: 

282 refresh_in = response["refresh_in"] # It is an integer 

283 at["refresh_on"] = str(now + refresh_in) # Schema wants a string 

284 self.modify(self.CredentialType.ACCESS_TOKEN, at, at) 

285 

286 if client_info and not event.get("skip_account_creation"): 

287 account = { 

288 "home_account_id": home_account_id, 

289 "environment": environment, 

290 "realm": realm, 

291 "local_account_id": event.get( 

292 "_account_id", # Came from mid-tier code path. 

293 # Emperically, it is the oid in AAD or cid in MSA. 

294 id_token_claims.get("oid", id_token_claims.get("sub"))), 

295 "username": _get_username(id_token_claims) 

296 or data.get("username") # Falls back to ROPC username 

297 or event.get("username") # Falls back to Federated ROPC username 

298 or "", # The schema does not like null 

299 "authority_type": event.get( 

300 "authority_type", # Honor caller's choice of authority_type 

301 self.AuthorityType.ADFS if realm == "adfs" 

302 else self.AuthorityType.MSSTS), 

303 # "client_info": response.get("client_info"), # Optional 

304 } 

305 grant_types_that_establish_an_account = ( 

306 _GRANT_TYPE_BROKER, "authorization_code", "password", 

307 Client.DEVICE_FLOW["GRANT_TYPE"]) 

308 if event.get("grant_type") in grant_types_that_establish_an_account: 

309 account["account_source"] = event["grant_type"] 

310 self.modify(self.CredentialType.ACCOUNT, account, account) 

311 

312 if id_token: 

313 idt = { 

314 "credential_type": self.CredentialType.ID_TOKEN, 

315 "secret": id_token, 

316 "home_account_id": home_account_id, 

317 "environment": environment, 

318 "realm": realm, 

319 "client_id": event.get("client_id"), 

320 # "authority": "it is optional", 

321 } 

322 self.modify(self.CredentialType.ID_TOKEN, idt, idt) 

323 

324 if refresh_token: 

325 rt = { 

326 "credential_type": self.CredentialType.REFRESH_TOKEN, 

327 "secret": refresh_token, 

328 "home_account_id": home_account_id, 

329 "environment": environment, 

330 "client_id": event.get("client_id"), 

331 "target": target, # Optional per schema though 

332 "last_modification_time": str(now), # Optional. Schema defines it as a string. 

333 } 

334 if "foci" in response: 

335 rt["family_id"] = response["foci"] 

336 self.modify(self.CredentialType.REFRESH_TOKEN, rt, rt) 

337 

338 app_metadata = { 

339 "client_id": event.get("client_id"), 

340 "environment": environment, 

341 } 

342 if "foci" in response: 

343 app_metadata["family_id"] = response.get("foci") 

344 self.modify(self.CredentialType.APP_METADATA, app_metadata, app_metadata) 

345 

346 def modify(self, credential_type, old_entry, new_key_value_pairs=None): 

347 # Modify the specified old_entry with new_key_value_pairs, 

348 # or remove the old_entry if the new_key_value_pairs is None. 

349 

350 # This helper exists to consolidate all token add/modify/remove behaviors, 

351 # so that the sub-classes will have only one method to work on, 

352 # instead of patching a pair of update_xx() and remove_xx() per type. 

353 # You can monkeypatch self.key_makers to support more types on-the-fly. 

354 key = self.key_makers[credential_type](**old_entry) 

355 with self._lock: 

356 if new_key_value_pairs: # Update with them 

357 entries = self._cache.setdefault(credential_type, {}) 

358 entries[key] = dict( 

359 old_entry, # Do not use entries[key] b/c it might not exist 

360 **new_key_value_pairs) 

361 else: # Remove old_entry 

362 self._cache.setdefault(credential_type, {}).pop(key, None) 

363 

364 def remove_rt(self, rt_item): 

365 assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN 

366 return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item) 

367 

368 def update_rt(self, rt_item, new_rt): 

369 assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN 

370 return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item, { 

371 "secret": new_rt, 

372 "last_modification_time": str(int(time.time())), # Optional. Schema defines it as a string. 

373 }) 

374 

375 def remove_at(self, at_item): 

376 assert at_item.get("credential_type") == self.CredentialType.ACCESS_TOKEN 

377 return self.modify(self.CredentialType.ACCESS_TOKEN, at_item) 

378 

379 def remove_idt(self, idt_item): 

380 assert idt_item.get("credential_type") == self.CredentialType.ID_TOKEN 

381 return self.modify(self.CredentialType.ID_TOKEN, idt_item) 

382 

383 def remove_account(self, account_item): 

384 assert "authority_type" in account_item 

385 return self.modify(self.CredentialType.ACCOUNT, account_item) 

386 

387 

388class SerializableTokenCache(TokenCache): 

389 """This serialization can be a starting point to implement your own persistence. 

390 

391 This class does NOT actually persist the cache on disk/db/etc.. 

392 Depending on your need, 

393 the following simple recipe for file-based, unencrypted persistence may be sufficient:: 

394 

395 import os, atexit, msal 

396 cache_filename = os.path.join( # Persist cache into this file 

397 os.getenv( 

398 # Automatically wipe out the cache from Linux when user's ssh session ends. 

399 # See also https://github.com/AzureAD/microsoft-authentication-library-for-python/issues/690 

400 "XDG_RUNTIME_DIR", ""), 

401 "my_cache.bin") 

402 cache = msal.SerializableTokenCache() 

403 if os.path.exists(cache_filename): 

404 cache.deserialize(open(cache_filename, "r").read()) 

405 atexit.register(lambda: 

406 open(cache_filename, "w").write(cache.serialize()) 

407 # Hint: The following optional line persists only when state changed 

408 if cache.has_state_changed else None 

409 ) 

410 app = msal.ClientApplication(..., token_cache=cache) 

411 ... 

412 

413 Alternatively, you may use a more sophisticated cache persistence library, 

414 `MSAL Extensions <https://github.com/AzureAD/microsoft-authentication-extensions-for-python>`_, 

415 which provides token cache persistence with encryption, and more. 

416 

417 :var bool has_state_changed: 

418 Indicates whether the cache state in the memory has changed since last 

419 :func:`~serialize` or :func:`~deserialize` call. 

420 """ 

421 has_state_changed = False 

422 

423 def add(self, event, **kwargs): 

424 super(SerializableTokenCache, self).add(event, **kwargs) 

425 self.has_state_changed = True 

426 

427 def modify(self, credential_type, old_entry, new_key_value_pairs=None): 

428 super(SerializableTokenCache, self).modify( 

429 credential_type, old_entry, new_key_value_pairs) 

430 self.has_state_changed = True 

431 

432 def deserialize(self, state): 

433 # type: (Optional[str]) -> None 

434 """Deserialize the cache from a state previously obtained by serialize()""" 

435 with self._lock: 

436 self._cache = json.loads(state) if state else {} 

437 self.has_state_changed = False # reset 

438 

439 def serialize(self): 

440 # type: () -> str 

441 """Serialize the current cache state into a string.""" 

442 with self._lock: 

443 self.has_state_changed = False 

444 return json.dumps(self._cache, indent=4) 

445