1from __future__ import annotations
2
3import enum
4import logging
5import time
6import types
7import typing
8
9import h2.config
10import h2.connection
11import h2.events
12import h2.exceptions
13import h2.settings
14
15from .._backends.base import AsyncNetworkStream
16from .._exceptions import (
17 ConnectionNotAvailable,
18 LocalProtocolError,
19 RemoteProtocolError,
20)
21from .._models import Origin, Request, Response
22from .._synchronization import AsyncLock, AsyncSemaphore, AsyncShieldCancellation
23from .._trace import Trace
24from .interfaces import AsyncConnectionInterface
25
26logger = logging.getLogger("httpcore.http2")
27
28
29def has_body_headers(request: Request) -> bool:
30 return any(
31 k.lower() == b"content-length" or k.lower() == b"transfer-encoding"
32 for k, v in request.headers
33 )
34
35
36class HTTPConnectionState(enum.IntEnum):
37 ACTIVE = 1
38 IDLE = 2
39 CLOSED = 3
40
41
42class AsyncHTTP2Connection(AsyncConnectionInterface):
43 READ_NUM_BYTES = 64 * 1024
44 CONFIG = h2.config.H2Configuration(validate_inbound_headers=False)
45
46 def __init__(
47 self,
48 origin: Origin,
49 stream: AsyncNetworkStream,
50 keepalive_expiry: float | None = None,
51 ):
52 self._origin = origin
53 self._network_stream = stream
54 self._keepalive_expiry: float | None = keepalive_expiry
55 self._h2_state = h2.connection.H2Connection(config=self.CONFIG)
56 self._state = HTTPConnectionState.IDLE
57 self._expire_at: float | None = None
58 self._request_count = 0
59 self._init_lock = AsyncLock()
60 self._state_lock = AsyncLock()
61 self._read_lock = AsyncLock()
62 self._write_lock = AsyncLock()
63 self._sent_connection_init = False
64 self._used_all_stream_ids = False
65 self._connection_error = False
66
67 # Mapping from stream ID to response stream events.
68 self._events: dict[
69 int,
70 list[
71 h2.events.ResponseReceived
72 | h2.events.DataReceived
73 | h2.events.StreamEnded
74 | h2.events.StreamReset,
75 ],
76 ] = {}
77
78 # Connection terminated events are stored as state since
79 # we need to handle them for all streams.
80 self._connection_terminated: h2.events.ConnectionTerminated | None = None
81
82 self._read_exception: Exception | None = None
83 self._write_exception: Exception | None = None
84
85 async def handle_async_request(self, request: Request) -> Response:
86 if not self.can_handle_request(request.url.origin):
87 # This cannot occur in normal operation, since the connection pool
88 # will only send requests on connections that handle them.
89 # It's in place simply for resilience as a guard against incorrect
90 # usage, for anyone working directly with httpcore connections.
91 raise RuntimeError(
92 f"Attempted to send request to {request.url.origin} on connection "
93 f"to {self._origin}"
94 )
95
96 async with self._state_lock:
97 if self._state in (HTTPConnectionState.ACTIVE, HTTPConnectionState.IDLE):
98 self._request_count += 1
99 self._expire_at = None
100 self._state = HTTPConnectionState.ACTIVE
101 else:
102 raise ConnectionNotAvailable()
103
104 async with self._init_lock:
105 if not self._sent_connection_init:
106 try:
107 sci_kwargs = {"request": request}
108 async with Trace(
109 "send_connection_init", logger, request, sci_kwargs
110 ):
111 await self._send_connection_init(**sci_kwargs)
112 except BaseException as exc:
113 with AsyncShieldCancellation():
114 await self.aclose()
115 raise exc
116
117 self._sent_connection_init = True
118
119 # Initially start with just 1 until the remote server provides
120 # its max_concurrent_streams value
121 self._max_streams = 1
122
123 local_settings_max_streams = (
124 self._h2_state.local_settings.max_concurrent_streams
125 )
126 self._max_streams_semaphore = AsyncSemaphore(local_settings_max_streams)
127
128 for _ in range(local_settings_max_streams - self._max_streams):
129 await self._max_streams_semaphore.acquire()
130
131 await self._max_streams_semaphore.acquire()
132
133 try:
134 stream_id = self._h2_state.get_next_available_stream_id()
135 self._events[stream_id] = []
136 except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover
137 self._used_all_stream_ids = True
138 self._request_count -= 1
139 raise ConnectionNotAvailable()
140
141 try:
142 kwargs = {"request": request, "stream_id": stream_id}
143 async with Trace("send_request_headers", logger, request, kwargs):
144 await self._send_request_headers(request=request, stream_id=stream_id)
145 async with Trace("send_request_body", logger, request, kwargs):
146 await self._send_request_body(request=request, stream_id=stream_id)
147 async with Trace(
148 "receive_response_headers", logger, request, kwargs
149 ) as trace:
150 status, headers = await self._receive_response(
151 request=request, stream_id=stream_id
152 )
153 trace.return_value = (status, headers)
154
155 return Response(
156 status=status,
157 headers=headers,
158 content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id),
159 extensions={
160 "http_version": b"HTTP/2",
161 "network_stream": self._network_stream,
162 "stream_id": stream_id,
163 },
164 )
165 except BaseException as exc: # noqa: PIE786
166 with AsyncShieldCancellation():
167 kwargs = {"stream_id": stream_id}
168 async with Trace("response_closed", logger, request, kwargs):
169 await self._response_closed(stream_id=stream_id)
170
171 if isinstance(exc, h2.exceptions.ProtocolError):
172 # One case where h2 can raise a protocol error is when a
173 # closed frame has been seen by the state machine.
174 #
175 # This happens when one stream is reading, and encounters
176 # a GOAWAY event. Other flows of control may then raise
177 # a protocol error at any point they interact with the 'h2_state'.
178 #
179 # In this case we'll have stored the event, and should raise
180 # it as a RemoteProtocolError.
181 if self._connection_terminated: # pragma: nocover
182 raise RemoteProtocolError(self._connection_terminated)
183 # If h2 raises a protocol error in some other state then we
184 # must somehow have made a protocol violation.
185 raise LocalProtocolError(exc) # pragma: nocover
186
187 raise exc
188
189 async def _send_connection_init(self, request: Request) -> None:
190 """
191 The HTTP/2 connection requires some initial setup before we can start
192 using individual request/response streams on it.
193 """
194 # Need to set these manually here instead of manipulating via
195 # __setitem__() otherwise the H2Connection will emit SettingsUpdate
196 # frames in addition to sending the undesired defaults.
197 self._h2_state.local_settings = h2.settings.Settings(
198 client=True,
199 initial_values={
200 # Disable PUSH_PROMISE frames from the server since we don't do anything
201 # with them for now. Maybe when we support caching?
202 h2.settings.SettingCodes.ENABLE_PUSH: 0,
203 # These two are taken from h2 for safe defaults
204 h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 100,
205 h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: 65536,
206 },
207 )
208
209 # Some websites (*cough* Yahoo *cough*) balk at this setting being
210 # present in the initial handshake since it's not defined in the original
211 # RFC despite the RFC mandating ignoring settings you don't know about.
212 del self._h2_state.local_settings[
213 h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL
214 ]
215
216 self._h2_state.initiate_connection()
217 self._h2_state.increment_flow_control_window(2**24)
218 await self._write_outgoing_data(request)
219
220 # Sending the request...
221
222 async def _send_request_headers(self, request: Request, stream_id: int) -> None:
223 """
224 Send the request headers to a given stream ID.
225 """
226 end_stream = not has_body_headers(request)
227
228 # In HTTP/2 the ':authority' pseudo-header is used instead of 'Host'.
229 # In order to gracefully handle HTTP/1.1 and HTTP/2 we always require
230 # HTTP/1.1 style headers, and map them appropriately if we end up on
231 # an HTTP/2 connection.
232 authority = [v for k, v in request.headers if k.lower() == b"host"][0]
233
234 headers = [
235 (b":method", request.method),
236 (b":authority", authority),
237 (b":scheme", request.url.scheme),
238 (b":path", request.url.target),
239 ] + [
240 (k.lower(), v)
241 for k, v in request.headers
242 if k.lower()
243 not in (
244 b"host",
245 b"transfer-encoding",
246 )
247 ]
248
249 self._h2_state.send_headers(stream_id, headers, end_stream=end_stream)
250 self._h2_state.increment_flow_control_window(2**24, stream_id=stream_id)
251 await self._write_outgoing_data(request)
252
253 async def _send_request_body(self, request: Request, stream_id: int) -> None:
254 """
255 Iterate over the request body sending it to a given stream ID.
256 """
257 if not has_body_headers(request):
258 return
259
260 assert isinstance(request.stream, typing.AsyncIterable)
261 async for data in request.stream:
262 await self._send_stream_data(request, stream_id, data)
263 await self._send_end_stream(request, stream_id)
264
265 async def _send_stream_data(
266 self, request: Request, stream_id: int, data: bytes
267 ) -> None:
268 """
269 Send a single chunk of data in one or more data frames.
270 """
271 while data:
272 max_flow = await self._wait_for_outgoing_flow(request, stream_id)
273 chunk_size = min(len(data), max_flow)
274 chunk, data = data[:chunk_size], data[chunk_size:]
275 self._h2_state.send_data(stream_id, chunk)
276 await self._write_outgoing_data(request)
277
278 async def _send_end_stream(self, request: Request, stream_id: int) -> None:
279 """
280 Send an empty data frame on on a given stream ID with the END_STREAM flag set.
281 """
282 self._h2_state.end_stream(stream_id)
283 await self._write_outgoing_data(request)
284
285 # Receiving the response...
286
287 async def _receive_response(
288 self, request: Request, stream_id: int
289 ) -> tuple[int, list[tuple[bytes, bytes]]]:
290 """
291 Return the response status code and headers for a given stream ID.
292 """
293 while True:
294 event = await self._receive_stream_event(request, stream_id)
295 if isinstance(event, h2.events.ResponseReceived):
296 break
297
298 status_code = 200
299 headers = []
300 assert event.headers is not None
301 for k, v in event.headers:
302 if k == b":status":
303 status_code = int(v.decode("ascii", errors="ignore"))
304 elif not k.startswith(b":"):
305 headers.append((k, v))
306
307 return (status_code, headers)
308
309 async def _receive_response_body(
310 self, request: Request, stream_id: int
311 ) -> typing.AsyncIterator[bytes]:
312 """
313 Iterator that returns the bytes of the response body for a given stream ID.
314 """
315 while True:
316 event = await self._receive_stream_event(request, stream_id)
317 if isinstance(event, h2.events.DataReceived):
318 assert event.flow_controlled_length is not None
319 assert event.data is not None
320 amount = event.flow_controlled_length
321 self._h2_state.acknowledge_received_data(amount, stream_id)
322 await self._write_outgoing_data(request)
323 yield event.data
324 elif isinstance(event, h2.events.StreamEnded):
325 break
326
327 async def _receive_stream_event(
328 self, request: Request, stream_id: int
329 ) -> h2.events.ResponseReceived | h2.events.DataReceived | h2.events.StreamEnded:
330 """
331 Return the next available event for a given stream ID.
332
333 Will read more data from the network if required.
334 """
335 while not self._events.get(stream_id):
336 await self._receive_events(request, stream_id)
337 event = self._events[stream_id].pop(0)
338 if isinstance(event, h2.events.StreamReset):
339 raise RemoteProtocolError(event)
340 return event
341
342 async def _receive_events(
343 self, request: Request, stream_id: int | None = None
344 ) -> None:
345 """
346 Read some data from the network until we see one or more events
347 for a given stream ID.
348 """
349 async with self._read_lock:
350 if self._connection_terminated is not None:
351 last_stream_id = self._connection_terminated.last_stream_id
352 if stream_id and last_stream_id and stream_id > last_stream_id:
353 self._request_count -= 1
354 raise ConnectionNotAvailable()
355 raise RemoteProtocolError(self._connection_terminated)
356
357 # This conditional is a bit icky. We don't want to block reading if we've
358 # actually got an event to return for a given stream. We need to do that
359 # check *within* the atomic read lock. Though it also need to be optional,
360 # because when we call it from `_wait_for_outgoing_flow` we *do* want to
361 # block until we've available flow control, event when we have events
362 # pending for the stream ID we're attempting to send on.
363 if stream_id is None or not self._events.get(stream_id):
364 events = await self._read_incoming_data(request)
365 for event in events:
366 if isinstance(event, h2.events.RemoteSettingsChanged):
367 async with Trace(
368 "receive_remote_settings", logger, request
369 ) as trace:
370 await self._receive_remote_settings_change(event)
371 trace.return_value = event
372
373 elif isinstance(
374 event,
375 (
376 h2.events.ResponseReceived,
377 h2.events.DataReceived,
378 h2.events.StreamEnded,
379 h2.events.StreamReset,
380 ),
381 ):
382 if event.stream_id in self._events:
383 self._events[event.stream_id].append(event)
384
385 elif isinstance(event, h2.events.ConnectionTerminated):
386 self._connection_terminated = event
387
388 await self._write_outgoing_data(request)
389
390 async def _receive_remote_settings_change(
391 self, event: h2.events.RemoteSettingsChanged
392 ) -> None:
393 max_concurrent_streams = event.changed_settings.get(
394 h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS
395 )
396 if max_concurrent_streams:
397 new_max_streams = min(
398 max_concurrent_streams.new_value,
399 self._h2_state.local_settings.max_concurrent_streams,
400 )
401 if new_max_streams and new_max_streams != self._max_streams:
402 while new_max_streams > self._max_streams:
403 await self._max_streams_semaphore.release()
404 self._max_streams += 1
405 while new_max_streams < self._max_streams:
406 await self._max_streams_semaphore.acquire()
407 self._max_streams -= 1
408
409 async def _response_closed(self, stream_id: int) -> None:
410 await self._max_streams_semaphore.release()
411 del self._events[stream_id]
412 async with self._state_lock:
413 if self._connection_terminated and not self._events:
414 await self.aclose()
415
416 elif self._state == HTTPConnectionState.ACTIVE and not self._events:
417 self._state = HTTPConnectionState.IDLE
418 if self._keepalive_expiry is not None:
419 now = time.monotonic()
420 self._expire_at = now + self._keepalive_expiry
421 if self._used_all_stream_ids: # pragma: nocover
422 await self.aclose()
423
424 async def aclose(self) -> None:
425 # Note that this method unilaterally closes the connection, and does
426 # not have any kind of locking in place around it.
427 self._h2_state.close_connection()
428 self._state = HTTPConnectionState.CLOSED
429 await self._network_stream.aclose()
430
431 # Wrappers around network read/write operations...
432
433 async def _read_incoming_data(self, request: Request) -> list[h2.events.Event]:
434 timeouts = request.extensions.get("timeout", {})
435 timeout = timeouts.get("read", None)
436
437 if self._read_exception is not None:
438 raise self._read_exception # pragma: nocover
439
440 try:
441 data = await self._network_stream.read(self.READ_NUM_BYTES, timeout)
442 if data == b"":
443 raise RemoteProtocolError("Server disconnected")
444 except Exception as exc:
445 # If we get a network error we should:
446 #
447 # 1. Save the exception and just raise it immediately on any future reads.
448 # (For example, this means that a single read timeout or disconnect will
449 # immediately close all pending streams. Without requiring multiple
450 # sequential timeouts.)
451 # 2. Mark the connection as errored, so that we don't accept any other
452 # incoming requests.
453 self._read_exception = exc
454 self._connection_error = True
455 raise exc
456
457 events: list[h2.events.Event] = self._h2_state.receive_data(data)
458
459 return events
460
461 async def _write_outgoing_data(self, request: Request) -> None:
462 timeouts = request.extensions.get("timeout", {})
463 timeout = timeouts.get("write", None)
464
465 async with self._write_lock:
466 data_to_send = self._h2_state.data_to_send()
467
468 if self._write_exception is not None:
469 raise self._write_exception # pragma: nocover
470
471 try:
472 await self._network_stream.write(data_to_send, timeout)
473 except Exception as exc: # pragma: nocover
474 # If we get a network error we should:
475 #
476 # 1. Save the exception and just raise it immediately on any future write.
477 # (For example, this means that a single write timeout or disconnect will
478 # immediately close all pending streams. Without requiring multiple
479 # sequential timeouts.)
480 # 2. Mark the connection as errored, so that we don't accept any other
481 # incoming requests.
482 self._write_exception = exc
483 self._connection_error = True
484 raise exc
485
486 # Flow control...
487
488 async def _wait_for_outgoing_flow(self, request: Request, stream_id: int) -> int:
489 """
490 Returns the maximum allowable outgoing flow for a given stream.
491
492 If the allowable flow is zero, then waits on the network until
493 WindowUpdated frames have increased the flow rate.
494 https://tools.ietf.org/html/rfc7540#section-6.9
495 """
496 local_flow: int = self._h2_state.local_flow_control_window(stream_id)
497 max_frame_size: int = self._h2_state.max_outbound_frame_size
498 flow = min(local_flow, max_frame_size)
499 while flow == 0:
500 await self._receive_events(request)
501 local_flow = self._h2_state.local_flow_control_window(stream_id)
502 max_frame_size = self._h2_state.max_outbound_frame_size
503 flow = min(local_flow, max_frame_size)
504 return flow
505
506 # Interface for connection pooling...
507
508 def can_handle_request(self, origin: Origin) -> bool:
509 return origin == self._origin
510
511 def is_available(self) -> bool:
512 return (
513 self._state != HTTPConnectionState.CLOSED
514 and not self._connection_error
515 and not self._used_all_stream_ids
516 and not (
517 self._h2_state.state_machine.state
518 == h2.connection.ConnectionState.CLOSED
519 )
520 )
521
522 def has_expired(self) -> bool:
523 now = time.monotonic()
524 return self._expire_at is not None and now > self._expire_at
525
526 def is_idle(self) -> bool:
527 return self._state == HTTPConnectionState.IDLE
528
529 def is_closed(self) -> bool:
530 return self._state == HTTPConnectionState.CLOSED
531
532 def info(self) -> str:
533 origin = str(self._origin)
534 return (
535 f"{origin!r}, HTTP/2, {self._state.name}, "
536 f"Request Count: {self._request_count}"
537 )
538
539 def __repr__(self) -> str:
540 class_name = self.__class__.__name__
541 origin = str(self._origin)
542 return (
543 f"<{class_name} [{origin!r}, {self._state.name}, "
544 f"Request Count: {self._request_count}]>"
545 )
546
547 # These context managers are not used in the standard flow, but are
548 # useful for testing or working with connection instances directly.
549
550 async def __aenter__(self) -> AsyncHTTP2Connection:
551 return self
552
553 async def __aexit__(
554 self,
555 exc_type: type[BaseException] | None = None,
556 exc_value: BaseException | None = None,
557 traceback: types.TracebackType | None = None,
558 ) -> None:
559 await self.aclose()
560
561
562class HTTP2ConnectionByteStream:
563 def __init__(
564 self, connection: AsyncHTTP2Connection, request: Request, stream_id: int
565 ) -> None:
566 self._connection = connection
567 self._request = request
568 self._stream_id = stream_id
569 self._closed = False
570
571 async def __aiter__(self) -> typing.AsyncIterator[bytes]:
572 kwargs = {"request": self._request, "stream_id": self._stream_id}
573 try:
574 async with Trace("receive_response_body", logger, self._request, kwargs):
575 async for chunk in self._connection._receive_response_body(
576 request=self._request, stream_id=self._stream_id
577 ):
578 yield chunk
579 except BaseException as exc:
580 # If we get an exception while streaming the response,
581 # we want to close the response (and possibly the connection)
582 # before raising that exception.
583 with AsyncShieldCancellation():
584 await self.aclose()
585 raise exc
586
587 async def aclose(self) -> None:
588 if not self._closed:
589 self._closed = True
590 kwargs = {"stream_id": self._stream_id}
591 async with Trace("response_closed", logger, self._request, kwargs):
592 await self._connection._response_closed(stream_id=self._stream_id)