1# --------------------------------------------------------------------------
2#
3# Copyright (c) Microsoft Corporation. All rights reserved.
4#
5# The MIT License (MIT)
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the ""Software""), to
9# deal in the Software without restriction, including without limitation the
10# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
11# sell copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in
15# all copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
22# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
23# IN THE SOFTWARE.
24#
25# --------------------------------------------------------------------------
26import base64
27import time
28from typing import Optional, Union, MutableMapping, List, Any, Sequence, TypeVar, Generic
29
30from azure.core.credentials import AccessToken, TokenCredential
31from azure.core.credentials_async import AsyncTokenCredential
32from azure.core.pipeline.policies import BearerTokenCredentialPolicy, SansIOHTTPPolicy
33from azure.core.pipeline import PipelineRequest, PipelineResponse
34from azure.core.exceptions import ServiceRequestError
35from azure.core.pipeline.transport import (
36 HttpRequest as LegacyHttpRequest,
37 HttpResponse as LegacyHttpResponse,
38)
39from azure.core.rest import HttpRequest, HttpResponse
40
41
42HTTPRequestType = Union[LegacyHttpRequest, HttpRequest]
43HTTPResponseType = Union[LegacyHttpResponse, HttpResponse]
44TokenCredentialType = TypeVar("TokenCredentialType", bound=Union[TokenCredential, AsyncTokenCredential])
45
46
47class ARMChallengeAuthenticationPolicy(BearerTokenCredentialPolicy):
48 """Adds a bearer token Authorization header to requests.
49
50 This policy internally handles Continuous Access Evaluation (CAE) challenges. When it can't complete a challenge,
51 it will return the 401 (unauthorized) response from ARM.
52
53 :param ~azure.core.credentials.TokenCredential credential: credential for authorizing requests
54 :param str scopes: required authentication scopes
55 """
56
57 def on_challenge(
58 self,
59 request: PipelineRequest[HTTPRequestType],
60 response: PipelineResponse[HTTPRequestType, HTTPResponseType],
61 ) -> bool:
62 """Authorize request according to an ARM authentication challenge
63
64 :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
65 :param ~azure.core.pipeline.PipelineResponse response: ARM's response
66 :returns: a bool indicating whether the policy should send the request
67 :rtype: bool
68 """
69
70 challenge = response.http_response.headers.get("WWW-Authenticate")
71 if challenge:
72 claims = _parse_claims_challenge(challenge)
73 if claims:
74 self.authorize_request(request, *self._scopes, claims=claims)
75 return True
76
77 return False
78
79
80# pylint:disable=too-few-public-methods
81class _AuxiliaryAuthenticationPolicyBase(Generic[TokenCredentialType]):
82 """Adds auxiliary authorization token header to requests.
83
84 :param ~azure.core.credentials.TokenCredential auxiliary_credentials: auxiliary credential for authorizing requests
85 :param str scopes: required authentication scopes
86 """
87
88 def __init__( # pylint: disable=unused-argument
89 self, auxiliary_credentials: Sequence[TokenCredentialType], *scopes: str, **kwargs: Any
90 ) -> None:
91 self._auxiliary_credentials = auxiliary_credentials
92 self._scopes = scopes
93 self._aux_tokens: Optional[List[AccessToken]] = None
94
95 @staticmethod
96 def _enforce_https(request: PipelineRequest[HTTPRequestType]) -> None:
97 # move 'enforce_https' from options to context, so it persists
98 # across retries but isn't passed to transport implementation
99 option = request.context.options.pop("enforce_https", None)
100
101 # True is the default setting; we needn't preserve an explicit opt in to the default behavior
102 if option is False:
103 request.context["enforce_https"] = option
104
105 enforce_https = request.context.get("enforce_https", True)
106 if enforce_https and not request.http_request.url.lower().startswith("https"):
107 raise ServiceRequestError(
108 "Bearer token authentication is not permitted for non-TLS protected (non-https) URLs."
109 )
110
111 def _update_headers(self, headers: MutableMapping[str, str]) -> None:
112 """Updates the x-ms-authorization-auxiliary header with the auxiliary token.
113
114 :param dict headers: The HTTP Request headers
115 """
116 if self._aux_tokens:
117 headers["x-ms-authorization-auxiliary"] = ", ".join(
118 "Bearer {}".format(token.token) for token in self._aux_tokens
119 )
120
121 @property
122 def _need_new_aux_tokens(self) -> bool:
123 if not self._aux_tokens:
124 return True
125 for token in self._aux_tokens:
126 if token.expires_on - time.time() < 300:
127 return True
128 return False
129
130
131class AuxiliaryAuthenticationPolicy(
132 _AuxiliaryAuthenticationPolicyBase[TokenCredential],
133 SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType],
134):
135 def _get_auxiliary_tokens(self, *scopes: str, **kwargs: Any) -> Optional[List[AccessToken]]:
136 if self._auxiliary_credentials:
137 return [cred.get_token(*scopes, **kwargs) for cred in self._auxiliary_credentials]
138 return None
139
140 def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
141 """Called before the policy sends a request.
142
143 The base implementation authorizes the request with an auxiliary authorization token.
144
145 :param ~azure.core.pipeline.PipelineRequest request: the request
146 """
147 self._enforce_https(request)
148
149 if self._need_new_aux_tokens:
150 self._aux_tokens = self._get_auxiliary_tokens(*self._scopes)
151
152 self._update_headers(request.http_request.headers)
153
154
155def _parse_claims_challenge(challenge: str) -> Optional[str]:
156 """Parse the "claims" parameter from an authentication challenge
157
158 Example challenge with claims:
159 Bearer authorization_uri="https://login.windows-ppe.net/", error="invalid_token",
160 error_description="User session has been revoked",
161 claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0="
162
163 :param str challenge: The authentication challenge
164 :return: the challenge's "claims" parameter or None, if it doesn't contain that parameter
165 """
166 encoded_claims = None
167 for parameter in challenge.split(","):
168 if "claims=" in parameter:
169 if encoded_claims:
170 # multiple claims challenges, e.g. for cross-tenant auth, would require special handling
171 return None
172 encoded_claims = parameter[parameter.index("=") + 1 :].strip(" \"'")
173
174 if not encoded_claims:
175 return None
176
177 padding_needed = -len(encoded_claims) % 4
178 try:
179 decoded_claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode()
180 return decoded_claims
181 except Exception: # pylint:disable=broad-except
182 return None