1import enum
2import logging
3import time
4import types
5import typing
6
7import h2.config
8import h2.connection
9import h2.events
10import h2.exceptions
11import h2.settings
12
13from .._backends.base import AsyncNetworkStream
14from .._exceptions import (
15 ConnectionNotAvailable,
16 LocalProtocolError,
17 RemoteProtocolError,
18)
19from .._models import Origin, Request, Response
20from .._synchronization import AsyncLock, AsyncSemaphore, AsyncShieldCancellation
21from .._trace import Trace
22from .interfaces import AsyncConnectionInterface
23
24logger = logging.getLogger("httpcore.http2")
25
26
27def has_body_headers(request: Request) -> bool:
28 return any(
29 k.lower() == b"content-length" or k.lower() == b"transfer-encoding"
30 for k, v in request.headers
31 )
32
33
34class HTTPConnectionState(enum.IntEnum):
35 ACTIVE = 1
36 IDLE = 2
37 CLOSED = 3
38
39
40class AsyncHTTP2Connection(AsyncConnectionInterface):
41 READ_NUM_BYTES = 64 * 1024
42 CONFIG = h2.config.H2Configuration(validate_inbound_headers=False)
43
44 def __init__(
45 self,
46 origin: Origin,
47 stream: AsyncNetworkStream,
48 keepalive_expiry: typing.Optional[float] = None,
49 ):
50 self._origin = origin
51 self._network_stream = stream
52 self._keepalive_expiry: typing.Optional[float] = keepalive_expiry
53 self._h2_state = h2.connection.H2Connection(config=self.CONFIG)
54 self._state = HTTPConnectionState.IDLE
55 self._expire_at: typing.Optional[float] = None
56 self._request_count = 0
57 self._init_lock = AsyncLock()
58 self._state_lock = AsyncLock()
59 self._read_lock = AsyncLock()
60 self._write_lock = AsyncLock()
61 self._sent_connection_init = False
62 self._used_all_stream_ids = False
63 self._connection_error = False
64
65 # Mapping from stream ID to response stream events.
66 self._events: typing.Dict[
67 int,
68 typing.Union[
69 h2.events.ResponseReceived,
70 h2.events.DataReceived,
71 h2.events.StreamEnded,
72 h2.events.StreamReset,
73 ],
74 ] = {}
75
76 # Connection terminated events are stored as state since
77 # we need to handle them for all streams.
78 self._connection_terminated: typing.Optional[h2.events.ConnectionTerminated] = (
79 None
80 )
81
82 self._read_exception: typing.Optional[Exception] = None
83 self._write_exception: typing.Optional[Exception] = None
84
85 async def handle_async_request(self, request: Request) -> Response:
86 if not self.can_handle_request(request.url.origin):
87 # This cannot occur in normal operation, since the connection pool
88 # will only send requests on connections that handle them.
89 # It's in place simply for resilience as a guard against incorrect
90 # usage, for anyone working directly with httpcore connections.
91 raise RuntimeError(
92 f"Attempted to send request to {request.url.origin} on connection "
93 f"to {self._origin}"
94 )
95
96 async with self._state_lock:
97 if self._state in (HTTPConnectionState.ACTIVE, HTTPConnectionState.IDLE):
98 self._request_count += 1
99 self._expire_at = None
100 self._state = HTTPConnectionState.ACTIVE
101 else:
102 raise ConnectionNotAvailable()
103
104 async with self._init_lock:
105 if not self._sent_connection_init:
106 try:
107 kwargs = {"request": request}
108 async with Trace("send_connection_init", logger, request, kwargs):
109 await self._send_connection_init(**kwargs)
110 except BaseException as exc:
111 with AsyncShieldCancellation():
112 await self.aclose()
113 raise exc
114
115 self._sent_connection_init = True
116
117 # Initially start with just 1 until the remote server provides
118 # its max_concurrent_streams value
119 self._max_streams = 1
120
121 local_settings_max_streams = (
122 self._h2_state.local_settings.max_concurrent_streams
123 )
124 self._max_streams_semaphore = AsyncSemaphore(local_settings_max_streams)
125
126 for _ in range(local_settings_max_streams - self._max_streams):
127 await self._max_streams_semaphore.acquire()
128
129 await self._max_streams_semaphore.acquire()
130
131 try:
132 stream_id = self._h2_state.get_next_available_stream_id()
133 self._events[stream_id] = []
134 except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover
135 self._used_all_stream_ids = True
136 self._request_count -= 1
137 raise ConnectionNotAvailable()
138
139 try:
140 kwargs = {"request": request, "stream_id": stream_id}
141 async with Trace("send_request_headers", logger, request, kwargs):
142 await self._send_request_headers(request=request, stream_id=stream_id)
143 async with Trace("send_request_body", logger, request, kwargs):
144 await self._send_request_body(request=request, stream_id=stream_id)
145 async with Trace(
146 "receive_response_headers", logger, request, kwargs
147 ) as trace:
148 status, headers = await self._receive_response(
149 request=request, stream_id=stream_id
150 )
151 trace.return_value = (status, headers)
152
153 return Response(
154 status=status,
155 headers=headers,
156 content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id),
157 extensions={
158 "http_version": b"HTTP/2",
159 "network_stream": self._network_stream,
160 "stream_id": stream_id,
161 },
162 )
163 except BaseException as exc: # noqa: PIE786
164 with AsyncShieldCancellation():
165 kwargs = {"stream_id": stream_id}
166 async with Trace("response_closed", logger, request, kwargs):
167 await self._response_closed(stream_id=stream_id)
168
169 if isinstance(exc, h2.exceptions.ProtocolError):
170 # One case where h2 can raise a protocol error is when a
171 # closed frame has been seen by the state machine.
172 #
173 # This happens when one stream is reading, and encounters
174 # a GOAWAY event. Other flows of control may then raise
175 # a protocol error at any point they interact with the 'h2_state'.
176 #
177 # In this case we'll have stored the event, and should raise
178 # it as a RemoteProtocolError.
179 if self._connection_terminated: # pragma: nocover
180 raise RemoteProtocolError(self._connection_terminated)
181 # If h2 raises a protocol error in some other state then we
182 # must somehow have made a protocol violation.
183 raise LocalProtocolError(exc) # pragma: nocover
184
185 raise exc
186
187 async def _send_connection_init(self, request: Request) -> None:
188 """
189 The HTTP/2 connection requires some initial setup before we can start
190 using individual request/response streams on it.
191 """
192 # Need to set these manually here instead of manipulating via
193 # __setitem__() otherwise the H2Connection will emit SettingsUpdate
194 # frames in addition to sending the undesired defaults.
195 self._h2_state.local_settings = h2.settings.Settings(
196 client=True,
197 initial_values={
198 # Disable PUSH_PROMISE frames from the server since we don't do anything
199 # with them for now. Maybe when we support caching?
200 h2.settings.SettingCodes.ENABLE_PUSH: 0,
201 # These two are taken from h2 for safe defaults
202 h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 100,
203 h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: 65536,
204 },
205 )
206
207 # Some websites (*cough* Yahoo *cough*) balk at this setting being
208 # present in the initial handshake since it's not defined in the original
209 # RFC despite the RFC mandating ignoring settings you don't know about.
210 del self._h2_state.local_settings[
211 h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL
212 ]
213
214 self._h2_state.initiate_connection()
215 self._h2_state.increment_flow_control_window(2**24)
216 await self._write_outgoing_data(request)
217
218 # Sending the request...
219
220 async def _send_request_headers(self, request: Request, stream_id: int) -> None:
221 """
222 Send the request headers to a given stream ID.
223 """
224 end_stream = not has_body_headers(request)
225
226 # In HTTP/2 the ':authority' pseudo-header is used instead of 'Host'.
227 # In order to gracefully handle HTTP/1.1 and HTTP/2 we always require
228 # HTTP/1.1 style headers, and map them appropriately if we end up on
229 # an HTTP/2 connection.
230 authority = [v for k, v in request.headers if k.lower() == b"host"][0]
231
232 headers = [
233 (b":method", request.method),
234 (b":authority", authority),
235 (b":scheme", request.url.scheme),
236 (b":path", request.url.target),
237 ] + [
238 (k.lower(), v)
239 for k, v in request.headers
240 if k.lower()
241 not in (
242 b"host",
243 b"transfer-encoding",
244 )
245 ]
246
247 self._h2_state.send_headers(stream_id, headers, end_stream=end_stream)
248 self._h2_state.increment_flow_control_window(2**24, stream_id=stream_id)
249 await self._write_outgoing_data(request)
250
251 async def _send_request_body(self, request: Request, stream_id: int) -> None:
252 """
253 Iterate over the request body sending it to a given stream ID.
254 """
255 if not has_body_headers(request):
256 return
257
258 assert isinstance(request.stream, typing.AsyncIterable)
259 async for data in request.stream:
260 await self._send_stream_data(request, stream_id, data)
261 await self._send_end_stream(request, stream_id)
262
263 async def _send_stream_data(
264 self, request: Request, stream_id: int, data: bytes
265 ) -> None:
266 """
267 Send a single chunk of data in one or more data frames.
268 """
269 while data:
270 max_flow = await self._wait_for_outgoing_flow(request, stream_id)
271 chunk_size = min(len(data), max_flow)
272 chunk, data = data[:chunk_size], data[chunk_size:]
273 self._h2_state.send_data(stream_id, chunk)
274 await self._write_outgoing_data(request)
275
276 async def _send_end_stream(self, request: Request, stream_id: int) -> None:
277 """
278 Send an empty data frame on on a given stream ID with the END_STREAM flag set.
279 """
280 self._h2_state.end_stream(stream_id)
281 await self._write_outgoing_data(request)
282
283 # Receiving the response...
284
285 async def _receive_response(
286 self, request: Request, stream_id: int
287 ) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]:
288 """
289 Return the response status code and headers for a given stream ID.
290 """
291 while True:
292 event = await self._receive_stream_event(request, stream_id)
293 if isinstance(event, h2.events.ResponseReceived):
294 break
295
296 status_code = 200
297 headers = []
298 for k, v in event.headers:
299 if k == b":status":
300 status_code = int(v.decode("ascii", errors="ignore"))
301 elif not k.startswith(b":"):
302 headers.append((k, v))
303
304 return (status_code, headers)
305
306 async def _receive_response_body(
307 self, request: Request, stream_id: int
308 ) -> typing.AsyncIterator[bytes]:
309 """
310 Iterator that returns the bytes of the response body for a given stream ID.
311 """
312 while True:
313 event = await self._receive_stream_event(request, stream_id)
314 if isinstance(event, h2.events.DataReceived):
315 amount = event.flow_controlled_length
316 self._h2_state.acknowledge_received_data(amount, stream_id)
317 await self._write_outgoing_data(request)
318 yield event.data
319 elif isinstance(event, h2.events.StreamEnded):
320 break
321
322 async def _receive_stream_event(
323 self, request: Request, stream_id: int
324 ) -> typing.Union[
325 h2.events.ResponseReceived, h2.events.DataReceived, h2.events.StreamEnded
326 ]:
327 """
328 Return the next available event for a given stream ID.
329
330 Will read more data from the network if required.
331 """
332 while not self._events.get(stream_id):
333 await self._receive_events(request, stream_id)
334 event = self._events[stream_id].pop(0)
335 if isinstance(event, h2.events.StreamReset):
336 raise RemoteProtocolError(event)
337 return event
338
339 async def _receive_events(
340 self, request: Request, stream_id: typing.Optional[int] = None
341 ) -> None:
342 """
343 Read some data from the network until we see one or more events
344 for a given stream ID.
345 """
346 async with self._read_lock:
347 if self._connection_terminated is not None:
348 last_stream_id = self._connection_terminated.last_stream_id
349 if stream_id and last_stream_id and stream_id > last_stream_id:
350 self._request_count -= 1
351 raise ConnectionNotAvailable()
352 raise RemoteProtocolError(self._connection_terminated)
353
354 # This conditional is a bit icky. We don't want to block reading if we've
355 # actually got an event to return for a given stream. We need to do that
356 # check *within* the atomic read lock. Though it also need to be optional,
357 # because when we call it from `_wait_for_outgoing_flow` we *do* want to
358 # block until we've available flow control, event when we have events
359 # pending for the stream ID we're attempting to send on.
360 if stream_id is None or not self._events.get(stream_id):
361 events = await self._read_incoming_data(request)
362 for event in events:
363 if isinstance(event, h2.events.RemoteSettingsChanged):
364 async with Trace(
365 "receive_remote_settings", logger, request
366 ) as trace:
367 await self._receive_remote_settings_change(event)
368 trace.return_value = event
369
370 elif isinstance(
371 event,
372 (
373 h2.events.ResponseReceived,
374 h2.events.DataReceived,
375 h2.events.StreamEnded,
376 h2.events.StreamReset,
377 ),
378 ):
379 if event.stream_id in self._events:
380 self._events[event.stream_id].append(event)
381
382 elif isinstance(event, h2.events.ConnectionTerminated):
383 self._connection_terminated = event
384
385 await self._write_outgoing_data(request)
386
387 async def _receive_remote_settings_change(self, event: h2.events.Event) -> None:
388 max_concurrent_streams = event.changed_settings.get(
389 h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS
390 )
391 if max_concurrent_streams:
392 new_max_streams = min(
393 max_concurrent_streams.new_value,
394 self._h2_state.local_settings.max_concurrent_streams,
395 )
396 if new_max_streams and new_max_streams != self._max_streams:
397 while new_max_streams > self._max_streams:
398 await self._max_streams_semaphore.release()
399 self._max_streams += 1
400 while new_max_streams < self._max_streams:
401 await self._max_streams_semaphore.acquire()
402 self._max_streams -= 1
403
404 async def _response_closed(self, stream_id: int) -> None:
405 await self._max_streams_semaphore.release()
406 del self._events[stream_id]
407 async with self._state_lock:
408 if self._connection_terminated and not self._events:
409 await self.aclose()
410
411 elif self._state == HTTPConnectionState.ACTIVE and not self._events:
412 self._state = HTTPConnectionState.IDLE
413 if self._keepalive_expiry is not None:
414 now = time.monotonic()
415 self._expire_at = now + self._keepalive_expiry
416 if self._used_all_stream_ids: # pragma: nocover
417 await self.aclose()
418
419 async def aclose(self) -> None:
420 # Note that this method unilaterally closes the connection, and does
421 # not have any kind of locking in place around it.
422 self._h2_state.close_connection()
423 self._state = HTTPConnectionState.CLOSED
424 await self._network_stream.aclose()
425
426 # Wrappers around network read/write operations...
427
428 async def _read_incoming_data(
429 self, request: Request
430 ) -> typing.List[h2.events.Event]:
431 timeouts = request.extensions.get("timeout", {})
432 timeout = timeouts.get("read", None)
433
434 if self._read_exception is not None:
435 raise self._read_exception # pragma: nocover
436
437 try:
438 data = await self._network_stream.read(self.READ_NUM_BYTES, timeout)
439 if data == b"":
440 raise RemoteProtocolError("Server disconnected")
441 except Exception as exc:
442 # If we get a network error we should:
443 #
444 # 1. Save the exception and just raise it immediately on any future reads.
445 # (For example, this means that a single read timeout or disconnect will
446 # immediately close all pending streams. Without requiring multiple
447 # sequential timeouts.)
448 # 2. Mark the connection as errored, so that we don't accept any other
449 # incoming requests.
450 self._read_exception = exc
451 self._connection_error = True
452 raise exc
453
454 events: typing.List[h2.events.Event] = self._h2_state.receive_data(data)
455
456 return events
457
458 async def _write_outgoing_data(self, request: Request) -> None:
459 timeouts = request.extensions.get("timeout", {})
460 timeout = timeouts.get("write", None)
461
462 async with self._write_lock:
463 data_to_send = self._h2_state.data_to_send()
464
465 if self._write_exception is not None:
466 raise self._write_exception # pragma: nocover
467
468 try:
469 await self._network_stream.write(data_to_send, timeout)
470 except Exception as exc: # pragma: nocover
471 # If we get a network error we should:
472 #
473 # 1. Save the exception and just raise it immediately on any future write.
474 # (For example, this means that a single write timeout or disconnect will
475 # immediately close all pending streams. Without requiring multiple
476 # sequential timeouts.)
477 # 2. Mark the connection as errored, so that we don't accept any other
478 # incoming requests.
479 self._write_exception = exc
480 self._connection_error = True
481 raise exc
482
483 # Flow control...
484
485 async def _wait_for_outgoing_flow(self, request: Request, stream_id: int) -> int:
486 """
487 Returns the maximum allowable outgoing flow for a given stream.
488
489 If the allowable flow is zero, then waits on the network until
490 WindowUpdated frames have increased the flow rate.
491 https://tools.ietf.org/html/rfc7540#section-6.9
492 """
493 local_flow: int = self._h2_state.local_flow_control_window(stream_id)
494 max_frame_size: int = self._h2_state.max_outbound_frame_size
495 flow = min(local_flow, max_frame_size)
496 while flow == 0:
497 await self._receive_events(request)
498 local_flow = self._h2_state.local_flow_control_window(stream_id)
499 max_frame_size = self._h2_state.max_outbound_frame_size
500 flow = min(local_flow, max_frame_size)
501 return flow
502
503 # Interface for connection pooling...
504
505 def can_handle_request(self, origin: Origin) -> bool:
506 return origin == self._origin
507
508 def is_available(self) -> bool:
509 return (
510 self._state != HTTPConnectionState.CLOSED
511 and not self._connection_error
512 and not self._used_all_stream_ids
513 and not (
514 self._h2_state.state_machine.state
515 == h2.connection.ConnectionState.CLOSED
516 )
517 )
518
519 def has_expired(self) -> bool:
520 now = time.monotonic()
521 return self._expire_at is not None and now > self._expire_at
522
523 def is_idle(self) -> bool:
524 return self._state == HTTPConnectionState.IDLE
525
526 def is_closed(self) -> bool:
527 return self._state == HTTPConnectionState.CLOSED
528
529 def info(self) -> str:
530 origin = str(self._origin)
531 return (
532 f"{origin!r}, HTTP/2, {self._state.name}, "
533 f"Request Count: {self._request_count}"
534 )
535
536 def __repr__(self) -> str:
537 class_name = self.__class__.__name__
538 origin = str(self._origin)
539 return (
540 f"<{class_name} [{origin!r}, {self._state.name}, "
541 f"Request Count: {self._request_count}]>"
542 )
543
544 # These context managers are not used in the standard flow, but are
545 # useful for testing or working with connection instances directly.
546
547 async def __aenter__(self) -> "AsyncHTTP2Connection":
548 return self
549
550 async def __aexit__(
551 self,
552 exc_type: typing.Optional[typing.Type[BaseException]] = None,
553 exc_value: typing.Optional[BaseException] = None,
554 traceback: typing.Optional[types.TracebackType] = None,
555 ) -> None:
556 await self.aclose()
557
558
559class HTTP2ConnectionByteStream:
560 def __init__(
561 self, connection: AsyncHTTP2Connection, request: Request, stream_id: int
562 ) -> None:
563 self._connection = connection
564 self._request = request
565 self._stream_id = stream_id
566 self._closed = False
567
568 async def __aiter__(self) -> typing.AsyncIterator[bytes]:
569 kwargs = {"request": self._request, "stream_id": self._stream_id}
570 try:
571 async with Trace("receive_response_body", logger, self._request, kwargs):
572 async for chunk in self._connection._receive_response_body(
573 request=self._request, stream_id=self._stream_id
574 ):
575 yield chunk
576 except BaseException as exc:
577 # If we get an exception while streaming the response,
578 # we want to close the response (and possibly the connection)
579 # before raising that exception.
580 with AsyncShieldCancellation():
581 await self.aclose()
582 raise exc
583
584 async def aclose(self) -> None:
585 if not self._closed:
586 self._closed = True
587 kwargs = {"stream_id": self._stream_id}
588 async with Trace("response_closed", logger, self._request, kwargs):
589 await self._connection._response_closed(stream_id=self._stream_id)