1# -------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See LICENSE.txt in the project root for
4# license information.
5# -------------------------------------------------------------------------
6import time
7import base64
8import random
9from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any, Union, cast
10
11from azure.core.credentials import (
12 TokenCredential,
13 SupportsTokenInfo,
14 TokenRequestOptions,
15 TokenProvider,
16)
17from azure.core.exceptions import HttpResponseError
18from azure.core.pipeline import PipelineRequest, PipelineResponse
19from azure.core.pipeline.transport import (
20 HttpResponse as LegacyHttpResponse,
21 HttpRequest as LegacyHttpRequest,
22)
23from azure.core.rest import HttpResponse, HttpRequest
24from . import HTTPPolicy, SansIOHTTPPolicy
25from ...exceptions import ServiceRequestError
26from ._utils import get_challenge_parameter
27
28if TYPE_CHECKING:
29
30 from azure.core.credentials import (
31 AccessToken,
32 AccessTokenInfo,
33 AzureKeyCredential,
34 AzureSasCredential,
35 )
36
37HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse)
38HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)
39
40DEFAULT_REFRESH_WINDOW_SECONDS = 300 # 5 minutes
41MAX_REFRESH_JITTER_SECONDS = 60 # 1 minute
42
43
44def _should_refresh_token(token: Optional[Union["AccessToken", "AccessTokenInfo"]], refresh_jitter: int) -> bool:
45 """Check if a new token is needed based on expiry and refresh logic.
46
47 :param token: The current token or None if no token exists
48 :type token: Optional[Union[~azure.core.credentials.AccessToken, ~azure.core.credentials.AccessTokenInfo]]
49 :param int refresh_jitter: The jitter to apply to refresh timing
50 :return: True if a new token is needed, False otherwise
51 :rtype: bool
52 """
53 if not token:
54 return True
55
56 now = time.time()
57 if token.expires_on <= now:
58 return True
59
60 refresh_on = getattr(token, "refresh_on", None)
61
62 if refresh_on:
63 # Apply jitter, but ensure that adding it doesn't push the refresh time past the actual expiration.
64 # This is a safeguard, as refresh_on is typically well before expires_on.
65 effective_refresh_time = min(refresh_on + refresh_jitter, token.expires_on)
66 return effective_refresh_time <= now
67
68 time_until_expiry = token.expires_on - now
69 # Reduce refresh window by jitter to delay refresh and distribute load
70 return time_until_expiry < (DEFAULT_REFRESH_WINDOW_SECONDS - refresh_jitter)
71
72
73def _enforce_https(request: PipelineRequest[HTTPRequestType]) -> None:
74 # move 'enforce_https' from options to context so it persists
75 # across retries but isn't passed to a transport implementation
76 option = request.context.options.pop("enforce_https", None)
77
78 # True is the default setting; we needn't preserve an explicit opt in to the default behavior
79 if option is False:
80 request.context["enforce_https"] = option
81
82 enforce_https = request.context.get("enforce_https", True)
83 if enforce_https and not request.http_request.url.lower().startswith("https"):
84 raise ServiceRequestError(
85 "Bearer token authentication is not permitted for non-TLS protected (non-https) URLs."
86 )
87
88
89# pylint:disable=too-few-public-methods
90class _BearerTokenCredentialPolicyBase:
91 """Base class for a Bearer Token Credential Policy.
92
93 :param credential: The credential.
94 :type credential: ~azure.core.credentials.TokenProvider
95 :param str scopes: Lets you specify the type of access needed.
96 :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested
97 tokens. Defaults to False.
98 """
99
100 def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None:
101 super(_BearerTokenCredentialPolicyBase, self).__init__()
102 self._scopes = scopes
103 self._credential = credential
104 self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None
105 self._enable_cae: bool = kwargs.get("enable_cae", False)
106 self._refresh_jitter = 0
107
108 @staticmethod
109 def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
110 """Updates the Authorization header with the bearer token.
111
112 :param MutableMapping[str, str] headers: The HTTP Request headers
113 :param str token: The OAuth token.
114 """
115 headers["Authorization"] = "Bearer {}".format(token)
116
117 @property
118 def _need_new_token(self) -> bool:
119 return _should_refresh_token(self._token, self._refresh_jitter)
120
121 def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "AccessTokenInfo"]:
122 if self._enable_cae:
123 kwargs.setdefault("enable_cae", self._enable_cae)
124
125 if hasattr(self._credential, "get_token_info"):
126 options: TokenRequestOptions = {}
127 # Loop through all the keyword arguments and check if they are part of the TokenRequestOptions.
128 for key in list(kwargs.keys()):
129 if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member
130 options[key] = kwargs.pop(key) # type: ignore[literal-required]
131
132 return cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options)
133 return cast(TokenCredential, self._credential).get_token(*scopes, **kwargs)
134
135 def _request_token(self, *scopes: str, **kwargs: Any) -> None:
136 """Request a new token from the credential.
137
138 This will call the credential's appropriate method to get a token and store it in the policy.
139
140 :param str scopes: The type of access needed.
141 """
142 self._token = self._get_token(*scopes, **kwargs)
143 self._refresh_jitter = random.randint(0, MAX_REFRESH_JITTER_SECONDS)
144
145
146class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]):
147 """Adds a bearer token Authorization header to requests.
148
149 :param credential: The credential.
150 :type credential: ~azure.core.TokenCredential
151 :param str scopes: Lets you specify the type of access needed.
152 :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested
153 tokens. Defaults to False.
154 :raises ~azure.core.exceptions.ServiceRequestError: If the request fails.
155 """
156
157 def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
158 """Called before the policy sends a request.
159
160 The base implementation authorizes the request with a bearer token.
161
162 :param ~azure.core.pipeline.PipelineRequest request: the request
163 """
164 _enforce_https(request)
165
166 if self._token is None or self._need_new_token:
167 self._request_token(*self._scopes)
168 bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token
169 self._update_headers(request.http_request.headers, bearer_token)
170
171 def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
172 """Acquire a token from the credential and authorize the request with it.
173
174 Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
175 authorize future requests.
176
177 :param ~azure.core.pipeline.PipelineRequest request: the request
178 :param str scopes: required scopes of authentication
179 """
180 self._request_token(*scopes, **kwargs)
181 bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token
182 self._update_headers(request.http_request.headers, bearer_token)
183
184 def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]:
185 """Authorize request with a bearer token and send it to the next policy
186
187 :param request: The pipeline request object
188 :type request: ~azure.core.pipeline.PipelineRequest
189 :return: The pipeline response object
190 :rtype: ~azure.core.pipeline.PipelineResponse
191 """
192 self.on_request(request)
193 try:
194 response = self.next.send(request)
195 except Exception:
196 self.on_exception(request)
197 raise
198
199 self.on_response(request, response)
200 if response.http_response.status_code == 401:
201 self._token = None # any cached token is invalid
202 if "WWW-Authenticate" in response.http_response.headers:
203 try:
204 request_authorized = self.on_challenge(request, response)
205 except Exception as ex:
206 # If the response is streamed, read it so the error message is immediately available to the user.
207 # Otherwise, a generic error message will be given and the user will have to read the response
208 # body to see the actual error.
209 if response.context.options.get("stream"):
210 try:
211 response.http_response.read() # type: ignore
212 except Exception: # pylint:disable=broad-except
213 pass
214 # Raise the exception from the token request with the original 401 response
215 raise ex from HttpResponseError(response=response.http_response)
216
217 if request_authorized:
218 try:
219 response = self.next.send(request)
220 self.on_response(request, response)
221 except Exception:
222 self.on_exception(request)
223 raise
224
225 return response
226
227 def on_challenge(
228 self,
229 request: PipelineRequest[HTTPRequestType],
230 response: PipelineResponse[HTTPRequestType, HTTPResponseType],
231 ) -> bool:
232 """Authorize request according to an authentication challenge
233
234 This method is called when the resource provider responds 401 with a WWW-Authenticate header.
235
236 :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
237 :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
238 :returns: a bool indicating whether the policy should send the request
239 :rtype: bool
240 """
241 headers = response.http_response.headers
242 error = get_challenge_parameter(headers, "Bearer", "error")
243 if error == "insufficient_claims":
244 encoded_claims = get_challenge_parameter(headers, "Bearer", "claims")
245 if not encoded_claims:
246 return False
247 padding_needed = -len(encoded_claims) % 4
248 claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode("utf-8")
249 if claims:
250 self.authorize_request(request, *self._scopes, claims=claims)
251 return True
252 return False
253
254 def on_response(
255 self,
256 request: PipelineRequest[HTTPRequestType],
257 response: PipelineResponse[HTTPRequestType, HTTPResponseType],
258 ) -> None:
259 """Executed after the request comes back from the next policy.
260
261 :param request: Request to be modified after returning from the policy.
262 :type request: ~azure.core.pipeline.PipelineRequest
263 :param response: Pipeline response object
264 :type response: ~azure.core.pipeline.PipelineResponse
265 """
266
267 def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None:
268 """Executed when an exception is raised while executing the next policy.
269
270 This method is executed inside the exception handler.
271
272 :param request: The Pipeline request object
273 :type request: ~azure.core.pipeline.PipelineRequest
274 """
275 # pylint: disable=unused-argument
276 return
277
278
279class AzureKeyCredentialPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
280 """Adds a key header for the provided credential.
281
282 :param credential: The credential used to authenticate requests.
283 :type credential: ~azure.core.credentials.AzureKeyCredential
284 :param str name: The name of the key header used for the credential.
285 :keyword str prefix: The name of the prefix for the header value if any.
286 :raises ValueError: if name is None or empty.
287 :raises TypeError: if name is not a string or if credential is not an instance of AzureKeyCredential.
288 """
289
290 def __init__( # pylint: disable=unused-argument
291 self,
292 credential: "AzureKeyCredential",
293 name: str,
294 *,
295 prefix: Optional[str] = None,
296 **kwargs: Any,
297 ) -> None:
298 super().__init__()
299 if not hasattr(credential, "key"):
300 raise TypeError("String is not a supported credential input type. Use an instance of AzureKeyCredential.")
301 if not name:
302 raise ValueError("name can not be None or empty")
303 if not isinstance(name, str):
304 raise TypeError("name must be a string.")
305 self._credential = credential
306 self._name = name
307 self._prefix = prefix + " " if prefix else ""
308
309 def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
310 """Called before the policy sends a request.
311
312 :param request: The request to be modified before sending.
313 :type request: ~azure.core.pipeline.PipelineRequest
314 """
315 request.http_request.headers[self._name] = f"{self._prefix}{self._credential.key}"
316
317
318class AzureSasCredentialPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
319 """Adds a shared access signature to query for the provided credential.
320
321 :param credential: The credential used to authenticate requests.
322 :type credential: ~azure.core.credentials.AzureSasCredential
323 :raises ValueError: if credential is None.
324 """
325
326 def __init__(
327 self, # pylint: disable=unused-argument
328 credential: "AzureSasCredential",
329 **kwargs: Any,
330 ) -> None:
331 super(AzureSasCredentialPolicy, self).__init__()
332 if not credential:
333 raise ValueError("credential can not be None")
334 self._credential = credential
335
336 def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
337 """Called before the policy sends a request.
338
339 :param request: The request to be modified before sending.
340 :type request: ~azure.core.pipeline.PipelineRequest
341 """
342 url = request.http_request.url
343 query = request.http_request.query
344 signature = self._credential.signature
345 if signature.startswith("?"):
346 signature = signature[1:]
347 if query:
348 if signature not in url:
349 url = url + "&" + signature
350 else:
351 if url.endswith("?"):
352 url = url + signature
353 else:
354 url = url + "?" + signature
355 request.http_request.url = url