Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/helpers.py: 38%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Various helper functions"""
3import asyncio
4import base64
5import contextlib
6import dataclasses
7import datetime
8import enum
9import functools
10import inspect
11import netrc
12import os
13import platform
14import re
15import sys
16import time
17import warnings
18import weakref
19from collections.abc import Callable, Iterable, Iterator, Mapping
20from contextlib import suppress
21from email.message import EmailMessage
22from email.parser import HeaderParser
23from email.policy import HTTP
24from email.utils import parsedate
25from http.cookies import SimpleCookie
26from math import ceil
27from pathlib import Path
28from types import MappingProxyType, TracebackType
29from typing import (
30 TYPE_CHECKING,
31 Any,
32 ContextManager,
33 Generic,
34 Protocol,
35 TypeVar,
36 Union,
37 final,
38 get_args,
39 overload,
40)
41from urllib.parse import quote
42from urllib.request import getproxies, proxy_bypass
44from multidict import CIMultiDict, MultiDict, MultiDictProxy
45from propcache.api import under_cached_property as reify
46from yarl import URL
48from . import hdrs
49from .log import client_logger
50from .typedefs import PathLike # noqa
52if sys.version_info >= (3, 11):
53 import asyncio as async_timeout
54else:
55 import async_timeout
57if TYPE_CHECKING:
58 from dataclasses import dataclass as frozen_dataclass_decorator
59else:
60 frozen_dataclass_decorator = functools.partial(
61 dataclasses.dataclass, frozen=True, slots=True
62 )
64__all__ = ("ChainMapProxy", "ETag", "frozen_dataclass_decorator", "reify")
66# This is the default size/limit for several operations.
67# Matches the max size we receive from sockets:
68# https://github.com/python/cpython/blob/1857a40807daeae3a1bf5efb682de9c9ae6df845/Lib/asyncio/selector_events.py#L766
69DEFAULT_CHUNK_SIZE = 2**18 # 256 KiB
70COOKIE_MAX_LENGTH = 4096
71_QUOTED_PAIR_SUB = re.compile(r"\\(.)")
72_QUOTED_STRING = r'"(?:[^"\\]|\\.)*"'
73_ESCAPED_COMMENT = r"(?:[^()\\]|\\.)*"
74# Matches one element in a comma-separated header list.
75# Group 1: content of a top-level quoted-string (quotes stripped).
76# Group 2: an unquoted element (may contain parameter quoted-strings / comments).
77_LIST_ELEMENT_RE = re.compile(
78 rf"""
79 [ \t]*
80 (?:
81 "( (?:[^"\\]|\\.)* )" # group 1: top-level quoted-string
82 | ( # group 2: unquoted element
83 (?:
84 (?<=[^\s]=) {_QUOTED_STRING} # parameter quoted value
85 | (?<=\s) \( {_ESCAPED_COMMENT} \) # comment
86 | [^,] # any non-comma character
87 )+?
88 )
89 )
90 [ \t]* (?:,|\Z)
91 """,
92 re.VERBOSE,
93)
94# Finds parameter quoted-strings and comments inside an unquoted element for unescaping.
95_PROTECTED_RE = re.compile(
96 rf"""
97 (?<=[^\s]=) {_QUOTED_STRING} # parameter quoted-string
98 | (?<=\s) \( {_ESCAPED_COMMENT} \) # comment
99 """,
100 re.VERBOSE,
101)
103_T = TypeVar("_T")
104_S = TypeVar("_S")
106_SENTINEL = enum.Enum("_SENTINEL", "sentinel")
107sentinel = _SENTINEL.sentinel
109NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS"))
111# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
112EMPTY_BODY_STATUS_CODES = frozenset((204, 304, *range(100, 200)))
113# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
114# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.2
115EMPTY_BODY_METHODS = hdrs.METH_HEAD_ALL
117DEBUG = sys.flags.dev_mode or (
118 not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG"))
119)
122CHAR = {chr(i) for i in range(0, 128)}
123CTL = {chr(i) for i in range(0, 32)} | {
124 chr(127),
125}
126SEPARATORS = {
127 "(",
128 ")",
129 "<",
130 ">",
131 "@",
132 ",",
133 ";",
134 ":",
135 "\\",
136 '"',
137 "/",
138 "[",
139 "]",
140 "?",
141 "=",
142 "{",
143 "}",
144 " ",
145 chr(9),
146}
147TOKEN = CHAR ^ CTL ^ SEPARATORS
150json_re = re.compile(r"^(?:application/|[\w.-]+/[\w.+-]+?\+)json$", re.IGNORECASE)
153def encode_basic_auth(login: str, password: str = "", encoding: str = "utf-8") -> str:
154 """Encode HTTP Basic Authentication credentials as an Authorization header value.
156 Returns a string of the form ``"Basic <base64>"`` suitable for use as the
157 value of the ``Authorization`` (or ``Proxy-Authorization``) header.
158 """
159 if ":" in login:
160 raise ValueError('A ":" is not allowed in login (RFC 7617#section-2)')
161 creds = f"{login}:{password}".encode(encoding)
162 return "Basic " + base64.b64encode(creds).decode(encoding)
165def strip_auth_from_url(url: URL) -> tuple[URL, str | None]:
166 """Strip user/password from a URL and return the Authorization header value.
168 Returns a tuple of ``(url_without_credentials, authorization_header_value)``.
169 The header value is ``None`` if no credentials were present.
170 """
171 # Check raw_user and raw_password first as yarl is likely
172 # to already have these values parsed from the netloc in the cache.
173 if url.raw_user is None and url.raw_password is None:
174 return url, None
175 return url.with_user(None), encode_basic_auth(url.user or "", url.password or "")
178def netrc_from_env() -> netrc.netrc | None:
179 """Load netrc from file.
181 Attempt to load it from the path specified by the env-var
182 NETRC or in the default location in the user's home directory.
184 Returns None if it couldn't be found or fails to parse.
185 """
186 netrc_env = os.environ.get("NETRC")
188 if netrc_env is not None:
189 netrc_path = Path(netrc_env)
190 else:
191 try:
192 home_dir = Path.home()
193 except RuntimeError as e:
194 # if pathlib can't resolve home, it may raise a RuntimeError
195 client_logger.debug(
196 "Could not resolve home directory when "
197 "trying to look for .netrc file: %s",
198 e,
199 )
200 return None
202 netrc_path = home_dir / (
203 "_netrc" if platform.system() == "Windows" else ".netrc"
204 )
206 try:
207 return netrc.netrc(str(netrc_path))
208 except netrc.NetrcParseError as e:
209 client_logger.warning("Could not parse .netrc file: %s", e)
210 except OSError as e:
211 netrc_exists = False
212 with contextlib.suppress(OSError):
213 netrc_exists = netrc_path.is_file()
214 # we couldn't read the file (doesn't exist, permissions, etc.)
215 if netrc_env or netrc_exists:
216 # only warn if the environment wanted us to load it,
217 # or it appears like the default file does actually exist
218 client_logger.warning("Could not read .netrc file: %s", e)
220 return None
223@frozen_dataclass_decorator
224class ProxyInfo:
225 proxy: URL
226 proxy_auth: str | None
229def _auth_header_from_netrc(netrc_obj: netrc.netrc | None, host: str) -> str:
230 """Return a ``Proxy-Authorization`` header value for ``host`` from netrc.
232 :raises LookupError: if ``netrc_obj`` is :py:data:`None` or if no
233 entry is found for the ``host``.
234 """
235 if netrc_obj is None:
236 raise LookupError("No .netrc file found")
237 auth_from_netrc = netrc_obj.authenticators(host)
239 if auth_from_netrc is None:
240 raise LookupError(f"No entry for {host!s} found in the `.netrc` file.")
241 login, account, password = auth_from_netrc
243 # TODO(PY311): username = login or account
244 # Up to python 3.10, account could be None if not specified,
245 # and login will be empty string if not specified. From 3.11,
246 # login and account will be empty string if not specified.
247 username = login if (login or account is None) else account
249 # TODO(PY311): Remove this, as password will be empty string
250 # if not specified
251 if password is None:
252 password = "" # type: ignore[unreachable]
254 return encode_basic_auth(username, password)
257def proxies_from_env() -> dict[str, ProxyInfo]:
258 proxy_urls = {
259 k: URL(v)
260 for k, v in getproxies().items()
261 if k in ("http", "https", "ws", "wss")
262 }
263 netrc_obj = netrc_from_env()
264 stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()}
265 ret = {}
266 for proto, val in stripped.items():
267 proxy, auth = val
268 if proxy.scheme in ("https", "wss"):
269 client_logger.warning(
270 "%s proxies %s are not supported, ignoring", proxy.scheme.upper(), proxy
271 )
272 continue
273 if netrc_obj and auth is None:
274 if proxy.host is not None:
275 try:
276 auth = _auth_header_from_netrc(netrc_obj, proxy.host)
277 except LookupError:
278 auth = None
279 ret[proto] = ProxyInfo(proxy, auth)
280 return ret
283def get_env_proxy_for_url(url: URL) -> tuple[URL, str | None]:
284 """Get a permitted proxy for the given URL from the env."""
285 if url.host is not None and proxy_bypass(url.host):
286 raise LookupError(f"Proxying is disallowed for `{url.host!r}`")
288 proxies_in_env = proxies_from_env()
289 try:
290 proxy_info = proxies_in_env[url.scheme]
291 except KeyError:
292 raise LookupError(f"No proxies found for `{url!s}` in the env")
293 else:
294 return proxy_info.proxy, proxy_info.proxy_auth
297@frozen_dataclass_decorator
298class MimeType:
299 type: str
300 subtype: str
301 suffix: str
302 parameters: "MultiDictProxy[str]"
305@functools.lru_cache(maxsize=56)
306def parse_mimetype(mimetype: str) -> MimeType:
307 """Parses a MIME type into its components.
309 mimetype is a MIME type string.
311 Returns a MimeType object.
313 Example:
315 >>> parse_mimetype('text/html; charset=utf-8')
316 MimeType(type='text', subtype='html', suffix='',
317 parameters={'charset': 'utf-8'})
319 """
320 if not mimetype:
321 return MimeType(
322 type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict())
323 )
325 parts = mimetype.split(";")
326 params: MultiDict[str] = MultiDict()
327 for item in parts[1:]:
328 if not item:
329 continue
330 key, _, value = item.partition("=")
331 params.add(key.lower().strip(), value.strip(' "'))
333 fulltype = parts[0].strip().lower()
334 if fulltype == "*":
335 fulltype = "*/*"
337 mtype, _, stype = fulltype.partition("/")
338 stype, _, suffix = stype.partition("+")
340 return MimeType(
341 type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params)
342 )
345class EnsureOctetStream(EmailMessage):
346 def __init__(self) -> None:
347 super().__init__()
348 # https://www.rfc-editor.org/rfc/rfc9110#section-8.3-5
349 self.set_default_type("application/octet-stream")
351 def get_content_type(self) -> str:
352 """Re-implementation from Message
354 Returns application/octet-stream in place of plain/text when
355 value is wrong.
357 The way this class is used guarantees that content-type will
358 be present so simplify the checks wrt to the base implementation.
359 """
360 value = self.get("content-type", "").lower()
362 # Based on the implementation of _splitparam in the standard library
363 ctype, _, _ = value.partition(";")
364 ctype = ctype.strip()
365 if ctype.count("/") != 1:
366 return self.get_default_type()
367 return ctype
370@functools.lru_cache(maxsize=56)
371def parse_content_type(raw: str) -> tuple[str, MappingProxyType[str, str]]:
372 """Parse Content-Type header.
374 Returns a tuple of the parsed content type and a
375 MappingProxyType of parameters. The default returned value
376 is `application/octet-stream`
377 """
378 msg = HeaderParser(EnsureOctetStream, policy=HTTP).parsestr(f"Content-Type: {raw}")
379 content_type = msg.get_content_type()
380 params = msg.get_params(())
381 content_dict = dict(params[1:]) # First element is content type again
382 return content_type, MappingProxyType(content_dict)
385def guess_filename(obj: Any, default: str | None = None) -> str | None:
386 name = getattr(obj, "name", None)
387 if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">":
388 return Path(name).name
389 return default
392not_qtext_re = re.compile(r"[^\041\043-\133\135-\176]")
393QCONTENT = {chr(i) for i in range(0x20, 0x7F)} | {"\t"}
396def quoted_string(content: str) -> str:
397 """Return 7-bit content as quoted-string.
399 Format content into a quoted-string as defined in RFC5322 for
400 Internet Message Format. Notice that this is not the 8-bit HTTP
401 format, but the 7-bit email format. Content must be in usascii or
402 a ValueError is raised.
403 """
404 if not (QCONTENT > set(content)):
405 raise ValueError(f"bad content for quoted-string {content!r}")
406 return not_qtext_re.sub(lambda x: "\\" + x.group(0), content)
409def content_disposition_header(
410 disptype: str,
411 quote_fields: bool = True,
412 _charset: str = "utf-8",
413 params: dict[str, str] | None = None,
414) -> str:
415 """Sets ``Content-Disposition`` header for MIME.
417 This is the MIME payload Content-Disposition header from RFC 2183
418 and RFC 7579 section 4.2, not the HTTP Content-Disposition from
419 RFC 6266.
421 disptype is a disposition type: inline, attachment, form-data.
422 Should be valid extension token (see RFC 2183)
424 quote_fields performs value quoting to 7-bit MIME headers
425 according to RFC 7578. Set to quote_fields to False if recipient
426 can take 8-bit file names and field values.
428 _charset specifies the charset to use when quote_fields is True.
430 params is a dict with disposition params.
431 """
432 if not disptype or not (TOKEN > set(disptype)):
433 raise ValueError(f"bad content disposition type {disptype!r}")
435 value = disptype
436 if params:
437 lparams = []
438 for key, val in params.items():
439 if not key or not (TOKEN > set(key)):
440 raise ValueError(f"bad content disposition parameter {key!r}={val!r}")
441 if quote_fields:
442 if key.lower() == "filename":
443 qval = quote(val, "", encoding=_charset)
444 lparams.append((key, '"%s"' % qval))
445 else:
446 try:
447 qval = quoted_string(val)
448 except ValueError:
449 qval = "".join(
450 (_charset, "''", quote(val, "", encoding=_charset))
451 )
452 lparams.append((key + "*", qval))
453 else:
454 lparams.append((key, '"%s"' % qval))
455 else:
456 qval = val.replace("\\", "\\\\").replace('"', '\\"')
457 lparams.append((key, '"%s"' % qval))
458 sparams = "; ".join("=".join(pair) for pair in lparams)
459 value = "; ".join((value, sparams))
460 return value
463def is_expected_content_type(
464 response_content_type: str, expected_content_type: str
465) -> bool:
466 """Checks if received content type is processable as an expected one.
468 Both arguments should be given without parameters.
469 """
470 if expected_content_type == "application/json":
471 return json_re.match(response_content_type) is not None
472 return expected_content_type in response_content_type
475def is_ip_address(host: str | None) -> bool:
476 """Check if host looks like an IP Address.
478 This check is only meant as a heuristic to ensure that
479 a host is not a domain name.
480 """
481 if not host:
482 return False
483 # For a host to be an ipv4 address, it must be all numeric.
484 # The host must contain a colon to be an IPv6 address.
485 return ":" in host or host.replace(".", "").isdigit()
488def is_canonical_ipv4_address(host: str) -> bool:
489 """Check if host is a canonical dotted-quad IPv4 address.
491 Rejects the legacy numeric forms that ``socket`` still accepts and
492 maps onto an address, e.g. ``2130706433``, ``017700000001``, ``127.1``.
493 """
494 parts = host.split(".")
495 if len(parts) != 4:
496 return False
497 for part in parts:
498 # Each octet must be 1-3 ASCII digits; reject unicode digits
499 # (which ``str.isdigit`` accepts but ``int`` may not), octal
500 # leading zeros, and values above 255.
501 if not (1 <= len(part) <= 3) or not part.isascii() or not part.isdigit():
502 return False
503 if part[0] == "0" and len(part) != 1:
504 return False
505 if int(part) > 255:
506 return False
507 return True
510_cached_current_datetime: int | None = None
511_cached_formatted_datetime = ""
514def rfc822_formatted_time() -> str:
515 global _cached_current_datetime
516 global _cached_formatted_datetime
518 now = int(time.time())
519 if now != _cached_current_datetime:
520 # Weekday and month names for HTTP date/time formatting;
521 # always English!
522 # Tuples are constants stored in codeobject!
523 _weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun")
524 _monthname = (
525 "", # Dummy so we can use 1-based month numbers
526 "Jan",
527 "Feb",
528 "Mar",
529 "Apr",
530 "May",
531 "Jun",
532 "Jul",
533 "Aug",
534 "Sep",
535 "Oct",
536 "Nov",
537 "Dec",
538 )
540 year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now)
541 _cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
542 _weekdayname[wd],
543 day,
544 _monthname[month],
545 year,
546 hh,
547 mm,
548 ss,
549 )
550 _cached_current_datetime = now
551 return _cached_formatted_datetime
554def _weakref_handle(info: "tuple[weakref.ref[object], str]") -> None:
555 ref, name = info
556 ob = ref()
557 if ob is not None:
558 with suppress(Exception):
559 getattr(ob, name)()
562def weakref_handle(
563 ob: object,
564 name: str,
565 timeout: float | None,
566 loop: asyncio.AbstractEventLoop,
567 timeout_ceil_threshold: float = 5,
568) -> asyncio.TimerHandle | None:
569 if timeout is not None and timeout > 0:
570 when = loop.time() + timeout
571 if timeout >= timeout_ceil_threshold:
572 when = ceil(when)
574 return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name))
575 return None
578def call_later(
579 cb: Callable[[], Any],
580 timeout: float | None,
581 loop: asyncio.AbstractEventLoop,
582 timeout_ceil_threshold: float = 5,
583) -> asyncio.TimerHandle | None:
584 if timeout is None or timeout <= 0:
585 return None
586 now = loop.time()
587 when = calculate_timeout_when(now, timeout, timeout_ceil_threshold)
588 return loop.call_at(when, cb)
591def calculate_timeout_when(
592 loop_time: float,
593 timeout: float,
594 timeout_ceiling_threshold: float,
595) -> float:
596 """Calculate when to execute a timeout."""
597 when = loop_time + timeout
598 if timeout > timeout_ceiling_threshold:
599 return ceil(when)
600 return when
603class TimeoutHandle:
604 """Timeout handle"""
606 __slots__ = ("_timeout", "_loop", "_ceil_threshold", "_callbacks")
608 def __init__(
609 self,
610 loop: asyncio.AbstractEventLoop,
611 timeout: float | None,
612 ceil_threshold: float = 5,
613 ) -> None:
614 self._timeout = timeout
615 self._loop = loop
616 self._ceil_threshold = ceil_threshold
617 self._callbacks: list[
618 tuple[Callable[..., None], tuple[Any, ...], dict[str, Any]]
619 ] = []
621 def register(
622 self, callback: Callable[..., None], *args: Any, **kwargs: Any
623 ) -> None:
624 self._callbacks.append((callback, args, kwargs))
626 def close(self) -> None:
627 self._callbacks.clear()
629 def start(self) -> asyncio.TimerHandle | None:
630 timeout = self._timeout
631 if timeout is not None and timeout > 0:
632 when = self._loop.time() + timeout
633 if timeout >= self._ceil_threshold:
634 when = ceil(when)
635 return self._loop.call_at(when, self.__call__)
636 else:
637 return None
639 def timer(self) -> "BaseTimerContext":
640 if self._timeout is not None and self._timeout > 0:
641 timer = TimerContext(self._loop)
642 self.register(timer.timeout)
643 return timer
644 else:
645 return TimerNoop()
647 def __call__(self) -> None:
648 for cb, args, kwargs in self._callbacks:
649 with suppress(Exception):
650 cb(*args, **kwargs)
652 self._callbacks.clear()
655class BaseTimerContext(ContextManager["BaseTimerContext"]):
657 __slots__ = ()
659 def assert_timeout(self) -> None:
660 """Raise TimeoutError if timeout has been exceeded."""
663class TimerNoop(BaseTimerContext):
665 __slots__ = ()
667 def __enter__(self) -> BaseTimerContext:
668 return self
670 def __exit__(
671 self,
672 exc_type: type[BaseException] | None,
673 exc_val: BaseException | None,
674 exc_tb: TracebackType | None,
675 ) -> None:
676 return
679class TimerContext(BaseTimerContext):
680 """Low resolution timeout context manager"""
682 __slots__ = ("_loop", "_tasks", "_cancelled", "_cancelling")
684 def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
685 self._loop = loop
686 self._tasks: list[asyncio.Task[Any]] = []
687 self._cancelled = False
688 self._cancelling = 0
690 def assert_timeout(self) -> None:
691 """Raise TimeoutError if timer has already been cancelled."""
692 if self._cancelled:
693 raise asyncio.TimeoutError from None
695 def __enter__(self) -> BaseTimerContext:
696 task = asyncio.current_task(loop=self._loop)
697 if task is None:
698 raise RuntimeError("Timeout context manager should be used inside a task")
700 if sys.version_info >= (3, 11):
701 # Remember if the task was already cancelling
702 # so when we __exit__ we can decide if we should
703 # raise asyncio.TimeoutError or let the cancellation propagate
704 self._cancelling = task.cancelling()
706 if self._cancelled:
707 raise asyncio.TimeoutError from None
709 self._tasks.append(task)
710 return self
712 def __exit__(
713 self,
714 exc_type: type[BaseException] | None,
715 exc_val: BaseException | None,
716 exc_tb: TracebackType | None,
717 ) -> bool | None:
718 enter_task: asyncio.Task[Any] | None = None
719 if self._tasks:
720 enter_task = self._tasks.pop()
722 if exc_type is asyncio.CancelledError and self._cancelled:
723 assert enter_task is not None
724 # The timeout was hit, and the task was cancelled
725 # so we need to uncancel the last task that entered the context manager
726 # since the cancellation should not leak out of the context manager
727 if sys.version_info >= (3, 11):
728 # If the task was already cancelling don't raise
729 # asyncio.TimeoutError and instead return None
730 # to allow the cancellation to propagate
731 if enter_task.uncancel() > self._cancelling:
732 return None
733 raise asyncio.TimeoutError from exc_val
734 return None
736 def timeout(self) -> None:
737 if not self._cancelled:
738 for task in set(self._tasks):
739 task.cancel()
741 self._cancelled = True
744def ceil_timeout(
745 delay: float | None, ceil_threshold: float = 5
746) -> async_timeout.Timeout:
747 if delay is None or delay <= 0:
748 return async_timeout.timeout(None)
750 loop = asyncio.get_running_loop()
751 now = loop.time()
752 when = now + delay
753 if delay > ceil_threshold:
754 when = ceil(when)
755 return async_timeout.timeout_at(when)
758class HeadersDictProxy(Mapping[str, str]):
759 def __init__(self, md: CIMultiDict[str]):
760 self._md = md
762 def getall(self, key: str) -> tuple[str, ...]:
763 val = self.get(key, "")
764 unescape = _QUOTED_PAIR_SUB.sub
765 values = []
766 for m in _LIST_ELEMENT_RE.finditer(val):
767 qs = m.group(1)
768 if qs is not None:
769 values.append(unescape(r"\1", qs))
770 else:
771 raw = m.group(2).strip()
772 if raw:
773 values.append(
774 _PROTECTED_RE.sub(lambda p: unescape(r"\1", p.group()), raw)
775 )
776 return tuple(values)
778 def __eq__(self, other: object) -> bool:
779 return self._md.__eq__(other)
781 def __getitem__(self, key: str) -> str:
782 return ", ".join(self._md.getall(key))
784 def __iter__(self) -> Iterator[str]:
785 # We need to deduplicate keys from MultiDict
786 # But, we also need to retain ordering
787 seen = set()
788 for k in self._md.__iter__():
789 if k in seen:
790 continue
791 seen.add(k)
792 yield k
794 def __len__(self) -> int:
795 return len(set(self._md.keys()))
797 def __repr__(self) -> str:
798 body = ", ".join(f"'{k}': {v!r}" for k, v in self.items())
799 return f"<{self.__class__.__name__}({body})>"
802class HeadersMixin:
803 """Mixin for handling headers."""
805 _headers: Mapping[str, str]
806 _content_type: str | None = None
807 _content_dict: dict[str, str] | None = None
808 _stored_content_type: str | None | _SENTINEL = sentinel
810 def _parse_content_type(self, raw: str | None) -> None:
811 self._stored_content_type = raw
812 if raw is None:
813 # default value according to RFC 2616
814 self._content_type = "application/octet-stream"
815 self._content_dict = {}
816 else:
817 content_type, content_mapping_proxy = parse_content_type(raw)
818 self._content_type = content_type
819 # _content_dict needs to be mutable so we can update it
820 self._content_dict = content_mapping_proxy.copy()
822 @property
823 def content_type(self) -> str:
824 """The value of content part for Content-Type HTTP header."""
825 raw = self._headers.get(hdrs.CONTENT_TYPE)
826 if self._stored_content_type != raw:
827 self._parse_content_type(raw)
828 assert self._content_type is not None
829 return self._content_type
831 @property
832 def charset(self) -> str | None:
833 """The value of charset part for Content-Type HTTP header."""
834 raw = self._headers.get(hdrs.CONTENT_TYPE)
835 if self._stored_content_type != raw:
836 self._parse_content_type(raw)
837 assert self._content_dict is not None
838 return self._content_dict.get("charset")
840 @property
841 def content_length(self) -> int | None:
842 """The value of Content-Length HTTP header."""
843 content_length = self._headers.get(hdrs.CONTENT_LENGTH)
844 return None if content_length is None else int(content_length)
847def set_result(fut: "asyncio.Future[_T]", result: _T) -> None:
848 if not fut.done():
849 fut.set_result(result)
852_EXC_SENTINEL = BaseException()
855class ErrorableProtocol(Protocol):
856 def set_exception(
857 self,
858 exc: type[BaseException] | BaseException,
859 exc_cause: BaseException = ...,
860 ) -> None: ...
863def set_exception(
864 fut: Union["asyncio.Future[_T]", ErrorableProtocol],
865 exc: type[BaseException] | BaseException,
866 exc_cause: BaseException = _EXC_SENTINEL,
867) -> None:
868 """Set future exception.
870 If the future is marked as complete, this function is a no-op.
872 :param exc_cause: An exception that is a direct cause of ``exc``.
873 Only set if provided.
874 """
875 if asyncio.isfuture(fut) and fut.done():
876 return
878 exc_is_sentinel = exc_cause is _EXC_SENTINEL
879 exc_causes_itself = exc is exc_cause
880 if not exc_is_sentinel and not exc_causes_itself:
881 exc.__cause__ = exc_cause
883 fut.set_exception(exc)
886@functools.total_ordering
887class BaseKey(Generic[_T]):
888 """Base for concrete context storage key classes.
890 Each storage is provided with its own sub-class for the sake of some additional type safety.
891 """
893 __slots__ = ("_name", "_t", "__orig_class__")
895 # This may be set by Python when instantiating with a generic type. We need to
896 # support this, in order to support types that are not concrete classes,
897 # like Iterable, which can't be passed as the second parameter to __init__.
898 __orig_class__: type[object]
900 # TODO(PY314): Change Type to TypeForm (this should resolve unreachable below).
901 def __init__(self, name: str, t: type[_T] | None = None):
902 # Prefix with module name to help deduplicate key names.
903 frame = inspect.currentframe()
904 while frame:
905 if frame.f_code.co_name == "<module>":
906 module: str = frame.f_globals["__name__"]
907 break
908 frame = frame.f_back
909 else:
910 raise RuntimeError("Failed to get module name.")
912 # https://github.com/python/mypy/issues/14209
913 self._name = module + "." + name # type: ignore[possibly-undefined]
914 self._t = t
916 def __lt__(self, other: object) -> bool:
917 if isinstance(other, BaseKey):
918 return self._name < other._name
919 return True # Order BaseKey above other types.
921 def __repr__(self) -> str:
922 t = self._t
923 if t is None:
924 with suppress(AttributeError):
925 # Set to type arg.
926 t = get_args(self.__orig_class__)[0]
928 if t is None:
929 t_repr = "<<Unknown>>"
930 elif isinstance(t, type):
931 if t.__module__ == "builtins":
932 t_repr = t.__qualname__
933 else:
934 t_repr = f"{t.__module__}.{t.__qualname__}"
935 else:
936 t_repr = repr(t) # type: ignore[unreachable]
937 return f"<{self.__class__.__name__}({self._name}, type={t_repr})>"
940class AppKey(BaseKey[_T]):
941 """Keys for static typing support in Application."""
944class RequestKey(BaseKey[_T]):
945 """Keys for static typing support in Request."""
948class ResponseKey(BaseKey[_T]):
949 """Keys for static typing support in Response."""
952@final
953class ChainMapProxy(Mapping[str | AppKey[Any], Any]):
954 __slots__ = ("_maps",)
956 def __init__(self, maps: Iterable[Mapping[str | AppKey[Any], Any]]) -> None:
957 self._maps = tuple(maps)
959 def __init_subclass__(cls) -> None:
960 raise TypeError(
961 f"Inheritance class {cls.__name__} from ChainMapProxy is forbidden"
962 )
964 @overload # type: ignore[override]
965 def __getitem__(self, key: AppKey[_T]) -> _T: ...
967 @overload
968 def __getitem__(self, key: str) -> Any: ...
970 def __getitem__(self, key: str | AppKey[_T]) -> Any:
971 for mapping in self._maps:
972 try:
973 return mapping[key]
974 except KeyError:
975 pass
976 raise KeyError(key)
978 @overload # type: ignore[override]
979 def get(self, key: AppKey[_T], default: _S) -> _T | _S: ...
981 @overload
982 def get(self, key: AppKey[_T], default: None = ...) -> _T | None: ...
984 @overload
985 def get(self, key: str, default: Any = ...) -> Any: ...
987 def get(self, key: str | AppKey[_T], default: Any = None) -> Any:
988 try:
989 return self[key]
990 except KeyError:
991 return default
993 def __len__(self) -> int:
994 # reuses stored hash values if possible
995 return len(set().union(*self._maps))
997 def __iter__(self) -> Iterator[str | AppKey[Any]]:
998 d: dict[str | AppKey[Any], Any] = {}
999 for mapping in reversed(self._maps):
1000 # reuses stored hash values if possible
1001 d.update(mapping)
1002 return iter(d)
1004 def __contains__(self, key: object) -> bool:
1005 return any(key in m for m in self._maps)
1007 def __bool__(self) -> bool:
1008 return any(self._maps)
1010 def __repr__(self) -> str:
1011 content = ", ".join(map(repr, self._maps))
1012 return f"ChainMapProxy({content})"
1015class CookieMixin:
1016 """Mixin for handling cookies."""
1018 _cookies: SimpleCookie | None = None
1020 @property
1021 def cookies(self) -> SimpleCookie:
1022 if self._cookies is None:
1023 self._cookies = SimpleCookie()
1024 return self._cookies
1026 def set_cookie(
1027 self,
1028 name: str,
1029 value: str,
1030 *,
1031 expires: str | None = None,
1032 domain: str | None = None,
1033 max_age: int | str | None = None,
1034 path: str = "/",
1035 secure: bool | None = None,
1036 httponly: bool | None = None,
1037 samesite: str | None = None,
1038 partitioned: bool | None = None,
1039 ) -> None:
1040 """Set or update response cookie.
1042 Sets new cookie or updates existent with new value.
1043 Also updates only those params which are not None.
1044 """
1045 if self._cookies is None:
1046 self._cookies = SimpleCookie()
1048 self._cookies[name] = value
1049 c = self._cookies[name]
1051 if expires is not None:
1052 c["expires"] = expires
1053 elif c.get("expires") == "Thu, 01 Jan 1970 00:00:00 GMT":
1054 del c["expires"]
1056 if domain is not None:
1057 c["domain"] = domain
1059 if max_age is not None:
1060 c["max-age"] = str(max_age)
1061 elif "max-age" in c:
1062 del c["max-age"]
1064 c["path"] = path
1066 if secure is not None:
1067 c["secure"] = secure
1068 if httponly is not None:
1069 c["httponly"] = httponly
1070 if samesite is not None:
1071 c["samesite"] = samesite
1073 if partitioned is not None:
1074 c["partitioned"] = partitioned
1076 if DEBUG:
1077 cookie_length = len(c.output(header="")[1:])
1078 if cookie_length > COOKIE_MAX_LENGTH:
1079 warnings.warn(
1080 "The size of is too large, it might get ignored by the client.",
1081 UserWarning,
1082 stacklevel=2,
1083 )
1085 def del_cookie(
1086 self,
1087 name: str,
1088 *,
1089 domain: str | None = None,
1090 path: str = "/",
1091 secure: bool | None = None,
1092 httponly: bool | None = None,
1093 samesite: str | None = None,
1094 ) -> None:
1095 """Delete cookie.
1097 Creates new empty expired cookie.
1098 """
1099 # TODO: do we need domain/path here?
1100 if self._cookies is not None:
1101 self._cookies.pop(name, None)
1102 self.set_cookie(
1103 name,
1104 "",
1105 max_age=0,
1106 expires="Thu, 01 Jan 1970 00:00:00 GMT",
1107 domain=domain,
1108 path=path,
1109 secure=secure,
1110 httponly=httponly,
1111 samesite=samesite,
1112 )
1115def populate_with_cookies(headers: "CIMultiDict[str]", cookies: SimpleCookie) -> None:
1116 for cookie in cookies.values():
1117 value = cookie.output(header="")[1:]
1118 headers.add(hdrs.SET_COOKIE, value)
1121# https://tools.ietf.org/html/rfc7232#section-2.3
1122_ETAGC = r"[!\x23-\x7E\x80-\xff]+"
1123_ETAGC_RE = re.compile(_ETAGC)
1124_QUOTED_ETAG = rf'(W/)?"({_ETAGC})"'
1125QUOTED_ETAG_RE = re.compile(_QUOTED_ETAG)
1126LIST_QUOTED_ETAG_RE = re.compile(rf"({_QUOTED_ETAG})(?:\s*,\s*|$)|(.)")
1128ETAG_ANY = "*"
1131@frozen_dataclass_decorator
1132class ETag:
1133 value: str
1134 is_weak: bool = False
1137def validate_etag_value(value: str) -> None:
1138 if value != ETAG_ANY and not _ETAGC_RE.fullmatch(value):
1139 raise ValueError(
1140 f"Value {value!r} is not a valid etag. Maybe it contains '\"'?"
1141 )
1144def parse_http_date(date_str: str | None) -> datetime.datetime | None:
1145 """Process a date string, return a datetime object"""
1146 if date_str is not None:
1147 timetuple = parsedate(date_str)
1148 if timetuple is not None:
1149 with suppress(ValueError):
1150 return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc)
1151 return None
1154@functools.lru_cache
1155def must_be_empty_body(method: str, code: int) -> bool:
1156 """Check if a request must return an empty body."""
1157 return (
1158 code in EMPTY_BODY_STATUS_CODES
1159 or method in EMPTY_BODY_METHODS
1160 or (200 <= code < 300 and method in hdrs.METH_CONNECT_ALL)
1161 )
1164def should_remove_content_length(method: str, code: int) -> bool:
1165 """Check if a Content-Length header should be removed.
1167 This should always be a subset of must_be_empty_body
1168 """
1169 # https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-8
1170 # https://www.rfc-editor.org/rfc/rfc9110.html#section-15.4.5-4
1171 return code in EMPTY_BODY_STATUS_CODES or (
1172 200 <= code < 300 and method in hdrs.METH_CONNECT_ALL
1173 )