1import asyncio
2import calendar
3import contextlib
4import datetime
5import heapq
6import itertools
7import os # noqa
8import pathlib
9import pickle
10import re
11import time
12import warnings
13from collections import defaultdict
14from collections.abc import Mapping
15from http.cookies import BaseCookie, Morsel, SimpleCookie
16from typing import (
17 DefaultDict,
18 Dict,
19 Iterable,
20 Iterator,
21 List,
22 Optional,
23 Set,
24 Tuple,
25 Union,
26)
27
28from yarl import URL
29
30from ._cookie_helpers import preserve_morsel_with_coded_value
31from .abc import AbstractCookieJar, ClearCookiePredicate
32from .helpers import is_ip_address
33from .typedefs import LooseCookies, PathLike, StrOrURL
34
35__all__ = ("CookieJar", "DummyCookieJar")
36
37
38CookieItem = Union[str, "Morsel[str]"]
39
40# We cache these string methods here as their use is in performance critical code.
41_FORMAT_PATH = "{}/{}".format
42_FORMAT_DOMAIN_REVERSED = "{1}.{0}".format
43
44# The minimum number of scheduled cookie expirations before we start cleaning up
45# the expiration heap. This is a performance optimization to avoid cleaning up the
46# heap too often when there are only a few scheduled expirations.
47_MIN_SCHEDULED_COOKIE_EXPIRATION = 100
48_SIMPLE_COOKIE = SimpleCookie()
49
50
51class CookieJar(AbstractCookieJar):
52 """Implements cookie storage adhering to RFC 6265."""
53
54 DATE_TOKENS_RE = re.compile(
55 r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*"
56 r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)"
57 )
58
59 DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})")
60
61 DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})")
62
63 DATE_MONTH_RE = re.compile(
64 "(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|(aug)|(sep)|(oct)|(nov)|(dec)",
65 re.I,
66 )
67
68 DATE_YEAR_RE = re.compile(r"(\d{2,4})")
69
70 # calendar.timegm() fails for timestamps after datetime.datetime.max
71 # Minus one as a loss of precision occurs when timestamp() is called.
72 MAX_TIME = (
73 int(datetime.datetime.max.replace(tzinfo=datetime.timezone.utc).timestamp()) - 1
74 )
75 try:
76 calendar.timegm(time.gmtime(MAX_TIME))
77 except (OSError, ValueError):
78 # Hit the maximum representable time on Windows
79 # https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/localtime-localtime32-localtime64
80 # Throws ValueError on PyPy 3.9, OSError elsewhere
81 MAX_TIME = calendar.timegm((3000, 12, 31, 23, 59, 59, -1, -1, -1))
82 except OverflowError:
83 # #4515: datetime.max may not be representable on 32-bit platforms
84 MAX_TIME = 2**31 - 1
85 # Avoid minuses in the future, 3x faster
86 SUB_MAX_TIME = MAX_TIME - 1
87
88 def __init__(
89 self,
90 *,
91 unsafe: bool = False,
92 quote_cookie: bool = True,
93 treat_as_secure_origin: Union[StrOrURL, List[StrOrURL], None] = None,
94 loop: Optional[asyncio.AbstractEventLoop] = None,
95 ) -> None:
96 super().__init__(loop=loop)
97 self._cookies: DefaultDict[Tuple[str, str], SimpleCookie] = defaultdict(
98 SimpleCookie
99 )
100 self._morsel_cache: DefaultDict[Tuple[str, str], Dict[str, Morsel[str]]] = (
101 defaultdict(dict)
102 )
103 self._host_only_cookies: Set[Tuple[str, str]] = set()
104 self._unsafe = unsafe
105 self._quote_cookie = quote_cookie
106 if treat_as_secure_origin is None:
107 treat_as_secure_origin = []
108 elif isinstance(treat_as_secure_origin, URL):
109 treat_as_secure_origin = [treat_as_secure_origin.origin()]
110 elif isinstance(treat_as_secure_origin, str):
111 treat_as_secure_origin = [URL(treat_as_secure_origin).origin()]
112 else:
113 treat_as_secure_origin = [
114 URL(url).origin() if isinstance(url, str) else url.origin()
115 for url in treat_as_secure_origin
116 ]
117 self._treat_as_secure_origin = treat_as_secure_origin
118 self._expire_heap: List[Tuple[float, Tuple[str, str, str]]] = []
119 self._expirations: Dict[Tuple[str, str, str], float] = {}
120
121 @property
122 def quote_cookie(self) -> bool:
123 return self._quote_cookie
124
125 def save(self, file_path: PathLike) -> None:
126 file_path = pathlib.Path(file_path)
127 with file_path.open(mode="wb") as f:
128 pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL)
129
130 def load(self, file_path: PathLike) -> None:
131 file_path = pathlib.Path(file_path)
132 with file_path.open(mode="rb") as f:
133 self._cookies = pickle.load(f)
134
135 def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
136 if predicate is None:
137 self._expire_heap.clear()
138 self._cookies.clear()
139 self._morsel_cache.clear()
140 self._host_only_cookies.clear()
141 self._expirations.clear()
142 return
143
144 now = time.time()
145 to_del = [
146 key
147 for (domain, path), cookie in self._cookies.items()
148 for name, morsel in cookie.items()
149 if (
150 (key := (domain, path, name)) in self._expirations
151 and self._expirations[key] <= now
152 )
153 or predicate(morsel)
154 ]
155 if to_del:
156 self._delete_cookies(to_del)
157
158 def clear_domain(self, domain: str) -> None:
159 self.clear(lambda x: self._is_domain_match(domain, x["domain"]))
160
161 def __iter__(self) -> "Iterator[Morsel[str]]":
162 self._do_expiration()
163 for val in self._cookies.values():
164 yield from val.values()
165
166 def __len__(self) -> int:
167 """Return number of cookies.
168
169 This function does not iterate self to avoid unnecessary expiration
170 checks.
171 """
172 return sum(len(cookie.values()) for cookie in self._cookies.values())
173
174 def _do_expiration(self) -> None:
175 """Remove expired cookies."""
176 if not (expire_heap_len := len(self._expire_heap)):
177 return
178
179 # If the expiration heap grows larger than the number expirations
180 # times two, we clean it up to avoid keeping expired entries in
181 # the heap and consuming memory. We guard this with a minimum
182 # threshold to avoid cleaning up the heap too often when there are
183 # only a few scheduled expirations.
184 if (
185 expire_heap_len > _MIN_SCHEDULED_COOKIE_EXPIRATION
186 and expire_heap_len > len(self._expirations) * 2
187 ):
188 # Remove any expired entries from the expiration heap
189 # that do not match the expiration time in the expirations
190 # as it means the cookie has been re-added to the heap
191 # with a different expiration time.
192 self._expire_heap = [
193 entry
194 for entry in self._expire_heap
195 if self._expirations.get(entry[1]) == entry[0]
196 ]
197 heapq.heapify(self._expire_heap)
198
199 now = time.time()
200 to_del: List[Tuple[str, str, str]] = []
201 # Find any expired cookies and add them to the to-delete list
202 while self._expire_heap:
203 when, cookie_key = self._expire_heap[0]
204 if when > now:
205 break
206 heapq.heappop(self._expire_heap)
207 # Check if the cookie hasn't been re-added to the heap
208 # with a different expiration time as it will be removed
209 # later when it reaches the top of the heap and its
210 # expiration time is met.
211 if self._expirations.get(cookie_key) == when:
212 to_del.append(cookie_key)
213
214 if to_del:
215 self._delete_cookies(to_del)
216
217 def _delete_cookies(self, to_del: List[Tuple[str, str, str]]) -> None:
218 for domain, path, name in to_del:
219 self._host_only_cookies.discard((domain, name))
220 self._cookies[(domain, path)].pop(name, None)
221 self._morsel_cache[(domain, path)].pop(name, None)
222 self._expirations.pop((domain, path, name), None)
223
224 def _expire_cookie(self, when: float, domain: str, path: str, name: str) -> None:
225 cookie_key = (domain, path, name)
226 if self._expirations.get(cookie_key) == when:
227 # Avoid adding duplicates to the heap
228 return
229 heapq.heappush(self._expire_heap, (when, cookie_key))
230 self._expirations[cookie_key] = when
231
232 def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
233 """Update cookies."""
234 hostname = response_url.raw_host
235
236 if not self._unsafe and is_ip_address(hostname):
237 # Don't accept cookies from IPs
238 return
239
240 if isinstance(cookies, Mapping):
241 cookies = cookies.items()
242
243 for name, cookie in cookies:
244 if not isinstance(cookie, Morsel):
245 tmp = SimpleCookie()
246 tmp[name] = cookie # type: ignore[assignment]
247 cookie = tmp[name]
248
249 domain = cookie["domain"]
250
251 # ignore domains with trailing dots
252 if domain and domain[-1] == ".":
253 domain = ""
254 del cookie["domain"]
255
256 if not domain and hostname is not None:
257 # Set the cookie's domain to the response hostname
258 # and set its host-only-flag
259 self._host_only_cookies.add((hostname, name))
260 domain = cookie["domain"] = hostname
261
262 if domain and domain[0] == ".":
263 # Remove leading dot
264 domain = domain[1:]
265 cookie["domain"] = domain
266
267 if hostname and not self._is_domain_match(domain, hostname):
268 # Setting cookies for different domains is not allowed
269 continue
270
271 path = cookie["path"]
272 if not path or path[0] != "/":
273 # Set the cookie's path to the response path
274 path = response_url.path
275 if not path.startswith("/"):
276 path = "/"
277 else:
278 # Cut everything from the last slash to the end
279 path = "/" + path[1 : path.rfind("/")]
280 cookie["path"] = path
281 path = path.rstrip("/")
282
283 if max_age := cookie["max-age"]:
284 try:
285 delta_seconds = int(max_age)
286 max_age_expiration = min(time.time() + delta_seconds, self.MAX_TIME)
287 self._expire_cookie(max_age_expiration, domain, path, name)
288 except ValueError:
289 cookie["max-age"] = ""
290
291 elif expires := cookie["expires"]:
292 if expire_time := self._parse_date(expires):
293 self._expire_cookie(expire_time, domain, path, name)
294 else:
295 cookie["expires"] = ""
296
297 key = (domain, path)
298 if self._cookies[key].get(name) != cookie:
299 # Don't blow away the cache if the same
300 # cookie gets set again
301 self._cookies[key][name] = cookie
302 self._morsel_cache[key].pop(name, None)
303
304 self._do_expiration()
305
306 def filter_cookies(self, request_url: URL = URL()) -> "BaseCookie[str]":
307 """Returns this jar's cookies filtered by their attributes."""
308 # We always use BaseCookie now since all
309 # cookies set on on filtered are fully constructed
310 # Morsels, not just names and values.
311 filtered: BaseCookie[str] = BaseCookie()
312 if not self._cookies:
313 # Skip do_expiration() if there are no cookies.
314 return filtered
315 self._do_expiration()
316 if not self._cookies:
317 # Skip rest of function if no non-expired cookies.
318 return filtered
319 if type(request_url) is not URL:
320 warnings.warn(
321 "filter_cookies expects yarl.URL instances only,"
322 f"and will stop working in 4.x, got {type(request_url)}",
323 DeprecationWarning,
324 stacklevel=2,
325 )
326 request_url = URL(request_url)
327 hostname = request_url.raw_host or ""
328
329 is_not_secure = request_url.scheme not in ("https", "wss")
330 if is_not_secure and self._treat_as_secure_origin:
331 request_origin = URL()
332 with contextlib.suppress(ValueError):
333 request_origin = request_url.origin()
334 is_not_secure = request_origin not in self._treat_as_secure_origin
335
336 # Send shared cookie
337 key = ("", "")
338 for c in self._cookies[key].values():
339 # Check cache first
340 if c.key in self._morsel_cache[key]:
341 filtered[c.key] = self._morsel_cache[key][c.key]
342 continue
343
344 # Build and cache the morsel
345 mrsl_val = self._build_morsel(c)
346 self._morsel_cache[key][c.key] = mrsl_val
347 filtered[c.key] = mrsl_val
348
349 if is_ip_address(hostname):
350 if not self._unsafe:
351 return filtered
352 domains: Iterable[str] = (hostname,)
353 else:
354 # Get all the subdomains that might match a cookie (e.g. "foo.bar.com", "bar.com", "com")
355 domains = itertools.accumulate(
356 reversed(hostname.split(".")), _FORMAT_DOMAIN_REVERSED
357 )
358
359 # Get all the path prefixes that might match a cookie (e.g. "", "/foo", "/foo/bar")
360 paths = itertools.accumulate(request_url.path.split("/"), _FORMAT_PATH)
361 # Create every combination of (domain, path) pairs.
362 pairs = itertools.product(domains, paths)
363
364 path_len = len(request_url.path)
365 # Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4
366 for p in pairs:
367 if p not in self._cookies:
368 continue
369 for name, cookie in self._cookies[p].items():
370 domain = cookie["domain"]
371
372 if (domain, name) in self._host_only_cookies and domain != hostname:
373 continue
374
375 # Skip edge case when the cookie has a trailing slash but request doesn't.
376 if len(cookie["path"]) > path_len:
377 continue
378
379 if is_not_secure and cookie["secure"]:
380 continue
381
382 # We already built the Morsel so reuse it here
383 if name in self._morsel_cache[p]:
384 filtered[name] = self._morsel_cache[p][name]
385 continue
386
387 # Build and cache the morsel
388 mrsl_val = self._build_morsel(cookie)
389 self._morsel_cache[p][name] = mrsl_val
390 filtered[name] = mrsl_val
391
392 return filtered
393
394 def _build_morsel(self, cookie: Morsel[str]) -> Morsel[str]:
395 """Build a morsel for sending, respecting quote_cookie setting."""
396 if self._quote_cookie and cookie.coded_value and cookie.coded_value[0] == '"':
397 return preserve_morsel_with_coded_value(cookie)
398 morsel: Morsel[str] = Morsel()
399 if self._quote_cookie:
400 value, coded_value = _SIMPLE_COOKIE.value_encode(cookie.value)
401 else:
402 coded_value = value = cookie.value
403 # We use __setstate__ instead of the public set() API because it allows us to
404 # bypass validation and set already validated state. This is more stable than
405 # setting protected attributes directly and unlikely to change since it would
406 # break pickling.
407 morsel.__setstate__({"key": cookie.key, "value": value, "coded_value": coded_value}) # type: ignore[attr-defined]
408 return morsel
409
410 @staticmethod
411 def _is_domain_match(domain: str, hostname: str) -> bool:
412 """Implements domain matching adhering to RFC 6265."""
413 if hostname == domain:
414 return True
415
416 if not hostname.endswith(domain):
417 return False
418
419 non_matching = hostname[: -len(domain)]
420
421 if not non_matching.endswith("."):
422 return False
423
424 return not is_ip_address(hostname)
425
426 @classmethod
427 def _parse_date(cls, date_str: str) -> Optional[int]:
428 """Implements date string parsing adhering to RFC 6265."""
429 if not date_str:
430 return None
431
432 found_time = False
433 found_day = False
434 found_month = False
435 found_year = False
436
437 hour = minute = second = 0
438 day = 0
439 month = 0
440 year = 0
441
442 for token_match in cls.DATE_TOKENS_RE.finditer(date_str):
443
444 token = token_match.group("token")
445
446 if not found_time:
447 time_match = cls.DATE_HMS_TIME_RE.match(token)
448 if time_match:
449 found_time = True
450 hour, minute, second = (int(s) for s in time_match.groups())
451 continue
452
453 if not found_day:
454 day_match = cls.DATE_DAY_OF_MONTH_RE.match(token)
455 if day_match:
456 found_day = True
457 day = int(day_match.group())
458 continue
459
460 if not found_month:
461 month_match = cls.DATE_MONTH_RE.match(token)
462 if month_match:
463 found_month = True
464 assert month_match.lastindex is not None
465 month = month_match.lastindex
466 continue
467
468 if not found_year:
469 year_match = cls.DATE_YEAR_RE.match(token)
470 if year_match:
471 found_year = True
472 year = int(year_match.group())
473
474 if 70 <= year <= 99:
475 year += 1900
476 elif 0 <= year <= 69:
477 year += 2000
478
479 if False in (found_day, found_month, found_year, found_time):
480 return None
481
482 if not 1 <= day <= 31:
483 return None
484
485 if year < 1601 or hour > 23 or minute > 59 or second > 59:
486 return None
487
488 return calendar.timegm((year, month, day, hour, minute, second, -1, -1, -1))
489
490
491class DummyCookieJar(AbstractCookieJar):
492 """Implements a dummy cookie storage.
493
494 It can be used with the ClientSession when no cookie processing is needed.
495
496 """
497
498 def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
499 super().__init__(loop=loop)
500
501 def __iter__(self) -> "Iterator[Morsel[str]]":
502 while False:
503 yield None
504
505 def __len__(self) -> int:
506 return 0
507
508 @property
509 def quote_cookie(self) -> bool:
510 return True
511
512 def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
513 pass
514
515 def clear_domain(self, domain: str) -> None:
516 pass
517
518 def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
519 pass
520
521 def filter_cookies(self, request_url: URL) -> "BaseCookie[str]":
522 return SimpleCookie()