1# -------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See LICENSE.txt in the project root for
4# license information.
5# -------------------------------------------------------------------------
6import time
7import base64
8from typing import Any, Awaitable, Optional, cast, TypeVar, Union
9
10from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions
11from azure.core.credentials_async import (
12 AsyncTokenCredential,
13 AsyncSupportsTokenInfo,
14 AsyncTokenProvider,
15)
16from azure.core.pipeline import PipelineRequest, PipelineResponse
17from azure.core.pipeline.policies import AsyncHTTPPolicy
18from azure.core.pipeline.policies._authentication import (
19 _BearerTokenCredentialPolicyBase,
20)
21from azure.core.pipeline.transport import (
22 AsyncHttpResponse as LegacyAsyncHttpResponse,
23 HttpRequest as LegacyHttpRequest,
24)
25from azure.core.rest import AsyncHttpResponse, HttpRequest
26from azure.core.utils._utils import get_running_async_lock
27from ._utils import get_challenge_parameter
28
29from .._tools_async import await_result
30
31AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", AsyncHttpResponse, LegacyAsyncHttpResponse)
32HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)
33
34
35class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]):
36 """Adds a bearer token Authorization header to requests.
37
38 :param credential: The credential.
39 :type credential: ~azure.core.credentials_async.AsyncTokenProvider
40 :param str scopes: Lets you specify the type of access needed.
41 :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested
42 tokens. Defaults to False.
43 """
44
45 def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None:
46 super().__init__()
47 self._credential = credential
48 self._scopes = scopes
49 self._lock_instance = None
50 self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None
51 self._enable_cae: bool = kwargs.get("enable_cae", False)
52
53 @property
54 def _lock(self):
55 if self._lock_instance is None:
56 self._lock_instance = get_running_async_lock()
57 return self._lock_instance
58
59 async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
60 """Adds a bearer token Authorization header to request and sends request to next policy.
61
62 :param request: The pipeline request object to be modified.
63 :type request: ~azure.core.pipeline.PipelineRequest
64 :raises ~azure.core.exceptions.ServiceRequestError: If the request fails.
65 """
66 _BearerTokenCredentialPolicyBase._enforce_https(request) # pylint:disable=protected-access
67
68 if self._token is None or self._need_new_token():
69 async with self._lock:
70 # double check because another coroutine may have acquired a token while we waited to acquire the lock
71 if self._token is None or self._need_new_token():
72 await self._request_token(*self._scopes)
73 bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token
74 request.http_request.headers["Authorization"] = "Bearer " + bearer_token
75
76 async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
77 """Acquire a token from the credential and authorize the request with it.
78
79 Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
80 authorize future requests.
81
82 :param ~azure.core.pipeline.PipelineRequest request: the request
83 :param str scopes: required scopes of authentication
84 """
85
86 async with self._lock:
87 await self._request_token(*scopes, **kwargs)
88 bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token
89 request.http_request.headers["Authorization"] = "Bearer " + bearer_token
90
91 async def send(
92 self, request: PipelineRequest[HTTPRequestType]
93 ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]:
94 """Authorize request with a bearer token and send it to the next policy
95
96 :param request: The pipeline request object
97 :type request: ~azure.core.pipeline.PipelineRequest
98 :return: The pipeline response object
99 :rtype: ~azure.core.pipeline.PipelineResponse
100 """
101 await await_result(self.on_request, request)
102 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]
103 try:
104 response = await self.next.send(request)
105 except Exception:
106 await await_result(self.on_exception, request)
107 raise
108 await await_result(self.on_response, request, response)
109
110 if response.http_response.status_code == 401:
111 self._token = None # any cached token is invalid
112 if "WWW-Authenticate" in response.http_response.headers:
113 request_authorized = await self.on_challenge(request, response)
114 if request_authorized:
115 # if we receive a challenge response, we retrieve a new token
116 # which matches the new target. In this case, we don't want to remove
117 # token from the request so clear the 'insecure_domain_change' tag
118 request.context.options.pop("insecure_domain_change", False)
119 try:
120 response = await self.next.send(request)
121 except Exception:
122 await await_result(self.on_exception, request)
123 raise
124 await await_result(self.on_response, request, response)
125
126 return response
127
128 async def on_challenge(
129 self,
130 request: PipelineRequest[HTTPRequestType],
131 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType],
132 ) -> bool:
133 """Authorize request according to an authentication challenge
134
135 This method is called when the resource provider responds 401 with a WWW-Authenticate header.
136
137 :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
138 :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
139 :returns: a bool indicating whether the policy should send the request
140 :rtype: bool
141 """
142 headers = response.http_response.headers
143 error = get_challenge_parameter(headers, "Bearer", "error")
144 if error == "insufficient_claims":
145 encoded_claims = get_challenge_parameter(headers, "Bearer", "claims")
146 if not encoded_claims:
147 return False
148 try:
149 padding_needed = -len(encoded_claims) % 4
150 claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode("utf-8")
151 if claims:
152 await self.authorize_request(request, *self._scopes, claims=claims)
153 return True
154 except Exception: # pylint:disable=broad-except
155 return False
156 return False
157
158 def on_response(
159 self,
160 request: PipelineRequest[HTTPRequestType],
161 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType],
162 ) -> Optional[Awaitable[None]]:
163 """Executed after the request comes back from the next policy.
164
165 :param request: Request to be modified after returning from the policy.
166 :type request: ~azure.core.pipeline.PipelineRequest
167 :param response: Pipeline response object
168 :type response: ~azure.core.pipeline.PipelineResponse
169 """
170
171 def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None:
172 """Executed when an exception is raised while executing the next policy.
173
174 This method is executed inside the exception handler.
175
176 :param request: The Pipeline request object
177 :type request: ~azure.core.pipeline.PipelineRequest
178 """
179 # pylint: disable=unused-argument
180 return
181
182 def _need_new_token(self) -> bool:
183 now = time.time()
184 refresh_on = getattr(self._token, "refresh_on", None)
185 return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300
186
187 async def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "AccessTokenInfo"]:
188 if self._enable_cae:
189 kwargs.setdefault("enable_cae", self._enable_cae)
190
191 if hasattr(self._credential, "get_token_info"):
192 options: TokenRequestOptions = {}
193 # Loop through all the keyword arguments and check if they are part of the TokenRequestOptions.
194 for key in list(kwargs.keys()):
195 if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member
196 options[key] = kwargs.pop(key) # type: ignore[literal-required]
197
198 return await await_result(
199 cast(AsyncSupportsTokenInfo, self._credential).get_token_info,
200 *scopes,
201 options=options,
202 )
203 return await await_result(
204 cast(AsyncTokenCredential, self._credential).get_token,
205 *scopes,
206 **kwargs,
207 )
208
209 async def _request_token(self, *scopes: str, **kwargs: Any) -> None:
210 """Request a new token from the credential.
211
212 This will call the credential's appropriate method to get a token and store it in the policy.
213
214 :param str scopes: The type of access needed.
215 """
216 self._token = await self._get_token(*scopes, **kwargs)