1# -------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See LICENSE.txt in the project root for
4# license information.
5# -------------------------------------------------------------------------
6import time
7import base64
8from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any, Union, cast
9from azure.core.credentials import (
10 TokenCredential,
11 SupportsTokenInfo,
12 TokenRequestOptions,
13 TokenProvider,
14)
15from azure.core.pipeline import PipelineRequest, PipelineResponse
16from azure.core.pipeline.transport import (
17 HttpResponse as LegacyHttpResponse,
18 HttpRequest as LegacyHttpRequest,
19)
20from azure.core.rest import HttpResponse, HttpRequest
21from . import HTTPPolicy, SansIOHTTPPolicy
22from ...exceptions import ServiceRequestError
23from ._utils import get_challenge_parameter
24
25if TYPE_CHECKING:
26
27 from azure.core.credentials import (
28 AccessToken,
29 AccessTokenInfo,
30 AzureKeyCredential,
31 AzureSasCredential,
32 )
33
34HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse)
35HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)
36
37
38# pylint:disable=too-few-public-methods
39class _BearerTokenCredentialPolicyBase:
40 """Base class for a Bearer Token Credential Policy.
41
42 :param credential: The credential.
43 :type credential: ~azure.core.credentials.TokenProvider
44 :param str scopes: Lets you specify the type of access needed.
45 :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested
46 tokens. Defaults to False.
47 """
48
49 def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None:
50 super(_BearerTokenCredentialPolicyBase, self).__init__()
51 self._scopes = scopes
52 self._credential = credential
53 self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None
54 self._enable_cae: bool = kwargs.get("enable_cae", False)
55
56 @staticmethod
57 def _enforce_https(request: PipelineRequest[HTTPRequestType]) -> None:
58 # move 'enforce_https' from options to context so it persists
59 # across retries but isn't passed to a transport implementation
60 option = request.context.options.pop("enforce_https", None)
61
62 # True is the default setting; we needn't preserve an explicit opt in to the default behavior
63 if option is False:
64 request.context["enforce_https"] = option
65
66 enforce_https = request.context.get("enforce_https", True)
67 if enforce_https and not request.http_request.url.lower().startswith("https"):
68 raise ServiceRequestError(
69 "Bearer token authentication is not permitted for non-TLS protected (non-https) URLs."
70 )
71
72 @staticmethod
73 def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
74 """Updates the Authorization header with the bearer token.
75
76 :param MutableMapping[str, str] headers: The HTTP Request headers
77 :param str token: The OAuth token.
78 """
79 headers["Authorization"] = "Bearer {}".format(token)
80
81 @property
82 def _need_new_token(self) -> bool:
83 now = time.time()
84 refresh_on = getattr(self._token, "refresh_on", None)
85 return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300
86
87 def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "AccessTokenInfo"]:
88 if self._enable_cae:
89 kwargs.setdefault("enable_cae", self._enable_cae)
90
91 if hasattr(self._credential, "get_token_info"):
92 options: TokenRequestOptions = {}
93 # Loop through all the keyword arguments and check if they are part of the TokenRequestOptions.
94 for key in list(kwargs.keys()):
95 if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member
96 options[key] = kwargs.pop(key) # type: ignore[literal-required]
97
98 return cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options)
99 return cast(TokenCredential, self._credential).get_token(*scopes, **kwargs)
100
101 def _request_token(self, *scopes: str, **kwargs: Any) -> None:
102 """Request a new token from the credential.
103
104 This will call the credential's appropriate method to get a token and store it in the policy.
105
106 :param str scopes: The type of access needed.
107 """
108 self._token = self._get_token(*scopes, **kwargs)
109
110
111class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]):
112 """Adds a bearer token Authorization header to requests.
113
114 :param credential: The credential.
115 :type credential: ~azure.core.TokenCredential
116 :param str scopes: Lets you specify the type of access needed.
117 :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested
118 tokens. Defaults to False.
119 :raises ~azure.core.exceptions.ServiceRequestError: If the request fails.
120 """
121
122 def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
123 """Called before the policy sends a request.
124
125 The base implementation authorizes the request with a bearer token.
126
127 :param ~azure.core.pipeline.PipelineRequest request: the request
128 """
129 self._enforce_https(request)
130
131 if self._token is None or self._need_new_token:
132 self._request_token(*self._scopes)
133 bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token
134 self._update_headers(request.http_request.headers, bearer_token)
135
136 def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
137 """Acquire a token from the credential and authorize the request with it.
138
139 Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
140 authorize future requests.
141
142 :param ~azure.core.pipeline.PipelineRequest request: the request
143 :param str scopes: required scopes of authentication
144 """
145 self._request_token(*scopes, **kwargs)
146 bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token
147 self._update_headers(request.http_request.headers, bearer_token)
148
149 def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]:
150 """Authorize request with a bearer token and send it to the next policy
151
152 :param request: The pipeline request object
153 :type request: ~azure.core.pipeline.PipelineRequest
154 :return: The pipeline response object
155 :rtype: ~azure.core.pipeline.PipelineResponse
156 """
157 self.on_request(request)
158 try:
159 response = self.next.send(request)
160 except Exception:
161 self.on_exception(request)
162 raise
163
164 self.on_response(request, response)
165 if response.http_response.status_code == 401:
166 self._token = None # any cached token is invalid
167 if "WWW-Authenticate" in response.http_response.headers:
168 request_authorized = self.on_challenge(request, response)
169 if request_authorized:
170 # if we receive a challenge response, we retrieve a new token
171 # which matches the new target. In this case, we don't want to remove
172 # token from the request so clear the 'insecure_domain_change' tag
173 request.context.options.pop("insecure_domain_change", False)
174 try:
175 response = self.next.send(request)
176 self.on_response(request, response)
177 except Exception:
178 self.on_exception(request)
179 raise
180
181 return response
182
183 def on_challenge(
184 self,
185 request: PipelineRequest[HTTPRequestType],
186 response: PipelineResponse[HTTPRequestType, HTTPResponseType],
187 ) -> bool:
188 """Authorize request according to an authentication challenge
189
190 This method is called when the resource provider responds 401 with a WWW-Authenticate header.
191
192 :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
193 :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
194 :returns: a bool indicating whether the policy should send the request
195 :rtype: bool
196 """
197 headers = response.http_response.headers
198 error = get_challenge_parameter(headers, "Bearer", "error")
199 if error == "insufficient_claims":
200 encoded_claims = get_challenge_parameter(headers, "Bearer", "claims")
201 if not encoded_claims:
202 return False
203 try:
204 padding_needed = -len(encoded_claims) % 4
205 claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode("utf-8")
206 if claims:
207 token = self._get_token(*self._scopes, claims=claims)
208 bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], token).token
209 request.http_request.headers["Authorization"] = "Bearer " + bearer_token
210 return True
211 except Exception: # pylint:disable=broad-except
212 return False
213 return False
214
215 def on_response(
216 self,
217 request: PipelineRequest[HTTPRequestType],
218 response: PipelineResponse[HTTPRequestType, HTTPResponseType],
219 ) -> None:
220 """Executed after the request comes back from the next policy.
221
222 :param request: Request to be modified after returning from the policy.
223 :type request: ~azure.core.pipeline.PipelineRequest
224 :param response: Pipeline response object
225 :type response: ~azure.core.pipeline.PipelineResponse
226 """
227
228 def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None:
229 """Executed when an exception is raised while executing the next policy.
230
231 This method is executed inside the exception handler.
232
233 :param request: The Pipeline request object
234 :type request: ~azure.core.pipeline.PipelineRequest
235 """
236 # pylint: disable=unused-argument
237 return
238
239
240class AzureKeyCredentialPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
241 """Adds a key header for the provided credential.
242
243 :param credential: The credential used to authenticate requests.
244 :type credential: ~azure.core.credentials.AzureKeyCredential
245 :param str name: The name of the key header used for the credential.
246 :keyword str prefix: The name of the prefix for the header value if any.
247 :raises ValueError: if name is None or empty.
248 :raises TypeError: if name is not a string or if credential is not an instance of AzureKeyCredential.
249 """
250
251 def __init__( # pylint: disable=unused-argument
252 self,
253 credential: "AzureKeyCredential",
254 name: str,
255 *,
256 prefix: Optional[str] = None,
257 **kwargs: Any,
258 ) -> None:
259 super().__init__()
260 if not hasattr(credential, "key"):
261 raise TypeError("String is not a supported credential input type. Use an instance of AzureKeyCredential.")
262 if not name:
263 raise ValueError("name can not be None or empty")
264 if not isinstance(name, str):
265 raise TypeError("name must be a string.")
266 self._credential = credential
267 self._name = name
268 self._prefix = prefix + " " if prefix else ""
269
270 def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
271 """Called before the policy sends a request.
272
273 :param request: The request to be modified before sending.
274 :type request: ~azure.core.pipeline.PipelineRequest
275 """
276 request.http_request.headers[self._name] = f"{self._prefix}{self._credential.key}"
277
278
279class AzureSasCredentialPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
280 """Adds a shared access signature to query for the provided credential.
281
282 :param credential: The credential used to authenticate requests.
283 :type credential: ~azure.core.credentials.AzureSasCredential
284 :raises ValueError: if credential is None.
285 """
286
287 def __init__(
288 self, # pylint: disable=unused-argument
289 credential: "AzureSasCredential",
290 **kwargs: Any,
291 ) -> None:
292 super(AzureSasCredentialPolicy, self).__init__()
293 if not credential:
294 raise ValueError("credential can not be None")
295 self._credential = credential
296
297 def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
298 """Called before the policy sends a request.
299
300 :param request: The request to be modified before sending.
301 :type request: ~azure.core.pipeline.PipelineRequest
302 """
303 url = request.http_request.url
304 query = request.http_request.query
305 signature = self._credential.signature
306 if signature.startswith("?"):
307 signature = signature[1:]
308 if query:
309 if signature not in url:
310 url = url + "&" + signature
311 else:
312 if url.endswith("?"):
313 url = url + signature
314 else:
315 url = url + "?" + signature
316 request.http_request.url = url