Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/msal/throttled_http_client.py: 56%

34 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:20 +0000

1from threading import Lock 

2from hashlib import sha256 

3 

4from .individual_cache import _IndividualCache as IndividualCache 

5from .individual_cache import _ExpiringMapping as ExpiringMapping 

6 

7 

8# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4 

9DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code" 

10 

11 

12def _hash(raw): 

13 return sha256(repr(raw).encode("utf-8")).hexdigest() 

14 

15 

16def _parse_http_429_5xx_retry_after(result=None, **ignored): 

17 """Return seconds to throttle""" 

18 assert result is not None, """ 

19 The signature defines it with a default value None, 

20 only because the its shape is already decided by the 

21 IndividualCache's.__call__(). 

22 In actual code path, the result parameter here won't be None. 

23 """ 

24 response = result 

25 lowercase_headers = {k.lower(): v for k, v in getattr( 

26 # Historically, MSAL's HttpResponse does not always have headers 

27 response, "headers", {}).items()} 

28 if not (response.status_code == 429 or response.status_code >= 500 

29 or "retry-after" in lowercase_headers): 

30 return 0 # Quick exit 

31 default = 60 # Recommended at the end of 

32 # https://identitydivision.visualstudio.com/devex/_git/AuthLibrariesApiReview?version=GBdev&path=%2FService%20protection%2FIntial%20set%20of%20protection%20measures.md&_a=preview 

33 retry_after = int(lowercase_headers.get("retry-after", default)) 

34 try: 

35 # AAD's retry_after uses integer format only 

36 # https://stackoverflow.microsoft.com/questions/264931/264932 

37 delay_seconds = int(retry_after) 

38 except ValueError: 

39 delay_seconds = default 

40 return min(3600, delay_seconds) 

41 

42 

43def _extract_data(kwargs, key, default=None): 

44 data = kwargs.get("data", {}) # data is usually a dict, but occasionally a string 

45 return data.get(key) if isinstance(data, dict) else default 

46 

47 

48class ThrottledHttpClient(object): 

49 def __init__(self, http_client, http_cache): 

50 """Throttle the given http_client by storing and retrieving data from cache. 

51 

52 This wrapper exists so that our patching post() and get() would prevent 

53 re-patching side effect when/if same http_client being reused. 

54 """ 

55 expiring_mapping = ExpiringMapping( # It will automatically clean up 

56 mapping=http_cache if http_cache is not None else {}, 

57 capacity=1024, # To prevent cache blowing up especially for CCA 

58 lock=Lock(), # TODO: This should ideally also allow customization 

59 ) 

60 

61 _post = http_client.post # We'll patch _post, and keep original post() intact 

62 

63 _post = IndividualCache( 

64 # Internal specs requires throttling on at least token endpoint, 

65 # here we have a generic patch for POST on all endpoints. 

66 mapping=expiring_mapping, 

67 key_maker=lambda func, args, kwargs: 

68 "POST {} client_id={} scope={} hash={} 429/5xx/Retry-After".format( 

69 args[0], # It is the url, typically containing authority and tenant 

70 _extract_data(kwargs, "client_id"), # Per internal specs 

71 _extract_data(kwargs, "scope"), # Per internal specs 

72 _hash( 

73 # The followings are all approximations of the "account" concept 

74 # to support per-account throttling. 

75 # TODO: We may want to disable it for confidential client, though 

76 _extract_data(kwargs, "refresh_token", # "account" during refresh 

77 _extract_data(kwargs, "code", # "account" of auth code grant 

78 _extract_data(kwargs, "username")))), # "account" of ROPC 

79 ), 

80 expires_in=_parse_http_429_5xx_retry_after, 

81 )(_post) 

82 

83 _post = IndividualCache( # It covers the "UI required cache" 

84 mapping=expiring_mapping, 

85 key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format( 

86 args[0], # It is the url, typically containing authority and tenant 

87 _hash( 

88 # Here we use literally all parameters, even those short-lived 

89 # parameters containing timestamps (WS-Trust or POP assertion), 

90 # because they will automatically be cleaned up by ExpiringMapping. 

91 # 

92 # Furthermore, there is no need to implement 

93 # "interactive requests would reset the cache", 

94 # because acquire_token_silent()'s would be automatically unblocked 

95 # due to token cache layer operates on top of http cache layer. 

96 # 

97 # And, acquire_token_silent(..., force_refresh=True) will NOT 

98 # bypass http cache, because there is no real gain from that. 

99 # We won't bother implement it, nor do we want to encourage 

100 # acquire_token_silent(..., force_refresh=True) pattern. 

101 str(kwargs.get("params")) + str(kwargs.get("data"))), 

102 ), 

103 expires_in=lambda result=None, kwargs=None, **ignored: 

104 60 

105 if result.status_code == 400 

106 # Here we choose to cache exact HTTP 400 errors only (rather than 4xx) 

107 # because they are the ones defined in OAuth2 

108 # (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2) 

109 # Other 4xx errors might have different requirements e.g. 

110 # "407 Proxy auth required" would need a key including http headers. 

111 and not( # Exclude Device Flow whose retry is expected and regulated 

112 isinstance(kwargs.get("data"), dict) 

113 and kwargs["data"].get("grant_type") == DEVICE_AUTH_GRANT 

114 ) 

115 and "retry-after" not in set( # Leave it to the Retry-After decorator 

116 h.lower() for h in getattr(result, "headers", {}).keys()) 

117 else 0, 

118 )(_post) 

119 

120 self.post = _post 

121 

122 self.get = IndividualCache( # Typically those discovery GETs 

123 mapping=expiring_mapping, 

124 key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format( 

125 args[0], # It is the url, sometimes containing inline params 

126 _hash(kwargs.get("params", "")), 

127 ), 

128 expires_in=lambda result=None, **ignored: 

129 3600*24 if 200 <= result.status_code < 300 else 0, 

130 )(http_client.get) 

131 

132 self._http_client = http_client 

133 

134 # The following 2 methods have been defined dynamically by __init__() 

135 #def post(self, *args, **kwargs): pass 

136 #def get(self, *args, **kwargs): pass 

137 

138 def close(self): 

139 """MSAL won't need this. But we allow throttled_http_client.close() anyway""" 

140 return self._http_client.close() 

141