1import enum 
    2import logging 
    3import ssl 
    4import time 
    5from types import TracebackType 
    6from typing import ( 
    7    Any, 
    8    AsyncIterable, 
    9    AsyncIterator, 
    10    List, 
    11    Optional, 
    12    Tuple, 
    13    Type, 
    14    Union, 
    15) 
    16 
    17import h11 
    18 
    19from .._backends.base import AsyncNetworkStream 
    20from .._exceptions import ( 
    21    ConnectionNotAvailable, 
    22    LocalProtocolError, 
    23    RemoteProtocolError, 
    24    WriteError, 
    25    map_exceptions, 
    26) 
    27from .._models import Origin, Request, Response 
    28from .._synchronization import AsyncLock, AsyncShieldCancellation 
    29from .._trace import Trace 
    30from .interfaces import AsyncConnectionInterface 
    31 
    32logger = logging.getLogger("httpcore.http11") 
    33 
    34 
    35# A subset of `h11.Event` types supported by `_send_event` 
    36H11SendEvent = Union[ 
    37    h11.Request, 
    38    h11.Data, 
    39    h11.EndOfMessage, 
    40] 
    41 
    42 
    43class HTTPConnectionState(enum.IntEnum): 
    44    NEW = 0 
    45    ACTIVE = 1 
    46    IDLE = 2 
    47    CLOSED = 3 
    48 
    49 
    50class AsyncHTTP11Connection(AsyncConnectionInterface): 
    51    READ_NUM_BYTES = 64 * 1024 
    52    MAX_INCOMPLETE_EVENT_SIZE = 100 * 1024 
    53 
    54    def __init__( 
    55        self, 
    56        origin: Origin, 
    57        stream: AsyncNetworkStream, 
    58        keepalive_expiry: Optional[float] = None, 
    59    ) -> None: 
    60        self._origin = origin 
    61        self._network_stream = stream 
    62        self._keepalive_expiry: Optional[float] = keepalive_expiry 
    63        self._expire_at: Optional[float] = None 
    64        self._state = HTTPConnectionState.NEW 
    65        self._state_lock = AsyncLock() 
    66        self._request_count = 0 
    67        self._h11_state = h11.Connection( 
    68            our_role=h11.CLIENT, 
    69            max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE, 
    70        ) 
    71 
    72    async def handle_async_request(self, request: Request) -> Response: 
    73        if not self.can_handle_request(request.url.origin): 
    74            raise RuntimeError( 
    75                f"Attempted to send request to {request.url.origin} on connection " 
    76                f"to {self._origin}" 
    77            ) 
    78 
    79        async with self._state_lock: 
    80            if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE): 
    81                self._request_count += 1 
    82                self._state = HTTPConnectionState.ACTIVE 
    83                self._expire_at = None 
    84            else: 
    85                raise ConnectionNotAvailable() 
    86 
    87        try: 
    88            kwargs = {"request": request} 
    89            try: 
    90                async with Trace( 
    91                    "send_request_headers", logger, request, kwargs 
    92                ) as trace: 
    93                    await self._send_request_headers(**kwargs) 
    94                async with Trace("send_request_body", logger, request, kwargs) as trace: 
    95                    await self._send_request_body(**kwargs) 
    96            except WriteError: 
    97                # If we get a write error while we're writing the request, 
    98                # then we supress this error and move on to attempting to 
    99                # read the response. Servers can sometimes close the request 
    100                # pre-emptively and then respond with a well formed HTTP 
    101                # error response. 
    102                pass 
    103 
    104            async with Trace( 
    105                "receive_response_headers", logger, request, kwargs 
    106            ) as trace: 
    107                ( 
    108                    http_version, 
    109                    status, 
    110                    reason_phrase, 
    111                    headers, 
    112                    trailing_data, 
    113                ) = await self._receive_response_headers(**kwargs) 
    114                trace.return_value = ( 
    115                    http_version, 
    116                    status, 
    117                    reason_phrase, 
    118                    headers, 
    119                ) 
    120 
    121            network_stream = self._network_stream 
    122 
    123            # CONNECT or Upgrade request 
    124            if (status == 101) or ( 
    125                (request.method == b"CONNECT") and (200 <= status < 300) 
    126            ): 
    127                network_stream = AsyncHTTP11UpgradeStream(network_stream, trailing_data) 
    128 
    129            return Response( 
    130                status=status, 
    131                headers=headers, 
    132                content=HTTP11ConnectionByteStream(self, request), 
    133                extensions={ 
    134                    "http_version": http_version, 
    135                    "reason_phrase": reason_phrase, 
    136                    "network_stream": network_stream, 
    137                }, 
    138            ) 
    139        except BaseException as exc: 
    140            with AsyncShieldCancellation(): 
    141                async with Trace("response_closed", logger, request) as trace: 
    142                    await self._response_closed() 
    143            raise exc 
    144 
    145    # Sending the request... 
    146 
    147    async def _send_request_headers(self, request: Request) -> None: 
    148        timeouts = request.extensions.get("timeout", {}) 
    149        timeout = timeouts.get("write", None) 
    150 
    151        with map_exceptions({h11.LocalProtocolError: LocalProtocolError}): 
    152            event = h11.Request( 
    153                method=request.method, 
    154                target=request.url.target, 
    155                headers=request.headers, 
    156            ) 
    157        await self._send_event(event, timeout=timeout) 
    158 
    159    async def _send_request_body(self, request: Request) -> None: 
    160        timeouts = request.extensions.get("timeout", {}) 
    161        timeout = timeouts.get("write", None) 
    162 
    163        assert isinstance(request.stream, AsyncIterable) 
    164        async for chunk in request.stream: 
    165            event = h11.Data(data=chunk) 
    166            await self._send_event(event, timeout=timeout) 
    167 
    168        await self._send_event(h11.EndOfMessage(), timeout=timeout) 
    169 
    170    async def _send_event( 
    171        self, event: h11.Event, timeout: Optional[float] = None 
    172    ) -> None: 
    173        bytes_to_send = self._h11_state.send(event) 
    174        if bytes_to_send is not None: 
    175            await self._network_stream.write(bytes_to_send, timeout=timeout) 
    176 
    177    # Receiving the response... 
    178 
    179    async def _receive_response_headers( 
    180        self, request: Request 
    181    ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], bytes]: 
    182        timeouts = request.extensions.get("timeout", {}) 
    183        timeout = timeouts.get("read", None) 
    184 
    185        while True: 
    186            event = await self._receive_event(timeout=timeout) 
    187            if isinstance(event, h11.Response): 
    188                break 
    189            if ( 
    190                isinstance(event, h11.InformationalResponse) 
    191                and event.status_code == 101 
    192            ): 
    193                break 
    194 
    195        http_version = b"HTTP/" + event.http_version 
    196 
    197        # h11 version 0.11+ supports a `raw_items` interface to get the 
    198        # raw header casing, rather than the enforced lowercase headers. 
    199        headers = event.headers.raw_items() 
    200 
    201        trailing_data, _ = self._h11_state.trailing_data 
    202 
    203        return http_version, event.status_code, event.reason, headers, trailing_data 
    204 
    205    async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes]: 
    206        timeouts = request.extensions.get("timeout", {}) 
    207        timeout = timeouts.get("read", None) 
    208 
    209        while True: 
    210            event = await self._receive_event(timeout=timeout) 
    211            if isinstance(event, h11.Data): 
    212                yield bytes(event.data) 
    213            elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)): 
    214                break 
    215 
    216    async def _receive_event( 
    217        self, timeout: Optional[float] = None 
    218    ) -> Union[h11.Event, Type[h11.PAUSED]]: 
    219        while True: 
    220            with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}): 
    221                event = self._h11_state.next_event() 
    222 
    223            if event is h11.NEED_DATA: 
    224                data = await self._network_stream.read( 
    225                    self.READ_NUM_BYTES, timeout=timeout 
    226                ) 
    227 
    228                # If we feed this case through h11 we'll raise an exception like: 
    229                # 
    230                #     httpcore.RemoteProtocolError: can't handle event type 
    231                #     ConnectionClosed when role=SERVER and state=SEND_RESPONSE 
    232                # 
    233                # Which is accurate, but not very informative from an end-user 
    234                # perspective. Instead we handle this case distinctly and treat 
    235                # it as a ConnectError. 
    236                if data == b"" and self._h11_state.their_state == h11.SEND_RESPONSE: 
    237                    msg = "Server disconnected without sending a response." 
    238                    raise RemoteProtocolError(msg) 
    239 
    240                self._h11_state.receive_data(data) 
    241            else: 
    242                # mypy fails to narrow the type in the above if statement above 
    243                return event  # type: ignore[return-value] 
    244 
    245    async def _response_closed(self) -> None: 
    246        async with self._state_lock: 
    247            if ( 
    248                self._h11_state.our_state is h11.DONE 
    249                and self._h11_state.their_state is h11.DONE 
    250            ): 
    251                self._state = HTTPConnectionState.IDLE 
    252                self._h11_state.start_next_cycle() 
    253                if self._keepalive_expiry is not None: 
    254                    now = time.monotonic() 
    255                    self._expire_at = now + self._keepalive_expiry 
    256            else: 
    257                await self.aclose() 
    258 
    259    # Once the connection is no longer required... 
    260 
    261    async def aclose(self) -> None: 
    262        # Note that this method unilaterally closes the connection, and does 
    263        # not have any kind of locking in place around it. 
    264        self._state = HTTPConnectionState.CLOSED 
    265        await self._network_stream.aclose() 
    266 
    267    # The AsyncConnectionInterface methods provide information about the state of 
    268    # the connection, allowing for a connection pooling implementation to 
    269    # determine when to reuse and when to close the connection... 
    270 
    271    def can_handle_request(self, origin: Origin) -> bool: 
    272        return origin == self._origin 
    273 
    274    def is_available(self) -> bool: 
    275        # Note that HTTP/1.1 connections in the "NEW" state are not treated as 
    276        # being "available". The control flow which created the connection will 
    277        # be able to send an outgoing request, but the connection will not be 
    278        # acquired from the connection pool for any other request. 
    279        return self._state == HTTPConnectionState.IDLE 
    280 
    281    def has_expired(self) -> bool: 
    282        now = time.monotonic() 
    283        keepalive_expired = self._expire_at is not None and now > self._expire_at 
    284 
    285        # If the HTTP connection is idle but the socket is readable, then the 
    286        # only valid state is that the socket is about to return b"", indicating 
    287        # a server-initiated disconnect. 
    288        server_disconnected = ( 
    289            self._state == HTTPConnectionState.IDLE 
    290            and self._network_stream.get_extra_info("is_readable") 
    291        ) 
    292 
    293        return keepalive_expired or server_disconnected 
    294 
    295    def is_idle(self) -> bool: 
    296        return self._state == HTTPConnectionState.IDLE 
    297 
    298    def is_closed(self) -> bool: 
    299        return self._state == HTTPConnectionState.CLOSED 
    300 
    301    def info(self) -> str: 
    302        origin = str(self._origin) 
    303        return ( 
    304            f"{origin!r}, HTTP/1.1, {self._state.name}, " 
    305            f"Request Count: {self._request_count}" 
    306        ) 
    307 
    308    def __repr__(self) -> str: 
    309        class_name = self.__class__.__name__ 
    310        origin = str(self._origin) 
    311        return ( 
    312            f"<{class_name} [{origin!r}, {self._state.name}, " 
    313            f"Request Count: {self._request_count}]>" 
    314        ) 
    315 
    316    # These context managers are not used in the standard flow, but are 
    317    # useful for testing or working with connection instances directly. 
    318 
    319    async def __aenter__(self) -> "AsyncHTTP11Connection": 
    320        return self 
    321 
    322    async def __aexit__( 
    323        self, 
    324        exc_type: Optional[Type[BaseException]] = None, 
    325        exc_value: Optional[BaseException] = None, 
    326        traceback: Optional[TracebackType] = None, 
    327    ) -> None: 
    328        await self.aclose() 
    329 
    330 
    331class HTTP11ConnectionByteStream: 
    332    def __init__(self, connection: AsyncHTTP11Connection, request: Request) -> None: 
    333        self._connection = connection 
    334        self._request = request 
    335        self._closed = False 
    336 
    337    async def __aiter__(self) -> AsyncIterator[bytes]: 
    338        kwargs = {"request": self._request} 
    339        try: 
    340            async with Trace("receive_response_body", logger, self._request, kwargs): 
    341                async for chunk in self._connection._receive_response_body(**kwargs): 
    342                    yield chunk 
    343        except BaseException as exc: 
    344            # If we get an exception while streaming the response, 
    345            # we want to close the response (and possibly the connection) 
    346            # before raising that exception. 
    347            with AsyncShieldCancellation(): 
    348                await self.aclose() 
    349            raise exc 
    350 
    351    async def aclose(self) -> None: 
    352        if not self._closed: 
    353            self._closed = True 
    354            async with Trace("response_closed", logger, self._request): 
    355                await self._connection._response_closed() 
    356 
    357 
    358class AsyncHTTP11UpgradeStream(AsyncNetworkStream): 
    359    def __init__(self, stream: AsyncNetworkStream, leading_data: bytes) -> None: 
    360        self._stream = stream 
    361        self._leading_data = leading_data 
    362 
    363    async def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes: 
    364        if self._leading_data: 
    365            buffer = self._leading_data[:max_bytes] 
    366            self._leading_data = self._leading_data[max_bytes:] 
    367            return buffer 
    368        else: 
    369            return await self._stream.read(max_bytes, timeout) 
    370 
    371    async def write(self, buffer: bytes, timeout: Optional[float] = None) -> None: 
    372        await self._stream.write(buffer, timeout) 
    373 
    374    async def aclose(self) -> None: 
    375        await self._stream.aclose() 
    376 
    377    async def start_tls( 
    378        self, 
    379        ssl_context: ssl.SSLContext, 
    380        server_hostname: Optional[str] = None, 
    381        timeout: Optional[float] = None, 
    382    ) -> AsyncNetworkStream: 
    383        return await self._stream.start_tls(ssl_context, server_hostname, timeout) 
    384 
    385    def get_extra_info(self, info: str) -> Any: 
    386        return self._stream.get_extra_info(info)