Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/web_protocol.py: 19%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import asyncio
2import asyncio.streams
3import sys
4import traceback
5from collections import deque
6from collections.abc import Awaitable, Callable, Sequence
7from contextlib import suppress
8from html import escape as html_escape
9from http import HTTPStatus
10from logging import Logger
11from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
13import yarl
14from propcache import under_cached_property
16from .abc import AbstractAccessLogger, AbstractAsyncAccessLogger, AbstractStreamWriter
17from .base_protocol import BaseProtocol
18from .helpers import DEFAULT_CHUNK_SIZE, ceil_timeout, frozen_dataclass_decorator
19from .http import (
20 HttpProcessingError,
21 HttpRequestParser,
22 HttpVersion10,
23 RawRequestMessage,
24 StreamWriter,
25 WebSocketReader,
26)
27from .http_exceptions import BadHttpMethod
28from .log import access_logger, server_logger
29from .streams import EMPTY_PAYLOAD, StreamReader
30from .tcp_helpers import tcp_keepalive
31from .web_exceptions import HTTPException, HTTPInternalServerError
32from .web_log import AccessLogger
33from .web_request import BaseRequest
34from .web_response import Response, StreamResponse
36__all__ = ("RequestHandler", "RequestPayloadError", "PayloadAccessError")
38if TYPE_CHECKING:
39 import ssl
41 from .web_server import Server
44_Request = TypeVar("_Request", bound=BaseRequest)
45_RequestFactory = Callable[
46 [
47 RawRequestMessage,
48 StreamReader,
49 "RequestHandler[_Request]",
50 AbstractStreamWriter,
51 "asyncio.Task[None]",
52 ],
53 _Request,
54]
56_RequestHandler = Callable[[_Request], Awaitable[StreamResponse]]
57_AnyAbstractAccessLogger = Union[
58 type[AbstractAsyncAccessLogger],
59 type[AbstractAccessLogger],
60]
62ERROR = RawRequestMessage(
63 "UNKNOWN",
64 "/",
65 HttpVersion10,
66 {}, # type: ignore[arg-type]
67 {}, # type: ignore[arg-type]
68 True,
69 None,
70 False,
71 False,
72 yarl.URL("/"),
73)
76class RequestPayloadError(Exception):
77 """Payload parsing error."""
80class PayloadAccessError(Exception):
81 """Payload was accessed after response was sent."""
84_PAYLOAD_ACCESS_ERROR = PayloadAccessError()
87class AccessLoggerWrapper(AbstractAsyncAccessLogger):
88 """Wrap an AbstractAccessLogger so it behaves like an AbstractAsyncAccessLogger."""
90 __slots__ = ("access_logger", "_loop")
92 def __init__(
93 self, access_logger: AbstractAccessLogger, loop: asyncio.AbstractEventLoop
94 ) -> None:
95 self.access_logger = access_logger
96 self._loop = loop
97 super().__init__()
99 async def log(
100 self, request: BaseRequest, response: StreamResponse, request_start: float
101 ) -> None:
102 self.access_logger.log(request, response, self._loop.time() - request_start)
104 @property
105 def enabled(self) -> bool:
106 """Check if logger is enabled."""
107 return self.access_logger.enabled
110@frozen_dataclass_decorator
111class _ErrInfo:
112 status: int
113 exc: BaseException
114 message: str
117_MsgType = tuple[RawRequestMessage | _ErrInfo, StreamReader]
120class RequestHandler(BaseProtocol, Generic[_Request]):
121 """HTTP protocol implementation.
123 RequestHandler handles incoming HTTP request. It reads request line,
124 request headers and request payload and calls handle_request() method.
125 By default it always returns with 404 response.
127 RequestHandler handles errors in incoming request, like bad
128 status line, bad headers or incomplete payload. If any error occurs,
129 connection gets closed.
131 keepalive_timeout -- number of seconds before closing
132 keep-alive connection
134 tcp_keepalive -- TCP keep-alive is on, default is on
136 logger -- custom logger object
138 access_log_class -- custom class for access_logger
140 access_log -- custom logging object
142 access_log_format -- access log format string
144 loop -- Optional event loop
146 max_line_size -- Optional maximum header line size
148 max_field_size -- Optional maximum header field size
150 timeout_ceil_threshold -- Optional value to specify
151 threshold to ceil() timeout
152 values
154 """
156 __slots__ = (
157 "max_field_size",
158 "max_headers",
159 "max_line_size",
160 "_request_count",
161 "_keepalive",
162 "_manager",
163 "_request_handler",
164 "_request_factory",
165 "_tcp_keepalive",
166 "_next_keepalive_close_time",
167 "_keepalive_handle",
168 "_keepalive_timeout",
169 "_lingering_time",
170 "_messages",
171 "_message_tail",
172 "_handler_waiter",
173 "_waiter",
174 "_task_handler",
175 "_payload_parser",
176 "_data_received_cb",
177 "logger",
178 "access_log",
179 "access_logger",
180 "_close",
181 "_force_close",
182 "_current_request",
183 "_timeout_ceil_threshold",
184 "_request_in_progress",
185 "_logging_enabled",
186 "_cache",
187 )
189 def __init__(
190 self,
191 manager: "Server[_Request]",
192 *,
193 loop: asyncio.AbstractEventLoop,
194 # Default should be high enough that it's likely longer than a reverse proxy.
195 keepalive_timeout: float = 3630,
196 tcp_keepalive: bool = True,
197 logger: Logger = server_logger,
198 access_log_class: _AnyAbstractAccessLogger = AccessLogger,
199 access_log: Logger | None = access_logger,
200 access_log_format: str = AccessLogger.LOG_FORMAT,
201 max_line_size: int = 8190,
202 max_headers: int = 128,
203 max_field_size: int = 8190,
204 lingering_time: float = 10.0,
205 read_bufsize: int = DEFAULT_CHUNK_SIZE,
206 auto_decompress: bool = True,
207 timeout_ceil_threshold: float = 5,
208 ):
209 parser = HttpRequestParser(
210 self,
211 loop,
212 read_bufsize,
213 max_line_size=max_line_size,
214 max_field_size=max_field_size,
215 max_headers=max_headers,
216 payload_exception=RequestPayloadError,
217 auto_decompress=auto_decompress,
218 )
219 super().__init__(loop, parser)
221 # _request_count is the number of requests processed with the same connection.
222 self._request_count = 0
223 self._keepalive = False
224 self._current_request: _Request | None = None
225 self._manager: Server[_Request] | None = manager
226 self._request_handler: _RequestHandler[_Request] | None = (
227 manager.request_handler
228 )
229 self._request_factory: _RequestFactory[_Request] | None = (
230 manager.request_factory
231 )
233 self.max_line_size = max_line_size
234 self.max_headers = max_headers
235 self.max_field_size = max_field_size
237 self._tcp_keepalive = tcp_keepalive
238 # placeholder to be replaced on keepalive timeout setup
239 self._next_keepalive_close_time = 0.0
240 self._keepalive_handle: asyncio.Handle | None = None
241 self._keepalive_timeout = keepalive_timeout
242 self._lingering_time = float(lingering_time)
244 self._messages: deque[_MsgType] = deque()
245 self._message_tail = b""
246 self._data_received_cb: Callable[[], None] | None = None
248 self._waiter: asyncio.Future[None] | None = None
249 self._handler_waiter: asyncio.Future[None] | None = None
250 self._task_handler: asyncio.Task[None] | None = None
251 self._payload_parser: Any = None
253 self._timeout_ceil_threshold: float = 5
254 try:
255 self._timeout_ceil_threshold = float(timeout_ceil_threshold)
256 except (TypeError, ValueError):
257 pass
259 self.logger = logger
260 self.access_log = access_log
261 if access_log:
262 if issubclass(access_log_class, AbstractAsyncAccessLogger):
263 self.access_logger: AbstractAsyncAccessLogger | None = (
264 access_log_class()
265 )
266 else:
267 access_logger = access_log_class(access_log, access_log_format)
268 self.access_logger = AccessLoggerWrapper(
269 access_logger,
270 self._loop,
271 )
272 self._logging_enabled = self.access_logger.enabled
273 else:
274 self.access_logger = None
275 self._logging_enabled = False
277 self._close = False
278 self._force_close = False
279 self._request_in_progress = False
280 self._cache: dict[str, Any] = {}
282 def __repr__(self) -> str:
283 return "<{} {}>".format(
284 self.__class__.__name__,
285 "connected" if self.transport is not None else "disconnected",
286 )
288 @under_cached_property
289 def ssl_context(self) -> Optional["ssl.SSLContext"]:
290 """Return SSLContext if available."""
291 return (
292 None
293 if self.transport is None
294 else self.transport.get_extra_info("sslcontext")
295 )
297 @under_cached_property
298 def peername(
299 self,
300 ) -> str | tuple[str, int, int, int] | tuple[str, int] | None:
301 """Return peername if available."""
302 return (
303 None
304 if self.transport is None
305 else self.transport.get_extra_info("peername")
306 )
308 @under_cached_property
309 def sockname(
310 self,
311 ) -> str | tuple[str, int, int, int] | tuple[str, int] | None:
312 """Return sockname if available."""
313 return (
314 None
315 if self.transport is None
316 else self.transport.get_extra_info("sockname")
317 )
319 @property
320 def keepalive_timeout(self) -> float:
321 return self._keepalive_timeout
323 async def shutdown(self, timeout: float | None = 15.0) -> None:
324 """Do worker process exit preparations.
326 We need to clean up everything and stop accepting requests.
327 It is especially important for keep-alive connections.
328 """
329 self._force_close = True
331 if self._keepalive_handle is not None:
332 self._keepalive_handle.cancel()
334 # Wait for graceful handler completion
335 if self._request_in_progress:
336 # The future is only created when we are shutting
337 # down while the handler is still processing a request
338 # to avoid creating a future for every request.
339 self._handler_waiter = self._loop.create_future()
340 try:
341 async with ceil_timeout(timeout):
342 await self._handler_waiter
343 except (asyncio.CancelledError, asyncio.TimeoutError):
344 self._handler_waiter = None
345 if (
346 sys.version_info >= (3, 11)
347 and (task := asyncio.current_task())
348 and task.cancelling()
349 ):
350 raise
351 # Then cancel handler and wait
352 try:
353 async with ceil_timeout(timeout):
354 if self._current_request is not None:
355 self._current_request._cancel(asyncio.CancelledError())
357 if self._task_handler is not None and not self._task_handler.done():
358 await asyncio.shield(self._task_handler)
359 except (asyncio.CancelledError, asyncio.TimeoutError):
360 if (
361 sys.version_info >= (3, 11)
362 and (task := asyncio.current_task())
363 and task.cancelling()
364 ):
365 raise
367 # force-close non-idle handler
368 if self._task_handler is not None:
369 self._task_handler.cancel()
371 self.force_close()
373 def connection_made(self, transport: asyncio.BaseTransport) -> None:
374 super().connection_made(transport)
376 real_transport = cast(asyncio.Transport, transport)
377 if self._tcp_keepalive:
378 tcp_keepalive(real_transport)
380 assert self._manager is not None
381 self._manager.connection_made(self, real_transport)
383 loop = self._loop
384 if sys.version_info >= (3, 12):
385 task = asyncio.Task(self.start(), loop=loop, eager_start=True)
386 else:
387 task = loop.create_task(self.start())
388 self._task_handler = task
390 def connection_lost(self, exc: BaseException | None) -> None:
391 if self._manager is None:
392 return
393 self._manager.connection_lost(self, exc)
395 # Grab value before setting _manager to None.
396 handler_cancellation = self._manager.handler_cancellation
398 self.force_close()
399 super().connection_lost(exc)
400 self._manager = None
401 self._request_factory = None
402 self._request_handler = None
403 self._parser = None
405 if self._keepalive_handle is not None:
406 self._keepalive_handle.cancel()
408 if self._current_request is not None:
409 if exc is None:
410 exc = ConnectionResetError("Connection lost")
411 self._current_request._cancel(exc)
413 if handler_cancellation and self._task_handler is not None:
414 self._task_handler.cancel()
416 self._task_handler = None
418 if self._payload_parser is not None:
419 self._payload_parser.feed_eof()
420 self._payload_parser = None
422 def set_parser(
423 self,
424 parser: WebSocketReader,
425 data_received_cb: Callable[[], None] | None = None,
426 ) -> None:
427 assert self._payload_parser is None
429 self._payload_parser = parser
430 self._data_received_cb = data_received_cb
432 if self._message_tail:
433 self._payload_parser.feed_data(self._message_tail)
434 self._message_tail = b""
436 def eof_received(self) -> None:
437 pass
439 def data_received(self, data: bytes) -> None:
440 if self._force_close or self._close:
441 return
442 # parse http messages
443 messages: Sequence[_MsgType]
444 if self._payload_parser is None and not self._upgraded:
445 assert self._parser is not None
446 try:
447 messages, upgraded, tail = self._parser.feed_data(data)
448 except HttpProcessingError as exc:
449 messages = [
450 (_ErrInfo(status=400, exc=exc, message=exc.message), EMPTY_PAYLOAD)
451 ]
452 upgraded = False
453 tail = b""
455 for msg, payload in messages or ():
456 self._request_count += 1
457 self._messages.append((msg, payload))
459 waiter = self._waiter
460 if messages and waiter is not None and not waiter.done():
461 # don't set result twice
462 waiter.set_result(None)
464 self._upgraded = upgraded
465 if upgraded and tail:
466 self._message_tail = tail
468 # no parser, just store
469 elif self._payload_parser is None and self._upgraded and data:
470 self._message_tail += data
472 # feed payload
473 elif data:
474 if self._data_received_cb is not None:
475 self._data_received_cb()
476 eof, tail = self._payload_parser.feed_data(data)
477 if eof:
478 self.close()
480 def keep_alive(self, val: bool) -> None:
481 """Set keep-alive connection mode.
483 :param bool val: new state.
484 """
485 self._keepalive = val
486 if self._keepalive_handle:
487 self._keepalive_handle.cancel()
488 self._keepalive_handle = None
490 def close(self) -> None:
491 """Close connection.
493 Stop accepting new pipelining messages and close
494 connection when handlers done processing messages.
495 """
496 self._close = True
497 if self._waiter:
498 self._waiter.cancel()
500 def force_close(self) -> None:
501 """Forcefully close connection."""
502 self._force_close = True
503 if self._waiter:
504 self._waiter.cancel()
505 if self.transport is not None:
506 self.transport.close()
507 self.transport = None
509 async def log_access(
510 self,
511 request: BaseRequest,
512 response: StreamResponse,
513 request_start: float | None,
514 ) -> None:
515 if self._logging_enabled and self.access_logger is not None:
516 if TYPE_CHECKING:
517 assert request_start is not None
518 await self.access_logger.log(request, response, request_start)
520 def log_debug(self, *args: Any, **kw: Any) -> None:
521 if self._loop.get_debug():
522 self.logger.debug(*args, **kw)
524 def log_exception(self, *args: Any, **kw: Any) -> None:
525 self.logger.exception(*args, **kw)
527 def _process_keepalive(self) -> None:
528 self._keepalive_handle = None
529 if self._force_close or not self._keepalive:
530 return
532 loop = self._loop
533 now = loop.time()
534 close_time = self._next_keepalive_close_time
535 if now < close_time:
536 # Keep alive close check fired too early, reschedule
537 self._keepalive_handle = loop.call_at(close_time, self._process_keepalive)
538 return
540 # handler in idle state
541 if self._waiter and not self._waiter.done():
542 self.force_close()
544 async def _handle_request(
545 self,
546 request: _Request,
547 start_time: float | None,
548 request_handler: Callable[[_Request], Awaitable[StreamResponse]],
549 ) -> tuple[StreamResponse, bool]:
550 self._request_in_progress = True
551 try:
552 try:
553 self._current_request = request
554 resp = await request_handler(request)
555 finally:
556 self._current_request = None
557 except HTTPException as exc:
558 resp = Response(
559 status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers
560 )
561 resp._cookies = exc._cookies
562 resp, reset = await self.finish_response(request, resp, start_time)
563 except asyncio.CancelledError:
564 raise
565 except asyncio.TimeoutError as exc:
566 self.log_debug("Request handler timed out.", exc_info=exc)
567 resp = self.handle_error(request, 504)
568 resp, reset = await self.finish_response(request, resp, start_time)
569 except Exception as exc:
570 resp = self.handle_error(request, 500, exc)
571 resp, reset = await self.finish_response(request, resp, start_time)
572 else:
573 resp, reset = await self.finish_response(request, resp, start_time)
574 finally:
575 self._request_in_progress = False
576 if self._handler_waiter is not None:
577 self._handler_waiter.set_result(None)
579 return resp, reset
581 async def start(self) -> None:
582 """Process incoming request.
584 It reads request line, request headers and request payload, then
585 calls handle_request() method. Subclass has to override
586 handle_request(). start() handles various exceptions in request
587 or response handling. Connection is being closed always unless
588 keep_alive(True) specified.
589 """
590 loop = self._loop
591 manager = self._manager
592 assert manager is not None
593 keepalive_timeout = self._keepalive_timeout
594 resp = None
595 assert self._request_factory is not None
596 assert self._request_handler is not None
598 while not self._force_close:
599 if not self._messages:
600 try:
601 # wait for next request
602 self._waiter = loop.create_future()
603 await self._waiter
604 finally:
605 self._waiter = None
607 message, payload = self._messages.popleft()
609 # time is only fetched if logging is enabled as otherwise
610 # its thrown away and never used.
611 start = loop.time() if self._logging_enabled else None
613 manager.requests_count += 1
614 writer = StreamWriter(self, loop)
615 if not isinstance(message, _ErrInfo):
616 request_handler = self._request_handler
617 else:
618 # make request_factory work
619 request_handler = self._make_error_handler(message)
620 message = ERROR
622 # Important don't hold a reference to the current task
623 # as on traceback it will prevent the task from being
624 # collected and will cause a memory leak.
625 request = self._request_factory(
626 message,
627 payload,
628 self,
629 writer,
630 self._task_handler or asyncio.current_task(loop), # type: ignore[arg-type]
631 )
632 try:
633 # a new task is used for copy context vars (#3406)
634 coro = self._handle_request(request, start, request_handler)
635 if sys.version_info >= (3, 12):
636 task = asyncio.Task(coro, loop=loop, eager_start=True)
637 else:
638 task = loop.create_task(coro)
639 try:
640 resp, reset = await task
641 except ConnectionError:
642 self.log_debug("Ignored premature client disconnection")
643 break
645 # Drop the processed task from asyncio.Task.all_tasks() early
646 del task
647 if reset:
648 self.log_debug("Ignored premature client disconnection 2")
649 break
651 # notify server about keep-alive
652 self._keepalive = bool(resp.keep_alive)
654 # check payload
655 if not payload.is_eof():
656 lingering_time = self._lingering_time
657 # Could be force closed while awaiting above tasks.
658 if not self._force_close and lingering_time: # type: ignore[redundant-expr]
659 self.log_debug(
660 "Start lingering close timer for %s sec.", lingering_time
661 )
663 now = loop.time()
664 end_t = now + lingering_time
666 try:
667 while not payload.is_eof() and now < end_t:
668 async with ceil_timeout(end_t - now):
669 # read and ignore
670 await payload.readany()
671 now = loop.time()
672 except (asyncio.CancelledError, asyncio.TimeoutError):
673 if (
674 sys.version_info >= (3, 11)
675 and (t := asyncio.current_task())
676 and t.cancelling()
677 ):
678 raise
680 # if payload still uncompleted
681 if not payload.is_eof() and not self._force_close:
682 self.log_debug("Uncompleted request.")
683 self.close()
685 payload.set_exception(_PAYLOAD_ACCESS_ERROR)
687 except asyncio.CancelledError:
688 self.log_debug("Ignored premature client disconnection")
689 self.force_close()
690 raise
691 except Exception as exc:
692 self.log_exception("Unhandled exception", exc_info=exc)
693 self.force_close()
694 except BaseException:
695 self.force_close()
696 raise
697 finally:
698 request._task = None # type: ignore[assignment] # Break reference cycle in case of exception
699 if self.transport is None and resp is not None:
700 self.log_debug("Ignored premature client disconnection.")
702 if self._keepalive and not self._close and not self._force_close:
703 # start keep-alive timer
704 close_time = loop.time() + keepalive_timeout
705 self._next_keepalive_close_time = close_time
706 if self._keepalive_handle is None:
707 self._keepalive_handle = loop.call_at(
708 close_time, self._process_keepalive
709 )
710 else:
711 break
713 # remove handler, close transport if no handlers left
714 if not self._force_close:
715 self._task_handler = None
716 if self.transport is not None:
717 self.transport.close()
719 async def finish_response(
720 self, request: BaseRequest, resp: StreamResponse, start_time: float | None
721 ) -> tuple[StreamResponse, bool]:
722 """Prepare the response and write_eof, then log access.
724 This has to
725 be called within the context of any exception so the access logger
726 can get exception information. Returns True if the client disconnects
727 prematurely.
728 """
729 request._finish()
730 if self._parser is not None:
731 self._parser.set_upgraded(False)
732 self._upgraded = False
733 if self._message_tail:
734 self._parser.feed_data(self._message_tail)
735 self._message_tail = b""
736 try:
737 prepare_meth = resp.prepare
738 except AttributeError:
739 if resp is None:
740 self.log_exception("Missing return statement on request handler") # type: ignore[unreachable]
741 else:
742 self.log_exception(
743 f"Web-handler should return a response instance, got {resp!r}"
744 )
745 exc = HTTPInternalServerError()
746 resp = Response(
747 status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers
748 )
749 prepare_meth = resp.prepare
750 try:
751 await prepare_meth(request)
752 await resp.write_eof()
753 except ConnectionError:
754 await self.log_access(request, resp, start_time)
755 return resp, True
757 await self.log_access(request, resp, start_time)
758 return resp, False
760 def handle_error(
761 self,
762 request: BaseRequest,
763 status: int = 500,
764 exc: BaseException | None = None,
765 message: str | None = None,
766 ) -> StreamResponse:
767 """Handle errors.
769 Returns HTTP response with specific status code. Logs additional
770 information. It always closes current connection.
771 """
772 if self._request_count == 1 and isinstance(exc, BadHttpMethod):
773 # BadHttpMethod is common when a client sends non-HTTP
774 # or encrypted traffic to an HTTP port. This is expected
775 # to happen when connected to the public internet so we log
776 # it at the debug level as to not fill logs with noise.
777 self.logger.debug(
778 "Error handling request from %s", request.remote, exc_info=exc
779 )
780 else:
781 self.log_exception(
782 "Error handling request from %s", request.remote, exc_info=exc
783 )
785 # some data already got sent, connection is broken
786 if request.writer.output_size > 0:
787 raise ConnectionError(
788 "Response is sent already, cannot send another response "
789 "with the error message"
790 )
792 ct = "text/plain"
793 if status == HTTPStatus.INTERNAL_SERVER_ERROR:
794 title = f"{HTTPStatus.INTERNAL_SERVER_ERROR.value} {HTTPStatus.INTERNAL_SERVER_ERROR.phrase}"
795 msg = HTTPStatus.INTERNAL_SERVER_ERROR.description
796 tb = None
797 if self._loop.get_debug():
798 with suppress(Exception):
799 tb = traceback.format_exc()
801 if "text/html" in request.headers.get("Accept", ""):
802 if tb:
803 tb = html_escape(tb)
804 msg = f"<h2>Traceback:</h2>\n<pre>{tb}</pre>"
805 message = (
806 "<html><head>"
807 f"<title>{title}</title>"
808 f"</head><body>\n<h1>{title}</h1>"
809 f"\n{msg}\n</body></html>\n"
810 )
811 ct = "text/html"
812 else:
813 if tb:
814 msg = tb
815 message = title + "\n\n" + msg
817 resp = Response(status=status, text=message, content_type=ct)
818 resp.force_close()
820 return resp
822 def _make_error_handler(
823 self, err_info: _ErrInfo
824 ) -> Callable[[BaseRequest], Awaitable[StreamResponse]]:
825 async def handler(request: BaseRequest) -> StreamResponse:
826 return self.handle_error(
827 request, err_info.status, err_info.exc, err_info.message
828 )
830 return handler