Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/aiohttp/helpers.py: 43%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Various helper functions"""
3import asyncio
4import base64
5import binascii
6import contextlib
7import datetime
8import enum
9import functools
10import inspect
11import netrc
12import os
13import platform
14import re
15import sys
16import time
17import warnings
18import weakref
19from collections import namedtuple
20from collections.abc import Callable, Generator, Iterable, Iterator, Mapping
21from contextlib import suppress
22from email.message import EmailMessage
23from email.parser import HeaderParser
24from email.policy import HTTP
25from email.utils import parsedate
26from math import ceil
27from pathlib import Path
28from types import MappingProxyType, TracebackType
29from typing import (
30 Any,
31 ContextManager,
32 Generic,
33 Optional,
34 Protocol,
35 TypeVar,
36 get_args,
37 overload,
38)
39from urllib.parse import quote
40from urllib.request import getproxies, proxy_bypass
42import attr
43from multidict import MultiDict, MultiDictProxy, MultiMapping
44from propcache.api import under_cached_property as reify
45from yarl import URL
47from . import hdrs
48from .log import client_logger
50if sys.version_info >= (3, 11):
51 import asyncio as async_timeout
52else:
53 import async_timeout
55__all__ = ("BasicAuth", "ChainMapProxy", "ETag", "reify")
57IS_MACOS = platform.system() == "Darwin"
58IS_WINDOWS = platform.system() == "Windows"
60PY_311 = sys.version_info >= (3, 11)
62# This is the default size/limit for several operations.
63# Matches the max size we receive from sockets:
64# https://github.com/python/cpython/blob/1857a40807daeae3a1bf5efb682de9c9ae6df845/Lib/asyncio/selector_events.py#L766
65DEFAULT_CHUNK_SIZE = 2**18 # 256 KiB
67_T = TypeVar("_T")
68_S = TypeVar("_S")
70_SENTINEL = enum.Enum("_SENTINEL", "sentinel")
71sentinel = _SENTINEL.sentinel
73NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS"))
75# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
76EMPTY_BODY_STATUS_CODES = frozenset((204, 304, *range(100, 200)))
77# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
78# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.2
79EMPTY_BODY_METHODS = hdrs.METH_HEAD_ALL
81DEBUG = sys.flags.dev_mode or (
82 not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG"))
83)
86CHAR = {chr(i) for i in range(0, 128)}
87CTL = {chr(i) for i in range(0, 32)} | {
88 chr(127),
89}
90SEPARATORS = {
91 "(",
92 ")",
93 "<",
94 ">",
95 "@",
96 ",",
97 ";",
98 ":",
99 "\\",
100 '"',
101 "/",
102 "[",
103 "]",
104 "?",
105 "=",
106 "{",
107 "}",
108 " ",
109 chr(9),
110}
111TOKEN = CHAR ^ CTL ^ SEPARATORS
114class noop:
115 def __await__(self) -> Generator[None, None, None]:
116 yield
119def encode_basic_auth(login: str, password: str = "", encoding: str = "utf-8") -> str:
120 """Encode HTTP Basic Authentication credentials as an Authorization header value.
122 Returns a string of the form ``"Basic <base64>"`` suitable for use as the
123 value of the ``Authorization`` (or ``Proxy-Authorization``) header.
124 """
125 if ":" in login:
126 raise ValueError('A ":" is not allowed in login (RFC 7617#section-2)')
127 creds = f"{login}:{password}".encode(encoding)
128 return "Basic " + base64.b64encode(creds).decode(encoding)
131class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])):
132 """Http basic authentication helper."""
134 def __new__(
135 cls, login: str, password: str = "", encoding: str = "latin1"
136 ) -> "BasicAuth":
137 if login is None:
138 raise ValueError("None is not allowed as login value")
140 if password is None:
141 raise ValueError("None is not allowed as password value")
143 if ":" in login:
144 raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)')
146 warnings.warn(
147 "BasicAuth is deprecated and will be removed in aiohttp 4.0; "
148 "use aiohttp.encode_basic_auth() with "
149 "headers={'Authorization': ...} instead",
150 DeprecationWarning,
151 stacklevel=2,
152 )
153 return super().__new__(cls, login, password, encoding)
155 @classmethod
156 def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth":
157 """Create a BasicAuth object from an Authorization HTTP header."""
158 try:
159 auth_type, encoded_credentials = auth_header.split(" ", 1)
160 except ValueError:
161 raise ValueError("Could not parse authorization header.")
163 if auth_type.lower() != "basic":
164 raise ValueError("Unknown authorization method %s" % auth_type)
166 try:
167 decoded = base64.b64decode(
168 encoded_credentials.encode("ascii"), validate=True
169 ).decode(encoding)
170 except binascii.Error:
171 raise ValueError("Invalid base64 encoding.")
173 try:
174 # RFC 2617 HTTP Authentication
175 # https://www.ietf.org/rfc/rfc2617.txt
176 # the colon must be present, but the username and password may be
177 # otherwise blank.
178 username, password = decoded.split(":", 1)
179 except ValueError:
180 raise ValueError("Invalid credentials.")
182 return _basic_auth_no_warn(username, password, encoding)
184 @classmethod
185 def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]:
186 """Create BasicAuth from url."""
187 if not isinstance(url, URL):
188 raise TypeError("url should be yarl.URL instance")
189 # Check raw_user and raw_password first as yarl is likely
190 # to already have these values parsed from the netloc in the cache.
191 if url.raw_user is None and url.raw_password is None:
192 return None
193 return _basic_auth_no_warn(url.user or "", url.password or "", encoding)
195 def encode(self) -> str:
196 """Encode credentials."""
197 return encode_basic_auth(self.login, self.password, self.encoding)
200def _basic_auth_no_warn(
201 login: str, password: str = "", encoding: str = "latin1"
202) -> BasicAuth:
203 """Construct a BasicAuth without emitting the deprecation warning.
205 For internal use only. Bypasses BasicAuth.__new__ so that aiohttp's own
206 machinery doesn't trigger deprecation warnings in user code.
207 """
208 return tuple.__new__(BasicAuth, (login, password, encoding))
211def strip_auth_from_url(url: URL) -> tuple[URL, BasicAuth | None]:
212 """Remove user and password from URL if present and return BasicAuth object."""
213 # Check raw_user and raw_password first as yarl is likely
214 # to already have these values parsed from the netloc in the cache.
215 if url.raw_user is None and url.raw_password is None:
216 return url, None
217 return url.with_user(None), _basic_auth_no_warn(url.user or "", url.password or "")
220def netrc_from_env() -> netrc.netrc | None:
221 """Load netrc from file.
223 Attempt to load it from the path specified by the env-var
224 NETRC or in the default location in the user's home directory.
226 Returns None if it couldn't be found or fails to parse.
227 """
228 netrc_env = os.environ.get("NETRC")
230 if netrc_env is not None:
231 netrc_path = Path(netrc_env)
232 else:
233 try:
234 home_dir = Path.home()
235 except RuntimeError as e: # pragma: no cover
236 # if pathlib can't resolve home, it may raise a RuntimeError
237 client_logger.debug(
238 "Could not resolve home directory when "
239 "trying to look for .netrc file: %s",
240 e,
241 )
242 return None
244 netrc_path = home_dir / ("_netrc" if IS_WINDOWS else ".netrc")
246 try:
247 return netrc.netrc(str(netrc_path))
248 except netrc.NetrcParseError as e:
249 client_logger.warning("Could not parse .netrc file: %s", e)
250 except OSError as e:
251 netrc_exists = False
252 with contextlib.suppress(OSError):
253 netrc_exists = netrc_path.is_file()
254 # we couldn't read the file (doesn't exist, permissions, etc.)
255 if netrc_env or netrc_exists:
256 # only warn if the environment wanted us to load it,
257 # or it appears like the default file does actually exist
258 client_logger.warning("Could not read .netrc file: %s", e)
260 return None
263@attr.s(auto_attribs=True, frozen=True, slots=True)
264class ProxyInfo:
265 proxy: URL
266 proxy_auth: BasicAuth | None
269def basicauth_from_netrc(netrc_obj: netrc.netrc | None, host: str) -> BasicAuth:
270 """
271 Return :py:class:`~aiohttp.BasicAuth` credentials for ``host`` from ``netrc_obj``.
273 :raises LookupError: if ``netrc_obj`` is :py:data:`None` or if no
274 entry is found for the ``host``.
275 """
276 if netrc_obj is None:
277 raise LookupError("No .netrc file found")
278 auth_from_netrc = netrc_obj.authenticators(host)
280 if auth_from_netrc is None:
281 raise LookupError(f"No entry for {host!s} found in the `.netrc` file.")
282 login, account, password = auth_from_netrc
284 # TODO(PY311): username = login or account
285 # Up to python 3.10, account could be None if not specified,
286 # and login will be empty string if not specified. From 3.11,
287 # login and account will be empty string if not specified.
288 username = login if (login or account is None) else account
290 # TODO(PY311): Remove this, as password will be empty string
291 # if not specified
292 if password is None:
293 password = ""
295 return _basic_auth_no_warn(username, password)
298def proxies_from_env() -> dict[str, ProxyInfo]:
299 proxy_urls = {
300 k: URL(v)
301 for k, v in getproxies().items()
302 if k in ("http", "https", "ws", "wss")
303 }
304 netrc_obj = netrc_from_env()
305 stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()}
306 ret = {}
307 for proto, val in stripped.items():
308 proxy, auth = val
309 if proxy.scheme in ("https", "wss"):
310 client_logger.warning(
311 "%s proxies %s are not supported, ignoring", proxy.scheme.upper(), proxy
312 )
313 continue
314 if netrc_obj and auth is None:
315 if proxy.host is not None:
316 try:
317 auth = basicauth_from_netrc(netrc_obj, proxy.host)
318 except LookupError:
319 auth = None
320 ret[proto] = ProxyInfo(proxy, auth)
321 return ret
324def get_env_proxy_for_url(url: URL) -> tuple[URL, BasicAuth | None]:
325 """Get a permitted proxy for the given URL from the env."""
326 if url.host is not None and proxy_bypass(url.host):
327 raise LookupError(f"Proxying is disallowed for `{url.host!r}`")
329 proxies_in_env = proxies_from_env()
330 try:
331 proxy_info = proxies_in_env[url.scheme]
332 except KeyError:
333 raise LookupError(f"No proxies found for `{url!s}` in the env")
334 else:
335 return proxy_info.proxy, proxy_info.proxy_auth
338@attr.s(auto_attribs=True, frozen=True, slots=True)
339class MimeType:
340 type: str
341 subtype: str
342 suffix: str
343 parameters: "MultiDictProxy[str]"
346@functools.lru_cache(maxsize=56)
347def parse_mimetype(mimetype: str) -> MimeType:
348 """Parses a MIME type into its components.
350 mimetype is a MIME type string.
352 Returns a MimeType object.
354 Example:
356 >>> parse_mimetype('text/html; charset=utf-8')
357 MimeType(type='text', subtype='html', suffix='',
358 parameters={'charset': 'utf-8'})
360 """
361 if not mimetype:
362 return MimeType(
363 type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict())
364 )
366 parts = mimetype.split(";")
367 params: MultiDict[str] = MultiDict()
368 for item in parts[1:]:
369 if not item:
370 continue
371 key, _, value = item.partition("=")
372 params.add(key.lower().strip(), value.strip(' "'))
374 fulltype = parts[0].strip().lower()
375 if fulltype == "*":
376 fulltype = "*/*"
378 mtype, _, stype = fulltype.partition("/")
379 stype, _, suffix = stype.partition("+")
381 return MimeType(
382 type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params)
383 )
386class EnsureOctetStream(EmailMessage):
387 def __init__(self) -> None:
388 super().__init__()
389 # https://www.rfc-editor.org/rfc/rfc9110#section-8.3-5
390 self.set_default_type("application/octet-stream")
392 def get_content_type(self) -> str:
393 """Re-implementation from Message
395 Returns application/octet-stream in place of plain/text when
396 value is wrong.
398 The way this class is used guarantees that content-type will
399 be present so simplify the checks wrt to the base implementation.
400 """
401 value = self.get("content-type", "").lower()
403 # Based on the implementation of _splitparam in the standard library
404 ctype, _, _ = value.partition(";")
405 ctype = ctype.strip()
406 if ctype.count("/") != 1:
407 return self.get_default_type()
408 return ctype
411@functools.lru_cache(maxsize=56)
412def parse_content_type(raw: str) -> tuple[str, MappingProxyType[str, str]]:
413 """Parse Content-Type header.
415 Returns a tuple of the parsed content type and a
416 MappingProxyType of parameters. The default returned value
417 is `application/octet-stream`
418 """
419 msg = HeaderParser(EnsureOctetStream, policy=HTTP).parsestr(f"Content-Type: {raw}")
420 content_type = msg.get_content_type()
421 params = msg.get_params(())
422 content_dict = dict(params[1:]) # First element is content type again
423 return content_type, MappingProxyType(content_dict)
426def guess_filename(obj: Any, default: str | None = None) -> str | None:
427 name = getattr(obj, "name", None)
428 if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">":
429 return Path(name).name
430 return default
433not_qtext_re = re.compile(r"[^\041\043-\133\135-\176]")
434QCONTENT = {chr(i) for i in range(0x20, 0x7F)} | {"\t"}
437def quoted_string(content: str) -> str:
438 """Return 7-bit content as quoted-string.
440 Format content into a quoted-string as defined in RFC5322 for
441 Internet Message Format. Notice that this is not the 8-bit HTTP
442 format, but the 7-bit email format. Content must be in usascii or
443 a ValueError is raised.
444 """
445 if not (QCONTENT > set(content)):
446 raise ValueError(f"bad content for quoted-string {content!r}")
447 return not_qtext_re.sub(lambda x: "\\" + x.group(0), content)
450def content_disposition_header(
451 disptype: str, quote_fields: bool = True, _charset: str = "utf-8", **params: str
452) -> str:
453 """Sets ``Content-Disposition`` header for MIME.
455 This is the MIME payload Content-Disposition header from RFC 2183
456 and RFC 7579 section 4.2, not the HTTP Content-Disposition from
457 RFC 6266.
459 disptype is a disposition type: inline, attachment, form-data.
460 Should be valid extension token (see RFC 2183)
462 quote_fields performs value quoting to 7-bit MIME headers
463 according to RFC 7578. Set to quote_fields to False if recipient
464 can take 8-bit file names and field values.
466 _charset specifies the charset to use when quote_fields is True.
468 params is a dict with disposition params.
469 """
470 if not disptype or not (TOKEN > set(disptype)):
471 raise ValueError(f"bad content disposition type {disptype!r}")
473 value = disptype
474 if params:
475 lparams = []
476 for key, val in params.items():
477 if not key or not (TOKEN > set(key)):
478 raise ValueError(f"bad content disposition parameter {key!r}={val!r}")
479 if quote_fields:
480 if key.lower() == "filename":
481 qval = quote(val, "", encoding=_charset)
482 lparams.append((key, '"%s"' % qval))
483 else:
484 try:
485 qval = quoted_string(val)
486 except ValueError:
487 qval = "".join(
488 (_charset, "''", quote(val, "", encoding=_charset))
489 )
490 lparams.append((key + "*", qval))
491 else:
492 lparams.append((key, '"%s"' % qval))
493 else:
494 qval = val.replace("\\", "\\\\").replace('"', '\\"')
495 lparams.append((key, '"%s"' % qval))
496 sparams = "; ".join("=".join(pair) for pair in lparams)
497 value = "; ".join((value, sparams))
498 return value
501def is_ip_address(host: str | None) -> bool:
502 """Check if host looks like an IP Address.
504 This check is only meant as a heuristic to ensure that
505 a host is not a domain name.
506 """
507 if not host:
508 return False
509 # For a host to be an ipv4 address, it must be all numeric.
510 # The host must contain a colon to be an IPv6 address.
511 return ":" in host or host.replace(".", "").isdigit()
514def is_canonical_ipv4_address(host: str) -> bool:
515 """Check if host is a canonical dotted-quad IPv4 address.
517 Rejects the legacy numeric forms that ``socket`` still accepts and
518 maps onto an address, e.g. ``2130706433``, ``017700000001``, ``127.1``.
519 """
520 parts = host.split(".")
521 if len(parts) != 4:
522 return False
523 for part in parts:
524 # Each octet must be 1-3 ASCII digits; reject unicode digits
525 # (which ``str.isdigit`` accepts but ``int`` may not), octal
526 # leading zeros, and values above 255.
527 if not (1 <= len(part) <= 3) or not part.isascii() or not part.isdigit():
528 return False
529 if part[0] == "0" and len(part) != 1:
530 return False
531 if int(part) > 255:
532 return False
533 return True
536_cached_current_datetime: int | None = None
537_cached_formatted_datetime = ""
540def rfc822_formatted_time() -> str:
541 global _cached_current_datetime
542 global _cached_formatted_datetime
544 now = int(time.time())
545 if now != _cached_current_datetime:
546 # Weekday and month names for HTTP date/time formatting;
547 # always English!
548 # Tuples are constants stored in codeobject!
549 _weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun")
550 _monthname = (
551 "", # Dummy so we can use 1-based month numbers
552 "Jan",
553 "Feb",
554 "Mar",
555 "Apr",
556 "May",
557 "Jun",
558 "Jul",
559 "Aug",
560 "Sep",
561 "Oct",
562 "Nov",
563 "Dec",
564 )
566 year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now)
567 _cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
568 _weekdayname[wd],
569 day,
570 _monthname[month],
571 year,
572 hh,
573 mm,
574 ss,
575 )
576 _cached_current_datetime = now
577 return _cached_formatted_datetime
580def _weakref_handle(info: "tuple[weakref.ref[object], str]") -> None:
581 ref, name = info
582 ob = ref()
583 if ob is not None:
584 with suppress(Exception):
585 getattr(ob, name)()
588def weakref_handle(
589 ob: object,
590 name: str,
591 timeout: float,
592 loop: asyncio.AbstractEventLoop,
593 timeout_ceil_threshold: float = 5,
594) -> asyncio.TimerHandle | None:
595 if timeout is not None and timeout > 0:
596 when = loop.time() + timeout
597 if timeout >= timeout_ceil_threshold:
598 when = ceil(when)
600 return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name))
601 return None
604def call_later(
605 cb: Callable[[], Any],
606 timeout: float,
607 loop: asyncio.AbstractEventLoop,
608 timeout_ceil_threshold: float = 5,
609) -> asyncio.TimerHandle | None:
610 if timeout is None or timeout <= 0:
611 return None
612 now = loop.time()
613 when = calculate_timeout_when(now, timeout, timeout_ceil_threshold)
614 return loop.call_at(when, cb)
617def calculate_timeout_when(
618 loop_time: float,
619 timeout: float,
620 timeout_ceiling_threshold: float,
621) -> float:
622 """Calculate when to execute a timeout."""
623 when = loop_time + timeout
624 if timeout > timeout_ceiling_threshold:
625 return ceil(when)
626 return when
629class TimeoutHandle:
630 """Timeout handle"""
632 __slots__ = ("_timeout", "_loop", "_ceil_threshold", "_callbacks")
634 def __init__(
635 self,
636 loop: asyncio.AbstractEventLoop,
637 timeout: float | None,
638 ceil_threshold: float = 5,
639 ) -> None:
640 self._timeout = timeout
641 self._loop = loop
642 self._ceil_threshold = ceil_threshold
643 self._callbacks: list[
644 tuple[Callable[..., None], tuple[Any, ...], dict[str, Any]]
645 ] = []
647 def register(
648 self, callback: Callable[..., None], *args: Any, **kwargs: Any
649 ) -> None:
650 self._callbacks.append((callback, args, kwargs))
652 def close(self) -> None:
653 self._callbacks.clear()
655 def start(self) -> asyncio.TimerHandle | None:
656 timeout = self._timeout
657 if timeout is not None and timeout > 0:
658 when = self._loop.time() + timeout
659 if timeout >= self._ceil_threshold:
660 when = ceil(when)
661 return self._loop.call_at(when, self.__call__)
662 else:
663 return None
665 def timer(self) -> "BaseTimerContext":
666 if self._timeout is not None and self._timeout > 0:
667 timer = TimerContext(self._loop)
668 self.register(timer.timeout)
669 return timer
670 else:
671 return TimerNoop()
673 def __call__(self) -> None:
674 for cb, args, kwargs in self._callbacks:
675 with suppress(Exception):
676 cb(*args, **kwargs)
678 self._callbacks.clear()
681class BaseTimerContext(ContextManager["BaseTimerContext"]):
683 __slots__ = ()
685 def assert_timeout(self) -> None:
686 """Raise TimeoutError if timeout has been exceeded."""
689class TimerNoop(BaseTimerContext):
691 __slots__ = ()
693 def __enter__(self) -> BaseTimerContext:
694 return self
696 def __exit__(
697 self,
698 exc_type: type[BaseException] | None,
699 exc_val: BaseException | None,
700 exc_tb: TracebackType | None,
701 ) -> None:
702 return
705class TimerContext(BaseTimerContext):
706 """Low resolution timeout context manager"""
708 __slots__ = ("_loop", "_tasks", "_cancelled", "_cancelling")
710 def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
711 self._loop = loop
712 self._tasks: list[asyncio.Task[Any]] = []
713 self._cancelled = False
714 self._cancelling = 0
716 def assert_timeout(self) -> None:
717 """Raise TimeoutError if timer has already been cancelled."""
718 if self._cancelled:
719 raise asyncio.TimeoutError from None
721 def __enter__(self) -> BaseTimerContext:
722 task = asyncio.current_task(loop=self._loop)
723 if task is None:
724 raise RuntimeError("Timeout context manager should be used inside a task")
726 if sys.version_info >= (3, 11):
727 # Remember if the task was already cancelling
728 # so when we __exit__ we can decide if we should
729 # raise asyncio.TimeoutError or let the cancellation propagate
730 self._cancelling = task.cancelling()
732 if self._cancelled:
733 raise asyncio.TimeoutError from None
735 self._tasks.append(task)
736 return self
738 def __exit__(
739 self,
740 exc_type: type[BaseException] | None,
741 exc_val: BaseException | None,
742 exc_tb: TracebackType | None,
743 ) -> bool | None:
744 enter_task: asyncio.Task[Any] | None = None
745 if self._tasks:
746 enter_task = self._tasks.pop()
748 if exc_type is asyncio.CancelledError and self._cancelled:
749 assert enter_task is not None
750 # The timeout was hit, and the task was cancelled
751 # so we need to uncancel the last task that entered the context manager
752 # since the cancellation should not leak out of the context manager
753 if sys.version_info >= (3, 11):
754 # If the task was already cancelling don't raise
755 # asyncio.TimeoutError and instead return None
756 # to allow the cancellation to propagate
757 if enter_task.uncancel() > self._cancelling:
758 return None
759 raise asyncio.TimeoutError from exc_val
760 return None
762 def timeout(self) -> None:
763 if not self._cancelled:
764 for task in set(self._tasks):
765 task.cancel()
767 self._cancelled = True
770def ceil_timeout(
771 delay: float | None, ceil_threshold: float = 5
772) -> async_timeout.Timeout:
773 if delay is None or delay <= 0:
774 return async_timeout.timeout(None)
776 loop = asyncio.get_running_loop()
777 now = loop.time()
778 when = now + delay
779 if delay > ceil_threshold:
780 when = ceil(when)
781 return async_timeout.timeout_at(when)
784class HeadersMixin:
785 """Mixin for handling headers."""
787 ATTRS = frozenset(["_content_type", "_content_dict", "_stored_content_type"])
789 _headers: MultiMapping[str]
790 _content_type: str | None = None
791 _content_dict: dict[str, str] | None = None
792 _stored_content_type: str | None | _SENTINEL = sentinel
794 def _parse_content_type(self, raw: str | None) -> None:
795 self._stored_content_type = raw
796 if raw is None:
797 # default value according to RFC 2616
798 self._content_type = "application/octet-stream"
799 self._content_dict = {}
800 else:
801 content_type, content_mapping_proxy = parse_content_type(raw)
802 self._content_type = content_type
803 # _content_dict needs to be mutable so we can update it
804 self._content_dict = content_mapping_proxy.copy()
806 @property
807 def content_type(self) -> str:
808 """The value of content part for Content-Type HTTP header."""
809 raw = self._headers.get(hdrs.CONTENT_TYPE)
810 if self._stored_content_type != raw:
811 self._parse_content_type(raw)
812 assert self._content_type is not None
813 return self._content_type
815 @property
816 def charset(self) -> str | None:
817 """The value of charset part for Content-Type HTTP header."""
818 raw = self._headers.get(hdrs.CONTENT_TYPE)
819 if self._stored_content_type != raw:
820 self._parse_content_type(raw)
821 assert self._content_dict is not None
822 return self._content_dict.get("charset")
824 @property
825 def content_length(self) -> int | None:
826 """The value of Content-Length HTTP header."""
827 content_length = self._headers.get(hdrs.CONTENT_LENGTH)
828 return None if content_length is None else int(content_length)
831def set_result(fut: "asyncio.Future[_T]", result: _T) -> None:
832 if not fut.done():
833 fut.set_result(result)
836_EXC_SENTINEL = BaseException()
839class ErrorableProtocol(Protocol):
840 def set_exception(
841 self,
842 exc: BaseException,
843 exc_cause: BaseException = ...,
844 ) -> None: ... # pragma: no cover
847def set_exception(
848 fut: "asyncio.Future[_T] | ErrorableProtocol",
849 exc: BaseException,
850 exc_cause: BaseException = _EXC_SENTINEL,
851) -> None:
852 """Set future exception.
854 If the future is marked as complete, this function is a no-op.
856 :param exc_cause: An exception that is a direct cause of ``exc``.
857 Only set if provided.
858 """
859 if asyncio.isfuture(fut) and fut.done():
860 return
862 exc_is_sentinel = exc_cause is _EXC_SENTINEL
863 exc_causes_itself = exc is exc_cause
864 if not exc_is_sentinel and not exc_causes_itself:
865 exc.__cause__ = exc_cause
867 fut.set_exception(exc)
870@functools.total_ordering
871class BaseKey(Generic[_T]):
872 """Base for concrete context storage key classes.
874 Each storage is provided with its own sub-class for the sake of some additional type safety.
875 """
877 __slots__ = ("_name", "_t", "__orig_class__")
879 # This may be set by Python when instantiating with a generic type. We need to
880 # support this, in order to support types that are not concrete classes,
881 # like Iterable, which can't be passed as the second parameter to __init__.
882 __orig_class__: type[object]
884 def __init__(self, name: str, t: type[_T] | None = None):
885 # Prefix with module name to help deduplicate key names.
886 frame = inspect.currentframe()
887 while frame:
888 if frame.f_code.co_name == "<module>":
889 module: str = frame.f_globals["__name__"]
890 break
891 frame = frame.f_back
893 self._name = module + "." + name
894 self._t = t
896 def __lt__(self, other: object) -> bool:
897 if isinstance(other, BaseKey):
898 return self._name < other._name
899 return True # Order BaseKey above other types.
901 def __repr__(self) -> str:
902 t = self._t
903 if t is None:
904 with suppress(AttributeError):
905 # Set to type arg.
906 t = get_args(self.__orig_class__)[0]
908 if t is None:
909 t_repr = "<<Unknown>>"
910 elif isinstance(t, type):
911 if t.__module__ == "builtins":
912 t_repr = t.__qualname__
913 else:
914 t_repr = f"{t.__module__}.{t.__qualname__}"
915 else:
916 t_repr = repr(t)
917 return f"<{self.__class__.__name__}({self._name}, type={t_repr})>"
920class AppKey(BaseKey[_T]):
921 """Keys for static typing support in Application."""
924class RequestKey(BaseKey[_T]):
925 """Keys for static typing support in Request."""
928class ResponseKey(BaseKey[_T]):
929 """Keys for static typing support in Response."""
932class ChainMapProxy(Mapping[str | AppKey[Any], Any]):
933 __slots__ = ("_maps",)
935 def __init__(self, maps: Iterable[Mapping[str | AppKey[Any], Any]]) -> None:
936 self._maps = tuple(maps)
938 def __init_subclass__(cls) -> None:
939 raise TypeError(
940 f"Inheritance class {cls.__name__} from ChainMapProxy is forbidden"
941 )
943 @overload # type: ignore[override]
944 def __getitem__(self, key: AppKey[_T]) -> _T: ...
946 @overload
947 def __getitem__(self, key: str) -> Any: ...
949 def __getitem__(self, key: str | AppKey[_T]) -> Any:
950 for mapping in self._maps:
951 try:
952 return mapping[key]
953 except KeyError:
954 pass
955 raise KeyError(key)
957 @overload # type: ignore[override]
958 def get(self, key: AppKey[_T], default: _S) -> _T | _S: ...
960 @overload
961 def get(self, key: AppKey[_T], default: None = ...) -> _T | None: ...
963 @overload
964 def get(self, key: str, default: Any = ...) -> Any: ...
966 def get(self, key: str | AppKey[_T], default: Any = None) -> Any:
967 try:
968 return self[key]
969 except KeyError:
970 return default
972 def __len__(self) -> int:
973 # reuses stored hash values if possible
974 return len(set().union(*self._maps))
976 def __iter__(self) -> Iterator[str | AppKey[Any]]:
977 d: dict[str | AppKey[Any], Any] = {}
978 for mapping in reversed(self._maps):
979 # reuses stored hash values if possible
980 d.update(mapping)
981 return iter(d)
983 def __contains__(self, key: object) -> bool:
984 return any(key in m for m in self._maps)
986 def __bool__(self) -> bool:
987 return any(self._maps)
989 def __repr__(self) -> str:
990 content = ", ".join(map(repr, self._maps))
991 return f"ChainMapProxy({content})"
994# https://tools.ietf.org/html/rfc7232#section-2.3
995_ETAGC = r"[!\x23-\x7E\x80-\xff]+"
996_ETAGC_RE = re.compile(_ETAGC)
997_QUOTED_ETAG = rf'(W/)?"({_ETAGC})"'
998QUOTED_ETAG_RE = re.compile(_QUOTED_ETAG)
999LIST_QUOTED_ETAG_RE = re.compile(rf"({_QUOTED_ETAG})(?:\s*,\s*|$)|(.)")
1001ETAG_ANY = "*"
1004@attr.s(auto_attribs=True, frozen=True, slots=True)
1005class ETag:
1006 value: str
1007 is_weak: bool = False
1010def validate_etag_value(value: str) -> None:
1011 if value != ETAG_ANY and not _ETAGC_RE.fullmatch(value):
1012 raise ValueError(
1013 f"Value {value!r} is not a valid etag. Maybe it contains '\"'?"
1014 )
1017def parse_http_date(date_str: str | None) -> datetime.datetime | None:
1018 """Process a date string, return a datetime object"""
1019 if date_str is not None:
1020 timetuple = parsedate(date_str)
1021 if timetuple is not None:
1022 with suppress(ValueError):
1023 return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc)
1024 return None
1027@functools.lru_cache
1028def must_be_empty_body(method: str, code: int) -> bool:
1029 """Check if a request must return an empty body."""
1030 return (
1031 code in EMPTY_BODY_STATUS_CODES
1032 or method in EMPTY_BODY_METHODS
1033 or (200 <= code < 300 and method in hdrs.METH_CONNECT_ALL)
1034 )
1037def should_remove_content_length(method: str, code: int) -> bool:
1038 """Check if a Content-Length header should be removed.
1040 This should always be a subset of must_be_empty_body
1041 """
1042 # https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-8
1043 # https://www.rfc-editor.org/rfc/rfc9110.html#section-15.4.5-4
1044 return code in EMPTY_BODY_STATUS_CODES or (
1045 200 <= code < 300 and method in hdrs.METH_CONNECT_ALL
1046 )