Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/helpers.py: 39%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

559 statements  

1"""Various helper functions""" 

2 

3import asyncio 

4import base64 

5import contextlib 

6import dataclasses 

7import datetime 

8import enum 

9import functools 

10import inspect 

11import netrc 

12import os 

13import platform 

14import re 

15import sys 

16import time 

17import warnings 

18import weakref 

19from collections.abc import Callable, Iterable, Iterator, Mapping 

20from contextlib import suppress 

21from email.message import EmailMessage 

22from email.parser import HeaderParser 

23from email.policy import HTTP 

24from email.utils import parsedate 

25from http.cookies import SimpleCookie 

26from math import ceil 

27from pathlib import Path 

28from types import MappingProxyType, TracebackType 

29from typing import ( 

30 TYPE_CHECKING, 

31 Any, 

32 ContextManager, 

33 Generic, 

34 Protocol, 

35 TypeVar, 

36 Union, 

37 final, 

38 get_args, 

39 overload, 

40) 

41from urllib.parse import quote 

42from urllib.request import getproxies, proxy_bypass 

43 

44from multidict import CIMultiDict, MultiDict, MultiDictProxy 

45from propcache.api import under_cached_property as reify 

46from yarl import URL 

47 

48from . import hdrs 

49from .log import client_logger 

50from .typedefs import PathLike # noqa 

51 

52if sys.version_info >= (3, 11): 

53 import asyncio as async_timeout 

54else: 

55 import async_timeout 

56 

57if TYPE_CHECKING: 

58 from dataclasses import dataclass as frozen_dataclass_decorator 

59else: 

60 frozen_dataclass_decorator = functools.partial( 

61 dataclasses.dataclass, frozen=True, slots=True 

62 ) 

63 

64__all__ = ("ChainMapProxy", "ETag", "frozen_dataclass_decorator", "reify") 

65 

66# This is the default size/limit for several operations. 

67# Matches the max size we receive from sockets: 

68# https://github.com/python/cpython/blob/1857a40807daeae3a1bf5efb682de9c9ae6df845/Lib/asyncio/selector_events.py#L766 

69DEFAULT_CHUNK_SIZE = 2**18 # 256 KiB 

70COOKIE_MAX_LENGTH = 4096 

71_QUOTED_PAIR_SUB = re.compile(r"\\(.)") 

72_QUOTED_STRING = r'"(?:[^"\\]|\\.)*"' 

73_ESCAPED_COMMENT = r"(?:[^()\\]|\\.)*" 

74# Matches one element in a comma-separated header list. 

75# Group 1: content of a top-level quoted-string (quotes stripped). 

76# Group 2: an unquoted element (may contain parameter quoted-strings / comments). 

77_LIST_ELEMENT_RE = re.compile( 

78 rf""" 

79 [ \t]* 

80 (?: 

81 "( (?:[^"\\]|\\.)* )" # group 1: top-level quoted-string 

82 | ( # group 2: unquoted element 

83 (?: 

84 (?<=[^\s]=) {_QUOTED_STRING} # parameter quoted value 

85 | (?<=\s) \( {_ESCAPED_COMMENT} \) # comment 

86 | [^,] # any non-comma character 

87 )+? 

88 ) 

89 ) 

90 [ \t]* (?:,|\Z) 

91 """, 

92 re.VERBOSE, 

93) 

94# Finds parameter quoted-strings and comments inside an unquoted element for unescaping. 

95_PROTECTED_RE = re.compile( 

96 rf""" 

97 (?<=[^\s]=) {_QUOTED_STRING} # parameter quoted-string 

98 | (?<=\s) \( {_ESCAPED_COMMENT} \) # comment 

99 """, 

100 re.VERBOSE, 

101) 

102 

103_T = TypeVar("_T") 

104_S = TypeVar("_S") 

105 

106_SENTINEL = enum.Enum("_SENTINEL", "sentinel") 

107sentinel = _SENTINEL.sentinel 

108 

109NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS")) 

110 

111# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1 

112EMPTY_BODY_STATUS_CODES = frozenset((204, 304, *range(100, 200))) 

113# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1 

114# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.2 

115EMPTY_BODY_METHODS = hdrs.METH_HEAD_ALL 

116 

117DEBUG = sys.flags.dev_mode or ( 

118 not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG")) 

119) 

120 

121 

122CHAR = {chr(i) for i in range(0, 128)} 

123CTL = {chr(i) for i in range(0, 32)} | { 

124 chr(127), 

125} 

126SEPARATORS = { 

127 "(", 

128 ")", 

129 "<", 

130 ">", 

131 "@", 

132 ",", 

133 ";", 

134 ":", 

135 "\\", 

136 '"', 

137 "/", 

138 "[", 

139 "]", 

140 "?", 

141 "=", 

142 "{", 

143 "}", 

144 " ", 

145 chr(9), 

146} 

147TOKEN = CHAR ^ CTL ^ SEPARATORS 

148 

149 

150json_re = re.compile(r"^(?:application/|[\w.-]+/[\w.+-]+?\+)json$", re.IGNORECASE) 

151 

152 

153def encode_basic_auth(login: str, password: str = "", encoding: str = "utf-8") -> str: 

154 """Encode HTTP Basic Authentication credentials as an Authorization header value. 

155 

156 Returns a string of the form ``"Basic <base64>"`` suitable for use as the 

157 value of the ``Authorization`` (or ``Proxy-Authorization``) header. 

158 """ 

159 if ":" in login: 

160 raise ValueError('A ":" is not allowed in login (RFC 7617#section-2)') 

161 creds = f"{login}:{password}".encode(encoding) 

162 return "Basic " + base64.b64encode(creds).decode(encoding) 

163 

164 

165def strip_auth_from_url(url: URL) -> tuple[URL, str | None]: 

166 """Strip user/password from a URL and return the Authorization header value. 

167 

168 Returns a tuple of ``(url_without_credentials, authorization_header_value)``. 

169 The header value is ``None`` if no credentials were present. 

170 """ 

171 # Check raw_user and raw_password first as yarl is likely 

172 # to already have these values parsed from the netloc in the cache. 

173 if url.raw_user is None and url.raw_password is None: 

174 return url, None 

175 return url.with_user(None), encode_basic_auth(url.user or "", url.password or "") 

176 

177 

178def netrc_from_env() -> netrc.netrc | None: 

179 """Load netrc from file. 

180 

181 Attempt to load it from the path specified by the env-var 

182 NETRC or in the default location in the user's home directory. 

183 

184 Returns None if it couldn't be found or fails to parse. 

185 """ 

186 netrc_env = os.environ.get("NETRC") 

187 

188 if netrc_env is not None: 

189 netrc_path = Path(netrc_env) 

190 else: 

191 try: 

192 home_dir = Path.home() 

193 except RuntimeError as e: 

194 # if pathlib can't resolve home, it may raise a RuntimeError 

195 client_logger.debug( 

196 "Could not resolve home directory when " 

197 "trying to look for .netrc file: %s", 

198 e, 

199 ) 

200 return None 

201 

202 netrc_path = home_dir / ( 

203 "_netrc" if platform.system() == "Windows" else ".netrc" 

204 ) 

205 

206 try: 

207 return netrc.netrc(str(netrc_path)) 

208 except netrc.NetrcParseError as e: 

209 client_logger.warning("Could not parse .netrc file: %s", e) 

210 except OSError as e: 

211 netrc_exists = False 

212 with contextlib.suppress(OSError): 

213 netrc_exists = netrc_path.is_file() 

214 # we couldn't read the file (doesn't exist, permissions, etc.) 

215 if netrc_env or netrc_exists: 

216 # only warn if the environment wanted us to load it, 

217 # or it appears like the default file does actually exist 

218 client_logger.warning("Could not read .netrc file: %s", e) 

219 

220 return None 

221 

222 

223@frozen_dataclass_decorator 

224class ProxyInfo: 

225 proxy: URL 

226 proxy_auth: str | None 

227 

228 

229def _auth_header_from_netrc(netrc_obj: netrc.netrc | None, host: str) -> str: 

230 """Return a ``Proxy-Authorization`` header value for ``host`` from netrc. 

231 

232 :raises LookupError: if ``netrc_obj`` is :py:data:`None` or if no 

233 entry is found for the ``host``. 

234 """ 

235 if netrc_obj is None: 

236 raise LookupError("No .netrc file found") 

237 auth_from_netrc = netrc_obj.authenticators(host) 

238 

239 if auth_from_netrc is None: 

240 raise LookupError(f"No entry for {host!s} found in the `.netrc` file.") 

241 login, account, password = auth_from_netrc 

242 

243 # TODO(PY311): username = login or account 

244 # Up to python 3.10, account could be None if not specified, 

245 # and login will be empty string if not specified. From 3.11, 

246 # login and account will be empty string if not specified. 

247 username = login if (login or account is None) else account 

248 

249 # TODO(PY311): Remove this, as password will be empty string 

250 # if not specified 

251 if password is None: 

252 password = "" # type: ignore[unreachable] 

253 

254 return encode_basic_auth(username, password) 

255 

256 

257def proxies_from_env() -> dict[str, ProxyInfo]: 

258 proxy_urls = { 

259 k: URL(v) 

260 for k, v in getproxies().items() 

261 if k in ("http", "https", "ws", "wss") 

262 } 

263 netrc_obj = netrc_from_env() 

264 stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()} 

265 ret = {} 

266 for proto, val in stripped.items(): 

267 proxy, auth = val 

268 if proxy.scheme in ("https", "wss"): 

269 client_logger.warning( 

270 "%s proxies %s are not supported, ignoring", proxy.scheme.upper(), proxy 

271 ) 

272 continue 

273 if netrc_obj and auth is None: 

274 if proxy.host is not None: 

275 try: 

276 auth = _auth_header_from_netrc(netrc_obj, proxy.host) 

277 except LookupError: 

278 auth = None 

279 ret[proto] = ProxyInfo(proxy, auth) 

280 return ret 

281 

282 

283def get_env_proxy_for_url(url: URL) -> tuple[URL, str | None]: 

284 """Get a permitted proxy for the given URL from the env.""" 

285 if url.host is not None and proxy_bypass(url.host): 

286 raise LookupError(f"Proxying is disallowed for `{url.host!r}`") 

287 

288 proxies_in_env = proxies_from_env() 

289 try: 

290 proxy_info = proxies_in_env[url.scheme] 

291 except KeyError: 

292 raise LookupError(f"No proxies found for `{url!s}` in the env") 

293 else: 

294 return proxy_info.proxy, proxy_info.proxy_auth 

295 

296 

297@frozen_dataclass_decorator 

298class MimeType: 

299 type: str 

300 subtype: str 

301 suffix: str 

302 parameters: "MultiDictProxy[str]" 

303 

304 

305@functools.lru_cache(maxsize=56) 

306def parse_mimetype(mimetype: str) -> MimeType: 

307 """Parses a MIME type into its components. 

308 

309 mimetype is a MIME type string. 

310 

311 Returns a MimeType object. 

312 

313 Example: 

314 

315 >>> parse_mimetype('text/html; charset=utf-8') 

316 MimeType(type='text', subtype='html', suffix='', 

317 parameters={'charset': 'utf-8'}) 

318 

319 """ 

320 if not mimetype: 

321 return MimeType( 

322 type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict()) 

323 ) 

324 

325 parts = mimetype.split(";") 

326 params: MultiDict[str] = MultiDict() 

327 for item in parts[1:]: 

328 if not item: 

329 continue 

330 key, _, value = item.partition("=") 

331 params.add(key.lower().strip(), value.strip(' "')) 

332 

333 fulltype = parts[0].strip().lower() 

334 if fulltype == "*": 

335 fulltype = "*/*" 

336 

337 mtype, _, stype = fulltype.partition("/") 

338 stype, _, suffix = stype.partition("+") 

339 

340 return MimeType( 

341 type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params) 

342 ) 

343 

344 

345class EnsureOctetStream(EmailMessage): 

346 def __init__(self) -> None: 

347 super().__init__() 

348 # https://www.rfc-editor.org/rfc/rfc9110#section-8.3-5 

349 self.set_default_type("application/octet-stream") 

350 

351 def get_content_type(self) -> str: 

352 """Re-implementation from Message 

353 

354 Returns application/octet-stream in place of plain/text when 

355 value is wrong. 

356 

357 The way this class is used guarantees that content-type will 

358 be present so simplify the checks wrt to the base implementation. 

359 """ 

360 value = self.get("content-type", "").lower() 

361 

362 # Based on the implementation of _splitparam in the standard library 

363 ctype, _, _ = value.partition(";") 

364 ctype = ctype.strip() 

365 if ctype.count("/") != 1: 

366 return self.get_default_type() 

367 return ctype 

368 

369 

370@functools.lru_cache(maxsize=56) 

371def parse_content_type(raw: str) -> tuple[str, MappingProxyType[str, str]]: 

372 """Parse Content-Type header. 

373 

374 Returns a tuple of the parsed content type and a 

375 MappingProxyType of parameters. The default returned value 

376 is `application/octet-stream` 

377 """ 

378 msg = HeaderParser(EnsureOctetStream, policy=HTTP).parsestr(f"Content-Type: {raw}") 

379 content_type = msg.get_content_type() 

380 params = msg.get_params(()) 

381 content_dict = dict(params[1:]) # First element is content type again 

382 return content_type, MappingProxyType(content_dict) 

383 

384 

385def guess_filename(obj: Any, default: str | None = None) -> str | None: 

386 name = getattr(obj, "name", None) 

387 if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">": 

388 return Path(name).name 

389 return default 

390 

391 

392not_qtext_re = re.compile(r"[^\041\043-\133\135-\176]") 

393QCONTENT = {chr(i) for i in range(0x20, 0x7F)} | {"\t"} 

394 

395 

396def quoted_string(content: str) -> str: 

397 """Return 7-bit content as quoted-string. 

398 

399 Format content into a quoted-string as defined in RFC5322 for 

400 Internet Message Format. Notice that this is not the 8-bit HTTP 

401 format, but the 7-bit email format. Content must be in usascii or 

402 a ValueError is raised. 

403 """ 

404 if not (QCONTENT > set(content)): 

405 raise ValueError(f"bad content for quoted-string {content!r}") 

406 return not_qtext_re.sub(lambda x: "\\" + x.group(0), content) 

407 

408 

409def content_disposition_header( 

410 disptype: str, 

411 quote_fields: bool = True, 

412 _charset: str = "utf-8", 

413 params: dict[str, str] | None = None, 

414) -> str: 

415 """Sets ``Content-Disposition`` header for MIME. 

416 

417 This is the MIME payload Content-Disposition header from RFC 2183 

418 and RFC 7579 section 4.2, not the HTTP Content-Disposition from 

419 RFC 6266. 

420 

421 disptype is a disposition type: inline, attachment, form-data. 

422 Should be valid extension token (see RFC 2183) 

423 

424 quote_fields performs value quoting to 7-bit MIME headers 

425 according to RFC 7578. Set to quote_fields to False if recipient 

426 can take 8-bit file names and field values. 

427 

428 _charset specifies the charset to use when quote_fields is True. 

429 

430 params is a dict with disposition params. 

431 """ 

432 if not disptype or not (TOKEN > set(disptype)): 

433 raise ValueError(f"bad content disposition type {disptype!r}") 

434 

435 value = disptype 

436 if params: 

437 lparams = [] 

438 for key, val in params.items(): 

439 if not key or not (TOKEN > set(key)): 

440 raise ValueError(f"bad content disposition parameter {key!r}={val!r}") 

441 if quote_fields: 

442 if key.lower() == "filename": 

443 qval = quote(val, "", encoding=_charset) 

444 lparams.append((key, '"%s"' % qval)) 

445 else: 

446 try: 

447 qval = quoted_string(val) 

448 except ValueError: 

449 qval = "".join( 

450 (_charset, "''", quote(val, "", encoding=_charset)) 

451 ) 

452 lparams.append((key + "*", qval)) 

453 else: 

454 lparams.append((key, '"%s"' % qval)) 

455 else: 

456 qval = val.replace("\\", "\\\\").replace('"', '\\"') 

457 lparams.append((key, '"%s"' % qval)) 

458 sparams = "; ".join("=".join(pair) for pair in lparams) 

459 value = "; ".join((value, sparams)) 

460 return value 

461 

462 

463def is_expected_content_type( 

464 response_content_type: str, expected_content_type: str 

465) -> bool: 

466 """Checks if received content type is processable as an expected one. 

467 

468 Both arguments should be given without parameters. 

469 """ 

470 if expected_content_type == "application/json": 

471 return json_re.match(response_content_type) is not None 

472 return expected_content_type in response_content_type 

473 

474 

475def is_ip_address(host: str | None) -> bool: 

476 """Check if host looks like an IP Address. 

477 

478 This check is only meant as a heuristic to ensure that 

479 a host is not a domain name. 

480 """ 

481 if not host: 

482 return False 

483 # For a host to be an ipv4 address, it must be all numeric. 

484 # The host must contain a colon to be an IPv6 address. 

485 return ":" in host or host.replace(".", "").isdigit() 

486 

487 

488_cached_current_datetime: int | None = None 

489_cached_formatted_datetime = "" 

490 

491 

492def rfc822_formatted_time() -> str: 

493 global _cached_current_datetime 

494 global _cached_formatted_datetime 

495 

496 now = int(time.time()) 

497 if now != _cached_current_datetime: 

498 # Weekday and month names for HTTP date/time formatting; 

499 # always English! 

500 # Tuples are constants stored in codeobject! 

501 _weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun") 

502 _monthname = ( 

503 "", # Dummy so we can use 1-based month numbers 

504 "Jan", 

505 "Feb", 

506 "Mar", 

507 "Apr", 

508 "May", 

509 "Jun", 

510 "Jul", 

511 "Aug", 

512 "Sep", 

513 "Oct", 

514 "Nov", 

515 "Dec", 

516 ) 

517 

518 year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now) 

519 _cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( 

520 _weekdayname[wd], 

521 day, 

522 _monthname[month], 

523 year, 

524 hh, 

525 mm, 

526 ss, 

527 ) 

528 _cached_current_datetime = now 

529 return _cached_formatted_datetime 

530 

531 

532def _weakref_handle(info: "tuple[weakref.ref[object], str]") -> None: 

533 ref, name = info 

534 ob = ref() 

535 if ob is not None: 

536 with suppress(Exception): 

537 getattr(ob, name)() 

538 

539 

540def weakref_handle( 

541 ob: object, 

542 name: str, 

543 timeout: float | None, 

544 loop: asyncio.AbstractEventLoop, 

545 timeout_ceil_threshold: float = 5, 

546) -> asyncio.TimerHandle | None: 

547 if timeout is not None and timeout > 0: 

548 when = loop.time() + timeout 

549 if timeout >= timeout_ceil_threshold: 

550 when = ceil(when) 

551 

552 return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name)) 

553 return None 

554 

555 

556def call_later( 

557 cb: Callable[[], Any], 

558 timeout: float | None, 

559 loop: asyncio.AbstractEventLoop, 

560 timeout_ceil_threshold: float = 5, 

561) -> asyncio.TimerHandle | None: 

562 if timeout is None or timeout <= 0: 

563 return None 

564 now = loop.time() 

565 when = calculate_timeout_when(now, timeout, timeout_ceil_threshold) 

566 return loop.call_at(when, cb) 

567 

568 

569def calculate_timeout_when( 

570 loop_time: float, 

571 timeout: float, 

572 timeout_ceiling_threshold: float, 

573) -> float: 

574 """Calculate when to execute a timeout.""" 

575 when = loop_time + timeout 

576 if timeout > timeout_ceiling_threshold: 

577 return ceil(when) 

578 return when 

579 

580 

581class TimeoutHandle: 

582 """Timeout handle""" 

583 

584 __slots__ = ("_timeout", "_loop", "_ceil_threshold", "_callbacks") 

585 

586 def __init__( 

587 self, 

588 loop: asyncio.AbstractEventLoop, 

589 timeout: float | None, 

590 ceil_threshold: float = 5, 

591 ) -> None: 

592 self._timeout = timeout 

593 self._loop = loop 

594 self._ceil_threshold = ceil_threshold 

595 self._callbacks: list[ 

596 tuple[Callable[..., None], tuple[Any, ...], dict[str, Any]] 

597 ] = [] 

598 

599 def register( 

600 self, callback: Callable[..., None], *args: Any, **kwargs: Any 

601 ) -> None: 

602 self._callbacks.append((callback, args, kwargs)) 

603 

604 def close(self) -> None: 

605 self._callbacks.clear() 

606 

607 def start(self) -> asyncio.TimerHandle | None: 

608 timeout = self._timeout 

609 if timeout is not None and timeout > 0: 

610 when = self._loop.time() + timeout 

611 if timeout >= self._ceil_threshold: 

612 when = ceil(when) 

613 return self._loop.call_at(when, self.__call__) 

614 else: 

615 return None 

616 

617 def timer(self) -> "BaseTimerContext": 

618 if self._timeout is not None and self._timeout > 0: 

619 timer = TimerContext(self._loop) 

620 self.register(timer.timeout) 

621 return timer 

622 else: 

623 return TimerNoop() 

624 

625 def __call__(self) -> None: 

626 for cb, args, kwargs in self._callbacks: 

627 with suppress(Exception): 

628 cb(*args, **kwargs) 

629 

630 self._callbacks.clear() 

631 

632 

633class BaseTimerContext(ContextManager["BaseTimerContext"]): 

634 

635 __slots__ = () 

636 

637 def assert_timeout(self) -> None: 

638 """Raise TimeoutError if timeout has been exceeded.""" 

639 

640 

641class TimerNoop(BaseTimerContext): 

642 

643 __slots__ = () 

644 

645 def __enter__(self) -> BaseTimerContext: 

646 return self 

647 

648 def __exit__( 

649 self, 

650 exc_type: type[BaseException] | None, 

651 exc_val: BaseException | None, 

652 exc_tb: TracebackType | None, 

653 ) -> None: 

654 return 

655 

656 

657class TimerContext(BaseTimerContext): 

658 """Low resolution timeout context manager""" 

659 

660 __slots__ = ("_loop", "_tasks", "_cancelled", "_cancelling") 

661 

662 def __init__(self, loop: asyncio.AbstractEventLoop) -> None: 

663 self._loop = loop 

664 self._tasks: list[asyncio.Task[Any]] = [] 

665 self._cancelled = False 

666 self._cancelling = 0 

667 

668 def assert_timeout(self) -> None: 

669 """Raise TimeoutError if timer has already been cancelled.""" 

670 if self._cancelled: 

671 raise asyncio.TimeoutError from None 

672 

673 def __enter__(self) -> BaseTimerContext: 

674 task = asyncio.current_task(loop=self._loop) 

675 if task is None: 

676 raise RuntimeError("Timeout context manager should be used inside a task") 

677 

678 if sys.version_info >= (3, 11): 

679 # Remember if the task was already cancelling 

680 # so when we __exit__ we can decide if we should 

681 # raise asyncio.TimeoutError or let the cancellation propagate 

682 self._cancelling = task.cancelling() 

683 

684 if self._cancelled: 

685 raise asyncio.TimeoutError from None 

686 

687 self._tasks.append(task) 

688 return self 

689 

690 def __exit__( 

691 self, 

692 exc_type: type[BaseException] | None, 

693 exc_val: BaseException | None, 

694 exc_tb: TracebackType | None, 

695 ) -> bool | None: 

696 enter_task: asyncio.Task[Any] | None = None 

697 if self._tasks: 

698 enter_task = self._tasks.pop() 

699 

700 if exc_type is asyncio.CancelledError and self._cancelled: 

701 assert enter_task is not None 

702 # The timeout was hit, and the task was cancelled 

703 # so we need to uncancel the last task that entered the context manager 

704 # since the cancellation should not leak out of the context manager 

705 if sys.version_info >= (3, 11): 

706 # If the task was already cancelling don't raise 

707 # asyncio.TimeoutError and instead return None 

708 # to allow the cancellation to propagate 

709 if enter_task.uncancel() > self._cancelling: 

710 return None 

711 raise asyncio.TimeoutError from exc_val 

712 return None 

713 

714 def timeout(self) -> None: 

715 if not self._cancelled: 

716 for task in set(self._tasks): 

717 task.cancel() 

718 

719 self._cancelled = True 

720 

721 

722def ceil_timeout( 

723 delay: float | None, ceil_threshold: float = 5 

724) -> async_timeout.Timeout: 

725 if delay is None or delay <= 0: 

726 return async_timeout.timeout(None) 

727 

728 loop = asyncio.get_running_loop() 

729 now = loop.time() 

730 when = now + delay 

731 if delay > ceil_threshold: 

732 when = ceil(when) 

733 return async_timeout.timeout_at(when) 

734 

735 

736class HeadersDictProxy(Mapping[str, str]): 

737 def __init__(self, md: CIMultiDict[str]): 

738 self._md = md 

739 

740 def getall(self, key: str) -> tuple[str, ...]: 

741 val = self.get(key, "") 

742 unescape = _QUOTED_PAIR_SUB.sub 

743 values = [] 

744 for m in _LIST_ELEMENT_RE.finditer(val): 

745 qs = m.group(1) 

746 if qs is not None: 

747 values.append(unescape(r"\1", qs)) 

748 else: 

749 raw = m.group(2).strip() 

750 if raw: 

751 values.append( 

752 _PROTECTED_RE.sub(lambda p: unescape(r"\1", p.group()), raw) 

753 ) 

754 return tuple(values) 

755 

756 def __eq__(self, other: object) -> bool: 

757 return self._md.__eq__(other) 

758 

759 def __getitem__(self, key: str) -> str: 

760 return ", ".join(self._md.getall(key)) 

761 

762 def __iter__(self) -> Iterator[str]: 

763 # We need to deduplicate keys from MultiDict 

764 # But, we also need to retain ordering 

765 seen = set() 

766 for k in self._md.__iter__(): 

767 if k in seen: 

768 continue 

769 seen.add(k) 

770 yield k 

771 

772 def __len__(self) -> int: 

773 return len(set(self._md.keys())) 

774 

775 def __repr__(self) -> str: 

776 body = ", ".join(f"'{k}': {v!r}" for k, v in self.items()) 

777 return f"<{self.__class__.__name__}({body})>" 

778 

779 

780class HeadersMixin: 

781 """Mixin for handling headers.""" 

782 

783 _headers: Mapping[str, str] 

784 _content_type: str | None = None 

785 _content_dict: dict[str, str] | None = None 

786 _stored_content_type: str | None | _SENTINEL = sentinel 

787 

788 def _parse_content_type(self, raw: str | None) -> None: 

789 self._stored_content_type = raw 

790 if raw is None: 

791 # default value according to RFC 2616 

792 self._content_type = "application/octet-stream" 

793 self._content_dict = {} 

794 else: 

795 content_type, content_mapping_proxy = parse_content_type(raw) 

796 self._content_type = content_type 

797 # _content_dict needs to be mutable so we can update it 

798 self._content_dict = content_mapping_proxy.copy() 

799 

800 @property 

801 def content_type(self) -> str: 

802 """The value of content part for Content-Type HTTP header.""" 

803 raw = self._headers.get(hdrs.CONTENT_TYPE) 

804 if self._stored_content_type != raw: 

805 self._parse_content_type(raw) 

806 assert self._content_type is not None 

807 return self._content_type 

808 

809 @property 

810 def charset(self) -> str | None: 

811 """The value of charset part for Content-Type HTTP header.""" 

812 raw = self._headers.get(hdrs.CONTENT_TYPE) 

813 if self._stored_content_type != raw: 

814 self._parse_content_type(raw) 

815 assert self._content_dict is not None 

816 return self._content_dict.get("charset") 

817 

818 @property 

819 def content_length(self) -> int | None: 

820 """The value of Content-Length HTTP header.""" 

821 content_length = self._headers.get(hdrs.CONTENT_LENGTH) 

822 return None if content_length is None else int(content_length) 

823 

824 

825def set_result(fut: "asyncio.Future[_T]", result: _T) -> None: 

826 if not fut.done(): 

827 fut.set_result(result) 

828 

829 

830_EXC_SENTINEL = BaseException() 

831 

832 

833class ErrorableProtocol(Protocol): 

834 def set_exception( 

835 self, 

836 exc: type[BaseException] | BaseException, 

837 exc_cause: BaseException = ..., 

838 ) -> None: ... 

839 

840 

841def set_exception( 

842 fut: Union["asyncio.Future[_T]", ErrorableProtocol], 

843 exc: type[BaseException] | BaseException, 

844 exc_cause: BaseException = _EXC_SENTINEL, 

845) -> None: 

846 """Set future exception. 

847 

848 If the future is marked as complete, this function is a no-op. 

849 

850 :param exc_cause: An exception that is a direct cause of ``exc``. 

851 Only set if provided. 

852 """ 

853 if asyncio.isfuture(fut) and fut.done(): 

854 return 

855 

856 exc_is_sentinel = exc_cause is _EXC_SENTINEL 

857 exc_causes_itself = exc is exc_cause 

858 if not exc_is_sentinel and not exc_causes_itself: 

859 exc.__cause__ = exc_cause 

860 

861 fut.set_exception(exc) 

862 

863 

864@functools.total_ordering 

865class BaseKey(Generic[_T]): 

866 """Base for concrete context storage key classes. 

867 

868 Each storage is provided with its own sub-class for the sake of some additional type safety. 

869 """ 

870 

871 __slots__ = ("_name", "_t", "__orig_class__") 

872 

873 # This may be set by Python when instantiating with a generic type. We need to 

874 # support this, in order to support types that are not concrete classes, 

875 # like Iterable, which can't be passed as the second parameter to __init__. 

876 __orig_class__: type[object] 

877 

878 # TODO(PY314): Change Type to TypeForm (this should resolve unreachable below). 

879 def __init__(self, name: str, t: type[_T] | None = None): 

880 # Prefix with module name to help deduplicate key names. 

881 frame = inspect.currentframe() 

882 while frame: 

883 if frame.f_code.co_name == "<module>": 

884 module: str = frame.f_globals["__name__"] 

885 break 

886 frame = frame.f_back 

887 else: 

888 raise RuntimeError("Failed to get module name.") 

889 

890 # https://github.com/python/mypy/issues/14209 

891 self._name = module + "." + name # type: ignore[possibly-undefined] 

892 self._t = t 

893 

894 def __lt__(self, other: object) -> bool: 

895 if isinstance(other, BaseKey): 

896 return self._name < other._name 

897 return True # Order BaseKey above other types. 

898 

899 def __repr__(self) -> str: 

900 t = self._t 

901 if t is None: 

902 with suppress(AttributeError): 

903 # Set to type arg. 

904 t = get_args(self.__orig_class__)[0] 

905 

906 if t is None: 

907 t_repr = "<<Unknown>>" 

908 elif isinstance(t, type): 

909 if t.__module__ == "builtins": 

910 t_repr = t.__qualname__ 

911 else: 

912 t_repr = f"{t.__module__}.{t.__qualname__}" 

913 else: 

914 t_repr = repr(t) # type: ignore[unreachable] 

915 return f"<{self.__class__.__name__}({self._name}, type={t_repr})>" 

916 

917 

918class AppKey(BaseKey[_T]): 

919 """Keys for static typing support in Application.""" 

920 

921 

922class RequestKey(BaseKey[_T]): 

923 """Keys for static typing support in Request.""" 

924 

925 

926class ResponseKey(BaseKey[_T]): 

927 """Keys for static typing support in Response.""" 

928 

929 

930@final 

931class ChainMapProxy(Mapping[str | AppKey[Any], Any]): 

932 __slots__ = ("_maps",) 

933 

934 def __init__(self, maps: Iterable[Mapping[str | AppKey[Any], Any]]) -> None: 

935 self._maps = tuple(maps) 

936 

937 def __init_subclass__(cls) -> None: 

938 raise TypeError( 

939 f"Inheritance class {cls.__name__} from ChainMapProxy is forbidden" 

940 ) 

941 

942 @overload # type: ignore[override] 

943 def __getitem__(self, key: AppKey[_T]) -> _T: ... 

944 

945 @overload 

946 def __getitem__(self, key: str) -> Any: ... 

947 

948 def __getitem__(self, key: str | AppKey[_T]) -> Any: 

949 for mapping in self._maps: 

950 try: 

951 return mapping[key] 

952 except KeyError: 

953 pass 

954 raise KeyError(key) 

955 

956 @overload # type: ignore[override] 

957 def get(self, key: AppKey[_T], default: _S) -> _T | _S: ... 

958 

959 @overload 

960 def get(self, key: AppKey[_T], default: None = ...) -> _T | None: ... 

961 

962 @overload 

963 def get(self, key: str, default: Any = ...) -> Any: ... 

964 

965 def get(self, key: str | AppKey[_T], default: Any = None) -> Any: 

966 try: 

967 return self[key] 

968 except KeyError: 

969 return default 

970 

971 def __len__(self) -> int: 

972 # reuses stored hash values if possible 

973 return len(set().union(*self._maps)) 

974 

975 def __iter__(self) -> Iterator[str | AppKey[Any]]: 

976 d: dict[str | AppKey[Any], Any] = {} 

977 for mapping in reversed(self._maps): 

978 # reuses stored hash values if possible 

979 d.update(mapping) 

980 return iter(d) 

981 

982 def __contains__(self, key: object) -> bool: 

983 return any(key in m for m in self._maps) 

984 

985 def __bool__(self) -> bool: 

986 return any(self._maps) 

987 

988 def __repr__(self) -> str: 

989 content = ", ".join(map(repr, self._maps)) 

990 return f"ChainMapProxy({content})" 

991 

992 

993class CookieMixin: 

994 """Mixin for handling cookies.""" 

995 

996 _cookies: SimpleCookie | None = None 

997 

998 @property 

999 def cookies(self) -> SimpleCookie: 

1000 if self._cookies is None: 

1001 self._cookies = SimpleCookie() 

1002 return self._cookies 

1003 

1004 def set_cookie( 

1005 self, 

1006 name: str, 

1007 value: str, 

1008 *, 

1009 expires: str | None = None, 

1010 domain: str | None = None, 

1011 max_age: int | str | None = None, 

1012 path: str = "/", 

1013 secure: bool | None = None, 

1014 httponly: bool | None = None, 

1015 samesite: str | None = None, 

1016 partitioned: bool | None = None, 

1017 ) -> None: 

1018 """Set or update response cookie. 

1019 

1020 Sets new cookie or updates existent with new value. 

1021 Also updates only those params which are not None. 

1022 """ 

1023 if self._cookies is None: 

1024 self._cookies = SimpleCookie() 

1025 

1026 self._cookies[name] = value 

1027 c = self._cookies[name] 

1028 

1029 if expires is not None: 

1030 c["expires"] = expires 

1031 elif c.get("expires") == "Thu, 01 Jan 1970 00:00:00 GMT": 

1032 del c["expires"] 

1033 

1034 if domain is not None: 

1035 c["domain"] = domain 

1036 

1037 if max_age is not None: 

1038 c["max-age"] = str(max_age) 

1039 elif "max-age" in c: 

1040 del c["max-age"] 

1041 

1042 c["path"] = path 

1043 

1044 if secure is not None: 

1045 c["secure"] = secure 

1046 if httponly is not None: 

1047 c["httponly"] = httponly 

1048 if samesite is not None: 

1049 c["samesite"] = samesite 

1050 

1051 if partitioned is not None: 

1052 c["partitioned"] = partitioned 

1053 

1054 if DEBUG: 

1055 cookie_length = len(c.output(header="")[1:]) 

1056 if cookie_length > COOKIE_MAX_LENGTH: 

1057 warnings.warn( 

1058 "The size of is too large, it might get ignored by the client.", 

1059 UserWarning, 

1060 stacklevel=2, 

1061 ) 

1062 

1063 def del_cookie( 

1064 self, 

1065 name: str, 

1066 *, 

1067 domain: str | None = None, 

1068 path: str = "/", 

1069 secure: bool | None = None, 

1070 httponly: bool | None = None, 

1071 samesite: str | None = None, 

1072 ) -> None: 

1073 """Delete cookie. 

1074 

1075 Creates new empty expired cookie. 

1076 """ 

1077 # TODO: do we need domain/path here? 

1078 if self._cookies is not None: 

1079 self._cookies.pop(name, None) 

1080 self.set_cookie( 

1081 name, 

1082 "", 

1083 max_age=0, 

1084 expires="Thu, 01 Jan 1970 00:00:00 GMT", 

1085 domain=domain, 

1086 path=path, 

1087 secure=secure, 

1088 httponly=httponly, 

1089 samesite=samesite, 

1090 ) 

1091 

1092 

1093def populate_with_cookies(headers: "CIMultiDict[str]", cookies: SimpleCookie) -> None: 

1094 for cookie in cookies.values(): 

1095 value = cookie.output(header="")[1:] 

1096 headers.add(hdrs.SET_COOKIE, value) 

1097 

1098 

1099# https://tools.ietf.org/html/rfc7232#section-2.3 

1100_ETAGC = r"[!\x23-\x7E\x80-\xff]+" 

1101_ETAGC_RE = re.compile(_ETAGC) 

1102_QUOTED_ETAG = rf'(W/)?"({_ETAGC})"' 

1103QUOTED_ETAG_RE = re.compile(_QUOTED_ETAG) 

1104LIST_QUOTED_ETAG_RE = re.compile(rf"({_QUOTED_ETAG})(?:\s*,\s*|$)|(.)") 

1105 

1106ETAG_ANY = "*" 

1107 

1108 

1109@frozen_dataclass_decorator 

1110class ETag: 

1111 value: str 

1112 is_weak: bool = False 

1113 

1114 

1115def validate_etag_value(value: str) -> None: 

1116 if value != ETAG_ANY and not _ETAGC_RE.fullmatch(value): 

1117 raise ValueError( 

1118 f"Value {value!r} is not a valid etag. Maybe it contains '\"'?" 

1119 ) 

1120 

1121 

1122def parse_http_date(date_str: str | None) -> datetime.datetime | None: 

1123 """Process a date string, return a datetime object""" 

1124 if date_str is not None: 

1125 timetuple = parsedate(date_str) 

1126 if timetuple is not None: 

1127 with suppress(ValueError): 

1128 return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc) 

1129 return None 

1130 

1131 

1132@functools.lru_cache 

1133def must_be_empty_body(method: str, code: int) -> bool: 

1134 """Check if a request must return an empty body.""" 

1135 return ( 

1136 code in EMPTY_BODY_STATUS_CODES 

1137 or method in EMPTY_BODY_METHODS 

1138 or (200 <= code < 300 and method in hdrs.METH_CONNECT_ALL) 

1139 ) 

1140 

1141 

1142def should_remove_content_length(method: str, code: int) -> bool: 

1143 """Check if a Content-Length header should be removed. 

1144 

1145 This should always be a subset of must_be_empty_body 

1146 """ 

1147 # https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-8 

1148 # https://www.rfc-editor.org/rfc/rfc9110.html#section-15.4.5-4 

1149 return code in EMPTY_BODY_STATUS_CODES or ( 

1150 200 <= code < 300 and method in hdrs.METH_CONNECT_ALL 

1151 )