Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/adal/token_cache.py: 29%
65 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:05 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 06:05 +0000
1#------------------------------------------------------------------------------
2#
3# Copyright (c) Microsoft Corporation.
4# All rights reserved.
5#
6# This code is licensed under the MIT License.
7#
8# Permission is hereby granted, free of charge, to any person obtaining a copy
9# of this software and associated documentation files(the "Software"), to deal
10# in the Software without restriction, including without limitation the rights
11# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell
12# copies of the Software, and to permit persons to whom the Software is
13# furnished to do so, subject to the following conditions :
14#
15# The above copyright notice and this permission notice shall be included in
16# all copies or substantial portions of the Software.
17#
18# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE
21# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24# THE SOFTWARE.
25#
26#------------------------------------------------------------------------------
28import json
29import threading
31from .constants import TokenResponseFields
33def _string_cmp(str1, str2):
34 '''Case insensitive comparison. Return true if both are None'''
35 str1 = str1 if str1 is not None else ''
36 str2 = str2 if str2 is not None else ''
37 return str1.lower() == str2.lower()
39class TokenCacheKey(object): # pylint: disable=too-few-public-methods
40 def __init__(self, authority, resource, client_id, user_id):
41 self.authority = authority
42 self.resource = resource
43 self.client_id = client_id
44 self.user_id = user_id
46 def __hash__(self):
47 return hash((self.authority, self.resource, self.client_id, self.user_id))
49 def __eq__(self, other):
50 return _string_cmp(self.authority, other.authority) and \
51 _string_cmp(self.resource, other.resource) and \
52 _string_cmp(self.client_id, other.client_id) and \
53 _string_cmp(self.user_id, other.user_id)
55 def __ne__(self, other):
56 return not self == other
58# pylint: disable=protected-access
60def _get_cache_key(entry):
61 return TokenCacheKey(
62 entry.get(TokenResponseFields._AUTHORITY),
63 entry.get(TokenResponseFields.RESOURCE),
64 entry.get(TokenResponseFields._CLIENT_ID),
65 entry.get(TokenResponseFields.USER_ID))
68class TokenCache(object):
69 def __init__(self, state=None):
70 self._cache = {}
71 self._lock = threading.RLock()
72 if state:
73 self.deserialize(state)
74 self.has_state_changed = False
76 def find(self, query):
77 with self._lock:
78 return self._query_cache(
79 query.get(TokenResponseFields.IS_MRRT),
80 query.get(TokenResponseFields.USER_ID),
81 query.get(TokenResponseFields._CLIENT_ID))
83 def remove(self, entries):
84 with self._lock:
85 for e in entries:
86 key = _get_cache_key(e)
87 removed = self._cache.pop(key, None)
88 if removed is not None:
89 self.has_state_changed = True
91 def add(self, entries):
92 with self._lock:
93 for e in entries:
94 key = _get_cache_key(e)
95 self._cache[key] = e
96 self.has_state_changed = True
98 def serialize(self):
99 with self._lock:
100 return json.dumps(list(self._cache.values()))
102 def deserialize(self, state):
103 with self._lock:
104 self._cache.clear()
105 if state:
106 tokens = json.loads(state)
107 for t in tokens:
108 key = _get_cache_key(t)
109 self._cache[key] = t
111 def read_items(self):
112 '''output list of tuples in (key, authentication-result)'''
113 with self._lock:
114 return self._cache.items()
116 def _query_cache(self, is_mrrt, user_id, client_id):
117 matches = []
118 for k in self._cache:
119 v = self._cache[k]
120 #None value will be taken as wildcard match
121 #pylint: disable=too-many-boolean-expressions
122 if ((is_mrrt is None or is_mrrt == v.get(TokenResponseFields.IS_MRRT)) and
123 (user_id is None or _string_cmp(user_id, v.get(TokenResponseFields.USER_ID))) and
124 (client_id is None or _string_cmp(client_id, v.get(TokenResponseFields._CLIENT_ID)))):
125 matches.append(v)
126 return matches