Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/aiohttp/helpers.py: 44%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Various helper functions"""
3import asyncio
4import base64
5import binascii
6import contextlib
7import datetime
8import enum
9import functools
10import inspect
11import netrc
12import os
13import platform
14import re
15import sys
16import time
17import weakref
18from collections import namedtuple
19from contextlib import suppress
20from email.parser import HeaderParser
21from email.utils import parsedate
22from math import ceil
23from pathlib import Path
24from types import MappingProxyType, TracebackType
25from typing import (
26 Any,
27 Callable,
28 ContextManager,
29 Dict,
30 Generator,
31 Generic,
32 Iterable,
33 Iterator,
34 List,
35 Mapping,
36 Optional,
37 Protocol,
38 Tuple,
39 Type,
40 TypeVar,
41 Union,
42 get_args,
43 overload,
44)
45from urllib.parse import quote
46from urllib.request import getproxies, proxy_bypass
48import attr
49from multidict import MultiDict, MultiDictProxy, MultiMapping
50from propcache.api import under_cached_property as reify
51from yarl import URL
53from . import hdrs
54from .log import client_logger
56if sys.version_info >= (3, 11):
57 import asyncio as async_timeout
58else:
59 import async_timeout
61__all__ = ("BasicAuth", "ChainMapProxy", "ETag", "reify")
63IS_MACOS = platform.system() == "Darwin"
64IS_WINDOWS = platform.system() == "Windows"
66PY_310 = sys.version_info >= (3, 10)
67PY_311 = sys.version_info >= (3, 11)
70_T = TypeVar("_T")
71_S = TypeVar("_S")
73_SENTINEL = enum.Enum("_SENTINEL", "sentinel")
74sentinel = _SENTINEL.sentinel
76NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS"))
78# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
79EMPTY_BODY_STATUS_CODES = frozenset((204, 304, *range(100, 200)))
80# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
81# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.2
82EMPTY_BODY_METHODS = hdrs.METH_HEAD_ALL
84DEBUG = sys.flags.dev_mode or (
85 not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG"))
86)
89CHAR = {chr(i) for i in range(0, 128)}
90CTL = {chr(i) for i in range(0, 32)} | {
91 chr(127),
92}
93SEPARATORS = {
94 "(",
95 ")",
96 "<",
97 ">",
98 "@",
99 ",",
100 ";",
101 ":",
102 "\\",
103 '"',
104 "/",
105 "[",
106 "]",
107 "?",
108 "=",
109 "{",
110 "}",
111 " ",
112 chr(9),
113}
114TOKEN = CHAR ^ CTL ^ SEPARATORS
117class noop:
118 def __await__(self) -> Generator[None, None, None]:
119 yield
122class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])):
123 """Http basic authentication helper."""
125 def __new__(
126 cls, login: str, password: str = "", encoding: str = "latin1"
127 ) -> "BasicAuth":
128 if login is None:
129 raise ValueError("None is not allowed as login value")
131 if password is None:
132 raise ValueError("None is not allowed as password value")
134 if ":" in login:
135 raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)')
137 return super().__new__(cls, login, password, encoding)
139 @classmethod
140 def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth":
141 """Create a BasicAuth object from an Authorization HTTP header."""
142 try:
143 auth_type, encoded_credentials = auth_header.split(" ", 1)
144 except ValueError:
145 raise ValueError("Could not parse authorization header.")
147 if auth_type.lower() != "basic":
148 raise ValueError("Unknown authorization method %s" % auth_type)
150 try:
151 decoded = base64.b64decode(
152 encoded_credentials.encode("ascii"), validate=True
153 ).decode(encoding)
154 except binascii.Error:
155 raise ValueError("Invalid base64 encoding.")
157 try:
158 # RFC 2617 HTTP Authentication
159 # https://www.ietf.org/rfc/rfc2617.txt
160 # the colon must be present, but the username and password may be
161 # otherwise blank.
162 username, password = decoded.split(":", 1)
163 except ValueError:
164 raise ValueError("Invalid credentials.")
166 return cls(username, password, encoding=encoding)
168 @classmethod
169 def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]:
170 """Create BasicAuth from url."""
171 if not isinstance(url, URL):
172 raise TypeError("url should be yarl.URL instance")
173 # Check raw_user and raw_password first as yarl is likely
174 # to already have these values parsed from the netloc in the cache.
175 if url.raw_user is None and url.raw_password is None:
176 return None
177 return cls(url.user or "", url.password or "", encoding=encoding)
179 def encode(self) -> str:
180 """Encode credentials."""
181 creds = (f"{self.login}:{self.password}").encode(self.encoding)
182 return "Basic %s" % base64.b64encode(creds).decode(self.encoding)
185def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
186 """Remove user and password from URL if present and return BasicAuth object."""
187 # Check raw_user and raw_password first as yarl is likely
188 # to already have these values parsed from the netloc in the cache.
189 if url.raw_user is None and url.raw_password is None:
190 return url, None
191 return url.with_user(None), BasicAuth(url.user or "", url.password or "")
194def netrc_from_env() -> Optional[netrc.netrc]:
195 """Load netrc from file.
197 Attempt to load it from the path specified by the env-var
198 NETRC or in the default location in the user's home directory.
200 Returns None if it couldn't be found or fails to parse.
201 """
202 netrc_env = os.environ.get("NETRC")
204 if netrc_env is not None:
205 netrc_path = Path(netrc_env)
206 else:
207 try:
208 home_dir = Path.home()
209 except RuntimeError as e: # pragma: no cover
210 # if pathlib can't resolve home, it may raise a RuntimeError
211 client_logger.debug(
212 "Could not resolve home directory when "
213 "trying to look for .netrc file: %s",
214 e,
215 )
216 return None
218 netrc_path = home_dir / ("_netrc" if IS_WINDOWS else ".netrc")
220 try:
221 return netrc.netrc(str(netrc_path))
222 except netrc.NetrcParseError as e:
223 client_logger.warning("Could not parse .netrc file: %s", e)
224 except OSError as e:
225 netrc_exists = False
226 with contextlib.suppress(OSError):
227 netrc_exists = netrc_path.is_file()
228 # we couldn't read the file (doesn't exist, permissions, etc.)
229 if netrc_env or netrc_exists:
230 # only warn if the environment wanted us to load it,
231 # or it appears like the default file does actually exist
232 client_logger.warning("Could not read .netrc file: %s", e)
234 return None
237@attr.s(auto_attribs=True, frozen=True, slots=True)
238class ProxyInfo:
239 proxy: URL
240 proxy_auth: Optional[BasicAuth]
243def basicauth_from_netrc(netrc_obj: Optional[netrc.netrc], host: str) -> BasicAuth:
244 """
245 Return :py:class:`~aiohttp.BasicAuth` credentials for ``host`` from ``netrc_obj``.
247 :raises LookupError: if ``netrc_obj`` is :py:data:`None` or if no
248 entry is found for the ``host``.
249 """
250 if netrc_obj is None:
251 raise LookupError("No .netrc file found")
252 auth_from_netrc = netrc_obj.authenticators(host)
254 if auth_from_netrc is None:
255 raise LookupError(f"No entry for {host!s} found in the `.netrc` file.")
256 login, account, password = auth_from_netrc
258 # TODO(PY311): username = login or account
259 # Up to python 3.10, account could be None if not specified,
260 # and login will be empty string if not specified. From 3.11,
261 # login and account will be empty string if not specified.
262 username = login if (login or account is None) else account
264 # TODO(PY311): Remove this, as password will be empty string
265 # if not specified
266 if password is None:
267 password = ""
269 return BasicAuth(username, password)
272def proxies_from_env() -> Dict[str, ProxyInfo]:
273 proxy_urls = {
274 k: URL(v)
275 for k, v in getproxies().items()
276 if k in ("http", "https", "ws", "wss")
277 }
278 netrc_obj = netrc_from_env()
279 stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()}
280 ret = {}
281 for proto, val in stripped.items():
282 proxy, auth = val
283 if proxy.scheme in ("https", "wss"):
284 client_logger.warning(
285 "%s proxies %s are not supported, ignoring", proxy.scheme.upper(), proxy
286 )
287 continue
288 if netrc_obj and auth is None:
289 if proxy.host is not None:
290 try:
291 auth = basicauth_from_netrc(netrc_obj, proxy.host)
292 except LookupError:
293 auth = None
294 ret[proto] = ProxyInfo(proxy, auth)
295 return ret
298def get_env_proxy_for_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
299 """Get a permitted proxy for the given URL from the env."""
300 if url.host is not None and proxy_bypass(url.host):
301 raise LookupError(f"Proxying is disallowed for `{url.host!r}`")
303 proxies_in_env = proxies_from_env()
304 try:
305 proxy_info = proxies_in_env[url.scheme]
306 except KeyError:
307 raise LookupError(f"No proxies found for `{url!s}` in the env")
308 else:
309 return proxy_info.proxy, proxy_info.proxy_auth
312@attr.s(auto_attribs=True, frozen=True, slots=True)
313class MimeType:
314 type: str
315 subtype: str
316 suffix: str
317 parameters: "MultiDictProxy[str]"
320@functools.lru_cache(maxsize=56)
321def parse_mimetype(mimetype: str) -> MimeType:
322 """Parses a MIME type into its components.
324 mimetype is a MIME type string.
326 Returns a MimeType object.
328 Example:
330 >>> parse_mimetype('text/html; charset=utf-8')
331 MimeType(type='text', subtype='html', suffix='',
332 parameters={'charset': 'utf-8'})
334 """
335 if not mimetype:
336 return MimeType(
337 type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict())
338 )
340 parts = mimetype.split(";")
341 params: MultiDict[str] = MultiDict()
342 for item in parts[1:]:
343 if not item:
344 continue
345 key, _, value = item.partition("=")
346 params.add(key.lower().strip(), value.strip(' "'))
348 fulltype = parts[0].strip().lower()
349 if fulltype == "*":
350 fulltype = "*/*"
352 mtype, _, stype = fulltype.partition("/")
353 stype, _, suffix = stype.partition("+")
355 return MimeType(
356 type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params)
357 )
360@functools.lru_cache(maxsize=56)
361def parse_content_type(raw: str) -> Tuple[str, MappingProxyType[str, str]]:
362 """Parse Content-Type header.
364 Returns a tuple of the parsed content type and a
365 MappingProxyType of parameters.
366 """
367 msg = HeaderParser().parsestr(f"Content-Type: {raw}")
368 content_type = msg.get_content_type()
369 params = msg.get_params(())
370 content_dict = dict(params[1:]) # First element is content type again
371 return content_type, MappingProxyType(content_dict)
374def guess_filename(obj: Any, default: Optional[str] = None) -> Optional[str]:
375 name = getattr(obj, "name", None)
376 if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">":
377 return Path(name).name
378 return default
381not_qtext_re = re.compile(r"[^\041\043-\133\135-\176]")
382QCONTENT = {chr(i) for i in range(0x20, 0x7F)} | {"\t"}
385def quoted_string(content: str) -> str:
386 """Return 7-bit content as quoted-string.
388 Format content into a quoted-string as defined in RFC5322 for
389 Internet Message Format. Notice that this is not the 8-bit HTTP
390 format, but the 7-bit email format. Content must be in usascii or
391 a ValueError is raised.
392 """
393 if not (QCONTENT > set(content)):
394 raise ValueError(f"bad content for quoted-string {content!r}")
395 return not_qtext_re.sub(lambda x: "\\" + x.group(0), content)
398def content_disposition_header(
399 disptype: str, quote_fields: bool = True, _charset: str = "utf-8", **params: str
400) -> str:
401 """Sets ``Content-Disposition`` header for MIME.
403 This is the MIME payload Content-Disposition header from RFC 2183
404 and RFC 7579 section 4.2, not the HTTP Content-Disposition from
405 RFC 6266.
407 disptype is a disposition type: inline, attachment, form-data.
408 Should be valid extension token (see RFC 2183)
410 quote_fields performs value quoting to 7-bit MIME headers
411 according to RFC 7578. Set to quote_fields to False if recipient
412 can take 8-bit file names and field values.
414 _charset specifies the charset to use when quote_fields is True.
416 params is a dict with disposition params.
417 """
418 if not disptype or not (TOKEN > set(disptype)):
419 raise ValueError(f"bad content disposition type {disptype!r}")
421 value = disptype
422 if params:
423 lparams = []
424 for key, val in params.items():
425 if not key or not (TOKEN > set(key)):
426 raise ValueError(f"bad content disposition parameter {key!r}={val!r}")
427 if quote_fields:
428 if key.lower() == "filename":
429 qval = quote(val, "", encoding=_charset)
430 lparams.append((key, '"%s"' % qval))
431 else:
432 try:
433 qval = quoted_string(val)
434 except ValueError:
435 qval = "".join(
436 (_charset, "''", quote(val, "", encoding=_charset))
437 )
438 lparams.append((key + "*", qval))
439 else:
440 lparams.append((key, '"%s"' % qval))
441 else:
442 qval = val.replace("\\", "\\\\").replace('"', '\\"')
443 lparams.append((key, '"%s"' % qval))
444 sparams = "; ".join("=".join(pair) for pair in lparams)
445 value = "; ".join((value, sparams))
446 return value
449def is_ip_address(host: Optional[str]) -> bool:
450 """Check if host looks like an IP Address.
452 This check is only meant as a heuristic to ensure that
453 a host is not a domain name.
454 """
455 if not host:
456 return False
457 # For a host to be an ipv4 address, it must be all numeric.
458 # The host must contain a colon to be an IPv6 address.
459 return ":" in host or host.replace(".", "").isdigit()
462_cached_current_datetime: Optional[int] = None
463_cached_formatted_datetime = ""
466def rfc822_formatted_time() -> str:
467 global _cached_current_datetime
468 global _cached_formatted_datetime
470 now = int(time.time())
471 if now != _cached_current_datetime:
472 # Weekday and month names for HTTP date/time formatting;
473 # always English!
474 # Tuples are constants stored in codeobject!
475 _weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun")
476 _monthname = (
477 "", # Dummy so we can use 1-based month numbers
478 "Jan",
479 "Feb",
480 "Mar",
481 "Apr",
482 "May",
483 "Jun",
484 "Jul",
485 "Aug",
486 "Sep",
487 "Oct",
488 "Nov",
489 "Dec",
490 )
492 year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now)
493 _cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
494 _weekdayname[wd],
495 day,
496 _monthname[month],
497 year,
498 hh,
499 mm,
500 ss,
501 )
502 _cached_current_datetime = now
503 return _cached_formatted_datetime
506def _weakref_handle(info: "Tuple[weakref.ref[object], str]") -> None:
507 ref, name = info
508 ob = ref()
509 if ob is not None:
510 with suppress(Exception):
511 getattr(ob, name)()
514def weakref_handle(
515 ob: object,
516 name: str,
517 timeout: float,
518 loop: asyncio.AbstractEventLoop,
519 timeout_ceil_threshold: float = 5,
520) -> Optional[asyncio.TimerHandle]:
521 if timeout is not None and timeout > 0:
522 when = loop.time() + timeout
523 if timeout >= timeout_ceil_threshold:
524 when = ceil(when)
526 return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name))
527 return None
530def call_later(
531 cb: Callable[[], Any],
532 timeout: float,
533 loop: asyncio.AbstractEventLoop,
534 timeout_ceil_threshold: float = 5,
535) -> Optional[asyncio.TimerHandle]:
536 if timeout is None or timeout <= 0:
537 return None
538 now = loop.time()
539 when = calculate_timeout_when(now, timeout, timeout_ceil_threshold)
540 return loop.call_at(when, cb)
543def calculate_timeout_when(
544 loop_time: float,
545 timeout: float,
546 timeout_ceiling_threshold: float,
547) -> float:
548 """Calculate when to execute a timeout."""
549 when = loop_time + timeout
550 if timeout > timeout_ceiling_threshold:
551 return ceil(when)
552 return when
555class TimeoutHandle:
556 """Timeout handle"""
558 __slots__ = ("_timeout", "_loop", "_ceil_threshold", "_callbacks")
560 def __init__(
561 self,
562 loop: asyncio.AbstractEventLoop,
563 timeout: Optional[float],
564 ceil_threshold: float = 5,
565 ) -> None:
566 self._timeout = timeout
567 self._loop = loop
568 self._ceil_threshold = ceil_threshold
569 self._callbacks: List[
570 Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]]
571 ] = []
573 def register(
574 self, callback: Callable[..., None], *args: Any, **kwargs: Any
575 ) -> None:
576 self._callbacks.append((callback, args, kwargs))
578 def close(self) -> None:
579 self._callbacks.clear()
581 def start(self) -> Optional[asyncio.TimerHandle]:
582 timeout = self._timeout
583 if timeout is not None and timeout > 0:
584 when = self._loop.time() + timeout
585 if timeout >= self._ceil_threshold:
586 when = ceil(when)
587 return self._loop.call_at(when, self.__call__)
588 else:
589 return None
591 def timer(self) -> "BaseTimerContext":
592 if self._timeout is not None and self._timeout > 0:
593 timer = TimerContext(self._loop)
594 self.register(timer.timeout)
595 return timer
596 else:
597 return TimerNoop()
599 def __call__(self) -> None:
600 for cb, args, kwargs in self._callbacks:
601 with suppress(Exception):
602 cb(*args, **kwargs)
604 self._callbacks.clear()
607class BaseTimerContext(ContextManager["BaseTimerContext"]):
609 __slots__ = ()
611 def assert_timeout(self) -> None:
612 """Raise TimeoutError if timeout has been exceeded."""
615class TimerNoop(BaseTimerContext):
617 __slots__ = ()
619 def __enter__(self) -> BaseTimerContext:
620 return self
622 def __exit__(
623 self,
624 exc_type: Optional[Type[BaseException]],
625 exc_val: Optional[BaseException],
626 exc_tb: Optional[TracebackType],
627 ) -> None:
628 return
631class TimerContext(BaseTimerContext):
632 """Low resolution timeout context manager"""
634 __slots__ = ("_loop", "_tasks", "_cancelled", "_cancelling")
636 def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
637 self._loop = loop
638 self._tasks: List[asyncio.Task[Any]] = []
639 self._cancelled = False
640 self._cancelling = 0
642 def assert_timeout(self) -> None:
643 """Raise TimeoutError if timer has already been cancelled."""
644 if self._cancelled:
645 raise asyncio.TimeoutError from None
647 def __enter__(self) -> BaseTimerContext:
648 task = asyncio.current_task(loop=self._loop)
649 if task is None:
650 raise RuntimeError("Timeout context manager should be used inside a task")
652 if sys.version_info >= (3, 11):
653 # Remember if the task was already cancelling
654 # so when we __exit__ we can decide if we should
655 # raise asyncio.TimeoutError or let the cancellation propagate
656 self._cancelling = task.cancelling()
658 if self._cancelled:
659 raise asyncio.TimeoutError from None
661 self._tasks.append(task)
662 return self
664 def __exit__(
665 self,
666 exc_type: Optional[Type[BaseException]],
667 exc_val: Optional[BaseException],
668 exc_tb: Optional[TracebackType],
669 ) -> Optional[bool]:
670 enter_task: Optional[asyncio.Task[Any]] = None
671 if self._tasks:
672 enter_task = self._tasks.pop()
674 if exc_type is asyncio.CancelledError and self._cancelled:
675 assert enter_task is not None
676 # The timeout was hit, and the task was cancelled
677 # so we need to uncancel the last task that entered the context manager
678 # since the cancellation should not leak out of the context manager
679 if sys.version_info >= (3, 11):
680 # If the task was already cancelling don't raise
681 # asyncio.TimeoutError and instead return None
682 # to allow the cancellation to propagate
683 if enter_task.uncancel() > self._cancelling:
684 return None
685 raise asyncio.TimeoutError from exc_val
686 return None
688 def timeout(self) -> None:
689 if not self._cancelled:
690 for task in set(self._tasks):
691 task.cancel()
693 self._cancelled = True
696def ceil_timeout(
697 delay: Optional[float], ceil_threshold: float = 5
698) -> async_timeout.Timeout:
699 if delay is None or delay <= 0:
700 return async_timeout.timeout(None)
702 loop = asyncio.get_running_loop()
703 now = loop.time()
704 when = now + delay
705 if delay > ceil_threshold:
706 when = ceil(when)
707 return async_timeout.timeout_at(when)
710class HeadersMixin:
711 """Mixin for handling headers."""
713 ATTRS = frozenset(["_content_type", "_content_dict", "_stored_content_type"])
715 _headers: MultiMapping[str]
716 _content_type: Optional[str] = None
717 _content_dict: Optional[Dict[str, str]] = None
718 _stored_content_type: Union[str, None, _SENTINEL] = sentinel
720 def _parse_content_type(self, raw: Optional[str]) -> None:
721 self._stored_content_type = raw
722 if raw is None:
723 # default value according to RFC 2616
724 self._content_type = "application/octet-stream"
725 self._content_dict = {}
726 else:
727 content_type, content_mapping_proxy = parse_content_type(raw)
728 self._content_type = content_type
729 # _content_dict needs to be mutable so we can update it
730 self._content_dict = content_mapping_proxy.copy()
732 @property
733 def content_type(self) -> str:
734 """The value of content part for Content-Type HTTP header."""
735 raw = self._headers.get(hdrs.CONTENT_TYPE)
736 if self._stored_content_type != raw:
737 self._parse_content_type(raw)
738 assert self._content_type is not None
739 return self._content_type
741 @property
742 def charset(self) -> Optional[str]:
743 """The value of charset part for Content-Type HTTP header."""
744 raw = self._headers.get(hdrs.CONTENT_TYPE)
745 if self._stored_content_type != raw:
746 self._parse_content_type(raw)
747 assert self._content_dict is not None
748 return self._content_dict.get("charset")
750 @property
751 def content_length(self) -> Optional[int]:
752 """The value of Content-Length HTTP header."""
753 content_length = self._headers.get(hdrs.CONTENT_LENGTH)
754 return None if content_length is None else int(content_length)
757def set_result(fut: "asyncio.Future[_T]", result: _T) -> None:
758 if not fut.done():
759 fut.set_result(result)
762_EXC_SENTINEL = BaseException()
765class ErrorableProtocol(Protocol):
766 def set_exception(
767 self,
768 exc: BaseException,
769 exc_cause: BaseException = ...,
770 ) -> None: ... # pragma: no cover
773def set_exception(
774 fut: "asyncio.Future[_T] | ErrorableProtocol",
775 exc: BaseException,
776 exc_cause: BaseException = _EXC_SENTINEL,
777) -> None:
778 """Set future exception.
780 If the future is marked as complete, this function is a no-op.
782 :param exc_cause: An exception that is a direct cause of ``exc``.
783 Only set if provided.
784 """
785 if asyncio.isfuture(fut) and fut.done():
786 return
788 exc_is_sentinel = exc_cause is _EXC_SENTINEL
789 exc_causes_itself = exc is exc_cause
790 if not exc_is_sentinel and not exc_causes_itself:
791 exc.__cause__ = exc_cause
793 fut.set_exception(exc)
796@functools.total_ordering
797class AppKey(Generic[_T]):
798 """Keys for static typing support in Application."""
800 __slots__ = ("_name", "_t", "__orig_class__")
802 # This may be set by Python when instantiating with a generic type. We need to
803 # support this, in order to support types that are not concrete classes,
804 # like Iterable, which can't be passed as the second parameter to __init__.
805 __orig_class__: Type[object]
807 def __init__(self, name: str, t: Optional[Type[_T]] = None):
808 # Prefix with module name to help deduplicate key names.
809 frame = inspect.currentframe()
810 while frame:
811 if frame.f_code.co_name == "<module>":
812 module: str = frame.f_globals["__name__"]
813 break
814 frame = frame.f_back
816 self._name = module + "." + name
817 self._t = t
819 def __lt__(self, other: object) -> bool:
820 if isinstance(other, AppKey):
821 return self._name < other._name
822 return True # Order AppKey above other types.
824 def __repr__(self) -> str:
825 t = self._t
826 if t is None:
827 with suppress(AttributeError):
828 # Set to type arg.
829 t = get_args(self.__orig_class__)[0]
831 if t is None:
832 t_repr = "<<Unknown>>"
833 elif isinstance(t, type):
834 if t.__module__ == "builtins":
835 t_repr = t.__qualname__
836 else:
837 t_repr = f"{t.__module__}.{t.__qualname__}"
838 else:
839 t_repr = repr(t)
840 return f"<AppKey({self._name}, type={t_repr})>"
843class ChainMapProxy(Mapping[Union[str, AppKey[Any]], Any]):
844 __slots__ = ("_maps",)
846 def __init__(self, maps: Iterable[Mapping[Union[str, AppKey[Any]], Any]]) -> None:
847 self._maps = tuple(maps)
849 def __init_subclass__(cls) -> None:
850 raise TypeError(
851 "Inheritance class {} from ChainMapProxy "
852 "is forbidden".format(cls.__name__)
853 )
855 @overload # type: ignore[override]
856 def __getitem__(self, key: AppKey[_T]) -> _T: ...
858 @overload
859 def __getitem__(self, key: str) -> Any: ...
861 def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any:
862 for mapping in self._maps:
863 try:
864 return mapping[key]
865 except KeyError:
866 pass
867 raise KeyError(key)
869 @overload # type: ignore[override]
870 def get(self, key: AppKey[_T], default: _S) -> Union[_T, _S]: ...
872 @overload
873 def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: ...
875 @overload
876 def get(self, key: str, default: Any = ...) -> Any: ...
878 def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any:
879 try:
880 return self[key]
881 except KeyError:
882 return default
884 def __len__(self) -> int:
885 # reuses stored hash values if possible
886 return len(set().union(*self._maps))
888 def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]:
889 d: Dict[Union[str, AppKey[Any]], Any] = {}
890 for mapping in reversed(self._maps):
891 # reuses stored hash values if possible
892 d.update(mapping)
893 return iter(d)
895 def __contains__(self, key: object) -> bool:
896 return any(key in m for m in self._maps)
898 def __bool__(self) -> bool:
899 return any(self._maps)
901 def __repr__(self) -> str:
902 content = ", ".join(map(repr, self._maps))
903 return f"ChainMapProxy({content})"
906# https://tools.ietf.org/html/rfc7232#section-2.3
907_ETAGC = r"[!\x23-\x7E\x80-\xff]+"
908_ETAGC_RE = re.compile(_ETAGC)
909_QUOTED_ETAG = rf'(W/)?"({_ETAGC})"'
910QUOTED_ETAG_RE = re.compile(_QUOTED_ETAG)
911LIST_QUOTED_ETAG_RE = re.compile(rf"({_QUOTED_ETAG})(?:\s*,\s*|$)|(.)")
913ETAG_ANY = "*"
916@attr.s(auto_attribs=True, frozen=True, slots=True)
917class ETag:
918 value: str
919 is_weak: bool = False
922def validate_etag_value(value: str) -> None:
923 if value != ETAG_ANY and not _ETAGC_RE.fullmatch(value):
924 raise ValueError(
925 f"Value {value!r} is not a valid etag. Maybe it contains '\"'?"
926 )
929def parse_http_date(date_str: Optional[str]) -> Optional[datetime.datetime]:
930 """Process a date string, return a datetime object"""
931 if date_str is not None:
932 timetuple = parsedate(date_str)
933 if timetuple is not None:
934 with suppress(ValueError):
935 return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc)
936 return None
939@functools.lru_cache
940def must_be_empty_body(method: str, code: int) -> bool:
941 """Check if a request must return an empty body."""
942 return (
943 code in EMPTY_BODY_STATUS_CODES
944 or method in EMPTY_BODY_METHODS
945 or (200 <= code < 300 and method in hdrs.METH_CONNECT_ALL)
946 )
949def should_remove_content_length(method: str, code: int) -> bool:
950 """Check if a Content-Length header should be removed.
952 This should always be a subset of must_be_empty_body
953 """
954 # https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-8
955 # https://www.rfc-editor.org/rfc/rfc9110.html#section-15.4.5-4
956 return code in EMPTY_BODY_STATUS_CODES or (
957 200 <= code < 300 and method in hdrs.METH_CONNECT_ALL
958 )