1# -------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See LICENSE.txt in the project root for
4# license information.
5# -------------------------------------------------------------------------
6import random
7import base64
8from typing import Any, Awaitable, Optional, cast, TypeVar, Union
9
10from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions
11from azure.core.credentials_async import (
12 AsyncTokenCredential,
13 AsyncSupportsTokenInfo,
14 AsyncTokenProvider,
15)
16from azure.core.exceptions import HttpResponseError
17from azure.core.pipeline import PipelineRequest, PipelineResponse
18from azure.core.pipeline.policies import AsyncHTTPPolicy
19from azure.core.pipeline.policies._authentication import (
20 _enforce_https,
21 _should_refresh_token,
22 MAX_REFRESH_JITTER_SECONDS,
23)
24from azure.core.pipeline.transport import (
25 AsyncHttpResponse as LegacyAsyncHttpResponse,
26 HttpRequest as LegacyHttpRequest,
27)
28from azure.core.rest import AsyncHttpResponse, HttpRequest
29from azure.core.utils._utils import get_running_async_lock
30from ._utils import get_challenge_parameter
31
32from .._tools_async import await_result
33
34AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", AsyncHttpResponse, LegacyAsyncHttpResponse)
35HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)
36
37
38class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]):
39 """Adds a bearer token Authorization header to requests.
40
41 :param credential: The credential.
42 :type credential: ~azure.core.credentials_async.AsyncTokenProvider
43 :param str scopes: Lets you specify the type of access needed.
44 :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested
45 tokens. Defaults to False.
46 """
47
48 def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None:
49 super().__init__()
50 self._credential = credential
51 self._scopes = scopes
52 self._lock_instance = None
53 self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None
54 self._enable_cae: bool = kwargs.get("enable_cae", False)
55 self._refresh_jitter = 0
56
57 @property
58 def _lock(self):
59 if self._lock_instance is None:
60 self._lock_instance = get_running_async_lock()
61 return self._lock_instance
62
63 async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
64 """Adds a bearer token Authorization header to request and sends request to next policy.
65
66 :param request: The pipeline request object to be modified.
67 :type request: ~azure.core.pipeline.PipelineRequest
68 :raises ~azure.core.exceptions.ServiceRequestError: If the request fails.
69 """
70 _enforce_https(request)
71
72 if self._token is None or self._need_new_token():
73 async with self._lock:
74 # double check because another coroutine may have acquired a token while we waited to acquire the lock
75 if self._token is None or self._need_new_token():
76 await self._request_token(*self._scopes)
77 bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token
78 request.http_request.headers["Authorization"] = "Bearer " + bearer_token
79
80 async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
81 """Acquire a token from the credential and authorize the request with it.
82
83 Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
84 authorize future requests.
85
86 :param ~azure.core.pipeline.PipelineRequest request: the request
87 :param str scopes: required scopes of authentication
88 """
89
90 async with self._lock:
91 await self._request_token(*scopes, **kwargs)
92 bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token
93 request.http_request.headers["Authorization"] = "Bearer " + bearer_token
94
95 async def send(
96 self, request: PipelineRequest[HTTPRequestType]
97 ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]:
98 """Authorize request with a bearer token and send it to the next policy
99
100 :param request: The pipeline request object
101 :type request: ~azure.core.pipeline.PipelineRequest
102 :return: The pipeline response object
103 :rtype: ~azure.core.pipeline.PipelineResponse
104 """
105 await await_result(self.on_request, request)
106 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]
107 try:
108 response = await self.next.send(request)
109 except Exception:
110 await await_result(self.on_exception, request)
111 raise
112 await await_result(self.on_response, request, response)
113
114 if response.http_response.status_code == 401:
115 self._token = None # any cached token is invalid
116 if "WWW-Authenticate" in response.http_response.headers:
117 try:
118 request_authorized = await self.on_challenge(request, response)
119 except Exception as ex:
120 # If the response is streamed, read it so the error message is immediately available to the user.
121 # Otherwise, a generic error message will be given and the user will have to read the response
122 # body to see the actual error.
123 if response.context.options.get("stream"):
124 try:
125 await response.http_response.read() # type: ignore
126 except Exception: # pylint:disable=broad-except
127 pass
128
129 # Raise the exception from the token request with the original 401 response
130 raise ex from HttpResponseError(response=response.http_response)
131
132 if request_authorized:
133 try:
134 response = await self.next.send(request)
135 except Exception:
136 await await_result(self.on_exception, request)
137 raise
138 await await_result(self.on_response, request, response)
139
140 return response
141
142 async def on_challenge(
143 self,
144 request: PipelineRequest[HTTPRequestType],
145 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType],
146 ) -> bool:
147 """Authorize request according to an authentication challenge
148
149 This method is called when the resource provider responds 401 with a WWW-Authenticate header.
150
151 :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
152 :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
153 :returns: a bool indicating whether the policy should send the request
154 :rtype: bool
155 """
156 headers = response.http_response.headers
157 error = get_challenge_parameter(headers, "Bearer", "error")
158 if error == "insufficient_claims":
159 encoded_claims = get_challenge_parameter(headers, "Bearer", "claims")
160 if not encoded_claims:
161 return False
162 padding_needed = -len(encoded_claims) % 4
163 claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode("utf-8")
164 if claims:
165 await self.authorize_request(request, *self._scopes, claims=claims)
166 return True
167 return False
168
169 def on_response(
170 self,
171 request: PipelineRequest[HTTPRequestType],
172 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType],
173 ) -> Optional[Awaitable[None]]:
174 """Executed after the request comes back from the next policy.
175
176 :param request: Request to be modified after returning from the policy.
177 :type request: ~azure.core.pipeline.PipelineRequest
178 :param response: Pipeline response object
179 :type response: ~azure.core.pipeline.PipelineResponse
180 """
181
182 def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None:
183 """Executed when an exception is raised while executing the next policy.
184
185 This method is executed inside the exception handler.
186
187 :param request: The Pipeline request object
188 :type request: ~azure.core.pipeline.PipelineRequest
189 """
190 # pylint: disable=unused-argument
191 return
192
193 def _need_new_token(self) -> bool:
194 return _should_refresh_token(self._token, self._refresh_jitter)
195
196 async def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "AccessTokenInfo"]:
197 if self._enable_cae:
198 kwargs.setdefault("enable_cae", self._enable_cae)
199
200 if hasattr(self._credential, "get_token_info"):
201 options: TokenRequestOptions = {}
202 # Loop through all the keyword arguments and check if they are part of the TokenRequestOptions.
203 for key in list(kwargs.keys()):
204 if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member
205 options[key] = kwargs.pop(key) # type: ignore[literal-required]
206
207 return await await_result(
208 cast(AsyncSupportsTokenInfo, self._credential).get_token_info,
209 *scopes,
210 options=options,
211 )
212 return await await_result(
213 cast(AsyncTokenCredential, self._credential).get_token,
214 *scopes,
215 **kwargs,
216 )
217
218 async def _request_token(self, *scopes: str, **kwargs: Any) -> None:
219 """Request a new token from the credential.
220
221 This will call the credential's appropriate method to get a token and store it in the policy.
222
223 :param str scopes: The type of access needed.
224 """
225 self._token = await self._get_token(*scopes, **kwargs)
226 self._refresh_jitter = random.randint(0, MAX_REFRESH_JITTER_SECONDS)