Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/helpers.py: 39%

556 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-26 06:16 +0000

1"""Various helper functions""" 

2 

3import asyncio 

4import base64 

5import binascii 

6import contextlib 

7import dataclasses 

8import datetime 

9import enum 

10import functools 

11import inspect 

12import netrc 

13import os 

14import platform 

15import re 

16import sys 

17import time 

18import warnings 

19import weakref 

20from collections import namedtuple 

21from contextlib import suppress 

22from email.parser import HeaderParser 

23from email.utils import parsedate 

24from http.cookies import SimpleCookie 

25from math import ceil 

26from pathlib import Path 

27from types import TracebackType 

28from typing import ( 

29 Any, 

30 Callable, 

31 ContextManager, 

32 Dict, 

33 Generator, 

34 Generic, 

35 Iterable, 

36 Iterator, 

37 List, 

38 Mapping, 

39 Optional, 

40 Pattern, 

41 Protocol, 

42 Tuple, 

43 Type, 

44 TypeVar, 

45 Union, 

46 final, 

47 get_args, 

48 overload, 

49) 

50from urllib.parse import quote 

51from urllib.request import getproxies, proxy_bypass 

52 

53from multidict import CIMultiDict, MultiDict, MultiDictProxy 

54from yarl import URL 

55 

56from . import hdrs 

57from .log import client_logger 

58from .typedefs import PathLike # noqa 

59 

60if sys.version_info >= (3, 11): 

61 import asyncio as async_timeout 

62else: 

63 import async_timeout 

64 

65__all__ = ("BasicAuth", "ChainMapProxy", "ETag") 

66 

67PY_310 = sys.version_info >= (3, 10) 

68 

69COOKIE_MAX_LENGTH = 4096 

70 

71_T = TypeVar("_T") 

72_S = TypeVar("_S") 

73 

74_SENTINEL = enum.Enum("_SENTINEL", "sentinel") 

75sentinel = _SENTINEL.sentinel 

76 

77NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS")) 

78 

79DEBUG = sys.flags.dev_mode or ( 

80 not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG")) 

81) 

82 

83 

84CHAR = {chr(i) for i in range(0, 128)} 

85CTL = {chr(i) for i in range(0, 32)} | { 

86 chr(127), 

87} 

88SEPARATORS = { 

89 "(", 

90 ")", 

91 "<", 

92 ">", 

93 "@", 

94 ",", 

95 ";", 

96 ":", 

97 "\\", 

98 '"', 

99 "/", 

100 "[", 

101 "]", 

102 "?", 

103 "=", 

104 "{", 

105 "}", 

106 " ", 

107 chr(9), 

108} 

109TOKEN = CHAR ^ CTL ^ SEPARATORS 

110 

111 

112class noop: 

113 def __await__(self) -> Generator[None, None, None]: 

114 yield 

115 

116 

117json_re = re.compile(r"(?:application/|[\w.-]+/[\w.+-]+?\+)json$", re.IGNORECASE) 

118 

119 

120class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])): 

121 """Http basic authentication helper.""" 

122 

123 def __new__( 

124 cls, login: str, password: str = "", encoding: str = "latin1" 

125 ) -> "BasicAuth": 

126 if login is None: 

127 raise ValueError("None is not allowed as login value") 

128 

129 if password is None: 

130 raise ValueError("None is not allowed as password value") 

131 

132 if ":" in login: 

133 raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)') 

134 

135 return super().__new__(cls, login, password, encoding) 

136 

137 @classmethod 

138 def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth": 

139 """Create a BasicAuth object from an Authorization HTTP header.""" 

140 try: 

141 auth_type, encoded_credentials = auth_header.split(" ", 1) 

142 except ValueError: 

143 raise ValueError("Could not parse authorization header.") 

144 

145 if auth_type.lower() != "basic": 

146 raise ValueError("Unknown authorization method %s" % auth_type) 

147 

148 try: 

149 decoded = base64.b64decode( 

150 encoded_credentials.encode("ascii"), validate=True 

151 ).decode(encoding) 

152 except binascii.Error: 

153 raise ValueError("Invalid base64 encoding.") 

154 

155 try: 

156 # RFC 2617 HTTP Authentication 

157 # https://www.ietf.org/rfc/rfc2617.txt 

158 # the colon must be present, but the username and password may be 

159 # otherwise blank. 

160 username, password = decoded.split(":", 1) 

161 except ValueError: 

162 raise ValueError("Invalid credentials.") 

163 

164 return cls(username, password, encoding=encoding) 

165 

166 @classmethod 

167 def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]: 

168 """Create BasicAuth from url.""" 

169 if not isinstance(url, URL): 

170 raise TypeError("url should be yarl.URL instance") 

171 if url.user is None: 

172 return None 

173 return cls(url.user, url.password or "", encoding=encoding) 

174 

175 def encode(self) -> str: 

176 """Encode credentials.""" 

177 creds = (f"{self.login}:{self.password}").encode(self.encoding) 

178 return "Basic %s" % base64.b64encode(creds).decode(self.encoding) 

179 

180 

181def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]: 

182 auth = BasicAuth.from_url(url) 

183 if auth is None: 

184 return url, None 

185 else: 

186 return url.with_user(None), auth 

187 

188 

189def netrc_from_env() -> Optional[netrc.netrc]: 

190 """Load netrc from file. 

191 

192 Attempt to load it from the path specified by the env-var 

193 NETRC or in the default location in the user's home directory. 

194 

195 Returns None if it couldn't be found or fails to parse. 

196 """ 

197 netrc_env = os.environ.get("NETRC") 

198 

199 if netrc_env is not None: 

200 netrc_path = Path(netrc_env) 

201 else: 

202 try: 

203 home_dir = Path.home() 

204 except RuntimeError as e: # pragma: no cover 

205 # if pathlib can't resolve home, it may raise a RuntimeError 

206 client_logger.debug( 

207 "Could not resolve home directory when " 

208 "trying to look for .netrc file: %s", 

209 e, 

210 ) 

211 return None 

212 

213 netrc_path = home_dir / ( 

214 "_netrc" if platform.system() == "Windows" else ".netrc" 

215 ) 

216 

217 try: 

218 return netrc.netrc(str(netrc_path)) 

219 except netrc.NetrcParseError as e: 

220 client_logger.warning("Could not parse .netrc file: %s", e) 

221 except OSError as e: 

222 netrc_exists = False 

223 with contextlib.suppress(OSError): 

224 netrc_exists = netrc_path.is_file() 

225 # we couldn't read the file (doesn't exist, permissions, etc.) 

226 if netrc_env or netrc_exists: 

227 # only warn if the environment wanted us to load it, 

228 # or it appears like the default file does actually exist 

229 client_logger.warning("Could not read .netrc file: %s", e) 

230 

231 return None 

232 

233 

234@dataclasses.dataclass(frozen=True) 

235class ProxyInfo: 

236 proxy: URL 

237 proxy_auth: Optional[BasicAuth] 

238 

239 

240def basicauth_from_netrc(netrc_obj: Optional[netrc.netrc], host: str) -> BasicAuth: 

241 """ 

242 Return :py:class:`~aiohttp.BasicAuth` credentials for ``host`` from ``netrc_obj``. 

243 

244 :raises LookupError: if ``netrc_obj`` is :py:data:`None` or if no 

245 entry is found for the ``host``. 

246 """ 

247 if netrc_obj is None: 

248 raise LookupError("No .netrc file found") 

249 auth_from_netrc = netrc_obj.authenticators(host) 

250 

251 if auth_from_netrc is None: 

252 raise LookupError(f"No entry for {host!s} found in the `.netrc` file.") 

253 login, account, password = auth_from_netrc 

254 

255 # TODO(PY311): username = login or account 

256 # Up to python 3.10, account could be None if not specified, 

257 # and login will be empty string if not specified. From 3.11, 

258 # login and account will be empty string if not specified. 

259 username = login if (login or account is None) else account 

260 

261 # TODO(PY311): Remove this, as password will be empty string 

262 # if not specified 

263 if password is None: 

264 password = "" 

265 

266 return BasicAuth(username, password) 

267 

268 

269def proxies_from_env() -> Dict[str, ProxyInfo]: 

270 proxy_urls = { 

271 k: URL(v) 

272 for k, v in getproxies().items() 

273 if k in ("http", "https", "ws", "wss") 

274 } 

275 netrc_obj = netrc_from_env() 

276 stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()} 

277 ret = {} 

278 for proto, val in stripped.items(): 

279 proxy, auth = val 

280 if proxy.scheme in ("https", "wss"): 

281 client_logger.warning( 

282 "%s proxies %s are not supported, ignoring", proxy.scheme.upper(), proxy 

283 ) 

284 continue 

285 if netrc_obj and auth is None: 

286 if proxy.host is not None: 

287 try: 

288 auth = basicauth_from_netrc(netrc_obj, proxy.host) 

289 except LookupError: 

290 auth = None 

291 ret[proto] = ProxyInfo(proxy, auth) 

292 return ret 

293 

294 

295def get_env_proxy_for_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]: 

296 """Get a permitted proxy for the given URL from the env.""" 

297 if url.host is not None and proxy_bypass(url.host): 

298 raise LookupError(f"Proxying is disallowed for `{url.host!r}`") 

299 

300 proxies_in_env = proxies_from_env() 

301 try: 

302 proxy_info = proxies_in_env[url.scheme] 

303 except KeyError: 

304 raise LookupError(f"No proxies found for `{url!s}` in the env") 

305 else: 

306 return proxy_info.proxy, proxy_info.proxy_auth 

307 

308 

309@dataclasses.dataclass(frozen=True) 

310class MimeType: 

311 type: str 

312 subtype: str 

313 suffix: str 

314 parameters: "MultiDictProxy[str]" 

315 

316 

317@functools.lru_cache(maxsize=56) 

318def parse_mimetype(mimetype: str) -> MimeType: 

319 """Parses a MIME type into its components. 

320 

321 mimetype is a MIME type string. 

322 

323 Returns a MimeType object. 

324 

325 Example: 

326 

327 >>> parse_mimetype('text/html; charset=utf-8') 

328 MimeType(type='text', subtype='html', suffix='', 

329 parameters={'charset': 'utf-8'}) 

330 

331 """ 

332 if not mimetype: 

333 return MimeType( 

334 type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict()) 

335 ) 

336 

337 parts = mimetype.split(";") 

338 params: MultiDict[str] = MultiDict() 

339 for item in parts[1:]: 

340 if not item: 

341 continue 

342 key, _, value = item.partition("=") 

343 params.add(key.lower().strip(), value.strip(' "')) 

344 

345 fulltype = parts[0].strip().lower() 

346 if fulltype == "*": 

347 fulltype = "*/*" 

348 

349 mtype, _, stype = fulltype.partition("/") 

350 stype, _, suffix = stype.partition("+") 

351 

352 return MimeType( 

353 type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params) 

354 ) 

355 

356 

357def guess_filename(obj: Any, default: Optional[str] = None) -> Optional[str]: 

358 name = getattr(obj, "name", None) 

359 if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">": 

360 return Path(name).name 

361 return default 

362 

363 

364not_qtext_re = re.compile(r"[^\041\043-\133\135-\176]") 

365QCONTENT = {chr(i) for i in range(0x20, 0x7F)} | {"\t"} 

366 

367 

368def quoted_string(content: str) -> str: 

369 """Return 7-bit content as quoted-string. 

370 

371 Format content into a quoted-string as defined in RFC5322 for 

372 Internet Message Format. Notice that this is not the 8-bit HTTP 

373 format, but the 7-bit email format. Content must be in usascii or 

374 a ValueError is raised. 

375 """ 

376 if not (QCONTENT > set(content)): 

377 raise ValueError(f"bad content for quoted-string {content!r}") 

378 return not_qtext_re.sub(lambda x: "\\" + x.group(0), content) 

379 

380 

381def content_disposition_header( 

382 disptype: str, quote_fields: bool = True, _charset: str = "utf-8", **params: str 

383) -> str: 

384 """Sets ``Content-Disposition`` header for MIME. 

385 

386 This is the MIME payload Content-Disposition header from RFC 2183 

387 and RFC 7579 section 4.2, not the HTTP Content-Disposition from 

388 RFC 6266. 

389 

390 disptype is a disposition type: inline, attachment, form-data. 

391 Should be valid extension token (see RFC 2183) 

392 

393 quote_fields performs value quoting to 7-bit MIME headers 

394 according to RFC 7578. Set to quote_fields to False if recipient 

395 can take 8-bit file names and field values. 

396 

397 _charset specifies the charset to use when quote_fields is True. 

398 

399 params is a dict with disposition params. 

400 """ 

401 if not disptype or not (TOKEN > set(disptype)): 

402 raise ValueError("bad content disposition type {!r}" "".format(disptype)) 

403 

404 value = disptype 

405 if params: 

406 lparams = [] 

407 for key, val in params.items(): 

408 if not key or not (TOKEN > set(key)): 

409 raise ValueError( 

410 "bad content disposition parameter" " {!r}={!r}".format(key, val) 

411 ) 

412 if quote_fields: 

413 if key.lower() == "filename": 

414 qval = quote(val, "", encoding=_charset) 

415 lparams.append((key, '"%s"' % qval)) 

416 else: 

417 try: 

418 qval = quoted_string(val) 

419 except ValueError: 

420 qval = "".join( 

421 (_charset, "''", quote(val, "", encoding=_charset)) 

422 ) 

423 lparams.append((key + "*", qval)) 

424 else: 

425 lparams.append((key, '"%s"' % qval)) 

426 else: 

427 qval = val.replace("\\", "\\\\").replace('"', '\\"') 

428 lparams.append((key, '"%s"' % qval)) 

429 sparams = "; ".join("=".join(pair) for pair in lparams) 

430 value = "; ".join((value, sparams)) 

431 return value 

432 

433 

434def is_expected_content_type( 

435 response_content_type: str, expected_content_type: str 

436) -> bool: 

437 """Checks if received content type is processable as an expected one. 

438 

439 Both arguments should be given without parameters. 

440 """ 

441 if expected_content_type == "application/json": 

442 return json_re.match(response_content_type) is not None 

443 return expected_content_type in response_content_type 

444 

445 

446class _TSelf(Protocol, Generic[_T]): 

447 _cache: Dict[str, _T] 

448 

449 

450class reify(Generic[_T]): 

451 """Use as a class method decorator. 

452 

453 It operates almost exactly like 

454 the Python `@property` decorator, but it puts the result of the 

455 method it decorates into the instance dict after the first call, 

456 effectively replacing the function it decorates with an instance 

457 variable. It is, in Python parlance, a data descriptor. 

458 """ 

459 

460 def __init__(self, wrapped: Callable[..., _T]) -> None: 

461 self.wrapped = wrapped 

462 self.__doc__ = wrapped.__doc__ 

463 self.name = wrapped.__name__ 

464 

465 def __get__(self, inst: _TSelf[_T], owner: Optional[Type[Any]] = None) -> _T: 

466 try: 

467 try: 

468 return inst._cache[self.name] 

469 except KeyError: 

470 val = self.wrapped(inst) 

471 inst._cache[self.name] = val 

472 return val 

473 except AttributeError: 

474 if inst is None: 

475 return self 

476 raise 

477 

478 def __set__(self, inst: _TSelf[_T], value: _T) -> None: 

479 raise AttributeError("reified property is read-only") 

480 

481 

482reify_py = reify 

483 

484try: 

485 from ._helpers import reify as reify_c 

486 

487 if not NO_EXTENSIONS: 

488 reify = reify_c # type: ignore[misc,assignment] 

489except ImportError: 

490 pass 

491 

492_ipv4_pattern = ( 

493 r"^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}" 

494 r"(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$" 

495) 

496_ipv6_pattern = ( 

497 r"^(?:(?:(?:[A-F0-9]{1,4}:){6}|(?=(?:[A-F0-9]{0,4}:){0,6}" 

498 r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}$)(([0-9A-F]{1,4}:){0,5}|:)" 

499 r"((:[0-9A-F]{1,4}){1,5}:|:)|::(?:[A-F0-9]{1,4}:){5})" 

500 r"(?:(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])\.){3}" 

501 r"(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])|(?:[A-F0-9]{1,4}:){7}" 

502 r"[A-F0-9]{1,4}|(?=(?:[A-F0-9]{0,4}:){0,7}[A-F0-9]{0,4}$)" 

503 r"(([0-9A-F]{1,4}:){1,7}|:)((:[0-9A-F]{1,4}){1,7}|:)|(?:[A-F0-9]{1,4}:){7}" 

504 r":|:(:[A-F0-9]{1,4}){7})$" 

505) 

506_ipv4_regex = re.compile(_ipv4_pattern) 

507_ipv6_regex = re.compile(_ipv6_pattern, flags=re.IGNORECASE) 

508_ipv4_regexb = re.compile(_ipv4_pattern.encode("ascii")) 

509_ipv6_regexb = re.compile(_ipv6_pattern.encode("ascii"), flags=re.IGNORECASE) 

510 

511 

512def _is_ip_address( 

513 regex: Pattern[str], regexb: Pattern[bytes], host: Optional[Union[str, bytes]] 

514) -> bool: 

515 if host is None: 

516 return False 

517 if isinstance(host, str): 

518 return bool(regex.match(host)) 

519 elif isinstance(host, (bytes, bytearray, memoryview)): 

520 return bool(regexb.match(host)) 

521 else: 

522 raise TypeError(f"{host} [{type(host)}] is not a str or bytes") 

523 

524 

525is_ipv4_address = functools.partial(_is_ip_address, _ipv4_regex, _ipv4_regexb) 

526is_ipv6_address = functools.partial(_is_ip_address, _ipv6_regex, _ipv6_regexb) 

527 

528 

529def is_ip_address(host: Optional[Union[str, bytes, bytearray, memoryview]]) -> bool: 

530 return is_ipv4_address(host) or is_ipv6_address(host) 

531 

532 

533_cached_current_datetime: Optional[int] = None 

534_cached_formatted_datetime = "" 

535 

536 

537def rfc822_formatted_time() -> str: 

538 global _cached_current_datetime 

539 global _cached_formatted_datetime 

540 

541 now = int(time.time()) 

542 if now != _cached_current_datetime: 

543 # Weekday and month names for HTTP date/time formatting; 

544 # always English! 

545 # Tuples are constants stored in codeobject! 

546 _weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun") 

547 _monthname = ( 

548 "", # Dummy so we can use 1-based month numbers 

549 "Jan", 

550 "Feb", 

551 "Mar", 

552 "Apr", 

553 "May", 

554 "Jun", 

555 "Jul", 

556 "Aug", 

557 "Sep", 

558 "Oct", 

559 "Nov", 

560 "Dec", 

561 ) 

562 

563 year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now) 

564 _cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( 

565 _weekdayname[wd], 

566 day, 

567 _monthname[month], 

568 year, 

569 hh, 

570 mm, 

571 ss, 

572 ) 

573 _cached_current_datetime = now 

574 return _cached_formatted_datetime 

575 

576 

577def _weakref_handle(info: "Tuple[weakref.ref[object], str]") -> None: 

578 ref, name = info 

579 ob = ref() 

580 if ob is not None: 

581 with suppress(Exception): 

582 getattr(ob, name)() 

583 

584 

585def weakref_handle( 

586 ob: object, 

587 name: str, 

588 timeout: Optional[float], 

589 loop: asyncio.AbstractEventLoop, 

590 timeout_ceil_threshold: float = 5, 

591) -> Optional[asyncio.TimerHandle]: 

592 if timeout is not None and timeout > 0: 

593 when = loop.time() + timeout 

594 if timeout >= timeout_ceil_threshold: 

595 when = ceil(when) 

596 

597 return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name)) 

598 return None 

599 

600 

601def call_later( 

602 cb: Callable[[], Any], 

603 timeout: Optional[float], 

604 loop: asyncio.AbstractEventLoop, 

605 timeout_ceil_threshold: float = 5, 

606) -> Optional[asyncio.TimerHandle]: 

607 if timeout is not None and timeout > 0: 

608 when = loop.time() + timeout 

609 if timeout > timeout_ceil_threshold: 

610 when = ceil(when) 

611 return loop.call_at(when, cb) 

612 return None 

613 

614 

615class TimeoutHandle: 

616 """Timeout handle""" 

617 

618 def __init__( 

619 self, 

620 loop: asyncio.AbstractEventLoop, 

621 timeout: Optional[float], 

622 ceil_threshold: float = 5, 

623 ) -> None: 

624 self._timeout = timeout 

625 self._loop = loop 

626 self._ceil_threshold = ceil_threshold 

627 self._callbacks: List[ 

628 Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]] 

629 ] = [] 

630 

631 def register( 

632 self, callback: Callable[..., None], *args: Any, **kwargs: Any 

633 ) -> None: 

634 self._callbacks.append((callback, args, kwargs)) 

635 

636 def close(self) -> None: 

637 self._callbacks.clear() 

638 

639 def start(self) -> Optional[asyncio.Handle]: 

640 timeout = self._timeout 

641 if timeout is not None and timeout > 0: 

642 when = self._loop.time() + timeout 

643 if timeout >= self._ceil_threshold: 

644 when = ceil(when) 

645 return self._loop.call_at(when, self.__call__) 

646 else: 

647 return None 

648 

649 def timer(self) -> "BaseTimerContext": 

650 if self._timeout is not None and self._timeout > 0: 

651 timer = TimerContext(self._loop) 

652 self.register(timer.timeout) 

653 return timer 

654 else: 

655 return TimerNoop() 

656 

657 def __call__(self) -> None: 

658 for cb, args, kwargs in self._callbacks: 

659 with suppress(Exception): 

660 cb(*args, **kwargs) 

661 

662 self._callbacks.clear() 

663 

664 

665class BaseTimerContext(ContextManager["BaseTimerContext"]): 

666 def assert_timeout(self) -> None: 

667 """Raise TimeoutError if timeout has been exceeded.""" 

668 

669 

670class TimerNoop(BaseTimerContext): 

671 def __enter__(self) -> BaseTimerContext: 

672 return self 

673 

674 def __exit__( 

675 self, 

676 exc_type: Optional[Type[BaseException]], 

677 exc_val: Optional[BaseException], 

678 exc_tb: Optional[TracebackType], 

679 ) -> None: 

680 return 

681 

682 

683class TimerContext(BaseTimerContext): 

684 """Low resolution timeout context manager""" 

685 

686 def __init__(self, loop: asyncio.AbstractEventLoop) -> None: 

687 self._loop = loop 

688 self._tasks: List[asyncio.Task[Any]] = [] 

689 self._cancelled = False 

690 

691 def assert_timeout(self) -> None: 

692 """Raise TimeoutError if timer has already been cancelled.""" 

693 if self._cancelled: 

694 raise asyncio.TimeoutError from None 

695 

696 def __enter__(self) -> BaseTimerContext: 

697 task = asyncio.current_task(loop=self._loop) 

698 

699 if task is None: 

700 raise RuntimeError( 

701 "Timeout context manager should be used " "inside a task" 

702 ) 

703 

704 if self._cancelled: 

705 raise asyncio.TimeoutError from None 

706 

707 self._tasks.append(task) 

708 return self 

709 

710 def __exit__( 

711 self, 

712 exc_type: Optional[Type[BaseException]], 

713 exc_val: Optional[BaseException], 

714 exc_tb: Optional[TracebackType], 

715 ) -> Optional[bool]: 

716 if self._tasks: 

717 self._tasks.pop() # type: ignore[unused-awaitable] 

718 

719 if exc_type is asyncio.CancelledError and self._cancelled: 

720 raise asyncio.TimeoutError from None 

721 return None 

722 

723 def timeout(self) -> None: 

724 if not self._cancelled: 

725 for task in set(self._tasks): 

726 task.cancel() 

727 

728 self._cancelled = True 

729 

730 

731def ceil_timeout( 

732 delay: Optional[float], ceil_threshold: float = 5 

733) -> async_timeout.Timeout: 

734 if delay is None or delay <= 0: 

735 return async_timeout.timeout(None) 

736 

737 loop = asyncio.get_running_loop() 

738 now = loop.time() 

739 when = now + delay 

740 if delay > ceil_threshold: 

741 when = ceil(when) 

742 return async_timeout.timeout_at(when) 

743 

744 

745class HeadersMixin: 

746 __slots__ = ("_content_type", "_content_dict", "_stored_content_type") 

747 

748 def __init__(self) -> None: 

749 super().__init__() 

750 self._content_type: Optional[str] = None 

751 self._content_dict: Optional[Dict[str, str]] = None 

752 self._stored_content_type: Union[str, _SENTINEL] = sentinel 

753 

754 def _parse_content_type(self, raw: str) -> None: 

755 self._stored_content_type = raw 

756 if raw is None: 

757 # default value according to RFC 2616 

758 self._content_type = "application/octet-stream" 

759 self._content_dict = {} 

760 else: 

761 msg = HeaderParser().parsestr("Content-Type: " + raw) 

762 self._content_type = msg.get_content_type() 

763 params = msg.get_params(()) 

764 self._content_dict = dict(params[1:]) # First element is content type again 

765 

766 @property 

767 def content_type(self) -> str: 

768 """The value of content part for Content-Type HTTP header.""" 

769 raw = self._headers.get(hdrs.CONTENT_TYPE) # type: ignore[attr-defined] 

770 if self._stored_content_type != raw: 

771 self._parse_content_type(raw) 

772 return self._content_type # type: ignore[return-value] 

773 

774 @property 

775 def charset(self) -> Optional[str]: 

776 """The value of charset part for Content-Type HTTP header.""" 

777 raw = self._headers.get(hdrs.CONTENT_TYPE) # type: ignore[attr-defined] 

778 if self._stored_content_type != raw: 

779 self._parse_content_type(raw) 

780 return self._content_dict.get("charset") # type: ignore[union-attr] 

781 

782 @property 

783 def content_length(self) -> Optional[int]: 

784 """The value of Content-Length HTTP header.""" 

785 content_length = self._headers.get( # type: ignore[attr-defined] 

786 hdrs.CONTENT_LENGTH 

787 ) 

788 

789 if content_length is not None: 

790 return int(content_length) 

791 else: 

792 return None 

793 

794 

795def set_result(fut: "asyncio.Future[_T]", result: _T) -> None: 

796 if not fut.done(): 

797 fut.set_result(result) 

798 

799 

800def set_exception(fut: "asyncio.Future[_T]", exc: BaseException) -> None: 

801 if not fut.done(): 

802 fut.set_exception(exc) 

803 

804 

805@functools.total_ordering 

806class AppKey(Generic[_T]): 

807 """Keys for static typing support in Application.""" 

808 

809 __slots__ = ("_name", "_t", "__orig_class__") 

810 

811 # This may be set by Python when instantiating with a generic type. We need to 

812 # support this, in order to support types that are not concrete classes, 

813 # like Iterable, which can't be passed as the second parameter to __init__. 

814 __orig_class__: Type[object] 

815 

816 def __init__(self, name: str, t: Optional[Type[_T]] = None): 

817 # Prefix with module name to help deduplicate key names. 

818 frame = inspect.currentframe() 

819 while frame: 

820 if frame.f_code.co_name == "<module>": 

821 module: str = frame.f_globals["__name__"] 

822 break 

823 frame = frame.f_back 

824 else: 

825 raise RuntimeError("Failed to get module name.") 

826 

827 # https://github.com/python/mypy/issues/14209 

828 self._name = module + "." + name # type: ignore[possibly-undefined] 

829 self._t = t 

830 

831 def __lt__(self, other: object) -> bool: 

832 if isinstance(other, AppKey): 

833 return self._name < other._name 

834 return True # Order AppKey above other types. 

835 

836 def __repr__(self) -> str: 

837 t = self._t 

838 if t is None: 

839 with suppress(AttributeError): 

840 # Set to type arg. 

841 t = get_args(self.__orig_class__)[0] 

842 

843 if t is None: 

844 t_repr = "<<Unknown>>" 

845 elif isinstance(t, type): 

846 if t.__module__ == "builtins": 

847 t_repr = t.__qualname__ 

848 else: 

849 t_repr = f"{t.__module__}.{t.__qualname__}" 

850 else: 

851 t_repr = repr(t) 

852 return f"<AppKey({self._name}, type={t_repr})>" 

853 

854 

855@final 

856class ChainMapProxy(Mapping[Union[str, AppKey[Any]], Any]): 

857 __slots__ = ("_maps",) 

858 

859 def __init__(self, maps: Iterable[Mapping[Union[str, AppKey[Any]], Any]]) -> None: 

860 self._maps = tuple(maps) 

861 

862 def __init_subclass__(cls) -> None: 

863 raise TypeError( 

864 "Inheritance class {} from ChainMapProxy " 

865 "is forbidden".format(cls.__name__) 

866 ) 

867 

868 @overload # type: ignore[override] 

869 def __getitem__(self, key: AppKey[_T]) -> _T: 

870 ... 

871 

872 @overload 

873 def __getitem__(self, key: str) -> Any: 

874 ... 

875 

876 def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any: 

877 for mapping in self._maps: 

878 try: 

879 return mapping[key] 

880 except KeyError: 

881 pass 

882 raise KeyError(key) 

883 

884 @overload # type: ignore[override] 

885 def get(self, key: AppKey[_T], default: _S) -> Union[_T, _S]: 

886 ... 

887 

888 @overload 

889 def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: 

890 ... 

891 

892 @overload 

893 def get(self, key: str, default: Any = ...) -> Any: 

894 ... 

895 

896 def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any: 

897 try: 

898 return self[key] 

899 except KeyError: 

900 return default 

901 

902 def __len__(self) -> int: 

903 # reuses stored hash values if possible 

904 return len(set().union(*self._maps)) 

905 

906 def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]: 

907 d: Dict[Union[str, AppKey[Any]], Any] = {} 

908 for mapping in reversed(self._maps): 

909 # reuses stored hash values if possible 

910 d.update(mapping) 

911 return iter(d) 

912 

913 def __contains__(self, key: object) -> bool: 

914 return any(key in m for m in self._maps) 

915 

916 def __bool__(self) -> bool: 

917 return any(self._maps) 

918 

919 def __repr__(self) -> str: 

920 content = ", ".join(map(repr, self._maps)) 

921 return f"ChainMapProxy({content})" 

922 

923 

924class CookieMixin: 

925 # The `_cookies` slots is not defined here because non-empty slots cannot 

926 # be combined with an Exception base class, as is done in HTTPException. 

927 # CookieMixin subclasses with slots should define the `_cookies` 

928 # slot themselves. 

929 __slots__ = () 

930 

931 def __init__(self) -> None: 

932 super().__init__() 

933 # Mypy doesn't like that _cookies isn't in __slots__. 

934 # See the comment on this class's __slots__ for why this is OK. 

935 self._cookies = SimpleCookie() # type: ignore[misc] 

936 

937 @property 

938 def cookies(self) -> SimpleCookie: 

939 return self._cookies 

940 

941 def set_cookie( 

942 self, 

943 name: str, 

944 value: str, 

945 *, 

946 expires: Optional[str] = None, 

947 domain: Optional[str] = None, 

948 max_age: Optional[Union[int, str]] = None, 

949 path: str = "/", 

950 secure: Optional[bool] = None, 

951 httponly: Optional[bool] = None, 

952 version: Optional[str] = None, 

953 samesite: Optional[str] = None, 

954 ) -> None: 

955 """Set or update response cookie. 

956 

957 Sets new cookie or updates existent with new value. 

958 Also updates only those params which are not None. 

959 """ 

960 old = self._cookies.get(name) 

961 if old is not None and old.coded_value == "": 

962 # deleted cookie 

963 self._cookies.pop(name, None) 

964 

965 self._cookies[name] = value 

966 c = self._cookies[name] 

967 

968 if expires is not None: 

969 c["expires"] = expires 

970 elif c.get("expires") == "Thu, 01 Jan 1970 00:00:00 GMT": 

971 del c["expires"] 

972 

973 if domain is not None: 

974 c["domain"] = domain 

975 

976 if max_age is not None: 

977 c["max-age"] = str(max_age) 

978 elif "max-age" in c: 

979 del c["max-age"] 

980 

981 c["path"] = path 

982 

983 if secure is not None: 

984 c["secure"] = secure 

985 if httponly is not None: 

986 c["httponly"] = httponly 

987 if version is not None: 

988 c["version"] = version 

989 if samesite is not None: 

990 c["samesite"] = samesite 

991 

992 if DEBUG: 

993 cookie_length = len(c.output(header="")[1:]) 

994 if cookie_length > COOKIE_MAX_LENGTH: 

995 warnings.warn( 

996 "The size of is too large, it might get ignored by the client.", 

997 UserWarning, 

998 stacklevel=2, 

999 ) 

1000 

1001 def del_cookie( 

1002 self, name: str, *, domain: Optional[str] = None, path: str = "/" 

1003 ) -> None: 

1004 """Delete cookie. 

1005 

1006 Creates new empty expired cookie. 

1007 """ 

1008 # TODO: do we need domain/path here? 

1009 self._cookies.pop(name, None) 

1010 self.set_cookie( 

1011 name, 

1012 "", 

1013 max_age=0, 

1014 expires="Thu, 01 Jan 1970 00:00:00 GMT", 

1015 domain=domain, 

1016 path=path, 

1017 ) 

1018 

1019 

1020def populate_with_cookies(headers: "CIMultiDict[str]", cookies: SimpleCookie) -> None: 

1021 for cookie in cookies.values(): 

1022 value = cookie.output(header="")[1:] 

1023 headers.add(hdrs.SET_COOKIE, value) 

1024 

1025 

1026# https://tools.ietf.org/html/rfc7232#section-2.3 

1027_ETAGC = r"[!\x23-\x7E\x80-\xff]+" 

1028_ETAGC_RE = re.compile(_ETAGC) 

1029_QUOTED_ETAG = rf'(W/)?"({_ETAGC})"' 

1030QUOTED_ETAG_RE = re.compile(_QUOTED_ETAG) 

1031LIST_QUOTED_ETAG_RE = re.compile(rf"({_QUOTED_ETAG})(?:\s*,\s*|$)|(.)") 

1032 

1033ETAG_ANY = "*" 

1034 

1035 

1036@dataclasses.dataclass(frozen=True) 

1037class ETag: 

1038 value: str 

1039 is_weak: bool = False 

1040 

1041 

1042def validate_etag_value(value: str) -> None: 

1043 if value != ETAG_ANY and not _ETAGC_RE.fullmatch(value): 

1044 raise ValueError( 

1045 f"Value {value!r} is not a valid etag. Maybe it contains '\"'?" 

1046 ) 

1047 

1048 

1049def parse_http_date(date_str: Optional[str]) -> Optional[datetime.datetime]: 

1050 """Process a date string, return a datetime object""" 

1051 if date_str is not None: 

1052 timetuple = parsedate(date_str) 

1053 if timetuple is not None: 

1054 with suppress(ValueError): 

1055 return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc) 

1056 return None 

1057 

1058 

1059def must_be_empty_body(method: str, code: int) -> bool: 

1060 """Check if a request must return an empty body.""" 

1061 return ( 

1062 status_code_must_be_empty_body(code) 

1063 or method_must_be_empty_body(method) 

1064 or (200 <= code < 300 and method.upper() == hdrs.METH_CONNECT) 

1065 ) 

1066 

1067 

1068def method_must_be_empty_body(method: str) -> bool: 

1069 """Check if a method must return an empty body.""" 

1070 # https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1 

1071 # https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.2 

1072 return method.upper() == hdrs.METH_HEAD 

1073 

1074 

1075def status_code_must_be_empty_body(code: int) -> bool: 

1076 """Check if a status code must return an empty body.""" 

1077 # https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1 

1078 return code in {204, 304} or 100 <= code < 200 

1079 

1080 

1081def should_remove_content_length(method: str, code: int) -> bool: 

1082 """Check if a Content-Length header should be removed. 

1083 

1084 This should always be a subset of must_be_empty_body 

1085 """ 

1086 # https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-8 

1087 # https://www.rfc-editor.org/rfc/rfc9110.html#section-15.4.5-4 

1088 return ( 

1089 code in {204, 304} 

1090 or 100 <= code < 200 

1091 or (200 <= code < 300 and method.upper() == hdrs.METH_CONNECT) 

1092 )