Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/web_protocol.py: 18%
354 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-26 06:16 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-26 06:16 +0000
1import asyncio
2import asyncio.streams
3import dataclasses
4import traceback
5from collections import deque
6from contextlib import suppress
7from html import escape as html_escape
8from http import HTTPStatus
9from logging import Logger
10from typing import (
11 TYPE_CHECKING,
12 Any,
13 Awaitable,
14 Callable,
15 Deque,
16 Optional,
17 Sequence,
18 Tuple,
19 Type,
20 Union,
21 cast,
22)
24import yarl
26from .abc import AbstractAccessLogger, AbstractAsyncAccessLogger, AbstractStreamWriter
27from .base_protocol import BaseProtocol
28from .helpers import ceil_timeout
29from .http import (
30 HttpProcessingError,
31 HttpRequestParser,
32 HttpVersion10,
33 RawRequestMessage,
34 StreamWriter,
35)
36from .log import access_logger, server_logger
37from .streams import EMPTY_PAYLOAD, StreamReader
38from .tcp_helpers import tcp_keepalive
39from .web_exceptions import HTTPException
40from .web_log import AccessLogger
41from .web_request import BaseRequest
42from .web_response import Response, StreamResponse
44__all__ = ("RequestHandler", "RequestPayloadError", "PayloadAccessError")
46if TYPE_CHECKING:
47 from .web_server import Server
50_RequestFactory = Callable[
51 [
52 RawRequestMessage,
53 StreamReader,
54 "RequestHandler",
55 AbstractStreamWriter,
56 "asyncio.Task[None]",
57 ],
58 BaseRequest,
59]
61_RequestHandler = Callable[[BaseRequest], Awaitable[StreamResponse]]
62_AnyAbstractAccessLogger = Union[
63 Type[AbstractAsyncAccessLogger],
64 Type[AbstractAccessLogger],
65]
67ERROR = RawRequestMessage(
68 "UNKNOWN",
69 "/",
70 HttpVersion10,
71 {}, # type: ignore[arg-type]
72 {}, # type: ignore[arg-type]
73 True,
74 None,
75 False,
76 False,
77 yarl.URL("/"),
78)
81class RequestPayloadError(Exception):
82 """Payload parsing error."""
85class PayloadAccessError(Exception):
86 """Payload was accessed after response was sent."""
89class AccessLoggerWrapper(AbstractAsyncAccessLogger):
90 """Wrap an AbstractAccessLogger so it behaves like an AbstractAsyncAccessLogger."""
92 def __init__(
93 self, access_logger: AbstractAccessLogger, loop: asyncio.AbstractEventLoop
94 ) -> None:
95 self.access_logger = access_logger
96 self._loop = loop
97 super().__init__()
99 async def log(
100 self, request: BaseRequest, response: StreamResponse, request_start: float
101 ) -> None:
102 self.access_logger.log(request, response, self._loop.time() - request_start)
105@dataclasses.dataclass(frozen=True)
106class _ErrInfo:
107 status: int
108 exc: BaseException
109 message: str
112_MsgType = Tuple[Union[RawRequestMessage, _ErrInfo], StreamReader]
115class RequestHandler(BaseProtocol):
116 """HTTP protocol implementation.
118 RequestHandler handles incoming HTTP request. It reads request line,
119 request headers and request payload and calls handle_request() method.
120 By default it always returns with 404 response.
122 RequestHandler handles errors in incoming request, like bad
123 status line, bad headers or incomplete payload. If any error occurs,
124 connection gets closed.
126 keepalive_timeout -- number of seconds before closing
127 keep-alive connection
129 tcp_keepalive -- TCP keep-alive is on, default is on
131 logger -- custom logger object
133 access_log_class -- custom class for access_logger
135 access_log -- custom logging object
137 access_log_format -- access log format string
139 loop -- Optional event loop
141 max_line_size -- Optional maximum header line size
143 max_field_size -- Optional maximum header field size
145 timeout_ceil_threshold -- Optional value to specify
146 threshold to ceil() timeout
147 values
149 """
151 KEEPALIVE_RESCHEDULE_DELAY = 1
153 __slots__ = (
154 "_request_count",
155 "_keepalive",
156 "_manager",
157 "_request_handler",
158 "_request_factory",
159 "_tcp_keepalive",
160 "_keepalive_time",
161 "_keepalive_handle",
162 "_keepalive_timeout",
163 "_lingering_time",
164 "_messages",
165 "_message_tail",
166 "_waiter",
167 "_task_handler",
168 "_upgrade",
169 "_payload_parser",
170 "_request_parser",
171 "logger",
172 "access_log",
173 "access_logger",
174 "_close",
175 "_force_close",
176 "_current_request",
177 "_timeout_ceil_threshold",
178 )
180 def __init__(
181 self,
182 manager: "Server",
183 *,
184 loop: asyncio.AbstractEventLoop,
185 keepalive_timeout: float = 75.0, # NGINX default is 75 secs
186 tcp_keepalive: bool = True,
187 logger: Logger = server_logger,
188 access_log_class: _AnyAbstractAccessLogger = AccessLogger,
189 access_log: Optional[Logger] = access_logger,
190 access_log_format: str = AccessLogger.LOG_FORMAT,
191 max_line_size: int = 8190,
192 max_field_size: int = 8190,
193 lingering_time: float = 10.0,
194 read_bufsize: int = 2**16,
195 auto_decompress: bool = True,
196 timeout_ceil_threshold: float = 5,
197 ):
198 super().__init__(loop)
200 self._request_count = 0
201 self._keepalive = False
202 self._current_request: Optional[BaseRequest] = None
203 self._manager: Optional[Server] = manager
204 self._request_handler: Optional[_RequestHandler] = manager.request_handler
205 self._request_factory: Optional[_RequestFactory] = manager.request_factory
207 self._tcp_keepalive = tcp_keepalive
208 # placeholder to be replaced on keepalive timeout setup
209 self._keepalive_time = 0.0
210 self._keepalive_handle: Optional[asyncio.Handle] = None
211 self._keepalive_timeout = keepalive_timeout
212 self._lingering_time = float(lingering_time)
214 self._messages: Deque[_MsgType] = deque()
215 self._message_tail = b""
217 self._waiter: Optional[asyncio.Future[None]] = None
218 self._task_handler: Optional[asyncio.Task[None]] = None
220 self._upgrade = False
221 self._payload_parser: Any = None
222 self._request_parser: Optional[HttpRequestParser] = HttpRequestParser(
223 self,
224 loop,
225 read_bufsize,
226 max_line_size=max_line_size,
227 max_field_size=max_field_size,
228 payload_exception=RequestPayloadError,
229 auto_decompress=auto_decompress,
230 )
232 self._timeout_ceil_threshold: float = 5
233 try:
234 self._timeout_ceil_threshold = float(timeout_ceil_threshold)
235 except (TypeError, ValueError):
236 pass
238 self.logger = logger
239 self.access_log = access_log
240 if access_log:
241 if issubclass(access_log_class, AbstractAsyncAccessLogger):
242 self.access_logger: Optional[
243 AbstractAsyncAccessLogger
244 ] = access_log_class()
245 else:
246 access_logger = access_log_class(access_log, access_log_format)
247 self.access_logger = AccessLoggerWrapper(
248 access_logger,
249 self._loop,
250 )
251 else:
252 self.access_logger = None
254 self._close = False
255 self._force_close = False
257 def __repr__(self) -> str:
258 return "<{} {}>".format(
259 self.__class__.__name__,
260 "connected" if self.transport is not None else "disconnected",
261 )
263 @property
264 def keepalive_timeout(self) -> float:
265 return self._keepalive_timeout
267 async def shutdown(self, timeout: Optional[float] = 15.0) -> None:
268 """Do worker process exit preparations.
270 We need to clean up everything and stop accepting requests.
271 It is especially important for keep-alive connections.
272 """
273 self._force_close = True
275 if self._keepalive_handle is not None:
276 self._keepalive_handle.cancel()
278 if self._waiter:
279 self._waiter.cancel()
281 # wait for handlers
282 with suppress(asyncio.CancelledError, asyncio.TimeoutError):
283 async with ceil_timeout(timeout):
284 if self._current_request is not None:
285 self._current_request._cancel(asyncio.CancelledError())
287 if self._task_handler is not None and not self._task_handler.done():
288 await self._task_handler
290 # force-close non-idle handler
291 if self._task_handler is not None:
292 self._task_handler.cancel()
294 if self.transport is not None:
295 self.transport.close()
296 self.transport = None
298 def connection_made(self, transport: asyncio.BaseTransport) -> None:
299 super().connection_made(transport)
301 real_transport = cast(asyncio.Transport, transport)
302 if self._tcp_keepalive:
303 tcp_keepalive(real_transport)
305 self._task_handler = self._loop.create_task(self.start())
306 assert self._manager is not None
307 self._manager.connection_made(self, real_transport)
309 def connection_lost(self, exc: Optional[BaseException]) -> None:
310 if self._manager is None:
311 return
312 self._manager.connection_lost(self, exc)
314 super().connection_lost(exc)
316 # Grab value before setting _manager to None.
317 handler_cancellation = self._manager.handler_cancellation
319 self._manager = None
320 self._force_close = True
321 self._request_factory = None
322 self._request_handler = None
323 self._request_parser = None
325 if self._keepalive_handle is not None:
326 self._keepalive_handle.cancel()
328 if self._current_request is not None:
329 if exc is None:
330 exc = ConnectionResetError("Connection lost")
331 self._current_request._cancel(exc)
333 if self._waiter is not None:
334 self._waiter.cancel()
336 if handler_cancellation and self._task_handler is not None:
337 self._task_handler.cancel()
339 self._task_handler = None
341 if self._payload_parser is not None:
342 self._payload_parser.feed_eof()
343 self._payload_parser = None
345 def set_parser(self, parser: Any) -> None:
346 # Actual type is WebReader
347 assert self._payload_parser is None
349 self._payload_parser = parser
351 if self._message_tail:
352 self._payload_parser.feed_data(self._message_tail)
353 self._message_tail = b""
355 def eof_received(self) -> None:
356 pass
358 def data_received(self, data: bytes) -> None:
359 if self._force_close or self._close:
360 return
361 # parse http messages
362 messages: Sequence[_MsgType]
363 if self._payload_parser is None and not self._upgrade:
364 assert self._request_parser is not None
365 try:
366 messages, upgraded, tail = self._request_parser.feed_data(data)
367 except HttpProcessingError as exc:
368 messages = [
369 (_ErrInfo(status=400, exc=exc, message=exc.message), EMPTY_PAYLOAD)
370 ]
371 upgraded = False
372 tail = b""
374 for msg, payload in messages or ():
375 self._request_count += 1
376 self._messages.append((msg, payload))
378 waiter = self._waiter
379 if messages and waiter is not None and not waiter.done():
380 # don't set result twice
381 waiter.set_result(None)
383 self._upgrade = upgraded
384 if upgraded and tail:
385 self._message_tail = tail
387 # no parser, just store
388 elif self._payload_parser is None and self._upgrade and data:
389 self._message_tail += data
391 # feed payload
392 elif data:
393 eof, tail = self._payload_parser.feed_data(data)
394 if eof:
395 self.close()
397 def keep_alive(self, val: bool) -> None:
398 """Set keep-alive connection mode.
400 :param bool val: new state.
401 """
402 self._keepalive = val
403 if self._keepalive_handle:
404 self._keepalive_handle.cancel()
405 self._keepalive_handle = None
407 def close(self) -> None:
408 """Close connection.
410 Stop accepting new pipelining messages and close
411 connection when handlers done processing messages.
412 """
413 self._close = True
414 if self._waiter:
415 self._waiter.cancel()
417 def force_close(self) -> None:
418 """Forcefully close connection."""
419 self._force_close = True
420 if self._waiter:
421 self._waiter.cancel()
422 if self.transport is not None:
423 self.transport.close()
424 self.transport = None
426 async def log_access(
427 self, request: BaseRequest, response: StreamResponse, request_start: float
428 ) -> None:
429 if self.access_logger is not None:
430 await self.access_logger.log(request, response, request_start)
432 def log_debug(self, *args: Any, **kw: Any) -> None:
433 if self._loop.get_debug():
434 self.logger.debug(*args, **kw)
436 def log_exception(self, *args: Any, **kw: Any) -> None:
437 self.logger.exception(*args, **kw)
439 def _process_keepalive(self) -> None:
440 if self._force_close or not self._keepalive:
441 return
443 next = self._keepalive_time + self._keepalive_timeout
445 # handler in idle state
446 if self._waiter:
447 if self._loop.time() > next:
448 self.force_close()
449 return
451 # not all request handlers are done,
452 # reschedule itself to next second
453 self._keepalive_handle = self._loop.call_later(
454 self.KEEPALIVE_RESCHEDULE_DELAY,
455 self._process_keepalive,
456 )
458 async def _handle_request(
459 self,
460 request: BaseRequest,
461 start_time: float,
462 request_handler: Callable[[BaseRequest], Awaitable[StreamResponse]],
463 ) -> Tuple[StreamResponse, bool]:
464 assert self._request_handler is not None
465 try:
466 try:
467 self._current_request = request
468 resp = await request_handler(request)
469 finally:
470 self._current_request = None
471 except HTTPException as exc:
472 resp = Response(
473 status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers
474 )
475 resp._cookies = exc._cookies
476 reset = await self.finish_response(request, resp, start_time)
477 except asyncio.CancelledError:
478 raise
479 except asyncio.TimeoutError as exc:
480 self.log_debug("Request handler timed out.", exc_info=exc)
481 resp = self.handle_error(request, 504)
482 reset = await self.finish_response(request, resp, start_time)
483 except Exception as exc:
484 resp = self.handle_error(request, 500, exc)
485 reset = await self.finish_response(request, resp, start_time)
486 else:
487 reset = await self.finish_response(request, resp, start_time)
489 return resp, reset
491 async def start(self) -> None:
492 """Process incoming request.
494 It reads request line, request headers and request payload, then
495 calls handle_request() method. Subclass has to override
496 handle_request(). start() handles various exceptions in request
497 or response handling. Connection is being closed always unless
498 keep_alive(True) specified.
499 """
500 loop = self._loop
501 handler = self._task_handler
502 assert handler is not None
503 manager = self._manager
504 assert manager is not None
505 keepalive_timeout = self._keepalive_timeout
506 resp = None
507 assert self._request_factory is not None
508 assert self._request_handler is not None
510 while not self._force_close:
511 if not self._messages:
512 try:
513 # wait for next request
514 self._waiter = loop.create_future()
515 await self._waiter
516 except asyncio.CancelledError:
517 break
518 finally:
519 self._waiter = None
521 message, payload = self._messages.popleft()
523 start = loop.time()
525 manager.requests_count += 1
526 writer = StreamWriter(self, loop)
527 if isinstance(message, _ErrInfo):
528 # make request_factory work
529 request_handler = self._make_error_handler(message)
530 message = ERROR
531 else:
532 request_handler = self._request_handler
534 request = self._request_factory(message, payload, self, writer, handler)
535 try:
536 # a new task is used for copy context vars (#3406)
537 task = self._loop.create_task(
538 self._handle_request(request, start, request_handler)
539 )
540 try:
541 resp, reset = await task
542 except (asyncio.CancelledError, ConnectionError):
543 self.log_debug("Ignored premature client disconnection")
544 break
546 # Drop the processed task from asyncio.Task.all_tasks() early
547 del task
548 # https://github.com/python/mypy/issues/14309
549 if reset: # type: ignore[possibly-undefined]
550 self.log_debug("Ignored premature client disconnection 2")
551 break
553 # notify server about keep-alive
554 self._keepalive = bool(resp.keep_alive)
556 # check payload
557 if not payload.is_eof():
558 lingering_time = self._lingering_time
559 # Could be force closed while awaiting above tasks.
560 if not self._force_close and lingering_time: # type: ignore[redundant-expr]
561 self.log_debug(
562 "Start lingering close timer for %s sec.", lingering_time
563 )
565 now = loop.time()
566 end_t = now + lingering_time
568 with suppress(asyncio.TimeoutError, asyncio.CancelledError):
569 while not payload.is_eof() and now < end_t:
570 async with ceil_timeout(end_t - now):
571 # read and ignore
572 await payload.readany()
573 now = loop.time()
575 # if payload still uncompleted
576 if not payload.is_eof() and not self._force_close:
577 self.log_debug("Uncompleted request.")
578 self.close()
580 payload.set_exception(PayloadAccessError())
582 except asyncio.CancelledError:
583 self.log_debug("Ignored premature client disconnection ")
584 break
585 except RuntimeError as exc:
586 if self._loop.get_debug():
587 self.log_exception("Unhandled runtime exception", exc_info=exc)
588 self.force_close()
589 except Exception as exc:
590 self.log_exception("Unhandled exception", exc_info=exc)
591 self.force_close()
592 finally:
593 if self.transport is None and resp is not None:
594 self.log_debug("Ignored premature client disconnection.")
595 elif not self._force_close:
596 if self._keepalive and not self._close:
597 # start keep-alive timer
598 if keepalive_timeout is not None:
599 now = self._loop.time()
600 self._keepalive_time = now
601 if self._keepalive_handle is None:
602 self._keepalive_handle = loop.call_at(
603 now + keepalive_timeout, self._process_keepalive
604 )
605 else:
606 break
608 # remove handler, close transport if no handlers left
609 if not self._force_close:
610 self._task_handler = None
611 if self.transport is not None:
612 self.transport.close()
614 async def finish_response(
615 self, request: BaseRequest, resp: StreamResponse, start_time: float
616 ) -> bool:
617 """Prepare the response and write_eof, then log access.
619 This has to
620 be called within the context of any exception so the access logger
621 can get exception information. Returns True if the client disconnects
622 prematurely.
623 """
624 request._finish()
625 if self._request_parser is not None:
626 self._request_parser.set_upgraded(False)
627 self._upgrade = False
628 if self._message_tail:
629 self._request_parser.feed_data(self._message_tail)
630 self._message_tail = b""
631 try:
632 prepare_meth = resp.prepare
633 except AttributeError:
634 if resp is None:
635 raise RuntimeError("Missing return " "statement on request handler")
636 else:
637 raise RuntimeError(
638 "Web-handler should return "
639 "a response instance, "
640 "got {!r}".format(resp)
641 )
642 try:
643 await prepare_meth(request)
644 await resp.write_eof()
645 except ConnectionError:
646 await self.log_access(request, resp, start_time)
647 return True
648 else:
649 await self.log_access(request, resp, start_time)
650 return False
652 def handle_error(
653 self,
654 request: BaseRequest,
655 status: int = 500,
656 exc: Optional[BaseException] = None,
657 message: Optional[str] = None,
658 ) -> StreamResponse:
659 """Handle errors.
661 Returns HTTP response with specific status code. Logs additional
662 information. It always closes current connection.
663 """
664 self.log_exception("Error handling request", exc_info=exc)
666 # some data already got sent, connection is broken
667 if request.writer.output_size > 0:
668 raise ConnectionError(
669 "Response is sent already, cannot send another response "
670 "with the error message"
671 )
673 ct = "text/plain"
674 if status == HTTPStatus.INTERNAL_SERVER_ERROR:
675 title = "{0.value} {0.phrase}".format(HTTPStatus.INTERNAL_SERVER_ERROR)
676 msg = HTTPStatus.INTERNAL_SERVER_ERROR.description
677 tb = None
678 if self._loop.get_debug():
679 with suppress(Exception):
680 tb = traceback.format_exc()
682 if "text/html" in request.headers.get("Accept", ""):
683 if tb:
684 tb = html_escape(tb)
685 msg = f"<h2>Traceback:</h2>\n<pre>{tb}</pre>"
686 message = (
687 "<html><head>"
688 "<title>{title}</title>"
689 "</head><body>\n<h1>{title}</h1>"
690 "\n{msg}\n</body></html>\n"
691 ).format(title=title, msg=msg)
692 ct = "text/html"
693 else:
694 if tb:
695 msg = tb
696 message = title + "\n\n" + msg
698 resp = Response(status=status, text=message, content_type=ct)
699 resp.force_close()
701 return resp
703 def _make_error_handler(
704 self, err_info: _ErrInfo
705 ) -> Callable[[BaseRequest], Awaitable[StreamResponse]]:
706 async def handler(request: BaseRequest) -> StreamResponse:
707 return self.handle_error(
708 request, err_info.status, err_info.exc, err_info.message
709 )
711 return handler