Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/aiohttp/helpers.py: 44%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

485 statements  

1"""Various helper functions""" 

2 

3import asyncio 

4import base64 

5import binascii 

6import contextlib 

7import datetime 

8import enum 

9import functools 

10import inspect 

11import netrc 

12import os 

13import platform 

14import re 

15import sys 

16import time 

17import weakref 

18from collections import namedtuple 

19from contextlib import suppress 

20from email.parser import HeaderParser 

21from email.utils import parsedate 

22from math import ceil 

23from pathlib import Path 

24from types import MappingProxyType, TracebackType 

25from typing import ( 

26 Any, 

27 Callable, 

28 ContextManager, 

29 Dict, 

30 Generator, 

31 Generic, 

32 Iterable, 

33 Iterator, 

34 List, 

35 Mapping, 

36 Optional, 

37 Protocol, 

38 Tuple, 

39 Type, 

40 TypeVar, 

41 Union, 

42 get_args, 

43 overload, 

44) 

45from urllib.parse import quote 

46from urllib.request import getproxies, proxy_bypass 

47 

48import attr 

49from multidict import MultiDict, MultiDictProxy, MultiMapping 

50from propcache.api import under_cached_property as reify 

51from yarl import URL 

52 

53from . import hdrs 

54from .log import client_logger 

55 

56if sys.version_info >= (3, 11): 

57 import asyncio as async_timeout 

58else: 

59 import async_timeout 

60 

61__all__ = ("BasicAuth", "ChainMapProxy", "ETag", "reify") 

62 

63IS_MACOS = platform.system() == "Darwin" 

64IS_WINDOWS = platform.system() == "Windows" 

65 

66PY_310 = sys.version_info >= (3, 10) 

67PY_311 = sys.version_info >= (3, 11) 

68 

69 

70_T = TypeVar("_T") 

71_S = TypeVar("_S") 

72 

73_SENTINEL = enum.Enum("_SENTINEL", "sentinel") 

74sentinel = _SENTINEL.sentinel 

75 

76NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS")) 

77 

78# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1 

79EMPTY_BODY_STATUS_CODES = frozenset((204, 304, *range(100, 200))) 

80# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1 

81# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.2 

82EMPTY_BODY_METHODS = hdrs.METH_HEAD_ALL 

83 

84DEBUG = sys.flags.dev_mode or ( 

85 not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG")) 

86) 

87 

88 

89CHAR = {chr(i) for i in range(0, 128)} 

90CTL = {chr(i) for i in range(0, 32)} | { 

91 chr(127), 

92} 

93SEPARATORS = { 

94 "(", 

95 ")", 

96 "<", 

97 ">", 

98 "@", 

99 ",", 

100 ";", 

101 ":", 

102 "\\", 

103 '"', 

104 "/", 

105 "[", 

106 "]", 

107 "?", 

108 "=", 

109 "{", 

110 "}", 

111 " ", 

112 chr(9), 

113} 

114TOKEN = CHAR ^ CTL ^ SEPARATORS 

115 

116 

117class noop: 

118 def __await__(self) -> Generator[None, None, None]: 

119 yield 

120 

121 

122class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])): 

123 """Http basic authentication helper.""" 

124 

125 def __new__( 

126 cls, login: str, password: str = "", encoding: str = "latin1" 

127 ) -> "BasicAuth": 

128 if login is None: 

129 raise ValueError("None is not allowed as login value") 

130 

131 if password is None: 

132 raise ValueError("None is not allowed as password value") 

133 

134 if ":" in login: 

135 raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)') 

136 

137 return super().__new__(cls, login, password, encoding) 

138 

139 @classmethod 

140 def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth": 

141 """Create a BasicAuth object from an Authorization HTTP header.""" 

142 try: 

143 auth_type, encoded_credentials = auth_header.split(" ", 1) 

144 except ValueError: 

145 raise ValueError("Could not parse authorization header.") 

146 

147 if auth_type.lower() != "basic": 

148 raise ValueError("Unknown authorization method %s" % auth_type) 

149 

150 try: 

151 decoded = base64.b64decode( 

152 encoded_credentials.encode("ascii"), validate=True 

153 ).decode(encoding) 

154 except binascii.Error: 

155 raise ValueError("Invalid base64 encoding.") 

156 

157 try: 

158 # RFC 2617 HTTP Authentication 

159 # https://www.ietf.org/rfc/rfc2617.txt 

160 # the colon must be present, but the username and password may be 

161 # otherwise blank. 

162 username, password = decoded.split(":", 1) 

163 except ValueError: 

164 raise ValueError("Invalid credentials.") 

165 

166 return cls(username, password, encoding=encoding) 

167 

168 @classmethod 

169 def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]: 

170 """Create BasicAuth from url.""" 

171 if not isinstance(url, URL): 

172 raise TypeError("url should be yarl.URL instance") 

173 # Check raw_user and raw_password first as yarl is likely 

174 # to already have these values parsed from the netloc in the cache. 

175 if url.raw_user is None and url.raw_password is None: 

176 return None 

177 return cls(url.user or "", url.password or "", encoding=encoding) 

178 

179 def encode(self) -> str: 

180 """Encode credentials.""" 

181 creds = (f"{self.login}:{self.password}").encode(self.encoding) 

182 return "Basic %s" % base64.b64encode(creds).decode(self.encoding) 

183 

184 

185def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]: 

186 """Remove user and password from URL if present and return BasicAuth object.""" 

187 # Check raw_user and raw_password first as yarl is likely 

188 # to already have these values parsed from the netloc in the cache. 

189 if url.raw_user is None and url.raw_password is None: 

190 return url, None 

191 return url.with_user(None), BasicAuth(url.user or "", url.password or "") 

192 

193 

194def netrc_from_env() -> Optional[netrc.netrc]: 

195 """Load netrc from file. 

196 

197 Attempt to load it from the path specified by the env-var 

198 NETRC or in the default location in the user's home directory. 

199 

200 Returns None if it couldn't be found or fails to parse. 

201 """ 

202 netrc_env = os.environ.get("NETRC") 

203 

204 if netrc_env is not None: 

205 netrc_path = Path(netrc_env) 

206 else: 

207 try: 

208 home_dir = Path.home() 

209 except RuntimeError as e: # pragma: no cover 

210 # if pathlib can't resolve home, it may raise a RuntimeError 

211 client_logger.debug( 

212 "Could not resolve home directory when " 

213 "trying to look for .netrc file: %s", 

214 e, 

215 ) 

216 return None 

217 

218 netrc_path = home_dir / ("_netrc" if IS_WINDOWS else ".netrc") 

219 

220 try: 

221 return netrc.netrc(str(netrc_path)) 

222 except netrc.NetrcParseError as e: 

223 client_logger.warning("Could not parse .netrc file: %s", e) 

224 except OSError as e: 

225 netrc_exists = False 

226 with contextlib.suppress(OSError): 

227 netrc_exists = netrc_path.is_file() 

228 # we couldn't read the file (doesn't exist, permissions, etc.) 

229 if netrc_env or netrc_exists: 

230 # only warn if the environment wanted us to load it, 

231 # or it appears like the default file does actually exist 

232 client_logger.warning("Could not read .netrc file: %s", e) 

233 

234 return None 

235 

236 

237@attr.s(auto_attribs=True, frozen=True, slots=True) 

238class ProxyInfo: 

239 proxy: URL 

240 proxy_auth: Optional[BasicAuth] 

241 

242 

243def basicauth_from_netrc(netrc_obj: Optional[netrc.netrc], host: str) -> BasicAuth: 

244 """ 

245 Return :py:class:`~aiohttp.BasicAuth` credentials for ``host`` from ``netrc_obj``. 

246 

247 :raises LookupError: if ``netrc_obj`` is :py:data:`None` or if no 

248 entry is found for the ``host``. 

249 """ 

250 if netrc_obj is None: 

251 raise LookupError("No .netrc file found") 

252 auth_from_netrc = netrc_obj.authenticators(host) 

253 

254 if auth_from_netrc is None: 

255 raise LookupError(f"No entry for {host!s} found in the `.netrc` file.") 

256 login, account, password = auth_from_netrc 

257 

258 # TODO(PY311): username = login or account 

259 # Up to python 3.10, account could be None if not specified, 

260 # and login will be empty string if not specified. From 3.11, 

261 # login and account will be empty string if not specified. 

262 username = login if (login or account is None) else account 

263 

264 # TODO(PY311): Remove this, as password will be empty string 

265 # if not specified 

266 if password is None: 

267 password = "" 

268 

269 return BasicAuth(username, password) 

270 

271 

272def proxies_from_env() -> Dict[str, ProxyInfo]: 

273 proxy_urls = { 

274 k: URL(v) 

275 for k, v in getproxies().items() 

276 if k in ("http", "https", "ws", "wss") 

277 } 

278 netrc_obj = netrc_from_env() 

279 stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()} 

280 ret = {} 

281 for proto, val in stripped.items(): 

282 proxy, auth = val 

283 if proxy.scheme in ("https", "wss"): 

284 client_logger.warning( 

285 "%s proxies %s are not supported, ignoring", proxy.scheme.upper(), proxy 

286 ) 

287 continue 

288 if netrc_obj and auth is None: 

289 if proxy.host is not None: 

290 try: 

291 auth = basicauth_from_netrc(netrc_obj, proxy.host) 

292 except LookupError: 

293 auth = None 

294 ret[proto] = ProxyInfo(proxy, auth) 

295 return ret 

296 

297 

298def get_env_proxy_for_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]: 

299 """Get a permitted proxy for the given URL from the env.""" 

300 if url.host is not None and proxy_bypass(url.host): 

301 raise LookupError(f"Proxying is disallowed for `{url.host!r}`") 

302 

303 proxies_in_env = proxies_from_env() 

304 try: 

305 proxy_info = proxies_in_env[url.scheme] 

306 except KeyError: 

307 raise LookupError(f"No proxies found for `{url!s}` in the env") 

308 else: 

309 return proxy_info.proxy, proxy_info.proxy_auth 

310 

311 

312@attr.s(auto_attribs=True, frozen=True, slots=True) 

313class MimeType: 

314 type: str 

315 subtype: str 

316 suffix: str 

317 parameters: "MultiDictProxy[str]" 

318 

319 

320@functools.lru_cache(maxsize=56) 

321def parse_mimetype(mimetype: str) -> MimeType: 

322 """Parses a MIME type into its components. 

323 

324 mimetype is a MIME type string. 

325 

326 Returns a MimeType object. 

327 

328 Example: 

329 

330 >>> parse_mimetype('text/html; charset=utf-8') 

331 MimeType(type='text', subtype='html', suffix='', 

332 parameters={'charset': 'utf-8'}) 

333 

334 """ 

335 if not mimetype: 

336 return MimeType( 

337 type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict()) 

338 ) 

339 

340 parts = mimetype.split(";") 

341 params: MultiDict[str] = MultiDict() 

342 for item in parts[1:]: 

343 if not item: 

344 continue 

345 key, _, value = item.partition("=") 

346 params.add(key.lower().strip(), value.strip(' "')) 

347 

348 fulltype = parts[0].strip().lower() 

349 if fulltype == "*": 

350 fulltype = "*/*" 

351 

352 mtype, _, stype = fulltype.partition("/") 

353 stype, _, suffix = stype.partition("+") 

354 

355 return MimeType( 

356 type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params) 

357 ) 

358 

359 

360@functools.lru_cache(maxsize=56) 

361def parse_content_type(raw: str) -> Tuple[str, MappingProxyType[str, str]]: 

362 """Parse Content-Type header. 

363 

364 Returns a tuple of the parsed content type and a 

365 MappingProxyType of parameters. 

366 """ 

367 msg = HeaderParser().parsestr(f"Content-Type: {raw}") 

368 content_type = msg.get_content_type() 

369 params = msg.get_params(()) 

370 content_dict = dict(params[1:]) # First element is content type again 

371 return content_type, MappingProxyType(content_dict) 

372 

373 

374def guess_filename(obj: Any, default: Optional[str] = None) -> Optional[str]: 

375 name = getattr(obj, "name", None) 

376 if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">": 

377 return Path(name).name 

378 return default 

379 

380 

381not_qtext_re = re.compile(r"[^\041\043-\133\135-\176]") 

382QCONTENT = {chr(i) for i in range(0x20, 0x7F)} | {"\t"} 

383 

384 

385def quoted_string(content: str) -> str: 

386 """Return 7-bit content as quoted-string. 

387 

388 Format content into a quoted-string as defined in RFC5322 for 

389 Internet Message Format. Notice that this is not the 8-bit HTTP 

390 format, but the 7-bit email format. Content must be in usascii or 

391 a ValueError is raised. 

392 """ 

393 if not (QCONTENT > set(content)): 

394 raise ValueError(f"bad content for quoted-string {content!r}") 

395 return not_qtext_re.sub(lambda x: "\\" + x.group(0), content) 

396 

397 

398def content_disposition_header( 

399 disptype: str, quote_fields: bool = True, _charset: str = "utf-8", **params: str 

400) -> str: 

401 """Sets ``Content-Disposition`` header for MIME. 

402 

403 This is the MIME payload Content-Disposition header from RFC 2183 

404 and RFC 7579 section 4.2, not the HTTP Content-Disposition from 

405 RFC 6266. 

406 

407 disptype is a disposition type: inline, attachment, form-data. 

408 Should be valid extension token (see RFC 2183) 

409 

410 quote_fields performs value quoting to 7-bit MIME headers 

411 according to RFC 7578. Set to quote_fields to False if recipient 

412 can take 8-bit file names and field values. 

413 

414 _charset specifies the charset to use when quote_fields is True. 

415 

416 params is a dict with disposition params. 

417 """ 

418 if not disptype or not (TOKEN > set(disptype)): 

419 raise ValueError(f"bad content disposition type {disptype!r}") 

420 

421 value = disptype 

422 if params: 

423 lparams = [] 

424 for key, val in params.items(): 

425 if not key or not (TOKEN > set(key)): 

426 raise ValueError(f"bad content disposition parameter {key!r}={val!r}") 

427 if quote_fields: 

428 if key.lower() == "filename": 

429 qval = quote(val, "", encoding=_charset) 

430 lparams.append((key, '"%s"' % qval)) 

431 else: 

432 try: 

433 qval = quoted_string(val) 

434 except ValueError: 

435 qval = "".join( 

436 (_charset, "''", quote(val, "", encoding=_charset)) 

437 ) 

438 lparams.append((key + "*", qval)) 

439 else: 

440 lparams.append((key, '"%s"' % qval)) 

441 else: 

442 qval = val.replace("\\", "\\\\").replace('"', '\\"') 

443 lparams.append((key, '"%s"' % qval)) 

444 sparams = "; ".join("=".join(pair) for pair in lparams) 

445 value = "; ".join((value, sparams)) 

446 return value 

447 

448 

449def is_ip_address(host: Optional[str]) -> bool: 

450 """Check if host looks like an IP Address. 

451 

452 This check is only meant as a heuristic to ensure that 

453 a host is not a domain name. 

454 """ 

455 if not host: 

456 return False 

457 # For a host to be an ipv4 address, it must be all numeric. 

458 # The host must contain a colon to be an IPv6 address. 

459 return ":" in host or host.replace(".", "").isdigit() 

460 

461 

462_cached_current_datetime: Optional[int] = None 

463_cached_formatted_datetime = "" 

464 

465 

466def rfc822_formatted_time() -> str: 

467 global _cached_current_datetime 

468 global _cached_formatted_datetime 

469 

470 now = int(time.time()) 

471 if now != _cached_current_datetime: 

472 # Weekday and month names for HTTP date/time formatting; 

473 # always English! 

474 # Tuples are constants stored in codeobject! 

475 _weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun") 

476 _monthname = ( 

477 "", # Dummy so we can use 1-based month numbers 

478 "Jan", 

479 "Feb", 

480 "Mar", 

481 "Apr", 

482 "May", 

483 "Jun", 

484 "Jul", 

485 "Aug", 

486 "Sep", 

487 "Oct", 

488 "Nov", 

489 "Dec", 

490 ) 

491 

492 year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now) 

493 _cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( 

494 _weekdayname[wd], 

495 day, 

496 _monthname[month], 

497 year, 

498 hh, 

499 mm, 

500 ss, 

501 ) 

502 _cached_current_datetime = now 

503 return _cached_formatted_datetime 

504 

505 

506def _weakref_handle(info: "Tuple[weakref.ref[object], str]") -> None: 

507 ref, name = info 

508 ob = ref() 

509 if ob is not None: 

510 with suppress(Exception): 

511 getattr(ob, name)() 

512 

513 

514def weakref_handle( 

515 ob: object, 

516 name: str, 

517 timeout: float, 

518 loop: asyncio.AbstractEventLoop, 

519 timeout_ceil_threshold: float = 5, 

520) -> Optional[asyncio.TimerHandle]: 

521 if timeout is not None and timeout > 0: 

522 when = loop.time() + timeout 

523 if timeout >= timeout_ceil_threshold: 

524 when = ceil(when) 

525 

526 return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name)) 

527 return None 

528 

529 

530def call_later( 

531 cb: Callable[[], Any], 

532 timeout: float, 

533 loop: asyncio.AbstractEventLoop, 

534 timeout_ceil_threshold: float = 5, 

535) -> Optional[asyncio.TimerHandle]: 

536 if timeout is None or timeout <= 0: 

537 return None 

538 now = loop.time() 

539 when = calculate_timeout_when(now, timeout, timeout_ceil_threshold) 

540 return loop.call_at(when, cb) 

541 

542 

543def calculate_timeout_when( 

544 loop_time: float, 

545 timeout: float, 

546 timeout_ceiling_threshold: float, 

547) -> float: 

548 """Calculate when to execute a timeout.""" 

549 when = loop_time + timeout 

550 if timeout > timeout_ceiling_threshold: 

551 return ceil(when) 

552 return when 

553 

554 

555class TimeoutHandle: 

556 """Timeout handle""" 

557 

558 __slots__ = ("_timeout", "_loop", "_ceil_threshold", "_callbacks") 

559 

560 def __init__( 

561 self, 

562 loop: asyncio.AbstractEventLoop, 

563 timeout: Optional[float], 

564 ceil_threshold: float = 5, 

565 ) -> None: 

566 self._timeout = timeout 

567 self._loop = loop 

568 self._ceil_threshold = ceil_threshold 

569 self._callbacks: List[ 

570 Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]] 

571 ] = [] 

572 

573 def register( 

574 self, callback: Callable[..., None], *args: Any, **kwargs: Any 

575 ) -> None: 

576 self._callbacks.append((callback, args, kwargs)) 

577 

578 def close(self) -> None: 

579 self._callbacks.clear() 

580 

581 def start(self) -> Optional[asyncio.TimerHandle]: 

582 timeout = self._timeout 

583 if timeout is not None and timeout > 0: 

584 when = self._loop.time() + timeout 

585 if timeout >= self._ceil_threshold: 

586 when = ceil(when) 

587 return self._loop.call_at(when, self.__call__) 

588 else: 

589 return None 

590 

591 def timer(self) -> "BaseTimerContext": 

592 if self._timeout is not None and self._timeout > 0: 

593 timer = TimerContext(self._loop) 

594 self.register(timer.timeout) 

595 return timer 

596 else: 

597 return TimerNoop() 

598 

599 def __call__(self) -> None: 

600 for cb, args, kwargs in self._callbacks: 

601 with suppress(Exception): 

602 cb(*args, **kwargs) 

603 

604 self._callbacks.clear() 

605 

606 

607class BaseTimerContext(ContextManager["BaseTimerContext"]): 

608 

609 __slots__ = () 

610 

611 def assert_timeout(self) -> None: 

612 """Raise TimeoutError if timeout has been exceeded.""" 

613 

614 

615class TimerNoop(BaseTimerContext): 

616 

617 __slots__ = () 

618 

619 def __enter__(self) -> BaseTimerContext: 

620 return self 

621 

622 def __exit__( 

623 self, 

624 exc_type: Optional[Type[BaseException]], 

625 exc_val: Optional[BaseException], 

626 exc_tb: Optional[TracebackType], 

627 ) -> None: 

628 return 

629 

630 

631class TimerContext(BaseTimerContext): 

632 """Low resolution timeout context manager""" 

633 

634 __slots__ = ("_loop", "_tasks", "_cancelled", "_cancelling") 

635 

636 def __init__(self, loop: asyncio.AbstractEventLoop) -> None: 

637 self._loop = loop 

638 self._tasks: List[asyncio.Task[Any]] = [] 

639 self._cancelled = False 

640 self._cancelling = 0 

641 

642 def assert_timeout(self) -> None: 

643 """Raise TimeoutError if timer has already been cancelled.""" 

644 if self._cancelled: 

645 raise asyncio.TimeoutError from None 

646 

647 def __enter__(self) -> BaseTimerContext: 

648 task = asyncio.current_task(loop=self._loop) 

649 if task is None: 

650 raise RuntimeError("Timeout context manager should be used inside a task") 

651 

652 if sys.version_info >= (3, 11): 

653 # Remember if the task was already cancelling 

654 # so when we __exit__ we can decide if we should 

655 # raise asyncio.TimeoutError or let the cancellation propagate 

656 self._cancelling = task.cancelling() 

657 

658 if self._cancelled: 

659 raise asyncio.TimeoutError from None 

660 

661 self._tasks.append(task) 

662 return self 

663 

664 def __exit__( 

665 self, 

666 exc_type: Optional[Type[BaseException]], 

667 exc_val: Optional[BaseException], 

668 exc_tb: Optional[TracebackType], 

669 ) -> Optional[bool]: 

670 enter_task: Optional[asyncio.Task[Any]] = None 

671 if self._tasks: 

672 enter_task = self._tasks.pop() 

673 

674 if exc_type is asyncio.CancelledError and self._cancelled: 

675 assert enter_task is not None 

676 # The timeout was hit, and the task was cancelled 

677 # so we need to uncancel the last task that entered the context manager 

678 # since the cancellation should not leak out of the context manager 

679 if sys.version_info >= (3, 11): 

680 # If the task was already cancelling don't raise 

681 # asyncio.TimeoutError and instead return None 

682 # to allow the cancellation to propagate 

683 if enter_task.uncancel() > self._cancelling: 

684 return None 

685 raise asyncio.TimeoutError from exc_val 

686 return None 

687 

688 def timeout(self) -> None: 

689 if not self._cancelled: 

690 for task in set(self._tasks): 

691 task.cancel() 

692 

693 self._cancelled = True 

694 

695 

696def ceil_timeout( 

697 delay: Optional[float], ceil_threshold: float = 5 

698) -> async_timeout.Timeout: 

699 if delay is None or delay <= 0: 

700 return async_timeout.timeout(None) 

701 

702 loop = asyncio.get_running_loop() 

703 now = loop.time() 

704 when = now + delay 

705 if delay > ceil_threshold: 

706 when = ceil(when) 

707 return async_timeout.timeout_at(when) 

708 

709 

710class HeadersMixin: 

711 """Mixin for handling headers.""" 

712 

713 ATTRS = frozenset(["_content_type", "_content_dict", "_stored_content_type"]) 

714 

715 _headers: MultiMapping[str] 

716 _content_type: Optional[str] = None 

717 _content_dict: Optional[Dict[str, str]] = None 

718 _stored_content_type: Union[str, None, _SENTINEL] = sentinel 

719 

720 def _parse_content_type(self, raw: Optional[str]) -> None: 

721 self._stored_content_type = raw 

722 if raw is None: 

723 # default value according to RFC 2616 

724 self._content_type = "application/octet-stream" 

725 self._content_dict = {} 

726 else: 

727 content_type, content_mapping_proxy = parse_content_type(raw) 

728 self._content_type = content_type 

729 # _content_dict needs to be mutable so we can update it 

730 self._content_dict = content_mapping_proxy.copy() 

731 

732 @property 

733 def content_type(self) -> str: 

734 """The value of content part for Content-Type HTTP header.""" 

735 raw = self._headers.get(hdrs.CONTENT_TYPE) 

736 if self._stored_content_type != raw: 

737 self._parse_content_type(raw) 

738 assert self._content_type is not None 

739 return self._content_type 

740 

741 @property 

742 def charset(self) -> Optional[str]: 

743 """The value of charset part for Content-Type HTTP header.""" 

744 raw = self._headers.get(hdrs.CONTENT_TYPE) 

745 if self._stored_content_type != raw: 

746 self._parse_content_type(raw) 

747 assert self._content_dict is not None 

748 return self._content_dict.get("charset") 

749 

750 @property 

751 def content_length(self) -> Optional[int]: 

752 """The value of Content-Length HTTP header.""" 

753 content_length = self._headers.get(hdrs.CONTENT_LENGTH) 

754 return None if content_length is None else int(content_length) 

755 

756 

757def set_result(fut: "asyncio.Future[_T]", result: _T) -> None: 

758 if not fut.done(): 

759 fut.set_result(result) 

760 

761 

762_EXC_SENTINEL = BaseException() 

763 

764 

765class ErrorableProtocol(Protocol): 

766 def set_exception( 

767 self, 

768 exc: BaseException, 

769 exc_cause: BaseException = ..., 

770 ) -> None: ... # pragma: no cover 

771 

772 

773def set_exception( 

774 fut: "asyncio.Future[_T] | ErrorableProtocol", 

775 exc: BaseException, 

776 exc_cause: BaseException = _EXC_SENTINEL, 

777) -> None: 

778 """Set future exception. 

779 

780 If the future is marked as complete, this function is a no-op. 

781 

782 :param exc_cause: An exception that is a direct cause of ``exc``. 

783 Only set if provided. 

784 """ 

785 if asyncio.isfuture(fut) and fut.done(): 

786 return 

787 

788 exc_is_sentinel = exc_cause is _EXC_SENTINEL 

789 exc_causes_itself = exc is exc_cause 

790 if not exc_is_sentinel and not exc_causes_itself: 

791 exc.__cause__ = exc_cause 

792 

793 fut.set_exception(exc) 

794 

795 

796@functools.total_ordering 

797class AppKey(Generic[_T]): 

798 """Keys for static typing support in Application.""" 

799 

800 __slots__ = ("_name", "_t", "__orig_class__") 

801 

802 # This may be set by Python when instantiating with a generic type. We need to 

803 # support this, in order to support types that are not concrete classes, 

804 # like Iterable, which can't be passed as the second parameter to __init__. 

805 __orig_class__: Type[object] 

806 

807 def __init__(self, name: str, t: Optional[Type[_T]] = None): 

808 # Prefix with module name to help deduplicate key names. 

809 frame = inspect.currentframe() 

810 while frame: 

811 if frame.f_code.co_name == "<module>": 

812 module: str = frame.f_globals["__name__"] 

813 break 

814 frame = frame.f_back 

815 

816 self._name = module + "." + name 

817 self._t = t 

818 

819 def __lt__(self, other: object) -> bool: 

820 if isinstance(other, AppKey): 

821 return self._name < other._name 

822 return True # Order AppKey above other types. 

823 

824 def __repr__(self) -> str: 

825 t = self._t 

826 if t is None: 

827 with suppress(AttributeError): 

828 # Set to type arg. 

829 t = get_args(self.__orig_class__)[0] 

830 

831 if t is None: 

832 t_repr = "<<Unknown>>" 

833 elif isinstance(t, type): 

834 if t.__module__ == "builtins": 

835 t_repr = t.__qualname__ 

836 else: 

837 t_repr = f"{t.__module__}.{t.__qualname__}" 

838 else: 

839 t_repr = repr(t) 

840 return f"<AppKey({self._name}, type={t_repr})>" 

841 

842 

843class ChainMapProxy(Mapping[Union[str, AppKey[Any]], Any]): 

844 __slots__ = ("_maps",) 

845 

846 def __init__(self, maps: Iterable[Mapping[Union[str, AppKey[Any]], Any]]) -> None: 

847 self._maps = tuple(maps) 

848 

849 def __init_subclass__(cls) -> None: 

850 raise TypeError( 

851 "Inheritance class {} from ChainMapProxy " 

852 "is forbidden".format(cls.__name__) 

853 ) 

854 

855 @overload # type: ignore[override] 

856 def __getitem__(self, key: AppKey[_T]) -> _T: ... 

857 

858 @overload 

859 def __getitem__(self, key: str) -> Any: ... 

860 

861 def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any: 

862 for mapping in self._maps: 

863 try: 

864 return mapping[key] 

865 except KeyError: 

866 pass 

867 raise KeyError(key) 

868 

869 @overload # type: ignore[override] 

870 def get(self, key: AppKey[_T], default: _S) -> Union[_T, _S]: ... 

871 

872 @overload 

873 def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: ... 

874 

875 @overload 

876 def get(self, key: str, default: Any = ...) -> Any: ... 

877 

878 def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any: 

879 try: 

880 return self[key] 

881 except KeyError: 

882 return default 

883 

884 def __len__(self) -> int: 

885 # reuses stored hash values if possible 

886 return len(set().union(*self._maps)) 

887 

888 def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]: 

889 d: Dict[Union[str, AppKey[Any]], Any] = {} 

890 for mapping in reversed(self._maps): 

891 # reuses stored hash values if possible 

892 d.update(mapping) 

893 return iter(d) 

894 

895 def __contains__(self, key: object) -> bool: 

896 return any(key in m for m in self._maps) 

897 

898 def __bool__(self) -> bool: 

899 return any(self._maps) 

900 

901 def __repr__(self) -> str: 

902 content = ", ".join(map(repr, self._maps)) 

903 return f"ChainMapProxy({content})" 

904 

905 

906# https://tools.ietf.org/html/rfc7232#section-2.3 

907_ETAGC = r"[!\x23-\x7E\x80-\xff]+" 

908_ETAGC_RE = re.compile(_ETAGC) 

909_QUOTED_ETAG = rf'(W/)?"({_ETAGC})"' 

910QUOTED_ETAG_RE = re.compile(_QUOTED_ETAG) 

911LIST_QUOTED_ETAG_RE = re.compile(rf"({_QUOTED_ETAG})(?:\s*,\s*|$)|(.)") 

912 

913ETAG_ANY = "*" 

914 

915 

916@attr.s(auto_attribs=True, frozen=True, slots=True) 

917class ETag: 

918 value: str 

919 is_weak: bool = False 

920 

921 

922def validate_etag_value(value: str) -> None: 

923 if value != ETAG_ANY and not _ETAGC_RE.fullmatch(value): 

924 raise ValueError( 

925 f"Value {value!r} is not a valid etag. Maybe it contains '\"'?" 

926 ) 

927 

928 

929def parse_http_date(date_str: Optional[str]) -> Optional[datetime.datetime]: 

930 """Process a date string, return a datetime object""" 

931 if date_str is not None: 

932 timetuple = parsedate(date_str) 

933 if timetuple is not None: 

934 with suppress(ValueError): 

935 return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc) 

936 return None 

937 

938 

939@functools.lru_cache 

940def must_be_empty_body(method: str, code: int) -> bool: 

941 """Check if a request must return an empty body.""" 

942 return ( 

943 code in EMPTY_BODY_STATUS_CODES 

944 or method in EMPTY_BODY_METHODS 

945 or (200 <= code < 300 and method in hdrs.METH_CONNECT_ALL) 

946 ) 

947 

948 

949def should_remove_content_length(method: str, code: int) -> bool: 

950 """Check if a Content-Length header should be removed. 

951 

952 This should always be a subset of must_be_empty_body 

953 """ 

954 # https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-8 

955 # https://www.rfc-editor.org/rfc/rfc9110.html#section-15.4.5-4 

956 return code in EMPTY_BODY_STATUS_CODES or ( 

957 200 <= code < 300 and method in hdrs.METH_CONNECT_ALL 

958 )