Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/adal/cache_driver.py: 20%

147 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-08 06:05 +0000

1#------------------------------------------------------------------------------ 

2# 

3# Copyright (c) Microsoft Corporation.  

4# All rights reserved. 

5#  

6# This code is licensed under the MIT License. 

7#  

8# Permission is hereby granted, free of charge, to any person obtaining a copy 

9# of this software and associated documentation files(the "Software"), to deal 

10# in the Software without restriction, including without limitation the rights 

11# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell 

12# copies of the Software, and to permit persons to whom the Software is 

13# furnished to do so, subject to the following conditions : 

14#  

15# The above copyright notice and this permission notice shall be included in 

16# all copies or substantial portions of the Software. 

17#  

18# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

19# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

20# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE 

21# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

22# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

23# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 

24# THE SOFTWARE. 

25# 

26#------------------------------------------------------------------------------ 

27 

28import base64 

29import copy 

30import hashlib 

31from datetime import datetime, timedelta 

32from dateutil import parser 

33 

34from .adal_error import AdalError 

35from .constants import TokenResponseFields, Misc 

36from . import log 

37 

38#surppress warnings: like access to a protected member of "_AUTHORITY", etc 

39# pylint: disable=W0212 

40 

41def _create_token_hash(token): 

42 hash_object = hashlib.sha256() 

43 hash_object.update(token.encode('utf8')) 

44 return base64.b64encode(hash_object.digest()) 

45 

46def _create_token_id_message(entry): 

47 access_token_hash = _create_token_hash(entry[TokenResponseFields.ACCESS_TOKEN]) 

48 message = 'AccessTokenId: ' + str(access_token_hash) 

49 if entry.get(TokenResponseFields.REFRESH_TOKEN): 

50 refresh_token_hash = _create_token_hash(entry[TokenResponseFields.REFRESH_TOKEN]) 

51 message += ', RefreshTokenId: ' + str(refresh_token_hash) 

52 return message 

53 

54def _is_mrrt(entry): 

55 return bool(entry.get(TokenResponseFields.RESOURCE, None)) 

56 

57def _entry_has_metadata(entry): 

58 return (TokenResponseFields._CLIENT_ID in entry and 

59 TokenResponseFields._AUTHORITY in entry) 

60 

61 

62class CacheDriver(object): 

63 def __init__(self, call_context, authority, resource, client_id, cache, 

64 refresh_function): 

65 self._call_context = call_context 

66 self._log = log.Logger("CacheDriver", call_context['log_context']) 

67 self._authority = authority 

68 self._resource = resource 

69 self._client_id = client_id 

70 self._cache = cache 

71 self._refresh_function = refresh_function 

72 

73 def _get_potential_entries(self, query): 

74 potential_entries_query = {} 

75 

76 if query.get(TokenResponseFields._CLIENT_ID): 

77 potential_entries_query[TokenResponseFields._CLIENT_ID] = query[TokenResponseFields._CLIENT_ID] 

78 

79 if query.get(TokenResponseFields.USER_ID): 

80 potential_entries_query[TokenResponseFields.USER_ID] = query[TokenResponseFields.USER_ID] 

81 

82 self._log.debug( 

83 'Looking for potential cache entries: %(query)s', 

84 {"query": log.scrub_pii(potential_entries_query)}) 

85 entries = self._cache.find(potential_entries_query) 

86 self._log.debug( 

87 'Found %(quantity)s potential entries.', {"quantity": len(entries)}) 

88 return entries 

89 

90 def _find_mrrt_tokens_for_user(self, user): 

91 return self._cache.find({ 

92 TokenResponseFields.IS_MRRT: True, 

93 TokenResponseFields.USER_ID: user, 

94 TokenResponseFields._CLIENT_ID : self._client_id 

95 }) 

96 

97 def _load_single_entry_from_cache(self, query): 

98 return_val = [] 

99 is_resource_tenant_specific = False 

100 

101 potential_entries = self._get_potential_entries(query) 

102 if potential_entries: 

103 resource_tenant_specific_entries = [ 

104 x for x in potential_entries 

105 if x[TokenResponseFields.RESOURCE] == self._resource and 

106 x[TokenResponseFields._AUTHORITY] == self._authority] 

107 

108 if not resource_tenant_specific_entries: 

109 self._log.debug('No resource specific cache entries found.') 

110 

111 #There are no resource specific entries. Find an MRRT token. 

112 mrrt_tokens = (x for x in potential_entries if x[TokenResponseFields.IS_MRRT]) 

113 token = next(mrrt_tokens, None) 

114 if token: 

115 self._log.debug('Found an MRRT token.') 

116 return_val = token 

117 else: 

118 self._log.debug('No MRRT tokens found.') 

119 elif len(resource_tenant_specific_entries) == 1: 

120 self._log.debug('Resource specific token found.') 

121 return_val = resource_tenant_specific_entries[0] 

122 is_resource_tenant_specific = True 

123 else: 

124 raise AdalError('More than one token matches the criteria. The result is ambiguous.') 

125 

126 if return_val: 

127 self._log.debug('Returning token from cache lookup, %(token_hash)s', 

128 {"token_hash": _create_token_id_message(return_val)}) 

129 

130 return return_val, is_resource_tenant_specific 

131 

132 def _create_entry_from_refresh(self, entry, refresh_response): 

133 new_entry = copy.deepcopy(entry) 

134 new_entry.update(refresh_response) 

135 

136 # It is possible the response payload has no 'resource' field, like in ADFS, so we manually  

137 # fill it here. Note, 'resource' is part of the token cache key, so we have to set it to avoid 

138 # corrupting the cache. 

139 if 'resource' not in refresh_response: 

140 new_entry['resource'] = self._resource 

141 

142 if entry[TokenResponseFields.IS_MRRT] and self._authority != entry[TokenResponseFields._AUTHORITY]: 

143 new_entry[TokenResponseFields._AUTHORITY] = self._authority 

144 

145 self._log.debug('Created new cache entry from refresh response.') 

146 return new_entry 

147 

148 def _replace_entry(self, entry_to_replace, new_entry): 

149 self.remove(entry_to_replace) 

150 self.add(new_entry) 

151 

152 def _refresh_expired_entry(self, entry): 

153 token_response = self._refresh_function(entry, None) 

154 new_entry = self._create_entry_from_refresh(entry, token_response) 

155 self._replace_entry(entry, new_entry) 

156 self._log.info('Returning token refreshed after expiry.') 

157 return new_entry 

158 

159 def _acquire_new_token_from_mrrt(self, entry): 

160 token_response = self._refresh_function(entry, self._resource) 

161 new_entry = self._create_entry_from_refresh(entry, token_response) 

162 self.add(new_entry) 

163 self._log.info('Returning token derived from mrrt refresh.') 

164 return new_entry 

165 

166 def _refresh_entry_if_necessary(self, entry, is_resource_specific): 

167 expiry_date = parser.parse(entry[TokenResponseFields.EXPIRES_ON]) 

168 now = datetime.now(expiry_date.tzinfo) 

169 

170 # Add some buffer in to the time comparison to account for clock skew or latency. 

171 now_plus_buffer = now + timedelta(minutes=Misc.CLOCK_BUFFER) 

172 

173 if is_resource_specific and now_plus_buffer > expiry_date: 

174 if TokenResponseFields.REFRESH_TOKEN in entry: 

175 self._log.info('Cached token is expired at %(date)s. Refreshing', 

176 {"date": expiry_date}) 

177 return self._refresh_expired_entry(entry) 

178 else: 

179 self.remove(entry) 

180 return None 

181 elif not is_resource_specific and entry.get(TokenResponseFields.IS_MRRT): 

182 if TokenResponseFields.REFRESH_TOKEN in entry: 

183 self._log.info('Acquiring new access token from MRRT token.') 

184 return self._acquire_new_token_from_mrrt(entry) 

185 else: 

186 self.remove(entry) 

187 return None 

188 else: 

189 return entry 

190 

191 def find(self, query): 

192 if query is None: 

193 query = {} 

194 self._log.debug('finding with query keys: %(query)s', 

195 {"query": log.scrub_pii(query)}) 

196 entry, is_resource_tenant_specific = self._load_single_entry_from_cache(query) 

197 if entry: 

198 return self._refresh_entry_if_necessary(entry, 

199 is_resource_tenant_specific) 

200 else: 

201 return None 

202 

203 def remove(self, entry): 

204 self._log.debug('Removing entry.') 

205 self._cache.remove([entry]) 

206 

207 def _remove_many(self, entries): 

208 self._log.debug('Remove many: %(number)s', {"number": len(entries)}) 

209 self._cache.remove(entries) 

210 

211 def _add_many(self, entries): 

212 self._log.debug('Add many: %(number)s', {"number": len(entries)}) 

213 self._cache.add(entries) 

214 

215 def _update_refresh_tokens(self, entry): 

216 if _is_mrrt(entry) and entry.get(TokenResponseFields.REFRESH_TOKEN): 

217 mrrt_tokens = self._find_mrrt_tokens_for_user(entry.get(TokenResponseFields.USER_ID)) 

218 if mrrt_tokens: 

219 self._log.debug('Updating %(number)s cached refresh tokens', 

220 {"number": len(mrrt_tokens)}) 

221 self._remove_many(mrrt_tokens) 

222 

223 for t in mrrt_tokens: 

224 t[TokenResponseFields.REFRESH_TOKEN] = entry[TokenResponseFields.REFRESH_TOKEN] 

225 

226 self._add_many(mrrt_tokens) 

227 

228 def _argument_entry_with_cached_metadata(self, entry): 

229 if _entry_has_metadata(entry): 

230 return 

231 

232 if _is_mrrt(entry): 

233 self._log.debug('Added entry is MRRT') 

234 entry[TokenResponseFields.IS_MRRT] = True 

235 else: 

236 entry[TokenResponseFields.RESOURCE] = self._resource 

237 

238 entry[TokenResponseFields._CLIENT_ID] = self._client_id 

239 entry[TokenResponseFields._AUTHORITY] = self._authority 

240 

241 def add(self, entry): 

242 self._log.debug('Adding entry %(token_hash)s', 

243 {"token_hash": _create_token_id_message(entry)}) 

244 self._argument_entry_with_cached_metadata(entry) 

245 self._update_refresh_tokens(entry) 

246 self._cache.add([entry])