1# --------------------------------------------------------------------------
2#
3# Copyright (c) Microsoft Corporation. All rights reserved.
4#
5# The MIT License (MIT)
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the ""Software""), to
9# deal in the Software without restriction, including without limitation the
10# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
11# sell copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in
15# all copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
22# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
23# IN THE SOFTWARE.
24#
25# --------------------------------------------------------------------------
26from typing import TypeVar, Awaitable, Optional
27import inspect
28
29from azure.core.pipeline.policies import (
30 AsyncBearerTokenCredentialPolicy,
31 AsyncHTTPPolicy,
32)
33from azure.core.pipeline import PipelineRequest, PipelineResponse
34
35from ._authentication import _parse_claims_challenge, _AuxiliaryAuthenticationPolicyBase
36
37
38HTTPRequestType = TypeVar("HTTPRequestType")
39AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType")
40
41
42async def await_result(func, *args, **kwargs):
43 """If func returns an awaitable, await it."""
44 result = func(*args, **kwargs)
45 if inspect.isawaitable(result):
46 return await result
47 return result
48
49
50class AsyncARMChallengeAuthenticationPolicy(AsyncBearerTokenCredentialPolicy):
51 """Adds a bearer token Authorization header to requests.
52
53 This policy internally handles Continuous Access Evaluation (CAE) challenges. When it can't complete a challenge,
54 it will return the 401 (unauthorized) response from ARM.
55
56 :param ~azure.core.credentials.TokenCredential credential: credential for authorizing requests
57 :param str scopes: required authentication scopes
58 """
59
60 # pylint:disable=unused-argument
61 async def on_challenge(
62 self,
63 request: PipelineRequest[HTTPRequestType],
64 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType],
65 ) -> bool:
66 """Authorize request according to an ARM authentication challenge
67
68 :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
69 :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
70 :returns: a bool indicating whether the policy should send the request
71 """
72
73 challenge = response.http_response.headers.get("WWW-Authenticate")
74 claims = _parse_claims_challenge(challenge)
75 if claims:
76 await self.authorize_request(request, *self._scopes, claims=claims)
77 return True
78
79 return False
80
81
82class AsyncAuxiliaryAuthenticationPolicy(
83 _AuxiliaryAuthenticationPolicyBase,
84 AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType],
85):
86 async def _get_auxiliary_tokens(self, *scopes, **kwargs):
87 if self._auxiliary_credentials:
88 return [await cred.get_token(*scopes, **kwargs) for cred in self._auxiliary_credentials]
89 return None
90
91 async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
92 """Called before the policy sends a request.
93
94 The base implementation authorizes the request with an auxiliary authorization token.
95
96 :param ~azure.core.pipeline.PipelineRequest request: the request
97 """
98 self._enforce_https(request)
99
100 if self._need_new_aux_tokens:
101 self._aux_tokens = await self._get_auxiliary_tokens(*self._scopes)
102
103 self._update_headers(request.http_request.headers)
104
105 def on_response(
106 self,
107 request: PipelineRequest[HTTPRequestType],
108 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType],
109 ) -> Optional[Awaitable[None]]:
110 """Executed after the request comes back from the next policy.
111
112 :param request: Request to be modified after returning from the policy.
113 :type request: ~azure.core.pipeline.PipelineRequest
114 :param response: Pipeline response object
115 :type response: ~azure.core.pipeline.PipelineResponse
116 """
117
118 def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None:
119 """Executed when an exception is raised while executing the next policy.
120
121 This method is executed inside the exception handler.
122
123 :param request: The Pipeline request object
124 :type request: ~azure.core.pipeline.PipelineRequest
125 """
126 # pylint: disable=no-self-use,unused-argument
127 return
128
129 async def send(
130 self, request: PipelineRequest[HTTPRequestType]
131 ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]:
132 """Authorize request with a bearer token and send it to the next policy
133
134 :param request: The pipeline request object
135 :type request: ~azure.core.pipeline.PipelineRequest
136 """
137 await await_result(self.on_request, request)
138 try:
139 response = await self.next.send(request)
140 await await_result(self.on_response, request, response)
141 except Exception: # pylint:disable=broad-except
142 handled = await await_result(self.on_exception, request)
143 if not handled:
144 raise
145 return response