1import asyncio
2import calendar
3import contextlib
4import datetime
5import heapq
6import itertools
7import os # noqa
8import pathlib
9import pickle
10import re
11import time
12import warnings
13from collections import defaultdict
14from http.cookies import BaseCookie, Morsel, SimpleCookie
15from typing import (
16 DefaultDict,
17 Dict,
18 Iterable,
19 Iterator,
20 List,
21 Mapping,
22 Optional,
23 Set,
24 Tuple,
25 Union,
26 cast,
27)
28
29from yarl import URL
30
31from .abc import AbstractCookieJar, ClearCookiePredicate
32from .helpers import is_ip_address
33from .typedefs import LooseCookies, PathLike, StrOrURL
34
35__all__ = ("CookieJar", "DummyCookieJar")
36
37
38CookieItem = Union[str, "Morsel[str]"]
39
40# We cache these string methods here as their use is in performance critical code.
41_FORMAT_PATH = "{}/{}".format
42_FORMAT_DOMAIN_REVERSED = "{1}.{0}".format
43
44# The minimum number of scheduled cookie expirations before we start cleaning up
45# the expiration heap. This is a performance optimization to avoid cleaning up the
46# heap too often when there are only a few scheduled expirations.
47_MIN_SCHEDULED_COOKIE_EXPIRATION = 100
48
49
50class CookieJar(AbstractCookieJar):
51 """Implements cookie storage adhering to RFC 6265."""
52
53 DATE_TOKENS_RE = re.compile(
54 r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*"
55 r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)"
56 )
57
58 DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})")
59
60 DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})")
61
62 DATE_MONTH_RE = re.compile(
63 "(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|(aug)|(sep)|(oct)|(nov)|(dec)",
64 re.I,
65 )
66
67 DATE_YEAR_RE = re.compile(r"(\d{2,4})")
68
69 # calendar.timegm() fails for timestamps after datetime.datetime.max
70 # Minus one as a loss of precision occurs when timestamp() is called.
71 MAX_TIME = (
72 int(datetime.datetime.max.replace(tzinfo=datetime.timezone.utc).timestamp()) - 1
73 )
74 try:
75 calendar.timegm(time.gmtime(MAX_TIME))
76 except (OSError, ValueError):
77 # Hit the maximum representable time on Windows
78 # https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/localtime-localtime32-localtime64
79 # Throws ValueError on PyPy 3.9, OSError elsewhere
80 MAX_TIME = calendar.timegm((3000, 12, 31, 23, 59, 59, -1, -1, -1))
81 except OverflowError:
82 # #4515: datetime.max may not be representable on 32-bit platforms
83 MAX_TIME = 2**31 - 1
84 # Avoid minuses in the future, 3x faster
85 SUB_MAX_TIME = MAX_TIME - 1
86
87 def __init__(
88 self,
89 *,
90 unsafe: bool = False,
91 quote_cookie: bool = True,
92 treat_as_secure_origin: Union[StrOrURL, List[StrOrURL], None] = None,
93 loop: Optional[asyncio.AbstractEventLoop] = None,
94 ) -> None:
95 super().__init__(loop=loop)
96 self._cookies: DefaultDict[Tuple[str, str], SimpleCookie] = defaultdict(
97 SimpleCookie
98 )
99 self._morsel_cache: DefaultDict[Tuple[str, str], Dict[str, Morsel[str]]] = (
100 defaultdict(dict)
101 )
102 self._host_only_cookies: Set[Tuple[str, str]] = set()
103 self._unsafe = unsafe
104 self._quote_cookie = quote_cookie
105 if treat_as_secure_origin is None:
106 treat_as_secure_origin = []
107 elif isinstance(treat_as_secure_origin, URL):
108 treat_as_secure_origin = [treat_as_secure_origin.origin()]
109 elif isinstance(treat_as_secure_origin, str):
110 treat_as_secure_origin = [URL(treat_as_secure_origin).origin()]
111 else:
112 treat_as_secure_origin = [
113 URL(url).origin() if isinstance(url, str) else url.origin()
114 for url in treat_as_secure_origin
115 ]
116 self._treat_as_secure_origin = treat_as_secure_origin
117 self._expire_heap: List[Tuple[float, Tuple[str, str, str]]] = []
118 self._expirations: Dict[Tuple[str, str, str], float] = {}
119
120 @property
121 def quote_cookie(self) -> bool:
122 return self._quote_cookie
123
124 def save(self, file_path: PathLike) -> None:
125 file_path = pathlib.Path(file_path)
126 with file_path.open(mode="wb") as f:
127 pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL)
128
129 def load(self, file_path: PathLike) -> None:
130 file_path = pathlib.Path(file_path)
131 with file_path.open(mode="rb") as f:
132 self._cookies = pickle.load(f)
133
134 def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
135 if predicate is None:
136 self._expire_heap.clear()
137 self._cookies.clear()
138 self._morsel_cache.clear()
139 self._host_only_cookies.clear()
140 self._expirations.clear()
141 return
142
143 now = time.time()
144 to_del = [
145 key
146 for (domain, path), cookie in self._cookies.items()
147 for name, morsel in cookie.items()
148 if (
149 (key := (domain, path, name)) in self._expirations
150 and self._expirations[key] <= now
151 )
152 or predicate(morsel)
153 ]
154 if to_del:
155 self._delete_cookies(to_del)
156
157 def clear_domain(self, domain: str) -> None:
158 self.clear(lambda x: self._is_domain_match(domain, x["domain"]))
159
160 def __iter__(self) -> "Iterator[Morsel[str]]":
161 self._do_expiration()
162 for val in self._cookies.values():
163 yield from val.values()
164
165 def __len__(self) -> int:
166 """Return number of cookies.
167
168 This function does not iterate self to avoid unnecessary expiration
169 checks.
170 """
171 return sum(len(cookie.values()) for cookie in self._cookies.values())
172
173 def _do_expiration(self) -> None:
174 """Remove expired cookies."""
175 if not (expire_heap_len := len(self._expire_heap)):
176 return
177
178 # If the expiration heap grows larger than the number expirations
179 # times two, we clean it up to avoid keeping expired entries in
180 # the heap and consuming memory. We guard this with a minimum
181 # threshold to avoid cleaning up the heap too often when there are
182 # only a few scheduled expirations.
183 if (
184 expire_heap_len > _MIN_SCHEDULED_COOKIE_EXPIRATION
185 and expire_heap_len > len(self._expirations) * 2
186 ):
187 # Remove any expired entries from the expiration heap
188 # that do not match the expiration time in the expirations
189 # as it means the cookie has been re-added to the heap
190 # with a different expiration time.
191 self._expire_heap = [
192 entry
193 for entry in self._expire_heap
194 if self._expirations.get(entry[1]) == entry[0]
195 ]
196 heapq.heapify(self._expire_heap)
197
198 now = time.time()
199 to_del: List[Tuple[str, str, str]] = []
200 # Find any expired cookies and add them to the to-delete list
201 while self._expire_heap:
202 when, cookie_key = self._expire_heap[0]
203 if when > now:
204 break
205 heapq.heappop(self._expire_heap)
206 # Check if the cookie hasn't been re-added to the heap
207 # with a different expiration time as it will be removed
208 # later when it reaches the top of the heap and its
209 # expiration time is met.
210 if self._expirations.get(cookie_key) == when:
211 to_del.append(cookie_key)
212
213 if to_del:
214 self._delete_cookies(to_del)
215
216 def _delete_cookies(self, to_del: List[Tuple[str, str, str]]) -> None:
217 for domain, path, name in to_del:
218 self._host_only_cookies.discard((domain, name))
219 self._cookies[(domain, path)].pop(name, None)
220 self._morsel_cache[(domain, path)].pop(name, None)
221 self._expirations.pop((domain, path, name), None)
222
223 def _expire_cookie(self, when: float, domain: str, path: str, name: str) -> None:
224 cookie_key = (domain, path, name)
225 if self._expirations.get(cookie_key) == when:
226 # Avoid adding duplicates to the heap
227 return
228 heapq.heappush(self._expire_heap, (when, cookie_key))
229 self._expirations[cookie_key] = when
230
231 def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
232 """Update cookies."""
233 hostname = response_url.raw_host
234
235 if not self._unsafe and is_ip_address(hostname):
236 # Don't accept cookies from IPs
237 return
238
239 if isinstance(cookies, Mapping):
240 cookies = cookies.items()
241
242 for name, cookie in cookies:
243 if not isinstance(cookie, Morsel):
244 tmp = SimpleCookie()
245 tmp[name] = cookie # type: ignore[assignment]
246 cookie = tmp[name]
247
248 domain = cookie["domain"]
249
250 # ignore domains with trailing dots
251 if domain and domain[-1] == ".":
252 domain = ""
253 del cookie["domain"]
254
255 if not domain and hostname is not None:
256 # Set the cookie's domain to the response hostname
257 # and set its host-only-flag
258 self._host_only_cookies.add((hostname, name))
259 domain = cookie["domain"] = hostname
260
261 if domain and domain[0] == ".":
262 # Remove leading dot
263 domain = domain[1:]
264 cookie["domain"] = domain
265
266 if hostname and not self._is_domain_match(domain, hostname):
267 # Setting cookies for different domains is not allowed
268 continue
269
270 path = cookie["path"]
271 if not path or path[0] != "/":
272 # Set the cookie's path to the response path
273 path = response_url.path
274 if not path.startswith("/"):
275 path = "/"
276 else:
277 # Cut everything from the last slash to the end
278 path = "/" + path[1 : path.rfind("/")]
279 cookie["path"] = path
280 path = path.rstrip("/")
281
282 if max_age := cookie["max-age"]:
283 try:
284 delta_seconds = int(max_age)
285 max_age_expiration = min(time.time() + delta_seconds, self.MAX_TIME)
286 self._expire_cookie(max_age_expiration, domain, path, name)
287 except ValueError:
288 cookie["max-age"] = ""
289
290 elif expires := cookie["expires"]:
291 if expire_time := self._parse_date(expires):
292 self._expire_cookie(expire_time, domain, path, name)
293 else:
294 cookie["expires"] = ""
295
296 key = (domain, path)
297 if self._cookies[key].get(name) != cookie:
298 # Don't blow away the cache if the same
299 # cookie gets set again
300 self._cookies[key][name] = cookie
301 self._morsel_cache[key].pop(name, None)
302
303 self._do_expiration()
304
305 def filter_cookies(self, request_url: URL = URL()) -> "BaseCookie[str]":
306 """Returns this jar's cookies filtered by their attributes."""
307 filtered: Union[SimpleCookie, "BaseCookie[str]"] = (
308 SimpleCookie() if self._quote_cookie else BaseCookie()
309 )
310 if not self._cookies:
311 # Skip do_expiration() if there are no cookies.
312 return filtered
313 self._do_expiration()
314 if not self._cookies:
315 # Skip rest of function if no non-expired cookies.
316 return filtered
317 if type(request_url) is not URL:
318 warnings.warn(
319 "filter_cookies expects yarl.URL instances only,"
320 f"and will stop working in 4.x, got {type(request_url)}",
321 DeprecationWarning,
322 stacklevel=2,
323 )
324 request_url = URL(request_url)
325 hostname = request_url.raw_host or ""
326
327 is_not_secure = request_url.scheme not in ("https", "wss")
328 if is_not_secure and self._treat_as_secure_origin:
329 request_origin = URL()
330 with contextlib.suppress(ValueError):
331 request_origin = request_url.origin()
332 is_not_secure = request_origin not in self._treat_as_secure_origin
333
334 # Send shared cookie
335 for c in self._cookies[("", "")].values():
336 filtered[c.key] = c.value
337
338 if is_ip_address(hostname):
339 if not self._unsafe:
340 return filtered
341 domains: Iterable[str] = (hostname,)
342 else:
343 # Get all the subdomains that might match a cookie (e.g. "foo.bar.com", "bar.com", "com")
344 domains = itertools.accumulate(
345 reversed(hostname.split(".")), _FORMAT_DOMAIN_REVERSED
346 )
347
348 # Get all the path prefixes that might match a cookie (e.g. "", "/foo", "/foo/bar")
349 paths = itertools.accumulate(request_url.path.split("/"), _FORMAT_PATH)
350 # Create every combination of (domain, path) pairs.
351 pairs = itertools.product(domains, paths)
352
353 path_len = len(request_url.path)
354 # Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4
355 for p in pairs:
356 for name, cookie in self._cookies[p].items():
357 domain = cookie["domain"]
358
359 if (domain, name) in self._host_only_cookies and domain != hostname:
360 continue
361
362 # Skip edge case when the cookie has a trailing slash but request doesn't.
363 if len(cookie["path"]) > path_len:
364 continue
365
366 if is_not_secure and cookie["secure"]:
367 continue
368
369 # We already built the Morsel so reuse it here
370 if name in self._morsel_cache[p]:
371 filtered[name] = self._morsel_cache[p][name]
372 continue
373
374 # It's critical we use the Morsel so the coded_value
375 # (based on cookie version) is preserved
376 mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel()))
377 mrsl_val.set(cookie.key, cookie.value, cookie.coded_value)
378 self._morsel_cache[p][name] = mrsl_val
379 filtered[name] = mrsl_val
380
381 return filtered
382
383 @staticmethod
384 def _is_domain_match(domain: str, hostname: str) -> bool:
385 """Implements domain matching adhering to RFC 6265."""
386 if hostname == domain:
387 return True
388
389 if not hostname.endswith(domain):
390 return False
391
392 non_matching = hostname[: -len(domain)]
393
394 if not non_matching.endswith("."):
395 return False
396
397 return not is_ip_address(hostname)
398
399 @classmethod
400 def _parse_date(cls, date_str: str) -> Optional[int]:
401 """Implements date string parsing adhering to RFC 6265."""
402 if not date_str:
403 return None
404
405 found_time = False
406 found_day = False
407 found_month = False
408 found_year = False
409
410 hour = minute = second = 0
411 day = 0
412 month = 0
413 year = 0
414
415 for token_match in cls.DATE_TOKENS_RE.finditer(date_str):
416
417 token = token_match.group("token")
418
419 if not found_time:
420 time_match = cls.DATE_HMS_TIME_RE.match(token)
421 if time_match:
422 found_time = True
423 hour, minute, second = (int(s) for s in time_match.groups())
424 continue
425
426 if not found_day:
427 day_match = cls.DATE_DAY_OF_MONTH_RE.match(token)
428 if day_match:
429 found_day = True
430 day = int(day_match.group())
431 continue
432
433 if not found_month:
434 month_match = cls.DATE_MONTH_RE.match(token)
435 if month_match:
436 found_month = True
437 assert month_match.lastindex is not None
438 month = month_match.lastindex
439 continue
440
441 if not found_year:
442 year_match = cls.DATE_YEAR_RE.match(token)
443 if year_match:
444 found_year = True
445 year = int(year_match.group())
446
447 if 70 <= year <= 99:
448 year += 1900
449 elif 0 <= year <= 69:
450 year += 2000
451
452 if False in (found_day, found_month, found_year, found_time):
453 return None
454
455 if not 1 <= day <= 31:
456 return None
457
458 if year < 1601 or hour > 23 or minute > 59 or second > 59:
459 return None
460
461 return calendar.timegm((year, month, day, hour, minute, second, -1, -1, -1))
462
463
464class DummyCookieJar(AbstractCookieJar):
465 """Implements a dummy cookie storage.
466
467 It can be used with the ClientSession when no cookie processing is needed.
468
469 """
470
471 def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
472 super().__init__(loop=loop)
473
474 def __iter__(self) -> "Iterator[Morsel[str]]":
475 while False:
476 yield None
477
478 def __len__(self) -> int:
479 return 0
480
481 @property
482 def quote_cookie(self) -> bool:
483 return True
484
485 def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
486 pass
487
488 def clear_domain(self, domain: str) -> None:
489 pass
490
491 def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
492 pass
493
494 def filter_cookies(self, request_url: URL) -> "BaseCookie[str]":
495 return SimpleCookie()