Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/msal/throttled_http_client.py: 65%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

43 statements  

1from threading import Lock 

2from hashlib import sha256 

3 

4from .individual_cache import _IndividualCache as IndividualCache 

5from .individual_cache import _ExpiringMapping as ExpiringMapping 

6 

7 

8# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4 

9DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code" 

10 

11 

12class RetryAfterParser(object): 

13 def __init__(self, default_value=None): 

14 self._default_value = 5 if default_value is None else default_value 

15 

16 def parse(self, *, result, **ignored): 

17 """Return seconds to throttle""" 

18 response = result 

19 lowercase_headers = {k.lower(): v for k, v in getattr( 

20 # Historically, MSAL's HttpResponse does not always have headers 

21 response, "headers", {}).items()} 

22 if not (response.status_code == 429 or response.status_code >= 500 

23 or "retry-after" in lowercase_headers): 

24 return 0 # Quick exit 

25 retry_after = lowercase_headers.get("retry-after", self._default_value) 

26 try: 

27 # AAD's retry_after uses integer format only 

28 # https://stackoverflow.microsoft.com/questions/264931/264932 

29 delay_seconds = int(retry_after) 

30 except ValueError: 

31 delay_seconds = self._default_value 

32 return min(3600, delay_seconds) 

33 

34 

35def _extract_data(kwargs, key, default=None): 

36 data = kwargs.get("data", {}) # data is usually a dict, but occasionally a string 

37 return data.get(key) if isinstance(data, dict) else default 

38 

39 

40class ThrottledHttpClientBase(object): 

41 """Throttle the given http_client by storing and retrieving data from cache. 

42 

43 This wrapper exists so that our patching post() and get() would prevent 

44 re-patching side effect when/if same http_client being reused. 

45 

46 The subclass should implement post() and/or get() 

47 """ 

48 def __init__(self, http_client, *, http_cache=None): 

49 self.http_client = http_client 

50 self._expiring_mapping = ExpiringMapping( # It will automatically clean up 

51 mapping=http_cache if http_cache is not None else {}, 

52 capacity=1024, # To prevent cache blowing up especially for CCA 

53 lock=Lock(), # TODO: This should ideally also allow customization 

54 ) 

55 

56 def post(self, *args, **kwargs): 

57 return self.http_client.post(*args, **kwargs) 

58 

59 def get(self, *args, **kwargs): 

60 return self.http_client.get(*args, **kwargs) 

61 

62 def close(self): 

63 return self.http_client.close() 

64 

65 @staticmethod 

66 def _hash(raw): 

67 return sha256(repr(raw).encode("utf-8")).hexdigest() 

68 

69 

70class ThrottledHttpClient(ThrottledHttpClientBase): 

71 def __init__(self, http_client, *, default_throttle_time=None, **kwargs): 

72 super(ThrottledHttpClient, self).__init__(http_client, **kwargs) 

73 

74 _post = http_client.post # We'll patch _post, and keep original post() intact 

75 

76 _post = IndividualCache( 

77 # Internal specs requires throttling on at least token endpoint, 

78 # here we have a generic patch for POST on all endpoints. 

79 mapping=self._expiring_mapping, 

80 key_maker=lambda func, args, kwargs: 

81 "POST {} client_id={} scope={} hash={} 429/5xx/Retry-After".format( 

82 args[0], # It is the url, typically containing authority and tenant 

83 _extract_data(kwargs, "client_id"), # Per internal specs 

84 _extract_data(kwargs, "scope"), # Per internal specs 

85 self._hash( 

86 # The followings are all approximations of the "account" concept 

87 # to support per-account throttling. 

88 # TODO: We may want to disable it for confidential client, though 

89 _extract_data(kwargs, "refresh_token", # "account" during refresh 

90 _extract_data(kwargs, "code", # "account" of auth code grant 

91 _extract_data(kwargs, "username")))), # "account" of ROPC 

92 ), 

93 expires_in=RetryAfterParser(default_throttle_time or 5).parse, 

94 )(_post) 

95 

96 _post = IndividualCache( # It covers the "UI required cache" 

97 mapping=self._expiring_mapping, 

98 key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format( 

99 args[0], # It is the url, typically containing authority and tenant 

100 self._hash( 

101 # Here we use literally all parameters, even those short-lived 

102 # parameters containing timestamps (WS-Trust or POP assertion), 

103 # because they will automatically be cleaned up by ExpiringMapping. 

104 # 

105 # Furthermore, there is no need to implement 

106 # "interactive requests would reset the cache", 

107 # because acquire_token_silent()'s would be automatically unblocked 

108 # due to token cache layer operates on top of http cache layer. 

109 # 

110 # And, acquire_token_silent(..., force_refresh=True) will NOT 

111 # bypass http cache, because there is no real gain from that. 

112 # We won't bother implement it, nor do we want to encourage 

113 # acquire_token_silent(..., force_refresh=True) pattern. 

114 str(kwargs.get("params")) + str(kwargs.get("data"))), 

115 ), 

116 expires_in=lambda result=None, kwargs=None, **ignored: 

117 60 

118 if result.status_code == 400 

119 # Here we choose to cache exact HTTP 400 errors only (rather than 4xx) 

120 # because they are the ones defined in OAuth2 

121 # (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2) 

122 # Other 4xx errors might have different requirements e.g. 

123 # "407 Proxy auth required" would need a key including http headers. 

124 and not( # Exclude Device Flow whose retry is expected and regulated 

125 isinstance(kwargs.get("data"), dict) 

126 and kwargs["data"].get("grant_type") == DEVICE_AUTH_GRANT 

127 ) 

128 and "retry-after" not in set( # Leave it to the Retry-After decorator 

129 h.lower() for h in getattr(result, "headers", {}).keys()) 

130 else 0, 

131 )(_post) 

132 

133 self.post = _post 

134 

135 self.get = IndividualCache( # Typically those discovery GETs 

136 mapping=self._expiring_mapping, 

137 key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format( 

138 args[0], # It is the url, sometimes containing inline params 

139 self._hash(kwargs.get("params", "")), 

140 ), 

141 expires_in=lambda result=None, **ignored: 

142 3600*24 if 200 <= result.status_code < 300 else 0, 

143 )(http_client.get) 

144 

145 # The following 2 methods have been defined dynamically by __init__() 

146 #def post(self, *args, **kwargs): pass 

147 #def get(self, *args, **kwargs): pass 

148