Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/aiohttp/helpers.py: 44%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Various helper functions"""
3import asyncio
4import base64
5import binascii
6import contextlib
7import datetime
8import enum
9import functools
10import inspect
11import netrc
12import os
13import platform
14import re
15import sys
16import time
17import weakref
18from collections import namedtuple
19from contextlib import suppress
20from email.message import EmailMessage
21from email.parser import HeaderParser
22from email.policy import HTTP
23from email.utils import parsedate
24from math import ceil
25from pathlib import Path
26from types import MappingProxyType, TracebackType
27from typing import (
28 Any,
29 Callable,
30 ContextManager,
31 Dict,
32 Generator,
33 Generic,
34 Iterable,
35 Iterator,
36 List,
37 Mapping,
38 Optional,
39 Protocol,
40 Tuple,
41 Type,
42 TypeVar,
43 Union,
44 get_args,
45 overload,
46)
47from urllib.parse import quote
48from urllib.request import getproxies, proxy_bypass
50import attr
51from multidict import MultiDict, MultiDictProxy, MultiMapping
52from propcache.api import under_cached_property as reify
53from yarl import URL
55from . import hdrs
56from .log import client_logger
58if sys.version_info >= (3, 11):
59 import asyncio as async_timeout
60else:
61 import async_timeout
63__all__ = ("BasicAuth", "ChainMapProxy", "ETag", "reify")
65IS_MACOS = platform.system() == "Darwin"
66IS_WINDOWS = platform.system() == "Windows"
68PY_310 = sys.version_info >= (3, 10)
69PY_311 = sys.version_info >= (3, 11)
72_T = TypeVar("_T")
73_S = TypeVar("_S")
75_SENTINEL = enum.Enum("_SENTINEL", "sentinel")
76sentinel = _SENTINEL.sentinel
78NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS"))
80# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
81EMPTY_BODY_STATUS_CODES = frozenset((204, 304, *range(100, 200)))
82# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
83# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.2
84EMPTY_BODY_METHODS = hdrs.METH_HEAD_ALL
86DEBUG = sys.flags.dev_mode or (
87 not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG"))
88)
91CHAR = {chr(i) for i in range(0, 128)}
92CTL = {chr(i) for i in range(0, 32)} | {
93 chr(127),
94}
95SEPARATORS = {
96 "(",
97 ")",
98 "<",
99 ">",
100 "@",
101 ",",
102 ";",
103 ":",
104 "\\",
105 '"',
106 "/",
107 "[",
108 "]",
109 "?",
110 "=",
111 "{",
112 "}",
113 " ",
114 chr(9),
115}
116TOKEN = CHAR ^ CTL ^ SEPARATORS
119class noop:
120 def __await__(self) -> Generator[None, None, None]:
121 yield
124class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])):
125 """Http basic authentication helper."""
127 def __new__(
128 cls, login: str, password: str = "", encoding: str = "latin1"
129 ) -> "BasicAuth":
130 if login is None:
131 raise ValueError("None is not allowed as login value")
133 if password is None:
134 raise ValueError("None is not allowed as password value")
136 if ":" in login:
137 raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)')
139 return super().__new__(cls, login, password, encoding)
141 @classmethod
142 def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth":
143 """Create a BasicAuth object from an Authorization HTTP header."""
144 try:
145 auth_type, encoded_credentials = auth_header.split(" ", 1)
146 except ValueError:
147 raise ValueError("Could not parse authorization header.")
149 if auth_type.lower() != "basic":
150 raise ValueError("Unknown authorization method %s" % auth_type)
152 try:
153 decoded = base64.b64decode(
154 encoded_credentials.encode("ascii"), validate=True
155 ).decode(encoding)
156 except binascii.Error:
157 raise ValueError("Invalid base64 encoding.")
159 try:
160 # RFC 2617 HTTP Authentication
161 # https://www.ietf.org/rfc/rfc2617.txt
162 # the colon must be present, but the username and password may be
163 # otherwise blank.
164 username, password = decoded.split(":", 1)
165 except ValueError:
166 raise ValueError("Invalid credentials.")
168 return cls(username, password, encoding=encoding)
170 @classmethod
171 def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]:
172 """Create BasicAuth from url."""
173 if not isinstance(url, URL):
174 raise TypeError("url should be yarl.URL instance")
175 # Check raw_user and raw_password first as yarl is likely
176 # to already have these values parsed from the netloc in the cache.
177 if url.raw_user is None and url.raw_password is None:
178 return None
179 return cls(url.user or "", url.password or "", encoding=encoding)
181 def encode(self) -> str:
182 """Encode credentials."""
183 creds = (f"{self.login}:{self.password}").encode(self.encoding)
184 return "Basic %s" % base64.b64encode(creds).decode(self.encoding)
187def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
188 """Remove user and password from URL if present and return BasicAuth object."""
189 # Check raw_user and raw_password first as yarl is likely
190 # to already have these values parsed from the netloc in the cache.
191 if url.raw_user is None and url.raw_password is None:
192 return url, None
193 return url.with_user(None), BasicAuth(url.user or "", url.password or "")
196def netrc_from_env() -> Optional[netrc.netrc]:
197 """Load netrc from file.
199 Attempt to load it from the path specified by the env-var
200 NETRC or in the default location in the user's home directory.
202 Returns None if it couldn't be found or fails to parse.
203 """
204 netrc_env = os.environ.get("NETRC")
206 if netrc_env is not None:
207 netrc_path = Path(netrc_env)
208 else:
209 try:
210 home_dir = Path.home()
211 except RuntimeError as e: # pragma: no cover
212 # if pathlib can't resolve home, it may raise a RuntimeError
213 client_logger.debug(
214 "Could not resolve home directory when "
215 "trying to look for .netrc file: %s",
216 e,
217 )
218 return None
220 netrc_path = home_dir / ("_netrc" if IS_WINDOWS else ".netrc")
222 try:
223 return netrc.netrc(str(netrc_path))
224 except netrc.NetrcParseError as e:
225 client_logger.warning("Could not parse .netrc file: %s", e)
226 except OSError as e:
227 netrc_exists = False
228 with contextlib.suppress(OSError):
229 netrc_exists = netrc_path.is_file()
230 # we couldn't read the file (doesn't exist, permissions, etc.)
231 if netrc_env or netrc_exists:
232 # only warn if the environment wanted us to load it,
233 # or it appears like the default file does actually exist
234 client_logger.warning("Could not read .netrc file: %s", e)
236 return None
239@attr.s(auto_attribs=True, frozen=True, slots=True)
240class ProxyInfo:
241 proxy: URL
242 proxy_auth: Optional[BasicAuth]
245def basicauth_from_netrc(netrc_obj: Optional[netrc.netrc], host: str) -> BasicAuth:
246 """
247 Return :py:class:`~aiohttp.BasicAuth` credentials for ``host`` from ``netrc_obj``.
249 :raises LookupError: if ``netrc_obj`` is :py:data:`None` or if no
250 entry is found for the ``host``.
251 """
252 if netrc_obj is None:
253 raise LookupError("No .netrc file found")
254 auth_from_netrc = netrc_obj.authenticators(host)
256 if auth_from_netrc is None:
257 raise LookupError(f"No entry for {host!s} found in the `.netrc` file.")
258 login, account, password = auth_from_netrc
260 # TODO(PY311): username = login or account
261 # Up to python 3.10, account could be None if not specified,
262 # and login will be empty string if not specified. From 3.11,
263 # login and account will be empty string if not specified.
264 username = login if (login or account is None) else account
266 # TODO(PY311): Remove this, as password will be empty string
267 # if not specified
268 if password is None:
269 password = ""
271 return BasicAuth(username, password)
274def proxies_from_env() -> Dict[str, ProxyInfo]:
275 proxy_urls = {
276 k: URL(v)
277 for k, v in getproxies().items()
278 if k in ("http", "https", "ws", "wss")
279 }
280 netrc_obj = netrc_from_env()
281 stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()}
282 ret = {}
283 for proto, val in stripped.items():
284 proxy, auth = val
285 if proxy.scheme in ("https", "wss"):
286 client_logger.warning(
287 "%s proxies %s are not supported, ignoring", proxy.scheme.upper(), proxy
288 )
289 continue
290 if netrc_obj and auth is None:
291 if proxy.host is not None:
292 try:
293 auth = basicauth_from_netrc(netrc_obj, proxy.host)
294 except LookupError:
295 auth = None
296 ret[proto] = ProxyInfo(proxy, auth)
297 return ret
300def get_env_proxy_for_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
301 """Get a permitted proxy for the given URL from the env."""
302 if url.host is not None and proxy_bypass(url.host):
303 raise LookupError(f"Proxying is disallowed for `{url.host!r}`")
305 proxies_in_env = proxies_from_env()
306 try:
307 proxy_info = proxies_in_env[url.scheme]
308 except KeyError:
309 raise LookupError(f"No proxies found for `{url!s}` in the env")
310 else:
311 return proxy_info.proxy, proxy_info.proxy_auth
314@attr.s(auto_attribs=True, frozen=True, slots=True)
315class MimeType:
316 type: str
317 subtype: str
318 suffix: str
319 parameters: "MultiDictProxy[str]"
322@functools.lru_cache(maxsize=56)
323def parse_mimetype(mimetype: str) -> MimeType:
324 """Parses a MIME type into its components.
326 mimetype is a MIME type string.
328 Returns a MimeType object.
330 Example:
332 >>> parse_mimetype('text/html; charset=utf-8')
333 MimeType(type='text', subtype='html', suffix='',
334 parameters={'charset': 'utf-8'})
336 """
337 if not mimetype:
338 return MimeType(
339 type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict())
340 )
342 parts = mimetype.split(";")
343 params: MultiDict[str] = MultiDict()
344 for item in parts[1:]:
345 if not item:
346 continue
347 key, _, value = item.partition("=")
348 params.add(key.lower().strip(), value.strip(' "'))
350 fulltype = parts[0].strip().lower()
351 if fulltype == "*":
352 fulltype = "*/*"
354 mtype, _, stype = fulltype.partition("/")
355 stype, _, suffix = stype.partition("+")
357 return MimeType(
358 type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params)
359 )
362class EnsureOctetStream(EmailMessage):
363 def __init__(self) -> None:
364 super().__init__()
365 # https://www.rfc-editor.org/rfc/rfc9110#section-8.3-5
366 self.set_default_type("application/octet-stream")
368 def get_content_type(self) -> str:
369 """Re-implementation from Message
371 Returns application/octet-stream in place of plain/text when
372 value is wrong.
374 The way this class is used guarantees that content-type will
375 be present so simplify the checks wrt to the base implementation.
376 """
377 value = self.get("content-type", "").lower()
379 # Based on the implementation of _splitparam in the standard library
380 ctype, _, _ = value.partition(";")
381 ctype = ctype.strip()
382 if ctype.count("/") != 1:
383 return self.get_default_type()
384 return ctype
387@functools.lru_cache(maxsize=56)
388def parse_content_type(raw: str) -> Tuple[str, MappingProxyType[str, str]]:
389 """Parse Content-Type header.
391 Returns a tuple of the parsed content type and a
392 MappingProxyType of parameters. The default returned value
393 is `application/octet-stream`
394 """
395 msg = HeaderParser(EnsureOctetStream, policy=HTTP).parsestr(f"Content-Type: {raw}")
396 content_type = msg.get_content_type()
397 params = msg.get_params(())
398 content_dict = dict(params[1:]) # First element is content type again
399 return content_type, MappingProxyType(content_dict)
402def guess_filename(obj: Any, default: Optional[str] = None) -> Optional[str]:
403 name = getattr(obj, "name", None)
404 if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">":
405 return Path(name).name
406 return default
409not_qtext_re = re.compile(r"[^\041\043-\133\135-\176]")
410QCONTENT = {chr(i) for i in range(0x20, 0x7F)} | {"\t"}
413def quoted_string(content: str) -> str:
414 """Return 7-bit content as quoted-string.
416 Format content into a quoted-string as defined in RFC5322 for
417 Internet Message Format. Notice that this is not the 8-bit HTTP
418 format, but the 7-bit email format. Content must be in usascii or
419 a ValueError is raised.
420 """
421 if not (QCONTENT > set(content)):
422 raise ValueError(f"bad content for quoted-string {content!r}")
423 return not_qtext_re.sub(lambda x: "\\" + x.group(0), content)
426def content_disposition_header(
427 disptype: str, quote_fields: bool = True, _charset: str = "utf-8", **params: str
428) -> str:
429 """Sets ``Content-Disposition`` header for MIME.
431 This is the MIME payload Content-Disposition header from RFC 2183
432 and RFC 7579 section 4.2, not the HTTP Content-Disposition from
433 RFC 6266.
435 disptype is a disposition type: inline, attachment, form-data.
436 Should be valid extension token (see RFC 2183)
438 quote_fields performs value quoting to 7-bit MIME headers
439 according to RFC 7578. Set to quote_fields to False if recipient
440 can take 8-bit file names and field values.
442 _charset specifies the charset to use when quote_fields is True.
444 params is a dict with disposition params.
445 """
446 if not disptype or not (TOKEN > set(disptype)):
447 raise ValueError(f"bad content disposition type {disptype!r}")
449 value = disptype
450 if params:
451 lparams = []
452 for key, val in params.items():
453 if not key or not (TOKEN > set(key)):
454 raise ValueError(f"bad content disposition parameter {key!r}={val!r}")
455 if quote_fields:
456 if key.lower() == "filename":
457 qval = quote(val, "", encoding=_charset)
458 lparams.append((key, '"%s"' % qval))
459 else:
460 try:
461 qval = quoted_string(val)
462 except ValueError:
463 qval = "".join(
464 (_charset, "''", quote(val, "", encoding=_charset))
465 )
466 lparams.append((key + "*", qval))
467 else:
468 lparams.append((key, '"%s"' % qval))
469 else:
470 qval = val.replace("\\", "\\\\").replace('"', '\\"')
471 lparams.append((key, '"%s"' % qval))
472 sparams = "; ".join("=".join(pair) for pair in lparams)
473 value = "; ".join((value, sparams))
474 return value
477def is_ip_address(host: Optional[str]) -> bool:
478 """Check if host looks like an IP Address.
480 This check is only meant as a heuristic to ensure that
481 a host is not a domain name.
482 """
483 if not host:
484 return False
485 # For a host to be an ipv4 address, it must be all numeric.
486 # The host must contain a colon to be an IPv6 address.
487 return ":" in host or host.replace(".", "").isdigit()
490_cached_current_datetime: Optional[int] = None
491_cached_formatted_datetime = ""
494def rfc822_formatted_time() -> str:
495 global _cached_current_datetime
496 global _cached_formatted_datetime
498 now = int(time.time())
499 if now != _cached_current_datetime:
500 # Weekday and month names for HTTP date/time formatting;
501 # always English!
502 # Tuples are constants stored in codeobject!
503 _weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun")
504 _monthname = (
505 "", # Dummy so we can use 1-based month numbers
506 "Jan",
507 "Feb",
508 "Mar",
509 "Apr",
510 "May",
511 "Jun",
512 "Jul",
513 "Aug",
514 "Sep",
515 "Oct",
516 "Nov",
517 "Dec",
518 )
520 year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now)
521 _cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
522 _weekdayname[wd],
523 day,
524 _monthname[month],
525 year,
526 hh,
527 mm,
528 ss,
529 )
530 _cached_current_datetime = now
531 return _cached_formatted_datetime
534def _weakref_handle(info: "Tuple[weakref.ref[object], str]") -> None:
535 ref, name = info
536 ob = ref()
537 if ob is not None:
538 with suppress(Exception):
539 getattr(ob, name)()
542def weakref_handle(
543 ob: object,
544 name: str,
545 timeout: float,
546 loop: asyncio.AbstractEventLoop,
547 timeout_ceil_threshold: float = 5,
548) -> Optional[asyncio.TimerHandle]:
549 if timeout is not None and timeout > 0:
550 when = loop.time() + timeout
551 if timeout >= timeout_ceil_threshold:
552 when = ceil(when)
554 return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name))
555 return None
558def call_later(
559 cb: Callable[[], Any],
560 timeout: float,
561 loop: asyncio.AbstractEventLoop,
562 timeout_ceil_threshold: float = 5,
563) -> Optional[asyncio.TimerHandle]:
564 if timeout is None or timeout <= 0:
565 return None
566 now = loop.time()
567 when = calculate_timeout_when(now, timeout, timeout_ceil_threshold)
568 return loop.call_at(when, cb)
571def calculate_timeout_when(
572 loop_time: float,
573 timeout: float,
574 timeout_ceiling_threshold: float,
575) -> float:
576 """Calculate when to execute a timeout."""
577 when = loop_time + timeout
578 if timeout > timeout_ceiling_threshold:
579 return ceil(when)
580 return when
583class TimeoutHandle:
584 """Timeout handle"""
586 __slots__ = ("_timeout", "_loop", "_ceil_threshold", "_callbacks")
588 def __init__(
589 self,
590 loop: asyncio.AbstractEventLoop,
591 timeout: Optional[float],
592 ceil_threshold: float = 5,
593 ) -> None:
594 self._timeout = timeout
595 self._loop = loop
596 self._ceil_threshold = ceil_threshold
597 self._callbacks: List[
598 Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]]
599 ] = []
601 def register(
602 self, callback: Callable[..., None], *args: Any, **kwargs: Any
603 ) -> None:
604 self._callbacks.append((callback, args, kwargs))
606 def close(self) -> None:
607 self._callbacks.clear()
609 def start(self) -> Optional[asyncio.TimerHandle]:
610 timeout = self._timeout
611 if timeout is not None and timeout > 0:
612 when = self._loop.time() + timeout
613 if timeout >= self._ceil_threshold:
614 when = ceil(when)
615 return self._loop.call_at(when, self.__call__)
616 else:
617 return None
619 def timer(self) -> "BaseTimerContext":
620 if self._timeout is not None and self._timeout > 0:
621 timer = TimerContext(self._loop)
622 self.register(timer.timeout)
623 return timer
624 else:
625 return TimerNoop()
627 def __call__(self) -> None:
628 for cb, args, kwargs in self._callbacks:
629 with suppress(Exception):
630 cb(*args, **kwargs)
632 self._callbacks.clear()
635class BaseTimerContext(ContextManager["BaseTimerContext"]):
637 __slots__ = ()
639 def assert_timeout(self) -> None:
640 """Raise TimeoutError if timeout has been exceeded."""
643class TimerNoop(BaseTimerContext):
645 __slots__ = ()
647 def __enter__(self) -> BaseTimerContext:
648 return self
650 def __exit__(
651 self,
652 exc_type: Optional[Type[BaseException]],
653 exc_val: Optional[BaseException],
654 exc_tb: Optional[TracebackType],
655 ) -> None:
656 return
659class TimerContext(BaseTimerContext):
660 """Low resolution timeout context manager"""
662 __slots__ = ("_loop", "_tasks", "_cancelled", "_cancelling")
664 def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
665 self._loop = loop
666 self._tasks: List[asyncio.Task[Any]] = []
667 self._cancelled = False
668 self._cancelling = 0
670 def assert_timeout(self) -> None:
671 """Raise TimeoutError if timer has already been cancelled."""
672 if self._cancelled:
673 raise asyncio.TimeoutError from None
675 def __enter__(self) -> BaseTimerContext:
676 task = asyncio.current_task(loop=self._loop)
677 if task is None:
678 raise RuntimeError("Timeout context manager should be used inside a task")
680 if sys.version_info >= (3, 11):
681 # Remember if the task was already cancelling
682 # so when we __exit__ we can decide if we should
683 # raise asyncio.TimeoutError or let the cancellation propagate
684 self._cancelling = task.cancelling()
686 if self._cancelled:
687 raise asyncio.TimeoutError from None
689 self._tasks.append(task)
690 return self
692 def __exit__(
693 self,
694 exc_type: Optional[Type[BaseException]],
695 exc_val: Optional[BaseException],
696 exc_tb: Optional[TracebackType],
697 ) -> Optional[bool]:
698 enter_task: Optional[asyncio.Task[Any]] = None
699 if self._tasks:
700 enter_task = self._tasks.pop()
702 if exc_type is asyncio.CancelledError and self._cancelled:
703 assert enter_task is not None
704 # The timeout was hit, and the task was cancelled
705 # so we need to uncancel the last task that entered the context manager
706 # since the cancellation should not leak out of the context manager
707 if sys.version_info >= (3, 11):
708 # If the task was already cancelling don't raise
709 # asyncio.TimeoutError and instead return None
710 # to allow the cancellation to propagate
711 if enter_task.uncancel() > self._cancelling:
712 return None
713 raise asyncio.TimeoutError from exc_val
714 return None
716 def timeout(self) -> None:
717 if not self._cancelled:
718 for task in set(self._tasks):
719 task.cancel()
721 self._cancelled = True
724def ceil_timeout(
725 delay: Optional[float], ceil_threshold: float = 5
726) -> async_timeout.Timeout:
727 if delay is None or delay <= 0:
728 return async_timeout.timeout(None)
730 loop = asyncio.get_running_loop()
731 now = loop.time()
732 when = now + delay
733 if delay > ceil_threshold:
734 when = ceil(when)
735 return async_timeout.timeout_at(when)
738class HeadersMixin:
739 """Mixin for handling headers."""
741 ATTRS = frozenset(["_content_type", "_content_dict", "_stored_content_type"])
743 _headers: MultiMapping[str]
744 _content_type: Optional[str] = None
745 _content_dict: Optional[Dict[str, str]] = None
746 _stored_content_type: Union[str, None, _SENTINEL] = sentinel
748 def _parse_content_type(self, raw: Optional[str]) -> None:
749 self._stored_content_type = raw
750 if raw is None:
751 # default value according to RFC 2616
752 self._content_type = "application/octet-stream"
753 self._content_dict = {}
754 else:
755 content_type, content_mapping_proxy = parse_content_type(raw)
756 self._content_type = content_type
757 # _content_dict needs to be mutable so we can update it
758 self._content_dict = content_mapping_proxy.copy()
760 @property
761 def content_type(self) -> str:
762 """The value of content part for Content-Type HTTP header."""
763 raw = self._headers.get(hdrs.CONTENT_TYPE)
764 if self._stored_content_type != raw:
765 self._parse_content_type(raw)
766 assert self._content_type is not None
767 return self._content_type
769 @property
770 def charset(self) -> Optional[str]:
771 """The value of charset part for Content-Type HTTP header."""
772 raw = self._headers.get(hdrs.CONTENT_TYPE)
773 if self._stored_content_type != raw:
774 self._parse_content_type(raw)
775 assert self._content_dict is not None
776 return self._content_dict.get("charset")
778 @property
779 def content_length(self) -> Optional[int]:
780 """The value of Content-Length HTTP header."""
781 content_length = self._headers.get(hdrs.CONTENT_LENGTH)
782 return None if content_length is None else int(content_length)
785def set_result(fut: "asyncio.Future[_T]", result: _T) -> None:
786 if not fut.done():
787 fut.set_result(result)
790_EXC_SENTINEL = BaseException()
793class ErrorableProtocol(Protocol):
794 def set_exception(
795 self,
796 exc: BaseException,
797 exc_cause: BaseException = ...,
798 ) -> None: ... # pragma: no cover
801def set_exception(
802 fut: "asyncio.Future[_T] | ErrorableProtocol",
803 exc: BaseException,
804 exc_cause: BaseException = _EXC_SENTINEL,
805) -> None:
806 """Set future exception.
808 If the future is marked as complete, this function is a no-op.
810 :param exc_cause: An exception that is a direct cause of ``exc``.
811 Only set if provided.
812 """
813 if asyncio.isfuture(fut) and fut.done():
814 return
816 exc_is_sentinel = exc_cause is _EXC_SENTINEL
817 exc_causes_itself = exc is exc_cause
818 if not exc_is_sentinel and not exc_causes_itself:
819 exc.__cause__ = exc_cause
821 fut.set_exception(exc)
824@functools.total_ordering
825class AppKey(Generic[_T]):
826 """Keys for static typing support in Application."""
828 __slots__ = ("_name", "_t", "__orig_class__")
830 # This may be set by Python when instantiating with a generic type. We need to
831 # support this, in order to support types that are not concrete classes,
832 # like Iterable, which can't be passed as the second parameter to __init__.
833 __orig_class__: Type[object]
835 def __init__(self, name: str, t: Optional[Type[_T]] = None):
836 # Prefix with module name to help deduplicate key names.
837 frame = inspect.currentframe()
838 while frame:
839 if frame.f_code.co_name == "<module>":
840 module: str = frame.f_globals["__name__"]
841 break
842 frame = frame.f_back
844 self._name = module + "." + name
845 self._t = t
847 def __lt__(self, other: object) -> bool:
848 if isinstance(other, AppKey):
849 return self._name < other._name
850 return True # Order AppKey above other types.
852 def __repr__(self) -> str:
853 t = self._t
854 if t is None:
855 with suppress(AttributeError):
856 # Set to type arg.
857 t = get_args(self.__orig_class__)[0]
859 if t is None:
860 t_repr = "<<Unknown>>"
861 elif isinstance(t, type):
862 if t.__module__ == "builtins":
863 t_repr = t.__qualname__
864 else:
865 t_repr = f"{t.__module__}.{t.__qualname__}"
866 else:
867 t_repr = repr(t)
868 return f"<AppKey({self._name}, type={t_repr})>"
871class ChainMapProxy(Mapping[Union[str, AppKey[Any]], Any]):
872 __slots__ = ("_maps",)
874 def __init__(self, maps: Iterable[Mapping[Union[str, AppKey[Any]], Any]]) -> None:
875 self._maps = tuple(maps)
877 def __init_subclass__(cls) -> None:
878 raise TypeError(
879 "Inheritance class {} from ChainMapProxy "
880 "is forbidden".format(cls.__name__)
881 )
883 @overload # type: ignore[override]
884 def __getitem__(self, key: AppKey[_T]) -> _T: ...
886 @overload
887 def __getitem__(self, key: str) -> Any: ...
889 def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any:
890 for mapping in self._maps:
891 try:
892 return mapping[key]
893 except KeyError:
894 pass
895 raise KeyError(key)
897 @overload # type: ignore[override]
898 def get(self, key: AppKey[_T], default: _S) -> Union[_T, _S]: ...
900 @overload
901 def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: ...
903 @overload
904 def get(self, key: str, default: Any = ...) -> Any: ...
906 def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any:
907 try:
908 return self[key]
909 except KeyError:
910 return default
912 def __len__(self) -> int:
913 # reuses stored hash values if possible
914 return len(set().union(*self._maps))
916 def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]:
917 d: Dict[Union[str, AppKey[Any]], Any] = {}
918 for mapping in reversed(self._maps):
919 # reuses stored hash values if possible
920 d.update(mapping)
921 return iter(d)
923 def __contains__(self, key: object) -> bool:
924 return any(key in m for m in self._maps)
926 def __bool__(self) -> bool:
927 return any(self._maps)
929 def __repr__(self) -> str:
930 content = ", ".join(map(repr, self._maps))
931 return f"ChainMapProxy({content})"
934# https://tools.ietf.org/html/rfc7232#section-2.3
935_ETAGC = r"[!\x23-\x7E\x80-\xff]+"
936_ETAGC_RE = re.compile(_ETAGC)
937_QUOTED_ETAG = rf'(W/)?"({_ETAGC})"'
938QUOTED_ETAG_RE = re.compile(_QUOTED_ETAG)
939LIST_QUOTED_ETAG_RE = re.compile(rf"({_QUOTED_ETAG})(?:\s*,\s*|$)|(.)")
941ETAG_ANY = "*"
944@attr.s(auto_attribs=True, frozen=True, slots=True)
945class ETag:
946 value: str
947 is_weak: bool = False
950def validate_etag_value(value: str) -> None:
951 if value != ETAG_ANY and not _ETAGC_RE.fullmatch(value):
952 raise ValueError(
953 f"Value {value!r} is not a valid etag. Maybe it contains '\"'?"
954 )
957def parse_http_date(date_str: Optional[str]) -> Optional[datetime.datetime]:
958 """Process a date string, return a datetime object"""
959 if date_str is not None:
960 timetuple = parsedate(date_str)
961 if timetuple is not None:
962 with suppress(ValueError):
963 return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc)
964 return None
967@functools.lru_cache
968def must_be_empty_body(method: str, code: int) -> bool:
969 """Check if a request must return an empty body."""
970 return (
971 code in EMPTY_BODY_STATUS_CODES
972 or method in EMPTY_BODY_METHODS
973 or (200 <= code < 300 and method in hdrs.METH_CONNECT_ALL)
974 )
977def should_remove_content_length(method: str, code: int) -> bool:
978 """Check if a Content-Length header should be removed.
980 This should always be a subset of must_be_empty_body
981 """
982 # https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-8
983 # https://www.rfc-editor.org/rfc/rfc9110.html#section-15.4.5-4
984 return code in EMPTY_BODY_STATUS_CODES or (
985 200 <= code < 300 and method in hdrs.METH_CONNECT_ALL
986 )