1# -------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See LICENSE.txt in the project root for
4# license information.
5# -------------------------------------------------------------------------
6import time
7import base64
8from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any, Union, cast
9
10from azure.core.credentials import (
11 TokenCredential,
12 SupportsTokenInfo,
13 TokenRequestOptions,
14 TokenProvider,
15)
16from azure.core.exceptions import HttpResponseError
17from azure.core.pipeline import PipelineRequest, PipelineResponse
18from azure.core.pipeline.transport import (
19 HttpResponse as LegacyHttpResponse,
20 HttpRequest as LegacyHttpRequest,
21)
22from azure.core.rest import HttpResponse, HttpRequest
23from . import HTTPPolicy, SansIOHTTPPolicy
24from ...exceptions import ServiceRequestError
25from ._utils import get_challenge_parameter
26
27if TYPE_CHECKING:
28
29 from azure.core.credentials import (
30 AccessToken,
31 AccessTokenInfo,
32 AzureKeyCredential,
33 AzureSasCredential,
34 )
35
36HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse)
37HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)
38
39
40# pylint:disable=too-few-public-methods
41class _BearerTokenCredentialPolicyBase:
42 """Base class for a Bearer Token Credential Policy.
43
44 :param credential: The credential.
45 :type credential: ~azure.core.credentials.TokenProvider
46 :param str scopes: Lets you specify the type of access needed.
47 :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested
48 tokens. Defaults to False.
49 """
50
51 def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None:
52 super(_BearerTokenCredentialPolicyBase, self).__init__()
53 self._scopes = scopes
54 self._credential = credential
55 self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None
56 self._enable_cae: bool = kwargs.get("enable_cae", False)
57
58 @staticmethod
59 def _enforce_https(request: PipelineRequest[HTTPRequestType]) -> None:
60 # move 'enforce_https' from options to context so it persists
61 # across retries but isn't passed to a transport implementation
62 option = request.context.options.pop("enforce_https", None)
63
64 # True is the default setting; we needn't preserve an explicit opt in to the default behavior
65 if option is False:
66 request.context["enforce_https"] = option
67
68 enforce_https = request.context.get("enforce_https", True)
69 if enforce_https and not request.http_request.url.lower().startswith("https"):
70 raise ServiceRequestError(
71 "Bearer token authentication is not permitted for non-TLS protected (non-https) URLs."
72 )
73
74 @staticmethod
75 def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
76 """Updates the Authorization header with the bearer token.
77
78 :param MutableMapping[str, str] headers: The HTTP Request headers
79 :param str token: The OAuth token.
80 """
81 headers["Authorization"] = "Bearer {}".format(token)
82
83 @property
84 def _need_new_token(self) -> bool:
85 now = time.time()
86 refresh_on = getattr(self._token, "refresh_on", None)
87 return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300
88
89 def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "AccessTokenInfo"]:
90 if self._enable_cae:
91 kwargs.setdefault("enable_cae", self._enable_cae)
92
93 if hasattr(self._credential, "get_token_info"):
94 options: TokenRequestOptions = {}
95 # Loop through all the keyword arguments and check if they are part of the TokenRequestOptions.
96 for key in list(kwargs.keys()):
97 if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member
98 options[key] = kwargs.pop(key) # type: ignore[literal-required]
99
100 return cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options)
101 return cast(TokenCredential, self._credential).get_token(*scopes, **kwargs)
102
103 def _request_token(self, *scopes: str, **kwargs: Any) -> None:
104 """Request a new token from the credential.
105
106 This will call the credential's appropriate method to get a token and store it in the policy.
107
108 :param str scopes: The type of access needed.
109 """
110 self._token = self._get_token(*scopes, **kwargs)
111
112
113class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]):
114 """Adds a bearer token Authorization header to requests.
115
116 :param credential: The credential.
117 :type credential: ~azure.core.TokenCredential
118 :param str scopes: Lets you specify the type of access needed.
119 :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested
120 tokens. Defaults to False.
121 :raises ~azure.core.exceptions.ServiceRequestError: If the request fails.
122 """
123
124 def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
125 """Called before the policy sends a request.
126
127 The base implementation authorizes the request with a bearer token.
128
129 :param ~azure.core.pipeline.PipelineRequest request: the request
130 """
131 self._enforce_https(request)
132
133 if self._token is None or self._need_new_token:
134 self._request_token(*self._scopes)
135 bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token
136 self._update_headers(request.http_request.headers, bearer_token)
137
138 def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
139 """Acquire a token from the credential and authorize the request with it.
140
141 Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
142 authorize future requests.
143
144 :param ~azure.core.pipeline.PipelineRequest request: the request
145 :param str scopes: required scopes of authentication
146 """
147 self._request_token(*scopes, **kwargs)
148 bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token
149 self._update_headers(request.http_request.headers, bearer_token)
150
151 def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]:
152 """Authorize request with a bearer token and send it to the next policy
153
154 :param request: The pipeline request object
155 :type request: ~azure.core.pipeline.PipelineRequest
156 :return: The pipeline response object
157 :rtype: ~azure.core.pipeline.PipelineResponse
158 """
159 self.on_request(request)
160 try:
161 response = self.next.send(request)
162 except Exception:
163 self.on_exception(request)
164 raise
165
166 self.on_response(request, response)
167 if response.http_response.status_code == 401:
168 self._token = None # any cached token is invalid
169 if "WWW-Authenticate" in response.http_response.headers:
170 try:
171 request_authorized = self.on_challenge(request, response)
172 except Exception as ex:
173 # If the response is streamed, read it so the error message is immediately available to the user.
174 # Otherwise, a generic error message will be given and the user will have to read the response
175 # body to see the actual error.
176 if response.context.options.get("stream"):
177 try:
178 response.http_response.read() # type: ignore
179 except Exception: # pylint:disable=broad-except
180 pass
181 # Raise the exception from the token request with the original 401 response
182 raise ex from HttpResponseError(response=response.http_response)
183
184 if request_authorized:
185 # if we receive a challenge response, we retrieve a new token
186 # which matches the new target. In this case, we don't want to remove
187 # token from the request so clear the 'insecure_domain_change' tag
188 request.context.options.pop("insecure_domain_change", False)
189 try:
190 response = self.next.send(request)
191 self.on_response(request, response)
192 except Exception:
193 self.on_exception(request)
194 raise
195
196 return response
197
198 def on_challenge(
199 self,
200 request: PipelineRequest[HTTPRequestType],
201 response: PipelineResponse[HTTPRequestType, HTTPResponseType],
202 ) -> bool:
203 """Authorize request according to an authentication challenge
204
205 This method is called when the resource provider responds 401 with a WWW-Authenticate header.
206
207 :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
208 :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
209 :returns: a bool indicating whether the policy should send the request
210 :rtype: bool
211 """
212 headers = response.http_response.headers
213 error = get_challenge_parameter(headers, "Bearer", "error")
214 if error == "insufficient_claims":
215 encoded_claims = get_challenge_parameter(headers, "Bearer", "claims")
216 if not encoded_claims:
217 return False
218 padding_needed = -len(encoded_claims) % 4
219 claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode("utf-8")
220 if claims:
221 self.authorize_request(request, *self._scopes, claims=claims)
222 return True
223 return False
224
225 def on_response(
226 self,
227 request: PipelineRequest[HTTPRequestType],
228 response: PipelineResponse[HTTPRequestType, HTTPResponseType],
229 ) -> None:
230 """Executed after the request comes back from the next policy.
231
232 :param request: Request to be modified after returning from the policy.
233 :type request: ~azure.core.pipeline.PipelineRequest
234 :param response: Pipeline response object
235 :type response: ~azure.core.pipeline.PipelineResponse
236 """
237
238 def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None:
239 """Executed when an exception is raised while executing the next policy.
240
241 This method is executed inside the exception handler.
242
243 :param request: The Pipeline request object
244 :type request: ~azure.core.pipeline.PipelineRequest
245 """
246 # pylint: disable=unused-argument
247 return
248
249
250class AzureKeyCredentialPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
251 """Adds a key header for the provided credential.
252
253 :param credential: The credential used to authenticate requests.
254 :type credential: ~azure.core.credentials.AzureKeyCredential
255 :param str name: The name of the key header used for the credential.
256 :keyword str prefix: The name of the prefix for the header value if any.
257 :raises ValueError: if name is None or empty.
258 :raises TypeError: if name is not a string or if credential is not an instance of AzureKeyCredential.
259 """
260
261 def __init__( # pylint: disable=unused-argument
262 self,
263 credential: "AzureKeyCredential",
264 name: str,
265 *,
266 prefix: Optional[str] = None,
267 **kwargs: Any,
268 ) -> None:
269 super().__init__()
270 if not hasattr(credential, "key"):
271 raise TypeError("String is not a supported credential input type. Use an instance of AzureKeyCredential.")
272 if not name:
273 raise ValueError("name can not be None or empty")
274 if not isinstance(name, str):
275 raise TypeError("name must be a string.")
276 self._credential = credential
277 self._name = name
278 self._prefix = prefix + " " if prefix else ""
279
280 def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
281 """Called before the policy sends a request.
282
283 :param request: The request to be modified before sending.
284 :type request: ~azure.core.pipeline.PipelineRequest
285 """
286 request.http_request.headers[self._name] = f"{self._prefix}{self._credential.key}"
287
288
289class AzureSasCredentialPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
290 """Adds a shared access signature to query for the provided credential.
291
292 :param credential: The credential used to authenticate requests.
293 :type credential: ~azure.core.credentials.AzureSasCredential
294 :raises ValueError: if credential is None.
295 """
296
297 def __init__(
298 self, # pylint: disable=unused-argument
299 credential: "AzureSasCredential",
300 **kwargs: Any,
301 ) -> None:
302 super(AzureSasCredentialPolicy, self).__init__()
303 if not credential:
304 raise ValueError("credential can not be None")
305 self._credential = credential
306
307 def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
308 """Called before the policy sends a request.
309
310 :param request: The request to be modified before sending.
311 :type request: ~azure.core.pipeline.PipelineRequest
312 """
313 url = request.http_request.url
314 query = request.http_request.query
315 signature = self._credential.signature
316 if signature.startswith("?"):
317 signature = signature[1:]
318 if query:
319 if signature not in url:
320 url = url + "&" + signature
321 else:
322 if url.endswith("?"):
323 url = url + signature
324 else:
325 url = url + "?" + signature
326 request.http_request.url = url