Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/web_protocol.py: 18%

354 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-26 06:16 +0000

1import asyncio 

2import asyncio.streams 

3import dataclasses 

4import traceback 

5from collections import deque 

6from contextlib import suppress 

7from html import escape as html_escape 

8from http import HTTPStatus 

9from logging import Logger 

10from typing import ( 

11 TYPE_CHECKING, 

12 Any, 

13 Awaitable, 

14 Callable, 

15 Deque, 

16 Optional, 

17 Sequence, 

18 Tuple, 

19 Type, 

20 Union, 

21 cast, 

22) 

23 

24import yarl 

25 

26from .abc import AbstractAccessLogger, AbstractAsyncAccessLogger, AbstractStreamWriter 

27from .base_protocol import BaseProtocol 

28from .helpers import ceil_timeout 

29from .http import ( 

30 HttpProcessingError, 

31 HttpRequestParser, 

32 HttpVersion10, 

33 RawRequestMessage, 

34 StreamWriter, 

35) 

36from .log import access_logger, server_logger 

37from .streams import EMPTY_PAYLOAD, StreamReader 

38from .tcp_helpers import tcp_keepalive 

39from .web_exceptions import HTTPException 

40from .web_log import AccessLogger 

41from .web_request import BaseRequest 

42from .web_response import Response, StreamResponse 

43 

44__all__ = ("RequestHandler", "RequestPayloadError", "PayloadAccessError") 

45 

46if TYPE_CHECKING: 

47 from .web_server import Server 

48 

49 

50_RequestFactory = Callable[ 

51 [ 

52 RawRequestMessage, 

53 StreamReader, 

54 "RequestHandler", 

55 AbstractStreamWriter, 

56 "asyncio.Task[None]", 

57 ], 

58 BaseRequest, 

59] 

60 

61_RequestHandler = Callable[[BaseRequest], Awaitable[StreamResponse]] 

62_AnyAbstractAccessLogger = Union[ 

63 Type[AbstractAsyncAccessLogger], 

64 Type[AbstractAccessLogger], 

65] 

66 

67ERROR = RawRequestMessage( 

68 "UNKNOWN", 

69 "/", 

70 HttpVersion10, 

71 {}, # type: ignore[arg-type] 

72 {}, # type: ignore[arg-type] 

73 True, 

74 None, 

75 False, 

76 False, 

77 yarl.URL("/"), 

78) 

79 

80 

81class RequestPayloadError(Exception): 

82 """Payload parsing error.""" 

83 

84 

85class PayloadAccessError(Exception): 

86 """Payload was accessed after response was sent.""" 

87 

88 

89class AccessLoggerWrapper(AbstractAsyncAccessLogger): 

90 """Wrap an AbstractAccessLogger so it behaves like an AbstractAsyncAccessLogger.""" 

91 

92 def __init__( 

93 self, access_logger: AbstractAccessLogger, loop: asyncio.AbstractEventLoop 

94 ) -> None: 

95 self.access_logger = access_logger 

96 self._loop = loop 

97 super().__init__() 

98 

99 async def log( 

100 self, request: BaseRequest, response: StreamResponse, request_start: float 

101 ) -> None: 

102 self.access_logger.log(request, response, self._loop.time() - request_start) 

103 

104 

105@dataclasses.dataclass(frozen=True) 

106class _ErrInfo: 

107 status: int 

108 exc: BaseException 

109 message: str 

110 

111 

112_MsgType = Tuple[Union[RawRequestMessage, _ErrInfo], StreamReader] 

113 

114 

115class RequestHandler(BaseProtocol): 

116 """HTTP protocol implementation. 

117 

118 RequestHandler handles incoming HTTP request. It reads request line, 

119 request headers and request payload and calls handle_request() method. 

120 By default it always returns with 404 response. 

121 

122 RequestHandler handles errors in incoming request, like bad 

123 status line, bad headers or incomplete payload. If any error occurs, 

124 connection gets closed. 

125 

126 keepalive_timeout -- number of seconds before closing 

127 keep-alive connection 

128 

129 tcp_keepalive -- TCP keep-alive is on, default is on 

130 

131 logger -- custom logger object 

132 

133 access_log_class -- custom class for access_logger 

134 

135 access_log -- custom logging object 

136 

137 access_log_format -- access log format string 

138 

139 loop -- Optional event loop 

140 

141 max_line_size -- Optional maximum header line size 

142 

143 max_field_size -- Optional maximum header field size 

144 

145 timeout_ceil_threshold -- Optional value to specify 

146 threshold to ceil() timeout 

147 values 

148 

149 """ 

150 

151 KEEPALIVE_RESCHEDULE_DELAY = 1 

152 

153 __slots__ = ( 

154 "_request_count", 

155 "_keepalive", 

156 "_manager", 

157 "_request_handler", 

158 "_request_factory", 

159 "_tcp_keepalive", 

160 "_keepalive_time", 

161 "_keepalive_handle", 

162 "_keepalive_timeout", 

163 "_lingering_time", 

164 "_messages", 

165 "_message_tail", 

166 "_waiter", 

167 "_task_handler", 

168 "_upgrade", 

169 "_payload_parser", 

170 "_request_parser", 

171 "logger", 

172 "access_log", 

173 "access_logger", 

174 "_close", 

175 "_force_close", 

176 "_current_request", 

177 "_timeout_ceil_threshold", 

178 ) 

179 

180 def __init__( 

181 self, 

182 manager: "Server", 

183 *, 

184 loop: asyncio.AbstractEventLoop, 

185 keepalive_timeout: float = 75.0, # NGINX default is 75 secs 

186 tcp_keepalive: bool = True, 

187 logger: Logger = server_logger, 

188 access_log_class: _AnyAbstractAccessLogger = AccessLogger, 

189 access_log: Optional[Logger] = access_logger, 

190 access_log_format: str = AccessLogger.LOG_FORMAT, 

191 max_line_size: int = 8190, 

192 max_field_size: int = 8190, 

193 lingering_time: float = 10.0, 

194 read_bufsize: int = 2**16, 

195 auto_decompress: bool = True, 

196 timeout_ceil_threshold: float = 5, 

197 ): 

198 super().__init__(loop) 

199 

200 self._request_count = 0 

201 self._keepalive = False 

202 self._current_request: Optional[BaseRequest] = None 

203 self._manager: Optional[Server] = manager 

204 self._request_handler: Optional[_RequestHandler] = manager.request_handler 

205 self._request_factory: Optional[_RequestFactory] = manager.request_factory 

206 

207 self._tcp_keepalive = tcp_keepalive 

208 # placeholder to be replaced on keepalive timeout setup 

209 self._keepalive_time = 0.0 

210 self._keepalive_handle: Optional[asyncio.Handle] = None 

211 self._keepalive_timeout = keepalive_timeout 

212 self._lingering_time = float(lingering_time) 

213 

214 self._messages: Deque[_MsgType] = deque() 

215 self._message_tail = b"" 

216 

217 self._waiter: Optional[asyncio.Future[None]] = None 

218 self._task_handler: Optional[asyncio.Task[None]] = None 

219 

220 self._upgrade = False 

221 self._payload_parser: Any = None 

222 self._request_parser: Optional[HttpRequestParser] = HttpRequestParser( 

223 self, 

224 loop, 

225 read_bufsize, 

226 max_line_size=max_line_size, 

227 max_field_size=max_field_size, 

228 payload_exception=RequestPayloadError, 

229 auto_decompress=auto_decompress, 

230 ) 

231 

232 self._timeout_ceil_threshold: float = 5 

233 try: 

234 self._timeout_ceil_threshold = float(timeout_ceil_threshold) 

235 except (TypeError, ValueError): 

236 pass 

237 

238 self.logger = logger 

239 self.access_log = access_log 

240 if access_log: 

241 if issubclass(access_log_class, AbstractAsyncAccessLogger): 

242 self.access_logger: Optional[ 

243 AbstractAsyncAccessLogger 

244 ] = access_log_class() 

245 else: 

246 access_logger = access_log_class(access_log, access_log_format) 

247 self.access_logger = AccessLoggerWrapper( 

248 access_logger, 

249 self._loop, 

250 ) 

251 else: 

252 self.access_logger = None 

253 

254 self._close = False 

255 self._force_close = False 

256 

257 def __repr__(self) -> str: 

258 return "<{} {}>".format( 

259 self.__class__.__name__, 

260 "connected" if self.transport is not None else "disconnected", 

261 ) 

262 

263 @property 

264 def keepalive_timeout(self) -> float: 

265 return self._keepalive_timeout 

266 

267 async def shutdown(self, timeout: Optional[float] = 15.0) -> None: 

268 """Do worker process exit preparations. 

269 

270 We need to clean up everything and stop accepting requests. 

271 It is especially important for keep-alive connections. 

272 """ 

273 self._force_close = True 

274 

275 if self._keepalive_handle is not None: 

276 self._keepalive_handle.cancel() 

277 

278 if self._waiter: 

279 self._waiter.cancel() 

280 

281 # wait for handlers 

282 with suppress(asyncio.CancelledError, asyncio.TimeoutError): 

283 async with ceil_timeout(timeout): 

284 if self._current_request is not None: 

285 self._current_request._cancel(asyncio.CancelledError()) 

286 

287 if self._task_handler is not None and not self._task_handler.done(): 

288 await self._task_handler 

289 

290 # force-close non-idle handler 

291 if self._task_handler is not None: 

292 self._task_handler.cancel() 

293 

294 if self.transport is not None: 

295 self.transport.close() 

296 self.transport = None 

297 

298 def connection_made(self, transport: asyncio.BaseTransport) -> None: 

299 super().connection_made(transport) 

300 

301 real_transport = cast(asyncio.Transport, transport) 

302 if self._tcp_keepalive: 

303 tcp_keepalive(real_transport) 

304 

305 self._task_handler = self._loop.create_task(self.start()) 

306 assert self._manager is not None 

307 self._manager.connection_made(self, real_transport) 

308 

309 def connection_lost(self, exc: Optional[BaseException]) -> None: 

310 if self._manager is None: 

311 return 

312 self._manager.connection_lost(self, exc) 

313 

314 super().connection_lost(exc) 

315 

316 # Grab value before setting _manager to None. 

317 handler_cancellation = self._manager.handler_cancellation 

318 

319 self._manager = None 

320 self._force_close = True 

321 self._request_factory = None 

322 self._request_handler = None 

323 self._request_parser = None 

324 

325 if self._keepalive_handle is not None: 

326 self._keepalive_handle.cancel() 

327 

328 if self._current_request is not None: 

329 if exc is None: 

330 exc = ConnectionResetError("Connection lost") 

331 self._current_request._cancel(exc) 

332 

333 if self._waiter is not None: 

334 self._waiter.cancel() 

335 

336 if handler_cancellation and self._task_handler is not None: 

337 self._task_handler.cancel() 

338 

339 self._task_handler = None 

340 

341 if self._payload_parser is not None: 

342 self._payload_parser.feed_eof() 

343 self._payload_parser = None 

344 

345 def set_parser(self, parser: Any) -> None: 

346 # Actual type is WebReader 

347 assert self._payload_parser is None 

348 

349 self._payload_parser = parser 

350 

351 if self._message_tail: 

352 self._payload_parser.feed_data(self._message_tail) 

353 self._message_tail = b"" 

354 

355 def eof_received(self) -> None: 

356 pass 

357 

358 def data_received(self, data: bytes) -> None: 

359 if self._force_close or self._close: 

360 return 

361 # parse http messages 

362 messages: Sequence[_MsgType] 

363 if self._payload_parser is None and not self._upgrade: 

364 assert self._request_parser is not None 

365 try: 

366 messages, upgraded, tail = self._request_parser.feed_data(data) 

367 except HttpProcessingError as exc: 

368 messages = [ 

369 (_ErrInfo(status=400, exc=exc, message=exc.message), EMPTY_PAYLOAD) 

370 ] 

371 upgraded = False 

372 tail = b"" 

373 

374 for msg, payload in messages or (): 

375 self._request_count += 1 

376 self._messages.append((msg, payload)) 

377 

378 waiter = self._waiter 

379 if messages and waiter is not None and not waiter.done(): 

380 # don't set result twice 

381 waiter.set_result(None) 

382 

383 self._upgrade = upgraded 

384 if upgraded and tail: 

385 self._message_tail = tail 

386 

387 # no parser, just store 

388 elif self._payload_parser is None and self._upgrade and data: 

389 self._message_tail += data 

390 

391 # feed payload 

392 elif data: 

393 eof, tail = self._payload_parser.feed_data(data) 

394 if eof: 

395 self.close() 

396 

397 def keep_alive(self, val: bool) -> None: 

398 """Set keep-alive connection mode. 

399 

400 :param bool val: new state. 

401 """ 

402 self._keepalive = val 

403 if self._keepalive_handle: 

404 self._keepalive_handle.cancel() 

405 self._keepalive_handle = None 

406 

407 def close(self) -> None: 

408 """Close connection. 

409 

410 Stop accepting new pipelining messages and close 

411 connection when handlers done processing messages. 

412 """ 

413 self._close = True 

414 if self._waiter: 

415 self._waiter.cancel() 

416 

417 def force_close(self) -> None: 

418 """Forcefully close connection.""" 

419 self._force_close = True 

420 if self._waiter: 

421 self._waiter.cancel() 

422 if self.transport is not None: 

423 self.transport.close() 

424 self.transport = None 

425 

426 async def log_access( 

427 self, request: BaseRequest, response: StreamResponse, request_start: float 

428 ) -> None: 

429 if self.access_logger is not None: 

430 await self.access_logger.log(request, response, request_start) 

431 

432 def log_debug(self, *args: Any, **kw: Any) -> None: 

433 if self._loop.get_debug(): 

434 self.logger.debug(*args, **kw) 

435 

436 def log_exception(self, *args: Any, **kw: Any) -> None: 

437 self.logger.exception(*args, **kw) 

438 

439 def _process_keepalive(self) -> None: 

440 if self._force_close or not self._keepalive: 

441 return 

442 

443 next = self._keepalive_time + self._keepalive_timeout 

444 

445 # handler in idle state 

446 if self._waiter: 

447 if self._loop.time() > next: 

448 self.force_close() 

449 return 

450 

451 # not all request handlers are done, 

452 # reschedule itself to next second 

453 self._keepalive_handle = self._loop.call_later( 

454 self.KEEPALIVE_RESCHEDULE_DELAY, 

455 self._process_keepalive, 

456 ) 

457 

458 async def _handle_request( 

459 self, 

460 request: BaseRequest, 

461 start_time: float, 

462 request_handler: Callable[[BaseRequest], Awaitable[StreamResponse]], 

463 ) -> Tuple[StreamResponse, bool]: 

464 assert self._request_handler is not None 

465 try: 

466 try: 

467 self._current_request = request 

468 resp = await request_handler(request) 

469 finally: 

470 self._current_request = None 

471 except HTTPException as exc: 

472 resp = Response( 

473 status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers 

474 ) 

475 resp._cookies = exc._cookies 

476 reset = await self.finish_response(request, resp, start_time) 

477 except asyncio.CancelledError: 

478 raise 

479 except asyncio.TimeoutError as exc: 

480 self.log_debug("Request handler timed out.", exc_info=exc) 

481 resp = self.handle_error(request, 504) 

482 reset = await self.finish_response(request, resp, start_time) 

483 except Exception as exc: 

484 resp = self.handle_error(request, 500, exc) 

485 reset = await self.finish_response(request, resp, start_time) 

486 else: 

487 reset = await self.finish_response(request, resp, start_time) 

488 

489 return resp, reset 

490 

491 async def start(self) -> None: 

492 """Process incoming request. 

493 

494 It reads request line, request headers and request payload, then 

495 calls handle_request() method. Subclass has to override 

496 handle_request(). start() handles various exceptions in request 

497 or response handling. Connection is being closed always unless 

498 keep_alive(True) specified. 

499 """ 

500 loop = self._loop 

501 handler = self._task_handler 

502 assert handler is not None 

503 manager = self._manager 

504 assert manager is not None 

505 keepalive_timeout = self._keepalive_timeout 

506 resp = None 

507 assert self._request_factory is not None 

508 assert self._request_handler is not None 

509 

510 while not self._force_close: 

511 if not self._messages: 

512 try: 

513 # wait for next request 

514 self._waiter = loop.create_future() 

515 await self._waiter 

516 except asyncio.CancelledError: 

517 break 

518 finally: 

519 self._waiter = None 

520 

521 message, payload = self._messages.popleft() 

522 

523 start = loop.time() 

524 

525 manager.requests_count += 1 

526 writer = StreamWriter(self, loop) 

527 if isinstance(message, _ErrInfo): 

528 # make request_factory work 

529 request_handler = self._make_error_handler(message) 

530 message = ERROR 

531 else: 

532 request_handler = self._request_handler 

533 

534 request = self._request_factory(message, payload, self, writer, handler) 

535 try: 

536 # a new task is used for copy context vars (#3406) 

537 task = self._loop.create_task( 

538 self._handle_request(request, start, request_handler) 

539 ) 

540 try: 

541 resp, reset = await task 

542 except (asyncio.CancelledError, ConnectionError): 

543 self.log_debug("Ignored premature client disconnection") 

544 break 

545 

546 # Drop the processed task from asyncio.Task.all_tasks() early 

547 del task 

548 # https://github.com/python/mypy/issues/14309 

549 if reset: # type: ignore[possibly-undefined] 

550 self.log_debug("Ignored premature client disconnection 2") 

551 break 

552 

553 # notify server about keep-alive 

554 self._keepalive = bool(resp.keep_alive) 

555 

556 # check payload 

557 if not payload.is_eof(): 

558 lingering_time = self._lingering_time 

559 # Could be force closed while awaiting above tasks. 

560 if not self._force_close and lingering_time: # type: ignore[redundant-expr] 

561 self.log_debug( 

562 "Start lingering close timer for %s sec.", lingering_time 

563 ) 

564 

565 now = loop.time() 

566 end_t = now + lingering_time 

567 

568 with suppress(asyncio.TimeoutError, asyncio.CancelledError): 

569 while not payload.is_eof() and now < end_t: 

570 async with ceil_timeout(end_t - now): 

571 # read and ignore 

572 await payload.readany() 

573 now = loop.time() 

574 

575 # if payload still uncompleted 

576 if not payload.is_eof() and not self._force_close: 

577 self.log_debug("Uncompleted request.") 

578 self.close() 

579 

580 payload.set_exception(PayloadAccessError()) 

581 

582 except asyncio.CancelledError: 

583 self.log_debug("Ignored premature client disconnection ") 

584 break 

585 except RuntimeError as exc: 

586 if self._loop.get_debug(): 

587 self.log_exception("Unhandled runtime exception", exc_info=exc) 

588 self.force_close() 

589 except Exception as exc: 

590 self.log_exception("Unhandled exception", exc_info=exc) 

591 self.force_close() 

592 finally: 

593 if self.transport is None and resp is not None: 

594 self.log_debug("Ignored premature client disconnection.") 

595 elif not self._force_close: 

596 if self._keepalive and not self._close: 

597 # start keep-alive timer 

598 if keepalive_timeout is not None: 

599 now = self._loop.time() 

600 self._keepalive_time = now 

601 if self._keepalive_handle is None: 

602 self._keepalive_handle = loop.call_at( 

603 now + keepalive_timeout, self._process_keepalive 

604 ) 

605 else: 

606 break 

607 

608 # remove handler, close transport if no handlers left 

609 if not self._force_close: 

610 self._task_handler = None 

611 if self.transport is not None: 

612 self.transport.close() 

613 

614 async def finish_response( 

615 self, request: BaseRequest, resp: StreamResponse, start_time: float 

616 ) -> bool: 

617 """Prepare the response and write_eof, then log access. 

618 

619 This has to 

620 be called within the context of any exception so the access logger 

621 can get exception information. Returns True if the client disconnects 

622 prematurely. 

623 """ 

624 request._finish() 

625 if self._request_parser is not None: 

626 self._request_parser.set_upgraded(False) 

627 self._upgrade = False 

628 if self._message_tail: 

629 self._request_parser.feed_data(self._message_tail) 

630 self._message_tail = b"" 

631 try: 

632 prepare_meth = resp.prepare 

633 except AttributeError: 

634 if resp is None: 

635 raise RuntimeError("Missing return " "statement on request handler") 

636 else: 

637 raise RuntimeError( 

638 "Web-handler should return " 

639 "a response instance, " 

640 "got {!r}".format(resp) 

641 ) 

642 try: 

643 await prepare_meth(request) 

644 await resp.write_eof() 

645 except ConnectionError: 

646 await self.log_access(request, resp, start_time) 

647 return True 

648 else: 

649 await self.log_access(request, resp, start_time) 

650 return False 

651 

652 def handle_error( 

653 self, 

654 request: BaseRequest, 

655 status: int = 500, 

656 exc: Optional[BaseException] = None, 

657 message: Optional[str] = None, 

658 ) -> StreamResponse: 

659 """Handle errors. 

660 

661 Returns HTTP response with specific status code. Logs additional 

662 information. It always closes current connection. 

663 """ 

664 self.log_exception("Error handling request", exc_info=exc) 

665 

666 # some data already got sent, connection is broken 

667 if request.writer.output_size > 0: 

668 raise ConnectionError( 

669 "Response is sent already, cannot send another response " 

670 "with the error message" 

671 ) 

672 

673 ct = "text/plain" 

674 if status == HTTPStatus.INTERNAL_SERVER_ERROR: 

675 title = "{0.value} {0.phrase}".format(HTTPStatus.INTERNAL_SERVER_ERROR) 

676 msg = HTTPStatus.INTERNAL_SERVER_ERROR.description 

677 tb = None 

678 if self._loop.get_debug(): 

679 with suppress(Exception): 

680 tb = traceback.format_exc() 

681 

682 if "text/html" in request.headers.get("Accept", ""): 

683 if tb: 

684 tb = html_escape(tb) 

685 msg = f"<h2>Traceback:</h2>\n<pre>{tb}</pre>" 

686 message = ( 

687 "<html><head>" 

688 "<title>{title}</title>" 

689 "</head><body>\n<h1>{title}</h1>" 

690 "\n{msg}\n</body></html>\n" 

691 ).format(title=title, msg=msg) 

692 ct = "text/html" 

693 else: 

694 if tb: 

695 msg = tb 

696 message = title + "\n\n" + msg 

697 

698 resp = Response(status=status, text=message, content_type=ct) 

699 resp.force_close() 

700 

701 return resp 

702 

703 def _make_error_handler( 

704 self, err_info: _ErrInfo 

705 ) -> Callable[[BaseRequest], Awaitable[StreamResponse]]: 

706 async def handler(request: BaseRequest) -> StreamResponse: 

707 return self.handle_error( 

708 request, err_info.status, err_info.exc, err_info.message 

709 ) 

710 

711 return handler