Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.10/site-packages/msal/throttled_http_client.py: 62%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

55 statements  

1from threading import Lock 

2from hashlib import sha256 

3 

4from .individual_cache import _IndividualCache as IndividualCache 

5from .individual_cache import _ExpiringMapping as ExpiringMapping 

6from .oauth2cli.http import Response 

7from .exceptions import MsalServiceError 

8 

9 

10# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4 

11DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code" 

12 

13 

14def _get_headers(response): 

15 # MSAL's HttpResponse did not have headers until 1.23.0 

16 # https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/581/files#diff-28866b706bc3830cd20485685f20fe79d45b58dce7050e68032e9d9372d68654R61 

17 # This helper ensures graceful degradation to {} without exception 

18 return getattr(response, "headers", {}) 

19 

20 

21class RetryAfterParser(object): 

22 FIELD_NAME_LOWER = "Retry-After".lower() 

23 def __init__(self, default_value=None): 

24 self._default_value = 5 if default_value is None else default_value 

25 

26 def parse(self, *, result, **ignored): 

27 """Return seconds to throttle""" 

28 response = result 

29 lowercase_headers = {k.lower(): v for k, v in _get_headers(response).items()} 

30 if not (response.status_code == 429 or response.status_code >= 500 

31 or self.FIELD_NAME_LOWER in lowercase_headers): 

32 return 0 # Quick exit 

33 retry_after = lowercase_headers.get(self.FIELD_NAME_LOWER, self._default_value) 

34 try: 

35 # AAD's retry_after uses integer format only 

36 # https://stackoverflow.microsoft.com/questions/264931/264932 

37 delay_seconds = int(retry_after) 

38 except ValueError: 

39 delay_seconds = self._default_value 

40 return min(3600, delay_seconds) 

41 

42 

43def _extract_data(kwargs, key, default=None): 

44 data = kwargs.get("data", {}) # data is usually a dict, but occasionally a string 

45 return data.get(key) if isinstance(data, dict) else default 

46 

47 

48class NormalizedResponse(Response): 

49 """A http response with the shape defined in Response, 

50 but contains only the data we will store in cache. 

51 """ 

52 def __init__(self, raw_response): 

53 super().__init__() 

54 self.status_code = raw_response.status_code 

55 self.text = raw_response.text 

56 self.headers = { 

57 k.lower(): v for k, v in _get_headers(raw_response).items() 

58 # Attempted storing only a small set of headers (such as Retry-After), 

59 # but it tends to lead to missing information (such as WWW-Authenticate). 

60 # So we store all headers, which are expected to contain only public info, 

61 # because we throttle only error responses and public responses. 

62 } 

63 

64 ## Note: Don't use the following line, 

65 ## because when being pickled, it will indirectly pickle the whole raw_response 

66 # self.raise_for_status = raw_response.raise_for_status 

67 def raise_for_status(self): 

68 if self.status_code >= 400: 

69 raise MsalServiceError( 

70 "HTTP Error: {}".format(self.status_code), 

71 error=None, error_description=None, # Historically required, keeping them for now 

72 ) 

73 

74 

75class ThrottledHttpClientBase(object): 

76 """Throttle the given http_client by storing and retrieving data from cache. 

77 

78 This base exists so that: 

79 1. These base post() and get() will return a NormalizedResponse 

80 2. The base __init__() will NOT re-throttle even if caller accidentally nested ThrottledHttpClient. 

81 

82 Subclasses shall only need to dynamically decorate their post() and get() methods 

83 in their __init__() method. 

84 """ 

85 def __init__(self, http_client, *, http_cache=None): 

86 self.http_client = http_client.http_client if isinstance( 

87 # If it is already a ThrottledHttpClientBase, we use its raw (unthrottled) http client 

88 http_client, ThrottledHttpClientBase) else http_client 

89 self._expiring_mapping = ExpiringMapping( # It will automatically clean up 

90 mapping=http_cache if http_cache is not None else {}, 

91 capacity=1024, # To prevent cache blowing up especially for CCA 

92 lock=Lock(), # TODO: This should ideally also allow customization 

93 ) 

94 

95 def post(self, *args, **kwargs): 

96 return NormalizedResponse(self.http_client.post(*args, **kwargs)) 

97 

98 def get(self, *args, **kwargs): 

99 return NormalizedResponse(self.http_client.get(*args, **kwargs)) 

100 

101 def close(self): 

102 return self.http_client.close() 

103 

104 @staticmethod 

105 def _hash(raw): 

106 return sha256(repr(raw).encode("utf-8")).hexdigest() 

107 

108 

109class ThrottledHttpClient(ThrottledHttpClientBase): 

110 """A throttled http client that is used by MSAL's non-managed identity clients.""" 

111 def __init__(self, *args, default_throttle_time=None, **kwargs): 

112 """Decorate self.post() and self.get() dynamically""" 

113 super(ThrottledHttpClient, self).__init__(*args, **kwargs) 

114 self.post = IndividualCache( 

115 # Internal specs requires throttling on at least token endpoint, 

116 # here we have a generic patch for POST on all endpoints. 

117 mapping=self._expiring_mapping, 

118 key_maker=lambda func, args, kwargs: 

119 "POST {} client_id={} scope={} hash={} 429/5xx/Retry-After".format( 

120 args[0], # It is the url, typically containing authority and tenant 

121 _extract_data(kwargs, "client_id"), # Per internal specs 

122 _extract_data(kwargs, "scope"), # Per internal specs 

123 self._hash( 

124 # The followings are all approximations of the "account" concept 

125 # to support per-account throttling. 

126 # TODO: We may want to disable it for confidential client, though 

127 _extract_data(kwargs, "refresh_token", # "account" during refresh 

128 _extract_data(kwargs, "code", # "account" of auth code grant 

129 _extract_data(kwargs, "username")))), # "account" of ROPC 

130 ), 

131 expires_in=RetryAfterParser(default_throttle_time or 5).parse, 

132 )(self.post) 

133 

134 self.post = IndividualCache( # It covers the "UI required cache" 

135 mapping=self._expiring_mapping, 

136 key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format( 

137 args[0], # It is the url, typically containing authority and tenant 

138 self._hash( 

139 # Here we use literally all parameters, even those short-lived 

140 # parameters containing timestamps (WS-Trust or POP assertion), 

141 # because they will automatically be cleaned up by ExpiringMapping. 

142 # 

143 # Furthermore, there is no need to implement 

144 # "interactive requests would reset the cache", 

145 # because acquire_token_silent()'s would be automatically unblocked 

146 # due to token cache layer operates on top of http cache layer. 

147 # 

148 # And, acquire_token_silent(..., force_refresh=True) will NOT 

149 # bypass http cache, because there is no real gain from that. 

150 # We won't bother implement it, nor do we want to encourage 

151 # acquire_token_silent(..., force_refresh=True) pattern. 

152 str(kwargs.get("params")) + str(kwargs.get("data"))), 

153 ), 

154 expires_in=lambda result=None, kwargs=None, **ignored: 

155 60 

156 if result.status_code == 400 

157 # Here we choose to cache exact HTTP 400 errors only (rather than 4xx) 

158 # because they are the ones defined in OAuth2 

159 # (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2) 

160 # Other 4xx errors might have different requirements e.g. 

161 # "407 Proxy auth required" would need a key including http headers. 

162 and not( # Exclude Device Flow whose retry is expected and regulated 

163 isinstance(kwargs.get("data"), dict) 

164 and kwargs["data"].get("grant_type") == DEVICE_AUTH_GRANT 

165 ) 

166 and RetryAfterParser.FIELD_NAME_LOWER not in set( # Otherwise leave it to the Retry-After decorator 

167 h.lower() for h in _get_headers(result)) 

168 else 0, 

169 )(self.post) 

170 

171 self.get = IndividualCache( # Typically those discovery GETs 

172 mapping=self._expiring_mapping, 

173 key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format( 

174 args[0], # It is the url, sometimes containing inline params 

175 self._hash(kwargs.get("params", "")), 

176 ), 

177 expires_in=lambda result=None, **ignored: 

178 3600*24 if 200 <= result.status_code < 300 else 0, 

179 )(self.get)