Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/azure/core/pipeline/policies/_authentication_async.py: 32%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

79 statements  

1# ------------------------------------------------------------------------- 

2# Copyright (c) Microsoft Corporation. All rights reserved. 

3# Licensed under the MIT License. See LICENSE.txt in the project root for 

4# license information. 

5# ------------------------------------------------------------------------- 

6import time 

7from typing import Any, Awaitable, Optional, cast, TypeVar, Union 

8 

9from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions 

10from azure.core.credentials_async import AsyncTokenCredential, AsyncSupportsTokenInfo, AsyncTokenProvider 

11from azure.core.pipeline import PipelineRequest, PipelineResponse 

12from azure.core.pipeline.policies import AsyncHTTPPolicy 

13from azure.core.pipeline.policies._authentication import ( 

14 _BearerTokenCredentialPolicyBase, 

15) 

16from azure.core.pipeline.transport import AsyncHttpResponse as LegacyAsyncHttpResponse, HttpRequest as LegacyHttpRequest 

17from azure.core.rest import AsyncHttpResponse, HttpRequest 

18from azure.core.utils._utils import get_running_async_lock 

19 

20from .._tools_async import await_result 

21 

22AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", AsyncHttpResponse, LegacyAsyncHttpResponse) 

23HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) 

24 

25 

26class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]): 

27 """Adds a bearer token Authorization header to requests. 

28 

29 :param credential: The credential. 

30 :type credential: ~azure.core.credentials_async.AsyncTokenProvider 

31 :param str scopes: Lets you specify the type of access needed. 

32 :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested 

33 tokens. Defaults to False. 

34 """ 

35 

36 def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None: 

37 super().__init__() 

38 self._credential = credential 

39 self._scopes = scopes 

40 self._lock_instance = None 

41 self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None 

42 self._enable_cae: bool = kwargs.get("enable_cae", False) 

43 

44 @property 

45 def _lock(self): 

46 if self._lock_instance is None: 

47 self._lock_instance = get_running_async_lock() 

48 return self._lock_instance 

49 

50 async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: 

51 """Adds a bearer token Authorization header to request and sends request to next policy. 

52 

53 :param request: The pipeline request object to be modified. 

54 :type request: ~azure.core.pipeline.PipelineRequest 

55 :raises: :class:`~azure.core.exceptions.ServiceRequestError` 

56 """ 

57 _BearerTokenCredentialPolicyBase._enforce_https(request) # pylint:disable=protected-access 

58 

59 if self._token is None or self._need_new_token(): 

60 async with self._lock: 

61 # double check because another coroutine may have acquired a token while we waited to acquire the lock 

62 if self._token is None or self._need_new_token(): 

63 await self._request_token(*self._scopes) 

64 bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token 

65 request.http_request.headers["Authorization"] = "Bearer " + bearer_token 

66 

67 async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None: 

68 """Acquire a token from the credential and authorize the request with it. 

69 

70 Keyword arguments are passed to the credential's get_token method. The token will be cached and used to 

71 authorize future requests. 

72 

73 :param ~azure.core.pipeline.PipelineRequest request: the request 

74 :param str scopes: required scopes of authentication 

75 """ 

76 

77 async with self._lock: 

78 await self._request_token(*scopes, **kwargs) 

79 bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token 

80 request.http_request.headers["Authorization"] = "Bearer " + bearer_token 

81 

82 async def send( 

83 self, request: PipelineRequest[HTTPRequestType] 

84 ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]: 

85 """Authorize request with a bearer token and send it to the next policy 

86 

87 :param request: The pipeline request object 

88 :type request: ~azure.core.pipeline.PipelineRequest 

89 :return: The pipeline response object 

90 :rtype: ~azure.core.pipeline.PipelineResponse 

91 """ 

92 await await_result(self.on_request, request) 

93 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType] 

94 try: 

95 response = await self.next.send(request) 

96 except Exception: # pylint:disable=broad-except 

97 await await_result(self.on_exception, request) 

98 raise 

99 await await_result(self.on_response, request, response) 

100 

101 if response.http_response.status_code == 401: 

102 self._token = None # any cached token is invalid 

103 if "WWW-Authenticate" in response.http_response.headers: 

104 request_authorized = await self.on_challenge(request, response) 

105 if request_authorized: 

106 # if we receive a challenge response, we retrieve a new token 

107 # which matches the new target. In this case, we don't want to remove 

108 # token from the request so clear the 'insecure_domain_change' tag 

109 request.context.options.pop("insecure_domain_change", False) 

110 try: 

111 response = await self.next.send(request) 

112 except Exception: # pylint:disable=broad-except 

113 await await_result(self.on_exception, request) 

114 raise 

115 await await_result(self.on_response, request, response) 

116 

117 return response 

118 

119 async def on_challenge( 

120 self, 

121 request: PipelineRequest[HTTPRequestType], 

122 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType], 

123 ) -> bool: 

124 """Authorize request according to an authentication challenge 

125 

126 This method is called when the resource provider responds 401 with a WWW-Authenticate header. 

127 

128 :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge 

129 :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response 

130 :returns: a bool indicating whether the policy should send the request 

131 :rtype: bool 

132 """ 

133 # pylint:disable=unused-argument 

134 return False 

135 

136 def on_response( 

137 self, 

138 request: PipelineRequest[HTTPRequestType], 

139 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType], 

140 ) -> Optional[Awaitable[None]]: 

141 """Executed after the request comes back from the next policy. 

142 

143 :param request: Request to be modified after returning from the policy. 

144 :type request: ~azure.core.pipeline.PipelineRequest 

145 :param response: Pipeline response object 

146 :type response: ~azure.core.pipeline.PipelineResponse 

147 """ 

148 

149 def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None: 

150 """Executed when an exception is raised while executing the next policy. 

151 

152 This method is executed inside the exception handler. 

153 

154 :param request: The Pipeline request object 

155 :type request: ~azure.core.pipeline.PipelineRequest 

156 """ 

157 # pylint: disable=unused-argument 

158 return 

159 

160 def _need_new_token(self) -> bool: 

161 now = time.time() 

162 refresh_on = getattr(self._token, "refresh_on", None) 

163 return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 

164 

165 async def _request_token(self, *scopes: str, **kwargs: Any) -> None: 

166 """Request a new token from the credential. 

167 

168 This will call the credential's appropriate method to get a token and store it in the policy. 

169 

170 :param str scopes: The type of access needed. 

171 """ 

172 if self._enable_cae: 

173 kwargs.setdefault("enable_cae", self._enable_cae) 

174 

175 if hasattr(self._credential, "get_token_info"): 

176 options: TokenRequestOptions = {} 

177 # Loop through all the keyword arguments and check if they are part of the TokenRequestOptions. 

178 for key in list(kwargs.keys()): 

179 if key in TokenRequestOptions.__annotations__: # pylint:disable=no-member 

180 options[key] = kwargs.pop(key) # type: ignore[literal-required] 

181 

182 self._token = await await_result( 

183 cast(AsyncSupportsTokenInfo, self._credential).get_token_info, 

184 *scopes, 

185 options=options, 

186 ) 

187 else: 

188 self._token = await await_result(cast(AsyncTokenCredential, self._credential).get_token, *scopes, **kwargs)