Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/httpcore/_async/http2.py: 2%

256 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 07:19 +0000

1import enum 

2import logging 

3import time 

4import types 

5import typing 

6 

7import h2.config 

8import h2.connection 

9import h2.events 

10import h2.exceptions 

11import h2.settings 

12 

13from .._exceptions import ( 

14 ConnectionNotAvailable, 

15 LocalProtocolError, 

16 RemoteProtocolError, 

17) 

18from .._models import Origin, Request, Response 

19from .._synchronization import AsyncLock, AsyncSemaphore 

20from .._trace import Trace 

21from ..backends.base import AsyncNetworkStream 

22from .interfaces import AsyncConnectionInterface 

23 

24logger = logging.getLogger("httpcore.http2") 

25 

26 

27def has_body_headers(request: Request) -> bool: 

28 return any( 

29 k.lower() == b"content-length" or k.lower() == b"transfer-encoding" 

30 for k, v in request.headers 

31 ) 

32 

33 

34class HTTPConnectionState(enum.IntEnum): 

35 ACTIVE = 1 

36 IDLE = 2 

37 CLOSED = 3 

38 

39 

40class AsyncHTTP2Connection(AsyncConnectionInterface): 

41 READ_NUM_BYTES = 64 * 1024 

42 CONFIG = h2.config.H2Configuration(validate_inbound_headers=False) 

43 

44 def __init__( 

45 self, 

46 origin: Origin, 

47 stream: AsyncNetworkStream, 

48 keepalive_expiry: typing.Optional[float] = None, 

49 ): 

50 self._origin = origin 

51 self._network_stream = stream 

52 self._keepalive_expiry: typing.Optional[float] = keepalive_expiry 

53 self._h2_state = h2.connection.H2Connection(config=self.CONFIG) 

54 self._state = HTTPConnectionState.IDLE 

55 self._expire_at: typing.Optional[float] = None 

56 self._request_count = 0 

57 self._init_lock = AsyncLock() 

58 self._state_lock = AsyncLock() 

59 self._read_lock = AsyncLock() 

60 self._write_lock = AsyncLock() 

61 self._sent_connection_init = False 

62 self._used_all_stream_ids = False 

63 self._connection_error = False 

64 self._events: typing.Dict[int, h2.events.Event] = {} 

65 self._read_exception: typing.Optional[Exception] = None 

66 self._write_exception: typing.Optional[Exception] = None 

67 self._connection_error_event: typing.Optional[h2.events.Event] = None 

68 

69 async def handle_async_request(self, request: Request) -> Response: 

70 if not self.can_handle_request(request.url.origin): 

71 # This cannot occur in normal operation, since the connection pool 

72 # will only send requests on connections that handle them. 

73 # It's in place simply for resilience as a guard against incorrect 

74 # usage, for anyone working directly with httpcore connections. 

75 raise RuntimeError( 

76 f"Attempted to send request to {request.url.origin} on connection " 

77 f"to {self._origin}" 

78 ) 

79 

80 async with self._state_lock: 

81 if self._state in (HTTPConnectionState.ACTIVE, HTTPConnectionState.IDLE): 

82 self._request_count += 1 

83 self._expire_at = None 

84 self._state = HTTPConnectionState.ACTIVE 

85 else: 

86 raise ConnectionNotAvailable() 

87 

88 async with self._init_lock: 

89 if not self._sent_connection_init: 

90 kwargs = {"request": request} 

91 async with Trace("send_connection_init", logger, request, kwargs): 

92 await self._send_connection_init(**kwargs) 

93 self._sent_connection_init = True 

94 

95 # Initially start with just 1 until the remote server provides 

96 # its max_concurrent_streams value 

97 self._max_streams = 1 

98 

99 local_settings_max_streams = ( 

100 self._h2_state.local_settings.max_concurrent_streams 

101 ) 

102 self._max_streams_semaphore = AsyncSemaphore(local_settings_max_streams) 

103 

104 for _ in range(local_settings_max_streams - self._max_streams): 

105 await self._max_streams_semaphore.acquire() 

106 

107 await self._max_streams_semaphore.acquire() 

108 

109 try: 

110 stream_id = self._h2_state.get_next_available_stream_id() 

111 self._events[stream_id] = [] 

112 except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover 

113 self._used_all_stream_ids = True 

114 raise ConnectionNotAvailable() 

115 

116 try: 

117 kwargs = {"request": request, "stream_id": stream_id} 

118 async with Trace("send_request_headers", logger, request, kwargs): 

119 await self._send_request_headers(request=request, stream_id=stream_id) 

120 async with Trace("send_request_body", logger, request, kwargs): 

121 await self._send_request_body(request=request, stream_id=stream_id) 

122 async with Trace( 

123 "receive_response_headers", logger, request, kwargs 

124 ) as trace: 

125 status, headers = await self._receive_response( 

126 request=request, stream_id=stream_id 

127 ) 

128 trace.return_value = (status, headers) 

129 

130 return Response( 

131 status=status, 

132 headers=headers, 

133 content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id), 

134 extensions={"stream_id": stream_id, "http_version": b"HTTP/2"}, 

135 ) 

136 except Exception as exc: # noqa: PIE786 

137 kwargs = {"stream_id": stream_id} 

138 async with Trace("response_closed", logger, request, kwargs): 

139 await self._response_closed(stream_id=stream_id) 

140 

141 if isinstance(exc, h2.exceptions.ProtocolError): 

142 # One case where h2 can raise a protocol error is when a 

143 # closed frame has been seen by the state machine. 

144 # 

145 # This happens when one stream is reading, and encounters 

146 # a GOAWAY event. Other flows of control may then raise 

147 # a protocol error at any point they interact with the 'h2_state'. 

148 # 

149 # In this case we'll have stored the event, and should raise 

150 # it as a RemoteProtocolError. 

151 if self._connection_error_event: 

152 raise RemoteProtocolError(self._connection_error_event) 

153 # If h2 raises a protocol error in some other state then we 

154 # must somehow have made a protocol violation. 

155 raise LocalProtocolError(exc) # pragma: nocover 

156 

157 raise exc 

158 

159 async def _send_connection_init(self, request: Request) -> None: 

160 """ 

161 The HTTP/2 connection requires some initial setup before we can start 

162 using individual request/response streams on it. 

163 """ 

164 # Need to set these manually here instead of manipulating via 

165 # __setitem__() otherwise the H2Connection will emit SettingsUpdate 

166 # frames in addition to sending the undesired defaults. 

167 self._h2_state.local_settings = h2.settings.Settings( 

168 client=True, 

169 initial_values={ 

170 # Disable PUSH_PROMISE frames from the server since we don't do anything 

171 # with them for now. Maybe when we support caching? 

172 h2.settings.SettingCodes.ENABLE_PUSH: 0, 

173 # These two are taken from h2 for safe defaults 

174 h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 100, 

175 h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: 65536, 

176 }, 

177 ) 

178 

179 # Some websites (*cough* Yahoo *cough*) balk at this setting being 

180 # present in the initial handshake since it's not defined in the original 

181 # RFC despite the RFC mandating ignoring settings you don't know about. 

182 del self._h2_state.local_settings[ 

183 h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL 

184 ] 

185 

186 self._h2_state.initiate_connection() 

187 self._h2_state.increment_flow_control_window(2**24) 

188 await self._write_outgoing_data(request) 

189 

190 # Sending the request... 

191 

192 async def _send_request_headers(self, request: Request, stream_id: int) -> None: 

193 end_stream = not has_body_headers(request) 

194 

195 # In HTTP/2 the ':authority' pseudo-header is used instead of 'Host'. 

196 # In order to gracefully handle HTTP/1.1 and HTTP/2 we always require 

197 # HTTP/1.1 style headers, and map them appropriately if we end up on 

198 # an HTTP/2 connection. 

199 authority = [v for k, v in request.headers if k.lower() == b"host"][0] 

200 

201 headers = [ 

202 (b":method", request.method), 

203 (b":authority", authority), 

204 (b":scheme", request.url.scheme), 

205 (b":path", request.url.target), 

206 ] + [ 

207 (k.lower(), v) 

208 for k, v in request.headers 

209 if k.lower() 

210 not in ( 

211 b"host", 

212 b"transfer-encoding", 

213 ) 

214 ] 

215 

216 self._h2_state.send_headers(stream_id, headers, end_stream=end_stream) 

217 self._h2_state.increment_flow_control_window(2**24, stream_id=stream_id) 

218 await self._write_outgoing_data(request) 

219 

220 async def _send_request_body(self, request: Request, stream_id: int) -> None: 

221 if not has_body_headers(request): 

222 return 

223 

224 assert isinstance(request.stream, typing.AsyncIterable) 

225 async for data in request.stream: 

226 while data: 

227 max_flow = await self._wait_for_outgoing_flow(request, stream_id) 

228 chunk_size = min(len(data), max_flow) 

229 chunk, data = data[:chunk_size], data[chunk_size:] 

230 self._h2_state.send_data(stream_id, chunk) 

231 await self._write_outgoing_data(request) 

232 

233 self._h2_state.end_stream(stream_id) 

234 await self._write_outgoing_data(request) 

235 

236 # Receiving the response... 

237 

238 async def _receive_response( 

239 self, request: Request, stream_id: int 

240 ) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]: 

241 while True: 

242 event = await self._receive_stream_event(request, stream_id) 

243 if isinstance(event, h2.events.ResponseReceived): 

244 break 

245 

246 status_code = 200 

247 headers = [] 

248 for k, v in event.headers: 

249 if k == b":status": 

250 status_code = int(v.decode("ascii", errors="ignore")) 

251 elif not k.startswith(b":"): 

252 headers.append((k, v)) 

253 

254 return (status_code, headers) 

255 

256 async def _receive_response_body( 

257 self, request: Request, stream_id: int 

258 ) -> typing.AsyncIterator[bytes]: 

259 while True: 

260 event = await self._receive_stream_event(request, stream_id) 

261 if isinstance(event, h2.events.DataReceived): 

262 amount = event.flow_controlled_length 

263 self._h2_state.acknowledge_received_data(amount, stream_id) 

264 await self._write_outgoing_data(request) 

265 yield event.data 

266 elif isinstance(event, (h2.events.StreamEnded, h2.events.StreamReset)): 

267 break 

268 

269 async def _receive_stream_event( 

270 self, request: Request, stream_id: int 

271 ) -> h2.events.Event: 

272 while not self._events.get(stream_id): 

273 await self._receive_events(request, stream_id) 

274 event = self._events[stream_id].pop(0) 

275 # The StreamReset event applies to a single stream. 

276 if hasattr(event, "error_code"): 

277 raise RemoteProtocolError(event) 

278 return event 

279 

280 async def _receive_events( 

281 self, request: Request, stream_id: typing.Optional[int] = None 

282 ) -> None: 

283 async with self._read_lock: 

284 if self._connection_error_event is not None: # pragma: nocover 

285 raise RemoteProtocolError(self._connection_error_event) 

286 

287 # This conditional is a bit icky. We don't want to block reading if we've 

288 # actually got an event to return for a given stream. We need to do that 

289 # check *within* the atomic read lock. Though it also need to be optional, 

290 # because when we call it from `_wait_for_outgoing_flow` we *do* want to 

291 # block until we've available flow control, event when we have events 

292 # pending for the stream ID we're attempting to send on. 

293 if stream_id is None or not self._events.get(stream_id): 

294 events = await self._read_incoming_data(request) 

295 for event in events: 

296 if isinstance(event, h2.events.RemoteSettingsChanged): 

297 async with Trace( 

298 "receive_remote_settings", logger, request 

299 ) as trace: 

300 await self._receive_remote_settings_change(event) 

301 trace.return_value = event 

302 

303 event_stream_id = getattr(event, "stream_id", 0) 

304 

305 # The ConnectionTerminatedEvent applies to the entire connection, 

306 # and should be saved so it can be raised on all streams. 

307 if hasattr(event, "error_code") and event_stream_id == 0: 

308 self._connection_error_event = event 

309 raise RemoteProtocolError(event) 

310 

311 if event_stream_id in self._events: 

312 self._events[event_stream_id].append(event) 

313 

314 await self._write_outgoing_data(request) 

315 

316 async def _receive_remote_settings_change(self, event: h2.events.Event) -> None: 

317 max_concurrent_streams = event.changed_settings.get( 

318 h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS 

319 ) 

320 if max_concurrent_streams: 

321 new_max_streams = min( 

322 max_concurrent_streams.new_value, 

323 self._h2_state.local_settings.max_concurrent_streams, 

324 ) 

325 if new_max_streams and new_max_streams != self._max_streams: 

326 while new_max_streams > self._max_streams: 

327 await self._max_streams_semaphore.release() 

328 self._max_streams += 1 

329 while new_max_streams < self._max_streams: 

330 await self._max_streams_semaphore.acquire() 

331 self._max_streams -= 1 

332 

333 async def _response_closed(self, stream_id: int) -> None: 

334 await self._max_streams_semaphore.release() 

335 del self._events[stream_id] 

336 async with self._state_lock: 

337 if self._state == HTTPConnectionState.ACTIVE and not self._events: 

338 self._state = HTTPConnectionState.IDLE 

339 if self._keepalive_expiry is not None: 

340 now = time.monotonic() 

341 self._expire_at = now + self._keepalive_expiry 

342 if self._used_all_stream_ids: # pragma: nocover 

343 await self.aclose() 

344 

345 async def aclose(self) -> None: 

346 # Note that this method unilaterally closes the connection, and does 

347 # not have any kind of locking in place around it. 

348 self._h2_state.close_connection() 

349 self._state = HTTPConnectionState.CLOSED 

350 await self._network_stream.aclose() 

351 

352 # Wrappers around network read/write operations... 

353 

354 async def _read_incoming_data( 

355 self, request: Request 

356 ) -> typing.List[h2.events.Event]: 

357 timeouts = request.extensions.get("timeout", {}) 

358 timeout = timeouts.get("read", None) 

359 

360 if self._read_exception is not None: 

361 raise self._read_exception # pragma: nocover 

362 

363 try: 

364 data = await self._network_stream.read(self.READ_NUM_BYTES, timeout) 

365 if data == b"": 

366 raise RemoteProtocolError("Server disconnected") 

367 except Exception as exc: 

368 # If we get a network error we should: 

369 # 

370 # 1. Save the exception and just raise it immediately on any future reads. 

371 # (For example, this means that a single read timeout or disconnect will 

372 # immediately close all pending streams. Without requiring multiple 

373 # sequential timeouts.) 

374 # 2. Mark the connection as errored, so that we don't accept any other 

375 # incoming requests. 

376 self._read_exception = exc 

377 self._connection_error = True 

378 raise exc 

379 

380 events: typing.List[h2.events.Event] = self._h2_state.receive_data(data) 

381 

382 return events 

383 

384 async def _write_outgoing_data(self, request: Request) -> None: 

385 timeouts = request.extensions.get("timeout", {}) 

386 timeout = timeouts.get("write", None) 

387 

388 async with self._write_lock: 

389 data_to_send = self._h2_state.data_to_send() 

390 

391 if self._write_exception is not None: 

392 raise self._write_exception # pragma: nocover 

393 

394 try: 

395 await self._network_stream.write(data_to_send, timeout) 

396 except Exception as exc: # pragma: nocover 

397 # If we get a network error we should: 

398 # 

399 # 1. Save the exception and just raise it immediately on any future write. 

400 # (For example, this means that a single write timeout or disconnect will 

401 # immediately close all pending streams. Without requiring multiple 

402 # sequential timeouts.) 

403 # 2. Mark the connection as errored, so that we don't accept any other 

404 # incoming requests. 

405 self._write_exception = exc 

406 self._connection_error = True 

407 raise exc 

408 

409 # Flow control... 

410 

411 async def _wait_for_outgoing_flow(self, request: Request, stream_id: int) -> int: 

412 """ 

413 Returns the maximum allowable outgoing flow for a given stream. 

414 

415 If the allowable flow is zero, then waits on the network until 

416 WindowUpdated frames have increased the flow rate. 

417 https://tools.ietf.org/html/rfc7540#section-6.9 

418 """ 

419 local_flow: int = self._h2_state.local_flow_control_window(stream_id) 

420 max_frame_size: int = self._h2_state.max_outbound_frame_size 

421 flow = min(local_flow, max_frame_size) 

422 while flow == 0: 

423 await self._receive_events(request) 

424 local_flow = self._h2_state.local_flow_control_window(stream_id) 

425 max_frame_size = self._h2_state.max_outbound_frame_size 

426 flow = min(local_flow, max_frame_size) 

427 return flow 

428 

429 # Interface for connection pooling... 

430 

431 def can_handle_request(self, origin: Origin) -> bool: 

432 return origin == self._origin 

433 

434 def is_available(self) -> bool: 

435 return ( 

436 self._state != HTTPConnectionState.CLOSED 

437 and not self._connection_error 

438 and not self._used_all_stream_ids 

439 and not ( 

440 self._h2_state.state_machine.state 

441 == h2.connection.ConnectionState.CLOSED 

442 ) 

443 ) 

444 

445 def has_expired(self) -> bool: 

446 now = time.monotonic() 

447 return self._expire_at is not None and now > self._expire_at 

448 

449 def is_idle(self) -> bool: 

450 return self._state == HTTPConnectionState.IDLE 

451 

452 def is_closed(self) -> bool: 

453 return self._state == HTTPConnectionState.CLOSED 

454 

455 def info(self) -> str: 

456 origin = str(self._origin) 

457 return ( 

458 f"{origin!r}, HTTP/2, {self._state.name}, " 

459 f"Request Count: {self._request_count}" 

460 ) 

461 

462 def __repr__(self) -> str: 

463 class_name = self.__class__.__name__ 

464 origin = str(self._origin) 

465 return ( 

466 f"<{class_name} [{origin!r}, {self._state.name}, " 

467 f"Request Count: {self._request_count}]>" 

468 ) 

469 

470 # These context managers are not used in the standard flow, but are 

471 # useful for testing or working with connection instances directly. 

472 

473 async def __aenter__(self) -> "AsyncHTTP2Connection": 

474 return self 

475 

476 async def __aexit__( 

477 self, 

478 exc_type: typing.Optional[typing.Type[BaseException]] = None, 

479 exc_value: typing.Optional[BaseException] = None, 

480 traceback: typing.Optional[types.TracebackType] = None, 

481 ) -> None: 

482 await self.aclose() 

483 

484 

485class HTTP2ConnectionByteStream: 

486 def __init__( 

487 self, connection: AsyncHTTP2Connection, request: Request, stream_id: int 

488 ) -> None: 

489 self._connection = connection 

490 self._request = request 

491 self._stream_id = stream_id 

492 self._closed = False 

493 

494 async def __aiter__(self) -> typing.AsyncIterator[bytes]: 

495 kwargs = {"request": self._request, "stream_id": self._stream_id} 

496 try: 

497 async with Trace("receive_response_body", logger, self._request, kwargs): 

498 async for chunk in self._connection._receive_response_body( 

499 request=self._request, stream_id=self._stream_id 

500 ): 

501 yield chunk 

502 except BaseException as exc: 

503 # If we get an exception while streaming the response, 

504 # we want to close the response (and possibly the connection) 

505 # before raising that exception. 

506 await self.aclose() 

507 raise exc 

508 

509 async def aclose(self) -> None: 

510 if not self._closed: 

511 self._closed = True 

512 kwargs = {"stream_id": self._stream_id} 

513 async with Trace("response_closed", logger, self._request, kwargs): 

514 await self._connection._response_closed(stream_id=self._stream_id)