1# --------------------------------------------------------------------------
2#
3# Copyright (c) Microsoft Corporation. All rights reserved.
4#
5# The MIT License (MIT)
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the ""Software""), to
9# deal in the Software without restriction, including without limitation the
10# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
11# sell copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in
15# all copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
22# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
23# IN THE SOFTWARE.
24#
25# --------------------------------------------------------------------------
26import time
27from typing import Optional, Union, MutableMapping, List, Any, Sequence, TypeVar, Generic
28
29from azure.core.credentials import AccessToken, TokenCredential
30from azure.core.credentials_async import AsyncTokenCredential
31from azure.core.pipeline.policies import BearerTokenCredentialPolicy, SansIOHTTPPolicy
32from azure.core.pipeline import PipelineRequest
33from azure.core.exceptions import ServiceRequestError
34from azure.core.pipeline.transport import (
35 HttpRequest as LegacyHttpRequest,
36 HttpResponse as LegacyHttpResponse,
37)
38from azure.core.rest import HttpRequest, HttpResponse
39
40
41HTTPRequestType = Union[LegacyHttpRequest, HttpRequest]
42HTTPResponseType = Union[LegacyHttpResponse, HttpResponse]
43TokenCredentialType = TypeVar("TokenCredentialType", bound=Union[TokenCredential, AsyncTokenCredential])
44
45
46class ARMChallengeAuthenticationPolicy(BearerTokenCredentialPolicy):
47 """Adds a bearer token Authorization header to requests.
48
49 This policy internally handles Continuous Access Evaluation (CAE) challenges. When it can't complete a challenge,
50 it will return the 401 (unauthorized) response from ARM.
51 """
52
53
54# pylint:disable=too-few-public-methods
55class _AuxiliaryAuthenticationPolicyBase(Generic[TokenCredentialType]):
56 """Adds auxiliary authorization token header to requests.
57
58 :param ~azure.core.credentials.TokenCredential auxiliary_credentials: auxiliary credential for authorizing requests
59 :param str scopes: required authentication scopes
60 """
61
62 def __init__( # pylint: disable=unused-argument
63 self, auxiliary_credentials: Sequence[TokenCredentialType], *scopes: str, **kwargs: Any
64 ) -> None:
65 self._auxiliary_credentials = auxiliary_credentials
66 self._scopes = scopes
67 self._aux_tokens: Optional[List[AccessToken]] = None
68
69 @staticmethod
70 def _enforce_https(request: PipelineRequest[HTTPRequestType]) -> None:
71 # move 'enforce_https' from options to context, so it persists
72 # across retries but isn't passed to transport implementation
73 option = request.context.options.pop("enforce_https", None)
74
75 # True is the default setting; we needn't preserve an explicit opt in to the default behavior
76 if option is False:
77 request.context["enforce_https"] = option
78
79 enforce_https = request.context.get("enforce_https", True)
80 if enforce_https and not request.http_request.url.lower().startswith("https"):
81 raise ServiceRequestError(
82 "Bearer token authentication is not permitted for non-TLS protected (non-https) URLs."
83 )
84
85 def _update_headers(self, headers: MutableMapping[str, str]) -> None:
86 """Updates the x-ms-authorization-auxiliary header with the auxiliary token.
87
88 :param dict headers: The HTTP Request headers
89 """
90 if self._aux_tokens:
91 headers["x-ms-authorization-auxiliary"] = ", ".join(
92 "Bearer {}".format(token.token) for token in self._aux_tokens
93 )
94
95 @property
96 def _need_new_aux_tokens(self) -> bool:
97 if not self._aux_tokens:
98 return True
99 for token in self._aux_tokens:
100 if token.expires_on - time.time() < 300:
101 return True
102 return False
103
104
105class AuxiliaryAuthenticationPolicy(
106 _AuxiliaryAuthenticationPolicyBase[TokenCredential],
107 SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType],
108):
109 def _get_auxiliary_tokens(self, *scopes: str, **kwargs: Any) -> Optional[List[AccessToken]]:
110 if self._auxiliary_credentials:
111 return [cred.get_token(*scopes, **kwargs) for cred in self._auxiliary_credentials]
112 return None
113
114 def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
115 """Called before the policy sends a request.
116
117 The base implementation authorizes the request with an auxiliary authorization token.
118
119 :param ~azure.core.pipeline.PipelineRequest request: the request
120 """
121 self._enforce_https(request)
122
123 if self._need_new_aux_tokens:
124 self._aux_tokens = self._get_auxiliary_tokens(*self._scopes)
125
126 self._update_headers(request.http_request.headers)