Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/web_protocol.py: 18%

352 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:52 +0000

1import asyncio 

2import asyncio.streams 

3import dataclasses 

4import traceback 

5from collections import deque 

6from contextlib import suppress 

7from html import escape as html_escape 

8from http import HTTPStatus 

9from logging import Logger 

10from typing import ( 

11 TYPE_CHECKING, 

12 Any, 

13 Awaitable, 

14 Callable, 

15 Deque, 

16 Optional, 

17 Sequence, 

18 Tuple, 

19 Type, 

20 Union, 

21 cast, 

22) 

23 

24import yarl 

25 

26from .abc import AbstractAccessLogger, AbstractAsyncAccessLogger, AbstractStreamWriter 

27from .base_protocol import BaseProtocol 

28from .helpers import ceil_timeout 

29from .http import ( 

30 HttpProcessingError, 

31 HttpRequestParser, 

32 HttpVersion10, 

33 RawRequestMessage, 

34 StreamWriter, 

35) 

36from .log import access_logger, server_logger 

37from .streams import EMPTY_PAYLOAD, StreamReader 

38from .tcp_helpers import tcp_keepalive 

39from .web_exceptions import HTTPException 

40from .web_log import AccessLogger 

41from .web_request import BaseRequest 

42from .web_response import Response, StreamResponse 

43 

44__all__ = ("RequestHandler", "RequestPayloadError", "PayloadAccessError") 

45 

46if TYPE_CHECKING: # pragma: no cover 

47 from .web_server import Server 

48 

49 

50_RequestFactory = Callable[ 

51 [ 

52 RawRequestMessage, 

53 StreamReader, 

54 "RequestHandler", 

55 AbstractStreamWriter, 

56 "asyncio.Task[None]", 

57 ], 

58 BaseRequest, 

59] 

60 

61_RequestHandler = Callable[[BaseRequest], Awaitable[StreamResponse]] 

62_AnyAbstractAccessLogger = Union[ 

63 Type[AbstractAsyncAccessLogger], 

64 Type[AbstractAccessLogger], 

65] 

66 

67ERROR = RawRequestMessage( 

68 "UNKNOWN", 

69 "/", 

70 HttpVersion10, 

71 {}, # type: ignore[arg-type] 

72 {}, # type: ignore[arg-type] 

73 True, 

74 None, 

75 False, 

76 False, 

77 yarl.URL("/"), 

78) 

79 

80 

81class RequestPayloadError(Exception): 

82 """Payload parsing error.""" 

83 

84 

85class PayloadAccessError(Exception): 

86 """Payload was accessed after response was sent.""" 

87 

88 

89class AccessLoggerWrapper(AbstractAsyncAccessLogger): 

90 """Wrap an AbstractAccessLogger so it behaves like an AbstractAsyncAccessLogger.""" 

91 

92 def __init__( 

93 self, access_logger: AbstractAccessLogger, loop: asyncio.AbstractEventLoop 

94 ) -> None: 

95 self.access_logger = access_logger 

96 self._loop = loop 

97 super().__init__() 

98 

99 async def log( 

100 self, request: BaseRequest, response: StreamResponse, request_start: float 

101 ) -> None: 

102 self.access_logger.log(request, response, self._loop.time() - request_start) 

103 

104 

105@dataclasses.dataclass(frozen=True) 

106class _ErrInfo: 

107 status: int 

108 exc: BaseException 

109 message: str 

110 

111 

112_MsgType = Tuple[Union[RawRequestMessage, _ErrInfo], StreamReader] 

113 

114 

115class RequestHandler(BaseProtocol): 

116 """HTTP protocol implementation. 

117 

118 RequestHandler handles incoming HTTP request. It reads request line, 

119 request headers and request payload and calls handle_request() method. 

120 By default it always returns with 404 response. 

121 

122 RequestHandler handles errors in incoming request, like bad 

123 status line, bad headers or incomplete payload. If any error occurs, 

124 connection gets closed. 

125 

126 keepalive_timeout -- number of seconds before closing 

127 keep-alive connection 

128 

129 tcp_keepalive -- TCP keep-alive is on, default is on 

130 

131 logger -- custom logger object 

132 

133 access_log_class -- custom class for access_logger 

134 

135 access_log -- custom logging object 

136 

137 access_log_format -- access log format string 

138 

139 loop -- Optional event loop 

140 

141 max_line_size -- Optional maximum header line size 

142 

143 max_field_size -- Optional maximum header field size 

144 

145 timeout_ceil_threshold -- Optional value to specify 

146 threshold to ceil() timeout 

147 values 

148 

149 """ 

150 

151 KEEPALIVE_RESCHEDULE_DELAY = 1 

152 

153 __slots__ = ( 

154 "_request_count", 

155 "_keepalive", 

156 "_manager", 

157 "_request_handler", 

158 "_request_factory", 

159 "_tcp_keepalive", 

160 "_keepalive_time", 

161 "_keepalive_handle", 

162 "_keepalive_timeout", 

163 "_lingering_time", 

164 "_messages", 

165 "_message_tail", 

166 "_waiter", 

167 "_task_handler", 

168 "_upgrade", 

169 "_payload_parser", 

170 "_request_parser", 

171 "logger", 

172 "access_log", 

173 "access_logger", 

174 "_close", 

175 "_force_close", 

176 "_current_request", 

177 "_timeout_ceil_threshold", 

178 ) 

179 

180 def __init__( 

181 self, 

182 manager: "Server", 

183 *, 

184 loop: asyncio.AbstractEventLoop, 

185 keepalive_timeout: float = 75.0, # NGINX default is 75 secs 

186 tcp_keepalive: bool = True, 

187 logger: Logger = server_logger, 

188 access_log_class: _AnyAbstractAccessLogger = AccessLogger, 

189 access_log: Logger = access_logger, 

190 access_log_format: str = AccessLogger.LOG_FORMAT, 

191 max_line_size: int = 8190, 

192 max_field_size: int = 8190, 

193 lingering_time: float = 10.0, 

194 read_bufsize: int = 2**16, 

195 auto_decompress: bool = True, 

196 timeout_ceil_threshold: float = 5, 

197 ): 

198 super().__init__(loop) 

199 

200 self._request_count = 0 

201 self._keepalive = False 

202 self._current_request: Optional[BaseRequest] = None 

203 self._manager: Optional[Server] = manager 

204 self._request_handler: Optional[_RequestHandler] = manager.request_handler 

205 self._request_factory: Optional[_RequestFactory] = manager.request_factory 

206 

207 self._tcp_keepalive = tcp_keepalive 

208 # placeholder to be replaced on keepalive timeout setup 

209 self._keepalive_time = 0.0 

210 self._keepalive_handle: Optional[asyncio.Handle] = None 

211 self._keepalive_timeout = keepalive_timeout 

212 self._lingering_time = float(lingering_time) 

213 

214 self._messages: Deque[_MsgType] = deque() 

215 self._message_tail = b"" 

216 

217 self._waiter: Optional[asyncio.Future[None]] = None 

218 self._task_handler: Optional[asyncio.Task[None]] = None 

219 

220 self._upgrade = False 

221 self._payload_parser: Any = None 

222 self._request_parser: Optional[HttpRequestParser] = HttpRequestParser( 

223 self, 

224 loop, 

225 read_bufsize, 

226 max_line_size=max_line_size, 

227 max_field_size=max_field_size, 

228 payload_exception=RequestPayloadError, 

229 auto_decompress=auto_decompress, 

230 ) 

231 

232 self._timeout_ceil_threshold: float = 5 

233 try: 

234 self._timeout_ceil_threshold = float(timeout_ceil_threshold) 

235 except (TypeError, ValueError): 

236 pass 

237 

238 self.logger = logger 

239 self.access_log = access_log 

240 if access_log: 

241 if issubclass(access_log_class, AbstractAsyncAccessLogger): 

242 self.access_logger: Optional[ 

243 AbstractAsyncAccessLogger 

244 ] = access_log_class() 

245 else: 

246 access_logger = access_log_class(access_log, access_log_format) 

247 self.access_logger = AccessLoggerWrapper( 

248 access_logger, 

249 self._loop, 

250 ) 

251 else: 

252 self.access_logger = None 

253 

254 self._close = False 

255 self._force_close = False 

256 

257 def __repr__(self) -> str: 

258 return "<{} {}>".format( 

259 self.__class__.__name__, 

260 "connected" if self.transport is not None else "disconnected", 

261 ) 

262 

263 @property 

264 def keepalive_timeout(self) -> float: 

265 return self._keepalive_timeout 

266 

267 async def shutdown(self, timeout: Optional[float] = 15.0) -> None: 

268 """Do worker process exit preparations. 

269 

270 We need to clean up everything and stop accepting requests. 

271 It is especially important for keep-alive connections. 

272 """ 

273 self._force_close = True 

274 

275 if self._keepalive_handle is not None: 

276 self._keepalive_handle.cancel() 

277 

278 if self._waiter: 

279 self._waiter.cancel() 

280 

281 # wait for handlers 

282 with suppress(asyncio.CancelledError, asyncio.TimeoutError): 

283 async with ceil_timeout(timeout): 

284 if self._current_request is not None: 

285 self._current_request._cancel(asyncio.CancelledError()) 

286 

287 if self._task_handler is not None and not self._task_handler.done(): 

288 await self._task_handler 

289 

290 # force-close non-idle handler 

291 if self._task_handler is not None: 

292 self._task_handler.cancel() 

293 

294 if self.transport is not None: 

295 self.transport.close() 

296 self.transport = None 

297 

298 def connection_made(self, transport: asyncio.BaseTransport) -> None: 

299 super().connection_made(transport) 

300 

301 real_transport = cast(asyncio.Transport, transport) 

302 if self._tcp_keepalive: 

303 tcp_keepalive(real_transport) 

304 

305 self._task_handler = self._loop.create_task(self.start()) 

306 assert self._manager is not None 

307 self._manager.connection_made(self, real_transport) 

308 

309 def connection_lost(self, exc: Optional[BaseException]) -> None: 

310 if self._manager is None: 

311 return 

312 self._manager.connection_lost(self, exc) 

313 

314 super().connection_lost(exc) 

315 

316 # Grab value before setting _manager to None. 

317 handler_cancellation = self._manager.handler_cancellation 

318 

319 self._manager = None 

320 self._force_close = True 

321 self._request_factory = None 

322 self._request_handler = None 

323 self._request_parser = None 

324 

325 if self._keepalive_handle is not None: 

326 self._keepalive_handle.cancel() 

327 

328 if self._current_request is not None: 

329 if exc is None: 

330 exc = ConnectionResetError("Connection lost") 

331 self._current_request._cancel(exc) 

332 

333 if self._waiter is not None: 

334 self._waiter.cancel() 

335 

336 if handler_cancellation and self._task_handler is not None: 

337 self._task_handler.cancel() 

338 

339 self._task_handler = None 

340 

341 if self._payload_parser is not None: 

342 self._payload_parser.feed_eof() 

343 self._payload_parser = None 

344 

345 def set_parser(self, parser: Any) -> None: 

346 # Actual type is WebReader 

347 assert self._payload_parser is None 

348 

349 self._payload_parser = parser 

350 

351 if self._message_tail: 

352 self._payload_parser.feed_data(self._message_tail) 

353 self._message_tail = b"" 

354 

355 def eof_received(self) -> None: 

356 pass 

357 

358 def data_received(self, data: bytes) -> None: 

359 if self._force_close or self._close: 

360 return 

361 # parse http messages 

362 messages: Sequence[_MsgType] 

363 if self._payload_parser is None and not self._upgrade: 

364 assert self._request_parser is not None 

365 try: 

366 messages, upgraded, tail = self._request_parser.feed_data(data) 

367 except HttpProcessingError as exc: 

368 messages = [ 

369 (_ErrInfo(status=400, exc=exc, message=exc.message), EMPTY_PAYLOAD) 

370 ] 

371 upgraded = False 

372 tail = b"" 

373 

374 for msg, payload in messages or (): 

375 self._request_count += 1 

376 self._messages.append((msg, payload)) 

377 

378 waiter = self._waiter 

379 if messages and waiter is not None and not waiter.done(): 

380 # don't set result twice 

381 waiter.set_result(None) 

382 

383 self._upgrade = upgraded 

384 if upgraded and tail: 

385 self._message_tail = tail 

386 

387 # no parser, just store 

388 elif self._payload_parser is None and self._upgrade and data: 

389 self._message_tail += data 

390 

391 # feed payload 

392 elif data: 

393 eof, tail = self._payload_parser.feed_data(data) 

394 if eof: 

395 self.close() 

396 

397 def keep_alive(self, val: bool) -> None: 

398 """Set keep-alive connection mode. 

399 

400 :param bool val: new state. 

401 """ 

402 self._keepalive = val 

403 if self._keepalive_handle: 

404 self._keepalive_handle.cancel() 

405 self._keepalive_handle = None 

406 

407 def close(self) -> None: 

408 """Close connection. 

409 

410 Stop accepting new pipelining messages and close 

411 connection when handlers done processing messages. 

412 """ 

413 self._close = True 

414 if self._waiter: 

415 self._waiter.cancel() 

416 

417 def force_close(self) -> None: 

418 """Forcefully close connection.""" 

419 self._force_close = True 

420 if self._waiter: 

421 self._waiter.cancel() 

422 if self.transport is not None: 

423 self.transport.close() 

424 self.transport = None 

425 

426 async def log_access( 

427 self, request: BaseRequest, response: StreamResponse, request_start: float 

428 ) -> None: 

429 if self.access_logger is not None: 

430 await self.access_logger.log(request, response, request_start) 

431 

432 def log_debug(self, *args: Any, **kw: Any) -> None: 

433 if self._loop.get_debug(): 

434 self.logger.debug(*args, **kw) 

435 

436 def log_exception(self, *args: Any, **kw: Any) -> None: 

437 self.logger.exception(*args, **kw) 

438 

439 def _process_keepalive(self) -> None: 

440 if self._force_close or not self._keepalive: 

441 return 

442 

443 next = self._keepalive_time + self._keepalive_timeout 

444 

445 # handler in idle state 

446 if self._waiter: 

447 if self._loop.time() > next: 

448 self.force_close() 

449 return 

450 

451 # not all request handlers are done, 

452 # reschedule itself to next second 

453 self._keepalive_handle = self._loop.call_later( 

454 self.KEEPALIVE_RESCHEDULE_DELAY, 

455 self._process_keepalive, 

456 ) 

457 

458 async def _handle_request( 

459 self, 

460 request: BaseRequest, 

461 start_time: float, 

462 request_handler: Callable[[BaseRequest], Awaitable[StreamResponse]], 

463 ) -> Tuple[StreamResponse, bool]: 

464 assert self._request_handler is not None 

465 try: 

466 try: 

467 self._current_request = request 

468 resp = await request_handler(request) 

469 finally: 

470 self._current_request = None 

471 except HTTPException as exc: 

472 resp = Response( 

473 status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers 

474 ) 

475 resp._cookies = exc._cookies 

476 reset = await self.finish_response(request, resp, start_time) 

477 except asyncio.CancelledError: 

478 raise 

479 except asyncio.TimeoutError as exc: 

480 self.log_debug("Request handler timed out.", exc_info=exc) 

481 resp = self.handle_error(request, 504) 

482 reset = await self.finish_response(request, resp, start_time) 

483 except Exception as exc: 

484 resp = self.handle_error(request, 500, exc) 

485 reset = await self.finish_response(request, resp, start_time) 

486 else: 

487 reset = await self.finish_response(request, resp, start_time) 

488 

489 return resp, reset 

490 

491 async def start(self) -> None: 

492 """Process incoming request. 

493 

494 It reads request line, request headers and request payload, then 

495 calls handle_request() method. Subclass has to override 

496 handle_request(). start() handles various exceptions in request 

497 or response handling. Connection is being closed always unless 

498 keep_alive(True) specified. 

499 """ 

500 loop = self._loop 

501 handler = self._task_handler 

502 assert handler is not None 

503 manager = self._manager 

504 assert manager is not None 

505 keepalive_timeout = self._keepalive_timeout 

506 resp = None 

507 assert self._request_factory is not None 

508 assert self._request_handler is not None 

509 

510 while not self._force_close: 

511 if not self._messages: 

512 try: 

513 # wait for next request 

514 self._waiter = loop.create_future() 

515 await self._waiter 

516 except asyncio.CancelledError: 

517 break 

518 finally: 

519 self._waiter = None 

520 

521 message, payload = self._messages.popleft() 

522 

523 start = loop.time() 

524 

525 manager.requests_count += 1 

526 writer = StreamWriter(self, loop) 

527 if isinstance(message, _ErrInfo): 

528 # make request_factory work 

529 request_handler = self._make_error_handler(message) 

530 message = ERROR 

531 else: 

532 request_handler = self._request_handler 

533 

534 request = self._request_factory(message, payload, self, writer, handler) 

535 try: 

536 # a new task is used for copy context vars (#3406) 

537 task = self._loop.create_task( 

538 self._handle_request(request, start, request_handler) 

539 ) 

540 try: 

541 resp, reset = await task 

542 except (asyncio.CancelledError, ConnectionError): 

543 self.log_debug("Ignored premature client disconnection") 

544 break 

545 

546 # Drop the processed task from asyncio.Task.all_tasks() early 

547 del task 

548 if reset: 

549 self.log_debug("Ignored premature client disconnection 2") 

550 break 

551 

552 # notify server about keep-alive 

553 self._keepalive = bool(resp.keep_alive) 

554 

555 # check payload 

556 if not payload.is_eof(): 

557 lingering_time = self._lingering_time 

558 if not self._force_close and lingering_time: 

559 self.log_debug( 

560 "Start lingering close timer for %s sec.", lingering_time 

561 ) 

562 

563 now = loop.time() 

564 end_t = now + lingering_time 

565 

566 with suppress(asyncio.TimeoutError, asyncio.CancelledError): 

567 while not payload.is_eof() and now < end_t: 

568 async with ceil_timeout(end_t - now): 

569 # read and ignore 

570 await payload.readany() 

571 now = loop.time() 

572 

573 # if payload still uncompleted 

574 if not payload.is_eof() and not self._force_close: 

575 self.log_debug("Uncompleted request.") 

576 self.close() 

577 

578 payload.set_exception(PayloadAccessError()) 

579 

580 except asyncio.CancelledError: 

581 self.log_debug("Ignored premature client disconnection ") 

582 break 

583 except RuntimeError as exc: 

584 if self._loop.get_debug(): 

585 self.log_exception("Unhandled runtime exception", exc_info=exc) 

586 self.force_close() 

587 except Exception as exc: 

588 self.log_exception("Unhandled exception", exc_info=exc) 

589 self.force_close() 

590 finally: 

591 if self.transport is None and resp is not None: 

592 self.log_debug("Ignored premature client disconnection.") 

593 elif not self._force_close: 

594 if self._keepalive and not self._close: 

595 # start keep-alive timer 

596 if keepalive_timeout is not None: 

597 now = self._loop.time() 

598 self._keepalive_time = now 

599 if self._keepalive_handle is None: 

600 self._keepalive_handle = loop.call_at( 

601 now + keepalive_timeout, self._process_keepalive 

602 ) 

603 else: 

604 break 

605 

606 # remove handler, close transport if no handlers left 

607 if not self._force_close: 

608 self._task_handler = None 

609 if self.transport is not None: 

610 self.transport.close() 

611 

612 async def finish_response( 

613 self, request: BaseRequest, resp: StreamResponse, start_time: float 

614 ) -> bool: 

615 """Prepare the response and write_eof, then log access. 

616 

617 This has to 

618 be called within the context of any exception so the access logger 

619 can get exception information. Returns True if the client disconnects 

620 prematurely. 

621 """ 

622 request._finish() 

623 if self._request_parser is not None: 

624 self._request_parser.set_upgraded(False) 

625 self._upgrade = False 

626 if self._message_tail: 

627 self._request_parser.feed_data(self._message_tail) 

628 self._message_tail = b"" 

629 try: 

630 prepare_meth = resp.prepare 

631 except AttributeError: 

632 if resp is None: 

633 raise RuntimeError("Missing return " "statement on request handler") 

634 else: 

635 raise RuntimeError( 

636 "Web-handler should return " 

637 "a response instance, " 

638 "got {!r}".format(resp) 

639 ) 

640 try: 

641 await prepare_meth(request) 

642 await resp.write_eof() 

643 except ConnectionError: 

644 await self.log_access(request, resp, start_time) 

645 return True 

646 else: 

647 await self.log_access(request, resp, start_time) 

648 return False 

649 

650 def handle_error( 

651 self, 

652 request: BaseRequest, 

653 status: int = 500, 

654 exc: Optional[BaseException] = None, 

655 message: Optional[str] = None, 

656 ) -> StreamResponse: 

657 """Handle errors. 

658 

659 Returns HTTP response with specific status code. Logs additional 

660 information. It always closes current connection. 

661 """ 

662 self.log_exception("Error handling request", exc_info=exc) 

663 

664 # some data already got sent, connection is broken 

665 if request.writer.output_size > 0: 

666 raise ConnectionError( 

667 "Response is sent already, cannot send another response " 

668 "with the error message" 

669 ) 

670 

671 ct = "text/plain" 

672 if status == HTTPStatus.INTERNAL_SERVER_ERROR: 

673 title = "{0.value} {0.phrase}".format(HTTPStatus.INTERNAL_SERVER_ERROR) 

674 msg = HTTPStatus.INTERNAL_SERVER_ERROR.description 

675 tb = None 

676 if self._loop.get_debug(): 

677 with suppress(Exception): 

678 tb = traceback.format_exc() 

679 

680 if "text/html" in request.headers.get("Accept", ""): 

681 if tb: 

682 tb = html_escape(tb) 

683 msg = f"<h2>Traceback:</h2>\n<pre>{tb}</pre>" 

684 message = ( 

685 "<html><head>" 

686 "<title>{title}</title>" 

687 "</head><body>\n<h1>{title}</h1>" 

688 "\n{msg}\n</body></html>\n" 

689 ).format(title=title, msg=msg) 

690 ct = "text/html" 

691 else: 

692 if tb: 

693 msg = tb 

694 message = title + "\n\n" + msg 

695 

696 resp = Response(status=status, text=message, content_type=ct) 

697 resp.force_close() 

698 

699 return resp 

700 

701 def _make_error_handler( 

702 self, err_info: _ErrInfo 

703 ) -> Callable[[BaseRequest], Awaitable[StreamResponse]]: 

704 async def handler(request: BaseRequest) -> StreamResponse: 

705 return self.handle_error( 

706 request, err_info.status, err_info.exc, err_info.message 

707 ) 

708 

709 return handler