1# --------------------------------------------------------------------------
2#
3# Copyright (c) Microsoft Corporation. All rights reserved.
4#
5# The MIT License (MIT)
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the ""Software""), to
9# deal in the Software without restriction, including without limitation the
10# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
11# sell copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in
15# all copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
22# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
23# IN THE SOFTWARE.
24#
25# --------------------------------------------------------------------------
26from __future__ import annotations
27from types import TracebackType
28from typing import Any, Union, Generic, TypeVar, List, Dict, Optional, Iterable, Type, AsyncContextManager
29
30from azure.core.pipeline import PipelineRequest, PipelineResponse, PipelineContext
31from azure.core.pipeline.policies import AsyncHTTPPolicy, SansIOHTTPPolicy
32from ._tools_async import await_result as _await_result
33from ._base import cleanup_kwargs_for_transport
34from .transport import AsyncHttpTransport
35
36AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType")
37HTTPRequestType = TypeVar("HTTPRequestType")
38
39
40class _SansIOAsyncHTTPPolicyRunner(
41 AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]
42): # pylint: disable=unsubscriptable-object
43 """Async implementation of the SansIO policy.
44
45 Modifies the request and sends to the next policy in the chain.
46
47 :param policy: A SansIO policy.
48 :type policy: ~azure.core.pipeline.policies.SansIOHTTPPolicy
49 """
50
51 def __init__(self, policy: SansIOHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]) -> None:
52 super(_SansIOAsyncHTTPPolicyRunner, self).__init__()
53 self._policy = policy
54
55 async def send(
56 self, request: PipelineRequest[HTTPRequestType]
57 ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]:
58 """Modifies the request and sends to the next policy in the chain.
59
60 :param request: The PipelineRequest object.
61 :type request: ~azure.core.pipeline.PipelineRequest
62 :return: The PipelineResponse object.
63 :rtype: ~azure.core.pipeline.PipelineResponse
64 """
65 await _await_result(self._policy.on_request, request)
66 response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]
67 try:
68 response = await self.next.send(request)
69 except Exception: # pylint: disable=broad-except
70 await _await_result(self._policy.on_exception, request)
71 raise
72 await _await_result(self._policy.on_response, request, response)
73 return response
74
75
76class _AsyncTransportRunner(
77 AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]
78): # pylint: disable=unsubscriptable-object
79 """Async Transport runner.
80
81 Uses specified HTTP transport type to send request and returns response.
82
83 :param sender: The async Http Transport instance.
84 :type sender: ~azure.core.pipeline.transport.AsyncHttpTransport
85 """
86
87 def __init__(self, sender: AsyncHttpTransport[HTTPRequestType, AsyncHTTPResponseType]) -> None:
88 super(_AsyncTransportRunner, self).__init__()
89 self._sender = sender
90
91 async def send(
92 self, request: PipelineRequest[HTTPRequestType]
93 ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]:
94 """Async HTTP transport send method.
95
96 :param request: The PipelineRequest object.
97 :type request: ~azure.core.pipeline.PipelineRequest
98 :return: The PipelineResponse object.
99 :rtype: ~azure.core.pipeline.PipelineResponse
100 """
101 cleanup_kwargs_for_transport(request.context.options)
102 return PipelineResponse(
103 request.http_request,
104 await self._sender.send(request.http_request, **request.context.options),
105 request.context,
106 )
107
108
109class AsyncPipeline(AsyncContextManager["AsyncPipeline"], Generic[HTTPRequestType, AsyncHTTPResponseType]):
110 """Async pipeline implementation.
111
112 This is implemented as a context manager, that will activate the context
113 of the HTTP sender.
114
115 :param transport: The async Http Transport instance.
116 :type transport: ~azure.core.pipeline.transport.AsyncHttpTransport
117 :param list policies: List of configured policies.
118
119 .. admonition:: Example:
120
121 .. literalinclude:: ../samples/test_example_async.py
122 :start-after: [START build_async_pipeline]
123 :end-before: [END build_async_pipeline]
124 :language: python
125 :dedent: 4
126 :caption: Builds the async pipeline for asynchronous transport.
127 """
128
129 def __init__( # pylint: disable=super-init-not-called
130 self,
131 transport: AsyncHttpTransport[HTTPRequestType, AsyncHTTPResponseType],
132 policies: Optional[
133 Iterable[
134 Union[
135 AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType],
136 SansIOHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType],
137 ]
138 ]
139 ] = None,
140 ) -> None:
141 self._impl_policies: List[AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]] = []
142 self._transport = transport
143
144 for policy in policies or []:
145 if isinstance(policy, SansIOHTTPPolicy):
146 self._impl_policies.append(_SansIOAsyncHTTPPolicyRunner(policy))
147 elif policy:
148 self._impl_policies.append(policy)
149 for index in range(len(self._impl_policies) - 1):
150 self._impl_policies[index].next = self._impl_policies[index + 1]
151 if self._impl_policies:
152 self._impl_policies[-1].next = _AsyncTransportRunner(self._transport)
153
154 async def __aenter__(self) -> AsyncPipeline[HTTPRequestType, AsyncHTTPResponseType]:
155 await self._transport.__aenter__()
156 return self
157
158 async def __aexit__(
159 self,
160 exc_type: Optional[Type[BaseException]] = None,
161 exc_value: Optional[BaseException] = None,
162 traceback: Optional[TracebackType] = None,
163 ) -> None:
164 await self._transport.__aexit__(exc_type, exc_value, traceback)
165
166 async def _prepare_multipart_mixed_request(self, request: HTTPRequestType) -> None:
167 """Will execute the multipart policies.
168
169 Does nothing if "set_multipart_mixed" was never called.
170
171 :param request: The HTTP request object.
172 :type request: ~azure.core.rest.HttpRequest
173 """
174 multipart_mixed_info = request.multipart_mixed_info # type: ignore
175 if not multipart_mixed_info:
176 return
177
178 requests: List[HTTPRequestType] = multipart_mixed_info[0]
179 policies: List[SansIOHTTPPolicy] = multipart_mixed_info[1]
180 pipeline_options: Dict[str, Any] = multipart_mixed_info[3]
181
182 async def prepare_requests(req):
183 if req.multipart_mixed_info:
184 # Recursively update changeset "sub requests"
185 await self._prepare_multipart_mixed_request(req)
186 context = PipelineContext(None, **pipeline_options)
187 pipeline_request = PipelineRequest(req, context)
188 for policy in policies:
189 await _await_result(policy.on_request, pipeline_request)
190
191 # Not happy to make this code asyncio specific, but that's multipart only for now
192 # If we need trio and multipart, let's reinvesitgate that later
193 import asyncio
194
195 await asyncio.gather(*[prepare_requests(req) for req in requests])
196
197 async def _prepare_multipart(self, request: HTTPRequestType) -> None:
198 # This code is fine as long as HTTPRequestType is actually
199 # azure.core.pipeline.transport.HTTPRequest, bu we don't check it in here
200 # since we didn't see (yet) pipeline usage where it's not this actual instance
201 # class used
202 await self._prepare_multipart_mixed_request(request)
203 request.prepare_multipart_body() # type: ignore
204
205 async def run(
206 self, request: HTTPRequestType, **kwargs: Any
207 ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]:
208 """Runs the HTTP Request through the chained policies.
209
210 :param request: The HTTP request object.
211 :type request: ~azure.core.pipeline.transport.HttpRequest
212 :return: The PipelineResponse object.
213 :rtype: ~azure.core.pipeline.PipelineResponse
214 """
215 await self._prepare_multipart(request)
216 context = PipelineContext(self._transport, **kwargs)
217 pipeline_request = PipelineRequest(request, context)
218 first_node = self._impl_policies[0] if self._impl_policies else _AsyncTransportRunner(self._transport)
219 return await first_node.send(pipeline_request)