Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/helpers.py: 39%
554 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:52 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:52 +0000
1"""Various helper functions"""
3import asyncio
4import base64
5import binascii
6import dataclasses
7import datetime
8import enum
9import functools
10import inspect
11import netrc
12import os
13import platform
14import re
15import sys
16import time
17import warnings
18import weakref
19from collections import namedtuple
20from contextlib import suppress
21from email.parser import HeaderParser
22from email.utils import parsedate
23from http.cookies import SimpleCookie
24from math import ceil
25from pathlib import Path
26from types import TracebackType
27from typing import (
28 Any,
29 Callable,
30 ContextManager,
31 Dict,
32 Generator,
33 Generic,
34 Iterable,
35 Iterator,
36 List,
37 Mapping,
38 Optional,
39 Pattern,
40 Tuple,
41 Type,
42 TypeVar,
43 Union,
44 overload,
45)
46from urllib.parse import quote
47from urllib.request import getproxies, proxy_bypass
49import async_timeout
50from multidict import CIMultiDict, MultiDict, MultiDictProxy
51from typing_extensions import Protocol, final
52from yarl import URL
54from . import hdrs
55from .log import client_logger
56from .typedefs import PathLike # noqa
58if sys.version_info >= (3, 8):
59 from typing import get_args
60else:
61 from typing_extensions import get_args
63__all__ = ("BasicAuth", "ChainMapProxy", "ETag")
65PY_38 = sys.version_info >= (3, 8)
66PY_310 = sys.version_info >= (3, 10)
68COOKIE_MAX_LENGTH = 4096
70_T = TypeVar("_T")
71_S = TypeVar("_S")
73_SENTINEL = enum.Enum("_SENTINEL", "sentinel")
74sentinel = _SENTINEL.sentinel
76NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS"))
78DEBUG = sys.flags.dev_mode or (
79 not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG"))
80)
83CHAR = {chr(i) for i in range(0, 128)}
84CTL = {chr(i) for i in range(0, 32)} | {
85 chr(127),
86}
87SEPARATORS = {
88 "(",
89 ")",
90 "<",
91 ">",
92 "@",
93 ",",
94 ";",
95 ":",
96 "\\",
97 '"',
98 "/",
99 "[",
100 "]",
101 "?",
102 "=",
103 "{",
104 "}",
105 " ",
106 chr(9),
107}
108TOKEN = CHAR ^ CTL ^ SEPARATORS
111class noop:
112 def __await__(self) -> Generator[None, None, None]:
113 yield
116if PY_38:
117 iscoroutinefunction = asyncio.iscoroutinefunction
118else:
120 def iscoroutinefunction(func: Any) -> bool: # type: ignore[misc]
121 while isinstance(func, functools.partial):
122 func = func.func
123 return asyncio.iscoroutinefunction(func)
126json_re = re.compile(r"(?:application/|[\w.-]+/[\w.+-]+?\+)json$", re.IGNORECASE)
129class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])):
130 """Http basic authentication helper."""
132 def __new__(
133 cls, login: str, password: str = "", encoding: str = "latin1"
134 ) -> "BasicAuth":
135 if login is None:
136 raise ValueError("None is not allowed as login value")
138 if password is None:
139 raise ValueError("None is not allowed as password value")
141 if ":" in login:
142 raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)')
144 return super().__new__(cls, login, password, encoding)
146 @classmethod
147 def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth":
148 """Create a BasicAuth object from an Authorization HTTP header."""
149 try:
150 auth_type, encoded_credentials = auth_header.split(" ", 1)
151 except ValueError:
152 raise ValueError("Could not parse authorization header.")
154 if auth_type.lower() != "basic":
155 raise ValueError("Unknown authorization method %s" % auth_type)
157 try:
158 decoded = base64.b64decode(
159 encoded_credentials.encode("ascii"), validate=True
160 ).decode(encoding)
161 except binascii.Error:
162 raise ValueError("Invalid base64 encoding.")
164 try:
165 # RFC 2617 HTTP Authentication
166 # https://www.ietf.org/rfc/rfc2617.txt
167 # the colon must be present, but the username and password may be
168 # otherwise blank.
169 username, password = decoded.split(":", 1)
170 except ValueError:
171 raise ValueError("Invalid credentials.")
173 return cls(username, password, encoding=encoding)
175 @classmethod
176 def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]:
177 """Create BasicAuth from url."""
178 if not isinstance(url, URL):
179 raise TypeError("url should be yarl.URL instance")
180 if url.user is None:
181 return None
182 return cls(url.user, url.password or "", encoding=encoding)
184 def encode(self) -> str:
185 """Encode credentials."""
186 creds = (f"{self.login}:{self.password}").encode(self.encoding)
187 return "Basic %s" % base64.b64encode(creds).decode(self.encoding)
190def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
191 auth = BasicAuth.from_url(url)
192 if auth is None:
193 return url, None
194 else:
195 return url.with_user(None), auth
198def netrc_from_env() -> Optional[netrc.netrc]:
199 """Load netrc from file.
201 Attempt to load it from the path specified by the env-var
202 NETRC or in the default location in the user's home directory.
204 Returns None if it couldn't be found or fails to parse.
205 """
206 netrc_env = os.environ.get("NETRC")
208 if netrc_env is not None:
209 netrc_path = Path(netrc_env)
210 else:
211 try:
212 home_dir = Path.home()
213 except RuntimeError as e: # pragma: no cover
214 # if pathlib can't resolve home, it may raise a RuntimeError
215 client_logger.debug(
216 "Could not resolve home directory when "
217 "trying to look for .netrc file: %s",
218 e,
219 )
220 return None
222 netrc_path = home_dir / (
223 "_netrc" if platform.system() == "Windows" else ".netrc"
224 )
226 try:
227 return netrc.netrc(str(netrc_path))
228 except netrc.NetrcParseError as e:
229 client_logger.warning("Could not parse .netrc file: %s", e)
230 except OSError as e:
231 # we couldn't read the file (doesn't exist, permissions, etc.)
232 if netrc_env or netrc_path.is_file():
233 # only warn if the environment wanted us to load it,
234 # or it appears like the default file does actually exist
235 client_logger.warning("Could not read .netrc file: %s", e)
237 return None
240@dataclasses.dataclass(frozen=True)
241class ProxyInfo:
242 proxy: URL
243 proxy_auth: Optional[BasicAuth]
246def basicauth_from_netrc(netrc_obj: Optional[netrc.netrc], host: str) -> BasicAuth:
247 """
248 Return :py:class:`~aiohttp.BasicAuth` credentials for ``host`` from ``netrc_obj``.
250 :raises LookupError: if ``netrc_obj`` is :py:data:`None` or if no
251 entry is found for the ``host``.
252 """
253 if netrc_obj is None:
254 raise LookupError("No .netrc file found")
255 auth_from_netrc = netrc_obj.authenticators(host)
257 if auth_from_netrc is None:
258 raise LookupError(f"No entry for {host!s} found in the `.netrc` file.")
259 login, account, password = auth_from_netrc
261 # TODO(PY311): username = login or account
262 # Up to python 3.10, account could be None if not specified,
263 # and login will be empty string if not specified. From 3.11,
264 # login and account will be empty string if not specified.
265 username = login if (login or account is None) else account
267 # TODO(PY311): Remove this, as password will be empty string
268 # if not specified
269 if password is None:
270 password = ""
272 return BasicAuth(username, password)
275def proxies_from_env() -> Dict[str, ProxyInfo]:
276 proxy_urls = {
277 k: URL(v)
278 for k, v in getproxies().items()
279 if k in ("http", "https", "ws", "wss")
280 }
281 netrc_obj = netrc_from_env()
282 stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()}
283 ret = {}
284 for proto, val in stripped.items():
285 proxy, auth = val
286 if proxy.scheme in ("https", "wss"):
287 client_logger.warning(
288 "%s proxies %s are not supported, ignoring", proxy.scheme.upper(), proxy
289 )
290 continue
291 if netrc_obj and auth is None:
292 if proxy.host is not None:
293 try:
294 auth = basicauth_from_netrc(netrc_obj, proxy.host)
295 except LookupError:
296 auth = None
297 ret[proto] = ProxyInfo(proxy, auth)
298 return ret
301def get_env_proxy_for_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
302 """Get a permitted proxy for the given URL from the env."""
303 if url.host is not None and proxy_bypass(url.host):
304 raise LookupError(f"Proxying is disallowed for `{url.host!r}`")
306 proxies_in_env = proxies_from_env()
307 try:
308 proxy_info = proxies_in_env[url.scheme]
309 except KeyError:
310 raise LookupError(f"No proxies found for `{url!s}` in the env")
311 else:
312 return proxy_info.proxy, proxy_info.proxy_auth
315@dataclasses.dataclass(frozen=True)
316class MimeType:
317 type: str
318 subtype: str
319 suffix: str
320 parameters: "MultiDictProxy[str]"
323@functools.lru_cache(maxsize=56)
324def parse_mimetype(mimetype: str) -> MimeType:
325 """Parses a MIME type into its components.
327 mimetype is a MIME type string.
329 Returns a MimeType object.
331 Example:
333 >>> parse_mimetype('text/html; charset=utf-8')
334 MimeType(type='text', subtype='html', suffix='',
335 parameters={'charset': 'utf-8'})
337 """
338 if not mimetype:
339 return MimeType(
340 type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict())
341 )
343 parts = mimetype.split(";")
344 params: MultiDict[str] = MultiDict()
345 for item in parts[1:]:
346 if not item:
347 continue
348 key, _, value = item.partition("=")
349 params.add(key.lower().strip(), value.strip(' "'))
351 fulltype = parts[0].strip().lower()
352 if fulltype == "*":
353 fulltype = "*/*"
355 mtype, _, stype = fulltype.partition("/")
356 stype, _, suffix = stype.partition("+")
358 return MimeType(
359 type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params)
360 )
363def guess_filename(obj: Any, default: Optional[str] = None) -> Optional[str]:
364 name = getattr(obj, "name", None)
365 if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">":
366 return Path(name).name
367 return default
370not_qtext_re = re.compile(r"[^\041\043-\133\135-\176]")
371QCONTENT = {chr(i) for i in range(0x20, 0x7F)} | {"\t"}
374def quoted_string(content: str) -> str:
375 """Return 7-bit content as quoted-string.
377 Format content into a quoted-string as defined in RFC5322 for
378 Internet Message Format. Notice that this is not the 8-bit HTTP
379 format, but the 7-bit email format. Content must be in usascii or
380 a ValueError is raised.
381 """
382 if not (QCONTENT > set(content)):
383 raise ValueError(f"bad content for quoted-string {content!r}")
384 return not_qtext_re.sub(lambda x: "\\" + x.group(0), content)
387def content_disposition_header(
388 disptype: str, quote_fields: bool = True, _charset: str = "utf-8", **params: str
389) -> str:
390 """Sets ``Content-Disposition`` header for MIME.
392 This is the MIME payload Content-Disposition header from RFC 2183
393 and RFC 7579 section 4.2, not the HTTP Content-Disposition from
394 RFC 6266.
396 disptype is a disposition type: inline, attachment, form-data.
397 Should be valid extension token (see RFC 2183)
399 quote_fields performs value quoting to 7-bit MIME headers
400 according to RFC 7578. Set to quote_fields to False if recipient
401 can take 8-bit file names and field values.
403 _charset specifies the charset to use when quote_fields is True.
405 params is a dict with disposition params.
406 """
407 if not disptype or not (TOKEN > set(disptype)):
408 raise ValueError("bad content disposition type {!r}" "".format(disptype))
410 value = disptype
411 if params:
412 lparams = []
413 for key, val in params.items():
414 if not key or not (TOKEN > set(key)):
415 raise ValueError(
416 "bad content disposition parameter" " {!r}={!r}".format(key, val)
417 )
418 if quote_fields:
419 if key.lower() == "filename":
420 qval = quote(val, "", encoding=_charset)
421 lparams.append((key, '"%s"' % qval))
422 else:
423 try:
424 qval = quoted_string(val)
425 except ValueError:
426 qval = "".join(
427 (_charset, "''", quote(val, "", encoding=_charset))
428 )
429 lparams.append((key + "*", qval))
430 else:
431 lparams.append((key, '"%s"' % qval))
432 else:
433 qval = val.replace("\\", "\\\\").replace('"', '\\"')
434 lparams.append((key, '"%s"' % qval))
435 sparams = "; ".join("=".join(pair) for pair in lparams)
436 value = "; ".join((value, sparams))
437 return value
440def is_expected_content_type(
441 response_content_type: str, expected_content_type: str
442) -> bool:
443 """Checks if received content type is processable as an expected one.
445 Both arguments should be given without parameters.
446 """
447 if expected_content_type == "application/json":
448 return json_re.match(response_content_type) is not None
449 return expected_content_type in response_content_type
452class _TSelf(Protocol, Generic[_T]):
453 _cache: Dict[str, _T]
456class reify(Generic[_T]):
457 """Use as a class method decorator.
459 It operates almost exactly like
460 the Python `@property` decorator, but it puts the result of the
461 method it decorates into the instance dict after the first call,
462 effectively replacing the function it decorates with an instance
463 variable. It is, in Python parlance, a data descriptor.
464 """
466 def __init__(self, wrapped: Callable[..., _T]) -> None:
467 self.wrapped = wrapped
468 self.__doc__ = wrapped.__doc__
469 self.name = wrapped.__name__
471 def __get__(self, inst: _TSelf[_T], owner: Optional[Type[Any]] = None) -> _T:
472 try:
473 try:
474 return inst._cache[self.name]
475 except KeyError:
476 val = self.wrapped(inst)
477 inst._cache[self.name] = val
478 return val
479 except AttributeError:
480 if inst is None:
481 return self
482 raise
484 def __set__(self, inst: _TSelf[_T], value: _T) -> None:
485 raise AttributeError("reified property is read-only")
488reify_py = reify
490try:
491 from ._helpers import reify as reify_c
493 if not NO_EXTENSIONS:
494 reify = reify_c # type: ignore[misc,assignment]
495except ImportError:
496 pass
498_ipv4_pattern = (
499 r"^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}"
500 r"(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$"
501)
502_ipv6_pattern = (
503 r"^(?:(?:(?:[A-F0-9]{1,4}:){6}|(?=(?:[A-F0-9]{0,4}:){0,6}"
504 r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}$)(([0-9A-F]{1,4}:){0,5}|:)"
505 r"((:[0-9A-F]{1,4}){1,5}:|:)|::(?:[A-F0-9]{1,4}:){5})"
506 r"(?:(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])\.){3}"
507 r"(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])|(?:[A-F0-9]{1,4}:){7}"
508 r"[A-F0-9]{1,4}|(?=(?:[A-F0-9]{0,4}:){0,7}[A-F0-9]{0,4}$)"
509 r"(([0-9A-F]{1,4}:){1,7}|:)((:[0-9A-F]{1,4}){1,7}|:)|(?:[A-F0-9]{1,4}:){7}"
510 r":|:(:[A-F0-9]{1,4}){7})$"
511)
512_ipv4_regex = re.compile(_ipv4_pattern)
513_ipv6_regex = re.compile(_ipv6_pattern, flags=re.IGNORECASE)
514_ipv4_regexb = re.compile(_ipv4_pattern.encode("ascii"))
515_ipv6_regexb = re.compile(_ipv6_pattern.encode("ascii"), flags=re.IGNORECASE)
518def _is_ip_address(
519 regex: Pattern[str], regexb: Pattern[bytes], host: Optional[Union[str, bytes]]
520) -> bool:
521 if host is None:
522 return False
523 if isinstance(host, str):
524 return bool(regex.match(host))
525 elif isinstance(host, (bytes, bytearray, memoryview)):
526 return bool(regexb.match(host))
527 else:
528 raise TypeError(f"{host} [{type(host)}] is not a str or bytes")
531is_ipv4_address = functools.partial(_is_ip_address, _ipv4_regex, _ipv4_regexb)
532is_ipv6_address = functools.partial(_is_ip_address, _ipv6_regex, _ipv6_regexb)
535def is_ip_address(host: Optional[Union[str, bytes, bytearray, memoryview]]) -> bool:
536 return is_ipv4_address(host) or is_ipv6_address(host)
539def next_whole_second() -> datetime.datetime:
540 """Return current time rounded up to the next whole second."""
541 return datetime.datetime.now(datetime.timezone.utc).replace(
542 microsecond=0
543 ) + datetime.timedelta(seconds=0)
546_cached_current_datetime: Optional[int] = None
547_cached_formatted_datetime = ""
550def rfc822_formatted_time() -> str:
551 global _cached_current_datetime
552 global _cached_formatted_datetime
554 now = int(time.time())
555 if now != _cached_current_datetime:
556 # Weekday and month names for HTTP date/time formatting;
557 # always English!
558 # Tuples are constants stored in codeobject!
559 _weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun")
560 _monthname = (
561 "", # Dummy so we can use 1-based month numbers
562 "Jan",
563 "Feb",
564 "Mar",
565 "Apr",
566 "May",
567 "Jun",
568 "Jul",
569 "Aug",
570 "Sep",
571 "Oct",
572 "Nov",
573 "Dec",
574 )
576 year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now)
577 _cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
578 _weekdayname[wd],
579 day,
580 _monthname[month],
581 year,
582 hh,
583 mm,
584 ss,
585 )
586 _cached_current_datetime = now
587 return _cached_formatted_datetime
590def _weakref_handle(info: "Tuple[weakref.ref[object], str]") -> None:
591 ref, name = info
592 ob = ref()
593 if ob is not None:
594 with suppress(Exception):
595 getattr(ob, name)()
598def weakref_handle(
599 ob: object,
600 name: str,
601 timeout: float,
602 loop: asyncio.AbstractEventLoop,
603 timeout_ceil_threshold: float = 5,
604) -> Optional[asyncio.TimerHandle]:
605 if timeout is not None and timeout > 0:
606 when = loop.time() + timeout
607 if timeout >= timeout_ceil_threshold:
608 when = ceil(when)
610 return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name))
611 return None
614def call_later(
615 cb: Callable[[], Any],
616 timeout: float,
617 loop: asyncio.AbstractEventLoop,
618 timeout_ceil_threshold: float = 5,
619) -> Optional[asyncio.TimerHandle]:
620 if timeout is not None and timeout > 0:
621 when = loop.time() + timeout
622 if timeout > timeout_ceil_threshold:
623 when = ceil(when)
624 return loop.call_at(when, cb)
625 return None
628class TimeoutHandle:
629 """Timeout handle"""
631 def __init__(
632 self,
633 loop: asyncio.AbstractEventLoop,
634 timeout: Optional[float],
635 ceil_threshold: float = 5,
636 ) -> None:
637 self._timeout = timeout
638 self._loop = loop
639 self._ceil_threshold = ceil_threshold
640 self._callbacks: List[
641 Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]]
642 ] = []
644 def register(
645 self, callback: Callable[..., None], *args: Any, **kwargs: Any
646 ) -> None:
647 self._callbacks.append((callback, args, kwargs))
649 def close(self) -> None:
650 self._callbacks.clear()
652 def start(self) -> Optional[asyncio.Handle]:
653 timeout = self._timeout
654 if timeout is not None and timeout > 0:
655 when = self._loop.time() + timeout
656 if timeout >= self._ceil_threshold:
657 when = ceil(when)
658 return self._loop.call_at(when, self.__call__)
659 else:
660 return None
662 def timer(self) -> "BaseTimerContext":
663 if self._timeout is not None and self._timeout > 0:
664 timer = TimerContext(self._loop)
665 self.register(timer.timeout)
666 return timer
667 else:
668 return TimerNoop()
670 def __call__(self) -> None:
671 for cb, args, kwargs in self._callbacks:
672 with suppress(Exception):
673 cb(*args, **kwargs)
675 self._callbacks.clear()
678class BaseTimerContext(ContextManager["BaseTimerContext"]):
679 def assert_timeout(self) -> None:
680 """Raise TimeoutError if timeout has been exceeded."""
683class TimerNoop(BaseTimerContext):
684 def __enter__(self) -> BaseTimerContext:
685 return self
687 def __exit__(
688 self,
689 exc_type: Optional[Type[BaseException]],
690 exc_val: Optional[BaseException],
691 exc_tb: Optional[TracebackType],
692 ) -> None:
693 return
696class TimerContext(BaseTimerContext):
697 """Low resolution timeout context manager"""
699 def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
700 self._loop = loop
701 self._tasks: List[asyncio.Task[Any]] = []
702 self._cancelled = False
704 def assert_timeout(self) -> None:
705 """Raise TimeoutError if timer has already been cancelled."""
706 if self._cancelled:
707 raise asyncio.TimeoutError from None
709 def __enter__(self) -> BaseTimerContext:
710 task = asyncio.current_task(loop=self._loop)
712 if task is None:
713 raise RuntimeError(
714 "Timeout context manager should be used " "inside a task"
715 )
717 if self._cancelled:
718 raise asyncio.TimeoutError from None
720 self._tasks.append(task)
721 return self
723 def __exit__(
724 self,
725 exc_type: Optional[Type[BaseException]],
726 exc_val: Optional[BaseException],
727 exc_tb: Optional[TracebackType],
728 ) -> Optional[bool]:
729 if self._tasks:
730 self._tasks.pop()
732 if exc_type is asyncio.CancelledError and self._cancelled:
733 raise asyncio.TimeoutError from None
734 return None
736 def timeout(self) -> None:
737 if not self._cancelled:
738 for task in set(self._tasks):
739 task.cancel()
741 self._cancelled = True
744def ceil_timeout(
745 delay: Optional[float], ceil_threshold: float = 5
746) -> async_timeout.Timeout:
747 if delay is None or delay <= 0:
748 return async_timeout.timeout(None)
750 loop = asyncio.get_running_loop()
751 now = loop.time()
752 when = now + delay
753 if delay > ceil_threshold:
754 when = ceil(when)
755 return async_timeout.timeout_at(when)
758class HeadersMixin:
759 __slots__ = ("_content_type", "_content_dict", "_stored_content_type")
761 def __init__(self) -> None:
762 super().__init__()
763 self._content_type: Optional[str] = None
764 self._content_dict: Optional[Dict[str, str]] = None
765 self._stored_content_type: Union[str, _SENTINEL] = sentinel
767 def _parse_content_type(self, raw: str) -> None:
768 self._stored_content_type = raw
769 if raw is None:
770 # default value according to RFC 2616
771 self._content_type = "application/octet-stream"
772 self._content_dict = {}
773 else:
774 msg = HeaderParser().parsestr("Content-Type: " + raw)
775 self._content_type = msg.get_content_type()
776 params = msg.get_params()
777 self._content_dict = dict(params[1:]) # First element is content type again
779 @property
780 def content_type(self) -> str:
781 """The value of content part for Content-Type HTTP header."""
782 raw = self._headers.get(hdrs.CONTENT_TYPE) # type: ignore[attr-defined]
783 if self._stored_content_type != raw:
784 self._parse_content_type(raw)
785 return self._content_type # type: ignore[return-value]
787 @property
788 def charset(self) -> Optional[str]:
789 """The value of charset part for Content-Type HTTP header."""
790 raw = self._headers.get(hdrs.CONTENT_TYPE) # type: ignore[attr-defined]
791 if self._stored_content_type != raw:
792 self._parse_content_type(raw)
793 return self._content_dict.get("charset") # type: ignore[union-attr]
795 @property
796 def content_length(self) -> Optional[int]:
797 """The value of Content-Length HTTP header."""
798 content_length = self._headers.get( # type: ignore[attr-defined]
799 hdrs.CONTENT_LENGTH
800 )
802 if content_length is not None:
803 return int(content_length)
804 else:
805 return None
808def set_result(fut: "asyncio.Future[_T]", result: _T) -> None:
809 if not fut.done():
810 fut.set_result(result)
813def set_exception(fut: "asyncio.Future[_T]", exc: BaseException) -> None:
814 if not fut.done():
815 fut.set_exception(exc)
818@functools.total_ordering
819class AppKey(Generic[_T]):
820 """Keys for static typing support in Application."""
822 __slots__ = ("_name", "_t", "__orig_class__")
824 # This may be set by Python when instantiating with a generic type. We need to
825 # support this, in order to support types that are not concrete classes,
826 # like Iterable, which can't be passed as the second parameter to __init__.
827 __orig_class__: Type[object]
829 def __init__(self, name: str, t: Optional[Type[_T]] = None):
830 # Prefix with module name to help deduplicate key names.
831 frame = inspect.currentframe()
832 while frame:
833 if frame.f_code.co_name == "<module>":
834 module: str = frame.f_globals["__name__"]
835 break
836 frame = frame.f_back
838 self._name = module + "." + name
839 self._t = t
841 def __lt__(self, other: object) -> bool:
842 if isinstance(other, AppKey):
843 return self._name < other._name
844 return True # Order AppKey above other types.
846 def __repr__(self) -> str:
847 t = self._t
848 if t is None:
849 with suppress(AttributeError):
850 # Set to type arg.
851 t = get_args(self.__orig_class__)[0]
853 if t is None:
854 t_repr = "<<Unkown>>"
855 elif isinstance(t, type):
856 if t.__module__ == "builtins":
857 t_repr = t.__qualname__
858 else:
859 t_repr = f"{t.__module__}.{t.__qualname__}"
860 else:
861 t_repr = repr(t)
862 return f"<AppKey({self._name}, type={t_repr})>"
865@final
866class ChainMapProxy(Mapping[Union[str, AppKey[Any]], Any]):
867 __slots__ = ("_maps",)
869 def __init__(self, maps: Iterable[Mapping[Union[str, AppKey[Any]], Any]]) -> None:
870 self._maps = tuple(maps)
872 def __init_subclass__(cls) -> None:
873 raise TypeError(
874 "Inheritance class {} from ChainMapProxy "
875 "is forbidden".format(cls.__name__)
876 )
878 @overload # type: ignore[override]
879 def __getitem__(self, key: AppKey[_T]) -> _T:
880 ...
882 @overload
883 def __getitem__(self, key: str) -> Any:
884 ...
886 def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any:
887 for mapping in self._maps:
888 try:
889 return mapping[key]
890 except KeyError:
891 pass
892 raise KeyError(key)
894 @overload # type: ignore[override]
895 def get(self, key: AppKey[_T], default: _S) -> Union[_T, _S]:
896 ...
898 @overload
899 def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]:
900 ...
902 @overload
903 def get(self, key: str, default: Any = ...) -> Any:
904 ...
906 def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any:
907 try:
908 return self[key]
909 except KeyError:
910 return default
912 def __len__(self) -> int:
913 # reuses stored hash values if possible
914 return len(set().union(*self._maps))
916 def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]:
917 d: Dict[Union[str, AppKey[Any]], Any] = {}
918 for mapping in reversed(self._maps):
919 # reuses stored hash values if possible
920 d.update(mapping)
921 return iter(d)
923 def __contains__(self, key: object) -> bool:
924 return any(key in m for m in self._maps)
926 def __bool__(self) -> bool:
927 return any(self._maps)
929 def __repr__(self) -> str:
930 content = ", ".join(map(repr, self._maps))
931 return f"ChainMapProxy({content})"
934class CookieMixin:
935 # The `_cookies` slots is not defined here because non-empty slots cannot
936 # be combined with an Exception base class, as is done in HTTPException.
937 # CookieMixin subclasses with slots should define the `_cookies`
938 # slot themselves.
939 __slots__ = ()
941 def __init__(self) -> None:
942 super().__init__()
943 # Mypy doesn't like that _cookies isn't in __slots__.
944 # See the comment on this class's __slots__ for why this is OK.
945 self._cookies: SimpleCookie[str] = SimpleCookie() # type: ignore[misc]
947 @property
948 def cookies(self) -> "SimpleCookie[str]":
949 return self._cookies
951 def set_cookie(
952 self,
953 name: str,
954 value: str,
955 *,
956 expires: Optional[str] = None,
957 domain: Optional[str] = None,
958 max_age: Optional[Union[int, str]] = None,
959 path: str = "/",
960 secure: Optional[bool] = None,
961 httponly: Optional[bool] = None,
962 version: Optional[str] = None,
963 samesite: Optional[str] = None,
964 ) -> None:
965 """Set or update response cookie.
967 Sets new cookie or updates existent with new value.
968 Also updates only those params which are not None.
969 """
970 old = self._cookies.get(name)
971 if old is not None and old.coded_value == "":
972 # deleted cookie
973 self._cookies.pop(name, None)
975 self._cookies[name] = value
976 c = self._cookies[name]
978 if expires is not None:
979 c["expires"] = expires
980 elif c.get("expires") == "Thu, 01 Jan 1970 00:00:00 GMT":
981 del c["expires"]
983 if domain is not None:
984 c["domain"] = domain
986 if max_age is not None:
987 c["max-age"] = str(max_age)
988 elif "max-age" in c:
989 del c["max-age"]
991 c["path"] = path
993 if secure is not None:
994 c["secure"] = secure
995 if httponly is not None:
996 c["httponly"] = httponly
997 if version is not None:
998 c["version"] = version
999 if samesite is not None:
1000 c["samesite"] = samesite
1002 if DEBUG:
1003 cookie_length = len(c.output(header="")[1:])
1004 if cookie_length > COOKIE_MAX_LENGTH:
1005 warnings.warn(
1006 "The size of is too large, it might get ignored by the client.",
1007 UserWarning,
1008 stacklevel=2,
1009 )
1011 def del_cookie(
1012 self, name: str, *, domain: Optional[str] = None, path: str = "/"
1013 ) -> None:
1014 """Delete cookie.
1016 Creates new empty expired cookie.
1017 """
1018 # TODO: do we need domain/path here?
1019 self._cookies.pop(name, None)
1020 self.set_cookie(
1021 name,
1022 "",
1023 max_age=0,
1024 expires="Thu, 01 Jan 1970 00:00:00 GMT",
1025 domain=domain,
1026 path=path,
1027 )
1030def populate_with_cookies(
1031 headers: "CIMultiDict[str]", cookies: "SimpleCookie[str]"
1032) -> None:
1033 for cookie in cookies.values():
1034 value = cookie.output(header="")[1:]
1035 headers.add(hdrs.SET_COOKIE, value)
1038# https://tools.ietf.org/html/rfc7232#section-2.3
1039_ETAGC = r"[!#-}\x80-\xff]+"
1040_ETAGC_RE = re.compile(_ETAGC)
1041_QUOTED_ETAG = rf'(W/)?"({_ETAGC})"'
1042QUOTED_ETAG_RE = re.compile(_QUOTED_ETAG)
1043LIST_QUOTED_ETAG_RE = re.compile(rf"({_QUOTED_ETAG})(?:\s*,\s*|$)|(.)")
1045ETAG_ANY = "*"
1048@dataclasses.dataclass(frozen=True)
1049class ETag:
1050 value: str
1051 is_weak: bool = False
1054def validate_etag_value(value: str) -> None:
1055 if value != ETAG_ANY and not _ETAGC_RE.fullmatch(value):
1056 raise ValueError(
1057 f"Value {value!r} is not a valid etag. Maybe it contains '\"'?"
1058 )
1061def parse_http_date(date_str: Optional[str]) -> Optional[datetime.datetime]:
1062 """Process a date string, return a datetime object"""
1063 if date_str is not None:
1064 timetuple = parsedate(date_str)
1065 if timetuple is not None:
1066 with suppress(ValueError):
1067 return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc)
1068 return None