1# --------------------------------------------------------------------------
2#
3# Copyright (c) Microsoft Corporation. All rights reserved.
4#
5# The MIT License (MIT)
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the ""Software""), to
9# deal in the Software without restriction, including without limitation the
10# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
11# sell copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in
15# all copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
22# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
23# IN THE SOFTWARE.
24#
25# --------------------------------------------------------------------------
26from __future__ import annotations
27import logging
28import collections.abc
29from typing import (
30 Any,
31 Awaitable,
32 TypeVar,
33 AsyncContextManager,
34 Generator,
35 Generic,
36 Optional,
37 Type,
38 cast,
39)
40from types import TracebackType
41from .configuration import Configuration
42from .pipeline import AsyncPipeline
43from .pipeline.transport._base import PipelineClientBase
44from .pipeline.policies import (
45 ContentDecodePolicy,
46 DistributedTracingPolicy,
47 HttpLoggingPolicy,
48 RequestIdPolicy,
49 AsyncRetryPolicy,
50 SensitiveHeaderCleanupPolicy,
51)
52
53
54HTTPRequestType = TypeVar("HTTPRequestType")
55AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", bound="AsyncContextManager")
56
57_LOGGER = logging.getLogger(__name__)
58
59
60class _Coroutine(Awaitable[AsyncHTTPResponseType]):
61 """Wrapper to get both context manager and awaitable in place.
62
63 Naming it "_Coroutine" because if you don't await it makes the error message easier:
64 >>> result = client.send_request(request)
65 >>> result.text()
66 AttributeError: '_Coroutine' object has no attribute 'text'
67
68 Indeed, the message for calling a coroutine without waiting would be:
69 AttributeError: 'coroutine' object has no attribute 'text'
70
71 This allows the dev to either use the "async with" syntax, or simply the object directly.
72 It's also why "send_request" is not declared as async, since it couldn't be both easily.
73
74 "wrapped" must be an awaitable object that returns an object implements the async context manager protocol.
75
76 This permits this code to work for both following requests.
77
78 ```python
79 from azure.core import AsyncPipelineClient
80 from azure.core.rest import HttpRequest
81
82 async def main():
83
84 request = HttpRequest("GET", "https://httpbin.org/user-agent")
85 async with AsyncPipelineClient("https://httpbin.org/") as client:
86 # Can be used directly
87 result = await client.send_request(request)
88 print(result.text())
89
90 # Can be used as an async context manager
91 async with client.send_request(request) as result:
92 print(result.text())
93 ```
94
95 :param wrapped: Must be an awaitable the returns an async context manager that supports async "close()"
96 :type wrapped: awaitable[AsyncHTTPResponseType]
97 """
98
99 def __init__(self, wrapped: Awaitable[AsyncHTTPResponseType]) -> None:
100 super().__init__()
101 self._wrapped = wrapped
102 # If someone tries to use the object without awaiting, they will get a
103 # AttributeError: '_Coroutine' object has no attribute 'text'
104 self._response: AsyncHTTPResponseType = cast(AsyncHTTPResponseType, None)
105
106 def __await__(self) -> Generator[Any, None, AsyncHTTPResponseType]:
107 return self._wrapped.__await__()
108
109 async def __aenter__(self) -> AsyncHTTPResponseType:
110 self._response = await self
111 return self._response
112
113 async def __aexit__(
114 self,
115 exc_type: Optional[Type[BaseException]] = None,
116 exc_value: Optional[BaseException] = None,
117 traceback: Optional[TracebackType] = None,
118 ) -> None:
119 await self._response.__aexit__(exc_type, exc_value, traceback)
120
121
122class AsyncPipelineClient(
123 PipelineClientBase,
124 AsyncContextManager["AsyncPipelineClient"],
125 Generic[HTTPRequestType, AsyncHTTPResponseType],
126):
127 """Service client core methods.
128
129 Builds an AsyncPipeline client.
130
131 :param str base_url: URL for the request.
132 :keyword ~azure.core.configuration.Configuration config: If omitted, the standard configuration is used.
133 :keyword Pipeline pipeline: If omitted, a Pipeline object is created and returned.
134 :keyword list[AsyncHTTPPolicy] policies: If omitted, the standard policies of the configuration object is used.
135 :keyword per_call_policies: If specified, the policies will be added into the policy list before RetryPolicy
136 :paramtype per_call_policies: Union[AsyncHTTPPolicy, SansIOHTTPPolicy,
137 list[AsyncHTTPPolicy], list[SansIOHTTPPolicy]]
138 :keyword per_retry_policies: If specified, the policies will be added into the policy list after RetryPolicy
139 :paramtype per_retry_policies: Union[AsyncHTTPPolicy, SansIOHTTPPolicy,
140 list[AsyncHTTPPolicy], list[SansIOHTTPPolicy]]
141 :keyword AsyncHttpTransport transport: If omitted, AioHttpTransport is used for asynchronous transport.
142 :return: An async pipeline object.
143 :rtype: ~azure.core.pipeline.AsyncPipeline
144
145 .. admonition:: Example:
146
147 .. literalinclude:: ../samples/test_example_async.py
148 :start-after: [START build_async_pipeline_client]
149 :end-before: [END build_async_pipeline_client]
150 :language: python
151 :dedent: 4
152 :caption: Builds the async pipeline client.
153 """
154
155 def __init__(
156 self,
157 base_url: str,
158 *,
159 pipeline: Optional[AsyncPipeline[HTTPRequestType, AsyncHTTPResponseType]] = None,
160 config: Optional[Configuration[HTTPRequestType, AsyncHTTPResponseType]] = None,
161 **kwargs: Any,
162 ):
163 super(AsyncPipelineClient, self).__init__(base_url)
164 self._config: Configuration[HTTPRequestType, AsyncHTTPResponseType] = config or Configuration(**kwargs)
165 self._base_url = base_url
166 self._pipeline = pipeline or self._build_pipeline(self._config, **kwargs)
167
168 async def __aenter__(self) -> AsyncPipelineClient[HTTPRequestType, AsyncHTTPResponseType]:
169 await self._pipeline.__aenter__()
170 return self
171
172 async def __aexit__(
173 self,
174 exc_type: Optional[Type[BaseException]] = None,
175 exc_value: Optional[BaseException] = None,
176 traceback: Optional[TracebackType] = None,
177 ) -> None:
178 await self._pipeline.__aexit__(exc_type, exc_value, traceback)
179
180 async def close(self) -> None:
181 await self.__aexit__()
182
183 def _build_pipeline(
184 self,
185 config: Configuration[HTTPRequestType, AsyncHTTPResponseType],
186 *,
187 policies=None,
188 per_call_policies=None,
189 per_retry_policies=None,
190 **kwargs,
191 ) -> AsyncPipeline[HTTPRequestType, AsyncHTTPResponseType]:
192 transport = kwargs.get("transport")
193 per_call_policies = per_call_policies or []
194 per_retry_policies = per_retry_policies or []
195
196 if policies is None: # [] is a valid policy list
197 policies = [
198 config.request_id_policy or RequestIdPolicy(**kwargs),
199 config.headers_policy,
200 config.user_agent_policy,
201 config.proxy_policy,
202 ContentDecodePolicy(**kwargs),
203 ]
204 if isinstance(per_call_policies, collections.abc.Iterable):
205 policies.extend(per_call_policies)
206 else:
207 policies.append(per_call_policies)
208
209 policies.extend(
210 [
211 config.redirect_policy,
212 config.retry_policy,
213 config.authentication_policy,
214 config.custom_hook_policy,
215 ]
216 )
217 if isinstance(per_retry_policies, collections.abc.Iterable):
218 policies.extend(per_retry_policies)
219 else:
220 policies.append(per_retry_policies)
221
222 policies.extend(
223 [
224 config.logging_policy,
225 DistributedTracingPolicy(**kwargs),
226 SensitiveHeaderCleanupPolicy(**kwargs) if config.redirect_policy else None,
227 config.http_logging_policy or HttpLoggingPolicy(**kwargs),
228 ]
229 )
230 else:
231 if isinstance(per_call_policies, collections.abc.Iterable):
232 per_call_policies_list = list(per_call_policies)
233 else:
234 per_call_policies_list = [per_call_policies]
235 per_call_policies_list.extend(policies)
236 policies = per_call_policies_list
237 if isinstance(per_retry_policies, collections.abc.Iterable):
238 per_retry_policies_list = list(per_retry_policies)
239 else:
240 per_retry_policies_list = [per_retry_policies]
241 if len(per_retry_policies_list) > 0:
242 index_of_retry = -1
243 for index, policy in enumerate(policies):
244 if isinstance(policy, AsyncRetryPolicy):
245 index_of_retry = index
246 if index_of_retry == -1:
247 raise ValueError(
248 "Failed to add per_retry_policies; no RetryPolicy found in the supplied list of policies. "
249 )
250 policies_1 = policies[: index_of_retry + 1]
251 policies_2 = policies[index_of_retry + 1 :]
252 policies_1.extend(per_retry_policies_list)
253 policies_1.extend(policies_2)
254 policies = policies_1
255
256 if not transport:
257 # Use private import for better typing, mypy and pyright don't like PEP562
258 from .pipeline.transport._aiohttp import AioHttpTransport
259
260 transport = AioHttpTransport(**kwargs)
261
262 return AsyncPipeline[HTTPRequestType, AsyncHTTPResponseType](transport, policies)
263
264 async def _make_pipeline_call(self, request: HTTPRequestType, **kwargs) -> AsyncHTTPResponseType:
265 return_pipeline_response = kwargs.pop("_return_pipeline_response", False)
266 pipeline_response = await self._pipeline.run(request, **kwargs) # pylint: disable=protected-access
267 if return_pipeline_response:
268 return pipeline_response # type: ignore # This is a private API we don't want to type in signature
269 return pipeline_response.http_response
270
271 def send_request(
272 self, request: HTTPRequestType, *, stream: bool = False, **kwargs: Any
273 ) -> Awaitable[AsyncHTTPResponseType]:
274 """Method that runs the network request through the client's chained policies.
275
276 >>> from azure.core.rest import HttpRequest
277 >>> request = HttpRequest('GET', 'http://www.example.com')
278 <HttpRequest [GET], url: 'http://www.example.com'>
279 >>> response = await client.send_request(request)
280 <AsyncHttpResponse: 200 OK>
281
282 :param request: The network request you want to make. Required.
283 :type request: ~azure.core.rest.HttpRequest
284 :keyword bool stream: Whether the response payload will be streamed. Defaults to False.
285 :return: The response of your network call. Does not do error handling on your response.
286 :rtype: ~azure.core.rest.AsyncHttpResponse
287 """
288 wrapped = self._make_pipeline_call(request, stream=stream, **kwargs)
289 return _Coroutine(wrapped=wrapped)