1# -------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See LICENSE.txt in the project root for
4# license information.
5# -------------------------------------------------------------------------
6import time
7import base64
8from typing import Any, Awaitable, Optional, cast, TypeVar, Union
9
10from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions
11from azure.core.credentials_async import (
12 AsyncTokenCredential,
13 AsyncSupportsTokenInfo,
14 AsyncTokenProvider,
15)
16from azure.core.exceptions import HttpResponseError
17from azure.core.pipeline import PipelineRequest, PipelineResponse
18from azure.core.pipeline.policies import AsyncHTTPPolicy
19from azure.core.pipeline.policies._authentication import (
20 _BearerTokenCredentialPolicyBase,
21)
22from azure.core.pipeline.transport import (
23 AsyncHttpResponse as LegacyAsyncHttpResponse,
24 HttpRequest as LegacyHttpRequest,
25)
26from azure.core.rest import AsyncHttpResponse, HttpRequest
27from azure.core.utils._utils import get_running_async_lock
28from ._utils import get_challenge_parameter
29
30from .._tools_async import await_result
31
32AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", AsyncHttpResponse, LegacyAsyncHttpResponse)
33HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)
34
35
36class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]):
37 """Adds a bearer token Authorization header to requests.
38
39 :param credential: The credential.
40 :type credential: ~azure.core.credentials_async.AsyncTokenProvider
41 :param str scopes: Lets you specify the type of access needed.
42 :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested
43 tokens. Defaults to False.
44 """
45
46 def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None:
47 super().__init__()
48 self._credential = credential
49 self._scopes = scopes
50 self._lock_instance = None
51 self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None
52 self._enable_cae: bool = kwargs.get("enable_cae", False)
53
54 @property
55 def _lock(self):
56 if self._lock_instance is None:
57 self._lock_instance = get_running_async_lock()
58 return self._lock_instance
59
60 async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
61 """Adds a bearer token Authorization header to request and sends request to next policy.
62
63 :param request: The pipeline request object to be modified.
64 :type request: ~azure.core.pipeline.PipelineRequest
65 :raises ~azure.core.exceptions.ServiceRequestError: If the request fails.
66 """
67 _BearerTokenCredentialPolicyBase._enforce_https(request) # pylint:disable=protected-access
68
69 if self._token is None or self._need_new_token():
70 async with self._lock:
71 # double check because another coroutine may have acquired a token while we waited to acquire the lock
72 if self._token is None or self._need_new_token():
73 await self._request_token(*self._scopes)
74 bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token
75 request.http_request.headers["Authorization"] = "Bearer " + bearer_token
76
77 async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
78 """Acquire a token from the credential and authorize the request with it.
79
80 Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
81 authorize future requests.
82
83 :param ~azure.core.pipeline.PipelineRequest request: the request
84 :param str scopes: required scopes of authentication
85 """
86
87 async with self._lock:
88 await self._request_token(*scopes, **kwargs)
89 bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token
90 request.http_request.headers["Authorization"] = "Bearer " + bearer_token
91
92 async def send(
93 self, request: PipelineRequest[HTTPRequestType]
94 ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]:
95 """Authorize request with a bearer token and send it to the next policy
96
97 :param request: The pipeline request object
98 :type request: ~azure.core.pipeline.PipelineRequest
99 :return: The pipeline response object
100 :rtype: ~azure.core.pipeline.PipelineResponse
101 """
102 await await_result(self.on_request, request)
103 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]
104 try:
105 response = await self.next.send(request)
106 except Exception:
107 await await_result(self.on_exception, request)
108 raise
109 await await_result(self.on_response, request, response)
110
111 if response.http_response.status_code == 401:
112 self._token = None # any cached token is invalid
113 if "WWW-Authenticate" in response.http_response.headers:
114 try:
115 request_authorized = await self.on_challenge(request, response)
116 except Exception as ex:
117 # If the response is streamed, read it so the error message is immediately available to the user.
118 # Otherwise, a generic error message will be given and the user will have to read the response
119 # body to see the actual error.
120 if response.context.options.get("stream"):
121 try:
122 await response.http_response.read() # type: ignore
123 except Exception: # pylint:disable=broad-except
124 pass
125
126 # Raise the exception from the token request with the original 401 response
127 raise ex from HttpResponseError(response=response.http_response)
128
129 if request_authorized:
130 # if we receive a challenge response, we retrieve a new token
131 # which matches the new target. In this case, we don't want to remove
132 # token from the request so clear the 'insecure_domain_change' tag
133 request.context.options.pop("insecure_domain_change", False)
134 try:
135 response = await self.next.send(request)
136 except Exception:
137 await await_result(self.on_exception, request)
138 raise
139 await await_result(self.on_response, request, response)
140
141 return response
142
143 async def on_challenge(
144 self,
145 request: PipelineRequest[HTTPRequestType],
146 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType],
147 ) -> bool:
148 """Authorize request according to an authentication challenge
149
150 This method is called when the resource provider responds 401 with a WWW-Authenticate header.
151
152 :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
153 :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
154 :returns: a bool indicating whether the policy should send the request
155 :rtype: bool
156 """
157 headers = response.http_response.headers
158 error = get_challenge_parameter(headers, "Bearer", "error")
159 if error == "insufficient_claims":
160 encoded_claims = get_challenge_parameter(headers, "Bearer", "claims")
161 if not encoded_claims:
162 return False
163 padding_needed = -len(encoded_claims) % 4
164 claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode("utf-8")
165 if claims:
166 await self.authorize_request(request, *self._scopes, claims=claims)
167 return True
168 return False
169
170 def on_response(
171 self,
172 request: PipelineRequest[HTTPRequestType],
173 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType],
174 ) -> Optional[Awaitable[None]]:
175 """Executed after the request comes back from the next policy.
176
177 :param request: Request to be modified after returning from the policy.
178 :type request: ~azure.core.pipeline.PipelineRequest
179 :param response: Pipeline response object
180 :type response: ~azure.core.pipeline.PipelineResponse
181 """
182
183 def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None:
184 """Executed when an exception is raised while executing the next policy.
185
186 This method is executed inside the exception handler.
187
188 :param request: The Pipeline request object
189 :type request: ~azure.core.pipeline.PipelineRequest
190 """
191 # pylint: disable=unused-argument
192 return
193
194 def _need_new_token(self) -> bool:
195 now = time.time()
196 refresh_on = getattr(self._token, "refresh_on", None)
197 return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300
198
199 async def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "AccessTokenInfo"]:
200 if self._enable_cae:
201 kwargs.setdefault("enable_cae", self._enable_cae)
202
203 if hasattr(self._credential, "get_token_info"):
204 options: TokenRequestOptions = {}
205 # Loop through all the keyword arguments and check if they are part of the TokenRequestOptions.
206 for key in list(kwargs.keys()):
207 if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member
208 options[key] = kwargs.pop(key) # type: ignore[literal-required]
209
210 return await await_result(
211 cast(AsyncSupportsTokenInfo, self._credential).get_token_info,
212 *scopes,
213 options=options,
214 )
215 return await await_result(
216 cast(AsyncTokenCredential, self._credential).get_token,
217 *scopes,
218 **kwargs,
219 )
220
221 async def _request_token(self, *scopes: str, **kwargs: Any) -> None:
222 """Request a new token from the credential.
223
224 This will call the credential's appropriate method to get a token and store it in the policy.
225
226 :param str scopes: The type of access needed.
227 """
228 self._token = await self._get_token(*scopes, **kwargs)