Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/web_protocol.py: 18%
352 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:52 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:52 +0000
1import asyncio
2import asyncio.streams
3import dataclasses
4import traceback
5from collections import deque
6from contextlib import suppress
7from html import escape as html_escape
8from http import HTTPStatus
9from logging import Logger
10from typing import (
11 TYPE_CHECKING,
12 Any,
13 Awaitable,
14 Callable,
15 Deque,
16 Optional,
17 Sequence,
18 Tuple,
19 Type,
20 Union,
21 cast,
22)
24import yarl
26from .abc import AbstractAccessLogger, AbstractAsyncAccessLogger, AbstractStreamWriter
27from .base_protocol import BaseProtocol
28from .helpers import ceil_timeout
29from .http import (
30 HttpProcessingError,
31 HttpRequestParser,
32 HttpVersion10,
33 RawRequestMessage,
34 StreamWriter,
35)
36from .log import access_logger, server_logger
37from .streams import EMPTY_PAYLOAD, StreamReader
38from .tcp_helpers import tcp_keepalive
39from .web_exceptions import HTTPException
40from .web_log import AccessLogger
41from .web_request import BaseRequest
42from .web_response import Response, StreamResponse
44__all__ = ("RequestHandler", "RequestPayloadError", "PayloadAccessError")
46if TYPE_CHECKING: # pragma: no cover
47 from .web_server import Server
50_RequestFactory = Callable[
51 [
52 RawRequestMessage,
53 StreamReader,
54 "RequestHandler",
55 AbstractStreamWriter,
56 "asyncio.Task[None]",
57 ],
58 BaseRequest,
59]
61_RequestHandler = Callable[[BaseRequest], Awaitable[StreamResponse]]
62_AnyAbstractAccessLogger = Union[
63 Type[AbstractAsyncAccessLogger],
64 Type[AbstractAccessLogger],
65]
67ERROR = RawRequestMessage(
68 "UNKNOWN",
69 "/",
70 HttpVersion10,
71 {}, # type: ignore[arg-type]
72 {}, # type: ignore[arg-type]
73 True,
74 None,
75 False,
76 False,
77 yarl.URL("/"),
78)
81class RequestPayloadError(Exception):
82 """Payload parsing error."""
85class PayloadAccessError(Exception):
86 """Payload was accessed after response was sent."""
89class AccessLoggerWrapper(AbstractAsyncAccessLogger):
90 """Wrap an AbstractAccessLogger so it behaves like an AbstractAsyncAccessLogger."""
92 def __init__(
93 self, access_logger: AbstractAccessLogger, loop: asyncio.AbstractEventLoop
94 ) -> None:
95 self.access_logger = access_logger
96 self._loop = loop
97 super().__init__()
99 async def log(
100 self, request: BaseRequest, response: StreamResponse, request_start: float
101 ) -> None:
102 self.access_logger.log(request, response, self._loop.time() - request_start)
105@dataclasses.dataclass(frozen=True)
106class _ErrInfo:
107 status: int
108 exc: BaseException
109 message: str
112_MsgType = Tuple[Union[RawRequestMessage, _ErrInfo], StreamReader]
115class RequestHandler(BaseProtocol):
116 """HTTP protocol implementation.
118 RequestHandler handles incoming HTTP request. It reads request line,
119 request headers and request payload and calls handle_request() method.
120 By default it always returns with 404 response.
122 RequestHandler handles errors in incoming request, like bad
123 status line, bad headers or incomplete payload. If any error occurs,
124 connection gets closed.
126 keepalive_timeout -- number of seconds before closing
127 keep-alive connection
129 tcp_keepalive -- TCP keep-alive is on, default is on
131 logger -- custom logger object
133 access_log_class -- custom class for access_logger
135 access_log -- custom logging object
137 access_log_format -- access log format string
139 loop -- Optional event loop
141 max_line_size -- Optional maximum header line size
143 max_field_size -- Optional maximum header field size
145 timeout_ceil_threshold -- Optional value to specify
146 threshold to ceil() timeout
147 values
149 """
151 KEEPALIVE_RESCHEDULE_DELAY = 1
153 __slots__ = (
154 "_request_count",
155 "_keepalive",
156 "_manager",
157 "_request_handler",
158 "_request_factory",
159 "_tcp_keepalive",
160 "_keepalive_time",
161 "_keepalive_handle",
162 "_keepalive_timeout",
163 "_lingering_time",
164 "_messages",
165 "_message_tail",
166 "_waiter",
167 "_task_handler",
168 "_upgrade",
169 "_payload_parser",
170 "_request_parser",
171 "logger",
172 "access_log",
173 "access_logger",
174 "_close",
175 "_force_close",
176 "_current_request",
177 "_timeout_ceil_threshold",
178 )
180 def __init__(
181 self,
182 manager: "Server",
183 *,
184 loop: asyncio.AbstractEventLoop,
185 keepalive_timeout: float = 75.0, # NGINX default is 75 secs
186 tcp_keepalive: bool = True,
187 logger: Logger = server_logger,
188 access_log_class: _AnyAbstractAccessLogger = AccessLogger,
189 access_log: Logger = access_logger,
190 access_log_format: str = AccessLogger.LOG_FORMAT,
191 max_line_size: int = 8190,
192 max_field_size: int = 8190,
193 lingering_time: float = 10.0,
194 read_bufsize: int = 2**16,
195 auto_decompress: bool = True,
196 timeout_ceil_threshold: float = 5,
197 ):
198 super().__init__(loop)
200 self._request_count = 0
201 self._keepalive = False
202 self._current_request: Optional[BaseRequest] = None
203 self._manager: Optional[Server] = manager
204 self._request_handler: Optional[_RequestHandler] = manager.request_handler
205 self._request_factory: Optional[_RequestFactory] = manager.request_factory
207 self._tcp_keepalive = tcp_keepalive
208 # placeholder to be replaced on keepalive timeout setup
209 self._keepalive_time = 0.0
210 self._keepalive_handle: Optional[asyncio.Handle] = None
211 self._keepalive_timeout = keepalive_timeout
212 self._lingering_time = float(lingering_time)
214 self._messages: Deque[_MsgType] = deque()
215 self._message_tail = b""
217 self._waiter: Optional[asyncio.Future[None]] = None
218 self._task_handler: Optional[asyncio.Task[None]] = None
220 self._upgrade = False
221 self._payload_parser: Any = None
222 self._request_parser: Optional[HttpRequestParser] = HttpRequestParser(
223 self,
224 loop,
225 read_bufsize,
226 max_line_size=max_line_size,
227 max_field_size=max_field_size,
228 payload_exception=RequestPayloadError,
229 auto_decompress=auto_decompress,
230 )
232 self._timeout_ceil_threshold: float = 5
233 try:
234 self._timeout_ceil_threshold = float(timeout_ceil_threshold)
235 except (TypeError, ValueError):
236 pass
238 self.logger = logger
239 self.access_log = access_log
240 if access_log:
241 if issubclass(access_log_class, AbstractAsyncAccessLogger):
242 self.access_logger: Optional[
243 AbstractAsyncAccessLogger
244 ] = access_log_class()
245 else:
246 access_logger = access_log_class(access_log, access_log_format)
247 self.access_logger = AccessLoggerWrapper(
248 access_logger,
249 self._loop,
250 )
251 else:
252 self.access_logger = None
254 self._close = False
255 self._force_close = False
257 def __repr__(self) -> str:
258 return "<{} {}>".format(
259 self.__class__.__name__,
260 "connected" if self.transport is not None else "disconnected",
261 )
263 @property
264 def keepalive_timeout(self) -> float:
265 return self._keepalive_timeout
267 async def shutdown(self, timeout: Optional[float] = 15.0) -> None:
268 """Do worker process exit preparations.
270 We need to clean up everything and stop accepting requests.
271 It is especially important for keep-alive connections.
272 """
273 self._force_close = True
275 if self._keepalive_handle is not None:
276 self._keepalive_handle.cancel()
278 if self._waiter:
279 self._waiter.cancel()
281 # wait for handlers
282 with suppress(asyncio.CancelledError, asyncio.TimeoutError):
283 async with ceil_timeout(timeout):
284 if self._current_request is not None:
285 self._current_request._cancel(asyncio.CancelledError())
287 if self._task_handler is not None and not self._task_handler.done():
288 await self._task_handler
290 # force-close non-idle handler
291 if self._task_handler is not None:
292 self._task_handler.cancel()
294 if self.transport is not None:
295 self.transport.close()
296 self.transport = None
298 def connection_made(self, transport: asyncio.BaseTransport) -> None:
299 super().connection_made(transport)
301 real_transport = cast(asyncio.Transport, transport)
302 if self._tcp_keepalive:
303 tcp_keepalive(real_transport)
305 self._task_handler = self._loop.create_task(self.start())
306 assert self._manager is not None
307 self._manager.connection_made(self, real_transport)
309 def connection_lost(self, exc: Optional[BaseException]) -> None:
310 if self._manager is None:
311 return
312 self._manager.connection_lost(self, exc)
314 super().connection_lost(exc)
316 # Grab value before setting _manager to None.
317 handler_cancellation = self._manager.handler_cancellation
319 self._manager = None
320 self._force_close = True
321 self._request_factory = None
322 self._request_handler = None
323 self._request_parser = None
325 if self._keepalive_handle is not None:
326 self._keepalive_handle.cancel()
328 if self._current_request is not None:
329 if exc is None:
330 exc = ConnectionResetError("Connection lost")
331 self._current_request._cancel(exc)
333 if self._waiter is not None:
334 self._waiter.cancel()
336 if handler_cancellation and self._task_handler is not None:
337 self._task_handler.cancel()
339 self._task_handler = None
341 if self._payload_parser is not None:
342 self._payload_parser.feed_eof()
343 self._payload_parser = None
345 def set_parser(self, parser: Any) -> None:
346 # Actual type is WebReader
347 assert self._payload_parser is None
349 self._payload_parser = parser
351 if self._message_tail:
352 self._payload_parser.feed_data(self._message_tail)
353 self._message_tail = b""
355 def eof_received(self) -> None:
356 pass
358 def data_received(self, data: bytes) -> None:
359 if self._force_close or self._close:
360 return
361 # parse http messages
362 messages: Sequence[_MsgType]
363 if self._payload_parser is None and not self._upgrade:
364 assert self._request_parser is not None
365 try:
366 messages, upgraded, tail = self._request_parser.feed_data(data)
367 except HttpProcessingError as exc:
368 messages = [
369 (_ErrInfo(status=400, exc=exc, message=exc.message), EMPTY_PAYLOAD)
370 ]
371 upgraded = False
372 tail = b""
374 for msg, payload in messages or ():
375 self._request_count += 1
376 self._messages.append((msg, payload))
378 waiter = self._waiter
379 if messages and waiter is not None and not waiter.done():
380 # don't set result twice
381 waiter.set_result(None)
383 self._upgrade = upgraded
384 if upgraded and tail:
385 self._message_tail = tail
387 # no parser, just store
388 elif self._payload_parser is None and self._upgrade and data:
389 self._message_tail += data
391 # feed payload
392 elif data:
393 eof, tail = self._payload_parser.feed_data(data)
394 if eof:
395 self.close()
397 def keep_alive(self, val: bool) -> None:
398 """Set keep-alive connection mode.
400 :param bool val: new state.
401 """
402 self._keepalive = val
403 if self._keepalive_handle:
404 self._keepalive_handle.cancel()
405 self._keepalive_handle = None
407 def close(self) -> None:
408 """Close connection.
410 Stop accepting new pipelining messages and close
411 connection when handlers done processing messages.
412 """
413 self._close = True
414 if self._waiter:
415 self._waiter.cancel()
417 def force_close(self) -> None:
418 """Forcefully close connection."""
419 self._force_close = True
420 if self._waiter:
421 self._waiter.cancel()
422 if self.transport is not None:
423 self.transport.close()
424 self.transport = None
426 async def log_access(
427 self, request: BaseRequest, response: StreamResponse, request_start: float
428 ) -> None:
429 if self.access_logger is not None:
430 await self.access_logger.log(request, response, request_start)
432 def log_debug(self, *args: Any, **kw: Any) -> None:
433 if self._loop.get_debug():
434 self.logger.debug(*args, **kw)
436 def log_exception(self, *args: Any, **kw: Any) -> None:
437 self.logger.exception(*args, **kw)
439 def _process_keepalive(self) -> None:
440 if self._force_close or not self._keepalive:
441 return
443 next = self._keepalive_time + self._keepalive_timeout
445 # handler in idle state
446 if self._waiter:
447 if self._loop.time() > next:
448 self.force_close()
449 return
451 # not all request handlers are done,
452 # reschedule itself to next second
453 self._keepalive_handle = self._loop.call_later(
454 self.KEEPALIVE_RESCHEDULE_DELAY,
455 self._process_keepalive,
456 )
458 async def _handle_request(
459 self,
460 request: BaseRequest,
461 start_time: float,
462 request_handler: Callable[[BaseRequest], Awaitable[StreamResponse]],
463 ) -> Tuple[StreamResponse, bool]:
464 assert self._request_handler is not None
465 try:
466 try:
467 self._current_request = request
468 resp = await request_handler(request)
469 finally:
470 self._current_request = None
471 except HTTPException as exc:
472 resp = Response(
473 status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers
474 )
475 resp._cookies = exc._cookies
476 reset = await self.finish_response(request, resp, start_time)
477 except asyncio.CancelledError:
478 raise
479 except asyncio.TimeoutError as exc:
480 self.log_debug("Request handler timed out.", exc_info=exc)
481 resp = self.handle_error(request, 504)
482 reset = await self.finish_response(request, resp, start_time)
483 except Exception as exc:
484 resp = self.handle_error(request, 500, exc)
485 reset = await self.finish_response(request, resp, start_time)
486 else:
487 reset = await self.finish_response(request, resp, start_time)
489 return resp, reset
491 async def start(self) -> None:
492 """Process incoming request.
494 It reads request line, request headers and request payload, then
495 calls handle_request() method. Subclass has to override
496 handle_request(). start() handles various exceptions in request
497 or response handling. Connection is being closed always unless
498 keep_alive(True) specified.
499 """
500 loop = self._loop
501 handler = self._task_handler
502 assert handler is not None
503 manager = self._manager
504 assert manager is not None
505 keepalive_timeout = self._keepalive_timeout
506 resp = None
507 assert self._request_factory is not None
508 assert self._request_handler is not None
510 while not self._force_close:
511 if not self._messages:
512 try:
513 # wait for next request
514 self._waiter = loop.create_future()
515 await self._waiter
516 except asyncio.CancelledError:
517 break
518 finally:
519 self._waiter = None
521 message, payload = self._messages.popleft()
523 start = loop.time()
525 manager.requests_count += 1
526 writer = StreamWriter(self, loop)
527 if isinstance(message, _ErrInfo):
528 # make request_factory work
529 request_handler = self._make_error_handler(message)
530 message = ERROR
531 else:
532 request_handler = self._request_handler
534 request = self._request_factory(message, payload, self, writer, handler)
535 try:
536 # a new task is used for copy context vars (#3406)
537 task = self._loop.create_task(
538 self._handle_request(request, start, request_handler)
539 )
540 try:
541 resp, reset = await task
542 except (asyncio.CancelledError, ConnectionError):
543 self.log_debug("Ignored premature client disconnection")
544 break
546 # Drop the processed task from asyncio.Task.all_tasks() early
547 del task
548 if reset:
549 self.log_debug("Ignored premature client disconnection 2")
550 break
552 # notify server about keep-alive
553 self._keepalive = bool(resp.keep_alive)
555 # check payload
556 if not payload.is_eof():
557 lingering_time = self._lingering_time
558 if not self._force_close and lingering_time:
559 self.log_debug(
560 "Start lingering close timer for %s sec.", lingering_time
561 )
563 now = loop.time()
564 end_t = now + lingering_time
566 with suppress(asyncio.TimeoutError, asyncio.CancelledError):
567 while not payload.is_eof() and now < end_t:
568 async with ceil_timeout(end_t - now):
569 # read and ignore
570 await payload.readany()
571 now = loop.time()
573 # if payload still uncompleted
574 if not payload.is_eof() and not self._force_close:
575 self.log_debug("Uncompleted request.")
576 self.close()
578 payload.set_exception(PayloadAccessError())
580 except asyncio.CancelledError:
581 self.log_debug("Ignored premature client disconnection ")
582 break
583 except RuntimeError as exc:
584 if self._loop.get_debug():
585 self.log_exception("Unhandled runtime exception", exc_info=exc)
586 self.force_close()
587 except Exception as exc:
588 self.log_exception("Unhandled exception", exc_info=exc)
589 self.force_close()
590 finally:
591 if self.transport is None and resp is not None:
592 self.log_debug("Ignored premature client disconnection.")
593 elif not self._force_close:
594 if self._keepalive and not self._close:
595 # start keep-alive timer
596 if keepalive_timeout is not None:
597 now = self._loop.time()
598 self._keepalive_time = now
599 if self._keepalive_handle is None:
600 self._keepalive_handle = loop.call_at(
601 now + keepalive_timeout, self._process_keepalive
602 )
603 else:
604 break
606 # remove handler, close transport if no handlers left
607 if not self._force_close:
608 self._task_handler = None
609 if self.transport is not None:
610 self.transport.close()
612 async def finish_response(
613 self, request: BaseRequest, resp: StreamResponse, start_time: float
614 ) -> bool:
615 """Prepare the response and write_eof, then log access.
617 This has to
618 be called within the context of any exception so the access logger
619 can get exception information. Returns True if the client disconnects
620 prematurely.
621 """
622 request._finish()
623 if self._request_parser is not None:
624 self._request_parser.set_upgraded(False)
625 self._upgrade = False
626 if self._message_tail:
627 self._request_parser.feed_data(self._message_tail)
628 self._message_tail = b""
629 try:
630 prepare_meth = resp.prepare
631 except AttributeError:
632 if resp is None:
633 raise RuntimeError("Missing return " "statement on request handler")
634 else:
635 raise RuntimeError(
636 "Web-handler should return "
637 "a response instance, "
638 "got {!r}".format(resp)
639 )
640 try:
641 await prepare_meth(request)
642 await resp.write_eof()
643 except ConnectionError:
644 await self.log_access(request, resp, start_time)
645 return True
646 else:
647 await self.log_access(request, resp, start_time)
648 return False
650 def handle_error(
651 self,
652 request: BaseRequest,
653 status: int = 500,
654 exc: Optional[BaseException] = None,
655 message: Optional[str] = None,
656 ) -> StreamResponse:
657 """Handle errors.
659 Returns HTTP response with specific status code. Logs additional
660 information. It always closes current connection.
661 """
662 self.log_exception("Error handling request", exc_info=exc)
664 # some data already got sent, connection is broken
665 if request.writer.output_size > 0:
666 raise ConnectionError(
667 "Response is sent already, cannot send another response "
668 "with the error message"
669 )
671 ct = "text/plain"
672 if status == HTTPStatus.INTERNAL_SERVER_ERROR:
673 title = "{0.value} {0.phrase}".format(HTTPStatus.INTERNAL_SERVER_ERROR)
674 msg = HTTPStatus.INTERNAL_SERVER_ERROR.description
675 tb = None
676 if self._loop.get_debug():
677 with suppress(Exception):
678 tb = traceback.format_exc()
680 if "text/html" in request.headers.get("Accept", ""):
681 if tb:
682 tb = html_escape(tb)
683 msg = f"<h2>Traceback:</h2>\n<pre>{tb}</pre>"
684 message = (
685 "<html><head>"
686 "<title>{title}</title>"
687 "</head><body>\n<h1>{title}</h1>"
688 "\n{msg}\n</body></html>\n"
689 ).format(title=title, msg=msg)
690 ct = "text/html"
691 else:
692 if tb:
693 msg = tb
694 message = title + "\n\n" + msg
696 resp = Response(status=status, text=message, content_type=ct)
697 resp.force_close()
699 return resp
701 def _make_error_handler(
702 self, err_info: _ErrInfo
703 ) -> Callable[[BaseRequest], Awaitable[StreamResponse]]:
704 async def handler(request: BaseRequest) -> StreamResponse:
705 return self.handle_error(
706 request, err_info.status, err_info.exc, err_info.message
707 )
709 return handler