1# --------------------------------------------------------------------------
2#
3# Copyright (c) Microsoft Corporation. All rights reserved.
4#
5# The MIT License (MIT)
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the ""Software""), to
9# deal in the Software without restriction, including without limitation the
10# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
11# sell copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in
15# all copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
22# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
23# IN THE SOFTWARE.
24#
25# --------------------------------------------------------------------------
26from typing import cast, Awaitable, Optional, List, Union, Any
27import inspect
28
29from azure.core.pipeline.policies import (
30 AsyncBearerTokenCredentialPolicy,
31 AsyncHTTPPolicy,
32)
33from azure.core.pipeline import PipelineRequest, PipelineResponse
34from azure.core.pipeline.transport import (
35 HttpRequest as LegacyHttpRequest,
36 AsyncHttpResponse as LegacyAsyncHttpResponse,
37)
38from azure.core.rest import HttpRequest, AsyncHttpResponse
39from azure.core.credentials import AccessToken
40from azure.core.credentials_async import AsyncTokenCredential
41
42
43from ._authentication import _parse_claims_challenge, _AuxiliaryAuthenticationPolicyBase
44
45
46HTTPRequestType = Union[LegacyHttpRequest, HttpRequest]
47AsyncHTTPResponseType = Union[LegacyAsyncHttpResponse, AsyncHttpResponse]
48
49
50async def await_result(func, *args, **kwargs):
51 """If func returns an awaitable, await it.
52
53 :param callable func: Function to call
54 :param any args: Positional arguments to pass to func
55 :return: Result of func
56 :rtype: any
57 """
58 result = func(*args, **kwargs)
59 if inspect.isawaitable(result):
60 return await result
61 return result
62
63
64class AsyncARMChallengeAuthenticationPolicy(AsyncBearerTokenCredentialPolicy):
65 """Adds a bearer token Authorization header to requests.
66
67 This policy internally handles Continuous Access Evaluation (CAE) challenges. When it can't complete a challenge,
68 it will return the 401 (unauthorized) response from ARM.
69
70 :param ~azure.core.credentials.TokenCredential credential: credential for authorizing requests
71 :param str scopes: required authentication scopes
72 """
73
74 # pylint:disable=unused-argument
75 async def on_challenge(
76 self,
77 request: PipelineRequest[HTTPRequestType],
78 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType],
79 ) -> bool:
80 """Authorize request according to an ARM authentication challenge
81
82 :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
83 :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
84 :returns: a bool indicating whether the policy should send the request
85 :rtype: bool
86 """
87 # Casting, as the code seems to be certain that on_challenge this header will be present
88 challenge: str = cast(str, response.http_response.headers.get("WWW-Authenticate"))
89 claims = _parse_claims_challenge(challenge)
90 if claims:
91 await self.authorize_request(request, *self._scopes, claims=claims)
92 return True
93
94 return False
95
96
97class AsyncAuxiliaryAuthenticationPolicy(
98 _AuxiliaryAuthenticationPolicyBase[AsyncTokenCredential],
99 AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType],
100):
101 async def _get_auxiliary_tokens(self, *scopes: str, **kwargs: Any) -> Optional[List[AccessToken]]:
102 if self._auxiliary_credentials:
103 return [await cred.get_token(*scopes, **kwargs) for cred in self._auxiliary_credentials]
104 return None
105
106 async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
107 """Called before the policy sends a request.
108
109 The base implementation authorizes the request with an auxiliary authorization token.
110
111 :param ~azure.core.pipeline.PipelineRequest request: the request
112 """
113 self._enforce_https(request)
114
115 if self._need_new_aux_tokens:
116 self._aux_tokens = await self._get_auxiliary_tokens(*self._scopes)
117
118 self._update_headers(request.http_request.headers)
119
120 def on_response(
121 self,
122 request: PipelineRequest[HTTPRequestType],
123 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType],
124 ) -> Optional[Awaitable[None]]:
125 """Executed after the request comes back from the next policy.
126
127 :param request: Request to be modified after returning from the policy.
128 :type request: ~azure.core.pipeline.PipelineRequest
129 :param response: Pipeline response object
130 :type response: ~azure.core.pipeline.PipelineResponse
131 """
132
133 def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None:
134 """Executed when an exception is raised while executing the next policy.
135
136 This method is executed inside the exception handler.
137
138 :param request: The Pipeline request object
139 :type request: ~azure.core.pipeline.PipelineRequest
140 """
141 # pylint: disable=unused-argument
142 return
143
144 async def send(
145 self, request: PipelineRequest[HTTPRequestType]
146 ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]:
147 """Authorize request with a bearer token and send it to the next policy
148
149 :param request: The pipeline request object
150 :type request: ~azure.core.pipeline.PipelineRequest
151 :return: The pipeline response object
152 :rtype: ~azure.core.pipeline.PipelineResponse
153 """
154 await await_result(self.on_request, request)
155 try:
156 response = await self.next.send(request)
157 await await_result(self.on_response, request, response)
158 except Exception: # pylint:disable=broad-except
159 handled = await await_result(self.on_exception, request)
160 if not handled:
161 raise
162 return response