Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/web_protocol.py: 19%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import asyncio
2import asyncio.streams
3import sys
4import traceback
5from collections import deque
6from collections.abc import Awaitable, Callable, Sequence
7from contextlib import suppress
8from html import escape as html_escape
9from http import HTTPStatus
10from logging import Logger
11from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
13import yarl
14from propcache import under_cached_property
16from .abc import AbstractAccessLogger, AbstractAsyncAccessLogger, AbstractStreamWriter
17from .base_protocol import BaseProtocol
18from .helpers import ceil_timeout, frozen_dataclass_decorator
19from .http import (
20 HttpProcessingError,
21 HttpRequestParser,
22 HttpVersion10,
23 RawRequestMessage,
24 StreamWriter,
25)
26from .http_exceptions import BadHttpMethod
27from .log import access_logger, server_logger
28from .streams import EMPTY_PAYLOAD, StreamReader
29from .tcp_helpers import tcp_keepalive
30from .web_exceptions import HTTPException, HTTPInternalServerError
31from .web_log import AccessLogger
32from .web_request import BaseRequest
33from .web_response import Response, StreamResponse
35__all__ = ("RequestHandler", "RequestPayloadError", "PayloadAccessError")
37if TYPE_CHECKING:
38 import ssl
40 from .web_server import Server
43_Request = TypeVar("_Request", bound=BaseRequest)
44_RequestFactory = Callable[
45 [
46 RawRequestMessage,
47 StreamReader,
48 "RequestHandler[_Request]",
49 AbstractStreamWriter,
50 "asyncio.Task[None]",
51 ],
52 _Request,
53]
55_RequestHandler = Callable[[_Request], Awaitable[StreamResponse]]
56_AnyAbstractAccessLogger = Union[
57 type[AbstractAsyncAccessLogger],
58 type[AbstractAccessLogger],
59]
61ERROR = RawRequestMessage(
62 "UNKNOWN",
63 "/",
64 HttpVersion10,
65 {}, # type: ignore[arg-type]
66 {}, # type: ignore[arg-type]
67 True,
68 None,
69 False,
70 False,
71 yarl.URL("/"),
72)
75class RequestPayloadError(Exception):
76 """Payload parsing error."""
79class PayloadAccessError(Exception):
80 """Payload was accessed after response was sent."""
83_PAYLOAD_ACCESS_ERROR = PayloadAccessError()
86class AccessLoggerWrapper(AbstractAsyncAccessLogger):
87 """Wrap an AbstractAccessLogger so it behaves like an AbstractAsyncAccessLogger."""
89 __slots__ = ("access_logger", "_loop")
91 def __init__(
92 self, access_logger: AbstractAccessLogger, loop: asyncio.AbstractEventLoop
93 ) -> None:
94 self.access_logger = access_logger
95 self._loop = loop
96 super().__init__()
98 async def log(
99 self, request: BaseRequest, response: StreamResponse, request_start: float
100 ) -> None:
101 self.access_logger.log(request, response, self._loop.time() - request_start)
103 @property
104 def enabled(self) -> bool:
105 """Check if logger is enabled."""
106 return self.access_logger.enabled
109@frozen_dataclass_decorator
110class _ErrInfo:
111 status: int
112 exc: BaseException
113 message: str
116_MsgType = tuple[RawRequestMessage | _ErrInfo, StreamReader]
119class RequestHandler(BaseProtocol, Generic[_Request]):
120 """HTTP protocol implementation.
122 RequestHandler handles incoming HTTP request. It reads request line,
123 request headers and request payload and calls handle_request() method.
124 By default it always returns with 404 response.
126 RequestHandler handles errors in incoming request, like bad
127 status line, bad headers or incomplete payload. If any error occurs,
128 connection gets closed.
130 keepalive_timeout -- number of seconds before closing
131 keep-alive connection
133 tcp_keepalive -- TCP keep-alive is on, default is on
135 logger -- custom logger object
137 access_log_class -- custom class for access_logger
139 access_log -- custom logging object
141 access_log_format -- access log format string
143 loop -- Optional event loop
145 max_line_size -- Optional maximum header line size
147 max_field_size -- Optional maximum header field size
149 timeout_ceil_threshold -- Optional value to specify
150 threshold to ceil() timeout
151 values
153 """
155 __slots__ = (
156 "max_field_size",
157 "max_headers",
158 "max_line_size",
159 "_request_count",
160 "_keepalive",
161 "_manager",
162 "_request_handler",
163 "_request_factory",
164 "_tcp_keepalive",
165 "_next_keepalive_close_time",
166 "_keepalive_handle",
167 "_keepalive_timeout",
168 "_lingering_time",
169 "_messages",
170 "_message_tail",
171 "_handler_waiter",
172 "_waiter",
173 "_task_handler",
174 "_upgrade",
175 "_payload_parser",
176 "_data_received_cb",
177 "_request_parser",
178 "logger",
179 "access_log",
180 "access_logger",
181 "_close",
182 "_force_close",
183 "_current_request",
184 "_timeout_ceil_threshold",
185 "_request_in_progress",
186 "_logging_enabled",
187 "_cache",
188 )
190 def __init__(
191 self,
192 manager: "Server[_Request]",
193 *,
194 loop: asyncio.AbstractEventLoop,
195 # Default should be high enough that it's likely longer than a reverse proxy.
196 keepalive_timeout: float = 3630,
197 tcp_keepalive: bool = True,
198 logger: Logger = server_logger,
199 access_log_class: _AnyAbstractAccessLogger = AccessLogger,
200 access_log: Logger | None = access_logger,
201 access_log_format: str = AccessLogger.LOG_FORMAT,
202 max_line_size: int = 8190,
203 max_headers: int = 128,
204 max_field_size: int = 8190,
205 lingering_time: float = 10.0,
206 read_bufsize: int = 2**16,
207 auto_decompress: bool = True,
208 timeout_ceil_threshold: float = 5,
209 ):
210 super().__init__(loop)
212 # _request_count is the number of requests processed with the same connection.
213 self._request_count = 0
214 self._keepalive = False
215 self._current_request: _Request | None = None
216 self._manager: Server[_Request] | None = manager
217 self._request_handler: _RequestHandler[_Request] | None = (
218 manager.request_handler
219 )
220 self._request_factory: _RequestFactory[_Request] | None = (
221 manager.request_factory
222 )
224 self.max_line_size = max_line_size
225 self.max_headers = max_headers
226 self.max_field_size = max_field_size
228 self._tcp_keepalive = tcp_keepalive
229 # placeholder to be replaced on keepalive timeout setup
230 self._next_keepalive_close_time = 0.0
231 self._keepalive_handle: asyncio.Handle | None = None
232 self._keepalive_timeout = keepalive_timeout
233 self._lingering_time = float(lingering_time)
235 self._messages: deque[_MsgType] = deque()
236 self._message_tail = b""
237 self._data_received_cb: Callable[[], None] | None = None
239 self._waiter: asyncio.Future[None] | None = None
240 self._handler_waiter: asyncio.Future[None] | None = None
241 self._task_handler: asyncio.Task[None] | None = None
243 self._upgrade = False
244 self._payload_parser: Any = None
245 self._request_parser: HttpRequestParser | None = HttpRequestParser(
246 self,
247 loop,
248 read_bufsize,
249 max_line_size=max_line_size,
250 max_field_size=max_field_size,
251 max_headers=max_headers,
252 payload_exception=RequestPayloadError,
253 auto_decompress=auto_decompress,
254 )
256 self._timeout_ceil_threshold: float = 5
257 try:
258 self._timeout_ceil_threshold = float(timeout_ceil_threshold)
259 except (TypeError, ValueError):
260 pass
262 self.logger = logger
263 self.access_log = access_log
264 if access_log:
265 if issubclass(access_log_class, AbstractAsyncAccessLogger):
266 self.access_logger: AbstractAsyncAccessLogger | None = (
267 access_log_class()
268 )
269 else:
270 access_logger = access_log_class(access_log, access_log_format)
271 self.access_logger = AccessLoggerWrapper(
272 access_logger,
273 self._loop,
274 )
275 self._logging_enabled = self.access_logger.enabled
276 else:
277 self.access_logger = None
278 self._logging_enabled = False
280 self._close = False
281 self._force_close = False
282 self._request_in_progress = False
283 self._cache: dict[str, Any] = {}
285 def __repr__(self) -> str:
286 return "<{} {}>".format(
287 self.__class__.__name__,
288 "connected" if self.transport is not None else "disconnected",
289 )
291 @under_cached_property
292 def ssl_context(self) -> Optional["ssl.SSLContext"]:
293 """Return SSLContext if available."""
294 return (
295 None
296 if self.transport is None
297 else self.transport.get_extra_info("sslcontext")
298 )
300 @under_cached_property
301 def peername(
302 self,
303 ) -> str | tuple[str, int, int, int] | tuple[str, int] | None:
304 """Return peername if available."""
305 return (
306 None
307 if self.transport is None
308 else self.transport.get_extra_info("peername")
309 )
311 @property
312 def keepalive_timeout(self) -> float:
313 return self._keepalive_timeout
315 async def shutdown(self, timeout: float | None = 15.0) -> None:
316 """Do worker process exit preparations.
318 We need to clean up everything and stop accepting requests.
319 It is especially important for keep-alive connections.
320 """
321 self._force_close = True
323 if self._keepalive_handle is not None:
324 self._keepalive_handle.cancel()
326 # Wait for graceful handler completion
327 if self._request_in_progress:
328 # The future is only created when we are shutting
329 # down while the handler is still processing a request
330 # to avoid creating a future for every request.
331 self._handler_waiter = self._loop.create_future()
332 try:
333 async with ceil_timeout(timeout):
334 await self._handler_waiter
335 except (asyncio.CancelledError, asyncio.TimeoutError):
336 self._handler_waiter = None
337 if (
338 sys.version_info >= (3, 11)
339 and (task := asyncio.current_task())
340 and task.cancelling()
341 ):
342 raise
343 # Then cancel handler and wait
344 try:
345 async with ceil_timeout(timeout):
346 if self._current_request is not None:
347 self._current_request._cancel(asyncio.CancelledError())
349 if self._task_handler is not None and not self._task_handler.done():
350 await asyncio.shield(self._task_handler)
351 except (asyncio.CancelledError, asyncio.TimeoutError):
352 if (
353 sys.version_info >= (3, 11)
354 and (task := asyncio.current_task())
355 and task.cancelling()
356 ):
357 raise
359 # force-close non-idle handler
360 if self._task_handler is not None:
361 self._task_handler.cancel()
363 self.force_close()
365 def connection_made(self, transport: asyncio.BaseTransport) -> None:
366 super().connection_made(transport)
368 real_transport = cast(asyncio.Transport, transport)
369 if self._tcp_keepalive:
370 tcp_keepalive(real_transport)
372 assert self._manager is not None
373 self._manager.connection_made(self, real_transport)
375 loop = self._loop
376 if sys.version_info >= (3, 12):
377 task = asyncio.Task(self.start(), loop=loop, eager_start=True)
378 else:
379 task = loop.create_task(self.start())
380 self._task_handler = task
382 def connection_lost(self, exc: BaseException | None) -> None:
383 if self._manager is None:
384 return
385 self._manager.connection_lost(self, exc)
387 # Grab value before setting _manager to None.
388 handler_cancellation = self._manager.handler_cancellation
390 self.force_close()
391 super().connection_lost(exc)
392 self._manager = None
393 self._request_factory = None
394 self._request_handler = None
395 self._request_parser = None
397 if self._keepalive_handle is not None:
398 self._keepalive_handle.cancel()
400 if self._current_request is not None:
401 if exc is None:
402 exc = ConnectionResetError("Connection lost")
403 self._current_request._cancel(exc)
405 if handler_cancellation and self._task_handler is not None:
406 self._task_handler.cancel()
408 self._task_handler = None
410 if self._payload_parser is not None:
411 self._payload_parser.feed_eof()
412 self._payload_parser = None
414 def set_parser(
415 self, parser: Any, data_received_cb: Callable[[], None] | None = None
416 ) -> None:
417 # Actual type is WebReader
418 assert self._payload_parser is None
420 self._payload_parser = parser
421 self._data_received_cb = data_received_cb
423 if self._message_tail:
424 self._payload_parser.feed_data(self._message_tail)
425 self._message_tail = b""
427 def eof_received(self) -> None:
428 pass
430 def data_received(self, data: bytes) -> None:
431 if self._force_close or self._close:
432 return
433 # parse http messages
434 messages: Sequence[_MsgType]
435 if self._payload_parser is None and not self._upgrade:
436 assert self._request_parser is not None
437 try:
438 messages, upgraded, tail = self._request_parser.feed_data(data)
439 except HttpProcessingError as exc:
440 messages = [
441 (_ErrInfo(status=400, exc=exc, message=exc.message), EMPTY_PAYLOAD)
442 ]
443 upgraded = False
444 tail = b""
446 for msg, payload in messages or ():
447 self._request_count += 1
448 self._messages.append((msg, payload))
450 waiter = self._waiter
451 if messages and waiter is not None and not waiter.done():
452 # don't set result twice
453 waiter.set_result(None)
455 self._upgrade = upgraded
456 if upgraded and tail:
457 self._message_tail = tail
459 # no parser, just store
460 elif self._payload_parser is None and self._upgrade and data:
461 self._message_tail += data
463 # feed payload
464 elif data:
465 if self._data_received_cb is not None:
466 self._data_received_cb()
467 eof, tail = self._payload_parser.feed_data(data)
468 if eof:
469 self.close()
471 def keep_alive(self, val: bool) -> None:
472 """Set keep-alive connection mode.
474 :param bool val: new state.
475 """
476 self._keepalive = val
477 if self._keepalive_handle:
478 self._keepalive_handle.cancel()
479 self._keepalive_handle = None
481 def close(self) -> None:
482 """Close connection.
484 Stop accepting new pipelining messages and close
485 connection when handlers done processing messages.
486 """
487 self._close = True
488 if self._waiter:
489 self._waiter.cancel()
491 def force_close(self) -> None:
492 """Forcefully close connection."""
493 self._force_close = True
494 if self._waiter:
495 self._waiter.cancel()
496 if self.transport is not None:
497 self.transport.close()
498 self.transport = None
500 async def log_access(
501 self,
502 request: BaseRequest,
503 response: StreamResponse,
504 request_start: float | None,
505 ) -> None:
506 if self._logging_enabled and self.access_logger is not None:
507 if TYPE_CHECKING:
508 assert request_start is not None
509 await self.access_logger.log(request, response, request_start)
511 def log_debug(self, *args: Any, **kw: Any) -> None:
512 if self._loop.get_debug():
513 self.logger.debug(*args, **kw)
515 def log_exception(self, *args: Any, **kw: Any) -> None:
516 self.logger.exception(*args, **kw)
518 def _process_keepalive(self) -> None:
519 self._keepalive_handle = None
520 if self._force_close or not self._keepalive:
521 return
523 loop = self._loop
524 now = loop.time()
525 close_time = self._next_keepalive_close_time
526 if now < close_time:
527 # Keep alive close check fired too early, reschedule
528 self._keepalive_handle = loop.call_at(close_time, self._process_keepalive)
529 return
531 # handler in idle state
532 if self._waiter and not self._waiter.done():
533 self.force_close()
535 async def _handle_request(
536 self,
537 request: _Request,
538 start_time: float | None,
539 request_handler: Callable[[_Request], Awaitable[StreamResponse]],
540 ) -> tuple[StreamResponse, bool]:
541 self._request_in_progress = True
542 try:
543 try:
544 self._current_request = request
545 resp = await request_handler(request)
546 finally:
547 self._current_request = None
548 except HTTPException as exc:
549 resp = Response(
550 status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers
551 )
552 resp._cookies = exc._cookies
553 resp, reset = await self.finish_response(request, resp, start_time)
554 except asyncio.CancelledError:
555 raise
556 except asyncio.TimeoutError as exc:
557 self.log_debug("Request handler timed out.", exc_info=exc)
558 resp = self.handle_error(request, 504)
559 resp, reset = await self.finish_response(request, resp, start_time)
560 except Exception as exc:
561 resp = self.handle_error(request, 500, exc)
562 resp, reset = await self.finish_response(request, resp, start_time)
563 else:
564 resp, reset = await self.finish_response(request, resp, start_time)
565 finally:
566 self._request_in_progress = False
567 if self._handler_waiter is not None:
568 self._handler_waiter.set_result(None)
570 return resp, reset
572 async def start(self) -> None:
573 """Process incoming request.
575 It reads request line, request headers and request payload, then
576 calls handle_request() method. Subclass has to override
577 handle_request(). start() handles various exceptions in request
578 or response handling. Connection is being closed always unless
579 keep_alive(True) specified.
580 """
581 loop = self._loop
582 manager = self._manager
583 assert manager is not None
584 keepalive_timeout = self._keepalive_timeout
585 resp = None
586 assert self._request_factory is not None
587 assert self._request_handler is not None
589 while not self._force_close:
590 if not self._messages:
591 try:
592 # wait for next request
593 self._waiter = loop.create_future()
594 await self._waiter
595 finally:
596 self._waiter = None
598 message, payload = self._messages.popleft()
600 # time is only fetched if logging is enabled as otherwise
601 # its thrown away and never used.
602 start = loop.time() if self._logging_enabled else None
604 manager.requests_count += 1
605 writer = StreamWriter(self, loop)
606 if not isinstance(message, _ErrInfo):
607 request_handler = self._request_handler
608 else:
609 # make request_factory work
610 request_handler = self._make_error_handler(message)
611 message = ERROR
613 # Important don't hold a reference to the current task
614 # as on traceback it will prevent the task from being
615 # collected and will cause a memory leak.
616 request = self._request_factory(
617 message,
618 payload,
619 self,
620 writer,
621 self._task_handler or asyncio.current_task(loop), # type: ignore[arg-type]
622 )
623 try:
624 # a new task is used for copy context vars (#3406)
625 coro = self._handle_request(request, start, request_handler)
626 if sys.version_info >= (3, 12):
627 task = asyncio.Task(coro, loop=loop, eager_start=True)
628 else:
629 task = loop.create_task(coro)
630 try:
631 resp, reset = await task
632 except ConnectionError:
633 self.log_debug("Ignored premature client disconnection")
634 break
636 # Drop the processed task from asyncio.Task.all_tasks() early
637 del task
638 # https://github.com/python/mypy/issues/14309
639 if reset: # type: ignore[possibly-undefined]
640 self.log_debug("Ignored premature client disconnection 2")
641 break
643 # notify server about keep-alive
644 self._keepalive = bool(resp.keep_alive)
646 # check payload
647 if not payload.is_eof():
648 lingering_time = self._lingering_time
649 # Could be force closed while awaiting above tasks.
650 if not self._force_close and lingering_time: # type: ignore[redundant-expr]
651 self.log_debug(
652 "Start lingering close timer for %s sec.", lingering_time
653 )
655 now = loop.time()
656 end_t = now + lingering_time
658 try:
659 while not payload.is_eof() and now < end_t:
660 async with ceil_timeout(end_t - now):
661 # read and ignore
662 await payload.readany()
663 now = loop.time()
664 except (asyncio.CancelledError, asyncio.TimeoutError):
665 if (
666 sys.version_info >= (3, 11)
667 and (t := asyncio.current_task())
668 and t.cancelling()
669 ):
670 raise
672 # if payload still uncompleted
673 if not payload.is_eof() and not self._force_close:
674 self.log_debug("Uncompleted request.")
675 self.close()
677 payload.set_exception(_PAYLOAD_ACCESS_ERROR)
679 except asyncio.CancelledError:
680 self.log_debug("Ignored premature client disconnection")
681 self.force_close()
682 raise
683 except Exception as exc:
684 self.log_exception("Unhandled exception", exc_info=exc)
685 self.force_close()
686 except BaseException:
687 self.force_close()
688 raise
689 finally:
690 request._task = None # type: ignore[assignment] # Break reference cycle in case of exception
691 if self.transport is None and resp is not None:
692 self.log_debug("Ignored premature client disconnection.")
694 if self._keepalive and not self._close and not self._force_close:
695 # start keep-alive timer
696 close_time = loop.time() + keepalive_timeout
697 self._next_keepalive_close_time = close_time
698 if self._keepalive_handle is None:
699 self._keepalive_handle = loop.call_at(
700 close_time, self._process_keepalive
701 )
702 else:
703 break
705 # remove handler, close transport if no handlers left
706 if not self._force_close:
707 self._task_handler = None
708 if self.transport is not None:
709 self.transport.close()
711 async def finish_response(
712 self, request: BaseRequest, resp: StreamResponse, start_time: float | None
713 ) -> tuple[StreamResponse, bool]:
714 """Prepare the response and write_eof, then log access.
716 This has to
717 be called within the context of any exception so the access logger
718 can get exception information. Returns True if the client disconnects
719 prematurely.
720 """
721 request._finish()
722 if self._request_parser is not None:
723 self._request_parser.set_upgraded(False)
724 self._upgrade = False
725 if self._message_tail:
726 self._request_parser.feed_data(self._message_tail)
727 self._message_tail = b""
728 try:
729 prepare_meth = resp.prepare
730 except AttributeError:
731 if resp is None:
732 self.log_exception("Missing return statement on request handler") # type: ignore[unreachable]
733 else:
734 self.log_exception(
735 f"Web-handler should return a response instance, got {resp!r}"
736 )
737 exc = HTTPInternalServerError()
738 resp = Response(
739 status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers
740 )
741 prepare_meth = resp.prepare
742 try:
743 await prepare_meth(request)
744 await resp.write_eof()
745 except ConnectionError:
746 await self.log_access(request, resp, start_time)
747 return resp, True
749 await self.log_access(request, resp, start_time)
750 return resp, False
752 def handle_error(
753 self,
754 request: BaseRequest,
755 status: int = 500,
756 exc: BaseException | None = None,
757 message: str | None = None,
758 ) -> StreamResponse:
759 """Handle errors.
761 Returns HTTP response with specific status code. Logs additional
762 information. It always closes current connection.
763 """
764 if self._request_count == 1 and isinstance(exc, BadHttpMethod):
765 # BadHttpMethod is common when a client sends non-HTTP
766 # or encrypted traffic to an HTTP port. This is expected
767 # to happen when connected to the public internet so we log
768 # it at the debug level as to not fill logs with noise.
769 self.logger.debug(
770 "Error handling request from %s", request.remote, exc_info=exc
771 )
772 else:
773 self.log_exception(
774 "Error handling request from %s", request.remote, exc_info=exc
775 )
777 # some data already got sent, connection is broken
778 if request.writer.output_size > 0:
779 raise ConnectionError(
780 "Response is sent already, cannot send another response "
781 "with the error message"
782 )
784 ct = "text/plain"
785 if status == HTTPStatus.INTERNAL_SERVER_ERROR:
786 title = f"{HTTPStatus.INTERNAL_SERVER_ERROR.value} {HTTPStatus.INTERNAL_SERVER_ERROR.phrase}"
787 msg = HTTPStatus.INTERNAL_SERVER_ERROR.description
788 tb = None
789 if self._loop.get_debug():
790 with suppress(Exception):
791 tb = traceback.format_exc()
793 if "text/html" in request.headers.get("Accept", ""):
794 if tb:
795 tb = html_escape(tb)
796 msg = f"<h2>Traceback:</h2>\n<pre>{tb}</pre>"
797 message = (
798 "<html><head>"
799 f"<title>{title}</title>"
800 f"</head><body>\n<h1>{title}</h1>"
801 f"\n{msg}\n</body></html>\n"
802 )
803 ct = "text/html"
804 else:
805 if tb:
806 msg = tb
807 message = title + "\n\n" + msg
809 resp = Response(status=status, text=message, content_type=ct)
810 resp.force_close()
812 return resp
814 def _make_error_handler(
815 self, err_info: _ErrInfo
816 ) -> Callable[[BaseRequest], Awaitable[StreamResponse]]:
817 async def handler(request: BaseRequest) -> StreamResponse:
818 return self.handle_error(
819 request, err_info.status, err_info.exc, err_info.message
820 )
822 return handler