Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/azure/core/pipeline/policies/_authentication_async.py: 28%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

104 statements  

1# ------------------------------------------------------------------------- 

2# Copyright (c) Microsoft Corporation. All rights reserved. 

3# Licensed under the MIT License. See LICENSE.txt in the project root for 

4# license information. 

5# ------------------------------------------------------------------------- 

6import time 

7import base64 

8from typing import Any, Awaitable, Optional, cast, TypeVar, Union 

9 

10from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions 

11from azure.core.credentials_async import ( 

12 AsyncTokenCredential, 

13 AsyncSupportsTokenInfo, 

14 AsyncTokenProvider, 

15) 

16from azure.core.exceptions import HttpResponseError 

17from azure.core.pipeline import PipelineRequest, PipelineResponse 

18from azure.core.pipeline.policies import AsyncHTTPPolicy 

19from azure.core.pipeline.policies._authentication import ( 

20 _BearerTokenCredentialPolicyBase, 

21) 

22from azure.core.pipeline.transport import ( 

23 AsyncHttpResponse as LegacyAsyncHttpResponse, 

24 HttpRequest as LegacyHttpRequest, 

25) 

26from azure.core.rest import AsyncHttpResponse, HttpRequest 

27from azure.core.utils._utils import get_running_async_lock 

28from ._utils import get_challenge_parameter 

29 

30from .._tools_async import await_result 

31 

32AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", AsyncHttpResponse, LegacyAsyncHttpResponse) 

33HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) 

34 

35 

36class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]): 

37 """Adds a bearer token Authorization header to requests. 

38 

39 :param credential: The credential. 

40 :type credential: ~azure.core.credentials_async.AsyncTokenProvider 

41 :param str scopes: Lets you specify the type of access needed. 

42 :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested 

43 tokens. Defaults to False. 

44 """ 

45 

46 def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None: 

47 super().__init__() 

48 self._credential = credential 

49 self._scopes = scopes 

50 self._lock_instance = None 

51 self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None 

52 self._enable_cae: bool = kwargs.get("enable_cae", False) 

53 

54 @property 

55 def _lock(self): 

56 if self._lock_instance is None: 

57 self._lock_instance = get_running_async_lock() 

58 return self._lock_instance 

59 

60 async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: 

61 """Adds a bearer token Authorization header to request and sends request to next policy. 

62 

63 :param request: The pipeline request object to be modified. 

64 :type request: ~azure.core.pipeline.PipelineRequest 

65 :raises ~azure.core.exceptions.ServiceRequestError: If the request fails. 

66 """ 

67 _BearerTokenCredentialPolicyBase._enforce_https(request) # pylint:disable=protected-access 

68 

69 if self._token is None or self._need_new_token(): 

70 async with self._lock: 

71 # double check because another coroutine may have acquired a token while we waited to acquire the lock 

72 if self._token is None or self._need_new_token(): 

73 await self._request_token(*self._scopes) 

74 bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token 

75 request.http_request.headers["Authorization"] = "Bearer " + bearer_token 

76 

77 async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None: 

78 """Acquire a token from the credential and authorize the request with it. 

79 

80 Keyword arguments are passed to the credential's get_token method. The token will be cached and used to 

81 authorize future requests. 

82 

83 :param ~azure.core.pipeline.PipelineRequest request: the request 

84 :param str scopes: required scopes of authentication 

85 """ 

86 

87 async with self._lock: 

88 await self._request_token(*scopes, **kwargs) 

89 bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token 

90 request.http_request.headers["Authorization"] = "Bearer " + bearer_token 

91 

92 async def send( 

93 self, request: PipelineRequest[HTTPRequestType] 

94 ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]: 

95 """Authorize request with a bearer token and send it to the next policy 

96 

97 :param request: The pipeline request object 

98 :type request: ~azure.core.pipeline.PipelineRequest 

99 :return: The pipeline response object 

100 :rtype: ~azure.core.pipeline.PipelineResponse 

101 """ 

102 await await_result(self.on_request, request) 

103 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType] 

104 try: 

105 response = await self.next.send(request) 

106 except Exception: 

107 await await_result(self.on_exception, request) 

108 raise 

109 await await_result(self.on_response, request, response) 

110 

111 if response.http_response.status_code == 401: 

112 self._token = None # any cached token is invalid 

113 if "WWW-Authenticate" in response.http_response.headers: 

114 try: 

115 request_authorized = await self.on_challenge(request, response) 

116 except Exception as ex: 

117 # If the response is streamed, read it so the error message is immediately available to the user. 

118 # Otherwise, a generic error message will be given and the user will have to read the response 

119 # body to see the actual error. 

120 if response.context.options.get("stream"): 

121 try: 

122 await response.http_response.read() # type: ignore 

123 except Exception: # pylint:disable=broad-except 

124 pass 

125 

126 # Raise the exception from the token request with the original 401 response 

127 raise ex from HttpResponseError(response=response.http_response) 

128 

129 if request_authorized: 

130 # if we receive a challenge response, we retrieve a new token 

131 # which matches the new target. In this case, we don't want to remove 

132 # token from the request so clear the 'insecure_domain_change' tag 

133 request.context.options.pop("insecure_domain_change", False) 

134 try: 

135 response = await self.next.send(request) 

136 except Exception: 

137 await await_result(self.on_exception, request) 

138 raise 

139 await await_result(self.on_response, request, response) 

140 

141 return response 

142 

143 async def on_challenge( 

144 self, 

145 request: PipelineRequest[HTTPRequestType], 

146 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType], 

147 ) -> bool: 

148 """Authorize request according to an authentication challenge 

149 

150 This method is called when the resource provider responds 401 with a WWW-Authenticate header. 

151 

152 :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge 

153 :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response 

154 :returns: a bool indicating whether the policy should send the request 

155 :rtype: bool 

156 """ 

157 headers = response.http_response.headers 

158 error = get_challenge_parameter(headers, "Bearer", "error") 

159 if error == "insufficient_claims": 

160 encoded_claims = get_challenge_parameter(headers, "Bearer", "claims") 

161 if not encoded_claims: 

162 return False 

163 padding_needed = -len(encoded_claims) % 4 

164 claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode("utf-8") 

165 if claims: 

166 await self.authorize_request(request, *self._scopes, claims=claims) 

167 return True 

168 return False 

169 

170 def on_response( 

171 self, 

172 request: PipelineRequest[HTTPRequestType], 

173 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType], 

174 ) -> Optional[Awaitable[None]]: 

175 """Executed after the request comes back from the next policy. 

176 

177 :param request: Request to be modified after returning from the policy. 

178 :type request: ~azure.core.pipeline.PipelineRequest 

179 :param response: Pipeline response object 

180 :type response: ~azure.core.pipeline.PipelineResponse 

181 """ 

182 

183 def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None: 

184 """Executed when an exception is raised while executing the next policy. 

185 

186 This method is executed inside the exception handler. 

187 

188 :param request: The Pipeline request object 

189 :type request: ~azure.core.pipeline.PipelineRequest 

190 """ 

191 # pylint: disable=unused-argument 

192 return 

193 

194 def _need_new_token(self) -> bool: 

195 now = time.time() 

196 refresh_on = getattr(self._token, "refresh_on", None) 

197 return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 

198 

199 async def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "AccessTokenInfo"]: 

200 if self._enable_cae: 

201 kwargs.setdefault("enable_cae", self._enable_cae) 

202 

203 if hasattr(self._credential, "get_token_info"): 

204 options: TokenRequestOptions = {} 

205 # Loop through all the keyword arguments and check if they are part of the TokenRequestOptions. 

206 for key in list(kwargs.keys()): 

207 if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member 

208 options[key] = kwargs.pop(key) # type: ignore[literal-required] 

209 

210 return await await_result( 

211 cast(AsyncSupportsTokenInfo, self._credential).get_token_info, 

212 *scopes, 

213 options=options, 

214 ) 

215 return await await_result( 

216 cast(AsyncTokenCredential, self._credential).get_token, 

217 *scopes, 

218 **kwargs, 

219 ) 

220 

221 async def _request_token(self, *scopes: str, **kwargs: Any) -> None: 

222 """Request a new token from the credential. 

223 

224 This will call the credential's appropriate method to get a token and store it in the policy. 

225 

226 :param str scopes: The type of access needed. 

227 """ 

228 self._token = await self._get_token(*scopes, **kwargs)