Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/adal/cache_driver.py: 20%
147 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:05 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:05 +0000
1#------------------------------------------------------------------------------
2#
3# Copyright (c) Microsoft Corporation.
4# All rights reserved.
5#
6# This code is licensed under the MIT License.
7#
8# Permission is hereby granted, free of charge, to any person obtaining a copy
9# of this software and associated documentation files(the "Software"), to deal
10# in the Software without restriction, including without limitation the rights
11# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell
12# copies of the Software, and to permit persons to whom the Software is
13# furnished to do so, subject to the following conditions :
14#
15# The above copyright notice and this permission notice shall be included in
16# all copies or substantial portions of the Software.
17#
18# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE
21# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24# THE SOFTWARE.
25#
26#------------------------------------------------------------------------------
28import base64
29import copy
30import hashlib
31from datetime import datetime, timedelta
32from dateutil import parser
34from .adal_error import AdalError
35from .constants import TokenResponseFields, Misc
36from . import log
38#surppress warnings: like access to a protected member of "_AUTHORITY", etc
39# pylint: disable=W0212
41def _create_token_hash(token):
42 hash_object = hashlib.sha256()
43 hash_object.update(token.encode('utf8'))
44 return base64.b64encode(hash_object.digest())
46def _create_token_id_message(entry):
47 access_token_hash = _create_token_hash(entry[TokenResponseFields.ACCESS_TOKEN])
48 message = 'AccessTokenId: ' + str(access_token_hash)
49 if entry.get(TokenResponseFields.REFRESH_TOKEN):
50 refresh_token_hash = _create_token_hash(entry[TokenResponseFields.REFRESH_TOKEN])
51 message += ', RefreshTokenId: ' + str(refresh_token_hash)
52 return message
54def _is_mrrt(entry):
55 return bool(entry.get(TokenResponseFields.RESOURCE, None))
57def _entry_has_metadata(entry):
58 return (TokenResponseFields._CLIENT_ID in entry and
59 TokenResponseFields._AUTHORITY in entry)
62class CacheDriver(object):
63 def __init__(self, call_context, authority, resource, client_id, cache,
64 refresh_function):
65 self._call_context = call_context
66 self._log = log.Logger("CacheDriver", call_context['log_context'])
67 self._authority = authority
68 self._resource = resource
69 self._client_id = client_id
70 self._cache = cache
71 self._refresh_function = refresh_function
73 def _get_potential_entries(self, query):
74 potential_entries_query = {}
76 if query.get(TokenResponseFields._CLIENT_ID):
77 potential_entries_query[TokenResponseFields._CLIENT_ID] = query[TokenResponseFields._CLIENT_ID]
79 if query.get(TokenResponseFields.USER_ID):
80 potential_entries_query[TokenResponseFields.USER_ID] = query[TokenResponseFields.USER_ID]
82 self._log.debug(
83 'Looking for potential cache entries: %(query)s',
84 {"query": log.scrub_pii(potential_entries_query)})
85 entries = self._cache.find(potential_entries_query)
86 self._log.debug(
87 'Found %(quantity)s potential entries.', {"quantity": len(entries)})
88 return entries
90 def _find_mrrt_tokens_for_user(self, user):
91 return self._cache.find({
92 TokenResponseFields.IS_MRRT: True,
93 TokenResponseFields.USER_ID: user,
94 TokenResponseFields._CLIENT_ID : self._client_id
95 })
97 def _load_single_entry_from_cache(self, query):
98 return_val = []
99 is_resource_tenant_specific = False
101 potential_entries = self._get_potential_entries(query)
102 if potential_entries:
103 resource_tenant_specific_entries = [
104 x for x in potential_entries
105 if x[TokenResponseFields.RESOURCE] == self._resource and
106 x[TokenResponseFields._AUTHORITY] == self._authority]
108 if not resource_tenant_specific_entries:
109 self._log.debug('No resource specific cache entries found.')
111 #There are no resource specific entries. Find an MRRT token.
112 mrrt_tokens = (x for x in potential_entries if x[TokenResponseFields.IS_MRRT])
113 token = next(mrrt_tokens, None)
114 if token:
115 self._log.debug('Found an MRRT token.')
116 return_val = token
117 else:
118 self._log.debug('No MRRT tokens found.')
119 elif len(resource_tenant_specific_entries) == 1:
120 self._log.debug('Resource specific token found.')
121 return_val = resource_tenant_specific_entries[0]
122 is_resource_tenant_specific = True
123 else:
124 raise AdalError('More than one token matches the criteria. The result is ambiguous.')
126 if return_val:
127 self._log.debug('Returning token from cache lookup, %(token_hash)s',
128 {"token_hash": _create_token_id_message(return_val)})
130 return return_val, is_resource_tenant_specific
132 def _create_entry_from_refresh(self, entry, refresh_response):
133 new_entry = copy.deepcopy(entry)
134 new_entry.update(refresh_response)
136 # It is possible the response payload has no 'resource' field, like in ADFS, so we manually
137 # fill it here. Note, 'resource' is part of the token cache key, so we have to set it to avoid
138 # corrupting the cache.
139 if 'resource' not in refresh_response:
140 new_entry['resource'] = self._resource
142 if entry[TokenResponseFields.IS_MRRT] and self._authority != entry[TokenResponseFields._AUTHORITY]:
143 new_entry[TokenResponseFields._AUTHORITY] = self._authority
145 self._log.debug('Created new cache entry from refresh response.')
146 return new_entry
148 def _replace_entry(self, entry_to_replace, new_entry):
149 self.remove(entry_to_replace)
150 self.add(new_entry)
152 def _refresh_expired_entry(self, entry):
153 token_response = self._refresh_function(entry, None)
154 new_entry = self._create_entry_from_refresh(entry, token_response)
155 self._replace_entry(entry, new_entry)
156 self._log.info('Returning token refreshed after expiry.')
157 return new_entry
159 def _acquire_new_token_from_mrrt(self, entry):
160 token_response = self._refresh_function(entry, self._resource)
161 new_entry = self._create_entry_from_refresh(entry, token_response)
162 self.add(new_entry)
163 self._log.info('Returning token derived from mrrt refresh.')
164 return new_entry
166 def _refresh_entry_if_necessary(self, entry, is_resource_specific):
167 expiry_date = parser.parse(entry[TokenResponseFields.EXPIRES_ON])
168 now = datetime.now(expiry_date.tzinfo)
170 # Add some buffer in to the time comparison to account for clock skew or latency.
171 now_plus_buffer = now + timedelta(minutes=Misc.CLOCK_BUFFER)
173 if is_resource_specific and now_plus_buffer > expiry_date:
174 if TokenResponseFields.REFRESH_TOKEN in entry:
175 self._log.info('Cached token is expired at %(date)s. Refreshing',
176 {"date": expiry_date})
177 return self._refresh_expired_entry(entry)
178 else:
179 self.remove(entry)
180 return None
181 elif not is_resource_specific and entry.get(TokenResponseFields.IS_MRRT):
182 if TokenResponseFields.REFRESH_TOKEN in entry:
183 self._log.info('Acquiring new access token from MRRT token.')
184 return self._acquire_new_token_from_mrrt(entry)
185 else:
186 self.remove(entry)
187 return None
188 else:
189 return entry
191 def find(self, query):
192 if query is None:
193 query = {}
194 self._log.debug('finding with query keys: %(query)s',
195 {"query": log.scrub_pii(query)})
196 entry, is_resource_tenant_specific = self._load_single_entry_from_cache(query)
197 if entry:
198 return self._refresh_entry_if_necessary(entry,
199 is_resource_tenant_specific)
200 else:
201 return None
203 def remove(self, entry):
204 self._log.debug('Removing entry.')
205 self._cache.remove([entry])
207 def _remove_many(self, entries):
208 self._log.debug('Remove many: %(number)s', {"number": len(entries)})
209 self._cache.remove(entries)
211 def _add_many(self, entries):
212 self._log.debug('Add many: %(number)s', {"number": len(entries)})
213 self._cache.add(entries)
215 def _update_refresh_tokens(self, entry):
216 if _is_mrrt(entry) and entry.get(TokenResponseFields.REFRESH_TOKEN):
217 mrrt_tokens = self._find_mrrt_tokens_for_user(entry.get(TokenResponseFields.USER_ID))
218 if mrrt_tokens:
219 self._log.debug('Updating %(number)s cached refresh tokens',
220 {"number": len(mrrt_tokens)})
221 self._remove_many(mrrt_tokens)
223 for t in mrrt_tokens:
224 t[TokenResponseFields.REFRESH_TOKEN] = entry[TokenResponseFields.REFRESH_TOKEN]
226 self._add_many(mrrt_tokens)
228 def _argument_entry_with_cached_metadata(self, entry):
229 if _entry_has_metadata(entry):
230 return
232 if _is_mrrt(entry):
233 self._log.debug('Added entry is MRRT')
234 entry[TokenResponseFields.IS_MRRT] = True
235 else:
236 entry[TokenResponseFields.RESOURCE] = self._resource
238 entry[TokenResponseFields._CLIENT_ID] = self._client_id
239 entry[TokenResponseFields._AUTHORITY] = self._authority
241 def add(self, entry):
242 self._log.debug('Adding entry %(token_hash)s',
243 {"token_hash": _create_token_id_message(entry)})
244 self._argument_entry_with_cached_metadata(entry)
245 self._update_refresh_tokens(entry)
246 self._cache.add([entry])