1# -------------------------------------------------------------------------- 
    2# 
    3# Copyright (c) Microsoft Corporation. All rights reserved. 
    4# 
    5# The MIT License (MIT) 
    6# 
    7# Permission is hereby granted, free of charge, to any person obtaining a copy 
    8# of this software and associated documentation files (the ""Software""), to 
    9# deal in the Software without restriction, including without limitation the 
    10# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 
    11# sell copies of the Software, and to permit persons to whom the Software is 
    12# furnished to do so, subject to the following conditions: 
    13# 
    14# The above copyright notice and this permission notice shall be included in 
    15# all copies or substantial portions of the Software. 
    16# 
    17# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 
    18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 
    19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 
    20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 
    21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 
    22# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 
    23# IN THE SOFTWARE. 
    24# 
    25# -------------------------------------------------------------------------- 
    26from __future__ import annotations 
    27from types import TracebackType 
    28from typing import ( 
    29    Any, 
    30    Union, 
    31    Generic, 
    32    TypeVar, 
    33    List, 
    34    Dict, 
    35    Optional, 
    36    Iterable, 
    37    Type, 
    38    AsyncContextManager, 
    39) 
    40 
    41from azure.core.pipeline import PipelineRequest, PipelineResponse, PipelineContext 
    42from azure.core.pipeline.policies import AsyncHTTPPolicy, SansIOHTTPPolicy 
    43from ._tools_async import await_result as _await_result 
    44from ._base import cleanup_kwargs_for_transport 
    45from .transport import AsyncHttpTransport 
    46 
    47AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType") 
    48HTTPRequestType = TypeVar("HTTPRequestType") 
    49 
    50 
    51class _SansIOAsyncHTTPPolicyRunner(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]): 
    52    """Async implementation of the SansIO policy. 
    53 
    54    Modifies the request and sends to the next policy in the chain. 
    55 
    56    :param policy: A SansIO policy. 
    57    :type policy: ~azure.core.pipeline.policies.SansIOHTTPPolicy 
    58    """ 
    59 
    60    def __init__(self, policy: SansIOHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]) -> None: 
    61        super(_SansIOAsyncHTTPPolicyRunner, self).__init__() 
    62        self._policy = policy 
    63 
    64    async def send( 
    65        self, request: PipelineRequest[HTTPRequestType] 
    66    ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]: 
    67        """Modifies the request and sends to the next policy in the chain. 
    68 
    69        :param request: The PipelineRequest object. 
    70        :type request: ~azure.core.pipeline.PipelineRequest 
    71        :return: The PipelineResponse object. 
    72        :rtype: ~azure.core.pipeline.PipelineResponse 
    73        """ 
    74        await _await_result(self._policy.on_request, request) 
    75        response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType] 
    76        try: 
    77            response = await self.next.send(request) 
    78        except Exception: 
    79            await _await_result(self._policy.on_exception, request) 
    80            raise 
    81        await _await_result(self._policy.on_response, request, response) 
    82        return response 
    83 
    84 
    85class _AsyncTransportRunner(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]): 
    86    """Async Transport runner. 
    87 
    88    Uses specified HTTP transport type to send request and returns response. 
    89 
    90    :param sender: The async Http Transport instance. 
    91    :type sender: ~azure.core.pipeline.transport.AsyncHttpTransport 
    92    """ 
    93 
    94    def __init__(self, sender: AsyncHttpTransport[HTTPRequestType, AsyncHTTPResponseType]) -> None: 
    95        super(_AsyncTransportRunner, self).__init__() 
    96        self._sender = sender 
    97 
    98    async def send( 
    99        self, request: PipelineRequest[HTTPRequestType] 
    100    ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]: 
    101        """Async HTTP transport send method. 
    102 
    103        :param request: The PipelineRequest object. 
    104        :type request: ~azure.core.pipeline.PipelineRequest 
    105        :return: The PipelineResponse object. 
    106        :rtype: ~azure.core.pipeline.PipelineResponse 
    107        """ 
    108        cleanup_kwargs_for_transport(request.context.options) 
    109        return PipelineResponse( 
    110            request.http_request, 
    111            await self._sender.send(request.http_request, **request.context.options), 
    112            request.context, 
    113        ) 
    114 
    115 
    116class AsyncPipeline( 
    117    AsyncContextManager["AsyncPipeline"], 
    118    Generic[HTTPRequestType, AsyncHTTPResponseType], 
    119): 
    120    """Async pipeline implementation. 
    121 
    122    This is implemented as a context manager, that will activate the context 
    123    of the HTTP sender. 
    124 
    125    :param transport: The async Http Transport instance. 
    126    :type transport: ~azure.core.pipeline.transport.AsyncHttpTransport 
    127    :param list policies: List of configured policies. 
    128 
    129    .. admonition:: Example: 
    130 
    131        .. literalinclude:: ../samples/test_example_async.py 
    132            :start-after: [START build_async_pipeline] 
    133            :end-before: [END build_async_pipeline] 
    134            :language: python 
    135            :dedent: 4 
    136            :caption: Builds the async pipeline for asynchronous transport. 
    137    """ 
    138 
    139    def __init__( 
    140        self, 
    141        transport: AsyncHttpTransport[HTTPRequestType, AsyncHTTPResponseType], 
    142        policies: Optional[ 
    143            Iterable[ 
    144                Union[ 
    145                    AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType], 
    146                    SansIOHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType], 
    147                ] 
    148            ] 
    149        ] = None, 
    150    ) -> None: 
    151        self._impl_policies: List[AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]] = [] 
    152        self._transport = transport 
    153 
    154        for policy in policies or []: 
    155            if isinstance(policy, SansIOHTTPPolicy): 
    156                self._impl_policies.append(_SansIOAsyncHTTPPolicyRunner(policy)) 
    157            elif policy: 
    158                self._impl_policies.append(policy) 
    159        for index in range(len(self._impl_policies) - 1): 
    160            self._impl_policies[index].next = self._impl_policies[index + 1] 
    161        if self._impl_policies: 
    162            self._impl_policies[-1].next = _AsyncTransportRunner(self._transport) 
    163 
    164    async def __aenter__(self) -> AsyncPipeline[HTTPRequestType, AsyncHTTPResponseType]: 
    165        await self._transport.__aenter__() 
    166        return self 
    167 
    168    async def __aexit__( 
    169        self, 
    170        exc_type: Optional[Type[BaseException]] = None, 
    171        exc_value: Optional[BaseException] = None, 
    172        traceback: Optional[TracebackType] = None, 
    173    ) -> None: 
    174        await self._transport.__aexit__(exc_type, exc_value, traceback) 
    175 
    176    async def _prepare_multipart_mixed_request(self, request: HTTPRequestType) -> None: 
    177        """Will execute the multipart policies. 
    178 
    179        Does nothing if "set_multipart_mixed" was never called. 
    180 
    181        :param request: The HTTP request object. 
    182        :type request: ~azure.core.rest.HttpRequest 
    183        """ 
    184        multipart_mixed_info = request.multipart_mixed_info  # type: ignore 
    185        if not multipart_mixed_info: 
    186            return 
    187 
    188        requests: List[HTTPRequestType] = multipart_mixed_info[0] 
    189        policies: List[SansIOHTTPPolicy] = multipart_mixed_info[1] 
    190        pipeline_options: Dict[str, Any] = multipart_mixed_info[3] 
    191 
    192        async def prepare_requests(req): 
    193            if req.multipart_mixed_info: 
    194                # Recursively update changeset "sub requests" 
    195                await self._prepare_multipart_mixed_request(req) 
    196            context = PipelineContext(None, **pipeline_options) 
    197            pipeline_request = PipelineRequest(req, context) 
    198            for policy in policies: 
    199                await _await_result(policy.on_request, pipeline_request) 
    200 
    201        # Not happy to make this code asyncio specific, but that's multipart only for now 
    202        # If we need trio and multipart, let's reinvesitgate that later 
    203        import asyncio  # pylint: disable=do-not-import-asyncio 
    204 
    205        await asyncio.gather(*[prepare_requests(req) for req in requests]) 
    206 
    207    async def _prepare_multipart(self, request: HTTPRequestType) -> None: 
    208        # This code is fine as long as HTTPRequestType is actually 
    209        # azure.core.pipeline.transport.HTTPRequest, bu we don't check it in here 
    210        # since we didn't see (yet) pipeline usage where it's not this actual instance 
    211        # class used 
    212        await self._prepare_multipart_mixed_request(request) 
    213        request.prepare_multipart_body()  # type: ignore 
    214 
    215    async def run( 
    216        self, request: HTTPRequestType, **kwargs: Any 
    217    ) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]: 
    218        """Runs the HTTP Request through the chained policies. 
    219 
    220        :param request: The HTTP request object. 
    221        :type request: ~azure.core.pipeline.transport.HttpRequest 
    222        :return: The PipelineResponse object. 
    223        :rtype: ~azure.core.pipeline.PipelineResponse 
    224        """ 
    225        await self._prepare_multipart(request) 
    226        context = PipelineContext(self._transport, **kwargs) 
    227        pipeline_request = PipelineRequest(request, context) 
    228        first_node = self._impl_policies[0] if self._impl_policies else _AsyncTransportRunner(self._transport) 
    229        return await first_node.send(pipeline_request)