Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/msal/throttled_http_client.py: 56%
34 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:20 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:20 +0000
1from threading import Lock
2from hashlib import sha256
4from .individual_cache import _IndividualCache as IndividualCache
5from .individual_cache import _ExpiringMapping as ExpiringMapping
8# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
9DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
12def _hash(raw):
13 return sha256(repr(raw).encode("utf-8")).hexdigest()
16def _parse_http_429_5xx_retry_after(result=None, **ignored):
17 """Return seconds to throttle"""
18 assert result is not None, """
19 The signature defines it with a default value None,
20 only because the its shape is already decided by the
21 IndividualCache's.__call__().
22 In actual code path, the result parameter here won't be None.
23 """
24 response = result
25 lowercase_headers = {k.lower(): v for k, v in getattr(
26 # Historically, MSAL's HttpResponse does not always have headers
27 response, "headers", {}).items()}
28 if not (response.status_code == 429 or response.status_code >= 500
29 or "retry-after" in lowercase_headers):
30 return 0 # Quick exit
31 default = 60 # Recommended at the end of
32 # https://identitydivision.visualstudio.com/devex/_git/AuthLibrariesApiReview?version=GBdev&path=%2FService%20protection%2FIntial%20set%20of%20protection%20measures.md&_a=preview
33 retry_after = int(lowercase_headers.get("retry-after", default))
34 try:
35 # AAD's retry_after uses integer format only
36 # https://stackoverflow.microsoft.com/questions/264931/264932
37 delay_seconds = int(retry_after)
38 except ValueError:
39 delay_seconds = default
40 return min(3600, delay_seconds)
43def _extract_data(kwargs, key, default=None):
44 data = kwargs.get("data", {}) # data is usually a dict, but occasionally a string
45 return data.get(key) if isinstance(data, dict) else default
48class ThrottledHttpClient(object):
49 def __init__(self, http_client, http_cache):
50 """Throttle the given http_client by storing and retrieving data from cache.
52 This wrapper exists so that our patching post() and get() would prevent
53 re-patching side effect when/if same http_client being reused.
54 """
55 expiring_mapping = ExpiringMapping( # It will automatically clean up
56 mapping=http_cache if http_cache is not None else {},
57 capacity=1024, # To prevent cache blowing up especially for CCA
58 lock=Lock(), # TODO: This should ideally also allow customization
59 )
61 _post = http_client.post # We'll patch _post, and keep original post() intact
63 _post = IndividualCache(
64 # Internal specs requires throttling on at least token endpoint,
65 # here we have a generic patch for POST on all endpoints.
66 mapping=expiring_mapping,
67 key_maker=lambda func, args, kwargs:
68 "POST {} client_id={} scope={} hash={} 429/5xx/Retry-After".format(
69 args[0], # It is the url, typically containing authority and tenant
70 _extract_data(kwargs, "client_id"), # Per internal specs
71 _extract_data(kwargs, "scope"), # Per internal specs
72 _hash(
73 # The followings are all approximations of the "account" concept
74 # to support per-account throttling.
75 # TODO: We may want to disable it for confidential client, though
76 _extract_data(kwargs, "refresh_token", # "account" during refresh
77 _extract_data(kwargs, "code", # "account" of auth code grant
78 _extract_data(kwargs, "username")))), # "account" of ROPC
79 ),
80 expires_in=_parse_http_429_5xx_retry_after,
81 )(_post)
83 _post = IndividualCache( # It covers the "UI required cache"
84 mapping=expiring_mapping,
85 key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format(
86 args[0], # It is the url, typically containing authority and tenant
87 _hash(
88 # Here we use literally all parameters, even those short-lived
89 # parameters containing timestamps (WS-Trust or POP assertion),
90 # because they will automatically be cleaned up by ExpiringMapping.
91 #
92 # Furthermore, there is no need to implement
93 # "interactive requests would reset the cache",
94 # because acquire_token_silent()'s would be automatically unblocked
95 # due to token cache layer operates on top of http cache layer.
96 #
97 # And, acquire_token_silent(..., force_refresh=True) will NOT
98 # bypass http cache, because there is no real gain from that.
99 # We won't bother implement it, nor do we want to encourage
100 # acquire_token_silent(..., force_refresh=True) pattern.
101 str(kwargs.get("params")) + str(kwargs.get("data"))),
102 ),
103 expires_in=lambda result=None, kwargs=None, **ignored:
104 60
105 if result.status_code == 400
106 # Here we choose to cache exact HTTP 400 errors only (rather than 4xx)
107 # because they are the ones defined in OAuth2
108 # (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2)
109 # Other 4xx errors might have different requirements e.g.
110 # "407 Proxy auth required" would need a key including http headers.
111 and not( # Exclude Device Flow whose retry is expected and regulated
112 isinstance(kwargs.get("data"), dict)
113 and kwargs["data"].get("grant_type") == DEVICE_AUTH_GRANT
114 )
115 and "retry-after" not in set( # Leave it to the Retry-After decorator
116 h.lower() for h in getattr(result, "headers", {}).keys())
117 else 0,
118 )(_post)
120 self.post = _post
122 self.get = IndividualCache( # Typically those discovery GETs
123 mapping=expiring_mapping,
124 key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format(
125 args[0], # It is the url, sometimes containing inline params
126 _hash(kwargs.get("params", "")),
127 ),
128 expires_in=lambda result=None, **ignored:
129 3600*24 if 200 <= result.status_code < 300 else 0,
130 )(http_client.get)
132 self._http_client = http_client
134 # The following 2 methods have been defined dynamically by __init__()
135 #def post(self, *args, **kwargs): pass
136 #def get(self, *args, **kwargs): pass
138 def close(self):
139 """MSAL won't need this. But we allow throttled_http_client.close() anyway"""
140 return self._http_client.close()