1# -------------------------------------------------------------------------- 
    2# 
    3# Copyright (c) Microsoft Corporation. All rights reserved. 
    4# 
    5# The MIT License (MIT) 
    6# 
    7# Permission is hereby granted, free of charge, to any person obtaining a copy 
    8# of this software and associated documentation files (the ""Software""), to 
    9# deal in the Software without restriction, including without limitation the 
    10# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 
    11# sell copies of the Software, and to permit persons to whom the Software is 
    12# furnished to do so, subject to the following conditions: 
    13# 
    14# The above copyright notice and this permission notice shall be included in 
    15# all copies or substantial portions of the Software. 
    16# 
    17# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 
    18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 
    19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 
    20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 
    21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 
    22# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 
    23# IN THE SOFTWARE. 
    24# 
    25# -------------------------------------------------------------------------- 
    26from __future__ import annotations 
    27import logging 
    28import collections.abc 
    29from typing import ( 
    30    Any, 
    31    Awaitable, 
    32    TypeVar, 
    33    AsyncContextManager, 
    34    Generator, 
    35    Generic, 
    36    Optional, 
    37    Type, 
    38    cast, 
    39) 
    40from types import TracebackType 
    41from .configuration import Configuration 
    42from .pipeline import AsyncPipeline 
    43from .pipeline.transport._base import PipelineClientBase 
    44from .pipeline.policies import ( 
    45    ContentDecodePolicy, 
    46    DistributedTracingPolicy, 
    47    HttpLoggingPolicy, 
    48    RequestIdPolicy, 
    49    AsyncRetryPolicy, 
    50    SensitiveHeaderCleanupPolicy, 
    51) 
    52 
    53 
    54HTTPRequestType = TypeVar("HTTPRequestType") 
    55AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", bound="AsyncContextManager") 
    56 
    57_LOGGER = logging.getLogger(__name__) 
    58 
    59 
    60class _Coroutine(Awaitable[AsyncHTTPResponseType]): 
    61    """Wrapper to get both context manager and awaitable in place. 
    62 
    63    Naming it "_Coroutine" because if you don't await it makes the error message easier: 
    64    >>> result = client.send_request(request) 
    65    >>> result.text() 
    66    AttributeError: '_Coroutine' object has no attribute 'text' 
    67 
    68    Indeed, the message for calling a coroutine without waiting would be: 
    69    AttributeError: 'coroutine' object has no attribute 'text' 
    70 
    71    This allows the dev to either use the "async with" syntax, or simply the object directly. 
    72    It's also why "send_request" is not declared as async, since it couldn't be both easily. 
    73 
    74    "wrapped" must be an awaitable object that returns an object implements the async context manager protocol. 
    75 
    76    This permits this code to work for both following requests. 
    77 
    78    ```python 
    79    from azure.core import AsyncPipelineClient 
    80    from azure.core.rest import HttpRequest 
    81 
    82    async def main(): 
    83 
    84        request = HttpRequest("GET", "https://httpbin.org/user-agent") 
    85        async with AsyncPipelineClient("https://httpbin.org/") as client: 
    86            # Can be used directly 
    87            result = await client.send_request(request) 
    88            print(result.text()) 
    89 
    90            # Can be used as an async context manager 
    91            async with client.send_request(request) as result: 
    92                print(result.text()) 
    93    ``` 
    94 
    95    :param wrapped: Must be an awaitable the returns an async context manager that supports async "close()" 
    96    :type wrapped: awaitable[AsyncHTTPResponseType] 
    97    """ 
    98 
    99    def __init__(self, wrapped: Awaitable[AsyncHTTPResponseType]) -> None: 
    100        super().__init__() 
    101        self._wrapped = wrapped 
    102        # If someone tries to use the object without awaiting, they will get a 
    103        # AttributeError: '_Coroutine' object has no attribute 'text' 
    104        self._response: AsyncHTTPResponseType = cast(AsyncHTTPResponseType, None) 
    105 
    106    def __await__(self) -> Generator[Any, None, AsyncHTTPResponseType]: 
    107        return self._wrapped.__await__() 
    108 
    109    async def __aenter__(self) -> AsyncHTTPResponseType: 
    110        self._response = await self 
    111        return self._response 
    112 
    113    async def __aexit__( 
    114        self, 
    115        exc_type: Optional[Type[BaseException]] = None, 
    116        exc_value: Optional[BaseException] = None, 
    117        traceback: Optional[TracebackType] = None, 
    118    ) -> None: 
    119        await self._response.__aexit__(exc_type, exc_value, traceback) 
    120 
    121 
    122class AsyncPipelineClient( 
    123    PipelineClientBase, 
    124    AsyncContextManager["AsyncPipelineClient"], 
    125    Generic[HTTPRequestType, AsyncHTTPResponseType], 
    126): 
    127    """Service client core methods. 
    128 
    129    Builds an AsyncPipeline client. 
    130 
    131    :param str base_url: URL for the request. 
    132    :keyword ~azure.core.configuration.Configuration config: If omitted, the standard configuration is used. 
    133    :keyword Pipeline pipeline: If omitted, a Pipeline object is created and returned. 
    134    :keyword list[AsyncHTTPPolicy] policies: If omitted, the standard policies of the configuration object is used. 
    135    :keyword per_call_policies: If specified, the policies will be added into the policy list before RetryPolicy 
    136    :paramtype per_call_policies: Union[AsyncHTTPPolicy, SansIOHTTPPolicy, 
    137        list[AsyncHTTPPolicy], list[SansIOHTTPPolicy]] 
    138    :keyword per_retry_policies: If specified, the policies will be added into the policy list after RetryPolicy 
    139    :paramtype per_retry_policies: Union[AsyncHTTPPolicy, SansIOHTTPPolicy, 
    140        list[AsyncHTTPPolicy], list[SansIOHTTPPolicy]] 
    141    :keyword AsyncHttpTransport transport: If omitted, AioHttpTransport is used for asynchronous transport. 
    142    :return: An async pipeline object. 
    143    :rtype: ~azure.core.pipeline.AsyncPipeline 
    144 
    145    .. admonition:: Example: 
    146 
    147        .. literalinclude:: ../samples/test_example_async.py 
    148            :start-after: [START build_async_pipeline_client] 
    149            :end-before: [END build_async_pipeline_client] 
    150            :language: python 
    151            :dedent: 4 
    152            :caption: Builds the async pipeline client. 
    153    """ 
    154 
    155    def __init__( 
    156        self, 
    157        base_url: str, 
    158        *, 
    159        pipeline: Optional[AsyncPipeline[HTTPRequestType, AsyncHTTPResponseType]] = None, 
    160        config: Optional[Configuration[HTTPRequestType, AsyncHTTPResponseType]] = None, 
    161        **kwargs: Any, 
    162    ): 
    163        super(AsyncPipelineClient, self).__init__(base_url) 
    164        self._config: Configuration[HTTPRequestType, AsyncHTTPResponseType] = config or Configuration(**kwargs) 
    165        self._base_url = base_url 
    166        self._pipeline = pipeline or self._build_pipeline(self._config, **kwargs) 
    167 
    168    async def __aenter__( 
    169        self, 
    170    ) -> AsyncPipelineClient[HTTPRequestType, AsyncHTTPResponseType]: 
    171        await self._pipeline.__aenter__() 
    172        return self 
    173 
    174    async def __aexit__( 
    175        self, 
    176        exc_type: Optional[Type[BaseException]] = None, 
    177        exc_value: Optional[BaseException] = None, 
    178        traceback: Optional[TracebackType] = None, 
    179    ) -> None: 
    180        await self._pipeline.__aexit__(exc_type, exc_value, traceback) 
    181 
    182    async def close(self) -> None: 
    183        await self.__aexit__() 
    184 
    185    def _build_pipeline( 
    186        self, 
    187        config: Configuration[HTTPRequestType, AsyncHTTPResponseType], 
    188        *, 
    189        policies=None, 
    190        per_call_policies=None, 
    191        per_retry_policies=None, 
    192        **kwargs, 
    193    ) -> AsyncPipeline[HTTPRequestType, AsyncHTTPResponseType]: 
    194        transport = kwargs.get("transport") 
    195        per_call_policies = per_call_policies or [] 
    196        per_retry_policies = per_retry_policies or [] 
    197 
    198        if policies is None:  # [] is a valid policy list 
    199            policies = [ 
    200                config.request_id_policy or RequestIdPolicy(**kwargs), 
    201                config.headers_policy, 
    202                config.user_agent_policy, 
    203                config.proxy_policy, 
    204                ContentDecodePolicy(**kwargs), 
    205            ] 
    206            if isinstance(per_call_policies, collections.abc.Iterable): 
    207                policies.extend(per_call_policies) 
    208            else: 
    209                policies.append(per_call_policies) 
    210 
    211            policies.extend( 
    212                [ 
    213                    config.redirect_policy, 
    214                    config.retry_policy, 
    215                    config.authentication_policy, 
    216                    config.custom_hook_policy, 
    217                ] 
    218            ) 
    219            if isinstance(per_retry_policies, collections.abc.Iterable): 
    220                policies.extend(per_retry_policies) 
    221            else: 
    222                policies.append(per_retry_policies) 
    223 
    224            policies.extend( 
    225                [ 
    226                    config.logging_policy, 
    227                    DistributedTracingPolicy(**kwargs), 
    228                    (SensitiveHeaderCleanupPolicy(**kwargs) if config.redirect_policy else None), 
    229                    config.http_logging_policy or HttpLoggingPolicy(**kwargs), 
    230                ] 
    231            ) 
    232        else: 
    233            if isinstance(per_call_policies, collections.abc.Iterable): 
    234                per_call_policies_list = list(per_call_policies) 
    235            else: 
    236                per_call_policies_list = [per_call_policies] 
    237            per_call_policies_list.extend(policies) 
    238            policies = per_call_policies_list 
    239            if isinstance(per_retry_policies, collections.abc.Iterable): 
    240                per_retry_policies_list = list(per_retry_policies) 
    241            else: 
    242                per_retry_policies_list = [per_retry_policies] 
    243            if len(per_retry_policies_list) > 0: 
    244                index_of_retry = -1 
    245                for index, policy in enumerate(policies): 
    246                    if isinstance(policy, AsyncRetryPolicy): 
    247                        index_of_retry = index 
    248                if index_of_retry == -1: 
    249                    raise ValueError( 
    250                        "Failed to add per_retry_policies; no RetryPolicy found in the supplied list of policies. " 
    251                    ) 
    252                policies_1 = policies[: index_of_retry + 1] 
    253                policies_2 = policies[index_of_retry + 1 :] 
    254                policies_1.extend(per_retry_policies_list) 
    255                policies_1.extend(policies_2) 
    256                policies = policies_1 
    257 
    258        if not transport: 
    259            # Use private import for better typing, mypy and pyright don't like PEP562 
    260            from .pipeline.transport._aiohttp import AioHttpTransport 
    261 
    262            transport = AioHttpTransport(**kwargs) 
    263 
    264        return AsyncPipeline[HTTPRequestType, AsyncHTTPResponseType](transport, policies) 
    265 
    266    async def _make_pipeline_call(self, request: HTTPRequestType, **kwargs) -> AsyncHTTPResponseType: 
    267        return_pipeline_response = kwargs.pop("_return_pipeline_response", False) 
    268        pipeline_response = await self._pipeline.run(request, **kwargs) 
    269        if return_pipeline_response: 
    270            return pipeline_response  # type: ignore  # This is a private API we don't want to type in signature 
    271        return pipeline_response.http_response 
    272 
    273    def send_request( 
    274        self, request: HTTPRequestType, *, stream: bool = False, **kwargs: Any 
    275    ) -> Awaitable[AsyncHTTPResponseType]: 
    276        """Method that runs the network request through the client's chained policies. 
    277 
    278        >>> from azure.core.rest import HttpRequest 
    279        >>> request = HttpRequest('GET', 'http://www.example.com') 
    280        <HttpRequest [GET], url: 'http://www.example.com'> 
    281        >>> response = await client.send_request(request) 
    282        <AsyncHttpResponse: 200 OK> 
    283 
    284        :param request: The network request you want to make. Required. 
    285        :type request: ~azure.core.rest.HttpRequest 
    286        :keyword bool stream: Whether the response payload will be streamed. Defaults to False. 
    287        :return: The response of your network call. Does not do error handling on your response. 
    288        :rtype: ~azure.core.rest.AsyncHttpResponse 
    289        """ 
    290        wrapped = self._make_pipeline_call(request, stream=stream, **kwargs) 
    291        return _Coroutine(wrapped=wrapped)