Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/aiohttp/helpers.py: 44%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Various helper functions"""
3import asyncio
4import base64
5import binascii
6import contextlib
7import datetime
8import enum
9import functools
10import inspect
11import netrc
12import os
13import platform
14import re
15import sys
16import time
17import warnings
18import weakref
19from collections import namedtuple
20from contextlib import suppress
21from email.parser import HeaderParser
22from email.utils import parsedate
23from math import ceil
24from pathlib import Path
25from types import TracebackType
26from typing import (
27 Any,
28 Callable,
29 ContextManager,
30 Dict,
31 Generator,
32 Generic,
33 Iterable,
34 Iterator,
35 List,
36 Mapping,
37 Optional,
38 Pattern,
39 Protocol,
40 Tuple,
41 Type,
42 TypeVar,
43 Union,
44 get_args,
45 overload,
46)
47from urllib.parse import quote
48from urllib.request import getproxies, proxy_bypass
50import attr
51from multidict import MultiDict, MultiDictProxy, MultiMapping
52from yarl import URL
54from . import hdrs
55from .log import client_logger, internal_logger
57if sys.version_info >= (3, 11):
58 import asyncio as async_timeout
59else:
60 import async_timeout
62__all__ = ("BasicAuth", "ChainMapProxy", "ETag")
64IS_MACOS = platform.system() == "Darwin"
65IS_WINDOWS = platform.system() == "Windows"
67PY_310 = sys.version_info >= (3, 10)
68PY_311 = sys.version_info >= (3, 11)
71_T = TypeVar("_T")
72_S = TypeVar("_S")
74_SENTINEL = enum.Enum("_SENTINEL", "sentinel")
75sentinel = _SENTINEL.sentinel
77NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS"))
79DEBUG = sys.flags.dev_mode or (
80 not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG"))
81)
84CHAR = {chr(i) for i in range(0, 128)}
85CTL = {chr(i) for i in range(0, 32)} | {
86 chr(127),
87}
88SEPARATORS = {
89 "(",
90 ")",
91 "<",
92 ">",
93 "@",
94 ",",
95 ";",
96 ":",
97 "\\",
98 '"',
99 "/",
100 "[",
101 "]",
102 "?",
103 "=",
104 "{",
105 "}",
106 " ",
107 chr(9),
108}
109TOKEN = CHAR ^ CTL ^ SEPARATORS
112class noop:
113 def __await__(self) -> Generator[None, None, None]:
114 yield
117class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])):
118 """Http basic authentication helper."""
120 def __new__(
121 cls, login: str, password: str = "", encoding: str = "latin1"
122 ) -> "BasicAuth":
123 if login is None:
124 raise ValueError("None is not allowed as login value")
126 if password is None:
127 raise ValueError("None is not allowed as password value")
129 if ":" in login:
130 raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)')
132 return super().__new__(cls, login, password, encoding)
134 @classmethod
135 def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth":
136 """Create a BasicAuth object from an Authorization HTTP header."""
137 try:
138 auth_type, encoded_credentials = auth_header.split(" ", 1)
139 except ValueError:
140 raise ValueError("Could not parse authorization header.")
142 if auth_type.lower() != "basic":
143 raise ValueError("Unknown authorization method %s" % auth_type)
145 try:
146 decoded = base64.b64decode(
147 encoded_credentials.encode("ascii"), validate=True
148 ).decode(encoding)
149 except binascii.Error:
150 raise ValueError("Invalid base64 encoding.")
152 try:
153 # RFC 2617 HTTP Authentication
154 # https://www.ietf.org/rfc/rfc2617.txt
155 # the colon must be present, but the username and password may be
156 # otherwise blank.
157 username, password = decoded.split(":", 1)
158 except ValueError:
159 raise ValueError("Invalid credentials.")
161 return cls(username, password, encoding=encoding)
163 @classmethod
164 def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]:
165 """Create BasicAuth from url."""
166 if not isinstance(url, URL):
167 raise TypeError("url should be yarl.URL instance")
168 if url.user is None:
169 return None
170 return cls(url.user, url.password or "", encoding=encoding)
172 def encode(self) -> str:
173 """Encode credentials."""
174 creds = (f"{self.login}:{self.password}").encode(self.encoding)
175 return "Basic %s" % base64.b64encode(creds).decode(self.encoding)
178def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
179 auth = BasicAuth.from_url(url)
180 if auth is None:
181 return url, None
182 else:
183 return url.with_user(None), auth
186def netrc_from_env() -> Optional[netrc.netrc]:
187 """Load netrc from file.
189 Attempt to load it from the path specified by the env-var
190 NETRC or in the default location in the user's home directory.
192 Returns None if it couldn't be found or fails to parse.
193 """
194 netrc_env = os.environ.get("NETRC")
196 if netrc_env is not None:
197 netrc_path = Path(netrc_env)
198 else:
199 try:
200 home_dir = Path.home()
201 except RuntimeError as e: # pragma: no cover
202 # if pathlib can't resolve home, it may raise a RuntimeError
203 client_logger.debug(
204 "Could not resolve home directory when "
205 "trying to look for .netrc file: %s",
206 e,
207 )
208 return None
210 netrc_path = home_dir / ("_netrc" if IS_WINDOWS else ".netrc")
212 try:
213 return netrc.netrc(str(netrc_path))
214 except netrc.NetrcParseError as e:
215 client_logger.warning("Could not parse .netrc file: %s", e)
216 except OSError as e:
217 netrc_exists = False
218 with contextlib.suppress(OSError):
219 netrc_exists = netrc_path.is_file()
220 # we couldn't read the file (doesn't exist, permissions, etc.)
221 if netrc_env or netrc_exists:
222 # only warn if the environment wanted us to load it,
223 # or it appears like the default file does actually exist
224 client_logger.warning("Could not read .netrc file: %s", e)
226 return None
229@attr.s(auto_attribs=True, frozen=True, slots=True)
230class ProxyInfo:
231 proxy: URL
232 proxy_auth: Optional[BasicAuth]
235def basicauth_from_netrc(netrc_obj: Optional[netrc.netrc], host: str) -> BasicAuth:
236 """
237 Return :py:class:`~aiohttp.BasicAuth` credentials for ``host`` from ``netrc_obj``.
239 :raises LookupError: if ``netrc_obj`` is :py:data:`None` or if no
240 entry is found for the ``host``.
241 """
242 if netrc_obj is None:
243 raise LookupError("No .netrc file found")
244 auth_from_netrc = netrc_obj.authenticators(host)
246 if auth_from_netrc is None:
247 raise LookupError(f"No entry for {host!s} found in the `.netrc` file.")
248 login, account, password = auth_from_netrc
250 # TODO(PY311): username = login or account
251 # Up to python 3.10, account could be None if not specified,
252 # and login will be empty string if not specified. From 3.11,
253 # login and account will be empty string if not specified.
254 username = login if (login or account is None) else account
256 # TODO(PY311): Remove this, as password will be empty string
257 # if not specified
258 if password is None:
259 password = ""
261 return BasicAuth(username, password)
264def proxies_from_env() -> Dict[str, ProxyInfo]:
265 proxy_urls = {
266 k: URL(v)
267 for k, v in getproxies().items()
268 if k in ("http", "https", "ws", "wss")
269 }
270 netrc_obj = netrc_from_env()
271 stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()}
272 ret = {}
273 for proto, val in stripped.items():
274 proxy, auth = val
275 if proxy.scheme in ("https", "wss"):
276 client_logger.warning(
277 "%s proxies %s are not supported, ignoring", proxy.scheme.upper(), proxy
278 )
279 continue
280 if netrc_obj and auth is None:
281 if proxy.host is not None:
282 try:
283 auth = basicauth_from_netrc(netrc_obj, proxy.host)
284 except LookupError:
285 auth = None
286 ret[proto] = ProxyInfo(proxy, auth)
287 return ret
290def current_task(
291 loop: Optional[asyncio.AbstractEventLoop] = None,
292) -> "Optional[asyncio.Task[Any]]":
293 return asyncio.current_task(loop=loop)
296def get_running_loop(
297 loop: Optional[asyncio.AbstractEventLoop] = None,
298) -> asyncio.AbstractEventLoop:
299 if loop is None:
300 loop = asyncio.get_event_loop()
301 if not loop.is_running():
302 warnings.warn(
303 "The object should be created within an async function",
304 DeprecationWarning,
305 stacklevel=3,
306 )
307 if loop.get_debug():
308 internal_logger.warning(
309 "The object should be created within an async function", stack_info=True
310 )
311 return loop
314def isasyncgenfunction(obj: Any) -> bool:
315 func = getattr(inspect, "isasyncgenfunction", None)
316 if func is not None:
317 return func(obj) # type: ignore[no-any-return]
318 else:
319 return False
322def get_env_proxy_for_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
323 """Get a permitted proxy for the given URL from the env."""
324 if url.host is not None and proxy_bypass(url.host):
325 raise LookupError(f"Proxying is disallowed for `{url.host!r}`")
327 proxies_in_env = proxies_from_env()
328 try:
329 proxy_info = proxies_in_env[url.scheme]
330 except KeyError:
331 raise LookupError(f"No proxies found for `{url!s}` in the env")
332 else:
333 return proxy_info.proxy, proxy_info.proxy_auth
336@attr.s(auto_attribs=True, frozen=True, slots=True)
337class MimeType:
338 type: str
339 subtype: str
340 suffix: str
341 parameters: "MultiDictProxy[str]"
344@functools.lru_cache(maxsize=56)
345def parse_mimetype(mimetype: str) -> MimeType:
346 """Parses a MIME type into its components.
348 mimetype is a MIME type string.
350 Returns a MimeType object.
352 Example:
354 >>> parse_mimetype('text/html; charset=utf-8')
355 MimeType(type='text', subtype='html', suffix='',
356 parameters={'charset': 'utf-8'})
358 """
359 if not mimetype:
360 return MimeType(
361 type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict())
362 )
364 parts = mimetype.split(";")
365 params: MultiDict[str] = MultiDict()
366 for item in parts[1:]:
367 if not item:
368 continue
369 key, _, value = item.partition("=")
370 params.add(key.lower().strip(), value.strip(' "'))
372 fulltype = parts[0].strip().lower()
373 if fulltype == "*":
374 fulltype = "*/*"
376 mtype, _, stype = fulltype.partition("/")
377 stype, _, suffix = stype.partition("+")
379 return MimeType(
380 type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params)
381 )
384def guess_filename(obj: Any, default: Optional[str] = None) -> Optional[str]:
385 name = getattr(obj, "name", None)
386 if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">":
387 return Path(name).name
388 return default
391not_qtext_re = re.compile(r"[^\041\043-\133\135-\176]")
392QCONTENT = {chr(i) for i in range(0x20, 0x7F)} | {"\t"}
395def quoted_string(content: str) -> str:
396 """Return 7-bit content as quoted-string.
398 Format content into a quoted-string as defined in RFC5322 for
399 Internet Message Format. Notice that this is not the 8-bit HTTP
400 format, but the 7-bit email format. Content must be in usascii or
401 a ValueError is raised.
402 """
403 if not (QCONTENT > set(content)):
404 raise ValueError(f"bad content for quoted-string {content!r}")
405 return not_qtext_re.sub(lambda x: "\\" + x.group(0), content)
408def content_disposition_header(
409 disptype: str, quote_fields: bool = True, _charset: str = "utf-8", **params: str
410) -> str:
411 """Sets ``Content-Disposition`` header for MIME.
413 This is the MIME payload Content-Disposition header from RFC 2183
414 and RFC 7579 section 4.2, not the HTTP Content-Disposition from
415 RFC 6266.
417 disptype is a disposition type: inline, attachment, form-data.
418 Should be valid extension token (see RFC 2183)
420 quote_fields performs value quoting to 7-bit MIME headers
421 according to RFC 7578. Set to quote_fields to False if recipient
422 can take 8-bit file names and field values.
424 _charset specifies the charset to use when quote_fields is True.
426 params is a dict with disposition params.
427 """
428 if not disptype or not (TOKEN > set(disptype)):
429 raise ValueError("bad content disposition type {!r}" "".format(disptype))
431 value = disptype
432 if params:
433 lparams = []
434 for key, val in params.items():
435 if not key or not (TOKEN > set(key)):
436 raise ValueError(
437 "bad content disposition parameter" " {!r}={!r}".format(key, val)
438 )
439 if quote_fields:
440 if key.lower() == "filename":
441 qval = quote(val, "", encoding=_charset)
442 lparams.append((key, '"%s"' % qval))
443 else:
444 try:
445 qval = quoted_string(val)
446 except ValueError:
447 qval = "".join(
448 (_charset, "''", quote(val, "", encoding=_charset))
449 )
450 lparams.append((key + "*", qval))
451 else:
452 lparams.append((key, '"%s"' % qval))
453 else:
454 qval = val.replace("\\", "\\\\").replace('"', '\\"')
455 lparams.append((key, '"%s"' % qval))
456 sparams = "; ".join("=".join(pair) for pair in lparams)
457 value = "; ".join((value, sparams))
458 return value
461class _TSelf(Protocol, Generic[_T]):
462 _cache: Dict[str, _T]
465class reify(Generic[_T]):
466 """Use as a class method decorator.
468 It operates almost exactly like
469 the Python `@property` decorator, but it puts the result of the
470 method it decorates into the instance dict after the first call,
471 effectively replacing the function it decorates with an instance
472 variable. It is, in Python parlance, a data descriptor.
473 """
475 def __init__(self, wrapped: Callable[..., _T]) -> None:
476 self.wrapped = wrapped
477 self.__doc__ = wrapped.__doc__
478 self.name = wrapped.__name__
480 def __get__(self, inst: _TSelf[_T], owner: Optional[Type[Any]] = None) -> _T:
481 try:
482 try:
483 return inst._cache[self.name]
484 except KeyError:
485 val = self.wrapped(inst)
486 inst._cache[self.name] = val
487 return val
488 except AttributeError:
489 if inst is None:
490 return self
491 raise
493 def __set__(self, inst: _TSelf[_T], value: _T) -> None:
494 raise AttributeError("reified property is read-only")
497reify_py = reify
499try:
500 from ._helpers import reify as reify_c
502 if not NO_EXTENSIONS:
503 reify = reify_c # type: ignore[misc,assignment]
504except ImportError:
505 pass
507_ipv4_pattern = (
508 r"^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}"
509 r"(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$"
510)
511_ipv6_pattern = (
512 r"^(?:(?:(?:[A-F0-9]{1,4}:){6}|(?=(?:[A-F0-9]{0,4}:){0,6}"
513 r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}$)(([0-9A-F]{1,4}:){0,5}|:)"
514 r"((:[0-9A-F]{1,4}){1,5}:|:)|::(?:[A-F0-9]{1,4}:){5})"
515 r"(?:(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])\.){3}"
516 r"(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])|(?:[A-F0-9]{1,4}:){7}"
517 r"[A-F0-9]{1,4}|(?=(?:[A-F0-9]{0,4}:){0,7}[A-F0-9]{0,4}$)"
518 r"(([0-9A-F]{1,4}:){1,7}|:)((:[0-9A-F]{1,4}){1,7}|:)|(?:[A-F0-9]{1,4}:){7}"
519 r":|:(:[A-F0-9]{1,4}){7})$"
520)
521_ipv4_regex = re.compile(_ipv4_pattern)
522_ipv6_regex = re.compile(_ipv6_pattern, flags=re.IGNORECASE)
523_ipv4_regexb = re.compile(_ipv4_pattern.encode("ascii"))
524_ipv6_regexb = re.compile(_ipv6_pattern.encode("ascii"), flags=re.IGNORECASE)
527def _is_ip_address(
528 regex: Pattern[str], regexb: Pattern[bytes], host: Optional[Union[str, bytes]]
529) -> bool:
530 if host is None:
531 return False
532 if isinstance(host, str):
533 return bool(regex.match(host))
534 elif isinstance(host, (bytes, bytearray, memoryview)):
535 return bool(regexb.match(host))
536 else:
537 raise TypeError(f"{host} [{type(host)}] is not a str or bytes")
540is_ipv4_address = functools.partial(_is_ip_address, _ipv4_regex, _ipv4_regexb)
541is_ipv6_address = functools.partial(_is_ip_address, _ipv6_regex, _ipv6_regexb)
544def is_ip_address(host: Optional[Union[str, bytes, bytearray, memoryview]]) -> bool:
545 return is_ipv4_address(host) or is_ipv6_address(host)
548_cached_current_datetime: Optional[int] = None
549_cached_formatted_datetime = ""
552def rfc822_formatted_time() -> str:
553 global _cached_current_datetime
554 global _cached_formatted_datetime
556 now = int(time.time())
557 if now != _cached_current_datetime:
558 # Weekday and month names for HTTP date/time formatting;
559 # always English!
560 # Tuples are constants stored in codeobject!
561 _weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun")
562 _monthname = (
563 "", # Dummy so we can use 1-based month numbers
564 "Jan",
565 "Feb",
566 "Mar",
567 "Apr",
568 "May",
569 "Jun",
570 "Jul",
571 "Aug",
572 "Sep",
573 "Oct",
574 "Nov",
575 "Dec",
576 )
578 year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now)
579 _cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
580 _weekdayname[wd],
581 day,
582 _monthname[month],
583 year,
584 hh,
585 mm,
586 ss,
587 )
588 _cached_current_datetime = now
589 return _cached_formatted_datetime
592def _weakref_handle(info: "Tuple[weakref.ref[object], str]") -> None:
593 ref, name = info
594 ob = ref()
595 if ob is not None:
596 with suppress(Exception):
597 getattr(ob, name)()
600def weakref_handle(
601 ob: object,
602 name: str,
603 timeout: float,
604 loop: asyncio.AbstractEventLoop,
605 timeout_ceil_threshold: float = 5,
606) -> Optional[asyncio.TimerHandle]:
607 if timeout is not None and timeout > 0:
608 when = loop.time() + timeout
609 if timeout >= timeout_ceil_threshold:
610 when = ceil(when)
612 return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name))
613 return None
616def call_later(
617 cb: Callable[[], Any],
618 timeout: float,
619 loop: asyncio.AbstractEventLoop,
620 timeout_ceil_threshold: float = 5,
621) -> Optional[asyncio.TimerHandle]:
622 if timeout is not None and timeout > 0:
623 when = loop.time() + timeout
624 if timeout > timeout_ceil_threshold:
625 when = ceil(when)
626 return loop.call_at(when, cb)
627 return None
630class TimeoutHandle:
631 """Timeout handle"""
633 def __init__(
634 self,
635 loop: asyncio.AbstractEventLoop,
636 timeout: Optional[float],
637 ceil_threshold: float = 5,
638 ) -> None:
639 self._timeout = timeout
640 self._loop = loop
641 self._ceil_threshold = ceil_threshold
642 self._callbacks: List[
643 Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]]
644 ] = []
646 def register(
647 self, callback: Callable[..., None], *args: Any, **kwargs: Any
648 ) -> None:
649 self._callbacks.append((callback, args, kwargs))
651 def close(self) -> None:
652 self._callbacks.clear()
654 def start(self) -> Optional[asyncio.Handle]:
655 timeout = self._timeout
656 if timeout is not None and timeout > 0:
657 when = self._loop.time() + timeout
658 if timeout >= self._ceil_threshold:
659 when = ceil(when)
660 return self._loop.call_at(when, self.__call__)
661 else:
662 return None
664 def timer(self) -> "BaseTimerContext":
665 if self._timeout is not None and self._timeout > 0:
666 timer = TimerContext(self._loop)
667 self.register(timer.timeout)
668 return timer
669 else:
670 return TimerNoop()
672 def __call__(self) -> None:
673 for cb, args, kwargs in self._callbacks:
674 with suppress(Exception):
675 cb(*args, **kwargs)
677 self._callbacks.clear()
680class BaseTimerContext(ContextManager["BaseTimerContext"]):
681 def assert_timeout(self) -> None:
682 """Raise TimeoutError if timeout has been exceeded."""
685class TimerNoop(BaseTimerContext):
686 def __enter__(self) -> BaseTimerContext:
687 return self
689 def __exit__(
690 self,
691 exc_type: Optional[Type[BaseException]],
692 exc_val: Optional[BaseException],
693 exc_tb: Optional[TracebackType],
694 ) -> None:
695 return
698class TimerContext(BaseTimerContext):
699 """Low resolution timeout context manager"""
701 def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
702 self._loop = loop
703 self._tasks: List[asyncio.Task[Any]] = []
704 self._cancelled = False
706 def assert_timeout(self) -> None:
707 """Raise TimeoutError if timer has already been cancelled."""
708 if self._cancelled:
709 raise asyncio.TimeoutError from None
711 def __enter__(self) -> BaseTimerContext:
712 task = current_task(loop=self._loop)
714 if task is None:
715 raise RuntimeError(
716 "Timeout context manager should be used " "inside a task"
717 )
719 if self._cancelled:
720 raise asyncio.TimeoutError from None
722 self._tasks.append(task)
723 return self
725 def __exit__(
726 self,
727 exc_type: Optional[Type[BaseException]],
728 exc_val: Optional[BaseException],
729 exc_tb: Optional[TracebackType],
730 ) -> Optional[bool]:
731 if self._tasks:
732 self._tasks.pop()
734 if exc_type is asyncio.CancelledError and self._cancelled:
735 raise asyncio.TimeoutError from None
736 return None
738 def timeout(self) -> None:
739 if not self._cancelled:
740 for task in set(self._tasks):
741 task.cancel()
743 self._cancelled = True
746def ceil_timeout(
747 delay: Optional[float], ceil_threshold: float = 5
748) -> async_timeout.Timeout:
749 if delay is None or delay <= 0:
750 return async_timeout.timeout(None)
752 loop = get_running_loop()
753 now = loop.time()
754 when = now + delay
755 if delay > ceil_threshold:
756 when = ceil(when)
757 return async_timeout.timeout_at(when)
760class HeadersMixin:
761 ATTRS = frozenset(["_content_type", "_content_dict", "_stored_content_type"])
763 _headers: MultiMapping[str]
765 _content_type: Optional[str] = None
766 _content_dict: Optional[Dict[str, str]] = None
767 _stored_content_type: Union[str, None, _SENTINEL] = sentinel
769 def _parse_content_type(self, raw: Optional[str]) -> None:
770 self._stored_content_type = raw
771 if raw is None:
772 # default value according to RFC 2616
773 self._content_type = "application/octet-stream"
774 self._content_dict = {}
775 else:
776 msg = HeaderParser().parsestr("Content-Type: " + raw)
777 self._content_type = msg.get_content_type()
778 params = msg.get_params(())
779 self._content_dict = dict(params[1:]) # First element is content type again
781 @property
782 def content_type(self) -> str:
783 """The value of content part for Content-Type HTTP header."""
784 raw = self._headers.get(hdrs.CONTENT_TYPE)
785 if self._stored_content_type != raw:
786 self._parse_content_type(raw)
787 return self._content_type # type: ignore[return-value]
789 @property
790 def charset(self) -> Optional[str]:
791 """The value of charset part for Content-Type HTTP header."""
792 raw = self._headers.get(hdrs.CONTENT_TYPE)
793 if self._stored_content_type != raw:
794 self._parse_content_type(raw)
795 return self._content_dict.get("charset") # type: ignore[union-attr]
797 @property
798 def content_length(self) -> Optional[int]:
799 """The value of Content-Length HTTP header."""
800 content_length = self._headers.get(hdrs.CONTENT_LENGTH)
802 if content_length is not None:
803 return int(content_length)
804 else:
805 return None
808def set_result(fut: "asyncio.Future[_T]", result: _T) -> None:
809 if not fut.done():
810 fut.set_result(result)
813_EXC_SENTINEL = BaseException()
816class ErrorableProtocol(Protocol):
817 def set_exception(
818 self,
819 exc: BaseException,
820 exc_cause: BaseException = ...,
821 ) -> None:
822 ... # pragma: no cover
825def set_exception(
826 fut: "asyncio.Future[_T] | ErrorableProtocol",
827 exc: BaseException,
828 exc_cause: BaseException = _EXC_SENTINEL,
829) -> None:
830 """Set future exception.
832 If the future is marked as complete, this function is a no-op.
834 :param exc_cause: An exception that is a direct cause of ``exc``.
835 Only set if provided.
836 """
837 if asyncio.isfuture(fut) and fut.done():
838 return
840 exc_is_sentinel = exc_cause is _EXC_SENTINEL
841 exc_causes_itself = exc is exc_cause
842 if not exc_is_sentinel and not exc_causes_itself:
843 exc.__cause__ = exc_cause
845 fut.set_exception(exc)
848@functools.total_ordering
849class AppKey(Generic[_T]):
850 """Keys for static typing support in Application."""
852 __slots__ = ("_name", "_t", "__orig_class__")
854 # This may be set by Python when instantiating with a generic type. We need to
855 # support this, in order to support types that are not concrete classes,
856 # like Iterable, which can't be passed as the second parameter to __init__.
857 __orig_class__: Type[object]
859 def __init__(self, name: str, t: Optional[Type[_T]] = None):
860 # Prefix with module name to help deduplicate key names.
861 frame = inspect.currentframe()
862 while frame:
863 if frame.f_code.co_name == "<module>":
864 module: str = frame.f_globals["__name__"]
865 break
866 frame = frame.f_back
868 self._name = module + "." + name
869 self._t = t
871 def __lt__(self, other: object) -> bool:
872 if isinstance(other, AppKey):
873 return self._name < other._name
874 return True # Order AppKey above other types.
876 def __repr__(self) -> str:
877 t = self._t
878 if t is None:
879 with suppress(AttributeError):
880 # Set to type arg.
881 t = get_args(self.__orig_class__)[0]
883 if t is None:
884 t_repr = "<<Unknown>>"
885 elif isinstance(t, type):
886 if t.__module__ == "builtins":
887 t_repr = t.__qualname__
888 else:
889 t_repr = f"{t.__module__}.{t.__qualname__}"
890 else:
891 t_repr = repr(t)
892 return f"<AppKey({self._name}, type={t_repr})>"
895class ChainMapProxy(Mapping[Union[str, AppKey[Any]], Any]):
896 __slots__ = ("_maps",)
898 def __init__(self, maps: Iterable[Mapping[Union[str, AppKey[Any]], Any]]) -> None:
899 self._maps = tuple(maps)
901 def __init_subclass__(cls) -> None:
902 raise TypeError(
903 "Inheritance class {} from ChainMapProxy "
904 "is forbidden".format(cls.__name__)
905 )
907 @overload # type: ignore[override]
908 def __getitem__(self, key: AppKey[_T]) -> _T:
909 ...
911 @overload
912 def __getitem__(self, key: str) -> Any:
913 ...
915 def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any:
916 for mapping in self._maps:
917 try:
918 return mapping[key]
919 except KeyError:
920 pass
921 raise KeyError(key)
923 @overload # type: ignore[override]
924 def get(self, key: AppKey[_T], default: _S) -> Union[_T, _S]:
925 ...
927 @overload
928 def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]:
929 ...
931 @overload
932 def get(self, key: str, default: Any = ...) -> Any:
933 ...
935 def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any:
936 try:
937 return self[key]
938 except KeyError:
939 return default
941 def __len__(self) -> int:
942 # reuses stored hash values if possible
943 return len(set().union(*self._maps))
945 def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]:
946 d: Dict[Union[str, AppKey[Any]], Any] = {}
947 for mapping in reversed(self._maps):
948 # reuses stored hash values if possible
949 d.update(mapping)
950 return iter(d)
952 def __contains__(self, key: object) -> bool:
953 return any(key in m for m in self._maps)
955 def __bool__(self) -> bool:
956 return any(self._maps)
958 def __repr__(self) -> str:
959 content = ", ".join(map(repr, self._maps))
960 return f"ChainMapProxy({content})"
963# https://tools.ietf.org/html/rfc7232#section-2.3
964_ETAGC = r"[!\x23-\x7E\x80-\xff]+"
965_ETAGC_RE = re.compile(_ETAGC)
966_QUOTED_ETAG = rf'(W/)?"({_ETAGC})"'
967QUOTED_ETAG_RE = re.compile(_QUOTED_ETAG)
968LIST_QUOTED_ETAG_RE = re.compile(rf"({_QUOTED_ETAG})(?:\s*,\s*|$)|(.)")
970ETAG_ANY = "*"
973@attr.s(auto_attribs=True, frozen=True, slots=True)
974class ETag:
975 value: str
976 is_weak: bool = False
979def validate_etag_value(value: str) -> None:
980 if value != ETAG_ANY and not _ETAGC_RE.fullmatch(value):
981 raise ValueError(
982 f"Value {value!r} is not a valid etag. Maybe it contains '\"'?"
983 )
986def parse_http_date(date_str: Optional[str]) -> Optional[datetime.datetime]:
987 """Process a date string, return a datetime object"""
988 if date_str is not None:
989 timetuple = parsedate(date_str)
990 if timetuple is not None:
991 with suppress(ValueError):
992 return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc)
993 return None
996def must_be_empty_body(method: str, code: int) -> bool:
997 """Check if a request must return an empty body."""
998 return (
999 status_code_must_be_empty_body(code)
1000 or method_must_be_empty_body(method)
1001 or (200 <= code < 300 and method.upper() == hdrs.METH_CONNECT)
1002 )
1005def method_must_be_empty_body(method: str) -> bool:
1006 """Check if a method must return an empty body."""
1007 # https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
1008 # https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.2
1009 return method.upper() == hdrs.METH_HEAD
1012def status_code_must_be_empty_body(code: int) -> bool:
1013 """Check if a status code must return an empty body."""
1014 # https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
1015 return code in {204, 304} or 100 <= code < 200
1018def should_remove_content_length(method: str, code: int) -> bool:
1019 """Check if a Content-Length header should be removed.
1021 This should always be a subset of must_be_empty_body
1022 """
1023 # https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-8
1024 # https://www.rfc-editor.org/rfc/rfc9110.html#section-15.4.5-4
1025 return (
1026 code in {204, 304}
1027 or 100 <= code < 200
1028 or (200 <= code < 300 and method.upper() == hdrs.METH_CONNECT)
1029 )