Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/cookiejar.py: 22%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

278 statements  

1import calendar 

2import contextlib 

3import datetime 

4import heapq 

5import itertools 

6import os # noqa 

7import pathlib 

8import pickle 

9import re 

10import time 

11import warnings 

12from collections import defaultdict 

13from http.cookies import BaseCookie, Morsel, SimpleCookie 

14from typing import ( 

15 DefaultDict, 

16 Dict, 

17 FrozenSet, 

18 Iterable, 

19 Iterator, 

20 List, 

21 Mapping, 

22 Optional, 

23 Set, 

24 Tuple, 

25 Union, 

26 cast, 

27) 

28 

29from yarl import URL 

30 

31from .abc import AbstractCookieJar, ClearCookiePredicate 

32from .helpers import is_ip_address 

33from .typedefs import LooseCookies, PathLike, StrOrURL 

34 

35__all__ = ("CookieJar", "DummyCookieJar") 

36 

37 

38CookieItem = Union[str, "Morsel[str]"] 

39 

40# We cache these string methods here as their use is in performance critical code. 

41_FORMAT_PATH = "{}/{}".format 

42_FORMAT_DOMAIN_REVERSED = "{1}.{0}".format 

43 

44# The minimum number of scheduled cookie expirations before we start cleaning up 

45# the expiration heap. This is a performance optimization to avoid cleaning up the 

46# heap too often when there are only a few scheduled expirations. 

47_MIN_SCHEDULED_COOKIE_EXPIRATION = 100 

48 

49 

50class CookieJar(AbstractCookieJar): 

51 """Implements cookie storage adhering to RFC 6265.""" 

52 

53 DATE_TOKENS_RE = re.compile( 

54 r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*" 

55 r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)" 

56 ) 

57 

58 DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})") 

59 

60 DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})") 

61 

62 DATE_MONTH_RE = re.compile( 

63 "(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|(aug)|(sep)|(oct)|(nov)|(dec)", 

64 re.I, 

65 ) 

66 

67 DATE_YEAR_RE = re.compile(r"(\d{2,4})") 

68 

69 # calendar.timegm() fails for timestamps after datetime.datetime.max 

70 # Minus one as a loss of precision occurs when timestamp() is called. 

71 MAX_TIME = ( 

72 int(datetime.datetime.max.replace(tzinfo=datetime.timezone.utc).timestamp()) - 1 

73 ) 

74 try: 

75 calendar.timegm(time.gmtime(MAX_TIME)) 

76 except (OSError, ValueError): 

77 # Hit the maximum representable time on Windows 

78 # https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/localtime-localtime32-localtime64 

79 # Throws ValueError on PyPy 3.9, OSError elsewhere 

80 MAX_TIME = calendar.timegm((3000, 12, 31, 23, 59, 59, -1, -1, -1)) 

81 except OverflowError: 

82 # #4515: datetime.max may not be representable on 32-bit platforms 

83 MAX_TIME = 2**31 - 1 

84 # Avoid minuses in the future, 3x faster 

85 SUB_MAX_TIME = MAX_TIME - 1 

86 

87 def __init__( 

88 self, 

89 *, 

90 unsafe: bool = False, 

91 quote_cookie: bool = True, 

92 treat_as_secure_origin: Union[StrOrURL, Iterable[StrOrURL], None] = None, 

93 ) -> None: 

94 self._cookies: DefaultDict[Tuple[str, str], SimpleCookie] = defaultdict( 

95 SimpleCookie 

96 ) 

97 self._morsel_cache: DefaultDict[Tuple[str, str], Dict[str, Morsel[str]]] = ( 

98 defaultdict(dict) 

99 ) 

100 self._host_only_cookies: Set[Tuple[str, str]] = set() 

101 self._unsafe = unsafe 

102 self._quote_cookie = quote_cookie 

103 if treat_as_secure_origin is None: 

104 self._treat_as_secure_origin: FrozenSet[URL] = frozenset() 

105 elif isinstance(treat_as_secure_origin, URL): 

106 self._treat_as_secure_origin = frozenset({treat_as_secure_origin.origin()}) 

107 elif isinstance(treat_as_secure_origin, str): 

108 self._treat_as_secure_origin = frozenset( 

109 {URL(treat_as_secure_origin).origin()} 

110 ) 

111 else: 

112 self._treat_as_secure_origin = frozenset( 

113 { 

114 URL(url).origin() if isinstance(url, str) else url.origin() 

115 for url in treat_as_secure_origin 

116 } 

117 ) 

118 self._expire_heap: List[Tuple[float, Tuple[str, str, str]]] = [] 

119 self._expirations: Dict[Tuple[str, str, str], float] = {} 

120 

121 @property 

122 def quote_cookie(self) -> bool: 

123 return self._quote_cookie 

124 

125 def save(self, file_path: PathLike) -> None: 

126 file_path = pathlib.Path(file_path) 

127 with file_path.open(mode="wb") as f: 

128 pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL) 

129 

130 def load(self, file_path: PathLike) -> None: 

131 file_path = pathlib.Path(file_path) 

132 with file_path.open(mode="rb") as f: 

133 self._cookies = pickle.load(f) 

134 

135 def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None: 

136 if predicate is None: 

137 self._expire_heap.clear() 

138 self._cookies.clear() 

139 self._morsel_cache.clear() 

140 self._host_only_cookies.clear() 

141 self._expirations.clear() 

142 return 

143 

144 now = time.time() 

145 to_del = [ 

146 key 

147 for (domain, path), cookie in self._cookies.items() 

148 for name, morsel in cookie.items() 

149 if ( 

150 (key := (domain, path, name)) in self._expirations 

151 and self._expirations[key] <= now 

152 ) 

153 or predicate(morsel) 

154 ] 

155 if to_del: 

156 self._delete_cookies(to_del) 

157 

158 def clear_domain(self, domain: str) -> None: 

159 self.clear(lambda x: self._is_domain_match(domain, x["domain"])) 

160 

161 def __iter__(self) -> "Iterator[Morsel[str]]": 

162 self._do_expiration() 

163 for val in self._cookies.values(): 

164 yield from val.values() 

165 

166 def __len__(self) -> int: 

167 """Return number of cookies. 

168 

169 This function does not iterate self to avoid unnecessary expiration 

170 checks. 

171 """ 

172 return sum(len(cookie.values()) for cookie in self._cookies.values()) 

173 

174 def _do_expiration(self) -> None: 

175 """Remove expired cookies.""" 

176 if not (expire_heap_len := len(self._expire_heap)): 

177 return 

178 

179 # If the expiration heap grows larger than the number expirations 

180 # times two, we clean it up to avoid keeping expired entries in 

181 # the heap and consuming memory. We guard this with a minimum 

182 # threshold to avoid cleaning up the heap too often when there are 

183 # only a few scheduled expirations. 

184 if ( 

185 expire_heap_len > _MIN_SCHEDULED_COOKIE_EXPIRATION 

186 and expire_heap_len > len(self._expirations) * 2 

187 ): 

188 # Remove any expired entries from the expiration heap 

189 # that do not match the expiration time in the expirations 

190 # as it means the cookie has been re-added to the heap 

191 # with a different expiration time. 

192 self._expire_heap = [ 

193 entry 

194 for entry in self._expire_heap 

195 if self._expirations.get(entry[1]) == entry[0] 

196 ] 

197 heapq.heapify(self._expire_heap) 

198 

199 now = time.time() 

200 to_del: List[Tuple[str, str, str]] = [] 

201 # Find any expired cookies and add them to the to-delete list 

202 while self._expire_heap: 

203 when, cookie_key = self._expire_heap[0] 

204 if when > now: 

205 break 

206 heapq.heappop(self._expire_heap) 

207 # Check if the cookie hasn't been re-added to the heap 

208 # with a different expiration time as it will be removed 

209 # later when it reaches the top of the heap and its 

210 # expiration time is met. 

211 if self._expirations.get(cookie_key) == when: 

212 to_del.append(cookie_key) 

213 

214 if to_del: 

215 self._delete_cookies(to_del) 

216 

217 def _delete_cookies(self, to_del: List[Tuple[str, str, str]]) -> None: 

218 for domain, path, name in to_del: 

219 self._host_only_cookies.discard((domain, name)) 

220 self._cookies[(domain, path)].pop(name, None) 

221 self._morsel_cache[(domain, path)].pop(name, None) 

222 self._expirations.pop((domain, path, name), None) 

223 

224 def _expire_cookie(self, when: float, domain: str, path: str, name: str) -> None: 

225 cookie_key = (domain, path, name) 

226 if self._expirations.get(cookie_key) == when: 

227 # Avoid adding duplicates to the heap 

228 return 

229 heapq.heappush(self._expire_heap, (when, cookie_key)) 

230 self._expirations[cookie_key] = when 

231 

232 def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: 

233 """Update cookies.""" 

234 hostname = response_url.raw_host 

235 

236 if not self._unsafe and is_ip_address(hostname): 

237 # Don't accept cookies from IPs 

238 return 

239 

240 if isinstance(cookies, Mapping): 

241 cookies = cookies.items() 

242 

243 for name, cookie in cookies: 

244 if not isinstance(cookie, Morsel): 

245 tmp = SimpleCookie() 

246 tmp[name] = cookie # type: ignore[assignment] 

247 cookie = tmp[name] 

248 

249 domain = cookie["domain"] 

250 

251 # ignore domains with trailing dots 

252 if domain and domain[-1] == ".": 

253 domain = "" 

254 del cookie["domain"] 

255 

256 if not domain and hostname is not None: 

257 # Set the cookie's domain to the response hostname 

258 # and set its host-only-flag 

259 self._host_only_cookies.add((hostname, name)) 

260 domain = cookie["domain"] = hostname 

261 

262 if domain and domain[0] == ".": 

263 # Remove leading dot 

264 domain = domain[1:] 

265 cookie["domain"] = domain 

266 

267 if hostname and not self._is_domain_match(domain, hostname): 

268 # Setting cookies for different domains is not allowed 

269 continue 

270 

271 path = cookie["path"] 

272 if not path or path[0] != "/": 

273 # Set the cookie's path to the response path 

274 path = response_url.path 

275 if not path.startswith("/"): 

276 path = "/" 

277 else: 

278 # Cut everything from the last slash to the end 

279 path = "/" + path[1 : path.rfind("/")] 

280 cookie["path"] = path 

281 path = path.rstrip("/") 

282 

283 if max_age := cookie["max-age"]: 

284 try: 

285 delta_seconds = int(max_age) 

286 max_age_expiration = min(time.time() + delta_seconds, self.MAX_TIME) 

287 self._expire_cookie(max_age_expiration, domain, path, name) 

288 except ValueError: 

289 cookie["max-age"] = "" 

290 

291 elif expires := cookie["expires"]: 

292 if expire_time := self._parse_date(expires): 

293 self._expire_cookie(expire_time, domain, path, name) 

294 else: 

295 cookie["expires"] = "" 

296 

297 key = (domain, path) 

298 if self._cookies[key].get(name) != cookie: 

299 # Don't blow away the cache if the same 

300 # cookie gets set again 

301 self._cookies[key][name] = cookie 

302 self._morsel_cache[key].pop(name, None) 

303 

304 self._do_expiration() 

305 

306 def filter_cookies(self, request_url: URL) -> "BaseCookie[str]": 

307 """Returns this jar's cookies filtered by their attributes.""" 

308 if not isinstance(request_url, URL): 

309 warnings.warn( # type: ignore[unreachable] 

310 "The method accepts yarl.URL instances only, got {}".format( 

311 type(request_url) 

312 ), 

313 DeprecationWarning, 

314 ) 

315 request_url = URL(request_url) 

316 filtered: Union[SimpleCookie, "BaseCookie[str]"] = ( 

317 SimpleCookie() if self._quote_cookie else BaseCookie() 

318 ) 

319 if not self._cookies: 

320 # Skip do_expiration() if there are no cookies. 

321 return filtered 

322 self._do_expiration() 

323 if not self._cookies: 

324 # Skip rest of function if no non-expired cookies. 

325 return filtered 

326 hostname = request_url.raw_host or "" 

327 

328 is_not_secure = request_url.scheme not in ("https", "wss") 

329 if is_not_secure and self._treat_as_secure_origin: 

330 request_origin = URL() 

331 with contextlib.suppress(ValueError): 

332 request_origin = request_url.origin() 

333 is_not_secure = request_origin not in self._treat_as_secure_origin 

334 

335 # Send shared cookie 

336 for c in self._cookies[("", "")].values(): 

337 filtered[c.key] = c.value 

338 

339 if is_ip_address(hostname): 

340 if not self._unsafe: 

341 return filtered 

342 domains: Iterable[str] = (hostname,) 

343 else: 

344 # Get all the subdomains that might match a cookie (e.g. "foo.bar.com", "bar.com", "com") 

345 domains = itertools.accumulate( 

346 reversed(hostname.split(".")), _FORMAT_DOMAIN_REVERSED 

347 ) 

348 

349 # Get all the path prefixes that might match a cookie (e.g. "", "/foo", "/foo/bar") 

350 paths = itertools.accumulate(request_url.path.split("/"), _FORMAT_PATH) 

351 # Create every combination of (domain, path) pairs. 

352 pairs = itertools.product(domains, paths) 

353 

354 path_len = len(request_url.path) 

355 # Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4 

356 for p in pairs: 

357 for name, cookie in self._cookies[p].items(): 

358 domain = cookie["domain"] 

359 

360 if (domain, name) in self._host_only_cookies and domain != hostname: 

361 continue 

362 

363 # Skip edge case when the cookie has a trailing slash but request doesn't. 

364 if len(cookie["path"]) > path_len: 

365 continue 

366 

367 if is_not_secure and cookie["secure"]: 

368 continue 

369 

370 # We already built the Morsel so reuse it here 

371 if name in self._morsel_cache[p]: 

372 filtered[name] = self._morsel_cache[p][name] 

373 continue 

374 

375 # It's critical we use the Morsel so the coded_value 

376 # (based on cookie version) is preserved 

377 mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel())) 

378 mrsl_val.set(cookie.key, cookie.value, cookie.coded_value) 

379 self._morsel_cache[p][name] = mrsl_val 

380 filtered[name] = mrsl_val 

381 

382 return filtered 

383 

384 @staticmethod 

385 def _is_domain_match(domain: str, hostname: str) -> bool: 

386 """Implements domain matching adhering to RFC 6265.""" 

387 if hostname == domain: 

388 return True 

389 

390 if not hostname.endswith(domain): 

391 return False 

392 

393 non_matching = hostname[: -len(domain)] 

394 

395 if not non_matching.endswith("."): 

396 return False 

397 

398 return not is_ip_address(hostname) 

399 

400 @classmethod 

401 def _parse_date(cls, date_str: str) -> Optional[int]: 

402 """Implements date string parsing adhering to RFC 6265.""" 

403 if not date_str: 

404 return None 

405 

406 found_time = False 

407 found_day = False 

408 found_month = False 

409 found_year = False 

410 

411 hour = minute = second = 0 

412 day = 0 

413 month = 0 

414 year = 0 

415 

416 for token_match in cls.DATE_TOKENS_RE.finditer(date_str): 

417 token = token_match.group("token") 

418 

419 if not found_time: 

420 time_match = cls.DATE_HMS_TIME_RE.match(token) 

421 if time_match: 

422 found_time = True 

423 hour, minute, second = (int(s) for s in time_match.groups()) 

424 continue 

425 

426 if not found_day: 

427 day_match = cls.DATE_DAY_OF_MONTH_RE.match(token) 

428 if day_match: 

429 found_day = True 

430 day = int(day_match.group()) 

431 continue 

432 

433 if not found_month: 

434 month_match = cls.DATE_MONTH_RE.match(token) 

435 if month_match: 

436 found_month = True 

437 assert month_match.lastindex is not None 

438 month = month_match.lastindex 

439 continue 

440 

441 if not found_year: 

442 year_match = cls.DATE_YEAR_RE.match(token) 

443 if year_match: 

444 found_year = True 

445 year = int(year_match.group()) 

446 

447 if 70 <= year <= 99: 

448 year += 1900 

449 elif 0 <= year <= 69: 

450 year += 2000 

451 

452 if False in (found_day, found_month, found_year, found_time): 

453 return None 

454 

455 if not 1 <= day <= 31: 

456 return None 

457 

458 if year < 1601 or hour > 23 or minute > 59 or second > 59: 

459 return None 

460 

461 return calendar.timegm((year, month, day, hour, minute, second, -1, -1, -1)) 

462 

463 

464class DummyCookieJar(AbstractCookieJar): 

465 """Implements a dummy cookie storage. 

466 

467 It can be used with the ClientSession when no cookie processing is needed. 

468 

469 """ 

470 

471 def __iter__(self) -> "Iterator[Morsel[str]]": 

472 while False: 

473 yield None # type: ignore[unreachable] 

474 

475 def __len__(self) -> int: 

476 return 0 

477 

478 @property 

479 def quote_cookie(self) -> bool: 

480 return True 

481 

482 def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None: 

483 pass 

484 

485 def clear_domain(self, domain: str) -> None: 

486 pass 

487 

488 def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: 

489 pass 

490 

491 def filter_cookies(self, request_url: URL) -> "BaseCookie[str]": 

492 return SimpleCookie()