1# --------------------------------------------------------------------------
2#
3# Copyright (c) Microsoft Corporation. All rights reserved.
4#
5# The MIT License (MIT)
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the ""Software""), to
9# deal in the Software without restriction, including without limitation the
10# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
11# sell copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in
15# all copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
22# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
23# IN THE SOFTWARE.
24#
25# --------------------------------------------------------------------------
26from typing import Awaitable, Optional, List, Union, Any
27import inspect
28
29from azure.core.pipeline.policies import (
30 AsyncBearerTokenCredentialPolicy,
31 AsyncHTTPPolicy,
32)
33from azure.core.pipeline import PipelineRequest, PipelineResponse
34from azure.core.pipeline.transport import (
35 HttpRequest as LegacyHttpRequest,
36 AsyncHttpResponse as LegacyAsyncHttpResponse,
37)
38from azure.core.rest import HttpRequest, AsyncHttpResponse
39from azure.core.credentials import AccessToken
40from azure.core.credentials_async import AsyncTokenCredential
41
42
43from ._authentication import _AuxiliaryAuthenticationPolicyBase
44
45
46HTTPRequestType = Union[LegacyHttpRequest, HttpRequest]
47AsyncHTTPResponseType = Union[LegacyAsyncHttpResponse, AsyncHttpResponse]
48
49
50async def await_result(func, *args, **kwargs):
51 """If func returns an awaitable, await it.
52
53 :param callable func: Function to call
54 :param any args: Positional arguments to pass to func
55 :return: Result of func
56 :rtype: any
57 """
58 result = func(*args, **kwargs)
59 if inspect.isawaitable(result):
60 return await result
61 return result
62
63
64class AsyncARMChallengeAuthenticationPolicy(AsyncBearerTokenCredentialPolicy):
65 """Adds a bearer token Authorization header to requests.
66
67 This policy internally handles Continuous Access Evaluation (CAE) challenges. When it can't complete a challenge,
68 it will return the 401 (unauthorized) response from ARM.
69 """
70
71
72class AsyncAuxiliaryAuthenticationPolicy(
73 _AuxiliaryAuthenticationPolicyBase[AsyncTokenCredential],
74 AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType],
75):
76 async def _get_auxiliary_tokens(self, *scopes: str, **kwargs: Any) -> Optional[List[AccessToken]]:
77 if self._auxiliary_credentials:
78 return [await cred.get_token(*scopes, **kwargs) for cred in self._auxiliary_credentials]
79 return None
80
81 async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
82 """Called before the policy sends a request.
83
84 The base implementation authorizes the request with an auxiliary authorization token.
85
86 :param ~azure.core.pipeline.PipelineRequest request: the request
87 """
88 self._enforce_https(request)
89
90 if self._need_new_aux_tokens:
91 self._aux_tokens = await self._get_auxiliary_tokens(*self._scopes)
92
93 self._update_headers(request.http_request.headers)
94
95 def on_response(
96 self,
97 request: PipelineRequest[HTTPRequestType],
98 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType],
99 ) -> Optional[Awaitable[None]]:
100 """Executed after the request comes back from the next policy.
101
102 :param request: Request to be modified after returning from the policy.
103 :type request: ~azure.core.pipeline.PipelineRequest
104 :param response: Pipeline response object
105 :type response: ~azure.core.pipeline.PipelineResponse
106 """
107
108 def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None:
109 """Executed when an exception is raised while executing the next policy.
110
111 This method is executed inside the exception handler.
112
113 :param request: The Pipeline request object
114 :type request: ~azure.core.pipeline.PipelineRequest
115 """
116 # pylint: disable=unused-argument
117 return
118
119 async def send(
120 self, request: PipelineRequest[HTTPRequestType]
121 ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]:
122 """Authorize request with a bearer token and send it to the next policy
123
124 :param request: The pipeline request object
125 :type request: ~azure.core.pipeline.PipelineRequest
126 :return: The pipeline response object
127 :rtype: ~azure.core.pipeline.PipelineResponse
128 """
129 await await_result(self.on_request, request)
130 try:
131 response = await self.next.send(request)
132 await await_result(self.on_response, request, response)
133 except Exception: # pylint:disable=broad-except
134 handled = await await_result(self.on_exception, request)
135 if not handled:
136 raise
137 return response