1from __future__ import annotations 
    2 
    3import enum 
    4import logging 
    5import ssl 
    6import time 
    7import types 
    8import typing 
    9 
    10import h11 
    11 
    12from .._backends.base import AsyncNetworkStream 
    13from .._exceptions import ( 
    14    ConnectionNotAvailable, 
    15    LocalProtocolError, 
    16    RemoteProtocolError, 
    17    WriteError, 
    18    map_exceptions, 
    19) 
    20from .._models import Origin, Request, Response 
    21from .._synchronization import AsyncLock, AsyncShieldCancellation 
    22from .._trace import Trace 
    23from .interfaces import AsyncConnectionInterface 
    24 
    25logger = logging.getLogger("httpcore.http11") 
    26 
    27 
    28# A subset of `h11.Event` types supported by `_send_event` 
    29H11SendEvent = typing.Union[ 
    30    h11.Request, 
    31    h11.Data, 
    32    h11.EndOfMessage, 
    33] 
    34 
    35 
    36class HTTPConnectionState(enum.IntEnum): 
    37    NEW = 0 
    38    ACTIVE = 1 
    39    IDLE = 2 
    40    CLOSED = 3 
    41 
    42 
    43class AsyncHTTP11Connection(AsyncConnectionInterface): 
    44    READ_NUM_BYTES = 64 * 1024 
    45    MAX_INCOMPLETE_EVENT_SIZE = 100 * 1024 
    46 
    47    def __init__( 
    48        self, 
    49        origin: Origin, 
    50        stream: AsyncNetworkStream, 
    51        keepalive_expiry: float | None = None, 
    52    ) -> None: 
    53        self._origin = origin 
    54        self._network_stream = stream 
    55        self._keepalive_expiry: float | None = keepalive_expiry 
    56        self._expire_at: float | None = None 
    57        self._state = HTTPConnectionState.NEW 
    58        self._state_lock = AsyncLock() 
    59        self._request_count = 0 
    60        self._h11_state = h11.Connection( 
    61            our_role=h11.CLIENT, 
    62            max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE, 
    63        ) 
    64 
    65    async def handle_async_request(self, request: Request) -> Response: 
    66        if not self.can_handle_request(request.url.origin): 
    67            raise RuntimeError( 
    68                f"Attempted to send request to {request.url.origin} on connection " 
    69                f"to {self._origin}" 
    70            ) 
    71 
    72        async with self._state_lock: 
    73            if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE): 
    74                self._request_count += 1 
    75                self._state = HTTPConnectionState.ACTIVE 
    76                self._expire_at = None 
    77            else: 
    78                raise ConnectionNotAvailable() 
    79 
    80        try: 
    81            kwargs = {"request": request} 
    82            try: 
    83                async with Trace( 
    84                    "send_request_headers", logger, request, kwargs 
    85                ) as trace: 
    86                    await self._send_request_headers(**kwargs) 
    87                async with Trace("send_request_body", logger, request, kwargs) as trace: 
    88                    await self._send_request_body(**kwargs) 
    89            except WriteError: 
    90                # If we get a write error while we're writing the request, 
    91                # then we supress this error and move on to attempting to 
    92                # read the response. Servers can sometimes close the request 
    93                # pre-emptively and then respond with a well formed HTTP 
    94                # error response. 
    95                pass 
    96 
    97            async with Trace( 
    98                "receive_response_headers", logger, request, kwargs 
    99            ) as trace: 
    100                ( 
    101                    http_version, 
    102                    status, 
    103                    reason_phrase, 
    104                    headers, 
    105                    trailing_data, 
    106                ) = await self._receive_response_headers(**kwargs) 
    107                trace.return_value = ( 
    108                    http_version, 
    109                    status, 
    110                    reason_phrase, 
    111                    headers, 
    112                ) 
    113 
    114            network_stream = self._network_stream 
    115 
    116            # CONNECT or Upgrade request 
    117            if (status == 101) or ( 
    118                (request.method == b"CONNECT") and (200 <= status < 300) 
    119            ): 
    120                network_stream = AsyncHTTP11UpgradeStream(network_stream, trailing_data) 
    121 
    122            return Response( 
    123                status=status, 
    124                headers=headers, 
    125                content=HTTP11ConnectionByteStream(self, request), 
    126                extensions={ 
    127                    "http_version": http_version, 
    128                    "reason_phrase": reason_phrase, 
    129                    "network_stream": network_stream, 
    130                }, 
    131            ) 
    132        except BaseException as exc: 
    133            with AsyncShieldCancellation(): 
    134                async with Trace("response_closed", logger, request) as trace: 
    135                    await self._response_closed() 
    136            raise exc 
    137 
    138    # Sending the request... 
    139 
    140    async def _send_request_headers(self, request: Request) -> None: 
    141        timeouts = request.extensions.get("timeout", {}) 
    142        timeout = timeouts.get("write", None) 
    143 
    144        with map_exceptions({h11.LocalProtocolError: LocalProtocolError}): 
    145            event = h11.Request( 
    146                method=request.method, 
    147                target=request.url.target, 
    148                headers=request.headers, 
    149            ) 
    150        await self._send_event(event, timeout=timeout) 
    151 
    152    async def _send_request_body(self, request: Request) -> None: 
    153        timeouts = request.extensions.get("timeout", {}) 
    154        timeout = timeouts.get("write", None) 
    155 
    156        assert isinstance(request.stream, typing.AsyncIterable) 
    157        async for chunk in request.stream: 
    158            event = h11.Data(data=chunk) 
    159            await self._send_event(event, timeout=timeout) 
    160 
    161        await self._send_event(h11.EndOfMessage(), timeout=timeout) 
    162 
    163    async def _send_event(self, event: h11.Event, timeout: float | None = None) -> None: 
    164        bytes_to_send = self._h11_state.send(event) 
    165        if bytes_to_send is not None: 
    166            await self._network_stream.write(bytes_to_send, timeout=timeout) 
    167 
    168    # Receiving the response... 
    169 
    170    async def _receive_response_headers( 
    171        self, request: Request 
    172    ) -> tuple[bytes, int, bytes, list[tuple[bytes, bytes]], bytes]: 
    173        timeouts = request.extensions.get("timeout", {}) 
    174        timeout = timeouts.get("read", None) 
    175 
    176        while True: 
    177            event = await self._receive_event(timeout=timeout) 
    178            if isinstance(event, h11.Response): 
    179                break 
    180            if ( 
    181                isinstance(event, h11.InformationalResponse) 
    182                and event.status_code == 101 
    183            ): 
    184                break 
    185 
    186        http_version = b"HTTP/" + event.http_version 
    187 
    188        # h11 version 0.11+ supports a `raw_items` interface to get the 
    189        # raw header casing, rather than the enforced lowercase headers. 
    190        headers = event.headers.raw_items() 
    191 
    192        trailing_data, _ = self._h11_state.trailing_data 
    193 
    194        return http_version, event.status_code, event.reason, headers, trailing_data 
    195 
    196    async def _receive_response_body( 
    197        self, request: Request 
    198    ) -> typing.AsyncIterator[bytes]: 
    199        timeouts = request.extensions.get("timeout", {}) 
    200        timeout = timeouts.get("read", None) 
    201 
    202        while True: 
    203            event = await self._receive_event(timeout=timeout) 
    204            if isinstance(event, h11.Data): 
    205                yield bytes(event.data) 
    206            elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)): 
    207                break 
    208 
    209    async def _receive_event( 
    210        self, timeout: float | None = None 
    211    ) -> h11.Event | type[h11.PAUSED]: 
    212        while True: 
    213            with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}): 
    214                event = self._h11_state.next_event() 
    215 
    216            if event is h11.NEED_DATA: 
    217                data = await self._network_stream.read( 
    218                    self.READ_NUM_BYTES, timeout=timeout 
    219                ) 
    220 
    221                # If we feed this case through h11 we'll raise an exception like: 
    222                # 
    223                #     httpcore.RemoteProtocolError: can't handle event type 
    224                #     ConnectionClosed when role=SERVER and state=SEND_RESPONSE 
    225                # 
    226                # Which is accurate, but not very informative from an end-user 
    227                # perspective. Instead we handle this case distinctly and treat 
    228                # it as a ConnectError. 
    229                if data == b"" and self._h11_state.their_state == h11.SEND_RESPONSE: 
    230                    msg = "Server disconnected without sending a response." 
    231                    raise RemoteProtocolError(msg) 
    232 
    233                self._h11_state.receive_data(data) 
    234            else: 
    235                # mypy fails to narrow the type in the above if statement above 
    236                return event  # type: ignore[return-value] 
    237 
    238    async def _response_closed(self) -> None: 
    239        async with self._state_lock: 
    240            if ( 
    241                self._h11_state.our_state is h11.DONE 
    242                and self._h11_state.their_state is h11.DONE 
    243            ): 
    244                self._state = HTTPConnectionState.IDLE 
    245                self._h11_state.start_next_cycle() 
    246                if self._keepalive_expiry is not None: 
    247                    now = time.monotonic() 
    248                    self._expire_at = now + self._keepalive_expiry 
    249            else: 
    250                await self.aclose() 
    251 
    252    # Once the connection is no longer required... 
    253 
    254    async def aclose(self) -> None: 
    255        # Note that this method unilaterally closes the connection, and does 
    256        # not have any kind of locking in place around it. 
    257        self._state = HTTPConnectionState.CLOSED 
    258        await self._network_stream.aclose() 
    259 
    260    # The AsyncConnectionInterface methods provide information about the state of 
    261    # the connection, allowing for a connection pooling implementation to 
    262    # determine when to reuse and when to close the connection... 
    263 
    264    def can_handle_request(self, origin: Origin) -> bool: 
    265        return origin == self._origin 
    266 
    267    def is_available(self) -> bool: 
    268        # Note that HTTP/1.1 connections in the "NEW" state are not treated as 
    269        # being "available". The control flow which created the connection will 
    270        # be able to send an outgoing request, but the connection will not be 
    271        # acquired from the connection pool for any other request. 
    272        return self._state == HTTPConnectionState.IDLE 
    273 
    274    def has_expired(self) -> bool: 
    275        now = time.monotonic() 
    276        keepalive_expired = self._expire_at is not None and now > self._expire_at 
    277 
    278        # If the HTTP connection is idle but the socket is readable, then the 
    279        # only valid state is that the socket is about to return b"", indicating 
    280        # a server-initiated disconnect. 
    281        server_disconnected = ( 
    282            self._state == HTTPConnectionState.IDLE 
    283            and self._network_stream.get_extra_info("is_readable") 
    284        ) 
    285 
    286        return keepalive_expired or server_disconnected 
    287 
    288    def is_idle(self) -> bool: 
    289        return self._state == HTTPConnectionState.IDLE 
    290 
    291    def is_closed(self) -> bool: 
    292        return self._state == HTTPConnectionState.CLOSED 
    293 
    294    def info(self) -> str: 
    295        origin = str(self._origin) 
    296        return ( 
    297            f"{origin!r}, HTTP/1.1, {self._state.name}, " 
    298            f"Request Count: {self._request_count}" 
    299        ) 
    300 
    301    def __repr__(self) -> str: 
    302        class_name = self.__class__.__name__ 
    303        origin = str(self._origin) 
    304        return ( 
    305            f"<{class_name} [{origin!r}, {self._state.name}, " 
    306            f"Request Count: {self._request_count}]>" 
    307        ) 
    308 
    309    # These context managers are not used in the standard flow, but are 
    310    # useful for testing or working with connection instances directly. 
    311 
    312    async def __aenter__(self) -> AsyncHTTP11Connection: 
    313        return self 
    314 
    315    async def __aexit__( 
    316        self, 
    317        exc_type: type[BaseException] | None = None, 
    318        exc_value: BaseException | None = None, 
    319        traceback: types.TracebackType | None = None, 
    320    ) -> None: 
    321        await self.aclose() 
    322 
    323 
    324class HTTP11ConnectionByteStream: 
    325    def __init__(self, connection: AsyncHTTP11Connection, request: Request) -> None: 
    326        self._connection = connection 
    327        self._request = request 
    328        self._closed = False 
    329 
    330    async def __aiter__(self) -> typing.AsyncIterator[bytes]: 
    331        kwargs = {"request": self._request} 
    332        try: 
    333            async with Trace("receive_response_body", logger, self._request, kwargs): 
    334                async for chunk in self._connection._receive_response_body(**kwargs): 
    335                    yield chunk 
    336        except BaseException as exc: 
    337            # If we get an exception while streaming the response, 
    338            # we want to close the response (and possibly the connection) 
    339            # before raising that exception. 
    340            with AsyncShieldCancellation(): 
    341                await self.aclose() 
    342            raise exc 
    343 
    344    async def aclose(self) -> None: 
    345        if not self._closed: 
    346            self._closed = True 
    347            async with Trace("response_closed", logger, self._request): 
    348                await self._connection._response_closed() 
    349 
    350 
    351class AsyncHTTP11UpgradeStream(AsyncNetworkStream): 
    352    def __init__(self, stream: AsyncNetworkStream, leading_data: bytes) -> None: 
    353        self._stream = stream 
    354        self._leading_data = leading_data 
    355 
    356    async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: 
    357        if self._leading_data: 
    358            buffer = self._leading_data[:max_bytes] 
    359            self._leading_data = self._leading_data[max_bytes:] 
    360            return buffer 
    361        else: 
    362            return await self._stream.read(max_bytes, timeout) 
    363 
    364    async def write(self, buffer: bytes, timeout: float | None = None) -> None: 
    365        await self._stream.write(buffer, timeout) 
    366 
    367    async def aclose(self) -> None: 
    368        await self._stream.aclose() 
    369 
    370    async def start_tls( 
    371        self, 
    372        ssl_context: ssl.SSLContext, 
    373        server_hostname: str | None = None, 
    374        timeout: float | None = None, 
    375    ) -> AsyncNetworkStream: 
    376        return await self._stream.start_tls(ssl_context, server_hostname, timeout) 
    377 
    378    def get_extra_info(self, info: str) -> typing.Any: 
    379        return self._stream.get_extra_info(info)