Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/aiohttp/helpers.py: 46%

456 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:56 +0000

1"""Various helper functions""" 

2 

3import asyncio 

4import base64 

5import binascii 

6import datetime 

7import functools 

8import inspect 

9import netrc 

10import os 

11import platform 

12import re 

13import sys 

14import time 

15import warnings 

16import weakref 

17from collections import namedtuple 

18from contextlib import suppress 

19from email.parser import HeaderParser 

20from email.utils import parsedate 

21from math import ceil 

22from pathlib import Path 

23from types import TracebackType 

24from typing import ( 

25 Any, 

26 Callable, 

27 ContextManager, 

28 Dict, 

29 Generator, 

30 Generic, 

31 Iterable, 

32 Iterator, 

33 List, 

34 Mapping, 

35 Optional, 

36 Pattern, 

37 Set, 

38 Tuple, 

39 Type, 

40 TypeVar, 

41 Union, 

42 cast, 

43) 

44from urllib.parse import quote 

45from urllib.request import getproxies, proxy_bypass 

46 

47import async_timeout 

48import attr 

49from multidict import MultiDict, MultiDictProxy 

50from yarl import URL 

51 

52from . import hdrs 

53from .log import client_logger, internal_logger 

54from .typedefs import PathLike, Protocol # noqa 

55 

56__all__ = ("BasicAuth", "ChainMapProxy", "ETag") 

57 

58IS_MACOS = platform.system() == "Darwin" 

59IS_WINDOWS = platform.system() == "Windows" 

60 

61PY_36 = sys.version_info >= (3, 6) 

62PY_37 = sys.version_info >= (3, 7) 

63PY_38 = sys.version_info >= (3, 8) 

64PY_310 = sys.version_info >= (3, 10) 

65PY_311 = sys.version_info >= (3, 11) 

66 

67if sys.version_info < (3, 7): 

68 import idna_ssl 

69 

70 idna_ssl.patch_match_hostname() 

71 

72 def all_tasks( 

73 loop: Optional[asyncio.AbstractEventLoop] = None, 

74 ) -> Set["asyncio.Task[Any]"]: 

75 tasks = list(asyncio.Task.all_tasks(loop)) 

76 return {t for t in tasks if not t.done()} 

77 

78else: 

79 all_tasks = asyncio.all_tasks 

80 

81 

82_T = TypeVar("_T") 

83_S = TypeVar("_S") 

84 

85 

86sentinel: Any = object() 

87NO_EXTENSIONS: bool = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS")) 

88 

89# N.B. sys.flags.dev_mode is available on Python 3.7+, use getattr 

90# for compatibility with older versions 

91DEBUG: bool = getattr(sys.flags, "dev_mode", False) or ( 

92 not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG")) 

93) 

94 

95 

96CHAR = {chr(i) for i in range(0, 128)} 

97CTL = {chr(i) for i in range(0, 32)} | { 

98 chr(127), 

99} 

100SEPARATORS = { 

101 "(", 

102 ")", 

103 "<", 

104 ">", 

105 "@", 

106 ",", 

107 ";", 

108 ":", 

109 "\\", 

110 '"', 

111 "/", 

112 "[", 

113 "]", 

114 "?", 

115 "=", 

116 "{", 

117 "}", 

118 " ", 

119 chr(9), 

120} 

121TOKEN = CHAR ^ CTL ^ SEPARATORS 

122 

123 

124class noop: 

125 def __await__(self) -> Generator[None, None, None]: 

126 yield 

127 

128 

129class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])): 

130 """Http basic authentication helper.""" 

131 

132 def __new__( 

133 cls, login: str, password: str = "", encoding: str = "latin1" 

134 ) -> "BasicAuth": 

135 if login is None: 

136 raise ValueError("None is not allowed as login value") 

137 

138 if password is None: 

139 raise ValueError("None is not allowed as password value") 

140 

141 if ":" in login: 

142 raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)') 

143 

144 return super().__new__(cls, login, password, encoding) 

145 

146 @classmethod 

147 def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth": 

148 """Create a BasicAuth object from an Authorization HTTP header.""" 

149 try: 

150 auth_type, encoded_credentials = auth_header.split(" ", 1) 

151 except ValueError: 

152 raise ValueError("Could not parse authorization header.") 

153 

154 if auth_type.lower() != "basic": 

155 raise ValueError("Unknown authorization method %s" % auth_type) 

156 

157 try: 

158 decoded = base64.b64decode( 

159 encoded_credentials.encode("ascii"), validate=True 

160 ).decode(encoding) 

161 except binascii.Error: 

162 raise ValueError("Invalid base64 encoding.") 

163 

164 try: 

165 # RFC 2617 HTTP Authentication 

166 # https://www.ietf.org/rfc/rfc2617.txt 

167 # the colon must be present, but the username and password may be 

168 # otherwise blank. 

169 username, password = decoded.split(":", 1) 

170 except ValueError: 

171 raise ValueError("Invalid credentials.") 

172 

173 return cls(username, password, encoding=encoding) 

174 

175 @classmethod 

176 def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]: 

177 """Create BasicAuth from url.""" 

178 if not isinstance(url, URL): 

179 raise TypeError("url should be yarl.URL instance") 

180 if url.user is None: 

181 return None 

182 return cls(url.user, url.password or "", encoding=encoding) 

183 

184 def encode(self) -> str: 

185 """Encode credentials.""" 

186 creds = (f"{self.login}:{self.password}").encode(self.encoding) 

187 return "Basic %s" % base64.b64encode(creds).decode(self.encoding) 

188 

189 

190def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]: 

191 auth = BasicAuth.from_url(url) 

192 if auth is None: 

193 return url, None 

194 else: 

195 return url.with_user(None), auth 

196 

197 

198def netrc_from_env() -> Optional[netrc.netrc]: 

199 """Load netrc from file. 

200 

201 Attempt to load it from the path specified by the env-var 

202 NETRC or in the default location in the user's home directory. 

203 

204 Returns None if it couldn't be found or fails to parse. 

205 """ 

206 netrc_env = os.environ.get("NETRC") 

207 

208 if netrc_env is not None: 

209 netrc_path = Path(netrc_env) 

210 else: 

211 try: 

212 home_dir = Path.home() 

213 except RuntimeError as e: # pragma: no cover 

214 # if pathlib can't resolve home, it may raise a RuntimeError 

215 client_logger.debug( 

216 "Could not resolve home directory when " 

217 "trying to look for .netrc file: %s", 

218 e, 

219 ) 

220 return None 

221 

222 netrc_path = home_dir / ("_netrc" if IS_WINDOWS else ".netrc") 

223 

224 try: 

225 return netrc.netrc(str(netrc_path)) 

226 except netrc.NetrcParseError as e: 

227 client_logger.warning("Could not parse .netrc file: %s", e) 

228 except OSError as e: 

229 # we couldn't read the file (doesn't exist, permissions, etc.) 

230 if netrc_env or netrc_path.is_file(): 

231 # only warn if the environment wanted us to load it, 

232 # or it appears like the default file does actually exist 

233 client_logger.warning("Could not read .netrc file: %s", e) 

234 

235 return None 

236 

237 

238@attr.s(auto_attribs=True, frozen=True, slots=True) 

239class ProxyInfo: 

240 proxy: URL 

241 proxy_auth: Optional[BasicAuth] 

242 

243 

244def proxies_from_env() -> Dict[str, ProxyInfo]: 

245 proxy_urls = { 

246 k: URL(v) 

247 for k, v in getproxies().items() 

248 if k in ("http", "https", "ws", "wss") 

249 } 

250 netrc_obj = netrc_from_env() 

251 stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()} 

252 ret = {} 

253 for proto, val in stripped.items(): 

254 proxy, auth = val 

255 if proxy.scheme in ("https", "wss"): 

256 client_logger.warning( 

257 "%s proxies %s are not supported, ignoring", proxy.scheme.upper(), proxy 

258 ) 

259 continue 

260 if netrc_obj and auth is None: 

261 auth_from_netrc = None 

262 if proxy.host is not None: 

263 auth_from_netrc = netrc_obj.authenticators(proxy.host) 

264 if auth_from_netrc is not None: 

265 # auth_from_netrc is a (`user`, `account`, `password`) tuple, 

266 # `user` and `account` both can be username, 

267 # if `user` is None, use `account` 

268 *logins, password = auth_from_netrc 

269 login = logins[0] if logins[0] else logins[-1] 

270 auth = BasicAuth(cast(str, login), cast(str, password)) 

271 ret[proto] = ProxyInfo(proxy, auth) 

272 return ret 

273 

274 

275def current_task( 

276 loop: Optional[asyncio.AbstractEventLoop] = None, 

277) -> "Optional[asyncio.Task[Any]]": 

278 if sys.version_info >= (3, 7): 

279 return asyncio.current_task(loop=loop) 

280 else: 

281 return asyncio.Task.current_task(loop=loop) 

282 

283 

284def get_running_loop( 

285 loop: Optional[asyncio.AbstractEventLoop] = None, 

286) -> asyncio.AbstractEventLoop: 

287 if loop is None: 

288 loop = asyncio.get_event_loop() 

289 if not loop.is_running(): 

290 warnings.warn( 

291 "The object should be created within an async function", 

292 DeprecationWarning, 

293 stacklevel=3, 

294 ) 

295 if loop.get_debug(): 

296 internal_logger.warning( 

297 "The object should be created within an async function", stack_info=True 

298 ) 

299 return loop 

300 

301 

302def isasyncgenfunction(obj: Any) -> bool: 

303 func = getattr(inspect, "isasyncgenfunction", None) 

304 if func is not None: 

305 return func(obj) # type: ignore[no-any-return] 

306 else: 

307 return False 

308 

309 

310def get_env_proxy_for_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]: 

311 """Get a permitted proxy for the given URL from the env.""" 

312 if url.host is not None and proxy_bypass(url.host): 

313 raise LookupError(f"Proxying is disallowed for `{url.host!r}`") 

314 

315 proxies_in_env = proxies_from_env() 

316 try: 

317 proxy_info = proxies_in_env[url.scheme] 

318 except KeyError: 

319 raise LookupError(f"No proxies found for `{url!s}` in the env") 

320 else: 

321 return proxy_info.proxy, proxy_info.proxy_auth 

322 

323 

324@attr.s(auto_attribs=True, frozen=True, slots=True) 

325class MimeType: 

326 type: str 

327 subtype: str 

328 suffix: str 

329 parameters: "MultiDictProxy[str]" 

330 

331 

332@functools.lru_cache(maxsize=56) 

333def parse_mimetype(mimetype: str) -> MimeType: 

334 """Parses a MIME type into its components. 

335 

336 mimetype is a MIME type string. 

337 

338 Returns a MimeType object. 

339 

340 Example: 

341 

342 >>> parse_mimetype('text/html; charset=utf-8') 

343 MimeType(type='text', subtype='html', suffix='', 

344 parameters={'charset': 'utf-8'}) 

345 

346 """ 

347 if not mimetype: 

348 return MimeType( 

349 type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict()) 

350 ) 

351 

352 parts = mimetype.split(";") 

353 params: MultiDict[str] = MultiDict() 

354 for item in parts[1:]: 

355 if not item: 

356 continue 

357 key, value = cast( 

358 Tuple[str, str], item.split("=", 1) if "=" in item else (item, "") 

359 ) 

360 params.add(key.lower().strip(), value.strip(' "')) 

361 

362 fulltype = parts[0].strip().lower() 

363 if fulltype == "*": 

364 fulltype = "*/*" 

365 

366 mtype, stype = ( 

367 cast(Tuple[str, str], fulltype.split("/", 1)) 

368 if "/" in fulltype 

369 else (fulltype, "") 

370 ) 

371 stype, suffix = ( 

372 cast(Tuple[str, str], stype.split("+", 1)) if "+" in stype else (stype, "") 

373 ) 

374 

375 return MimeType( 

376 type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params) 

377 ) 

378 

379 

380def guess_filename(obj: Any, default: Optional[str] = None) -> Optional[str]: 

381 name = getattr(obj, "name", None) 

382 if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">": 

383 return Path(name).name 

384 return default 

385 

386 

387not_qtext_re = re.compile(r"[^\041\043-\133\135-\176]") 

388QCONTENT = {chr(i) for i in range(0x20, 0x7F)} | {"\t"} 

389 

390 

391def quoted_string(content: str) -> str: 

392 """Return 7-bit content as quoted-string. 

393 

394 Format content into a quoted-string as defined in RFC5322 for 

395 Internet Message Format. Notice that this is not the 8-bit HTTP 

396 format, but the 7-bit email format. Content must be in usascii or 

397 a ValueError is raised. 

398 """ 

399 if not (QCONTENT > set(content)): 

400 raise ValueError(f"bad content for quoted-string {content!r}") 

401 return not_qtext_re.sub(lambda x: "\\" + x.group(0), content) 

402 

403 

404def content_disposition_header( 

405 disptype: str, quote_fields: bool = True, _charset: str = "utf-8", **params: str 

406) -> str: 

407 """Sets ``Content-Disposition`` header for MIME. 

408 

409 This is the MIME payload Content-Disposition header from RFC 2183 

410 and RFC 7579 section 4.2, not the HTTP Content-Disposition from 

411 RFC 6266. 

412 

413 disptype is a disposition type: inline, attachment, form-data. 

414 Should be valid extension token (see RFC 2183) 

415 

416 quote_fields performs value quoting to 7-bit MIME headers 

417 according to RFC 7578. Set to quote_fields to False if recipient 

418 can take 8-bit file names and field values. 

419 

420 _charset specifies the charset to use when quote_fields is True. 

421 

422 params is a dict with disposition params. 

423 """ 

424 if not disptype or not (TOKEN > set(disptype)): 

425 raise ValueError("bad content disposition type {!r}" "".format(disptype)) 

426 

427 value = disptype 

428 if params: 

429 lparams = [] 

430 for key, val in params.items(): 

431 if not key or not (TOKEN > set(key)): 

432 raise ValueError( 

433 "bad content disposition parameter" " {!r}={!r}".format(key, val) 

434 ) 

435 if quote_fields: 

436 if key.lower() == "filename": 

437 qval = quote(val, "", encoding=_charset) 

438 lparams.append((key, '"%s"' % qval)) 

439 else: 

440 try: 

441 qval = quoted_string(val) 

442 except ValueError: 

443 qval = "".join( 

444 (_charset, "''", quote(val, "", encoding=_charset)) 

445 ) 

446 lparams.append((key + "*", qval)) 

447 else: 

448 lparams.append((key, '"%s"' % qval)) 

449 else: 

450 qval = val.replace("\\", "\\\\").replace('"', '\\"') 

451 lparams.append((key, '"%s"' % qval)) 

452 sparams = "; ".join("=".join(pair) for pair in lparams) 

453 value = "; ".join((value, sparams)) 

454 return value 

455 

456 

457class _TSelf(Protocol, Generic[_T]): 

458 _cache: Dict[str, _T] 

459 

460 

461class reify(Generic[_T]): 

462 """Use as a class method decorator. 

463 

464 It operates almost exactly like 

465 the Python `@property` decorator, but it puts the result of the 

466 method it decorates into the instance dict after the first call, 

467 effectively replacing the function it decorates with an instance 

468 variable. It is, in Python parlance, a data descriptor. 

469 """ 

470 

471 def __init__(self, wrapped: Callable[..., _T]) -> None: 

472 self.wrapped = wrapped 

473 self.__doc__ = wrapped.__doc__ 

474 self.name = wrapped.__name__ 

475 

476 def __get__(self, inst: _TSelf[_T], owner: Optional[Type[Any]] = None) -> _T: 

477 try: 

478 try: 

479 return inst._cache[self.name] 

480 except KeyError: 

481 val = self.wrapped(inst) 

482 inst._cache[self.name] = val 

483 return val 

484 except AttributeError: 

485 if inst is None: 

486 return self 

487 raise 

488 

489 def __set__(self, inst: _TSelf[_T], value: _T) -> None: 

490 raise AttributeError("reified property is read-only") 

491 

492 

493reify_py = reify 

494 

495try: 

496 from ._helpers import reify as reify_c 

497 

498 if not NO_EXTENSIONS: 

499 reify = reify_c # type: ignore[misc,assignment] 

500except ImportError: 

501 pass 

502 

503_ipv4_pattern = ( 

504 r"^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}" 

505 r"(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$" 

506) 

507_ipv6_pattern = ( 

508 r"^(?:(?:(?:[A-F0-9]{1,4}:){6}|(?=(?:[A-F0-9]{0,4}:){0,6}" 

509 r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}$)(([0-9A-F]{1,4}:){0,5}|:)" 

510 r"((:[0-9A-F]{1,4}){1,5}:|:)|::(?:[A-F0-9]{1,4}:){5})" 

511 r"(?:(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])\.){3}" 

512 r"(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])|(?:[A-F0-9]{1,4}:){7}" 

513 r"[A-F0-9]{1,4}|(?=(?:[A-F0-9]{0,4}:){0,7}[A-F0-9]{0,4}$)" 

514 r"(([0-9A-F]{1,4}:){1,7}|:)((:[0-9A-F]{1,4}){1,7}|:)|(?:[A-F0-9]{1,4}:){7}" 

515 r":|:(:[A-F0-9]{1,4}){7})$" 

516) 

517_ipv4_regex = re.compile(_ipv4_pattern) 

518_ipv6_regex = re.compile(_ipv6_pattern, flags=re.IGNORECASE) 

519_ipv4_regexb = re.compile(_ipv4_pattern.encode("ascii")) 

520_ipv6_regexb = re.compile(_ipv6_pattern.encode("ascii"), flags=re.IGNORECASE) 

521 

522 

523def _is_ip_address( 

524 regex: Pattern[str], regexb: Pattern[bytes], host: Optional[Union[str, bytes]] 

525) -> bool: 

526 if host is None: 

527 return False 

528 if isinstance(host, str): 

529 return bool(regex.match(host)) 

530 elif isinstance(host, (bytes, bytearray, memoryview)): 

531 return bool(regexb.match(host)) 

532 else: 

533 raise TypeError(f"{host} [{type(host)}] is not a str or bytes") 

534 

535 

536is_ipv4_address = functools.partial(_is_ip_address, _ipv4_regex, _ipv4_regexb) 

537is_ipv6_address = functools.partial(_is_ip_address, _ipv6_regex, _ipv6_regexb) 

538 

539 

540def is_ip_address(host: Optional[Union[str, bytes, bytearray, memoryview]]) -> bool: 

541 return is_ipv4_address(host) or is_ipv6_address(host) 

542 

543 

544def next_whole_second() -> datetime.datetime: 

545 """Return current time rounded up to the next whole second.""" 

546 return datetime.datetime.now(datetime.timezone.utc).replace( 

547 microsecond=0 

548 ) + datetime.timedelta(seconds=0) 

549 

550 

551_cached_current_datetime: Optional[int] = None 

552_cached_formatted_datetime = "" 

553 

554 

555def rfc822_formatted_time() -> str: 

556 global _cached_current_datetime 

557 global _cached_formatted_datetime 

558 

559 now = int(time.time()) 

560 if now != _cached_current_datetime: 

561 # Weekday and month names for HTTP date/time formatting; 

562 # always English! 

563 # Tuples are constants stored in codeobject! 

564 _weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun") 

565 _monthname = ( 

566 "", # Dummy so we can use 1-based month numbers 

567 "Jan", 

568 "Feb", 

569 "Mar", 

570 "Apr", 

571 "May", 

572 "Jun", 

573 "Jul", 

574 "Aug", 

575 "Sep", 

576 "Oct", 

577 "Nov", 

578 "Dec", 

579 ) 

580 

581 year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now) 

582 _cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( 

583 _weekdayname[wd], 

584 day, 

585 _monthname[month], 

586 year, 

587 hh, 

588 mm, 

589 ss, 

590 ) 

591 _cached_current_datetime = now 

592 return _cached_formatted_datetime 

593 

594 

595def _weakref_handle(info: "Tuple[weakref.ref[object], str]") -> None: 

596 ref, name = info 

597 ob = ref() 

598 if ob is not None: 

599 with suppress(Exception): 

600 getattr(ob, name)() 

601 

602 

603def weakref_handle( 

604 ob: object, name: str, timeout: float, loop: asyncio.AbstractEventLoop 

605) -> Optional[asyncio.TimerHandle]: 

606 if timeout is not None and timeout > 0: 

607 when = loop.time() + timeout 

608 if timeout >= 5: 

609 when = ceil(when) 

610 

611 return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name)) 

612 return None 

613 

614 

615def call_later( 

616 cb: Callable[[], Any], timeout: float, loop: asyncio.AbstractEventLoop 

617) -> Optional[asyncio.TimerHandle]: 

618 if timeout is not None and timeout > 0: 

619 when = loop.time() + timeout 

620 if timeout > 5: 

621 when = ceil(when) 

622 return loop.call_at(when, cb) 

623 return None 

624 

625 

626class TimeoutHandle: 

627 """Timeout handle""" 

628 

629 def __init__( 

630 self, loop: asyncio.AbstractEventLoop, timeout: Optional[float] 

631 ) -> None: 

632 self._timeout = timeout 

633 self._loop = loop 

634 self._callbacks: List[ 

635 Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]] 

636 ] = [] 

637 

638 def register( 

639 self, callback: Callable[..., None], *args: Any, **kwargs: Any 

640 ) -> None: 

641 self._callbacks.append((callback, args, kwargs)) 

642 

643 def close(self) -> None: 

644 self._callbacks.clear() 

645 

646 def start(self) -> Optional[asyncio.Handle]: 

647 timeout = self._timeout 

648 if timeout is not None and timeout > 0: 

649 when = self._loop.time() + timeout 

650 if timeout >= 5: 

651 when = ceil(when) 

652 return self._loop.call_at(when, self.__call__) 

653 else: 

654 return None 

655 

656 def timer(self) -> "BaseTimerContext": 

657 if self._timeout is not None and self._timeout > 0: 

658 timer = TimerContext(self._loop) 

659 self.register(timer.timeout) 

660 return timer 

661 else: 

662 return TimerNoop() 

663 

664 def __call__(self) -> None: 

665 for cb, args, kwargs in self._callbacks: 

666 with suppress(Exception): 

667 cb(*args, **kwargs) 

668 

669 self._callbacks.clear() 

670 

671 

672class BaseTimerContext(ContextManager["BaseTimerContext"]): 

673 pass 

674 

675 

676class TimerNoop(BaseTimerContext): 

677 def __enter__(self) -> BaseTimerContext: 

678 return self 

679 

680 def __exit__( 

681 self, 

682 exc_type: Optional[Type[BaseException]], 

683 exc_val: Optional[BaseException], 

684 exc_tb: Optional[TracebackType], 

685 ) -> None: 

686 return 

687 

688 

689class TimerContext(BaseTimerContext): 

690 """Low resolution timeout context manager""" 

691 

692 def __init__(self, loop: asyncio.AbstractEventLoop) -> None: 

693 self._loop = loop 

694 self._tasks: List[asyncio.Task[Any]] = [] 

695 self._cancelled = False 

696 

697 def __enter__(self) -> BaseTimerContext: 

698 task = current_task(loop=self._loop) 

699 

700 if task is None: 

701 raise RuntimeError( 

702 "Timeout context manager should be used " "inside a task" 

703 ) 

704 

705 if self._cancelled: 

706 raise asyncio.TimeoutError from None 

707 

708 self._tasks.append(task) 

709 return self 

710 

711 def __exit__( 

712 self, 

713 exc_type: Optional[Type[BaseException]], 

714 exc_val: Optional[BaseException], 

715 exc_tb: Optional[TracebackType], 

716 ) -> Optional[bool]: 

717 if self._tasks: 

718 self._tasks.pop() 

719 

720 if exc_type is asyncio.CancelledError and self._cancelled: 

721 raise asyncio.TimeoutError from None 

722 return None 

723 

724 def timeout(self) -> None: 

725 if not self._cancelled: 

726 for task in set(self._tasks): 

727 task.cancel() 

728 

729 self._cancelled = True 

730 

731 

732def ceil_timeout(delay: Optional[float]) -> async_timeout.Timeout: 

733 if delay is None or delay <= 0: 

734 return async_timeout.timeout(None) 

735 

736 loop = get_running_loop() 

737 now = loop.time() 

738 when = now + delay 

739 if delay > 5: 

740 when = ceil(when) 

741 return async_timeout.timeout_at(when) 

742 

743 

744class HeadersMixin: 

745 

746 ATTRS = frozenset(["_content_type", "_content_dict", "_stored_content_type"]) 

747 

748 _content_type: Optional[str] = None 

749 _content_dict: Optional[Dict[str, str]] = None 

750 _stored_content_type = sentinel 

751 

752 def _parse_content_type(self, raw: str) -> None: 

753 self._stored_content_type = raw 

754 if raw is None: 

755 # default value according to RFC 2616 

756 self._content_type = "application/octet-stream" 

757 self._content_dict = {} 

758 else: 

759 msg = HeaderParser().parsestr("Content-Type: " + raw) 

760 self._content_type = msg.get_content_type() 

761 params = msg.get_params() 

762 self._content_dict = dict(params[1:]) # First element is content type again 

763 

764 @property 

765 def content_type(self) -> str: 

766 """The value of content part for Content-Type HTTP header.""" 

767 raw = self._headers.get(hdrs.CONTENT_TYPE) # type: ignore[attr-defined] 

768 if self._stored_content_type != raw: 

769 self._parse_content_type(raw) 

770 return self._content_type # type: ignore[return-value] 

771 

772 @property 

773 def charset(self) -> Optional[str]: 

774 """The value of charset part for Content-Type HTTP header.""" 

775 raw = self._headers.get(hdrs.CONTENT_TYPE) # type: ignore[attr-defined] 

776 if self._stored_content_type != raw: 

777 self._parse_content_type(raw) 

778 return self._content_dict.get("charset") # type: ignore[union-attr] 

779 

780 @property 

781 def content_length(self) -> Optional[int]: 

782 """The value of Content-Length HTTP header.""" 

783 content_length = self._headers.get( # type: ignore[attr-defined] 

784 hdrs.CONTENT_LENGTH 

785 ) 

786 

787 if content_length is not None: 

788 return int(content_length) 

789 else: 

790 return None 

791 

792 

793def set_result(fut: "asyncio.Future[_T]", result: _T) -> None: 

794 if not fut.done(): 

795 fut.set_result(result) 

796 

797 

798def set_exception(fut: "asyncio.Future[_T]", exc: BaseException) -> None: 

799 if not fut.done(): 

800 fut.set_exception(exc) 

801 

802 

803class ChainMapProxy(Mapping[str, Any]): 

804 __slots__ = ("_maps",) 

805 

806 def __init__(self, maps: Iterable[Mapping[str, Any]]) -> None: 

807 self._maps = tuple(maps) 

808 

809 def __init_subclass__(cls) -> None: 

810 raise TypeError( 

811 "Inheritance class {} from ChainMapProxy " 

812 "is forbidden".format(cls.__name__) 

813 ) 

814 

815 def __getitem__(self, key: str) -> Any: 

816 for mapping in self._maps: 

817 try: 

818 return mapping[key] 

819 except KeyError: 

820 pass 

821 raise KeyError(key) 

822 

823 def get(self, key: str, default: Any = None) -> Any: 

824 return self[key] if key in self else default 

825 

826 def __len__(self) -> int: 

827 # reuses stored hash values if possible 

828 return len(set().union(*self._maps)) # type: ignore[arg-type] 

829 

830 def __iter__(self) -> Iterator[str]: 

831 d: Dict[str, Any] = {} 

832 for mapping in reversed(self._maps): 

833 # reuses stored hash values if possible 

834 d.update(mapping) 

835 return iter(d) 

836 

837 def __contains__(self, key: object) -> bool: 

838 return any(key in m for m in self._maps) 

839 

840 def __bool__(self) -> bool: 

841 return any(self._maps) 

842 

843 def __repr__(self) -> str: 

844 content = ", ".join(map(repr, self._maps)) 

845 return f"ChainMapProxy({content})" 

846 

847 

848# https://tools.ietf.org/html/rfc7232#section-2.3 

849_ETAGC = r"[!#-}\x80-\xff]+" 

850_ETAGC_RE = re.compile(_ETAGC) 

851_QUOTED_ETAG = rf'(W/)?"({_ETAGC})"' 

852QUOTED_ETAG_RE = re.compile(_QUOTED_ETAG) 

853LIST_QUOTED_ETAG_RE = re.compile(rf"({_QUOTED_ETAG})(?:\s*,\s*|$)|(.)") 

854 

855ETAG_ANY = "*" 

856 

857 

858@attr.s(auto_attribs=True, frozen=True, slots=True) 

859class ETag: 

860 value: str 

861 is_weak: bool = False 

862 

863 

864def validate_etag_value(value: str) -> None: 

865 if value != ETAG_ANY and not _ETAGC_RE.fullmatch(value): 

866 raise ValueError( 

867 f"Value {value!r} is not a valid etag. Maybe it contains '\"'?" 

868 ) 

869 

870 

871def parse_http_date(date_str: Optional[str]) -> Optional[datetime.datetime]: 

872 """Process a date string, return a datetime object""" 

873 if date_str is not None: 

874 timetuple = parsedate(date_str) 

875 if timetuple is not None: 

876 with suppress(ValueError): 

877 return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc) 

878 return None