Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/adal/token_cache.py: 29%

65 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-08 06:05 +0000

1#------------------------------------------------------------------------------ 

2# 

3# Copyright (c) Microsoft Corporation.  

4# All rights reserved. 

5#  

6# This code is licensed under the MIT License. 

7#  

8# Permission is hereby granted, free of charge, to any person obtaining a copy 

9# of this software and associated documentation files(the "Software"), to deal 

10# in the Software without restriction, including without limitation the rights 

11# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell 

12# copies of the Software, and to permit persons to whom the Software is 

13# furnished to do so, subject to the following conditions : 

14#  

15# The above copyright notice and this permission notice shall be included in 

16# all copies or substantial portions of the Software. 

17#  

18# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

19# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

20# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE 

21# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

22# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

23# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 

24# THE SOFTWARE. 

25# 

26#------------------------------------------------------------------------------ 

27 

28import json 

29import threading 

30 

31from .constants import TokenResponseFields 

32 

33def _string_cmp(str1, str2): 

34 '''Case insensitive comparison. Return true if both are None''' 

35 str1 = str1 if str1 is not None else '' 

36 str2 = str2 if str2 is not None else '' 

37 return str1.lower() == str2.lower() 

38 

39class TokenCacheKey(object): # pylint: disable=too-few-public-methods 

40 def __init__(self, authority, resource, client_id, user_id): 

41 self.authority = authority 

42 self.resource = resource 

43 self.client_id = client_id 

44 self.user_id = user_id 

45 

46 def __hash__(self): 

47 return hash((self.authority, self.resource, self.client_id, self.user_id)) 

48 

49 def __eq__(self, other): 

50 return _string_cmp(self.authority, other.authority) and \ 

51 _string_cmp(self.resource, other.resource) and \ 

52 _string_cmp(self.client_id, other.client_id) and \ 

53 _string_cmp(self.user_id, other.user_id) 

54 

55 def __ne__(self, other): 

56 return not self == other 

57 

58# pylint: disable=protected-access 

59 

60def _get_cache_key(entry): 

61 return TokenCacheKey( 

62 entry.get(TokenResponseFields._AUTHORITY), 

63 entry.get(TokenResponseFields.RESOURCE), 

64 entry.get(TokenResponseFields._CLIENT_ID), 

65 entry.get(TokenResponseFields.USER_ID)) 

66 

67 

68class TokenCache(object): 

69 def __init__(self, state=None): 

70 self._cache = {} 

71 self._lock = threading.RLock() 

72 if state: 

73 self.deserialize(state) 

74 self.has_state_changed = False 

75 

76 def find(self, query): 

77 with self._lock: 

78 return self._query_cache( 

79 query.get(TokenResponseFields.IS_MRRT), 

80 query.get(TokenResponseFields.USER_ID), 

81 query.get(TokenResponseFields._CLIENT_ID)) 

82 

83 def remove(self, entries): 

84 with self._lock: 

85 for e in entries: 

86 key = _get_cache_key(e) 

87 removed = self._cache.pop(key, None) 

88 if removed is not None: 

89 self.has_state_changed = True 

90 

91 def add(self, entries): 

92 with self._lock: 

93 for e in entries: 

94 key = _get_cache_key(e) 

95 self._cache[key] = e 

96 self.has_state_changed = True 

97 

98 def serialize(self): 

99 with self._lock: 

100 return json.dumps(list(self._cache.values())) 

101 

102 def deserialize(self, state): 

103 with self._lock: 

104 self._cache.clear() 

105 if state: 

106 tokens = json.loads(state) 

107 for t in tokens: 

108 key = _get_cache_key(t) 

109 self._cache[key] = t 

110 

111 def read_items(self): 

112 '''output list of tuples in (key, authentication-result)''' 

113 with self._lock: 

114 return self._cache.items() 

115 

116 def _query_cache(self, is_mrrt, user_id, client_id): 

117 matches = [] 

118 for k in self._cache: 

119 v = self._cache[k] 

120 #None value will be taken as wildcard match 

121 #pylint: disable=too-many-boolean-expressions 

122 if ((is_mrrt is None or is_mrrt == v.get(TokenResponseFields.IS_MRRT)) and 

123 (user_id is None or _string_cmp(user_id, v.get(TokenResponseFields.USER_ID))) and 

124 (client_id is None or _string_cmp(client_id, v.get(TokenResponseFields._CLIENT_ID)))): 

125 matches.append(v) 

126 return matches