1from threading import Lock
2from hashlib import sha256
3
4from .individual_cache import _IndividualCache as IndividualCache
5from .individual_cache import _ExpiringMapping as ExpiringMapping
6from .oauth2cli.http import Response
7from .exceptions import MsalServiceError
8
9
10# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
11DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
12
13
14def _get_headers(response):
15 # MSAL's HttpResponse did not have headers until 1.23.0
16 # https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/581/files#diff-28866b706bc3830cd20485685f20fe79d45b58dce7050e68032e9d9372d68654R61
17 # This helper ensures graceful degradation to {} without exception
18 return getattr(response, "headers", {})
19
20
21class RetryAfterParser(object):
22 FIELD_NAME_LOWER = "Retry-After".lower()
23 def __init__(self, default_value=None):
24 self._default_value = 5 if default_value is None else default_value
25
26 def parse(self, *, result, **ignored):
27 """Return seconds to throttle"""
28 response = result
29 lowercase_headers = {k.lower(): v for k, v in _get_headers(response).items()}
30 if not (response.status_code == 429 or response.status_code >= 500
31 or self.FIELD_NAME_LOWER in lowercase_headers):
32 return 0 # Quick exit
33 retry_after = lowercase_headers.get(self.FIELD_NAME_LOWER, self._default_value)
34 try:
35 # AAD's retry_after uses integer format only
36 # https://stackoverflow.microsoft.com/questions/264931/264932
37 delay_seconds = int(retry_after)
38 except ValueError:
39 delay_seconds = self._default_value
40 return min(3600, delay_seconds)
41
42
43def _extract_data(kwargs, key, default=None):
44 data = kwargs.get("data", {}) # data is usually a dict, but occasionally a string
45 return data.get(key) if isinstance(data, dict) else default
46
47
48class NormalizedResponse(Response):
49 """A http response with the shape defined in Response,
50 but contains only the data we will store in cache.
51 """
52 def __init__(self, raw_response):
53 super().__init__()
54 self.status_code = raw_response.status_code
55 self.text = raw_response.text
56 self.headers = {
57 k.lower(): v for k, v in _get_headers(raw_response).items()
58 # Attempted storing only a small set of headers (such as Retry-After),
59 # but it tends to lead to missing information (such as WWW-Authenticate).
60 # So we store all headers, which are expected to contain only public info,
61 # because we throttle only error responses and public responses.
62 }
63
64 ## Note: Don't use the following line,
65 ## because when being pickled, it will indirectly pickle the whole raw_response
66 # self.raise_for_status = raw_response.raise_for_status
67 def raise_for_status(self):
68 if self.status_code >= 400:
69 raise MsalServiceError(
70 "HTTP Error: {}".format(self.status_code),
71 error=None, error_description=None, # Historically required, keeping them for now
72 )
73
74
75class ThrottledHttpClientBase(object):
76 """Throttle the given http_client by storing and retrieving data from cache.
77
78 This base exists so that:
79 1. These base post() and get() will return a NormalizedResponse
80 2. The base __init__() will NOT re-throttle even if caller accidentally nested ThrottledHttpClient.
81
82 Subclasses shall only need to dynamically decorate their post() and get() methods
83 in their __init__() method.
84 """
85 def __init__(self, http_client, *, http_cache=None):
86 self.http_client = http_client.http_client if isinstance(
87 # If it is already a ThrottledHttpClientBase, we use its raw (unthrottled) http client
88 http_client, ThrottledHttpClientBase) else http_client
89 self._expiring_mapping = ExpiringMapping( # It will automatically clean up
90 mapping=http_cache if http_cache is not None else {},
91 capacity=1024, # To prevent cache blowing up especially for CCA
92 lock=Lock(), # TODO: This should ideally also allow customization
93 )
94
95 def post(self, *args, **kwargs):
96 return NormalizedResponse(self.http_client.post(*args, **kwargs))
97
98 def get(self, *args, **kwargs):
99 return NormalizedResponse(self.http_client.get(*args, **kwargs))
100
101 def close(self):
102 return self.http_client.close()
103
104 @staticmethod
105 def _hash(raw):
106 return sha256(repr(raw).encode("utf-8")).hexdigest()
107
108
109class ThrottledHttpClient(ThrottledHttpClientBase):
110 """A throttled http client that is used by MSAL's non-managed identity clients."""
111 def __init__(self, *args, default_throttle_time=None, **kwargs):
112 """Decorate self.post() and self.get() dynamically"""
113 super(ThrottledHttpClient, self).__init__(*args, **kwargs)
114 self.post = IndividualCache(
115 # Internal specs requires throttling on at least token endpoint,
116 # here we have a generic patch for POST on all endpoints.
117 mapping=self._expiring_mapping,
118 key_maker=lambda func, args, kwargs:
119 "POST {} client_id={} scope={} hash={} 429/5xx/Retry-After".format(
120 args[0], # It is the url, typically containing authority and tenant
121 _extract_data(kwargs, "client_id"), # Per internal specs
122 _extract_data(kwargs, "scope"), # Per internal specs
123 self._hash(
124 # The followings are all approximations of the "account" concept
125 # to support per-account throttling.
126 # TODO: We may want to disable it for confidential client, though
127 _extract_data(kwargs, "refresh_token", # "account" during refresh
128 _extract_data(kwargs, "code", # "account" of auth code grant
129 _extract_data(kwargs, "username", # "account" of ROPC
130 _extract_data(kwargs, "user_id"))))), # "account" of user_fic (OID path)
131 ),
132 expires_in=RetryAfterParser(default_throttle_time or 5).parse,
133 )(self.post)
134
135 self.post = IndividualCache( # It covers the "UI required cache"
136 mapping=self._expiring_mapping,
137 key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format(
138 args[0], # It is the url, typically containing authority and tenant
139 self._hash(
140 # Here we use literally all parameters, even those short-lived
141 # parameters containing timestamps (WS-Trust or POP assertion),
142 # because they will automatically be cleaned up by ExpiringMapping.
143 #
144 # Furthermore, there is no need to implement
145 # "interactive requests would reset the cache",
146 # because acquire_token_silent()'s would be automatically unblocked
147 # due to token cache layer operates on top of http cache layer.
148 #
149 # And, acquire_token_silent(..., force_refresh=True) will NOT
150 # bypass http cache, because there is no real gain from that.
151 # We won't bother implement it, nor do we want to encourage
152 # acquire_token_silent(..., force_refresh=True) pattern.
153 str(kwargs.get("params")) + str(kwargs.get("data"))),
154 ),
155 expires_in=lambda result=None, kwargs=None, **ignored:
156 60
157 if result.status_code == 400
158 # Here we choose to cache exact HTTP 400 errors only (rather than 4xx)
159 # because they are the ones defined in OAuth2
160 # (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2)
161 # Other 4xx errors might have different requirements e.g.
162 # "407 Proxy auth required" would need a key including http headers.
163 and not( # Exclude Device Flow whose retry is expected and regulated
164 isinstance(kwargs.get("data"), dict)
165 and kwargs["data"].get("grant_type") == DEVICE_AUTH_GRANT
166 )
167 and RetryAfterParser.FIELD_NAME_LOWER not in set( # Otherwise leave it to the Retry-After decorator
168 h.lower() for h in _get_headers(result))
169 else 0,
170 )(self.post)
171
172 self.get = IndividualCache( # Typically those discovery GETs
173 mapping=self._expiring_mapping,
174 key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format(
175 args[0], # It is the url, sometimes containing inline params
176 self._hash(kwargs.get("params", "")),
177 ),
178 expires_in=lambda result=None, **ignored:
179 3600*24 if 200 <= result.status_code < 300 else 0,
180 )(self.get)