1from threading import Lock
2from hashlib import sha256
3
4from .individual_cache import _IndividualCache as IndividualCache
5from .individual_cache import _ExpiringMapping as ExpiringMapping
6from .oauth2cli.http import Response
7from .exceptions import MsalServiceError
8
9
10# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
11DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
12
13
14def _get_headers(response):
15 # MSAL's HttpResponse did not have headers until 1.23.0
16 # https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/581/files#diff-28866b706bc3830cd20485685f20fe79d45b58dce7050e68032e9d9372d68654R61
17 # This helper ensures graceful degradation to {} without exception
18 return getattr(response, "headers", {})
19
20
21class RetryAfterParser(object):
22 FIELD_NAME_LOWER = "Retry-After".lower()
23 def __init__(self, default_value=None):
24 self._default_value = 5 if default_value is None else default_value
25
26 def parse(self, *, result, **ignored):
27 """Return seconds to throttle"""
28 response = result
29 lowercase_headers = {k.lower(): v for k, v in _get_headers(response).items()}
30 if not (response.status_code == 429 or response.status_code >= 500
31 or self.FIELD_NAME_LOWER in lowercase_headers):
32 return 0 # Quick exit
33 retry_after = lowercase_headers.get(self.FIELD_NAME_LOWER, self._default_value)
34 try:
35 # AAD's retry_after uses integer format only
36 # https://stackoverflow.microsoft.com/questions/264931/264932
37 delay_seconds = int(retry_after)
38 except ValueError:
39 delay_seconds = self._default_value
40 return min(3600, delay_seconds)
41
42
43def _extract_data(kwargs, key, default=None):
44 data = kwargs.get("data", {}) # data is usually a dict, but occasionally a string
45 return data.get(key) if isinstance(data, dict) else default
46
47
48class NormalizedResponse(Response):
49 """A http response with the shape defined in Response,
50 but contains only the data we will store in cache.
51 """
52 def __init__(self, raw_response):
53 super().__init__()
54 self.status_code = raw_response.status_code
55 self.text = raw_response.text
56 self.headers = {
57 k.lower(): v for k, v in _get_headers(raw_response).items()
58 # Attempted storing only a small set of headers (such as Retry-After),
59 # but it tends to lead to missing information (such as WWW-Authenticate).
60 # So we store all headers, which are expected to contain only public info,
61 # because we throttle only error responses and public responses.
62 }
63
64 ## Note: Don't use the following line,
65 ## because when being pickled, it will indirectly pickle the whole raw_response
66 # self.raise_for_status = raw_response.raise_for_status
67 def raise_for_status(self):
68 if self.status_code >= 400:
69 raise MsalServiceError(
70 "HTTP Error: {}".format(self.status_code),
71 error=None, error_description=None, # Historically required, keeping them for now
72 )
73
74
75class ThrottledHttpClientBase(object):
76 """Throttle the given http_client by storing and retrieving data from cache.
77
78 This base exists so that:
79 1. These base post() and get() will return a NormalizedResponse
80 2. The base __init__() will NOT re-throttle even if caller accidentally nested ThrottledHttpClient.
81
82 Subclasses shall only need to dynamically decorate their post() and get() methods
83 in their __init__() method.
84 """
85 def __init__(self, http_client, *, http_cache=None):
86 self.http_client = http_client.http_client if isinstance(
87 # If it is already a ThrottledHttpClientBase, we use its raw (unthrottled) http client
88 http_client, ThrottledHttpClientBase) else http_client
89 self._expiring_mapping = ExpiringMapping( # It will automatically clean up
90 mapping=http_cache if http_cache is not None else {},
91 capacity=1024, # To prevent cache blowing up especially for CCA
92 lock=Lock(), # TODO: This should ideally also allow customization
93 )
94
95 def post(self, *args, **kwargs):
96 return NormalizedResponse(self.http_client.post(*args, **kwargs))
97
98 def get(self, *args, **kwargs):
99 return NormalizedResponse(self.http_client.get(*args, **kwargs))
100
101 def close(self):
102 return self.http_client.close()
103
104 @staticmethod
105 def _hash(raw):
106 return sha256(repr(raw).encode("utf-8")).hexdigest()
107
108
109class ThrottledHttpClient(ThrottledHttpClientBase):
110 """A throttled http client that is used by MSAL's non-managed identity clients."""
111 def __init__(self, *args, default_throttle_time=None, **kwargs):
112 """Decorate self.post() and self.get() dynamically"""
113 super(ThrottledHttpClient, self).__init__(*args, **kwargs)
114 self.post = IndividualCache(
115 # Internal specs requires throttling on at least token endpoint,
116 # here we have a generic patch for POST on all endpoints.
117 mapping=self._expiring_mapping,
118 key_maker=lambda func, args, kwargs:
119 "POST {} client_id={} scope={} hash={} 429/5xx/Retry-After".format(
120 args[0], # It is the url, typically containing authority and tenant
121 _extract_data(kwargs, "client_id"), # Per internal specs
122 _extract_data(kwargs, "scope"), # Per internal specs
123 self._hash(
124 # The followings are all approximations of the "account" concept
125 # to support per-account throttling.
126 # TODO: We may want to disable it for confidential client, though
127 _extract_data(kwargs, "refresh_token", # "account" during refresh
128 _extract_data(kwargs, "code", # "account" of auth code grant
129 _extract_data(kwargs, "username")))), # "account" of ROPC
130 ),
131 expires_in=RetryAfterParser(default_throttle_time or 5).parse,
132 )(self.post)
133
134 self.post = IndividualCache( # It covers the "UI required cache"
135 mapping=self._expiring_mapping,
136 key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format(
137 args[0], # It is the url, typically containing authority and tenant
138 self._hash(
139 # Here we use literally all parameters, even those short-lived
140 # parameters containing timestamps (WS-Trust or POP assertion),
141 # because they will automatically be cleaned up by ExpiringMapping.
142 #
143 # Furthermore, there is no need to implement
144 # "interactive requests would reset the cache",
145 # because acquire_token_silent()'s would be automatically unblocked
146 # due to token cache layer operates on top of http cache layer.
147 #
148 # And, acquire_token_silent(..., force_refresh=True) will NOT
149 # bypass http cache, because there is no real gain from that.
150 # We won't bother implement it, nor do we want to encourage
151 # acquire_token_silent(..., force_refresh=True) pattern.
152 str(kwargs.get("params")) + str(kwargs.get("data"))),
153 ),
154 expires_in=lambda result=None, kwargs=None, **ignored:
155 60
156 if result.status_code == 400
157 # Here we choose to cache exact HTTP 400 errors only (rather than 4xx)
158 # because they are the ones defined in OAuth2
159 # (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2)
160 # Other 4xx errors might have different requirements e.g.
161 # "407 Proxy auth required" would need a key including http headers.
162 and not( # Exclude Device Flow whose retry is expected and regulated
163 isinstance(kwargs.get("data"), dict)
164 and kwargs["data"].get("grant_type") == DEVICE_AUTH_GRANT
165 )
166 and RetryAfterParser.FIELD_NAME_LOWER not in set( # Otherwise leave it to the Retry-After decorator
167 h.lower() for h in _get_headers(result))
168 else 0,
169 )(self.post)
170
171 self.get = IndividualCache( # Typically those discovery GETs
172 mapping=self._expiring_mapping,
173 key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format(
174 args[0], # It is the url, sometimes containing inline params
175 self._hash(kwargs.get("params", "")),
176 ),
177 expires_in=lambda result=None, **ignored:
178 3600*24 if 200 <= result.status_code < 300 else 0,
179 )(self.get)