Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/helpers.py: 38%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

571 statements  

1"""Various helper functions""" 

2 

3import asyncio 

4import base64 

5import contextlib 

6import dataclasses 

7import datetime 

8import enum 

9import functools 

10import inspect 

11import netrc 

12import os 

13import platform 

14import re 

15import sys 

16import time 

17import warnings 

18import weakref 

19from collections.abc import Callable, Iterable, Iterator, Mapping 

20from contextlib import suppress 

21from email.message import EmailMessage 

22from email.parser import HeaderParser 

23from email.policy import HTTP 

24from email.utils import parsedate 

25from http.cookies import SimpleCookie 

26from math import ceil 

27from pathlib import Path 

28from types import MappingProxyType, TracebackType 

29from typing import ( 

30 TYPE_CHECKING, 

31 Any, 

32 ContextManager, 

33 Generic, 

34 Protocol, 

35 TypeVar, 

36 Union, 

37 final, 

38 get_args, 

39 overload, 

40) 

41from urllib.parse import quote 

42from urllib.request import getproxies, proxy_bypass 

43 

44from multidict import CIMultiDict, MultiDict, MultiDictProxy 

45from propcache.api import under_cached_property as reify 

46from yarl import URL 

47 

48from . import hdrs 

49from .log import client_logger 

50from .typedefs import PathLike # noqa 

51 

52if sys.version_info >= (3, 11): 

53 import asyncio as async_timeout 

54else: 

55 import async_timeout 

56 

57if TYPE_CHECKING: 

58 from dataclasses import dataclass as frozen_dataclass_decorator 

59else: 

60 frozen_dataclass_decorator = functools.partial( 

61 dataclasses.dataclass, frozen=True, slots=True 

62 ) 

63 

64__all__ = ("ChainMapProxy", "ETag", "frozen_dataclass_decorator", "reify") 

65 

66# This is the default size/limit for several operations. 

67# Matches the max size we receive from sockets: 

68# https://github.com/python/cpython/blob/1857a40807daeae3a1bf5efb682de9c9ae6df845/Lib/asyncio/selector_events.py#L766 

69DEFAULT_CHUNK_SIZE = 2**18 # 256 KiB 

70COOKIE_MAX_LENGTH = 4096 

71_QUOTED_PAIR_SUB = re.compile(r"\\(.)") 

72_QUOTED_STRING = r'"(?:[^"\\]|\\.)*"' 

73_ESCAPED_COMMENT = r"(?:[^()\\]|\\.)*" 

74# Matches one element in a comma-separated header list. 

75# Group 1: content of a top-level quoted-string (quotes stripped). 

76# Group 2: an unquoted element (may contain parameter quoted-strings / comments). 

77_LIST_ELEMENT_RE = re.compile( 

78 rf""" 

79 [ \t]* 

80 (?: 

81 "( (?:[^"\\]|\\.)* )" # group 1: top-level quoted-string 

82 | ( # group 2: unquoted element 

83 (?: 

84 (?<=[^\s]=) {_QUOTED_STRING} # parameter quoted value 

85 | (?<=\s) \( {_ESCAPED_COMMENT} \) # comment 

86 | [^,] # any non-comma character 

87 )+? 

88 ) 

89 ) 

90 [ \t]* (?:,|\Z) 

91 """, 

92 re.VERBOSE, 

93) 

94# Finds parameter quoted-strings and comments inside an unquoted element for unescaping. 

95_PROTECTED_RE = re.compile( 

96 rf""" 

97 (?<=[^\s]=) {_QUOTED_STRING} # parameter quoted-string 

98 | (?<=\s) \( {_ESCAPED_COMMENT} \) # comment 

99 """, 

100 re.VERBOSE, 

101) 

102 

103_T = TypeVar("_T") 

104_S = TypeVar("_S") 

105 

106_SENTINEL = enum.Enum("_SENTINEL", "sentinel") 

107sentinel = _SENTINEL.sentinel 

108 

109NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS")) 

110 

111# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1 

112EMPTY_BODY_STATUS_CODES = frozenset((204, 304, *range(100, 200))) 

113# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1 

114# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.2 

115EMPTY_BODY_METHODS = hdrs.METH_HEAD_ALL 

116 

117DEBUG = sys.flags.dev_mode or ( 

118 not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG")) 

119) 

120 

121 

122CHAR = {chr(i) for i in range(0, 128)} 

123CTL = {chr(i) for i in range(0, 32)} | { 

124 chr(127), 

125} 

126SEPARATORS = { 

127 "(", 

128 ")", 

129 "<", 

130 ">", 

131 "@", 

132 ",", 

133 ";", 

134 ":", 

135 "\\", 

136 '"', 

137 "/", 

138 "[", 

139 "]", 

140 "?", 

141 "=", 

142 "{", 

143 "}", 

144 " ", 

145 chr(9), 

146} 

147TOKEN = CHAR ^ CTL ^ SEPARATORS 

148 

149 

150json_re = re.compile(r"^(?:application/|[\w.-]+/[\w.+-]+?\+)json$", re.IGNORECASE) 

151 

152 

153def encode_basic_auth(login: str, password: str = "", encoding: str = "utf-8") -> str: 

154 """Encode HTTP Basic Authentication credentials as an Authorization header value. 

155 

156 Returns a string of the form ``"Basic <base64>"`` suitable for use as the 

157 value of the ``Authorization`` (or ``Proxy-Authorization``) header. 

158 """ 

159 if ":" in login: 

160 raise ValueError('A ":" is not allowed in login (RFC 7617#section-2)') 

161 creds = f"{login}:{password}".encode(encoding) 

162 return "Basic " + base64.b64encode(creds).decode(encoding) 

163 

164 

165def strip_auth_from_url(url: URL) -> tuple[URL, str | None]: 

166 """Strip user/password from a URL and return the Authorization header value. 

167 

168 Returns a tuple of ``(url_without_credentials, authorization_header_value)``. 

169 The header value is ``None`` if no credentials were present. 

170 """ 

171 # Check raw_user and raw_password first as yarl is likely 

172 # to already have these values parsed from the netloc in the cache. 

173 if url.raw_user is None and url.raw_password is None: 

174 return url, None 

175 return url.with_user(None), encode_basic_auth(url.user or "", url.password or "") 

176 

177 

178def netrc_from_env() -> netrc.netrc | None: 

179 """Load netrc from file. 

180 

181 Attempt to load it from the path specified by the env-var 

182 NETRC or in the default location in the user's home directory. 

183 

184 Returns None if it couldn't be found or fails to parse. 

185 """ 

186 netrc_env = os.environ.get("NETRC") 

187 

188 if netrc_env is not None: 

189 netrc_path = Path(netrc_env) 

190 else: 

191 try: 

192 home_dir = Path.home() 

193 except RuntimeError as e: 

194 # if pathlib can't resolve home, it may raise a RuntimeError 

195 client_logger.debug( 

196 "Could not resolve home directory when " 

197 "trying to look for .netrc file: %s", 

198 e, 

199 ) 

200 return None 

201 

202 netrc_path = home_dir / ( 

203 "_netrc" if platform.system() == "Windows" else ".netrc" 

204 ) 

205 

206 try: 

207 return netrc.netrc(str(netrc_path)) 

208 except netrc.NetrcParseError as e: 

209 client_logger.warning("Could not parse .netrc file: %s", e) 

210 except OSError as e: 

211 netrc_exists = False 

212 with contextlib.suppress(OSError): 

213 netrc_exists = netrc_path.is_file() 

214 # we couldn't read the file (doesn't exist, permissions, etc.) 

215 if netrc_env or netrc_exists: 

216 # only warn if the environment wanted us to load it, 

217 # or it appears like the default file does actually exist 

218 client_logger.warning("Could not read .netrc file: %s", e) 

219 

220 return None 

221 

222 

223@frozen_dataclass_decorator 

224class ProxyInfo: 

225 proxy: URL 

226 proxy_auth: str | None 

227 

228 

229def _auth_header_from_netrc(netrc_obj: netrc.netrc | None, host: str) -> str: 

230 """Return a ``Proxy-Authorization`` header value for ``host`` from netrc. 

231 

232 :raises LookupError: if ``netrc_obj`` is :py:data:`None` or if no 

233 entry is found for the ``host``. 

234 """ 

235 if netrc_obj is None: 

236 raise LookupError("No .netrc file found") 

237 auth_from_netrc = netrc_obj.authenticators(host) 

238 

239 if auth_from_netrc is None: 

240 raise LookupError(f"No entry for {host!s} found in the `.netrc` file.") 

241 login, account, password = auth_from_netrc 

242 

243 # TODO(PY311): username = login or account 

244 # Up to python 3.10, account could be None if not specified, 

245 # and login will be empty string if not specified. From 3.11, 

246 # login and account will be empty string if not specified. 

247 username = login if (login or account is None) else account 

248 

249 # TODO(PY311): Remove this, as password will be empty string 

250 # if not specified 

251 if password is None: 

252 password = "" # type: ignore[unreachable] 

253 

254 return encode_basic_auth(username, password) 

255 

256 

257def proxies_from_env() -> dict[str, ProxyInfo]: 

258 proxy_urls = { 

259 k: URL(v) 

260 for k, v in getproxies().items() 

261 if k in ("http", "https", "ws", "wss") 

262 } 

263 netrc_obj = netrc_from_env() 

264 stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()} 

265 ret = {} 

266 for proto, val in stripped.items(): 

267 proxy, auth = val 

268 if proxy.scheme in ("https", "wss"): 

269 client_logger.warning( 

270 "%s proxies %s are not supported, ignoring", proxy.scheme.upper(), proxy 

271 ) 

272 continue 

273 if netrc_obj and auth is None: 

274 if proxy.host is not None: 

275 try: 

276 auth = _auth_header_from_netrc(netrc_obj, proxy.host) 

277 except LookupError: 

278 auth = None 

279 ret[proto] = ProxyInfo(proxy, auth) 

280 return ret 

281 

282 

283def get_env_proxy_for_url(url: URL) -> tuple[URL, str | None]: 

284 """Get a permitted proxy for the given URL from the env.""" 

285 if url.host is not None and proxy_bypass(url.host): 

286 raise LookupError(f"Proxying is disallowed for `{url.host!r}`") 

287 

288 proxies_in_env = proxies_from_env() 

289 try: 

290 proxy_info = proxies_in_env[url.scheme] 

291 except KeyError: 

292 raise LookupError(f"No proxies found for `{url!s}` in the env") 

293 else: 

294 return proxy_info.proxy, proxy_info.proxy_auth 

295 

296 

297@frozen_dataclass_decorator 

298class MimeType: 

299 type: str 

300 subtype: str 

301 suffix: str 

302 parameters: "MultiDictProxy[str]" 

303 

304 

305@functools.lru_cache(maxsize=56) 

306def parse_mimetype(mimetype: str) -> MimeType: 

307 """Parses a MIME type into its components. 

308 

309 mimetype is a MIME type string. 

310 

311 Returns a MimeType object. 

312 

313 Example: 

314 

315 >>> parse_mimetype('text/html; charset=utf-8') 

316 MimeType(type='text', subtype='html', suffix='', 

317 parameters={'charset': 'utf-8'}) 

318 

319 """ 

320 if not mimetype: 

321 return MimeType( 

322 type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict()) 

323 ) 

324 

325 parts = mimetype.split(";") 

326 params: MultiDict[str] = MultiDict() 

327 for item in parts[1:]: 

328 if not item: 

329 continue 

330 key, _, value = item.partition("=") 

331 params.add(key.lower().strip(), value.strip(' "')) 

332 

333 fulltype = parts[0].strip().lower() 

334 if fulltype == "*": 

335 fulltype = "*/*" 

336 

337 mtype, _, stype = fulltype.partition("/") 

338 stype, _, suffix = stype.partition("+") 

339 

340 return MimeType( 

341 type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params) 

342 ) 

343 

344 

345class EnsureOctetStream(EmailMessage): 

346 def __init__(self) -> None: 

347 super().__init__() 

348 # https://www.rfc-editor.org/rfc/rfc9110#section-8.3-5 

349 self.set_default_type("application/octet-stream") 

350 

351 def get_content_type(self) -> str: 

352 """Re-implementation from Message 

353 

354 Returns application/octet-stream in place of plain/text when 

355 value is wrong. 

356 

357 The way this class is used guarantees that content-type will 

358 be present so simplify the checks wrt to the base implementation. 

359 """ 

360 value = self.get("content-type", "").lower() 

361 

362 # Based on the implementation of _splitparam in the standard library 

363 ctype, _, _ = value.partition(";") 

364 ctype = ctype.strip() 

365 if ctype.count("/") != 1: 

366 return self.get_default_type() 

367 return ctype 

368 

369 

370@functools.lru_cache(maxsize=56) 

371def parse_content_type(raw: str) -> tuple[str, MappingProxyType[str, str]]: 

372 """Parse Content-Type header. 

373 

374 Returns a tuple of the parsed content type and a 

375 MappingProxyType of parameters. The default returned value 

376 is `application/octet-stream` 

377 """ 

378 msg = HeaderParser(EnsureOctetStream, policy=HTTP).parsestr(f"Content-Type: {raw}") 

379 content_type = msg.get_content_type() 

380 params = msg.get_params(()) 

381 content_dict = dict(params[1:]) # First element is content type again 

382 return content_type, MappingProxyType(content_dict) 

383 

384 

385def guess_filename(obj: Any, default: str | None = None) -> str | None: 

386 name = getattr(obj, "name", None) 

387 if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">": 

388 return Path(name).name 

389 return default 

390 

391 

392not_qtext_re = re.compile(r"[^\041\043-\133\135-\176]") 

393QCONTENT = {chr(i) for i in range(0x20, 0x7F)} | {"\t"} 

394 

395 

396def quoted_string(content: str) -> str: 

397 """Return 7-bit content as quoted-string. 

398 

399 Format content into a quoted-string as defined in RFC5322 for 

400 Internet Message Format. Notice that this is not the 8-bit HTTP 

401 format, but the 7-bit email format. Content must be in usascii or 

402 a ValueError is raised. 

403 """ 

404 if not (QCONTENT > set(content)): 

405 raise ValueError(f"bad content for quoted-string {content!r}") 

406 return not_qtext_re.sub(lambda x: "\\" + x.group(0), content) 

407 

408 

409def content_disposition_header( 

410 disptype: str, 

411 quote_fields: bool = True, 

412 _charset: str = "utf-8", 

413 params: dict[str, str] | None = None, 

414) -> str: 

415 """Sets ``Content-Disposition`` header for MIME. 

416 

417 This is the MIME payload Content-Disposition header from RFC 2183 

418 and RFC 7579 section 4.2, not the HTTP Content-Disposition from 

419 RFC 6266. 

420 

421 disptype is a disposition type: inline, attachment, form-data. 

422 Should be valid extension token (see RFC 2183) 

423 

424 quote_fields performs value quoting to 7-bit MIME headers 

425 according to RFC 7578. Set to quote_fields to False if recipient 

426 can take 8-bit file names and field values. 

427 

428 _charset specifies the charset to use when quote_fields is True. 

429 

430 params is a dict with disposition params. 

431 """ 

432 if not disptype or not (TOKEN > set(disptype)): 

433 raise ValueError(f"bad content disposition type {disptype!r}") 

434 

435 value = disptype 

436 if params: 

437 lparams = [] 

438 for key, val in params.items(): 

439 if not key or not (TOKEN > set(key)): 

440 raise ValueError(f"bad content disposition parameter {key!r}={val!r}") 

441 if quote_fields: 

442 if key.lower() == "filename": 

443 qval = quote(val, "", encoding=_charset) 

444 lparams.append((key, '"%s"' % qval)) 

445 else: 

446 try: 

447 qval = quoted_string(val) 

448 except ValueError: 

449 qval = "".join( 

450 (_charset, "''", quote(val, "", encoding=_charset)) 

451 ) 

452 lparams.append((key + "*", qval)) 

453 else: 

454 lparams.append((key, '"%s"' % qval)) 

455 else: 

456 qval = val.replace("\\", "\\\\").replace('"', '\\"') 

457 lparams.append((key, '"%s"' % qval)) 

458 sparams = "; ".join("=".join(pair) for pair in lparams) 

459 value = "; ".join((value, sparams)) 

460 return value 

461 

462 

463def is_expected_content_type( 

464 response_content_type: str, expected_content_type: str 

465) -> bool: 

466 """Checks if received content type is processable as an expected one. 

467 

468 Both arguments should be given without parameters. 

469 """ 

470 if expected_content_type == "application/json": 

471 return json_re.match(response_content_type) is not None 

472 return expected_content_type in response_content_type 

473 

474 

475def is_ip_address(host: str | None) -> bool: 

476 """Check if host looks like an IP Address. 

477 

478 This check is only meant as a heuristic to ensure that 

479 a host is not a domain name. 

480 """ 

481 if not host: 

482 return False 

483 # For a host to be an ipv4 address, it must be all numeric. 

484 # The host must contain a colon to be an IPv6 address. 

485 return ":" in host or host.replace(".", "").isdigit() 

486 

487 

488def is_canonical_ipv4_address(host: str) -> bool: 

489 """Check if host is a canonical dotted-quad IPv4 address. 

490 

491 Rejects the legacy numeric forms that ``socket`` still accepts and 

492 maps onto an address, e.g. ``2130706433``, ``017700000001``, ``127.1``. 

493 """ 

494 parts = host.split(".") 

495 if len(parts) != 4: 

496 return False 

497 for part in parts: 

498 # Each octet must be 1-3 ASCII digits; reject unicode digits 

499 # (which ``str.isdigit`` accepts but ``int`` may not), octal 

500 # leading zeros, and values above 255. 

501 if not (1 <= len(part) <= 3) or not part.isascii() or not part.isdigit(): 

502 return False 

503 if part[0] == "0" and len(part) != 1: 

504 return False 

505 if int(part) > 255: 

506 return False 

507 return True 

508 

509 

510_cached_current_datetime: int | None = None 

511_cached_formatted_datetime = "" 

512 

513 

514def rfc822_formatted_time() -> str: 

515 global _cached_current_datetime 

516 global _cached_formatted_datetime 

517 

518 now = int(time.time()) 

519 if now != _cached_current_datetime: 

520 # Weekday and month names for HTTP date/time formatting; 

521 # always English! 

522 # Tuples are constants stored in codeobject! 

523 _weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun") 

524 _monthname = ( 

525 "", # Dummy so we can use 1-based month numbers 

526 "Jan", 

527 "Feb", 

528 "Mar", 

529 "Apr", 

530 "May", 

531 "Jun", 

532 "Jul", 

533 "Aug", 

534 "Sep", 

535 "Oct", 

536 "Nov", 

537 "Dec", 

538 ) 

539 

540 year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now) 

541 _cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( 

542 _weekdayname[wd], 

543 day, 

544 _monthname[month], 

545 year, 

546 hh, 

547 mm, 

548 ss, 

549 ) 

550 _cached_current_datetime = now 

551 return _cached_formatted_datetime 

552 

553 

554def _weakref_handle(info: "tuple[weakref.ref[object], str]") -> None: 

555 ref, name = info 

556 ob = ref() 

557 if ob is not None: 

558 with suppress(Exception): 

559 getattr(ob, name)() 

560 

561 

562def weakref_handle( 

563 ob: object, 

564 name: str, 

565 timeout: float | None, 

566 loop: asyncio.AbstractEventLoop, 

567 timeout_ceil_threshold: float = 5, 

568) -> asyncio.TimerHandle | None: 

569 if timeout is not None and timeout > 0: 

570 when = loop.time() + timeout 

571 if timeout >= timeout_ceil_threshold: 

572 when = ceil(when) 

573 

574 return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name)) 

575 return None 

576 

577 

578def call_later( 

579 cb: Callable[[], Any], 

580 timeout: float | None, 

581 loop: asyncio.AbstractEventLoop, 

582 timeout_ceil_threshold: float = 5, 

583) -> asyncio.TimerHandle | None: 

584 if timeout is None or timeout <= 0: 

585 return None 

586 now = loop.time() 

587 when = calculate_timeout_when(now, timeout, timeout_ceil_threshold) 

588 return loop.call_at(when, cb) 

589 

590 

591def calculate_timeout_when( 

592 loop_time: float, 

593 timeout: float, 

594 timeout_ceiling_threshold: float, 

595) -> float: 

596 """Calculate when to execute a timeout.""" 

597 when = loop_time + timeout 

598 if timeout > timeout_ceiling_threshold: 

599 return ceil(when) 

600 return when 

601 

602 

603class TimeoutHandle: 

604 """Timeout handle""" 

605 

606 __slots__ = ("_timeout", "_loop", "_ceil_threshold", "_callbacks") 

607 

608 def __init__( 

609 self, 

610 loop: asyncio.AbstractEventLoop, 

611 timeout: float | None, 

612 ceil_threshold: float = 5, 

613 ) -> None: 

614 self._timeout = timeout 

615 self._loop = loop 

616 self._ceil_threshold = ceil_threshold 

617 self._callbacks: list[ 

618 tuple[Callable[..., None], tuple[Any, ...], dict[str, Any]] 

619 ] = [] 

620 

621 def register( 

622 self, callback: Callable[..., None], *args: Any, **kwargs: Any 

623 ) -> None: 

624 self._callbacks.append((callback, args, kwargs)) 

625 

626 def close(self) -> None: 

627 self._callbacks.clear() 

628 

629 def start(self) -> asyncio.TimerHandle | None: 

630 timeout = self._timeout 

631 if timeout is not None and timeout > 0: 

632 when = self._loop.time() + timeout 

633 if timeout >= self._ceil_threshold: 

634 when = ceil(when) 

635 return self._loop.call_at(when, self.__call__) 

636 else: 

637 return None 

638 

639 def timer(self) -> "BaseTimerContext": 

640 if self._timeout is not None and self._timeout > 0: 

641 timer = TimerContext(self._loop) 

642 self.register(timer.timeout) 

643 return timer 

644 else: 

645 return TimerNoop() 

646 

647 def __call__(self) -> None: 

648 for cb, args, kwargs in self._callbacks: 

649 with suppress(Exception): 

650 cb(*args, **kwargs) 

651 

652 self._callbacks.clear() 

653 

654 

655class BaseTimerContext(ContextManager["BaseTimerContext"]): 

656 

657 __slots__ = () 

658 

659 def assert_timeout(self) -> None: 

660 """Raise TimeoutError if timeout has been exceeded.""" 

661 

662 

663class TimerNoop(BaseTimerContext): 

664 

665 __slots__ = () 

666 

667 def __enter__(self) -> BaseTimerContext: 

668 return self 

669 

670 def __exit__( 

671 self, 

672 exc_type: type[BaseException] | None, 

673 exc_val: BaseException | None, 

674 exc_tb: TracebackType | None, 

675 ) -> None: 

676 return 

677 

678 

679class TimerContext(BaseTimerContext): 

680 """Low resolution timeout context manager""" 

681 

682 __slots__ = ("_loop", "_tasks", "_cancelled", "_cancelling") 

683 

684 def __init__(self, loop: asyncio.AbstractEventLoop) -> None: 

685 self._loop = loop 

686 self._tasks: list[asyncio.Task[Any]] = [] 

687 self._cancelled = False 

688 self._cancelling = 0 

689 

690 def assert_timeout(self) -> None: 

691 """Raise TimeoutError if timer has already been cancelled.""" 

692 if self._cancelled: 

693 raise asyncio.TimeoutError from None 

694 

695 def __enter__(self) -> BaseTimerContext: 

696 task = asyncio.current_task(loop=self._loop) 

697 if task is None: 

698 raise RuntimeError("Timeout context manager should be used inside a task") 

699 

700 if sys.version_info >= (3, 11): 

701 # Remember if the task was already cancelling 

702 # so when we __exit__ we can decide if we should 

703 # raise asyncio.TimeoutError or let the cancellation propagate 

704 self._cancelling = task.cancelling() 

705 

706 if self._cancelled: 

707 raise asyncio.TimeoutError from None 

708 

709 self._tasks.append(task) 

710 return self 

711 

712 def __exit__( 

713 self, 

714 exc_type: type[BaseException] | None, 

715 exc_val: BaseException | None, 

716 exc_tb: TracebackType | None, 

717 ) -> bool | None: 

718 enter_task: asyncio.Task[Any] | None = None 

719 if self._tasks: 

720 enter_task = self._tasks.pop() 

721 

722 if exc_type is asyncio.CancelledError and self._cancelled: 

723 assert enter_task is not None 

724 # The timeout was hit, and the task was cancelled 

725 # so we need to uncancel the last task that entered the context manager 

726 # since the cancellation should not leak out of the context manager 

727 if sys.version_info >= (3, 11): 

728 # If the task was already cancelling don't raise 

729 # asyncio.TimeoutError and instead return None 

730 # to allow the cancellation to propagate 

731 if enter_task.uncancel() > self._cancelling: 

732 return None 

733 raise asyncio.TimeoutError from exc_val 

734 return None 

735 

736 def timeout(self) -> None: 

737 if not self._cancelled: 

738 for task in set(self._tasks): 

739 task.cancel() 

740 

741 self._cancelled = True 

742 

743 

744def ceil_timeout( 

745 delay: float | None, ceil_threshold: float = 5 

746) -> async_timeout.Timeout: 

747 if delay is None or delay <= 0: 

748 return async_timeout.timeout(None) 

749 

750 loop = asyncio.get_running_loop() 

751 now = loop.time() 

752 when = now + delay 

753 if delay > ceil_threshold: 

754 when = ceil(when) 

755 return async_timeout.timeout_at(when) 

756 

757 

758class HeadersDictProxy(Mapping[str, str]): 

759 def __init__(self, md: CIMultiDict[str]): 

760 self._md = md 

761 

762 def getall(self, key: str) -> tuple[str, ...]: 

763 val = self.get(key, "") 

764 unescape = _QUOTED_PAIR_SUB.sub 

765 values = [] 

766 for m in _LIST_ELEMENT_RE.finditer(val): 

767 qs = m.group(1) 

768 if qs is not None: 

769 values.append(unescape(r"\1", qs)) 

770 else: 

771 raw = m.group(2).strip() 

772 if raw: 

773 values.append( 

774 _PROTECTED_RE.sub(lambda p: unescape(r"\1", p.group()), raw) 

775 ) 

776 return tuple(values) 

777 

778 def __eq__(self, other: object) -> bool: 

779 return self._md.__eq__(other) 

780 

781 def __getitem__(self, key: str) -> str: 

782 return ", ".join(self._md.getall(key)) 

783 

784 def __iter__(self) -> Iterator[str]: 

785 # We need to deduplicate keys from MultiDict 

786 # But, we also need to retain ordering 

787 seen = set() 

788 for k in self._md.__iter__(): 

789 if k in seen: 

790 continue 

791 seen.add(k) 

792 yield k 

793 

794 def __len__(self) -> int: 

795 return len(set(self._md.keys())) 

796 

797 def __repr__(self) -> str: 

798 body = ", ".join(f"'{k}': {v!r}" for k, v in self.items()) 

799 return f"<{self.__class__.__name__}({body})>" 

800 

801 

802class HeadersMixin: 

803 """Mixin for handling headers.""" 

804 

805 _headers: Mapping[str, str] 

806 _content_type: str | None = None 

807 _content_dict: dict[str, str] | None = None 

808 _stored_content_type: str | None | _SENTINEL = sentinel 

809 

810 def _parse_content_type(self, raw: str | None) -> None: 

811 self._stored_content_type = raw 

812 if raw is None: 

813 # default value according to RFC 2616 

814 self._content_type = "application/octet-stream" 

815 self._content_dict = {} 

816 else: 

817 content_type, content_mapping_proxy = parse_content_type(raw) 

818 self._content_type = content_type 

819 # _content_dict needs to be mutable so we can update it 

820 self._content_dict = content_mapping_proxy.copy() 

821 

822 @property 

823 def content_type(self) -> str: 

824 """The value of content part for Content-Type HTTP header.""" 

825 raw = self._headers.get(hdrs.CONTENT_TYPE) 

826 if self._stored_content_type != raw: 

827 self._parse_content_type(raw) 

828 assert self._content_type is not None 

829 return self._content_type 

830 

831 @property 

832 def charset(self) -> str | None: 

833 """The value of charset part for Content-Type HTTP header.""" 

834 raw = self._headers.get(hdrs.CONTENT_TYPE) 

835 if self._stored_content_type != raw: 

836 self._parse_content_type(raw) 

837 assert self._content_dict is not None 

838 return self._content_dict.get("charset") 

839 

840 @property 

841 def content_length(self) -> int | None: 

842 """The value of Content-Length HTTP header.""" 

843 content_length = self._headers.get(hdrs.CONTENT_LENGTH) 

844 return None if content_length is None else int(content_length) 

845 

846 

847def set_result(fut: "asyncio.Future[_T]", result: _T) -> None: 

848 if not fut.done(): 

849 fut.set_result(result) 

850 

851 

852_EXC_SENTINEL = BaseException() 

853 

854 

855class ErrorableProtocol(Protocol): 

856 def set_exception( 

857 self, 

858 exc: type[BaseException] | BaseException, 

859 exc_cause: BaseException = ..., 

860 ) -> None: ... 

861 

862 

863def set_exception( 

864 fut: Union["asyncio.Future[_T]", ErrorableProtocol], 

865 exc: type[BaseException] | BaseException, 

866 exc_cause: BaseException = _EXC_SENTINEL, 

867) -> None: 

868 """Set future exception. 

869 

870 If the future is marked as complete, this function is a no-op. 

871 

872 :param exc_cause: An exception that is a direct cause of ``exc``. 

873 Only set if provided. 

874 """ 

875 if asyncio.isfuture(fut) and fut.done(): 

876 return 

877 

878 exc_is_sentinel = exc_cause is _EXC_SENTINEL 

879 exc_causes_itself = exc is exc_cause 

880 if not exc_is_sentinel and not exc_causes_itself: 

881 exc.__cause__ = exc_cause 

882 

883 fut.set_exception(exc) 

884 

885 

886@functools.total_ordering 

887class BaseKey(Generic[_T]): 

888 """Base for concrete context storage key classes. 

889 

890 Each storage is provided with its own sub-class for the sake of some additional type safety. 

891 """ 

892 

893 __slots__ = ("_name", "_t", "__orig_class__") 

894 

895 # This may be set by Python when instantiating with a generic type. We need to 

896 # support this, in order to support types that are not concrete classes, 

897 # like Iterable, which can't be passed as the second parameter to __init__. 

898 __orig_class__: type[object] 

899 

900 # TODO(PY314): Change Type to TypeForm (this should resolve unreachable below). 

901 def __init__(self, name: str, t: type[_T] | None = None): 

902 # Prefix with module name to help deduplicate key names. 

903 frame = inspect.currentframe() 

904 while frame: 

905 if frame.f_code.co_name == "<module>": 

906 module: str = frame.f_globals["__name__"] 

907 break 

908 frame = frame.f_back 

909 else: 

910 raise RuntimeError("Failed to get module name.") 

911 

912 # https://github.com/python/mypy/issues/14209 

913 self._name = module + "." + name # type: ignore[possibly-undefined] 

914 self._t = t 

915 

916 def __lt__(self, other: object) -> bool: 

917 if isinstance(other, BaseKey): 

918 return self._name < other._name 

919 return True # Order BaseKey above other types. 

920 

921 def __repr__(self) -> str: 

922 t = self._t 

923 if t is None: 

924 with suppress(AttributeError): 

925 # Set to type arg. 

926 t = get_args(self.__orig_class__)[0] 

927 

928 if t is None: 

929 t_repr = "<<Unknown>>" 

930 elif isinstance(t, type): 

931 if t.__module__ == "builtins": 

932 t_repr = t.__qualname__ 

933 else: 

934 t_repr = f"{t.__module__}.{t.__qualname__}" 

935 else: 

936 t_repr = repr(t) # type: ignore[unreachable] 

937 return f"<{self.__class__.__name__}({self._name}, type={t_repr})>" 

938 

939 

940class AppKey(BaseKey[_T]): 

941 """Keys for static typing support in Application.""" 

942 

943 

944class RequestKey(BaseKey[_T]): 

945 """Keys for static typing support in Request.""" 

946 

947 

948class ResponseKey(BaseKey[_T]): 

949 """Keys for static typing support in Response.""" 

950 

951 

952@final 

953class ChainMapProxy(Mapping[str | AppKey[Any], Any]): 

954 __slots__ = ("_maps",) 

955 

956 def __init__(self, maps: Iterable[Mapping[str | AppKey[Any], Any]]) -> None: 

957 self._maps = tuple(maps) 

958 

959 def __init_subclass__(cls) -> None: 

960 raise TypeError( 

961 f"Inheritance class {cls.__name__} from ChainMapProxy is forbidden" 

962 ) 

963 

964 @overload # type: ignore[override] 

965 def __getitem__(self, key: AppKey[_T]) -> _T: ... 

966 

967 @overload 

968 def __getitem__(self, key: str) -> Any: ... 

969 

970 def __getitem__(self, key: str | AppKey[_T]) -> Any: 

971 for mapping in self._maps: 

972 try: 

973 return mapping[key] 

974 except KeyError: 

975 pass 

976 raise KeyError(key) 

977 

978 @overload # type: ignore[override] 

979 def get(self, key: AppKey[_T], default: _S) -> _T | _S: ... 

980 

981 @overload 

982 def get(self, key: AppKey[_T], default: None = ...) -> _T | None: ... 

983 

984 @overload 

985 def get(self, key: str, default: Any = ...) -> Any: ... 

986 

987 def get(self, key: str | AppKey[_T], default: Any = None) -> Any: 

988 try: 

989 return self[key] 

990 except KeyError: 

991 return default 

992 

993 def __len__(self) -> int: 

994 # reuses stored hash values if possible 

995 return len(set().union(*self._maps)) 

996 

997 def __iter__(self) -> Iterator[str | AppKey[Any]]: 

998 d: dict[str | AppKey[Any], Any] = {} 

999 for mapping in reversed(self._maps): 

1000 # reuses stored hash values if possible 

1001 d.update(mapping) 

1002 return iter(d) 

1003 

1004 def __contains__(self, key: object) -> bool: 

1005 return any(key in m for m in self._maps) 

1006 

1007 def __bool__(self) -> bool: 

1008 return any(self._maps) 

1009 

1010 def __repr__(self) -> str: 

1011 content = ", ".join(map(repr, self._maps)) 

1012 return f"ChainMapProxy({content})" 

1013 

1014 

1015class CookieMixin: 

1016 """Mixin for handling cookies.""" 

1017 

1018 _cookies: SimpleCookie | None = None 

1019 

1020 @property 

1021 def cookies(self) -> SimpleCookie: 

1022 if self._cookies is None: 

1023 self._cookies = SimpleCookie() 

1024 return self._cookies 

1025 

1026 def set_cookie( 

1027 self, 

1028 name: str, 

1029 value: str, 

1030 *, 

1031 expires: str | None = None, 

1032 domain: str | None = None, 

1033 max_age: int | str | None = None, 

1034 path: str = "/", 

1035 secure: bool | None = None, 

1036 httponly: bool | None = None, 

1037 samesite: str | None = None, 

1038 partitioned: bool | None = None, 

1039 ) -> None: 

1040 """Set or update response cookie. 

1041 

1042 Sets new cookie or updates existent with new value. 

1043 Also updates only those params which are not None. 

1044 """ 

1045 if self._cookies is None: 

1046 self._cookies = SimpleCookie() 

1047 

1048 self._cookies[name] = value 

1049 c = self._cookies[name] 

1050 

1051 if expires is not None: 

1052 c["expires"] = expires 

1053 elif c.get("expires") == "Thu, 01 Jan 1970 00:00:00 GMT": 

1054 del c["expires"] 

1055 

1056 if domain is not None: 

1057 c["domain"] = domain 

1058 

1059 if max_age is not None: 

1060 c["max-age"] = str(max_age) 

1061 elif "max-age" in c: 

1062 del c["max-age"] 

1063 

1064 c["path"] = path 

1065 

1066 if secure is not None: 

1067 c["secure"] = secure 

1068 if httponly is not None: 

1069 c["httponly"] = httponly 

1070 if samesite is not None: 

1071 c["samesite"] = samesite 

1072 

1073 if partitioned is not None: 

1074 c["partitioned"] = partitioned 

1075 

1076 if DEBUG: 

1077 cookie_length = len(c.output(header="")[1:]) 

1078 if cookie_length > COOKIE_MAX_LENGTH: 

1079 warnings.warn( 

1080 "The size of is too large, it might get ignored by the client.", 

1081 UserWarning, 

1082 stacklevel=2, 

1083 ) 

1084 

1085 def del_cookie( 

1086 self, 

1087 name: str, 

1088 *, 

1089 domain: str | None = None, 

1090 path: str = "/", 

1091 secure: bool | None = None, 

1092 httponly: bool | None = None, 

1093 samesite: str | None = None, 

1094 ) -> None: 

1095 """Delete cookie. 

1096 

1097 Creates new empty expired cookie. 

1098 """ 

1099 # TODO: do we need domain/path here? 

1100 if self._cookies is not None: 

1101 self._cookies.pop(name, None) 

1102 self.set_cookie( 

1103 name, 

1104 "", 

1105 max_age=0, 

1106 expires="Thu, 01 Jan 1970 00:00:00 GMT", 

1107 domain=domain, 

1108 path=path, 

1109 secure=secure, 

1110 httponly=httponly, 

1111 samesite=samesite, 

1112 ) 

1113 

1114 

1115def populate_with_cookies(headers: "CIMultiDict[str]", cookies: SimpleCookie) -> None: 

1116 for cookie in cookies.values(): 

1117 value = cookie.output(header="")[1:] 

1118 headers.add(hdrs.SET_COOKIE, value) 

1119 

1120 

1121# https://tools.ietf.org/html/rfc7232#section-2.3 

1122_ETAGC = r"[!\x23-\x7E\x80-\xff]+" 

1123_ETAGC_RE = re.compile(_ETAGC) 

1124_QUOTED_ETAG = rf'(W/)?"({_ETAGC})"' 

1125QUOTED_ETAG_RE = re.compile(_QUOTED_ETAG) 

1126LIST_QUOTED_ETAG_RE = re.compile(rf"({_QUOTED_ETAG})(?:\s*,\s*|$)|(.)") 

1127 

1128ETAG_ANY = "*" 

1129 

1130 

1131@frozen_dataclass_decorator 

1132class ETag: 

1133 value: str 

1134 is_weak: bool = False 

1135 

1136 

1137def validate_etag_value(value: str) -> None: 

1138 if value != ETAG_ANY and not _ETAGC_RE.fullmatch(value): 

1139 raise ValueError( 

1140 f"Value {value!r} is not a valid etag. Maybe it contains '\"'?" 

1141 ) 

1142 

1143 

1144def parse_http_date(date_str: str | None) -> datetime.datetime | None: 

1145 """Process a date string, return a datetime object""" 

1146 if date_str is not None: 

1147 timetuple = parsedate(date_str) 

1148 if timetuple is not None: 

1149 with suppress(ValueError): 

1150 return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc) 

1151 return None 

1152 

1153 

1154@functools.lru_cache 

1155def must_be_empty_body(method: str, code: int) -> bool: 

1156 """Check if a request must return an empty body.""" 

1157 return ( 

1158 code in EMPTY_BODY_STATUS_CODES 

1159 or method in EMPTY_BODY_METHODS 

1160 or (200 <= code < 300 and method in hdrs.METH_CONNECT_ALL) 

1161 ) 

1162 

1163 

1164def should_remove_content_length(method: str, code: int) -> bool: 

1165 """Check if a Content-Length header should be removed. 

1166 

1167 This should always be a subset of must_be_empty_body 

1168 """ 

1169 # https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-8 

1170 # https://www.rfc-editor.org/rfc/rfc9110.html#section-15.4.5-4 

1171 return code in EMPTY_BODY_STATUS_CODES or ( 

1172 200 <= code < 300 and method in hdrs.METH_CONNECT_ALL 

1173 )