1import enum
2import logging
3import ssl
4import time
5from types import TracebackType
6from typing import (
7 Any,
8 Iterable,
9 Iterator,
10 List,
11 Optional,
12 Tuple,
13 Type,
14 Union,
15)
16
17import h11
18
19from .._backends.base import NetworkStream
20from .._exceptions import (
21 ConnectionNotAvailable,
22 LocalProtocolError,
23 RemoteProtocolError,
24 WriteError,
25 map_exceptions,
26)
27from .._models import Origin, Request, Response
28from .._synchronization import Lock, ShieldCancellation
29from .._trace import Trace
30from .interfaces import ConnectionInterface
31
32logger = logging.getLogger("httpcore.http11")
33
34
35# A subset of `h11.Event` types supported by `_send_event`
36H11SendEvent = Union[
37 h11.Request,
38 h11.Data,
39 h11.EndOfMessage,
40]
41
42
43class HTTPConnectionState(enum.IntEnum):
44 NEW = 0
45 ACTIVE = 1
46 IDLE = 2
47 CLOSED = 3
48
49
50class HTTP11Connection(ConnectionInterface):
51 READ_NUM_BYTES = 64 * 1024
52 MAX_INCOMPLETE_EVENT_SIZE = 100 * 1024
53
54 def __init__(
55 self,
56 origin: Origin,
57 stream: NetworkStream,
58 keepalive_expiry: Optional[float] = None,
59 ) -> None:
60 self._origin = origin
61 self._network_stream = stream
62 self._keepalive_expiry: Optional[float] = keepalive_expiry
63 self._expire_at: Optional[float] = None
64 self._state = HTTPConnectionState.NEW
65 self._state_lock = Lock()
66 self._request_count = 0
67 self._h11_state = h11.Connection(
68 our_role=h11.CLIENT,
69 max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE,
70 )
71
72 def handle_request(self, request: Request) -> Response:
73 if not self.can_handle_request(request.url.origin):
74 raise RuntimeError(
75 f"Attempted to send request to {request.url.origin} on connection "
76 f"to {self._origin}"
77 )
78
79 with self._state_lock:
80 if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE):
81 self._request_count += 1
82 self._state = HTTPConnectionState.ACTIVE
83 self._expire_at = None
84 else:
85 raise ConnectionNotAvailable()
86
87 try:
88 kwargs = {"request": request}
89 try:
90 with Trace(
91 "send_request_headers", logger, request, kwargs
92 ) as trace:
93 self._send_request_headers(**kwargs)
94 with Trace("send_request_body", logger, request, kwargs) as trace:
95 self._send_request_body(**kwargs)
96 except WriteError:
97 # If we get a write error while we're writing the request,
98 # then we supress this error and move on to attempting to
99 # read the response. Servers can sometimes close the request
100 # pre-emptively and then respond with a well formed HTTP
101 # error response.
102 pass
103
104 with Trace(
105 "receive_response_headers", logger, request, kwargs
106 ) as trace:
107 (
108 http_version,
109 status,
110 reason_phrase,
111 headers,
112 trailing_data,
113 ) = self._receive_response_headers(**kwargs)
114 trace.return_value = (
115 http_version,
116 status,
117 reason_phrase,
118 headers,
119 )
120
121 network_stream = self._network_stream
122
123 # CONNECT or Upgrade request
124 if (status == 101) or (
125 (request.method == b"CONNECT") and (200 <= status < 300)
126 ):
127 network_stream = HTTP11UpgradeStream(network_stream, trailing_data)
128
129 return Response(
130 status=status,
131 headers=headers,
132 content=HTTP11ConnectionByteStream(self, request),
133 extensions={
134 "http_version": http_version,
135 "reason_phrase": reason_phrase,
136 "network_stream": network_stream,
137 },
138 )
139 except BaseException as exc:
140 with ShieldCancellation():
141 with Trace("response_closed", logger, request) as trace:
142 self._response_closed()
143 raise exc
144
145 # Sending the request...
146
147 def _send_request_headers(self, request: Request) -> None:
148 timeouts = request.extensions.get("timeout", {})
149 timeout = timeouts.get("write", None)
150
151 with map_exceptions({h11.LocalProtocolError: LocalProtocolError}):
152 event = h11.Request(
153 method=request.method,
154 target=request.url.target,
155 headers=request.headers,
156 )
157 self._send_event(event, timeout=timeout)
158
159 def _send_request_body(self, request: Request) -> None:
160 timeouts = request.extensions.get("timeout", {})
161 timeout = timeouts.get("write", None)
162
163 assert isinstance(request.stream, Iterable)
164 for chunk in request.stream:
165 event = h11.Data(data=chunk)
166 self._send_event(event, timeout=timeout)
167
168 self._send_event(h11.EndOfMessage(), timeout=timeout)
169
170 def _send_event(
171 self, event: h11.Event, timeout: Optional[float] = None
172 ) -> None:
173 bytes_to_send = self._h11_state.send(event)
174 if bytes_to_send is not None:
175 self._network_stream.write(bytes_to_send, timeout=timeout)
176
177 # Receiving the response...
178
179 def _receive_response_headers(
180 self, request: Request
181 ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], bytes]:
182 timeouts = request.extensions.get("timeout", {})
183 timeout = timeouts.get("read", None)
184
185 while True:
186 event = self._receive_event(timeout=timeout)
187 if isinstance(event, h11.Response):
188 break
189 if (
190 isinstance(event, h11.InformationalResponse)
191 and event.status_code == 101
192 ):
193 break
194
195 http_version = b"HTTP/" + event.http_version
196
197 # h11 version 0.11+ supports a `raw_items` interface to get the
198 # raw header casing, rather than the enforced lowercase headers.
199 headers = event.headers.raw_items()
200
201 trailing_data, _ = self._h11_state.trailing_data
202
203 return http_version, event.status_code, event.reason, headers, trailing_data
204
205 def _receive_response_body(self, request: Request) -> Iterator[bytes]:
206 timeouts = request.extensions.get("timeout", {})
207 timeout = timeouts.get("read", None)
208
209 while True:
210 event = self._receive_event(timeout=timeout)
211 if isinstance(event, h11.Data):
212 yield bytes(event.data)
213 elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)):
214 break
215
216 def _receive_event(
217 self, timeout: Optional[float] = None
218 ) -> Union[h11.Event, Type[h11.PAUSED]]:
219 while True:
220 with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}):
221 event = self._h11_state.next_event()
222
223 if event is h11.NEED_DATA:
224 data = self._network_stream.read(
225 self.READ_NUM_BYTES, timeout=timeout
226 )
227
228 # If we feed this case through h11 we'll raise an exception like:
229 #
230 # httpcore.RemoteProtocolError: can't handle event type
231 # ConnectionClosed when role=SERVER and state=SEND_RESPONSE
232 #
233 # Which is accurate, but not very informative from an end-user
234 # perspective. Instead we handle this case distinctly and treat
235 # it as a ConnectError.
236 if data == b"" and self._h11_state.their_state == h11.SEND_RESPONSE:
237 msg = "Server disconnected without sending a response."
238 raise RemoteProtocolError(msg)
239
240 self._h11_state.receive_data(data)
241 else:
242 # mypy fails to narrow the type in the above if statement above
243 return event # type: ignore[return-value]
244
245 def _response_closed(self) -> None:
246 with self._state_lock:
247 if (
248 self._h11_state.our_state is h11.DONE
249 and self._h11_state.their_state is h11.DONE
250 ):
251 self._state = HTTPConnectionState.IDLE
252 self._h11_state.start_next_cycle()
253 if self._keepalive_expiry is not None:
254 now = time.monotonic()
255 self._expire_at = now + self._keepalive_expiry
256 else:
257 self.close()
258
259 # Once the connection is no longer required...
260
261 def close(self) -> None:
262 # Note that this method unilaterally closes the connection, and does
263 # not have any kind of locking in place around it.
264 self._state = HTTPConnectionState.CLOSED
265 self._network_stream.close()
266
267 # The ConnectionInterface methods provide information about the state of
268 # the connection, allowing for a connection pooling implementation to
269 # determine when to reuse and when to close the connection...
270
271 def can_handle_request(self, origin: Origin) -> bool:
272 return origin == self._origin
273
274 def is_available(self) -> bool:
275 # Note that HTTP/1.1 connections in the "NEW" state are not treated as
276 # being "available". The control flow which created the connection will
277 # be able to send an outgoing request, but the connection will not be
278 # acquired from the connection pool for any other request.
279 return self._state == HTTPConnectionState.IDLE
280
281 def has_expired(self) -> bool:
282 now = time.monotonic()
283 keepalive_expired = self._expire_at is not None and now > self._expire_at
284
285 # If the HTTP connection is idle but the socket is readable, then the
286 # only valid state is that the socket is about to return b"", indicating
287 # a server-initiated disconnect.
288 server_disconnected = (
289 self._state == HTTPConnectionState.IDLE
290 and self._network_stream.get_extra_info("is_readable")
291 )
292
293 return keepalive_expired or server_disconnected
294
295 def is_idle(self) -> bool:
296 return self._state == HTTPConnectionState.IDLE
297
298 def is_closed(self) -> bool:
299 return self._state == HTTPConnectionState.CLOSED
300
301 def info(self) -> str:
302 origin = str(self._origin)
303 return (
304 f"{origin!r}, HTTP/1.1, {self._state.name}, "
305 f"Request Count: {self._request_count}"
306 )
307
308 def __repr__(self) -> str:
309 class_name = self.__class__.__name__
310 origin = str(self._origin)
311 return (
312 f"<{class_name} [{origin!r}, {self._state.name}, "
313 f"Request Count: {self._request_count}]>"
314 )
315
316 # These context managers are not used in the standard flow, but are
317 # useful for testing or working with connection instances directly.
318
319 def __enter__(self) -> "HTTP11Connection":
320 return self
321
322 def __exit__(
323 self,
324 exc_type: Optional[Type[BaseException]] = None,
325 exc_value: Optional[BaseException] = None,
326 traceback: Optional[TracebackType] = None,
327 ) -> None:
328 self.close()
329
330
331class HTTP11ConnectionByteStream:
332 def __init__(self, connection: HTTP11Connection, request: Request) -> None:
333 self._connection = connection
334 self._request = request
335 self._closed = False
336
337 def __iter__(self) -> Iterator[bytes]:
338 kwargs = {"request": self._request}
339 try:
340 with Trace("receive_response_body", logger, self._request, kwargs):
341 for chunk in self._connection._receive_response_body(**kwargs):
342 yield chunk
343 except BaseException as exc:
344 # If we get an exception while streaming the response,
345 # we want to close the response (and possibly the connection)
346 # before raising that exception.
347 with ShieldCancellation():
348 self.close()
349 raise exc
350
351 def close(self) -> None:
352 if not self._closed:
353 self._closed = True
354 with Trace("response_closed", logger, self._request):
355 self._connection._response_closed()
356
357
358class HTTP11UpgradeStream(NetworkStream):
359 def __init__(self, stream: NetworkStream, leading_data: bytes) -> None:
360 self._stream = stream
361 self._leading_data = leading_data
362
363 def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
364 if self._leading_data:
365 buffer = self._leading_data[:max_bytes]
366 self._leading_data = self._leading_data[max_bytes:]
367 return buffer
368 else:
369 return self._stream.read(max_bytes, timeout)
370
371 def write(self, buffer: bytes, timeout: Optional[float] = None) -> None:
372 self._stream.write(buffer, timeout)
373
374 def close(self) -> None:
375 self._stream.close()
376
377 def start_tls(
378 self,
379 ssl_context: ssl.SSLContext,
380 server_hostname: Optional[str] = None,
381 timeout: Optional[float] = None,
382 ) -> NetworkStream:
383 return self._stream.start_tls(ssl_context, server_hostname, timeout)
384
385 def get_extra_info(self, info: str) -> Any:
386 return self._stream.get_extra_info(info)