Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/azure/core/pipeline/policies/_authentication_async.py: 28%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

103 statements  

1# ------------------------------------------------------------------------- 

2# Copyright (c) Microsoft Corporation. All rights reserved. 

3# Licensed under the MIT License. See LICENSE.txt in the project root for 

4# license information. 

5# ------------------------------------------------------------------------- 

6import random 

7import base64 

8from typing import Any, Awaitable, Optional, cast, TypeVar, Union 

9 

10from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions 

11from azure.core.credentials_async import ( 

12 AsyncTokenCredential, 

13 AsyncSupportsTokenInfo, 

14 AsyncTokenProvider, 

15) 

16from azure.core.exceptions import HttpResponseError 

17from azure.core.pipeline import PipelineRequest, PipelineResponse 

18from azure.core.pipeline.policies import AsyncHTTPPolicy 

19from azure.core.pipeline.policies._authentication import ( 

20 _enforce_https, 

21 _should_refresh_token, 

22 MAX_REFRESH_JITTER_SECONDS, 

23) 

24from azure.core.pipeline.transport import ( 

25 AsyncHttpResponse as LegacyAsyncHttpResponse, 

26 HttpRequest as LegacyHttpRequest, 

27) 

28from azure.core.rest import AsyncHttpResponse, HttpRequest 

29from azure.core.utils._utils import get_running_async_lock 

30from ._utils import get_challenge_parameter 

31 

32from .._tools_async import await_result 

33 

34AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", AsyncHttpResponse, LegacyAsyncHttpResponse) 

35HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) 

36 

37 

38class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]): 

39 """Adds a bearer token Authorization header to requests. 

40 

41 :param credential: The credential. 

42 :type credential: ~azure.core.credentials_async.AsyncTokenProvider 

43 :param str scopes: Lets you specify the type of access needed. 

44 :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested 

45 tokens. Defaults to False. 

46 """ 

47 

48 def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None: 

49 super().__init__() 

50 self._credential = credential 

51 self._scopes = scopes 

52 self._lock_instance = None 

53 self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None 

54 self._enable_cae: bool = kwargs.get("enable_cae", False) 

55 self._refresh_jitter = 0 

56 

57 @property 

58 def _lock(self): 

59 if self._lock_instance is None: 

60 self._lock_instance = get_running_async_lock() 

61 return self._lock_instance 

62 

63 async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: 

64 """Adds a bearer token Authorization header to request and sends request to next policy. 

65 

66 :param request: The pipeline request object to be modified. 

67 :type request: ~azure.core.pipeline.PipelineRequest 

68 :raises ~azure.core.exceptions.ServiceRequestError: If the request fails. 

69 """ 

70 _enforce_https(request) 

71 

72 if self._token is None or self._need_new_token(): 

73 async with self._lock: 

74 # double check because another coroutine may have acquired a token while we waited to acquire the lock 

75 if self._token is None or self._need_new_token(): 

76 await self._request_token(*self._scopes) 

77 bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token 

78 request.http_request.headers["Authorization"] = "Bearer " + bearer_token 

79 

80 async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None: 

81 """Acquire a token from the credential and authorize the request with it. 

82 

83 Keyword arguments are passed to the credential's get_token method. The token will be cached and used to 

84 authorize future requests. 

85 

86 :param ~azure.core.pipeline.PipelineRequest request: the request 

87 :param str scopes: required scopes of authentication 

88 """ 

89 

90 async with self._lock: 

91 await self._request_token(*scopes, **kwargs) 

92 bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token 

93 request.http_request.headers["Authorization"] = "Bearer " + bearer_token 

94 

95 async def send( 

96 self, request: PipelineRequest[HTTPRequestType] 

97 ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]: 

98 """Authorize request with a bearer token and send it to the next policy 

99 

100 :param request: The pipeline request object 

101 :type request: ~azure.core.pipeline.PipelineRequest 

102 :return: The pipeline response object 

103 :rtype: ~azure.core.pipeline.PipelineResponse 

104 """ 

105 await await_result(self.on_request, request) 

106 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType] 

107 try: 

108 response = await self.next.send(request) 

109 except Exception: 

110 await await_result(self.on_exception, request) 

111 raise 

112 await await_result(self.on_response, request, response) 

113 

114 if response.http_response.status_code == 401: 

115 self._token = None # any cached token is invalid 

116 if "WWW-Authenticate" in response.http_response.headers: 

117 try: 

118 request_authorized = await self.on_challenge(request, response) 

119 except Exception as ex: 

120 # If the response is streamed, read it so the error message is immediately available to the user. 

121 # Otherwise, a generic error message will be given and the user will have to read the response 

122 # body to see the actual error. 

123 if response.context.options.get("stream"): 

124 try: 

125 await response.http_response.read() # type: ignore 

126 except Exception: # pylint:disable=broad-except 

127 pass 

128 

129 # Raise the exception from the token request with the original 401 response 

130 raise ex from HttpResponseError(response=response.http_response) 

131 

132 if request_authorized: 

133 try: 

134 response = await self.next.send(request) 

135 except Exception: 

136 await await_result(self.on_exception, request) 

137 raise 

138 await await_result(self.on_response, request, response) 

139 

140 return response 

141 

142 async def on_challenge( 

143 self, 

144 request: PipelineRequest[HTTPRequestType], 

145 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType], 

146 ) -> bool: 

147 """Authorize request according to an authentication challenge 

148 

149 This method is called when the resource provider responds 401 with a WWW-Authenticate header. 

150 

151 :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge 

152 :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response 

153 :returns: a bool indicating whether the policy should send the request 

154 :rtype: bool 

155 """ 

156 headers = response.http_response.headers 

157 error = get_challenge_parameter(headers, "Bearer", "error") 

158 if error == "insufficient_claims": 

159 encoded_claims = get_challenge_parameter(headers, "Bearer", "claims") 

160 if not encoded_claims: 

161 return False 

162 padding_needed = -len(encoded_claims) % 4 

163 claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode("utf-8") 

164 if claims: 

165 await self.authorize_request(request, *self._scopes, claims=claims) 

166 return True 

167 return False 

168 

169 def on_response( 

170 self, 

171 request: PipelineRequest[HTTPRequestType], 

172 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType], 

173 ) -> Optional[Awaitable[None]]: 

174 """Executed after the request comes back from the next policy. 

175 

176 :param request: Request to be modified after returning from the policy. 

177 :type request: ~azure.core.pipeline.PipelineRequest 

178 :param response: Pipeline response object 

179 :type response: ~azure.core.pipeline.PipelineResponse 

180 """ 

181 

182 def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None: 

183 """Executed when an exception is raised while executing the next policy. 

184 

185 This method is executed inside the exception handler. 

186 

187 :param request: The Pipeline request object 

188 :type request: ~azure.core.pipeline.PipelineRequest 

189 """ 

190 # pylint: disable=unused-argument 

191 return 

192 

193 def _need_new_token(self) -> bool: 

194 return _should_refresh_token(self._token, self._refresh_jitter) 

195 

196 async def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "AccessTokenInfo"]: 

197 if self._enable_cae: 

198 kwargs.setdefault("enable_cae", self._enable_cae) 

199 

200 if hasattr(self._credential, "get_token_info"): 

201 options: TokenRequestOptions = {} 

202 # Loop through all the keyword arguments and check if they are part of the TokenRequestOptions. 

203 for key in list(kwargs.keys()): 

204 if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member 

205 options[key] = kwargs.pop(key) # type: ignore[literal-required] 

206 

207 return await await_result( 

208 cast(AsyncSupportsTokenInfo, self._credential).get_token_info, 

209 *scopes, 

210 options=options, 

211 ) 

212 return await await_result( 

213 cast(AsyncTokenCredential, self._credential).get_token, 

214 *scopes, 

215 **kwargs, 

216 ) 

217 

218 async def _request_token(self, *scopes: str, **kwargs: Any) -> None: 

219 """Request a new token from the credential. 

220 

221 This will call the credential's appropriate method to get a token and store it in the policy. 

222 

223 :param str scopes: The type of access needed. 

224 """ 

225 self._token = await self._get_token(*scopes, **kwargs) 

226 self._refresh_jitter = random.randint(0, MAX_REFRESH_JITTER_SECONDS)