1# -------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See LICENSE.txt in the project root for
4# license information.
5# -------------------------------------------------------------------------
6import time
7from typing import Any, Awaitable, Optional, cast, TypeVar, Union
8
9from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions
10from azure.core.credentials_async import AsyncTokenCredential, AsyncSupportsTokenInfo, AsyncTokenProvider
11from azure.core.pipeline import PipelineRequest, PipelineResponse
12from azure.core.pipeline.policies import AsyncHTTPPolicy
13from azure.core.pipeline.policies._authentication import (
14 _BearerTokenCredentialPolicyBase,
15)
16from azure.core.pipeline.transport import AsyncHttpResponse as LegacyAsyncHttpResponse, HttpRequest as LegacyHttpRequest
17from azure.core.rest import AsyncHttpResponse, HttpRequest
18from azure.core.utils._utils import get_running_async_lock
19
20from .._tools_async import await_result
21
22AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", AsyncHttpResponse, LegacyAsyncHttpResponse)
23HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)
24
25
26class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]):
27 """Adds a bearer token Authorization header to requests.
28
29 :param credential: The credential.
30 :type credential: ~azure.core.credentials_async.AsyncTokenProvider
31 :param str scopes: Lets you specify the type of access needed.
32 :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested
33 tokens. Defaults to False.
34 """
35
36 def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None:
37 super().__init__()
38 self._credential = credential
39 self._scopes = scopes
40 self._lock_instance = None
41 self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None
42 self._enable_cae: bool = kwargs.get("enable_cae", False)
43
44 @property
45 def _lock(self):
46 if self._lock_instance is None:
47 self._lock_instance = get_running_async_lock()
48 return self._lock_instance
49
50 async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
51 """Adds a bearer token Authorization header to request and sends request to next policy.
52
53 :param request: The pipeline request object to be modified.
54 :type request: ~azure.core.pipeline.PipelineRequest
55 :raises: :class:`~azure.core.exceptions.ServiceRequestError`
56 """
57 _BearerTokenCredentialPolicyBase._enforce_https(request) # pylint:disable=protected-access
58
59 if self._token is None or self._need_new_token():
60 async with self._lock:
61 # double check because another coroutine may have acquired a token while we waited to acquire the lock
62 if self._token is None or self._need_new_token():
63 await self._request_token(*self._scopes)
64 bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token
65 request.http_request.headers["Authorization"] = "Bearer " + bearer_token
66
67 async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
68 """Acquire a token from the credential and authorize the request with it.
69
70 Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
71 authorize future requests.
72
73 :param ~azure.core.pipeline.PipelineRequest request: the request
74 :param str scopes: required scopes of authentication
75 """
76
77 async with self._lock:
78 await self._request_token(*scopes, **kwargs)
79 bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token
80 request.http_request.headers["Authorization"] = "Bearer " + bearer_token
81
82 async def send(
83 self, request: PipelineRequest[HTTPRequestType]
84 ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]:
85 """Authorize request with a bearer token and send it to the next policy
86
87 :param request: The pipeline request object
88 :type request: ~azure.core.pipeline.PipelineRequest
89 :return: The pipeline response object
90 :rtype: ~azure.core.pipeline.PipelineResponse
91 """
92 await await_result(self.on_request, request)
93 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]
94 try:
95 response = await self.next.send(request)
96 except Exception: # pylint:disable=broad-except
97 await await_result(self.on_exception, request)
98 raise
99 await await_result(self.on_response, request, response)
100
101 if response.http_response.status_code == 401:
102 self._token = None # any cached token is invalid
103 if "WWW-Authenticate" in response.http_response.headers:
104 request_authorized = await self.on_challenge(request, response)
105 if request_authorized:
106 # if we receive a challenge response, we retrieve a new token
107 # which matches the new target. In this case, we don't want to remove
108 # token from the request so clear the 'insecure_domain_change' tag
109 request.context.options.pop("insecure_domain_change", False)
110 try:
111 response = await self.next.send(request)
112 except Exception: # pylint:disable=broad-except
113 await await_result(self.on_exception, request)
114 raise
115 await await_result(self.on_response, request, response)
116
117 return response
118
119 async def on_challenge(
120 self,
121 request: PipelineRequest[HTTPRequestType],
122 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType],
123 ) -> bool:
124 """Authorize request according to an authentication challenge
125
126 This method is called when the resource provider responds 401 with a WWW-Authenticate header.
127
128 :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
129 :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
130 :returns: a bool indicating whether the policy should send the request
131 :rtype: bool
132 """
133 # pylint:disable=unused-argument
134 return False
135
136 def on_response(
137 self,
138 request: PipelineRequest[HTTPRequestType],
139 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType],
140 ) -> Optional[Awaitable[None]]:
141 """Executed after the request comes back from the next policy.
142
143 :param request: Request to be modified after returning from the policy.
144 :type request: ~azure.core.pipeline.PipelineRequest
145 :param response: Pipeline response object
146 :type response: ~azure.core.pipeline.PipelineResponse
147 """
148
149 def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None:
150 """Executed when an exception is raised while executing the next policy.
151
152 This method is executed inside the exception handler.
153
154 :param request: The Pipeline request object
155 :type request: ~azure.core.pipeline.PipelineRequest
156 """
157 # pylint: disable=unused-argument
158 return
159
160 def _need_new_token(self) -> bool:
161 now = time.time()
162 refresh_on = getattr(self._token, "refresh_on", None)
163 return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300
164
165 async def _request_token(self, *scopes: str, **kwargs: Any) -> None:
166 """Request a new token from the credential.
167
168 This will call the credential's appropriate method to get a token and store it in the policy.
169
170 :param str scopes: The type of access needed.
171 """
172 if self._enable_cae:
173 kwargs.setdefault("enable_cae", self._enable_cae)
174
175 if hasattr(self._credential, "get_token_info"):
176 options: TokenRequestOptions = {}
177 # Loop through all the keyword arguments and check if they are part of the TokenRequestOptions.
178 for key in list(kwargs.keys()):
179 if key in TokenRequestOptions.__annotations__: # pylint:disable=no-member
180 options[key] = kwargs.pop(key) # type: ignore[literal-required]
181
182 self._token = await await_result(
183 cast(AsyncSupportsTokenInfo, self._credential).get_token_info,
184 *scopes,
185 options=options,
186 )
187 else:
188 self._token = await await_result(cast(AsyncTokenCredential, self._credential).get_token, *scopes, **kwargs)