Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/helpers.py: 36%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Various helper functions"""
3import asyncio
4import base64
5import binascii
6import contextlib
7import dataclasses
8import datetime
9import enum
10import functools
11import inspect
12import netrc
13import os
14import platform
15import re
16import sys
17import time
18import warnings
19import weakref
20from collections import namedtuple
21from collections.abc import Callable, Iterable, Iterator, Mapping
22from contextlib import suppress
23from email.message import EmailMessage
24from email.parser import HeaderParser
25from email.policy import HTTP
26from email.utils import parsedate
27from http.cookies import SimpleCookie
28from math import ceil
29from pathlib import Path
30from types import MappingProxyType, TracebackType
31from typing import (
32 TYPE_CHECKING,
33 Any,
34 ContextManager,
35 Generic,
36 Optional,
37 Protocol,
38 TypeVar,
39 Union,
40 final,
41 get_args,
42 overload,
43)
44from urllib.parse import quote
45from urllib.request import getproxies, proxy_bypass
47from multidict import CIMultiDict, MultiDict, MultiDictProxy, MultiMapping
48from propcache.api import under_cached_property as reify
49from yarl import URL
51from . import hdrs
52from .log import client_logger
53from .typedefs import PathLike # noqa
55if sys.version_info >= (3, 11):
56 import asyncio as async_timeout
57else:
58 import async_timeout
60if TYPE_CHECKING:
61 from dataclasses import dataclass as frozen_dataclass_decorator
62else:
63 frozen_dataclass_decorator = functools.partial(
64 dataclasses.dataclass, frozen=True, slots=True
65 )
67__all__ = ("BasicAuth", "ChainMapProxy", "ETag", "frozen_dataclass_decorator", "reify")
69COOKIE_MAX_LENGTH = 4096
71_T = TypeVar("_T")
72_S = TypeVar("_S")
74_SENTINEL = enum.Enum("_SENTINEL", "sentinel")
75sentinel = _SENTINEL.sentinel
77NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS"))
79# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
80EMPTY_BODY_STATUS_CODES = frozenset((204, 304, *range(100, 200)))
81# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
82# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.2
83EMPTY_BODY_METHODS = hdrs.METH_HEAD_ALL
85DEBUG = sys.flags.dev_mode or (
86 not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG"))
87)
90CHAR = {chr(i) for i in range(0, 128)}
91CTL = {chr(i) for i in range(0, 32)} | {
92 chr(127),
93}
94SEPARATORS = {
95 "(",
96 ")",
97 "<",
98 ">",
99 "@",
100 ",",
101 ";",
102 ":",
103 "\\",
104 '"',
105 "/",
106 "[",
107 "]",
108 "?",
109 "=",
110 "{",
111 "}",
112 " ",
113 chr(9),
114}
115TOKEN = CHAR ^ CTL ^ SEPARATORS
118json_re = re.compile(r"^(?:application/|[\w.-]+/[\w.+-]+?\+)json$", re.IGNORECASE)
121class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])):
122 """Http basic authentication helper."""
124 def __new__(
125 cls, login: str, password: str = "", encoding: str = "latin1"
126 ) -> "BasicAuth":
127 if login is None:
128 raise ValueError("None is not allowed as login value")
130 if password is None:
131 raise ValueError("None is not allowed as password value")
133 if ":" in login:
134 raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)')
136 return super().__new__(cls, login, password, encoding)
138 @classmethod
139 def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth":
140 """Create a BasicAuth object from an Authorization HTTP header."""
141 try:
142 auth_type, encoded_credentials = auth_header.split(" ", 1)
143 except ValueError:
144 raise ValueError("Could not parse authorization header.")
146 if auth_type.lower() != "basic":
147 raise ValueError("Unknown authorization method %s" % auth_type)
149 try:
150 decoded = base64.b64decode(
151 encoded_credentials.encode("ascii"), validate=True
152 ).decode(encoding)
153 except binascii.Error:
154 raise ValueError("Invalid base64 encoding.")
156 try:
157 # RFC 2617 HTTP Authentication
158 # https://www.ietf.org/rfc/rfc2617.txt
159 # the colon must be present, but the username and password may be
160 # otherwise blank.
161 username, password = decoded.split(":", 1)
162 except ValueError:
163 raise ValueError("Invalid credentials.")
165 return cls(username, password, encoding=encoding)
167 @classmethod
168 def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]:
169 """Create BasicAuth from url."""
170 if not isinstance(url, URL):
171 raise TypeError("url should be yarl.URL instance")
172 # Check raw_user and raw_password first as yarl is likely
173 # to already have these values parsed from the netloc in the cache.
174 if url.raw_user is None and url.raw_password is None:
175 return None
176 return cls(url.user or "", url.password or "", encoding=encoding)
178 def encode(self) -> str:
179 """Encode credentials."""
180 creds = (f"{self.login}:{self.password}").encode(self.encoding)
181 return "Basic %s" % base64.b64encode(creds).decode(self.encoding)
184def strip_auth_from_url(url: URL) -> tuple[URL, BasicAuth | None]:
185 """Remove user and password from URL if present and return BasicAuth object."""
186 # Check raw_user and raw_password first as yarl is likely
187 # to already have these values parsed from the netloc in the cache.
188 if url.raw_user is None and url.raw_password is None:
189 return url, None
190 return url.with_user(None), BasicAuth(url.user or "", url.password or "")
193def netrc_from_env() -> netrc.netrc | None:
194 """Load netrc from file.
196 Attempt to load it from the path specified by the env-var
197 NETRC or in the default location in the user's home directory.
199 Returns None if it couldn't be found or fails to parse.
200 """
201 netrc_env = os.environ.get("NETRC")
203 if netrc_env is not None:
204 netrc_path = Path(netrc_env)
205 else:
206 try:
207 home_dir = Path.home()
208 except RuntimeError as e:
209 # if pathlib can't resolve home, it may raise a RuntimeError
210 client_logger.debug(
211 "Could not resolve home directory when "
212 "trying to look for .netrc file: %s",
213 e,
214 )
215 return None
217 netrc_path = home_dir / (
218 "_netrc" if platform.system() == "Windows" else ".netrc"
219 )
221 try:
222 return netrc.netrc(str(netrc_path))
223 except netrc.NetrcParseError as e:
224 client_logger.warning("Could not parse .netrc file: %s", e)
225 except OSError as e:
226 netrc_exists = False
227 with contextlib.suppress(OSError):
228 netrc_exists = netrc_path.is_file()
229 # we couldn't read the file (doesn't exist, permissions, etc.)
230 if netrc_env or netrc_exists:
231 # only warn if the environment wanted us to load it,
232 # or it appears like the default file does actually exist
233 client_logger.warning("Could not read .netrc file: %s", e)
235 return None
238@frozen_dataclass_decorator
239class ProxyInfo:
240 proxy: URL
241 proxy_auth: BasicAuth | None
244def basicauth_from_netrc(netrc_obj: netrc.netrc | None, host: str) -> BasicAuth:
245 """
246 Return :py:class:`~aiohttp.BasicAuth` credentials for ``host`` from ``netrc_obj``.
248 :raises LookupError: if ``netrc_obj`` is :py:data:`None` or if no
249 entry is found for the ``host``.
250 """
251 if netrc_obj is None:
252 raise LookupError("No .netrc file found")
253 auth_from_netrc = netrc_obj.authenticators(host)
255 if auth_from_netrc is None:
256 raise LookupError(f"No entry for {host!s} found in the `.netrc` file.")
257 login, account, password = auth_from_netrc
259 # TODO(PY311): username = login or account
260 # Up to python 3.10, account could be None if not specified,
261 # and login will be empty string if not specified. From 3.11,
262 # login and account will be empty string if not specified.
263 username = login if (login or account is None) else account
265 # TODO(PY311): Remove this, as password will be empty string
266 # if not specified
267 if password is None:
268 password = "" # type: ignore[unreachable]
270 return BasicAuth(username, password)
273def proxies_from_env() -> dict[str, ProxyInfo]:
274 proxy_urls = {
275 k: URL(v)
276 for k, v in getproxies().items()
277 if k in ("http", "https", "ws", "wss")
278 }
279 netrc_obj = netrc_from_env()
280 stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()}
281 ret = {}
282 for proto, val in stripped.items():
283 proxy, auth = val
284 if proxy.scheme in ("https", "wss"):
285 client_logger.warning(
286 "%s proxies %s are not supported, ignoring", proxy.scheme.upper(), proxy
287 )
288 continue
289 if netrc_obj and auth is None:
290 if proxy.host is not None:
291 try:
292 auth = basicauth_from_netrc(netrc_obj, proxy.host)
293 except LookupError:
294 auth = None
295 ret[proto] = ProxyInfo(proxy, auth)
296 return ret
299def get_env_proxy_for_url(url: URL) -> tuple[URL, BasicAuth | None]:
300 """Get a permitted proxy for the given URL from the env."""
301 if url.host is not None and proxy_bypass(url.host):
302 raise LookupError(f"Proxying is disallowed for `{url.host!r}`")
304 proxies_in_env = proxies_from_env()
305 try:
306 proxy_info = proxies_in_env[url.scheme]
307 except KeyError:
308 raise LookupError(f"No proxies found for `{url!s}` in the env")
309 else:
310 return proxy_info.proxy, proxy_info.proxy_auth
313@frozen_dataclass_decorator
314class MimeType:
315 type: str
316 subtype: str
317 suffix: str
318 parameters: "MultiDictProxy[str]"
321@functools.lru_cache(maxsize=56)
322def parse_mimetype(mimetype: str) -> MimeType:
323 """Parses a MIME type into its components.
325 mimetype is a MIME type string.
327 Returns a MimeType object.
329 Example:
331 >>> parse_mimetype('text/html; charset=utf-8')
332 MimeType(type='text', subtype='html', suffix='',
333 parameters={'charset': 'utf-8'})
335 """
336 if not mimetype:
337 return MimeType(
338 type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict())
339 )
341 parts = mimetype.split(";")
342 params: MultiDict[str] = MultiDict()
343 for item in parts[1:]:
344 if not item:
345 continue
346 key, _, value = item.partition("=")
347 params.add(key.lower().strip(), value.strip(' "'))
349 fulltype = parts[0].strip().lower()
350 if fulltype == "*":
351 fulltype = "*/*"
353 mtype, _, stype = fulltype.partition("/")
354 stype, _, suffix = stype.partition("+")
356 return MimeType(
357 type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params)
358 )
361class EnsureOctetStream(EmailMessage):
362 def __init__(self) -> None:
363 super().__init__()
364 # https://www.rfc-editor.org/rfc/rfc9110#section-8.3-5
365 self.set_default_type("application/octet-stream")
367 def get_content_type(self) -> str:
368 """Re-implementation from Message
370 Returns application/octet-stream in place of plain/text when
371 value is wrong.
373 The way this class is used guarantees that content-type will
374 be present so simplify the checks wrt to the base implementation.
375 """
376 value = self.get("content-type", "").lower()
378 # Based on the implementation of _splitparam in the standard library
379 ctype, _, _ = value.partition(";")
380 ctype = ctype.strip()
381 if ctype.count("/") != 1:
382 return self.get_default_type()
383 return ctype
386@functools.lru_cache(maxsize=56)
387def parse_content_type(raw: str) -> tuple[str, MappingProxyType[str, str]]:
388 """Parse Content-Type header.
390 Returns a tuple of the parsed content type and a
391 MappingProxyType of parameters. The default returned value
392 is `application/octet-stream`
393 """
394 msg = HeaderParser(EnsureOctetStream, policy=HTTP).parsestr(f"Content-Type: {raw}")
395 content_type = msg.get_content_type()
396 params = msg.get_params(())
397 content_dict = dict(params[1:]) # First element is content type again
398 return content_type, MappingProxyType(content_dict)
401def guess_filename(obj: Any, default: str | None = None) -> str | None:
402 name = getattr(obj, "name", None)
403 if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">":
404 return Path(name).name
405 return default
408not_qtext_re = re.compile(r"[^\041\043-\133\135-\176]")
409QCONTENT = {chr(i) for i in range(0x20, 0x7F)} | {"\t"}
412def quoted_string(content: str) -> str:
413 """Return 7-bit content as quoted-string.
415 Format content into a quoted-string as defined in RFC5322 for
416 Internet Message Format. Notice that this is not the 8-bit HTTP
417 format, but the 7-bit email format. Content must be in usascii or
418 a ValueError is raised.
419 """
420 if not (QCONTENT > set(content)):
421 raise ValueError(f"bad content for quoted-string {content!r}")
422 return not_qtext_re.sub(lambda x: "\\" + x.group(0), content)
425def content_disposition_header(
426 disptype: str,
427 quote_fields: bool = True,
428 _charset: str = "utf-8",
429 params: dict[str, str] | None = None,
430) -> str:
431 """Sets ``Content-Disposition`` header for MIME.
433 This is the MIME payload Content-Disposition header from RFC 2183
434 and RFC 7579 section 4.2, not the HTTP Content-Disposition from
435 RFC 6266.
437 disptype is a disposition type: inline, attachment, form-data.
438 Should be valid extension token (see RFC 2183)
440 quote_fields performs value quoting to 7-bit MIME headers
441 according to RFC 7578. Set to quote_fields to False if recipient
442 can take 8-bit file names and field values.
444 _charset specifies the charset to use when quote_fields is True.
446 params is a dict with disposition params.
447 """
448 if not disptype or not (TOKEN > set(disptype)):
449 raise ValueError(f"bad content disposition type {disptype!r}")
451 value = disptype
452 if params:
453 lparams = []
454 for key, val in params.items():
455 if not key or not (TOKEN > set(key)):
456 raise ValueError(f"bad content disposition parameter {key!r}={val!r}")
457 if quote_fields:
458 if key.lower() == "filename":
459 qval = quote(val, "", encoding=_charset)
460 lparams.append((key, '"%s"' % qval))
461 else:
462 try:
463 qval = quoted_string(val)
464 except ValueError:
465 qval = "".join(
466 (_charset, "''", quote(val, "", encoding=_charset))
467 )
468 lparams.append((key + "*", qval))
469 else:
470 lparams.append((key, '"%s"' % qval))
471 else:
472 qval = val.replace("\\", "\\\\").replace('"', '\\"')
473 lparams.append((key, '"%s"' % qval))
474 sparams = "; ".join("=".join(pair) for pair in lparams)
475 value = "; ".join((value, sparams))
476 return value
479def is_expected_content_type(
480 response_content_type: str, expected_content_type: str
481) -> bool:
482 """Checks if received content type is processable as an expected one.
484 Both arguments should be given without parameters.
485 """
486 if expected_content_type == "application/json":
487 return json_re.match(response_content_type) is not None
488 return expected_content_type in response_content_type
491def is_ip_address(host: str | None) -> bool:
492 """Check if host looks like an IP Address.
494 This check is only meant as a heuristic to ensure that
495 a host is not a domain name.
496 """
497 if not host:
498 return False
499 # For a host to be an ipv4 address, it must be all numeric.
500 # The host must contain a colon to be an IPv6 address.
501 return ":" in host or host.replace(".", "").isdigit()
504_cached_current_datetime: int | None = None
505_cached_formatted_datetime = ""
508def rfc822_formatted_time() -> str:
509 global _cached_current_datetime
510 global _cached_formatted_datetime
512 now = int(time.time())
513 if now != _cached_current_datetime:
514 # Weekday and month names for HTTP date/time formatting;
515 # always English!
516 # Tuples are constants stored in codeobject!
517 _weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun")
518 _monthname = (
519 "", # Dummy so we can use 1-based month numbers
520 "Jan",
521 "Feb",
522 "Mar",
523 "Apr",
524 "May",
525 "Jun",
526 "Jul",
527 "Aug",
528 "Sep",
529 "Oct",
530 "Nov",
531 "Dec",
532 )
534 year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now)
535 _cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
536 _weekdayname[wd],
537 day,
538 _monthname[month],
539 year,
540 hh,
541 mm,
542 ss,
543 )
544 _cached_current_datetime = now
545 return _cached_formatted_datetime
548def _weakref_handle(info: "tuple[weakref.ref[object], str]") -> None:
549 ref, name = info
550 ob = ref()
551 if ob is not None:
552 with suppress(Exception):
553 getattr(ob, name)()
556def weakref_handle(
557 ob: object,
558 name: str,
559 timeout: float | None,
560 loop: asyncio.AbstractEventLoop,
561 timeout_ceil_threshold: float = 5,
562) -> asyncio.TimerHandle | None:
563 if timeout is not None and timeout > 0:
564 when = loop.time() + timeout
565 if timeout >= timeout_ceil_threshold:
566 when = ceil(when)
568 return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name))
569 return None
572def call_later(
573 cb: Callable[[], Any],
574 timeout: float | None,
575 loop: asyncio.AbstractEventLoop,
576 timeout_ceil_threshold: float = 5,
577) -> asyncio.TimerHandle | None:
578 if timeout is None or timeout <= 0:
579 return None
580 now = loop.time()
581 when = calculate_timeout_when(now, timeout, timeout_ceil_threshold)
582 return loop.call_at(when, cb)
585def calculate_timeout_when(
586 loop_time: float,
587 timeout: float,
588 timeout_ceiling_threshold: float,
589) -> float:
590 """Calculate when to execute a timeout."""
591 when = loop_time + timeout
592 if timeout > timeout_ceiling_threshold:
593 return ceil(when)
594 return when
597class TimeoutHandle:
598 """Timeout handle"""
600 __slots__ = ("_timeout", "_loop", "_ceil_threshold", "_callbacks")
602 def __init__(
603 self,
604 loop: asyncio.AbstractEventLoop,
605 timeout: float | None,
606 ceil_threshold: float = 5,
607 ) -> None:
608 self._timeout = timeout
609 self._loop = loop
610 self._ceil_threshold = ceil_threshold
611 self._callbacks: list[
612 tuple[Callable[..., None], tuple[Any, ...], dict[str, Any]]
613 ] = []
615 def register(
616 self, callback: Callable[..., None], *args: Any, **kwargs: Any
617 ) -> None:
618 self._callbacks.append((callback, args, kwargs))
620 def close(self) -> None:
621 self._callbacks.clear()
623 def start(self) -> asyncio.TimerHandle | None:
624 timeout = self._timeout
625 if timeout is not None and timeout > 0:
626 when = self._loop.time() + timeout
627 if timeout >= self._ceil_threshold:
628 when = ceil(when)
629 return self._loop.call_at(when, self.__call__)
630 else:
631 return None
633 def timer(self) -> "BaseTimerContext":
634 if self._timeout is not None and self._timeout > 0:
635 timer = TimerContext(self._loop)
636 self.register(timer.timeout)
637 return timer
638 else:
639 return TimerNoop()
641 def __call__(self) -> None:
642 for cb, args, kwargs in self._callbacks:
643 with suppress(Exception):
644 cb(*args, **kwargs)
646 self._callbacks.clear()
649class BaseTimerContext(ContextManager["BaseTimerContext"]):
651 __slots__ = ()
653 def assert_timeout(self) -> None:
654 """Raise TimeoutError if timeout has been exceeded."""
657class TimerNoop(BaseTimerContext):
659 __slots__ = ()
661 def __enter__(self) -> BaseTimerContext:
662 return self
664 def __exit__(
665 self,
666 exc_type: type[BaseException] | None,
667 exc_val: BaseException | None,
668 exc_tb: TracebackType | None,
669 ) -> None:
670 return
673class TimerContext(BaseTimerContext):
674 """Low resolution timeout context manager"""
676 __slots__ = ("_loop", "_tasks", "_cancelled", "_cancelling")
678 def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
679 self._loop = loop
680 self._tasks: list[asyncio.Task[Any]] = []
681 self._cancelled = False
682 self._cancelling = 0
684 def assert_timeout(self) -> None:
685 """Raise TimeoutError if timer has already been cancelled."""
686 if self._cancelled:
687 raise asyncio.TimeoutError from None
689 def __enter__(self) -> BaseTimerContext:
690 task = asyncio.current_task(loop=self._loop)
691 if task is None:
692 raise RuntimeError("Timeout context manager should be used inside a task")
694 if sys.version_info >= (3, 11):
695 # Remember if the task was already cancelling
696 # so when we __exit__ we can decide if we should
697 # raise asyncio.TimeoutError or let the cancellation propagate
698 self._cancelling = task.cancelling()
700 if self._cancelled:
701 raise asyncio.TimeoutError from None
703 self._tasks.append(task)
704 return self
706 def __exit__(
707 self,
708 exc_type: type[BaseException] | None,
709 exc_val: BaseException | None,
710 exc_tb: TracebackType | None,
711 ) -> bool | None:
712 enter_task: asyncio.Task[Any] | None = None
713 if self._tasks:
714 enter_task = self._tasks.pop()
716 if exc_type is asyncio.CancelledError and self._cancelled:
717 assert enter_task is not None
718 # The timeout was hit, and the task was cancelled
719 # so we need to uncancel the last task that entered the context manager
720 # since the cancellation should not leak out of the context manager
721 if sys.version_info >= (3, 11):
722 # If the task was already cancelling don't raise
723 # asyncio.TimeoutError and instead return None
724 # to allow the cancellation to propagate
725 if enter_task.uncancel() > self._cancelling:
726 return None
727 raise asyncio.TimeoutError from exc_val
728 return None
730 def timeout(self) -> None:
731 if not self._cancelled:
732 for task in set(self._tasks):
733 task.cancel()
735 self._cancelled = True
738def ceil_timeout(
739 delay: float | None, ceil_threshold: float = 5
740) -> async_timeout.Timeout:
741 if delay is None or delay <= 0:
742 return async_timeout.timeout(None)
744 loop = asyncio.get_running_loop()
745 now = loop.time()
746 when = now + delay
747 if delay > ceil_threshold:
748 when = ceil(when)
749 return async_timeout.timeout_at(when)
752class HeadersMixin:
753 """Mixin for handling headers."""
755 _headers: MultiMapping[str]
756 _content_type: str | None = None
757 _content_dict: dict[str, str] | None = None
758 _stored_content_type: str | None | _SENTINEL = sentinel
760 def _parse_content_type(self, raw: str | None) -> None:
761 self._stored_content_type = raw
762 if raw is None:
763 # default value according to RFC 2616
764 self._content_type = "application/octet-stream"
765 self._content_dict = {}
766 else:
767 content_type, content_mapping_proxy = parse_content_type(raw)
768 self._content_type = content_type
769 # _content_dict needs to be mutable so we can update it
770 self._content_dict = content_mapping_proxy.copy()
772 @property
773 def content_type(self) -> str:
774 """The value of content part for Content-Type HTTP header."""
775 raw = self._headers.get(hdrs.CONTENT_TYPE)
776 if self._stored_content_type != raw:
777 self._parse_content_type(raw)
778 assert self._content_type is not None
779 return self._content_type
781 @property
782 def charset(self) -> str | None:
783 """The value of charset part for Content-Type HTTP header."""
784 raw = self._headers.get(hdrs.CONTENT_TYPE)
785 if self._stored_content_type != raw:
786 self._parse_content_type(raw)
787 assert self._content_dict is not None
788 return self._content_dict.get("charset")
790 @property
791 def content_length(self) -> int | None:
792 """The value of Content-Length HTTP header."""
793 content_length = self._headers.get(hdrs.CONTENT_LENGTH)
794 return None if content_length is None else int(content_length)
797def set_result(fut: "asyncio.Future[_T]", result: _T) -> None:
798 if not fut.done():
799 fut.set_result(result)
802_EXC_SENTINEL = BaseException()
805class ErrorableProtocol(Protocol):
806 def set_exception(
807 self,
808 exc: type[BaseException] | BaseException,
809 exc_cause: BaseException = ...,
810 ) -> None: ...
813def set_exception(
814 fut: Union["asyncio.Future[_T]", ErrorableProtocol],
815 exc: type[BaseException] | BaseException,
816 exc_cause: BaseException = _EXC_SENTINEL,
817) -> None:
818 """Set future exception.
820 If the future is marked as complete, this function is a no-op.
822 :param exc_cause: An exception that is a direct cause of ``exc``.
823 Only set if provided.
824 """
825 if asyncio.isfuture(fut) and fut.done():
826 return
828 exc_is_sentinel = exc_cause is _EXC_SENTINEL
829 exc_causes_itself = exc is exc_cause
830 if not exc_is_sentinel and not exc_causes_itself:
831 exc.__cause__ = exc_cause
833 fut.set_exception(exc)
836@functools.total_ordering
837class BaseKey(Generic[_T]):
838 """Base for concrete context storage key classes.
840 Each storage is provided with its own sub-class for the sake of some additional type safety.
841 """
843 __slots__ = ("_name", "_t", "__orig_class__")
845 # This may be set by Python when instantiating with a generic type. We need to
846 # support this, in order to support types that are not concrete classes,
847 # like Iterable, which can't be passed as the second parameter to __init__.
848 __orig_class__: type[object]
850 # TODO(PY314): Change Type to TypeForm (this should resolve unreachable below).
851 def __init__(self, name: str, t: type[_T] | None = None):
852 # Prefix with module name to help deduplicate key names.
853 frame = inspect.currentframe()
854 while frame:
855 if frame.f_code.co_name == "<module>":
856 module: str = frame.f_globals["__name__"]
857 break
858 frame = frame.f_back
859 else:
860 raise RuntimeError("Failed to get module name.")
862 # https://github.com/python/mypy/issues/14209
863 self._name = module + "." + name # type: ignore[possibly-undefined]
864 self._t = t
866 def __lt__(self, other: object) -> bool:
867 if isinstance(other, BaseKey):
868 return self._name < other._name
869 return True # Order BaseKey above other types.
871 def __repr__(self) -> str:
872 t = self._t
873 if t is None:
874 with suppress(AttributeError):
875 # Set to type arg.
876 t = get_args(self.__orig_class__)[0]
878 if t is None:
879 t_repr = "<<Unknown>>"
880 elif isinstance(t, type):
881 if t.__module__ == "builtins":
882 t_repr = t.__qualname__
883 else:
884 t_repr = f"{t.__module__}.{t.__qualname__}"
885 else:
886 t_repr = repr(t) # type: ignore[unreachable]
887 return f"<{self.__class__.__name__}({self._name}, type={t_repr})>"
890class AppKey(BaseKey[_T]):
891 """Keys for static typing support in Application."""
894class RequestKey(BaseKey[_T]):
895 """Keys for static typing support in Request."""
898class ResponseKey(BaseKey[_T]):
899 """Keys for static typing support in Response."""
902@final
903class ChainMapProxy(Mapping[str | AppKey[Any], Any]):
904 __slots__ = ("_maps",)
906 def __init__(self, maps: Iterable[Mapping[str | AppKey[Any], Any]]) -> None:
907 self._maps = tuple(maps)
909 def __init_subclass__(cls) -> None:
910 raise TypeError(
911 f"Inheritance class {cls.__name__} from ChainMapProxy is forbidden"
912 )
914 @overload # type: ignore[override]
915 def __getitem__(self, key: AppKey[_T]) -> _T: ...
917 @overload
918 def __getitem__(self, key: str) -> Any: ...
920 def __getitem__(self, key: str | AppKey[_T]) -> Any:
921 for mapping in self._maps:
922 try:
923 return mapping[key]
924 except KeyError:
925 pass
926 raise KeyError(key)
928 @overload # type: ignore[override]
929 def get(self, key: AppKey[_T], default: _S) -> _T | _S: ...
931 @overload
932 def get(self, key: AppKey[_T], default: None = ...) -> _T | None: ...
934 @overload
935 def get(self, key: str, default: Any = ...) -> Any: ...
937 def get(self, key: str | AppKey[_T], default: Any = None) -> Any:
938 try:
939 return self[key]
940 except KeyError:
941 return default
943 def __len__(self) -> int:
944 # reuses stored hash values if possible
945 return len(set().union(*self._maps))
947 def __iter__(self) -> Iterator[str | AppKey[Any]]:
948 d: dict[str | AppKey[Any], Any] = {}
949 for mapping in reversed(self._maps):
950 # reuses stored hash values if possible
951 d.update(mapping)
952 return iter(d)
954 def __contains__(self, key: object) -> bool:
955 return any(key in m for m in self._maps)
957 def __bool__(self) -> bool:
958 return any(self._maps)
960 def __repr__(self) -> str:
961 content = ", ".join(map(repr, self._maps))
962 return f"ChainMapProxy({content})"
965class CookieMixin:
966 """Mixin for handling cookies."""
968 _cookies: SimpleCookie | None = None
970 @property
971 def cookies(self) -> SimpleCookie:
972 if self._cookies is None:
973 self._cookies = SimpleCookie()
974 return self._cookies
976 def set_cookie(
977 self,
978 name: str,
979 value: str,
980 *,
981 expires: str | None = None,
982 domain: str | None = None,
983 max_age: int | str | None = None,
984 path: str = "/",
985 secure: bool | None = None,
986 httponly: bool | None = None,
987 samesite: str | None = None,
988 partitioned: bool | None = None,
989 ) -> None:
990 """Set or update response cookie.
992 Sets new cookie or updates existent with new value.
993 Also updates only those params which are not None.
994 """
995 if self._cookies is None:
996 self._cookies = SimpleCookie()
998 self._cookies[name] = value
999 c = self._cookies[name]
1001 if expires is not None:
1002 c["expires"] = expires
1003 elif c.get("expires") == "Thu, 01 Jan 1970 00:00:00 GMT":
1004 del c["expires"]
1006 if domain is not None:
1007 c["domain"] = domain
1009 if max_age is not None:
1010 c["max-age"] = str(max_age)
1011 elif "max-age" in c:
1012 del c["max-age"]
1014 c["path"] = path
1016 if secure is not None:
1017 c["secure"] = secure
1018 if httponly is not None:
1019 c["httponly"] = httponly
1020 if samesite is not None:
1021 c["samesite"] = samesite
1023 if partitioned is not None:
1024 c["partitioned"] = partitioned
1026 if DEBUG:
1027 cookie_length = len(c.output(header="")[1:])
1028 if cookie_length > COOKIE_MAX_LENGTH:
1029 warnings.warn(
1030 "The size of is too large, it might get ignored by the client.",
1031 UserWarning,
1032 stacklevel=2,
1033 )
1035 def del_cookie(
1036 self,
1037 name: str,
1038 *,
1039 domain: str | None = None,
1040 path: str = "/",
1041 secure: bool | None = None,
1042 httponly: bool | None = None,
1043 samesite: str | None = None,
1044 ) -> None:
1045 """Delete cookie.
1047 Creates new empty expired cookie.
1048 """
1049 # TODO: do we need domain/path here?
1050 if self._cookies is not None:
1051 self._cookies.pop(name, None)
1052 self.set_cookie(
1053 name,
1054 "",
1055 max_age=0,
1056 expires="Thu, 01 Jan 1970 00:00:00 GMT",
1057 domain=domain,
1058 path=path,
1059 secure=secure,
1060 httponly=httponly,
1061 samesite=samesite,
1062 )
1065def populate_with_cookies(headers: "CIMultiDict[str]", cookies: SimpleCookie) -> None:
1066 for cookie in cookies.values():
1067 value = cookie.output(header="")[1:]
1068 headers.add(hdrs.SET_COOKIE, value)
1071# https://tools.ietf.org/html/rfc7232#section-2.3
1072_ETAGC = r"[!\x23-\x7E\x80-\xff]+"
1073_ETAGC_RE = re.compile(_ETAGC)
1074_QUOTED_ETAG = rf'(W/)?"({_ETAGC})"'
1075QUOTED_ETAG_RE = re.compile(_QUOTED_ETAG)
1076LIST_QUOTED_ETAG_RE = re.compile(rf"({_QUOTED_ETAG})(?:\s*,\s*|$)|(.)")
1078ETAG_ANY = "*"
1081@frozen_dataclass_decorator
1082class ETag:
1083 value: str
1084 is_weak: bool = False
1087def validate_etag_value(value: str) -> None:
1088 if value != ETAG_ANY and not _ETAGC_RE.fullmatch(value):
1089 raise ValueError(
1090 f"Value {value!r} is not a valid etag. Maybe it contains '\"'?"
1091 )
1094def parse_http_date(date_str: str | None) -> datetime.datetime | None:
1095 """Process a date string, return a datetime object"""
1096 if date_str is not None:
1097 timetuple = parsedate(date_str)
1098 if timetuple is not None:
1099 with suppress(ValueError):
1100 return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc)
1101 return None
1104@functools.lru_cache
1105def must_be_empty_body(method: str, code: int) -> bool:
1106 """Check if a request must return an empty body."""
1107 return (
1108 code in EMPTY_BODY_STATUS_CODES
1109 or method in EMPTY_BODY_METHODS
1110 or (200 <= code < 300 and method in hdrs.METH_CONNECT_ALL)
1111 )
1114def should_remove_content_length(method: str, code: int) -> bool:
1115 """Check if a Content-Length header should be removed.
1117 This should always be a subset of must_be_empty_body
1118 """
1119 # https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-8
1120 # https://www.rfc-editor.org/rfc/rfc9110.html#section-15.4.5-4
1121 return code in EMPTY_BODY_STATUS_CODES or (
1122 200 <= code < 300 and method in hdrs.METH_CONNECT_ALL
1123 )