Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/web_protocol.py: 19%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import asyncio
2import asyncio.streams
3import sys
4import traceback
5from collections import deque
6from contextlib import suppress
7from html import escape as html_escape
8from http import HTTPStatus
9from logging import Logger
10from typing import (
11 TYPE_CHECKING,
12 Any,
13 Awaitable,
14 Callable,
15 Deque,
16 Generic,
17 Optional,
18 Sequence,
19 Tuple,
20 Type,
21 TypeVar,
22 Union,
23 cast,
24)
26import yarl
27from propcache import under_cached_property
29from .abc import AbstractAccessLogger, AbstractAsyncAccessLogger, AbstractStreamWriter
30from .base_protocol import BaseProtocol
31from .helpers import ceil_timeout, frozen_dataclass_decorator
32from .http import (
33 HttpProcessingError,
34 HttpRequestParser,
35 HttpVersion10,
36 RawRequestMessage,
37 StreamWriter,
38)
39from .http_exceptions import BadHttpMethod
40from .log import access_logger, server_logger
41from .streams import EMPTY_PAYLOAD, StreamReader
42from .tcp_helpers import tcp_keepalive
43from .web_exceptions import HTTPException, HTTPInternalServerError
44from .web_log import AccessLogger
45from .web_request import BaseRequest
46from .web_response import Response, StreamResponse
48__all__ = ("RequestHandler", "RequestPayloadError", "PayloadAccessError")
50if TYPE_CHECKING:
51 import ssl
53 from .web_server import Server
56_Request = TypeVar("_Request", bound=BaseRequest)
57_RequestFactory = Callable[
58 [
59 RawRequestMessage,
60 StreamReader,
61 "RequestHandler[_Request]",
62 AbstractStreamWriter,
63 "asyncio.Task[None]",
64 ],
65 _Request,
66]
68_RequestHandler = Callable[[_Request], Awaitable[StreamResponse]]
69_AnyAbstractAccessLogger = Union[
70 Type[AbstractAsyncAccessLogger],
71 Type[AbstractAccessLogger],
72]
74ERROR = RawRequestMessage(
75 "UNKNOWN",
76 "/",
77 HttpVersion10,
78 {}, # type: ignore[arg-type]
79 {}, # type: ignore[arg-type]
80 True,
81 None,
82 False,
83 False,
84 yarl.URL("/"),
85)
88class RequestPayloadError(Exception):
89 """Payload parsing error."""
92class PayloadAccessError(Exception):
93 """Payload was accessed after response was sent."""
96_PAYLOAD_ACCESS_ERROR = PayloadAccessError()
99class AccessLoggerWrapper(AbstractAsyncAccessLogger):
100 """Wrap an AbstractAccessLogger so it behaves like an AbstractAsyncAccessLogger."""
102 __slots__ = ("access_logger", "_loop")
104 def __init__(
105 self, access_logger: AbstractAccessLogger, loop: asyncio.AbstractEventLoop
106 ) -> None:
107 self.access_logger = access_logger
108 self._loop = loop
109 super().__init__()
111 async def log(
112 self, request: BaseRequest, response: StreamResponse, request_start: float
113 ) -> None:
114 self.access_logger.log(request, response, self._loop.time() - request_start)
116 @property
117 def enabled(self) -> bool:
118 """Check if logger is enabled."""
119 return self.access_logger.enabled
122@frozen_dataclass_decorator
123class _ErrInfo:
124 status: int
125 exc: BaseException
126 message: str
129_MsgType = Tuple[Union[RawRequestMessage, _ErrInfo], StreamReader]
132class RequestHandler(BaseProtocol, Generic[_Request]):
133 """HTTP protocol implementation.
135 RequestHandler handles incoming HTTP request. It reads request line,
136 request headers and request payload and calls handle_request() method.
137 By default it always returns with 404 response.
139 RequestHandler handles errors in incoming request, like bad
140 status line, bad headers or incomplete payload. If any error occurs,
141 connection gets closed.
143 keepalive_timeout -- number of seconds before closing
144 keep-alive connection
146 tcp_keepalive -- TCP keep-alive is on, default is on
148 logger -- custom logger object
150 access_log_class -- custom class for access_logger
152 access_log -- custom logging object
154 access_log_format -- access log format string
156 loop -- Optional event loop
158 max_line_size -- Optional maximum header line size
160 max_field_size -- Optional maximum header field size
162 timeout_ceil_threshold -- Optional value to specify
163 threshold to ceil() timeout
164 values
166 """
168 __slots__ = (
169 "_request_count",
170 "_keepalive",
171 "_manager",
172 "_request_handler",
173 "_request_factory",
174 "_tcp_keepalive",
175 "_next_keepalive_close_time",
176 "_keepalive_handle",
177 "_keepalive_timeout",
178 "_lingering_time",
179 "_messages",
180 "_message_tail",
181 "_handler_waiter",
182 "_waiter",
183 "_task_handler",
184 "_upgrade",
185 "_payload_parser",
186 "_request_parser",
187 "logger",
188 "access_log",
189 "access_logger",
190 "_close",
191 "_force_close",
192 "_current_request",
193 "_timeout_ceil_threshold",
194 "_request_in_progress",
195 "_logging_enabled",
196 "_cache",
197 )
199 def __init__(
200 self,
201 manager: "Server[_Request]",
202 *,
203 loop: asyncio.AbstractEventLoop,
204 # Default should be high enough that it's likely longer than a reverse proxy.
205 keepalive_timeout: float = 3630,
206 tcp_keepalive: bool = True,
207 logger: Logger = server_logger,
208 access_log_class: _AnyAbstractAccessLogger = AccessLogger,
209 access_log: Optional[Logger] = access_logger,
210 access_log_format: str = AccessLogger.LOG_FORMAT,
211 max_line_size: int = 8190,
212 max_field_size: int = 8190,
213 lingering_time: float = 10.0,
214 read_bufsize: int = 2**16,
215 auto_decompress: bool = True,
216 timeout_ceil_threshold: float = 5,
217 ):
218 super().__init__(loop)
220 # _request_count is the number of requests processed with the same connection.
221 self._request_count = 0
222 self._keepalive = False
223 self._current_request: Optional[_Request] = None
224 self._manager: Optional[Server[_Request]] = manager
225 self._request_handler: Optional[_RequestHandler[_Request]] = (
226 manager.request_handler
227 )
228 self._request_factory: Optional[_RequestFactory[_Request]] = (
229 manager.request_factory
230 )
232 self._tcp_keepalive = tcp_keepalive
233 # placeholder to be replaced on keepalive timeout setup
234 self._next_keepalive_close_time = 0.0
235 self._keepalive_handle: Optional[asyncio.Handle] = None
236 self._keepalive_timeout = keepalive_timeout
237 self._lingering_time = float(lingering_time)
239 self._messages: Deque[_MsgType] = deque()
240 self._message_tail = b""
242 self._waiter: Optional[asyncio.Future[None]] = None
243 self._handler_waiter: Optional[asyncio.Future[None]] = None
244 self._task_handler: Optional[asyncio.Task[None]] = None
246 self._upgrade = False
247 self._payload_parser: Any = None
248 self._request_parser: Optional[HttpRequestParser] = HttpRequestParser(
249 self,
250 loop,
251 read_bufsize,
252 max_line_size=max_line_size,
253 max_field_size=max_field_size,
254 payload_exception=RequestPayloadError,
255 auto_decompress=auto_decompress,
256 )
258 self._timeout_ceil_threshold: float = 5
259 try:
260 self._timeout_ceil_threshold = float(timeout_ceil_threshold)
261 except (TypeError, ValueError):
262 pass
264 self.logger = logger
265 self.access_log = access_log
266 if access_log:
267 if issubclass(access_log_class, AbstractAsyncAccessLogger):
268 self.access_logger: Optional[AbstractAsyncAccessLogger] = (
269 access_log_class()
270 )
271 else:
272 access_logger = access_log_class(access_log, access_log_format)
273 self.access_logger = AccessLoggerWrapper(
274 access_logger,
275 self._loop,
276 )
277 self._logging_enabled = self.access_logger.enabled
278 else:
279 self.access_logger = None
280 self._logging_enabled = False
282 self._close = False
283 self._force_close = False
284 self._request_in_progress = False
285 self._cache: dict[str, Any] = {}
287 def __repr__(self) -> str:
288 return "<{} {}>".format(
289 self.__class__.__name__,
290 "connected" if self.transport is not None else "disconnected",
291 )
293 @under_cached_property
294 def ssl_context(self) -> Optional["ssl.SSLContext"]:
295 """Return SSLContext if available."""
296 return (
297 None
298 if self.transport is None
299 else self.transport.get_extra_info("sslcontext")
300 )
302 @under_cached_property
303 def peername(
304 self,
305 ) -> Optional[Union[str, Tuple[str, int, int, int], Tuple[str, int]]]:
306 """Return peername if available."""
307 return (
308 None
309 if self.transport is None
310 else self.transport.get_extra_info("peername")
311 )
313 @property
314 def keepalive_timeout(self) -> float:
315 return self._keepalive_timeout
317 async def shutdown(self, timeout: Optional[float] = 15.0) -> None:
318 """Do worker process exit preparations.
320 We need to clean up everything and stop accepting requests.
321 It is especially important for keep-alive connections.
322 """
323 self._force_close = True
325 if self._keepalive_handle is not None:
326 self._keepalive_handle.cancel()
328 # Wait for graceful handler completion
329 if self._request_in_progress:
330 # The future is only created when we are shutting
331 # down while the handler is still processing a request
332 # to avoid creating a future for every request.
333 self._handler_waiter = self._loop.create_future()
334 try:
335 async with ceil_timeout(timeout):
336 await self._handler_waiter
337 except (asyncio.CancelledError, asyncio.TimeoutError):
338 self._handler_waiter = None
339 if (
340 sys.version_info >= (3, 11)
341 and (task := asyncio.current_task())
342 and task.cancelling()
343 ):
344 raise
345 # Then cancel handler and wait
346 try:
347 async with ceil_timeout(timeout):
348 if self._current_request is not None:
349 self._current_request._cancel(asyncio.CancelledError())
351 if self._task_handler is not None and not self._task_handler.done():
352 await asyncio.shield(self._task_handler)
353 except (asyncio.CancelledError, asyncio.TimeoutError):
354 if (
355 sys.version_info >= (3, 11)
356 and (task := asyncio.current_task())
357 and task.cancelling()
358 ):
359 raise
361 # force-close non-idle handler
362 if self._task_handler is not None:
363 self._task_handler.cancel()
365 self.force_close()
367 def connection_made(self, transport: asyncio.BaseTransport) -> None:
368 super().connection_made(transport)
370 real_transport = cast(asyncio.Transport, transport)
371 if self._tcp_keepalive:
372 tcp_keepalive(real_transport)
374 assert self._manager is not None
375 self._manager.connection_made(self, real_transport)
377 loop = self._loop
378 if sys.version_info >= (3, 12):
379 task = asyncio.Task(self.start(), loop=loop, eager_start=True)
380 else:
381 task = loop.create_task(self.start())
382 self._task_handler = task
384 def connection_lost(self, exc: Optional[BaseException]) -> None:
385 if self._manager is None:
386 return
387 self._manager.connection_lost(self, exc)
389 # Grab value before setting _manager to None.
390 handler_cancellation = self._manager.handler_cancellation
392 self.force_close()
393 super().connection_lost(exc)
394 self._manager = None
395 self._request_factory = None
396 self._request_handler = None
397 self._request_parser = None
399 if self._keepalive_handle is not None:
400 self._keepalive_handle.cancel()
402 if self._current_request is not None:
403 if exc is None:
404 exc = ConnectionResetError("Connection lost")
405 self._current_request._cancel(exc)
407 if handler_cancellation and self._task_handler is not None:
408 self._task_handler.cancel()
410 self._task_handler = None
412 if self._payload_parser is not None:
413 self._payload_parser.feed_eof()
414 self._payload_parser = None
416 def set_parser(self, parser: Any) -> None:
417 # Actual type is WebReader
418 assert self._payload_parser is None
420 self._payload_parser = parser
422 if self._message_tail:
423 self._payload_parser.feed_data(self._message_tail)
424 self._message_tail = b""
426 def eof_received(self) -> None:
427 pass
429 def data_received(self, data: bytes) -> None:
430 if self._force_close or self._close:
431 return
432 # parse http messages
433 messages: Sequence[_MsgType]
434 if self._payload_parser is None and not self._upgrade:
435 assert self._request_parser is not None
436 try:
437 messages, upgraded, tail = self._request_parser.feed_data(data)
438 except HttpProcessingError as exc:
439 messages = [
440 (_ErrInfo(status=400, exc=exc, message=exc.message), EMPTY_PAYLOAD)
441 ]
442 upgraded = False
443 tail = b""
445 for msg, payload in messages or ():
446 self._request_count += 1
447 self._messages.append((msg, payload))
449 waiter = self._waiter
450 if messages and waiter is not None and not waiter.done():
451 # don't set result twice
452 waiter.set_result(None)
454 self._upgrade = upgraded
455 if upgraded and tail:
456 self._message_tail = tail
458 # no parser, just store
459 elif self._payload_parser is None and self._upgrade and data:
460 self._message_tail += data
462 # feed payload
463 elif data:
464 eof, tail = self._payload_parser.feed_data(data)
465 if eof:
466 self.close()
468 def keep_alive(self, val: bool) -> None:
469 """Set keep-alive connection mode.
471 :param bool val: new state.
472 """
473 self._keepalive = val
474 if self._keepalive_handle:
475 self._keepalive_handle.cancel()
476 self._keepalive_handle = None
478 def close(self) -> None:
479 """Close connection.
481 Stop accepting new pipelining messages and close
482 connection when handlers done processing messages.
483 """
484 self._close = True
485 if self._waiter:
486 self._waiter.cancel()
488 def force_close(self) -> None:
489 """Forcefully close connection."""
490 self._force_close = True
491 if self._waiter:
492 self._waiter.cancel()
493 if self.transport is not None:
494 self.transport.close()
495 self.transport = None
497 async def log_access(
498 self,
499 request: BaseRequest,
500 response: StreamResponse,
501 request_start: Optional[float],
502 ) -> None:
503 if self._logging_enabled and self.access_logger is not None:
504 if TYPE_CHECKING:
505 assert request_start is not None
506 await self.access_logger.log(request, response, request_start)
508 def log_debug(self, *args: Any, **kw: Any) -> None:
509 if self._loop.get_debug():
510 self.logger.debug(*args, **kw)
512 def log_exception(self, *args: Any, **kw: Any) -> None:
513 self.logger.exception(*args, **kw)
515 def _process_keepalive(self) -> None:
516 self._keepalive_handle = None
517 if self._force_close or not self._keepalive:
518 return
520 loop = self._loop
521 now = loop.time()
522 close_time = self._next_keepalive_close_time
523 if now < close_time:
524 # Keep alive close check fired too early, reschedule
525 self._keepalive_handle = loop.call_at(close_time, self._process_keepalive)
526 return
528 # handler in idle state
529 if self._waiter and not self._waiter.done():
530 self.force_close()
532 async def _handle_request(
533 self,
534 request: _Request,
535 start_time: Optional[float],
536 request_handler: Callable[[_Request], Awaitable[StreamResponse]],
537 ) -> Tuple[StreamResponse, bool]:
538 self._request_in_progress = True
539 try:
540 try:
541 self._current_request = request
542 resp = await request_handler(request)
543 finally:
544 self._current_request = None
545 except HTTPException as exc:
546 resp = Response(
547 status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers
548 )
549 resp._cookies = exc._cookies
550 resp, reset = await self.finish_response(request, resp, start_time)
551 except asyncio.CancelledError:
552 raise
553 except asyncio.TimeoutError as exc:
554 self.log_debug("Request handler timed out.", exc_info=exc)
555 resp = self.handle_error(request, 504)
556 resp, reset = await self.finish_response(request, resp, start_time)
557 except Exception as exc:
558 resp = self.handle_error(request, 500, exc)
559 resp, reset = await self.finish_response(request, resp, start_time)
560 else:
561 resp, reset = await self.finish_response(request, resp, start_time)
562 finally:
563 self._request_in_progress = False
564 if self._handler_waiter is not None:
565 self._handler_waiter.set_result(None)
567 return resp, reset
569 async def start(self) -> None:
570 """Process incoming request.
572 It reads request line, request headers and request payload, then
573 calls handle_request() method. Subclass has to override
574 handle_request(). start() handles various exceptions in request
575 or response handling. Connection is being closed always unless
576 keep_alive(True) specified.
577 """
578 loop = self._loop
579 manager = self._manager
580 assert manager is not None
581 keepalive_timeout = self._keepalive_timeout
582 resp = None
583 assert self._request_factory is not None
584 assert self._request_handler is not None
586 while not self._force_close:
587 if not self._messages:
588 try:
589 # wait for next request
590 self._waiter = loop.create_future()
591 await self._waiter
592 finally:
593 self._waiter = None
595 message, payload = self._messages.popleft()
597 # time is only fetched if logging is enabled as otherwise
598 # its thrown away and never used.
599 start = loop.time() if self._logging_enabled else None
601 manager.requests_count += 1
602 writer = StreamWriter(self, loop)
603 if not isinstance(message, _ErrInfo):
604 request_handler = self._request_handler
605 else:
606 # make request_factory work
607 request_handler = self._make_error_handler(message)
608 message = ERROR
610 # Important don't hold a reference to the current task
611 # as on traceback it will prevent the task from being
612 # collected and will cause a memory leak.
613 request = self._request_factory(
614 message,
615 payload,
616 self,
617 writer,
618 self._task_handler or asyncio.current_task(loop), # type: ignore[arg-type]
619 )
620 try:
621 # a new task is used for copy context vars (#3406)
622 coro = self._handle_request(request, start, request_handler)
623 if sys.version_info >= (3, 12):
624 task = asyncio.Task(coro, loop=loop, eager_start=True)
625 else:
626 task = loop.create_task(coro)
627 try:
628 resp, reset = await task
629 except ConnectionError:
630 self.log_debug("Ignored premature client disconnection")
631 break
633 # Drop the processed task from asyncio.Task.all_tasks() early
634 del task
635 # https://github.com/python/mypy/issues/14309
636 if reset: # type: ignore[possibly-undefined]
637 self.log_debug("Ignored premature client disconnection 2")
638 break
640 # notify server about keep-alive
641 self._keepalive = bool(resp.keep_alive)
643 # check payload
644 if not payload.is_eof():
645 lingering_time = self._lingering_time
646 # Could be force closed while awaiting above tasks.
647 if not self._force_close and lingering_time: # type: ignore[redundant-expr]
648 self.log_debug(
649 "Start lingering close timer for %s sec.", lingering_time
650 )
652 now = loop.time()
653 end_t = now + lingering_time
655 try:
656 while not payload.is_eof() and now < end_t:
657 async with ceil_timeout(end_t - now):
658 # read and ignore
659 await payload.readany()
660 now = loop.time()
661 except (asyncio.CancelledError, asyncio.TimeoutError):
662 if (
663 sys.version_info >= (3, 11)
664 and (t := asyncio.current_task())
665 and t.cancelling()
666 ):
667 raise
669 # if payload still uncompleted
670 if not payload.is_eof() and not self._force_close:
671 self.log_debug("Uncompleted request.")
672 self.close()
674 payload.set_exception(_PAYLOAD_ACCESS_ERROR)
676 except asyncio.CancelledError:
677 self.log_debug("Ignored premature client disconnection")
678 self.force_close()
679 raise
680 except Exception as exc:
681 self.log_exception("Unhandled exception", exc_info=exc)
682 self.force_close()
683 except BaseException:
684 self.force_close()
685 raise
686 finally:
687 request._task = None # type: ignore[assignment] # Break reference cycle in case of exception
688 if self.transport is None and resp is not None:
689 self.log_debug("Ignored premature client disconnection.")
691 if self._keepalive and not self._close and not self._force_close:
692 # start keep-alive timer
693 close_time = loop.time() + keepalive_timeout
694 self._next_keepalive_close_time = close_time
695 if self._keepalive_handle is None:
696 self._keepalive_handle = loop.call_at(
697 close_time, self._process_keepalive
698 )
699 else:
700 break
702 # remove handler, close transport if no handlers left
703 if not self._force_close:
704 self._task_handler = None
705 if self.transport is not None:
706 self.transport.close()
708 async def finish_response(
709 self, request: BaseRequest, resp: StreamResponse, start_time: Optional[float]
710 ) -> Tuple[StreamResponse, bool]:
711 """Prepare the response and write_eof, then log access.
713 This has to
714 be called within the context of any exception so the access logger
715 can get exception information. Returns True if the client disconnects
716 prematurely.
717 """
718 request._finish()
719 if self._request_parser is not None:
720 self._request_parser.set_upgraded(False)
721 self._upgrade = False
722 if self._message_tail:
723 self._request_parser.feed_data(self._message_tail)
724 self._message_tail = b""
725 try:
726 prepare_meth = resp.prepare
727 except AttributeError:
728 if resp is None:
729 self.log_exception("Missing return statement on request handler") # type: ignore[unreachable]
730 else:
731 self.log_exception(
732 "Web-handler should return a response instance, "
733 "got {!r}".format(resp)
734 )
735 exc = HTTPInternalServerError()
736 resp = Response(
737 status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers
738 )
739 prepare_meth = resp.prepare
740 try:
741 await prepare_meth(request)
742 await resp.write_eof()
743 except ConnectionError:
744 await self.log_access(request, resp, start_time)
745 return resp, True
747 await self.log_access(request, resp, start_time)
748 return resp, False
750 def handle_error(
751 self,
752 request: BaseRequest,
753 status: int = 500,
754 exc: Optional[BaseException] = None,
755 message: Optional[str] = None,
756 ) -> StreamResponse:
757 """Handle errors.
759 Returns HTTP response with specific status code. Logs additional
760 information. It always closes current connection.
761 """
762 if self._request_count == 1 and isinstance(exc, BadHttpMethod):
763 # BadHttpMethod is common when a client sends non-HTTP
764 # or encrypted traffic to an HTTP port. This is expected
765 # to happen when connected to the public internet so we log
766 # it at the debug level as to not fill logs with noise.
767 self.logger.debug(
768 "Error handling request from %s", request.remote, exc_info=exc
769 )
770 else:
771 self.log_exception(
772 "Error handling request from %s", request.remote, exc_info=exc
773 )
775 # some data already got sent, connection is broken
776 if request.writer.output_size > 0:
777 raise ConnectionError(
778 "Response is sent already, cannot send another response "
779 "with the error message"
780 )
782 ct = "text/plain"
783 if status == HTTPStatus.INTERNAL_SERVER_ERROR:
784 title = "{0.value} {0.phrase}".format(HTTPStatus.INTERNAL_SERVER_ERROR)
785 msg = HTTPStatus.INTERNAL_SERVER_ERROR.description
786 tb = None
787 if self._loop.get_debug():
788 with suppress(Exception):
789 tb = traceback.format_exc()
791 if "text/html" in request.headers.get("Accept", ""):
792 if tb:
793 tb = html_escape(tb)
794 msg = f"<h2>Traceback:</h2>\n<pre>{tb}</pre>"
795 message = (
796 "<html><head>"
797 "<title>{title}</title>"
798 "</head><body>\n<h1>{title}</h1>"
799 "\n{msg}\n</body></html>\n"
800 ).format(title=title, msg=msg)
801 ct = "text/html"
802 else:
803 if tb:
804 msg = tb
805 message = title + "\n\n" + msg
807 resp = Response(status=status, text=message, content_type=ct)
808 resp.force_close()
810 return resp
812 def _make_error_handler(
813 self, err_info: _ErrInfo
814 ) -> Callable[[BaseRequest], Awaitable[StreamResponse]]:
815 async def handler(request: BaseRequest) -> StreamResponse:
816 return self.handle_error(
817 request, err_info.status, err_info.exc, err_info.message
818 )
820 return handler