1from threading import Lock
2from hashlib import sha256
3
4from .individual_cache import _IndividualCache as IndividualCache
5from .individual_cache import _ExpiringMapping as ExpiringMapping
6
7
8# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
9DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
10
11
12class RetryAfterParser(object):
13 def __init__(self, default_value=None):
14 self._default_value = 5 if default_value is None else default_value
15
16 def parse(self, *, result, **ignored):
17 """Return seconds to throttle"""
18 response = result
19 lowercase_headers = {k.lower(): v for k, v in getattr(
20 # Historically, MSAL's HttpResponse does not always have headers
21 response, "headers", {}).items()}
22 if not (response.status_code == 429 or response.status_code >= 500
23 or "retry-after" in lowercase_headers):
24 return 0 # Quick exit
25 retry_after = lowercase_headers.get("retry-after", self._default_value)
26 try:
27 # AAD's retry_after uses integer format only
28 # https://stackoverflow.microsoft.com/questions/264931/264932
29 delay_seconds = int(retry_after)
30 except ValueError:
31 delay_seconds = self._default_value
32 return min(3600, delay_seconds)
33
34
35def _extract_data(kwargs, key, default=None):
36 data = kwargs.get("data", {}) # data is usually a dict, but occasionally a string
37 return data.get(key) if isinstance(data, dict) else default
38
39
40class ThrottledHttpClientBase(object):
41 """Throttle the given http_client by storing and retrieving data from cache.
42
43 This wrapper exists so that our patching post() and get() would prevent
44 re-patching side effect when/if same http_client being reused.
45
46 The subclass should implement post() and/or get()
47 """
48 def __init__(self, http_client, *, http_cache=None):
49 self.http_client = http_client
50 self._expiring_mapping = ExpiringMapping( # It will automatically clean up
51 mapping=http_cache if http_cache is not None else {},
52 capacity=1024, # To prevent cache blowing up especially for CCA
53 lock=Lock(), # TODO: This should ideally also allow customization
54 )
55
56 def post(self, *args, **kwargs):
57 return self.http_client.post(*args, **kwargs)
58
59 def get(self, *args, **kwargs):
60 return self.http_client.get(*args, **kwargs)
61
62 def close(self):
63 return self.http_client.close()
64
65 @staticmethod
66 def _hash(raw):
67 return sha256(repr(raw).encode("utf-8")).hexdigest()
68
69
70class ThrottledHttpClient(ThrottledHttpClientBase):
71 def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
72 super(ThrottledHttpClient, self).__init__(http_client, **kwargs)
73
74 _post = http_client.post # We'll patch _post, and keep original post() intact
75
76 _post = IndividualCache(
77 # Internal specs requires throttling on at least token endpoint,
78 # here we have a generic patch for POST on all endpoints.
79 mapping=self._expiring_mapping,
80 key_maker=lambda func, args, kwargs:
81 "POST {} client_id={} scope={} hash={} 429/5xx/Retry-After".format(
82 args[0], # It is the url, typically containing authority and tenant
83 _extract_data(kwargs, "client_id"), # Per internal specs
84 _extract_data(kwargs, "scope"), # Per internal specs
85 self._hash(
86 # The followings are all approximations of the "account" concept
87 # to support per-account throttling.
88 # TODO: We may want to disable it for confidential client, though
89 _extract_data(kwargs, "refresh_token", # "account" during refresh
90 _extract_data(kwargs, "code", # "account" of auth code grant
91 _extract_data(kwargs, "username")))), # "account" of ROPC
92 ),
93 expires_in=RetryAfterParser(default_throttle_time or 5).parse,
94 )(_post)
95
96 _post = IndividualCache( # It covers the "UI required cache"
97 mapping=self._expiring_mapping,
98 key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format(
99 args[0], # It is the url, typically containing authority and tenant
100 self._hash(
101 # Here we use literally all parameters, even those short-lived
102 # parameters containing timestamps (WS-Trust or POP assertion),
103 # because they will automatically be cleaned up by ExpiringMapping.
104 #
105 # Furthermore, there is no need to implement
106 # "interactive requests would reset the cache",
107 # because acquire_token_silent()'s would be automatically unblocked
108 # due to token cache layer operates on top of http cache layer.
109 #
110 # And, acquire_token_silent(..., force_refresh=True) will NOT
111 # bypass http cache, because there is no real gain from that.
112 # We won't bother implement it, nor do we want to encourage
113 # acquire_token_silent(..., force_refresh=True) pattern.
114 str(kwargs.get("params")) + str(kwargs.get("data"))),
115 ),
116 expires_in=lambda result=None, kwargs=None, **ignored:
117 60
118 if result.status_code == 400
119 # Here we choose to cache exact HTTP 400 errors only (rather than 4xx)
120 # because they are the ones defined in OAuth2
121 # (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2)
122 # Other 4xx errors might have different requirements e.g.
123 # "407 Proxy auth required" would need a key including http headers.
124 and not( # Exclude Device Flow whose retry is expected and regulated
125 isinstance(kwargs.get("data"), dict)
126 and kwargs["data"].get("grant_type") == DEVICE_AUTH_GRANT
127 )
128 and "retry-after" not in set( # Leave it to the Retry-After decorator
129 h.lower() for h in getattr(result, "headers", {}).keys())
130 else 0,
131 )(_post)
132
133 self.post = _post
134
135 self.get = IndividualCache( # Typically those discovery GETs
136 mapping=self._expiring_mapping,
137 key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format(
138 args[0], # It is the url, sometimes containing inline params
139 self._hash(kwargs.get("params", "")),
140 ),
141 expires_in=lambda result=None, **ignored:
142 3600*24 if 200 <= result.status_code < 300 else 0,
143 )(http_client.get)
144
145 # The following 2 methods have been defined dynamically by __init__()
146 #def post(self, *args, **kwargs): pass
147 #def get(self, *args, **kwargs): pass
148