Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/msal/token_cache.py: 29%

126 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:20 +0000

1import json 

2import threading 

3import time 

4import logging 

5 

6from .authority import canonicalize 

7from .oauth2cli.oidc import decode_part, decode_id_token 

8 

9 

10logger = logging.getLogger(__name__) 

11 

12def is_subdict_of(small, big): 

13 return dict(big, **small) == big 

14 

15def _get_username(id_token_claims): 

16 return id_token_claims.get( 

17 "preferred_username", # AAD 

18 id_token_claims.get("upn")) # ADFS 2019 

19 

20class TokenCache(object): 

21 """This is considered as a base class containing minimal cache behavior. 

22 

23 Although it maintains tokens using unified schema across all MSAL libraries, 

24 this class does not serialize/persist them. 

25 See subclass :class:`SerializableTokenCache` for details on serialization. 

26 """ 

27 

28 class CredentialType: 

29 ACCESS_TOKEN = "AccessToken" 

30 REFRESH_TOKEN = "RefreshToken" 

31 ACCOUNT = "Account" # Not exactly a credential type, but we put it here 

32 ID_TOKEN = "IdToken" 

33 APP_METADATA = "AppMetadata" 

34 

35 class AuthorityType: 

36 ADFS = "ADFS" 

37 MSSTS = "MSSTS" # MSSTS means AAD v2 for both AAD & MSA 

38 

39 def __init__(self): 

40 self._lock = threading.RLock() 

41 self._cache = {} 

42 self.key_makers = { 

43 self.CredentialType.REFRESH_TOKEN: 

44 lambda home_account_id=None, environment=None, client_id=None, 

45 target=None, **ignored_payload_from_a_real_token: 

46 "-".join([ 

47 home_account_id or "", 

48 environment or "", 

49 self.CredentialType.REFRESH_TOKEN, 

50 client_id or "", 

51 "", # RT is cross-tenant in AAD 

52 target or "", # raw value could be None if deserialized from other SDK 

53 ]).lower(), 

54 self.CredentialType.ACCESS_TOKEN: 

55 lambda home_account_id=None, environment=None, client_id=None, 

56 realm=None, target=None, **ignored_payload_from_a_real_token: 

57 "-".join([ 

58 home_account_id or "", 

59 environment or "", 

60 self.CredentialType.ACCESS_TOKEN, 

61 client_id or "", 

62 realm or "", 

63 target or "", 

64 ]).lower(), 

65 self.CredentialType.ID_TOKEN: 

66 lambda home_account_id=None, environment=None, client_id=None, 

67 realm=None, **ignored_payload_from_a_real_token: 

68 "-".join([ 

69 home_account_id or "", 

70 environment or "", 

71 self.CredentialType.ID_TOKEN, 

72 client_id or "", 

73 realm or "", 

74 "" # Albeit irrelevant, schema requires an empty scope here 

75 ]).lower(), 

76 self.CredentialType.ACCOUNT: 

77 lambda home_account_id=None, environment=None, realm=None, 

78 **ignored_payload_from_a_real_entry: 

79 "-".join([ 

80 home_account_id or "", 

81 environment or "", 

82 realm or "", 

83 ]).lower(), 

84 self.CredentialType.APP_METADATA: 

85 lambda environment=None, client_id=None, **kwargs: 

86 "appmetadata-{}-{}".format(environment or "", client_id or ""), 

87 } 

88 

89 def find(self, credential_type, target=None, query=None): 

90 target = target or [] 

91 assert isinstance(target, list), "Invalid parameter type" 

92 target_set = set(target) 

93 with self._lock: 

94 # Since the target inside token cache key is (per schema) unsorted, 

95 # there is no point to attempt an O(1) key-value search here. 

96 # So we always do an O(n) in-memory search. 

97 return [entry 

98 for entry in self._cache.get(credential_type, {}).values() 

99 if is_subdict_of(query or {}, entry) 

100 and (target_set <= set(entry.get("target", "").split()) 

101 if target else True) 

102 ] 

103 

104 def add(self, event, now=None): 

105 # type: (dict) -> None 

106 """Handle a token obtaining event, and add tokens into cache.""" 

107 def make_clean_copy(dictionary, sensitive_fields): # Masks sensitive info 

108 return { 

109 k: "********" if k in sensitive_fields else v 

110 for k, v in dictionary.items() 

111 } 

112 clean_event = dict( 

113 event, 

114 data=make_clean_copy(event.get("data", {}), ( 

115 "password", "client_secret", "refresh_token", "assertion", 

116 )), 

117 response=make_clean_copy(event.get("response", {}), ( 

118 "id_token_claims", # Provided by broker 

119 "access_token", "refresh_token", "id_token", "username", 

120 )), 

121 ) 

122 logger.debug("event=%s", json.dumps( 

123 # We examined and concluded that this log won't have Log Injection risk, 

124 # because the event payload is already in JSON so CR/LF will be escaped. 

125 clean_event, 

126 indent=4, sort_keys=True, 

127 default=str, # assertion is in bytes in Python 3 

128 )) 

129 return self.__add(event, now=now) 

130 

131 def __parse_account(self, response, id_token_claims): 

132 """Return client_info and home_account_id""" 

133 if "client_info" in response: # It happens when client_info and profile are in request 

134 client_info = json.loads(decode_part(response["client_info"])) 

135 if "uid" in client_info and "utid" in client_info: 

136 return client_info, "{uid}.{utid}".format(**client_info) 

137 # https://github.com/AzureAD/microsoft-authentication-library-for-python/issues/387 

138 if id_token_claims: # This would be an end user on ADFS-direct scenario 

139 sub = id_token_claims["sub"] # "sub" always exists, per OIDC specs 

140 return {"uid": sub}, sub 

141 # client_credentials flow will reach this code path 

142 return {}, None 

143 

144 def __add(self, event, now=None): 

145 # event typically contains: client_id, scope, token_endpoint, 

146 # response, params, data, grant_type 

147 environment = realm = None 

148 if "token_endpoint" in event: 

149 _, environment, realm = canonicalize(event["token_endpoint"]) 

150 if "environment" in event: # Always available unless in legacy test cases 

151 environment = event["environment"] # Set by application.py 

152 response = event.get("response", {}) 

153 data = event.get("data", {}) 

154 access_token = response.get("access_token") 

155 refresh_token = response.get("refresh_token") 

156 id_token = response.get("id_token") 

157 id_token_claims = response.get("id_token_claims") or ( # Prefer the claims from broker 

158 # Only use decode_id_token() when necessary, it contains time-sensitive validation 

159 decode_id_token(id_token, client_id=event["client_id"]) if id_token else {}) 

160 client_info, home_account_id = self.__parse_account(response, id_token_claims) 

161 

162 target = ' '.join(event.get("scope") or []) # Per schema, we don't sort it 

163 

164 with self._lock: 

165 now = int(time.time() if now is None else now) 

166 

167 if access_token: 

168 default_expires_in = ( # https://www.rfc-editor.org/rfc/rfc6749#section-5.1 

169 int(response.get("expires_on")) - now # Some Managed Identity emits this 

170 ) if response.get("expires_on") else 600 

171 expires_in = int( # AADv1-like endpoint returns a string 

172 response.get("expires_in", default_expires_in)) 

173 ext_expires_in = int( # AADv1-like endpoint returns a string 

174 response.get("ext_expires_in", expires_in)) 

175 at = { 

176 "credential_type": self.CredentialType.ACCESS_TOKEN, 

177 "secret": access_token, 

178 "home_account_id": home_account_id, 

179 "environment": environment, 

180 "client_id": event.get("client_id"), 

181 "target": target, 

182 "realm": realm, 

183 "token_type": response.get("token_type", "Bearer"), 

184 "cached_at": str(now), # Schema defines it as a string 

185 "expires_on": str(now + expires_in), # Same here 

186 "extended_expires_on": str(now + ext_expires_in) # Same here 

187 } 

188 if data.get("key_id"): # It happens in SSH-cert or POP scenario 

189 at["key_id"] = data.get("key_id") 

190 if "refresh_in" in response: 

191 refresh_in = response["refresh_in"] # It is an integer 

192 at["refresh_on"] = str(now + refresh_in) # Schema wants a string 

193 self.modify(self.CredentialType.ACCESS_TOKEN, at, at) 

194 

195 if client_info and not event.get("skip_account_creation"): 

196 account = { 

197 "home_account_id": home_account_id, 

198 "environment": environment, 

199 "realm": realm, 

200 "local_account_id": event.get( 

201 "_account_id", # Came from mid-tier code path. 

202 # Emperically, it is the oid in AAD or cid in MSA. 

203 id_token_claims.get("oid", id_token_claims.get("sub"))), 

204 "username": _get_username(id_token_claims) 

205 or data.get("username") # Falls back to ROPC username 

206 or event.get("username") # Falls back to Federated ROPC username 

207 or "", # The schema does not like null 

208 "authority_type": event.get( 

209 "authority_type", # Honor caller's choice of authority_type 

210 self.AuthorityType.ADFS if realm == "adfs" 

211 else self.AuthorityType.MSSTS), 

212 # "client_info": response.get("client_info"), # Optional 

213 } 

214 self.modify(self.CredentialType.ACCOUNT, account, account) 

215 

216 if id_token: 

217 idt = { 

218 "credential_type": self.CredentialType.ID_TOKEN, 

219 "secret": id_token, 

220 "home_account_id": home_account_id, 

221 "environment": environment, 

222 "realm": realm, 

223 "client_id": event.get("client_id"), 

224 # "authority": "it is optional", 

225 } 

226 self.modify(self.CredentialType.ID_TOKEN, idt, idt) 

227 

228 if refresh_token: 

229 rt = { 

230 "credential_type": self.CredentialType.REFRESH_TOKEN, 

231 "secret": refresh_token, 

232 "home_account_id": home_account_id, 

233 "environment": environment, 

234 "client_id": event.get("client_id"), 

235 "target": target, # Optional per schema though 

236 "last_modification_time": str(now), # Optional. Schema defines it as a string. 

237 } 

238 if "foci" in response: 

239 rt["family_id"] = response["foci"] 

240 self.modify(self.CredentialType.REFRESH_TOKEN, rt, rt) 

241 

242 app_metadata = { 

243 "client_id": event.get("client_id"), 

244 "environment": environment, 

245 } 

246 if "foci" in response: 

247 app_metadata["family_id"] = response.get("foci") 

248 self.modify(self.CredentialType.APP_METADATA, app_metadata, app_metadata) 

249 

250 def modify(self, credential_type, old_entry, new_key_value_pairs=None): 

251 # Modify the specified old_entry with new_key_value_pairs, 

252 # or remove the old_entry if the new_key_value_pairs is None. 

253 

254 # This helper exists to consolidate all token add/modify/remove behaviors, 

255 # so that the sub-classes will have only one method to work on, 

256 # instead of patching a pair of update_xx() and remove_xx() per type. 

257 # You can monkeypatch self.key_makers to support more types on-the-fly. 

258 key = self.key_makers[credential_type](**old_entry) 

259 with self._lock: 

260 if new_key_value_pairs: # Update with them 

261 entries = self._cache.setdefault(credential_type, {}) 

262 entries[key] = dict( 

263 old_entry, # Do not use entries[key] b/c it might not exist 

264 **new_key_value_pairs) 

265 else: # Remove old_entry 

266 self._cache.setdefault(credential_type, {}).pop(key, None) 

267 

268 def remove_rt(self, rt_item): 

269 assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN 

270 return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item) 

271 

272 def update_rt(self, rt_item, new_rt): 

273 assert rt_item.get("credential_type") == self.CredentialType.REFRESH_TOKEN 

274 return self.modify(self.CredentialType.REFRESH_TOKEN, rt_item, { 

275 "secret": new_rt, 

276 "last_modification_time": str(int(time.time())), # Optional. Schema defines it as a string. 

277 }) 

278 

279 def remove_at(self, at_item): 

280 assert at_item.get("credential_type") == self.CredentialType.ACCESS_TOKEN 

281 return self.modify(self.CredentialType.ACCESS_TOKEN, at_item) 

282 

283 def remove_idt(self, idt_item): 

284 assert idt_item.get("credential_type") == self.CredentialType.ID_TOKEN 

285 return self.modify(self.CredentialType.ID_TOKEN, idt_item) 

286 

287 def remove_account(self, account_item): 

288 assert "authority_type" in account_item 

289 return self.modify(self.CredentialType.ACCOUNT, account_item) 

290 

291 

292class SerializableTokenCache(TokenCache): 

293 """This serialization can be a starting point to implement your own persistence. 

294 

295 This class does NOT actually persist the cache on disk/db/etc.. 

296 Depending on your need, 

297 the following simple recipe for file-based persistence may be sufficient:: 

298 

299 import os, atexit, msal 

300 cache = msal.SerializableTokenCache() 

301 if os.path.exists("my_cache.bin"): 

302 cache.deserialize(open("my_cache.bin", "r").read()) 

303 atexit.register(lambda: 

304 open("my_cache.bin", "w").write(cache.serialize()) 

305 # Hint: The following optional line persists only when state changed 

306 if cache.has_state_changed else None 

307 ) 

308 app = msal.ClientApplication(..., token_cache=cache) 

309 ... 

310 

311 :var bool has_state_changed: 

312 Indicates whether the cache state in the memory has changed since last 

313 :func:`~serialize` or :func:`~deserialize` call. 

314 """ 

315 has_state_changed = False 

316 

317 def add(self, event, **kwargs): 

318 super(SerializableTokenCache, self).add(event, **kwargs) 

319 self.has_state_changed = True 

320 

321 def modify(self, credential_type, old_entry, new_key_value_pairs=None): 

322 super(SerializableTokenCache, self).modify( 

323 credential_type, old_entry, new_key_value_pairs) 

324 self.has_state_changed = True 

325 

326 def deserialize(self, state): 

327 # type: (Optional[str]) -> None 

328 """Deserialize the cache from a state previously obtained by serialize()""" 

329 with self._lock: 

330 self._cache = json.loads(state) if state else {} 

331 self.has_state_changed = False # reset 

332 

333 def serialize(self): 

334 # type: () -> str 

335 """Serialize the current cache state into a string.""" 

336 with self._lock: 

337 self.has_state_changed = False 

338 return json.dumps(self._cache, indent=4) 

339