Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/azure/core/pipeline/policies/_authentication.py: 26%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

140 statements  

1# ------------------------------------------------------------------------- 

2# Copyright (c) Microsoft Corporation. All rights reserved. 

3# Licensed under the MIT License. See LICENSE.txt in the project root for 

4# license information. 

5# ------------------------------------------------------------------------- 

6import time 

7import base64 

8from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any, Union, cast 

9 

10from azure.core.credentials import ( 

11 TokenCredential, 

12 SupportsTokenInfo, 

13 TokenRequestOptions, 

14 TokenProvider, 

15) 

16from azure.core.exceptions import HttpResponseError 

17from azure.core.pipeline import PipelineRequest, PipelineResponse 

18from azure.core.pipeline.transport import ( 

19 HttpResponse as LegacyHttpResponse, 

20 HttpRequest as LegacyHttpRequest, 

21) 

22from azure.core.rest import HttpResponse, HttpRequest 

23from . import HTTPPolicy, SansIOHTTPPolicy 

24from ...exceptions import ServiceRequestError 

25from ._utils import get_challenge_parameter 

26 

27if TYPE_CHECKING: 

28 

29 from azure.core.credentials import ( 

30 AccessToken, 

31 AccessTokenInfo, 

32 AzureKeyCredential, 

33 AzureSasCredential, 

34 ) 

35 

36HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse) 

37HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest) 

38 

39 

40# pylint:disable=too-few-public-methods 

41class _BearerTokenCredentialPolicyBase: 

42 """Base class for a Bearer Token Credential Policy. 

43 

44 :param credential: The credential. 

45 :type credential: ~azure.core.credentials.TokenProvider 

46 :param str scopes: Lets you specify the type of access needed. 

47 :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested 

48 tokens. Defaults to False. 

49 """ 

50 

51 def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None: 

52 super(_BearerTokenCredentialPolicyBase, self).__init__() 

53 self._scopes = scopes 

54 self._credential = credential 

55 self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None 

56 self._enable_cae: bool = kwargs.get("enable_cae", False) 

57 

58 @staticmethod 

59 def _enforce_https(request: PipelineRequest[HTTPRequestType]) -> None: 

60 # move 'enforce_https' from options to context so it persists 

61 # across retries but isn't passed to a transport implementation 

62 option = request.context.options.pop("enforce_https", None) 

63 

64 # True is the default setting; we needn't preserve an explicit opt in to the default behavior 

65 if option is False: 

66 request.context["enforce_https"] = option 

67 

68 enforce_https = request.context.get("enforce_https", True) 

69 if enforce_https and not request.http_request.url.lower().startswith("https"): 

70 raise ServiceRequestError( 

71 "Bearer token authentication is not permitted for non-TLS protected (non-https) URLs." 

72 ) 

73 

74 @staticmethod 

75 def _update_headers(headers: MutableMapping[str, str], token: str) -> None: 

76 """Updates the Authorization header with the bearer token. 

77 

78 :param MutableMapping[str, str] headers: The HTTP Request headers 

79 :param str token: The OAuth token. 

80 """ 

81 headers["Authorization"] = "Bearer {}".format(token) 

82 

83 @property 

84 def _need_new_token(self) -> bool: 

85 now = time.time() 

86 refresh_on = getattr(self._token, "refresh_on", None) 

87 return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 

88 

89 def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "AccessTokenInfo"]: 

90 if self._enable_cae: 

91 kwargs.setdefault("enable_cae", self._enable_cae) 

92 

93 if hasattr(self._credential, "get_token_info"): 

94 options: TokenRequestOptions = {} 

95 # Loop through all the keyword arguments and check if they are part of the TokenRequestOptions. 

96 for key in list(kwargs.keys()): 

97 if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member 

98 options[key] = kwargs.pop(key) # type: ignore[literal-required] 

99 

100 return cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options) 

101 return cast(TokenCredential, self._credential).get_token(*scopes, **kwargs) 

102 

103 def _request_token(self, *scopes: str, **kwargs: Any) -> None: 

104 """Request a new token from the credential. 

105 

106 This will call the credential's appropriate method to get a token and store it in the policy. 

107 

108 :param str scopes: The type of access needed. 

109 """ 

110 self._token = self._get_token(*scopes, **kwargs) 

111 

112 

113class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]): 

114 """Adds a bearer token Authorization header to requests. 

115 

116 :param credential: The credential. 

117 :type credential: ~azure.core.TokenCredential 

118 :param str scopes: Lets you specify the type of access needed. 

119 :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested 

120 tokens. Defaults to False. 

121 :raises ~azure.core.exceptions.ServiceRequestError: If the request fails. 

122 """ 

123 

124 def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: 

125 """Called before the policy sends a request. 

126 

127 The base implementation authorizes the request with a bearer token. 

128 

129 :param ~azure.core.pipeline.PipelineRequest request: the request 

130 """ 

131 self._enforce_https(request) 

132 

133 if self._token is None or self._need_new_token: 

134 self._request_token(*self._scopes) 

135 bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token 

136 self._update_headers(request.http_request.headers, bearer_token) 

137 

138 def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None: 

139 """Acquire a token from the credential and authorize the request with it. 

140 

141 Keyword arguments are passed to the credential's get_token method. The token will be cached and used to 

142 authorize future requests. 

143 

144 :param ~azure.core.pipeline.PipelineRequest request: the request 

145 :param str scopes: required scopes of authentication 

146 """ 

147 self._request_token(*scopes, **kwargs) 

148 bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token 

149 self._update_headers(request.http_request.headers, bearer_token) 

150 

151 def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]: 

152 """Authorize request with a bearer token and send it to the next policy 

153 

154 :param request: The pipeline request object 

155 :type request: ~azure.core.pipeline.PipelineRequest 

156 :return: The pipeline response object 

157 :rtype: ~azure.core.pipeline.PipelineResponse 

158 """ 

159 self.on_request(request) 

160 try: 

161 response = self.next.send(request) 

162 except Exception: 

163 self.on_exception(request) 

164 raise 

165 

166 self.on_response(request, response) 

167 if response.http_response.status_code == 401: 

168 self._token = None # any cached token is invalid 

169 if "WWW-Authenticate" in response.http_response.headers: 

170 try: 

171 request_authorized = self.on_challenge(request, response) 

172 except Exception as ex: 

173 # If the response is streamed, read it so the error message is immediately available to the user. 

174 # Otherwise, a generic error message will be given and the user will have to read the response 

175 # body to see the actual error. 

176 if response.context.options.get("stream"): 

177 try: 

178 response.http_response.read() # type: ignore 

179 except Exception: # pylint:disable=broad-except 

180 pass 

181 # Raise the exception from the token request with the original 401 response 

182 raise ex from HttpResponseError(response=response.http_response) 

183 

184 if request_authorized: 

185 # if we receive a challenge response, we retrieve a new token 

186 # which matches the new target. In this case, we don't want to remove 

187 # token from the request so clear the 'insecure_domain_change' tag 

188 request.context.options.pop("insecure_domain_change", False) 

189 try: 

190 response = self.next.send(request) 

191 self.on_response(request, response) 

192 except Exception: 

193 self.on_exception(request) 

194 raise 

195 

196 return response 

197 

198 def on_challenge( 

199 self, 

200 request: PipelineRequest[HTTPRequestType], 

201 response: PipelineResponse[HTTPRequestType, HTTPResponseType], 

202 ) -> bool: 

203 """Authorize request according to an authentication challenge 

204 

205 This method is called when the resource provider responds 401 with a WWW-Authenticate header. 

206 

207 :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge 

208 :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response 

209 :returns: a bool indicating whether the policy should send the request 

210 :rtype: bool 

211 """ 

212 headers = response.http_response.headers 

213 error = get_challenge_parameter(headers, "Bearer", "error") 

214 if error == "insufficient_claims": 

215 encoded_claims = get_challenge_parameter(headers, "Bearer", "claims") 

216 if not encoded_claims: 

217 return False 

218 padding_needed = -len(encoded_claims) % 4 

219 claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode("utf-8") 

220 if claims: 

221 self.authorize_request(request, *self._scopes, claims=claims) 

222 return True 

223 return False 

224 

225 def on_response( 

226 self, 

227 request: PipelineRequest[HTTPRequestType], 

228 response: PipelineResponse[HTTPRequestType, HTTPResponseType], 

229 ) -> None: 

230 """Executed after the request comes back from the next policy. 

231 

232 :param request: Request to be modified after returning from the policy. 

233 :type request: ~azure.core.pipeline.PipelineRequest 

234 :param response: Pipeline response object 

235 :type response: ~azure.core.pipeline.PipelineResponse 

236 """ 

237 

238 def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None: 

239 """Executed when an exception is raised while executing the next policy. 

240 

241 This method is executed inside the exception handler. 

242 

243 :param request: The Pipeline request object 

244 :type request: ~azure.core.pipeline.PipelineRequest 

245 """ 

246 # pylint: disable=unused-argument 

247 return 

248 

249 

250class AzureKeyCredentialPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): 

251 """Adds a key header for the provided credential. 

252 

253 :param credential: The credential used to authenticate requests. 

254 :type credential: ~azure.core.credentials.AzureKeyCredential 

255 :param str name: The name of the key header used for the credential. 

256 :keyword str prefix: The name of the prefix for the header value if any. 

257 :raises ValueError: if name is None or empty. 

258 :raises TypeError: if name is not a string or if credential is not an instance of AzureKeyCredential. 

259 """ 

260 

261 def __init__( # pylint: disable=unused-argument 

262 self, 

263 credential: "AzureKeyCredential", 

264 name: str, 

265 *, 

266 prefix: Optional[str] = None, 

267 **kwargs: Any, 

268 ) -> None: 

269 super().__init__() 

270 if not hasattr(credential, "key"): 

271 raise TypeError("String is not a supported credential input type. Use an instance of AzureKeyCredential.") 

272 if not name: 

273 raise ValueError("name can not be None or empty") 

274 if not isinstance(name, str): 

275 raise TypeError("name must be a string.") 

276 self._credential = credential 

277 self._name = name 

278 self._prefix = prefix + " " if prefix else "" 

279 

280 def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: 

281 """Called before the policy sends a request. 

282 

283 :param request: The request to be modified before sending. 

284 :type request: ~azure.core.pipeline.PipelineRequest 

285 """ 

286 request.http_request.headers[self._name] = f"{self._prefix}{self._credential.key}" 

287 

288 

289class AzureSasCredentialPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): 

290 """Adds a shared access signature to query for the provided credential. 

291 

292 :param credential: The credential used to authenticate requests. 

293 :type credential: ~azure.core.credentials.AzureSasCredential 

294 :raises ValueError: if credential is None. 

295 """ 

296 

297 def __init__( 

298 self, # pylint: disable=unused-argument 

299 credential: "AzureSasCredential", 

300 **kwargs: Any, 

301 ) -> None: 

302 super(AzureSasCredentialPolicy, self).__init__() 

303 if not credential: 

304 raise ValueError("credential can not be None") 

305 self._credential = credential 

306 

307 def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: 

308 """Called before the policy sends a request. 

309 

310 :param request: The request to be modified before sending. 

311 :type request: ~azure.core.pipeline.PipelineRequest 

312 """ 

313 url = request.http_request.url 

314 query = request.http_request.query 

315 signature = self._credential.signature 

316 if signature.startswith("?"): 

317 signature = signature[1:] 

318 if query: 

319 if signature not in url: 

320 url = url + "&" + signature 

321 else: 

322 if url.endswith("?"): 

323 url = url + signature 

324 else: 

325 url = url + "?" + signature 

326 request.http_request.url = url