1# -------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See LICENSE.txt in the project root for
4# license information.
5# -------------------------------------------------------------------------
6import time
7from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any, Union, cast
8from azure.core.credentials import TokenCredential, SupportsTokenInfo, TokenRequestOptions, TokenProvider
9from azure.core.pipeline import PipelineRequest, PipelineResponse
10from azure.core.pipeline.transport import HttpResponse as LegacyHttpResponse, HttpRequest as LegacyHttpRequest
11from azure.core.rest import HttpResponse, HttpRequest
12from . import HTTPPolicy, SansIOHTTPPolicy
13from ...exceptions import ServiceRequestError
14
15if TYPE_CHECKING:
16 # pylint:disable=unused-import
17 from azure.core.credentials import (
18 AccessToken,
19 AccessTokenInfo,
20 AzureKeyCredential,
21 AzureSasCredential,
22 )
23
24HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse)
25HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)
26
27
28# pylint:disable=too-few-public-methods
29class _BearerTokenCredentialPolicyBase:
30 """Base class for a Bearer Token Credential Policy.
31
32 :param credential: The credential.
33 :type credential: ~azure.core.credentials.TokenProvider
34 :param str scopes: Lets you specify the type of access needed.
35 :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested
36 tokens. Defaults to False.
37 """
38
39 def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None:
40 super(_BearerTokenCredentialPolicyBase, self).__init__()
41 self._scopes = scopes
42 self._credential = credential
43 self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None
44 self._enable_cae: bool = kwargs.get("enable_cae", False)
45
46 @staticmethod
47 def _enforce_https(request: PipelineRequest[HTTPRequestType]) -> None:
48 # move 'enforce_https' from options to context so it persists
49 # across retries but isn't passed to a transport implementation
50 option = request.context.options.pop("enforce_https", None)
51
52 # True is the default setting; we needn't preserve an explicit opt in to the default behavior
53 if option is False:
54 request.context["enforce_https"] = option
55
56 enforce_https = request.context.get("enforce_https", True)
57 if enforce_https and not request.http_request.url.lower().startswith("https"):
58 raise ServiceRequestError(
59 "Bearer token authentication is not permitted for non-TLS protected (non-https) URLs."
60 )
61
62 @staticmethod
63 def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
64 """Updates the Authorization header with the bearer token.
65
66 :param MutableMapping[str, str] headers: The HTTP Request headers
67 :param str token: The OAuth token.
68 """
69 headers["Authorization"] = "Bearer {}".format(token)
70
71 @property
72 def _need_new_token(self) -> bool:
73 now = time.time()
74 refresh_on = getattr(self._token, "refresh_on", None)
75 return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300
76
77 def _request_token(self, *scopes: str, **kwargs: Any) -> None:
78 """Request a new token from the credential.
79
80 This will call the credential's appropriate method to get a token and store it in the policy.
81
82 :param str scopes: The type of access needed.
83 """
84 if self._enable_cae:
85 kwargs.setdefault("enable_cae", self._enable_cae)
86
87 if hasattr(self._credential, "get_token_info"):
88 options: TokenRequestOptions = {}
89 # Loop through all the keyword arguments and check if they are part of the TokenRequestOptions.
90 for key in list(kwargs.keys()):
91 if key in TokenRequestOptions.__annotations__: # pylint:disable=no-member
92 options[key] = kwargs.pop(key) # type: ignore[literal-required]
93
94 self._token = cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options)
95 else:
96 self._token = cast(TokenCredential, self._credential).get_token(*scopes, **kwargs)
97
98
99class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]):
100 """Adds a bearer token Authorization header to requests.
101
102 :param credential: The credential.
103 :type credential: ~azure.core.TokenCredential
104 :param str scopes: Lets you specify the type of access needed.
105 :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested
106 tokens. Defaults to False.
107 :raises: :class:`~azure.core.exceptions.ServiceRequestError`
108 """
109
110 def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
111 """Called before the policy sends a request.
112
113 The base implementation authorizes the request with a bearer token.
114
115 :param ~azure.core.pipeline.PipelineRequest request: the request
116 """
117 self._enforce_https(request)
118
119 if self._token is None or self._need_new_token:
120 self._request_token(*self._scopes)
121 bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token
122 self._update_headers(request.http_request.headers, bearer_token)
123
124 def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
125 """Acquire a token from the credential and authorize the request with it.
126
127 Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
128 authorize future requests.
129
130 :param ~azure.core.pipeline.PipelineRequest request: the request
131 :param str scopes: required scopes of authentication
132 """
133 self._request_token(*scopes, **kwargs)
134 bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token
135 self._update_headers(request.http_request.headers, bearer_token)
136
137 def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]:
138 """Authorize request with a bearer token and send it to the next policy
139
140 :param request: The pipeline request object
141 :type request: ~azure.core.pipeline.PipelineRequest
142 :return: The pipeline response object
143 :rtype: ~azure.core.pipeline.PipelineResponse
144 """
145 self.on_request(request)
146 try:
147 response = self.next.send(request)
148 except Exception: # pylint:disable=broad-except
149 self.on_exception(request)
150 raise
151
152 self.on_response(request, response)
153 if response.http_response.status_code == 401:
154 self._token = None # any cached token is invalid
155 if "WWW-Authenticate" in response.http_response.headers:
156 request_authorized = self.on_challenge(request, response)
157 if request_authorized:
158 # if we receive a challenge response, we retrieve a new token
159 # which matches the new target. In this case, we don't want to remove
160 # token from the request so clear the 'insecure_domain_change' tag
161 request.context.options.pop("insecure_domain_change", False)
162 try:
163 response = self.next.send(request)
164 self.on_response(request, response)
165 except Exception: # pylint:disable=broad-except
166 self.on_exception(request)
167 raise
168
169 return response
170
171 def on_challenge(
172 self, request: PipelineRequest[HTTPRequestType], response: PipelineResponse[HTTPRequestType, HTTPResponseType]
173 ) -> bool:
174 """Authorize request according to an authentication challenge
175
176 This method is called when the resource provider responds 401 with a WWW-Authenticate header.
177
178 :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
179 :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
180 :returns: a bool indicating whether the policy should send the request
181 :rtype: bool
182 """
183 # pylint:disable=unused-argument
184 return False
185
186 def on_response(
187 self, request: PipelineRequest[HTTPRequestType], response: PipelineResponse[HTTPRequestType, HTTPResponseType]
188 ) -> None:
189 """Executed after the request comes back from the next policy.
190
191 :param request: Request to be modified after returning from the policy.
192 :type request: ~azure.core.pipeline.PipelineRequest
193 :param response: Pipeline response object
194 :type response: ~azure.core.pipeline.PipelineResponse
195 """
196
197 def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None:
198 """Executed when an exception is raised while executing the next policy.
199
200 This method is executed inside the exception handler.
201
202 :param request: The Pipeline request object
203 :type request: ~azure.core.pipeline.PipelineRequest
204 """
205 # pylint: disable=unused-argument
206 return
207
208
209class AzureKeyCredentialPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
210 """Adds a key header for the provided credential.
211
212 :param credential: The credential used to authenticate requests.
213 :type credential: ~azure.core.credentials.AzureKeyCredential
214 :param str name: The name of the key header used for the credential.
215 :keyword str prefix: The name of the prefix for the header value if any.
216 :raises: ValueError or TypeError
217 """
218
219 def __init__( # pylint: disable=unused-argument
220 self,
221 credential: "AzureKeyCredential",
222 name: str,
223 *,
224 prefix: Optional[str] = None,
225 **kwargs: Any,
226 ) -> None:
227 super().__init__()
228 if not hasattr(credential, "key"):
229 raise TypeError("String is not a supported credential input type. Use an instance of AzureKeyCredential.")
230 if not name:
231 raise ValueError("name can not be None or empty")
232 if not isinstance(name, str):
233 raise TypeError("name must be a string.")
234 self._credential = credential
235 self._name = name
236 self._prefix = prefix + " " if prefix else ""
237
238 def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
239 request.http_request.headers[self._name] = f"{self._prefix}{self._credential.key}"
240
241
242class AzureSasCredentialPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
243 """Adds a shared access signature to query for the provided credential.
244
245 :param credential: The credential used to authenticate requests.
246 :type credential: ~azure.core.credentials.AzureSasCredential
247 :raises: ValueError or TypeError
248 """
249
250 def __init__(self, credential: "AzureSasCredential", **kwargs: Any) -> None: # pylint: disable=unused-argument
251 super(AzureSasCredentialPolicy, self).__init__()
252 if not credential:
253 raise ValueError("credential can not be None")
254 self._credential = credential
255
256 def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
257 url = request.http_request.url
258 query = request.http_request.query
259 signature = self._credential.signature
260 if signature.startswith("?"):
261 signature = signature[1:]
262 if query:
263 if signature not in url:
264 url = url + "&" + signature
265 else:
266 if url.endswith("?"):
267 url = url + signature
268 else:
269 url = url + "?" + signature
270 request.http_request.url = url