Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/helpers.py: 39%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Various helper functions"""
3import asyncio
4import base64
5import contextlib
6import dataclasses
7import datetime
8import enum
9import functools
10import inspect
11import netrc
12import os
13import platform
14import re
15import sys
16import time
17import warnings
18import weakref
19from collections.abc import Callable, Iterable, Iterator, Mapping
20from contextlib import suppress
21from email.message import EmailMessage
22from email.parser import HeaderParser
23from email.policy import HTTP
24from email.utils import parsedate
25from http.cookies import SimpleCookie
26from math import ceil
27from pathlib import Path
28from types import MappingProxyType, TracebackType
29from typing import (
30 TYPE_CHECKING,
31 Any,
32 ContextManager,
33 Generic,
34 Protocol,
35 TypeVar,
36 Union,
37 final,
38 get_args,
39 overload,
40)
41from urllib.parse import quote
42from urllib.request import getproxies, proxy_bypass
44from multidict import CIMultiDict, MultiDict, MultiDictProxy
45from propcache.api import under_cached_property as reify
46from yarl import URL
48from . import hdrs
49from .log import client_logger
50from .typedefs import PathLike # noqa
52if sys.version_info >= (3, 11):
53 import asyncio as async_timeout
54else:
55 import async_timeout
57if TYPE_CHECKING:
58 from dataclasses import dataclass as frozen_dataclass_decorator
59else:
60 frozen_dataclass_decorator = functools.partial(
61 dataclasses.dataclass, frozen=True, slots=True
62 )
64__all__ = ("ChainMapProxy", "ETag", "frozen_dataclass_decorator", "reify")
66# This is the default size/limit for several operations.
67# Matches the max size we receive from sockets:
68# https://github.com/python/cpython/blob/1857a40807daeae3a1bf5efb682de9c9ae6df845/Lib/asyncio/selector_events.py#L766
69DEFAULT_CHUNK_SIZE = 2**18 # 256 KiB
70COOKIE_MAX_LENGTH = 4096
71_QUOTED_PAIR_SUB = re.compile(r"\\(.)")
72_QUOTED_STRING = r'"(?:[^"\\]|\\.)*"'
73_ESCAPED_COMMENT = r"(?:[^()\\]|\\.)*"
74# Matches one element in a comma-separated header list.
75# Group 1: content of a top-level quoted-string (quotes stripped).
76# Group 2: an unquoted element (may contain parameter quoted-strings / comments).
77_LIST_ELEMENT_RE = re.compile(
78 rf"""
79 [ \t]*
80 (?:
81 "( (?:[^"\\]|\\.)* )" # group 1: top-level quoted-string
82 | ( # group 2: unquoted element
83 (?:
84 (?<=[^\s]=) {_QUOTED_STRING} # parameter quoted value
85 | (?<=\s) \( {_ESCAPED_COMMENT} \) # comment
86 | [^,] # any non-comma character
87 )+?
88 )
89 )
90 [ \t]* (?:,|\Z)
91 """,
92 re.VERBOSE,
93)
94# Finds parameter quoted-strings and comments inside an unquoted element for unescaping.
95_PROTECTED_RE = re.compile(
96 rf"""
97 (?<=[^\s]=) {_QUOTED_STRING} # parameter quoted-string
98 | (?<=\s) \( {_ESCAPED_COMMENT} \) # comment
99 """,
100 re.VERBOSE,
101)
103_T = TypeVar("_T")
104_S = TypeVar("_S")
106_SENTINEL = enum.Enum("_SENTINEL", "sentinel")
107sentinel = _SENTINEL.sentinel
109NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS"))
111# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
112EMPTY_BODY_STATUS_CODES = frozenset((204, 304, *range(100, 200)))
113# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
114# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.2
115EMPTY_BODY_METHODS = hdrs.METH_HEAD_ALL
117DEBUG = sys.flags.dev_mode or (
118 not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG"))
119)
122CHAR = {chr(i) for i in range(0, 128)}
123CTL = {chr(i) for i in range(0, 32)} | {
124 chr(127),
125}
126SEPARATORS = {
127 "(",
128 ")",
129 "<",
130 ">",
131 "@",
132 ",",
133 ";",
134 ":",
135 "\\",
136 '"',
137 "/",
138 "[",
139 "]",
140 "?",
141 "=",
142 "{",
143 "}",
144 " ",
145 chr(9),
146}
147TOKEN = CHAR ^ CTL ^ SEPARATORS
150json_re = re.compile(r"^(?:application/|[\w.-]+/[\w.+-]+?\+)json$", re.IGNORECASE)
153def encode_basic_auth(login: str, password: str = "", encoding: str = "utf-8") -> str:
154 """Encode HTTP Basic Authentication credentials as an Authorization header value.
156 Returns a string of the form ``"Basic <base64>"`` suitable for use as the
157 value of the ``Authorization`` (or ``Proxy-Authorization``) header.
158 """
159 if ":" in login:
160 raise ValueError('A ":" is not allowed in login (RFC 7617#section-2)')
161 creds = f"{login}:{password}".encode(encoding)
162 return "Basic " + base64.b64encode(creds).decode(encoding)
165def strip_auth_from_url(url: URL) -> tuple[URL, str | None]:
166 """Strip user/password from a URL and return the Authorization header value.
168 Returns a tuple of ``(url_without_credentials, authorization_header_value)``.
169 The header value is ``None`` if no credentials were present.
170 """
171 # Check raw_user and raw_password first as yarl is likely
172 # to already have these values parsed from the netloc in the cache.
173 if url.raw_user is None and url.raw_password is None:
174 return url, None
175 return url.with_user(None), encode_basic_auth(url.user or "", url.password or "")
178def netrc_from_env() -> netrc.netrc | None:
179 """Load netrc from file.
181 Attempt to load it from the path specified by the env-var
182 NETRC or in the default location in the user's home directory.
184 Returns None if it couldn't be found or fails to parse.
185 """
186 netrc_env = os.environ.get("NETRC")
188 if netrc_env is not None:
189 netrc_path = Path(netrc_env)
190 else:
191 try:
192 home_dir = Path.home()
193 except RuntimeError as e:
194 # if pathlib can't resolve home, it may raise a RuntimeError
195 client_logger.debug(
196 "Could not resolve home directory when "
197 "trying to look for .netrc file: %s",
198 e,
199 )
200 return None
202 netrc_path = home_dir / (
203 "_netrc" if platform.system() == "Windows" else ".netrc"
204 )
206 try:
207 return netrc.netrc(str(netrc_path))
208 except netrc.NetrcParseError as e:
209 client_logger.warning("Could not parse .netrc file: %s", e)
210 except OSError as e:
211 netrc_exists = False
212 with contextlib.suppress(OSError):
213 netrc_exists = netrc_path.is_file()
214 # we couldn't read the file (doesn't exist, permissions, etc.)
215 if netrc_env or netrc_exists:
216 # only warn if the environment wanted us to load it,
217 # or it appears like the default file does actually exist
218 client_logger.warning("Could not read .netrc file: %s", e)
220 return None
223@frozen_dataclass_decorator
224class ProxyInfo:
225 proxy: URL
226 proxy_auth: str | None
229def _auth_header_from_netrc(netrc_obj: netrc.netrc | None, host: str) -> str:
230 """Return a ``Proxy-Authorization`` header value for ``host`` from netrc.
232 :raises LookupError: if ``netrc_obj`` is :py:data:`None` or if no
233 entry is found for the ``host``.
234 """
235 if netrc_obj is None:
236 raise LookupError("No .netrc file found")
237 auth_from_netrc = netrc_obj.authenticators(host)
239 if auth_from_netrc is None:
240 raise LookupError(f"No entry for {host!s} found in the `.netrc` file.")
241 login, account, password = auth_from_netrc
243 # TODO(PY311): username = login or account
244 # Up to python 3.10, account could be None if not specified,
245 # and login will be empty string if not specified. From 3.11,
246 # login and account will be empty string if not specified.
247 username = login if (login or account is None) else account
249 # TODO(PY311): Remove this, as password will be empty string
250 # if not specified
251 if password is None:
252 password = "" # type: ignore[unreachable]
254 return encode_basic_auth(username, password)
257def proxies_from_env() -> dict[str, ProxyInfo]:
258 proxy_urls = {
259 k: URL(v)
260 for k, v in getproxies().items()
261 if k in ("http", "https", "ws", "wss")
262 }
263 netrc_obj = netrc_from_env()
264 stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()}
265 ret = {}
266 for proto, val in stripped.items():
267 proxy, auth = val
268 if proxy.scheme in ("https", "wss"):
269 client_logger.warning(
270 "%s proxies %s are not supported, ignoring", proxy.scheme.upper(), proxy
271 )
272 continue
273 if netrc_obj and auth is None:
274 if proxy.host is not None:
275 try:
276 auth = _auth_header_from_netrc(netrc_obj, proxy.host)
277 except LookupError:
278 auth = None
279 ret[proto] = ProxyInfo(proxy, auth)
280 return ret
283def get_env_proxy_for_url(url: URL) -> tuple[URL, str | None]:
284 """Get a permitted proxy for the given URL from the env."""
285 if url.host is not None and proxy_bypass(url.host):
286 raise LookupError(f"Proxying is disallowed for `{url.host!r}`")
288 proxies_in_env = proxies_from_env()
289 try:
290 proxy_info = proxies_in_env[url.scheme]
291 except KeyError:
292 raise LookupError(f"No proxies found for `{url!s}` in the env")
293 else:
294 return proxy_info.proxy, proxy_info.proxy_auth
297@frozen_dataclass_decorator
298class MimeType:
299 type: str
300 subtype: str
301 suffix: str
302 parameters: "MultiDictProxy[str]"
305@functools.lru_cache(maxsize=56)
306def parse_mimetype(mimetype: str) -> MimeType:
307 """Parses a MIME type into its components.
309 mimetype is a MIME type string.
311 Returns a MimeType object.
313 Example:
315 >>> parse_mimetype('text/html; charset=utf-8')
316 MimeType(type='text', subtype='html', suffix='',
317 parameters={'charset': 'utf-8'})
319 """
320 if not mimetype:
321 return MimeType(
322 type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict())
323 )
325 parts = mimetype.split(";")
326 params: MultiDict[str] = MultiDict()
327 for item in parts[1:]:
328 if not item:
329 continue
330 key, _, value = item.partition("=")
331 params.add(key.lower().strip(), value.strip(' "'))
333 fulltype = parts[0].strip().lower()
334 if fulltype == "*":
335 fulltype = "*/*"
337 mtype, _, stype = fulltype.partition("/")
338 stype, _, suffix = stype.partition("+")
340 return MimeType(
341 type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params)
342 )
345class EnsureOctetStream(EmailMessage):
346 def __init__(self) -> None:
347 super().__init__()
348 # https://www.rfc-editor.org/rfc/rfc9110#section-8.3-5
349 self.set_default_type("application/octet-stream")
351 def get_content_type(self) -> str:
352 """Re-implementation from Message
354 Returns application/octet-stream in place of plain/text when
355 value is wrong.
357 The way this class is used guarantees that content-type will
358 be present so simplify the checks wrt to the base implementation.
359 """
360 value = self.get("content-type", "").lower()
362 # Based on the implementation of _splitparam in the standard library
363 ctype, _, _ = value.partition(";")
364 ctype = ctype.strip()
365 if ctype.count("/") != 1:
366 return self.get_default_type()
367 return ctype
370@functools.lru_cache(maxsize=56)
371def parse_content_type(raw: str) -> tuple[str, MappingProxyType[str, str]]:
372 """Parse Content-Type header.
374 Returns a tuple of the parsed content type and a
375 MappingProxyType of parameters. The default returned value
376 is `application/octet-stream`
377 """
378 msg = HeaderParser(EnsureOctetStream, policy=HTTP).parsestr(f"Content-Type: {raw}")
379 content_type = msg.get_content_type()
380 params = msg.get_params(())
381 content_dict = dict(params[1:]) # First element is content type again
382 return content_type, MappingProxyType(content_dict)
385def guess_filename(obj: Any, default: str | None = None) -> str | None:
386 name = getattr(obj, "name", None)
387 if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">":
388 return Path(name).name
389 return default
392not_qtext_re = re.compile(r"[^\041\043-\133\135-\176]")
393QCONTENT = {chr(i) for i in range(0x20, 0x7F)} | {"\t"}
396def quoted_string(content: str) -> str:
397 """Return 7-bit content as quoted-string.
399 Format content into a quoted-string as defined in RFC5322 for
400 Internet Message Format. Notice that this is not the 8-bit HTTP
401 format, but the 7-bit email format. Content must be in usascii or
402 a ValueError is raised.
403 """
404 if not (QCONTENT > set(content)):
405 raise ValueError(f"bad content for quoted-string {content!r}")
406 return not_qtext_re.sub(lambda x: "\\" + x.group(0), content)
409def content_disposition_header(
410 disptype: str,
411 quote_fields: bool = True,
412 _charset: str = "utf-8",
413 params: dict[str, str] | None = None,
414) -> str:
415 """Sets ``Content-Disposition`` header for MIME.
417 This is the MIME payload Content-Disposition header from RFC 2183
418 and RFC 7579 section 4.2, not the HTTP Content-Disposition from
419 RFC 6266.
421 disptype is a disposition type: inline, attachment, form-data.
422 Should be valid extension token (see RFC 2183)
424 quote_fields performs value quoting to 7-bit MIME headers
425 according to RFC 7578. Set to quote_fields to False if recipient
426 can take 8-bit file names and field values.
428 _charset specifies the charset to use when quote_fields is True.
430 params is a dict with disposition params.
431 """
432 if not disptype or not (TOKEN > set(disptype)):
433 raise ValueError(f"bad content disposition type {disptype!r}")
435 value = disptype
436 if params:
437 lparams = []
438 for key, val in params.items():
439 if not key or not (TOKEN > set(key)):
440 raise ValueError(f"bad content disposition parameter {key!r}={val!r}")
441 if quote_fields:
442 if key.lower() == "filename":
443 qval = quote(val, "", encoding=_charset)
444 lparams.append((key, '"%s"' % qval))
445 else:
446 try:
447 qval = quoted_string(val)
448 except ValueError:
449 qval = "".join(
450 (_charset, "''", quote(val, "", encoding=_charset))
451 )
452 lparams.append((key + "*", qval))
453 else:
454 lparams.append((key, '"%s"' % qval))
455 else:
456 qval = val.replace("\\", "\\\\").replace('"', '\\"')
457 lparams.append((key, '"%s"' % qval))
458 sparams = "; ".join("=".join(pair) for pair in lparams)
459 value = "; ".join((value, sparams))
460 return value
463def is_expected_content_type(
464 response_content_type: str, expected_content_type: str
465) -> bool:
466 """Checks if received content type is processable as an expected one.
468 Both arguments should be given without parameters.
469 """
470 if expected_content_type == "application/json":
471 return json_re.match(response_content_type) is not None
472 return expected_content_type in response_content_type
475def is_ip_address(host: str | None) -> bool:
476 """Check if host looks like an IP Address.
478 This check is only meant as a heuristic to ensure that
479 a host is not a domain name.
480 """
481 if not host:
482 return False
483 # For a host to be an ipv4 address, it must be all numeric.
484 # The host must contain a colon to be an IPv6 address.
485 return ":" in host or host.replace(".", "").isdigit()
488_cached_current_datetime: int | None = None
489_cached_formatted_datetime = ""
492def rfc822_formatted_time() -> str:
493 global _cached_current_datetime
494 global _cached_formatted_datetime
496 now = int(time.time())
497 if now != _cached_current_datetime:
498 # Weekday and month names for HTTP date/time formatting;
499 # always English!
500 # Tuples are constants stored in codeobject!
501 _weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun")
502 _monthname = (
503 "", # Dummy so we can use 1-based month numbers
504 "Jan",
505 "Feb",
506 "Mar",
507 "Apr",
508 "May",
509 "Jun",
510 "Jul",
511 "Aug",
512 "Sep",
513 "Oct",
514 "Nov",
515 "Dec",
516 )
518 year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now)
519 _cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
520 _weekdayname[wd],
521 day,
522 _monthname[month],
523 year,
524 hh,
525 mm,
526 ss,
527 )
528 _cached_current_datetime = now
529 return _cached_formatted_datetime
532def _weakref_handle(info: "tuple[weakref.ref[object], str]") -> None:
533 ref, name = info
534 ob = ref()
535 if ob is not None:
536 with suppress(Exception):
537 getattr(ob, name)()
540def weakref_handle(
541 ob: object,
542 name: str,
543 timeout: float | None,
544 loop: asyncio.AbstractEventLoop,
545 timeout_ceil_threshold: float = 5,
546) -> asyncio.TimerHandle | None:
547 if timeout is not None and timeout > 0:
548 when = loop.time() + timeout
549 if timeout >= timeout_ceil_threshold:
550 when = ceil(when)
552 return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name))
553 return None
556def call_later(
557 cb: Callable[[], Any],
558 timeout: float | None,
559 loop: asyncio.AbstractEventLoop,
560 timeout_ceil_threshold: float = 5,
561) -> asyncio.TimerHandle | None:
562 if timeout is None or timeout <= 0:
563 return None
564 now = loop.time()
565 when = calculate_timeout_when(now, timeout, timeout_ceil_threshold)
566 return loop.call_at(when, cb)
569def calculate_timeout_when(
570 loop_time: float,
571 timeout: float,
572 timeout_ceiling_threshold: float,
573) -> float:
574 """Calculate when to execute a timeout."""
575 when = loop_time + timeout
576 if timeout > timeout_ceiling_threshold:
577 return ceil(when)
578 return when
581class TimeoutHandle:
582 """Timeout handle"""
584 __slots__ = ("_timeout", "_loop", "_ceil_threshold", "_callbacks")
586 def __init__(
587 self,
588 loop: asyncio.AbstractEventLoop,
589 timeout: float | None,
590 ceil_threshold: float = 5,
591 ) -> None:
592 self._timeout = timeout
593 self._loop = loop
594 self._ceil_threshold = ceil_threshold
595 self._callbacks: list[
596 tuple[Callable[..., None], tuple[Any, ...], dict[str, Any]]
597 ] = []
599 def register(
600 self, callback: Callable[..., None], *args: Any, **kwargs: Any
601 ) -> None:
602 self._callbacks.append((callback, args, kwargs))
604 def close(self) -> None:
605 self._callbacks.clear()
607 def start(self) -> asyncio.TimerHandle | None:
608 timeout = self._timeout
609 if timeout is not None and timeout > 0:
610 when = self._loop.time() + timeout
611 if timeout >= self._ceil_threshold:
612 when = ceil(when)
613 return self._loop.call_at(when, self.__call__)
614 else:
615 return None
617 def timer(self) -> "BaseTimerContext":
618 if self._timeout is not None and self._timeout > 0:
619 timer = TimerContext(self._loop)
620 self.register(timer.timeout)
621 return timer
622 else:
623 return TimerNoop()
625 def __call__(self) -> None:
626 for cb, args, kwargs in self._callbacks:
627 with suppress(Exception):
628 cb(*args, **kwargs)
630 self._callbacks.clear()
633class BaseTimerContext(ContextManager["BaseTimerContext"]):
635 __slots__ = ()
637 def assert_timeout(self) -> None:
638 """Raise TimeoutError if timeout has been exceeded."""
641class TimerNoop(BaseTimerContext):
643 __slots__ = ()
645 def __enter__(self) -> BaseTimerContext:
646 return self
648 def __exit__(
649 self,
650 exc_type: type[BaseException] | None,
651 exc_val: BaseException | None,
652 exc_tb: TracebackType | None,
653 ) -> None:
654 return
657class TimerContext(BaseTimerContext):
658 """Low resolution timeout context manager"""
660 __slots__ = ("_loop", "_tasks", "_cancelled", "_cancelling")
662 def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
663 self._loop = loop
664 self._tasks: list[asyncio.Task[Any]] = []
665 self._cancelled = False
666 self._cancelling = 0
668 def assert_timeout(self) -> None:
669 """Raise TimeoutError if timer has already been cancelled."""
670 if self._cancelled:
671 raise asyncio.TimeoutError from None
673 def __enter__(self) -> BaseTimerContext:
674 task = asyncio.current_task(loop=self._loop)
675 if task is None:
676 raise RuntimeError("Timeout context manager should be used inside a task")
678 if sys.version_info >= (3, 11):
679 # Remember if the task was already cancelling
680 # so when we __exit__ we can decide if we should
681 # raise asyncio.TimeoutError or let the cancellation propagate
682 self._cancelling = task.cancelling()
684 if self._cancelled:
685 raise asyncio.TimeoutError from None
687 self._tasks.append(task)
688 return self
690 def __exit__(
691 self,
692 exc_type: type[BaseException] | None,
693 exc_val: BaseException | None,
694 exc_tb: TracebackType | None,
695 ) -> bool | None:
696 enter_task: asyncio.Task[Any] | None = None
697 if self._tasks:
698 enter_task = self._tasks.pop()
700 if exc_type is asyncio.CancelledError and self._cancelled:
701 assert enter_task is not None
702 # The timeout was hit, and the task was cancelled
703 # so we need to uncancel the last task that entered the context manager
704 # since the cancellation should not leak out of the context manager
705 if sys.version_info >= (3, 11):
706 # If the task was already cancelling don't raise
707 # asyncio.TimeoutError and instead return None
708 # to allow the cancellation to propagate
709 if enter_task.uncancel() > self._cancelling:
710 return None
711 raise asyncio.TimeoutError from exc_val
712 return None
714 def timeout(self) -> None:
715 if not self._cancelled:
716 for task in set(self._tasks):
717 task.cancel()
719 self._cancelled = True
722def ceil_timeout(
723 delay: float | None, ceil_threshold: float = 5
724) -> async_timeout.Timeout:
725 if delay is None or delay <= 0:
726 return async_timeout.timeout(None)
728 loop = asyncio.get_running_loop()
729 now = loop.time()
730 when = now + delay
731 if delay > ceil_threshold:
732 when = ceil(when)
733 return async_timeout.timeout_at(when)
736class HeadersDictProxy(Mapping[str, str]):
737 def __init__(self, md: CIMultiDict[str]):
738 self._md = md
740 def getall(self, key: str) -> tuple[str, ...]:
741 val = self.get(key, "")
742 unescape = _QUOTED_PAIR_SUB.sub
743 values = []
744 for m in _LIST_ELEMENT_RE.finditer(val):
745 qs = m.group(1)
746 if qs is not None:
747 values.append(unescape(r"\1", qs))
748 else:
749 raw = m.group(2).strip()
750 if raw:
751 values.append(
752 _PROTECTED_RE.sub(lambda p: unescape(r"\1", p.group()), raw)
753 )
754 return tuple(values)
756 def __eq__(self, other: object) -> bool:
757 return self._md.__eq__(other)
759 def __getitem__(self, key: str) -> str:
760 return ", ".join(self._md.getall(key))
762 def __iter__(self) -> Iterator[str]:
763 # We need to deduplicate keys from MultiDict
764 # But, we also need to retain ordering
765 seen = set()
766 for k in self._md.__iter__():
767 if k in seen:
768 continue
769 seen.add(k)
770 yield k
772 def __len__(self) -> int:
773 return len(set(self._md.keys()))
775 def __repr__(self) -> str:
776 body = ", ".join(f"'{k}': {v!r}" for k, v in self.items())
777 return f"<{self.__class__.__name__}({body})>"
780class HeadersMixin:
781 """Mixin for handling headers."""
783 _headers: Mapping[str, str]
784 _content_type: str | None = None
785 _content_dict: dict[str, str] | None = None
786 _stored_content_type: str | None | _SENTINEL = sentinel
788 def _parse_content_type(self, raw: str | None) -> None:
789 self._stored_content_type = raw
790 if raw is None:
791 # default value according to RFC 2616
792 self._content_type = "application/octet-stream"
793 self._content_dict = {}
794 else:
795 content_type, content_mapping_proxy = parse_content_type(raw)
796 self._content_type = content_type
797 # _content_dict needs to be mutable so we can update it
798 self._content_dict = content_mapping_proxy.copy()
800 @property
801 def content_type(self) -> str:
802 """The value of content part for Content-Type HTTP header."""
803 raw = self._headers.get(hdrs.CONTENT_TYPE)
804 if self._stored_content_type != raw:
805 self._parse_content_type(raw)
806 assert self._content_type is not None
807 return self._content_type
809 @property
810 def charset(self) -> str | None:
811 """The value of charset part for Content-Type HTTP header."""
812 raw = self._headers.get(hdrs.CONTENT_TYPE)
813 if self._stored_content_type != raw:
814 self._parse_content_type(raw)
815 assert self._content_dict is not None
816 return self._content_dict.get("charset")
818 @property
819 def content_length(self) -> int | None:
820 """The value of Content-Length HTTP header."""
821 content_length = self._headers.get(hdrs.CONTENT_LENGTH)
822 return None if content_length is None else int(content_length)
825def set_result(fut: "asyncio.Future[_T]", result: _T) -> None:
826 if not fut.done():
827 fut.set_result(result)
830_EXC_SENTINEL = BaseException()
833class ErrorableProtocol(Protocol):
834 def set_exception(
835 self,
836 exc: type[BaseException] | BaseException,
837 exc_cause: BaseException = ...,
838 ) -> None: ...
841def set_exception(
842 fut: Union["asyncio.Future[_T]", ErrorableProtocol],
843 exc: type[BaseException] | BaseException,
844 exc_cause: BaseException = _EXC_SENTINEL,
845) -> None:
846 """Set future exception.
848 If the future is marked as complete, this function is a no-op.
850 :param exc_cause: An exception that is a direct cause of ``exc``.
851 Only set if provided.
852 """
853 if asyncio.isfuture(fut) and fut.done():
854 return
856 exc_is_sentinel = exc_cause is _EXC_SENTINEL
857 exc_causes_itself = exc is exc_cause
858 if not exc_is_sentinel and not exc_causes_itself:
859 exc.__cause__ = exc_cause
861 fut.set_exception(exc)
864@functools.total_ordering
865class BaseKey(Generic[_T]):
866 """Base for concrete context storage key classes.
868 Each storage is provided with its own sub-class for the sake of some additional type safety.
869 """
871 __slots__ = ("_name", "_t", "__orig_class__")
873 # This may be set by Python when instantiating with a generic type. We need to
874 # support this, in order to support types that are not concrete classes,
875 # like Iterable, which can't be passed as the second parameter to __init__.
876 __orig_class__: type[object]
878 # TODO(PY314): Change Type to TypeForm (this should resolve unreachable below).
879 def __init__(self, name: str, t: type[_T] | None = None):
880 # Prefix with module name to help deduplicate key names.
881 frame = inspect.currentframe()
882 while frame:
883 if frame.f_code.co_name == "<module>":
884 module: str = frame.f_globals["__name__"]
885 break
886 frame = frame.f_back
887 else:
888 raise RuntimeError("Failed to get module name.")
890 # https://github.com/python/mypy/issues/14209
891 self._name = module + "." + name # type: ignore[possibly-undefined]
892 self._t = t
894 def __lt__(self, other: object) -> bool:
895 if isinstance(other, BaseKey):
896 return self._name < other._name
897 return True # Order BaseKey above other types.
899 def __repr__(self) -> str:
900 t = self._t
901 if t is None:
902 with suppress(AttributeError):
903 # Set to type arg.
904 t = get_args(self.__orig_class__)[0]
906 if t is None:
907 t_repr = "<<Unknown>>"
908 elif isinstance(t, type):
909 if t.__module__ == "builtins":
910 t_repr = t.__qualname__
911 else:
912 t_repr = f"{t.__module__}.{t.__qualname__}"
913 else:
914 t_repr = repr(t) # type: ignore[unreachable]
915 return f"<{self.__class__.__name__}({self._name}, type={t_repr})>"
918class AppKey(BaseKey[_T]):
919 """Keys for static typing support in Application."""
922class RequestKey(BaseKey[_T]):
923 """Keys for static typing support in Request."""
926class ResponseKey(BaseKey[_T]):
927 """Keys for static typing support in Response."""
930@final
931class ChainMapProxy(Mapping[str | AppKey[Any], Any]):
932 __slots__ = ("_maps",)
934 def __init__(self, maps: Iterable[Mapping[str | AppKey[Any], Any]]) -> None:
935 self._maps = tuple(maps)
937 def __init_subclass__(cls) -> None:
938 raise TypeError(
939 f"Inheritance class {cls.__name__} from ChainMapProxy is forbidden"
940 )
942 @overload # type: ignore[override]
943 def __getitem__(self, key: AppKey[_T]) -> _T: ...
945 @overload
946 def __getitem__(self, key: str) -> Any: ...
948 def __getitem__(self, key: str | AppKey[_T]) -> Any:
949 for mapping in self._maps:
950 try:
951 return mapping[key]
952 except KeyError:
953 pass
954 raise KeyError(key)
956 @overload # type: ignore[override]
957 def get(self, key: AppKey[_T], default: _S) -> _T | _S: ...
959 @overload
960 def get(self, key: AppKey[_T], default: None = ...) -> _T | None: ...
962 @overload
963 def get(self, key: str, default: Any = ...) -> Any: ...
965 def get(self, key: str | AppKey[_T], default: Any = None) -> Any:
966 try:
967 return self[key]
968 except KeyError:
969 return default
971 def __len__(self) -> int:
972 # reuses stored hash values if possible
973 return len(set().union(*self._maps))
975 def __iter__(self) -> Iterator[str | AppKey[Any]]:
976 d: dict[str | AppKey[Any], Any] = {}
977 for mapping in reversed(self._maps):
978 # reuses stored hash values if possible
979 d.update(mapping)
980 return iter(d)
982 def __contains__(self, key: object) -> bool:
983 return any(key in m for m in self._maps)
985 def __bool__(self) -> bool:
986 return any(self._maps)
988 def __repr__(self) -> str:
989 content = ", ".join(map(repr, self._maps))
990 return f"ChainMapProxy({content})"
993class CookieMixin:
994 """Mixin for handling cookies."""
996 _cookies: SimpleCookie | None = None
998 @property
999 def cookies(self) -> SimpleCookie:
1000 if self._cookies is None:
1001 self._cookies = SimpleCookie()
1002 return self._cookies
1004 def set_cookie(
1005 self,
1006 name: str,
1007 value: str,
1008 *,
1009 expires: str | None = None,
1010 domain: str | None = None,
1011 max_age: int | str | None = None,
1012 path: str = "/",
1013 secure: bool | None = None,
1014 httponly: bool | None = None,
1015 samesite: str | None = None,
1016 partitioned: bool | None = None,
1017 ) -> None:
1018 """Set or update response cookie.
1020 Sets new cookie or updates existent with new value.
1021 Also updates only those params which are not None.
1022 """
1023 if self._cookies is None:
1024 self._cookies = SimpleCookie()
1026 self._cookies[name] = value
1027 c = self._cookies[name]
1029 if expires is not None:
1030 c["expires"] = expires
1031 elif c.get("expires") == "Thu, 01 Jan 1970 00:00:00 GMT":
1032 del c["expires"]
1034 if domain is not None:
1035 c["domain"] = domain
1037 if max_age is not None:
1038 c["max-age"] = str(max_age)
1039 elif "max-age" in c:
1040 del c["max-age"]
1042 c["path"] = path
1044 if secure is not None:
1045 c["secure"] = secure
1046 if httponly is not None:
1047 c["httponly"] = httponly
1048 if samesite is not None:
1049 c["samesite"] = samesite
1051 if partitioned is not None:
1052 c["partitioned"] = partitioned
1054 if DEBUG:
1055 cookie_length = len(c.output(header="")[1:])
1056 if cookie_length > COOKIE_MAX_LENGTH:
1057 warnings.warn(
1058 "The size of is too large, it might get ignored by the client.",
1059 UserWarning,
1060 stacklevel=2,
1061 )
1063 def del_cookie(
1064 self,
1065 name: str,
1066 *,
1067 domain: str | None = None,
1068 path: str = "/",
1069 secure: bool | None = None,
1070 httponly: bool | None = None,
1071 samesite: str | None = None,
1072 ) -> None:
1073 """Delete cookie.
1075 Creates new empty expired cookie.
1076 """
1077 # TODO: do we need domain/path here?
1078 if self._cookies is not None:
1079 self._cookies.pop(name, None)
1080 self.set_cookie(
1081 name,
1082 "",
1083 max_age=0,
1084 expires="Thu, 01 Jan 1970 00:00:00 GMT",
1085 domain=domain,
1086 path=path,
1087 secure=secure,
1088 httponly=httponly,
1089 samesite=samesite,
1090 )
1093def populate_with_cookies(headers: "CIMultiDict[str]", cookies: SimpleCookie) -> None:
1094 for cookie in cookies.values():
1095 value = cookie.output(header="")[1:]
1096 headers.add(hdrs.SET_COOKIE, value)
1099# https://tools.ietf.org/html/rfc7232#section-2.3
1100_ETAGC = r"[!\x23-\x7E\x80-\xff]+"
1101_ETAGC_RE = re.compile(_ETAGC)
1102_QUOTED_ETAG = rf'(W/)?"({_ETAGC})"'
1103QUOTED_ETAG_RE = re.compile(_QUOTED_ETAG)
1104LIST_QUOTED_ETAG_RE = re.compile(rf"({_QUOTED_ETAG})(?:\s*,\s*|$)|(.)")
1106ETAG_ANY = "*"
1109@frozen_dataclass_decorator
1110class ETag:
1111 value: str
1112 is_weak: bool = False
1115def validate_etag_value(value: str) -> None:
1116 if value != ETAG_ANY and not _ETAGC_RE.fullmatch(value):
1117 raise ValueError(
1118 f"Value {value!r} is not a valid etag. Maybe it contains '\"'?"
1119 )
1122def parse_http_date(date_str: str | None) -> datetime.datetime | None:
1123 """Process a date string, return a datetime object"""
1124 if date_str is not None:
1125 timetuple = parsedate(date_str)
1126 if timetuple is not None:
1127 with suppress(ValueError):
1128 return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc)
1129 return None
1132@functools.lru_cache
1133def must_be_empty_body(method: str, code: int) -> bool:
1134 """Check if a request must return an empty body."""
1135 return (
1136 code in EMPTY_BODY_STATUS_CODES
1137 or method in EMPTY_BODY_METHODS
1138 or (200 <= code < 300 and method in hdrs.METH_CONNECT_ALL)
1139 )
1142def should_remove_content_length(method: str, code: int) -> bool:
1143 """Check if a Content-Length header should be removed.
1145 This should always be a subset of must_be_empty_body
1146 """
1147 # https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-8
1148 # https://www.rfc-editor.org/rfc/rfc9110.html#section-15.4.5-4
1149 return code in EMPTY_BODY_STATUS_CODES or (
1150 200 <= code < 300 and method in hdrs.METH_CONNECT_ALL
1151 )