1# --------------------------------------------------------------------------
2#
3# Copyright (c) Microsoft Corporation. All rights reserved.
4#
5# The MIT License (MIT)
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the ""Software""), to
9# deal in the Software without restriction, including without limitation the
10# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
11# sell copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in
15# all copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
22# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
23# IN THE SOFTWARE.
24#
25# --------------------------------------------------------------------------
26from __future__ import annotations
27import logging
28import collections.abc
29from typing import (
30 Any,
31 Awaitable,
32 TypeVar,
33 AsyncContextManager,
34 Generator,
35 Generic,
36 Optional,
37 Type,
38 cast,
39)
40from types import TracebackType
41from .configuration import Configuration
42from .pipeline import AsyncPipeline
43from .pipeline.transport._base import PipelineClientBase
44from .pipeline.policies import (
45 ContentDecodePolicy,
46 DistributedTracingPolicy,
47 HttpLoggingPolicy,
48 RequestIdPolicy,
49 AsyncRetryPolicy,
50 SensitiveHeaderCleanupPolicy,
51)
52
53
54HTTPRequestType = TypeVar("HTTPRequestType")
55AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", bound="AsyncContextManager")
56
57_LOGGER = logging.getLogger(__name__)
58
59
60class _Coroutine(Awaitable[AsyncHTTPResponseType]):
61 """Wrapper to get both context manager and awaitable in place.
62
63 Naming it "_Coroutine" because if you don't await it makes the error message easier:
64 >>> result = client.send_request(request)
65 >>> result.text()
66 AttributeError: '_Coroutine' object has no attribute 'text'
67
68 Indeed, the message for calling a coroutine without waiting would be:
69 AttributeError: 'coroutine' object has no attribute 'text'
70
71 This allows the dev to either use the "async with" syntax, or simply the object directly.
72 It's also why "send_request" is not declared as async, since it couldn't be both easily.
73
74 "wrapped" must be an awaitable object that returns an object implements the async context manager protocol.
75
76 This permits this code to work for both following requests.
77
78 ```python
79 from azure.core import AsyncPipelineClient
80 from azure.core.rest import HttpRequest
81
82 async def main():
83
84 request = HttpRequest("GET", "https://httpbin.org/user-agent")
85 async with AsyncPipelineClient("https://httpbin.org/") as client:
86 # Can be used directly
87 result = await client.send_request(request)
88 print(result.text())
89
90 # Can be used as an async context manager
91 async with client.send_request(request) as result:
92 print(result.text())
93 ```
94
95 :param wrapped: Must be an awaitable the returns an async context manager that supports async "close()"
96 :type wrapped: awaitable[AsyncHTTPResponseType]
97 """
98
99 def __init__(self, wrapped: Awaitable[AsyncHTTPResponseType]) -> None:
100 super().__init__()
101 self._wrapped = wrapped
102 # If someone tries to use the object without awaiting, they will get a
103 # AttributeError: '_Coroutine' object has no attribute 'text'
104 self._response: AsyncHTTPResponseType = cast(AsyncHTTPResponseType, None)
105
106 def __await__(self) -> Generator[Any, None, AsyncHTTPResponseType]:
107 return self._wrapped.__await__()
108
109 async def __aenter__(self) -> AsyncHTTPResponseType:
110 self._response = await self
111 return self._response
112
113 async def __aexit__(
114 self,
115 exc_type: Optional[Type[BaseException]] = None,
116 exc_value: Optional[BaseException] = None,
117 traceback: Optional[TracebackType] = None,
118 ) -> None:
119 await self._response.__aexit__(exc_type, exc_value, traceback)
120
121
122class AsyncPipelineClient(
123 PipelineClientBase,
124 AsyncContextManager["AsyncPipelineClient"],
125 Generic[HTTPRequestType, AsyncHTTPResponseType],
126):
127 """Service client core methods.
128
129 Builds an AsyncPipeline client.
130
131 :param str base_url: URL for the request.
132 :keyword ~azure.core.configuration.Configuration config: If omitted, the standard configuration is used.
133 :keyword Pipeline pipeline: If omitted, a Pipeline object is created and returned.
134 :keyword list[AsyncHTTPPolicy] policies: If omitted, the standard policies of the configuration object is used.
135 :keyword per_call_policies: If specified, the policies will be added into the policy list before RetryPolicy
136 :paramtype per_call_policies: Union[AsyncHTTPPolicy, SansIOHTTPPolicy,
137 list[AsyncHTTPPolicy], list[SansIOHTTPPolicy]]
138 :keyword per_retry_policies: If specified, the policies will be added into the policy list after RetryPolicy
139 :paramtype per_retry_policies: Union[AsyncHTTPPolicy, SansIOHTTPPolicy,
140 list[AsyncHTTPPolicy], list[SansIOHTTPPolicy]]
141 :keyword AsyncHttpTransport transport: If omitted, AioHttpTransport is used for asynchronous transport.
142 :return: An async pipeline object.
143 :rtype: ~azure.core.pipeline.AsyncPipeline
144
145 .. admonition:: Example:
146
147 .. literalinclude:: ../samples/test_example_async.py
148 :start-after: [START build_async_pipeline_client]
149 :end-before: [END build_async_pipeline_client]
150 :language: python
151 :dedent: 4
152 :caption: Builds the async pipeline client.
153 """
154
155 def __init__(
156 self,
157 base_url: str,
158 *,
159 pipeline: Optional[AsyncPipeline[HTTPRequestType, AsyncHTTPResponseType]] = None,
160 config: Optional[Configuration[HTTPRequestType, AsyncHTTPResponseType]] = None,
161 **kwargs: Any,
162 ):
163 super(AsyncPipelineClient, self).__init__(base_url)
164 self._config: Configuration[HTTPRequestType, AsyncHTTPResponseType] = config or Configuration(**kwargs)
165 self._base_url = base_url
166 self._pipeline = pipeline or self._build_pipeline(self._config, **kwargs)
167
168 async def __aenter__(
169 self,
170 ) -> AsyncPipelineClient[HTTPRequestType, AsyncHTTPResponseType]:
171 await self._pipeline.__aenter__()
172 return self
173
174 async def __aexit__(
175 self,
176 exc_type: Optional[Type[BaseException]] = None,
177 exc_value: Optional[BaseException] = None,
178 traceback: Optional[TracebackType] = None,
179 ) -> None:
180 await self._pipeline.__aexit__(exc_type, exc_value, traceback)
181
182 async def close(self) -> None:
183 await self.__aexit__()
184
185 def _build_pipeline(
186 self,
187 config: Configuration[HTTPRequestType, AsyncHTTPResponseType],
188 *,
189 policies=None,
190 per_call_policies=None,
191 per_retry_policies=None,
192 **kwargs,
193 ) -> AsyncPipeline[HTTPRequestType, AsyncHTTPResponseType]:
194 transport = kwargs.get("transport")
195 per_call_policies = per_call_policies or []
196 per_retry_policies = per_retry_policies or []
197
198 if policies is None: # [] is a valid policy list
199 policies = [
200 config.request_id_policy or RequestIdPolicy(**kwargs),
201 config.headers_policy,
202 config.user_agent_policy,
203 config.proxy_policy,
204 ContentDecodePolicy(**kwargs),
205 ]
206 if isinstance(per_call_policies, collections.abc.Iterable):
207 policies.extend(per_call_policies)
208 else:
209 policies.append(per_call_policies)
210
211 policies.extend(
212 [
213 config.redirect_policy,
214 config.retry_policy,
215 config.authentication_policy,
216 config.custom_hook_policy,
217 ]
218 )
219 if isinstance(per_retry_policies, collections.abc.Iterable):
220 policies.extend(per_retry_policies)
221 else:
222 policies.append(per_retry_policies)
223
224 policies.extend(
225 [
226 config.logging_policy,
227 DistributedTracingPolicy(**kwargs),
228 (SensitiveHeaderCleanupPolicy(**kwargs) if config.redirect_policy else None),
229 config.http_logging_policy or HttpLoggingPolicy(**kwargs),
230 ]
231 )
232 else:
233 if isinstance(per_call_policies, collections.abc.Iterable):
234 per_call_policies_list = list(per_call_policies)
235 else:
236 per_call_policies_list = [per_call_policies]
237 per_call_policies_list.extend(policies)
238 policies = per_call_policies_list
239 if isinstance(per_retry_policies, collections.abc.Iterable):
240 per_retry_policies_list = list(per_retry_policies)
241 else:
242 per_retry_policies_list = [per_retry_policies]
243 if len(per_retry_policies_list) > 0:
244 index_of_retry = -1
245 for index, policy in enumerate(policies):
246 if isinstance(policy, AsyncRetryPolicy):
247 index_of_retry = index
248 if index_of_retry == -1:
249 raise ValueError(
250 "Failed to add per_retry_policies; no RetryPolicy found in the supplied list of policies. "
251 )
252 policies_1 = policies[: index_of_retry + 1]
253 policies_2 = policies[index_of_retry + 1 :]
254 policies_1.extend(per_retry_policies_list)
255 policies_1.extend(policies_2)
256 policies = policies_1
257
258 if not transport:
259 # Use private import for better typing, mypy and pyright don't like PEP562
260 from .pipeline.transport._aiohttp import AioHttpTransport
261
262 transport = AioHttpTransport(**kwargs)
263
264 return AsyncPipeline[HTTPRequestType, AsyncHTTPResponseType](transport, policies)
265
266 async def _make_pipeline_call(self, request: HTTPRequestType, **kwargs) -> AsyncHTTPResponseType:
267 return_pipeline_response = kwargs.pop("_return_pipeline_response", False)
268 pipeline_response = await self._pipeline.run(request, **kwargs)
269 if return_pipeline_response:
270 return pipeline_response # type: ignore # This is a private API we don't want to type in signature
271 return pipeline_response.http_response
272
273 def send_request(
274 self, request: HTTPRequestType, *, stream: bool = False, **kwargs: Any
275 ) -> Awaitable[AsyncHTTPResponseType]:
276 """Method that runs the network request through the client's chained policies.
277
278 >>> from azure.core.rest import HttpRequest
279 >>> request = HttpRequest('GET', 'http://www.example.com')
280 <HttpRequest [GET], url: 'http://www.example.com'>
281 >>> response = await client.send_request(request)
282 <AsyncHttpResponse: 200 OK>
283
284 :param request: The network request you want to make. Required.
285 :type request: ~azure.core.rest.HttpRequest
286 :keyword bool stream: Whether the response payload will be streamed. Defaults to False.
287 :return: The response of your network call. Does not do error handling on your response.
288 :rtype: ~azure.core.rest.AsyncHttpResponse
289 """
290 wrapped = self._make_pipeline_call(request, stream=stream, **kwargs)
291 return _Coroutine(wrapped=wrapped)