Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/cookiejar.py: 21%
252 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-26 06:16 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-26 06:16 +0000
1import calendar
2import contextlib
3import datetime
4import itertools
5import os # noqa
6import pathlib
7import pickle
8import re
9import time
10import warnings
11from collections import defaultdict
12from http.cookies import BaseCookie, Morsel, SimpleCookie
13from math import ceil
14from typing import (
15 DefaultDict,
16 Dict,
17 Iterable,
18 Iterator,
19 List,
20 Mapping,
21 Optional,
22 Set,
23 Tuple,
24 Union,
25 cast,
26)
28from yarl import URL
30from .abc import AbstractCookieJar, ClearCookiePredicate
31from .helpers import is_ip_address
32from .typedefs import LooseCookies, PathLike, StrOrURL
34__all__ = ("CookieJar", "DummyCookieJar")
37CookieItem = Union[str, "Morsel[str]"]
40class CookieJar(AbstractCookieJar):
41 """Implements cookie storage adhering to RFC 6265."""
43 DATE_TOKENS_RE = re.compile(
44 r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*"
45 r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)"
46 )
48 DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})")
50 DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})")
52 DATE_MONTH_RE = re.compile(
53 "(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|" "(aug)|(sep)|(oct)|(nov)|(dec)",
54 re.I,
55 )
57 DATE_YEAR_RE = re.compile(r"(\d{2,4})")
59 # calendar.timegm() fails for timestamps after datetime.datetime.max
60 # Minus one as a loss of precision occurs when timestamp() is called.
61 MAX_TIME = (
62 int(datetime.datetime.max.replace(tzinfo=datetime.timezone.utc).timestamp()) - 1
63 )
64 try:
65 calendar.timegm(time.gmtime(MAX_TIME))
66 except (OSError, ValueError):
67 # Hit the maximum representable time on Windows
68 # https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/localtime-localtime32-localtime64
69 # Throws ValueError on PyPy 3.8 and 3.9, OSError elsewhere
70 MAX_TIME = calendar.timegm((3000, 12, 31, 23, 59, 59, -1, -1, -1))
71 except OverflowError:
72 # #4515: datetime.max may not be representable on 32-bit platforms
73 MAX_TIME = 2**31 - 1
74 # Avoid minuses in the future, 3x faster
75 SUB_MAX_TIME = MAX_TIME - 1
77 def __init__(
78 self,
79 *,
80 unsafe: bool = False,
81 quote_cookie: bool = True,
82 treat_as_secure_origin: Union[StrOrURL, List[StrOrURL], None] = None,
83 ) -> None:
84 self._cookies: DefaultDict[Tuple[str, str], SimpleCookie] = defaultdict(
85 SimpleCookie
86 )
87 self._host_only_cookies: Set[Tuple[str, str]] = set()
88 self._unsafe = unsafe
89 self._quote_cookie = quote_cookie
90 if treat_as_secure_origin is None:
91 treat_as_secure_origin = []
92 elif isinstance(treat_as_secure_origin, URL):
93 treat_as_secure_origin = [treat_as_secure_origin.origin()]
94 elif isinstance(treat_as_secure_origin, str):
95 treat_as_secure_origin = [URL(treat_as_secure_origin).origin()]
96 else:
97 treat_as_secure_origin = [
98 URL(url).origin() if isinstance(url, str) else url.origin()
99 for url in treat_as_secure_origin
100 ]
101 self._treat_as_secure_origin = treat_as_secure_origin
102 self._next_expiration: float = ceil(time.time())
103 self._expirations: Dict[Tuple[str, str, str], float] = {}
105 def save(self, file_path: PathLike) -> None:
106 file_path = pathlib.Path(file_path)
107 with file_path.open(mode="wb") as f:
108 pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL)
110 def load(self, file_path: PathLike) -> None:
111 file_path = pathlib.Path(file_path)
112 with file_path.open(mode="rb") as f:
113 self._cookies = pickle.load(f)
115 def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
116 if predicate is None:
117 self._next_expiration = ceil(time.time())
118 self._cookies.clear()
119 self._host_only_cookies.clear()
120 self._expirations.clear()
121 return
123 to_del = []
124 now = time.time()
125 for (domain, path), cookie in self._cookies.items():
126 for name, morsel in cookie.items():
127 key = (domain, path, name)
128 if (
129 key in self._expirations and self._expirations[key] <= now
130 ) or predicate(morsel):
131 to_del.append(key)
133 for domain, path, name in to_del:
134 self._host_only_cookies.discard((domain, name))
135 key = (domain, path, name)
136 if key in self._expirations:
137 del self._expirations[(domain, path, name)]
138 self._cookies[(domain, path)].pop(name, None)
140 self._next_expiration = (
141 min(*self._expirations.values(), self.SUB_MAX_TIME) + 1
142 if self._expirations
143 else self.MAX_TIME
144 )
146 def clear_domain(self, domain: str) -> None:
147 self.clear(lambda x: self._is_domain_match(domain, x["domain"]))
149 def __iter__(self) -> "Iterator[Morsel[str]]":
150 self._do_expiration()
151 for val in self._cookies.values():
152 yield from val.values()
154 def __len__(self) -> int:
155 """Return number of cookies.
157 This function does not iterate self to avoid unnecessary expiration
158 checks.
159 """
160 return sum(len(cookie.values()) for cookie in self._cookies.values())
162 def _do_expiration(self) -> None:
163 self.clear(lambda x: False)
165 def _expire_cookie(self, when: float, domain: str, path: str, name: str) -> None:
166 self._next_expiration = min(self._next_expiration, when)
167 self._expirations[(domain, path, name)] = when
169 def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
170 """Update cookies."""
171 hostname = response_url.raw_host
173 if not self._unsafe and is_ip_address(hostname):
174 # Don't accept cookies from IPs
175 return
177 if isinstance(cookies, Mapping):
178 cookies = cookies.items()
180 for name, cookie in cookies:
181 if not isinstance(cookie, Morsel):
182 tmp = SimpleCookie()
183 tmp[name] = cookie # type: ignore[assignment]
184 cookie = tmp[name]
186 domain = cookie["domain"]
188 # ignore domains with trailing dots
189 if domain.endswith("."):
190 domain = ""
191 del cookie["domain"]
193 if not domain and hostname is not None:
194 # Set the cookie's domain to the response hostname
195 # and set its host-only-flag
196 self._host_only_cookies.add((hostname, name))
197 domain = cookie["domain"] = hostname
199 if domain.startswith("."):
200 # Remove leading dot
201 domain = domain[1:]
202 cookie["domain"] = domain
204 if hostname and not self._is_domain_match(domain, hostname):
205 # Setting cookies for different domains is not allowed
206 continue
208 path = cookie["path"]
209 if not path or not path.startswith("/"):
210 # Set the cookie's path to the response path
211 path = response_url.path
212 if not path.startswith("/"):
213 path = "/"
214 else:
215 # Cut everything from the last slash to the end
216 path = "/" + path[1 : path.rfind("/")]
217 cookie["path"] = path
218 path = path.rstrip("/")
220 max_age = cookie["max-age"]
221 if max_age:
222 try:
223 delta_seconds = int(max_age)
224 max_age_expiration = min(time.time() + delta_seconds, self.MAX_TIME)
225 self._expire_cookie(max_age_expiration, domain, path, name)
226 except ValueError:
227 cookie["max-age"] = ""
229 else:
230 expires = cookie["expires"]
231 if expires:
232 expire_time = self._parse_date(expires)
233 if expire_time:
234 self._expire_cookie(expire_time, domain, path, name)
235 else:
236 cookie["expires"] = ""
238 self._cookies[(domain, path)][name] = cookie
240 self._do_expiration()
242 def filter_cookies(self, request_url: URL = URL()) -> "BaseCookie[str]":
243 """Returns this jar's cookies filtered by their attributes."""
244 if not isinstance(request_url, URL):
245 warnings.warn(
246 "The method accepts yarl.URL instances only, got {}".format(
247 type(request_url)
248 ),
249 DeprecationWarning,
250 )
251 request_url = URL(request_url)
252 filtered: Union[SimpleCookie, "BaseCookie[str]"] = (
253 SimpleCookie() if self._quote_cookie else BaseCookie()
254 )
255 if not self._cookies:
256 # Skip do_expiration() if there are no cookies.
257 return filtered
258 self._do_expiration()
259 if not self._cookies:
260 # Skip rest of function if no non-expired cookies.
261 return filtered
262 hostname = request_url.raw_host or ""
264 is_not_secure = request_url.scheme not in ("https", "wss")
265 if is_not_secure and self._treat_as_secure_origin:
266 request_origin = URL()
267 with contextlib.suppress(ValueError):
268 request_origin = request_url.origin()
269 is_not_secure = request_origin not in self._treat_as_secure_origin
271 # Send shared cookie
272 for c in self._cookies[("", "")].values():
273 filtered[c.key] = c.value
275 if is_ip_address(hostname):
276 if not self._unsafe:
277 return filtered
278 domains: Iterable[str] = (hostname,)
279 else:
280 # Get all the subdomains that might match a cookie (e.g. "foo.bar.com", "bar.com", "com")
281 domains = itertools.accumulate(
282 reversed(hostname.split(".")), lambda x, y: f"{y}.{x}"
283 )
284 # Get all the path prefixes that might match a cookie (e.g. "", "/foo", "/foo/bar")
285 paths = itertools.accumulate(
286 request_url.path.split("/"), lambda x, y: f"{x}/{y}"
287 )
288 # Create every combination of (domain, path) pairs.
289 pairs = itertools.product(domains, paths)
291 # Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4
292 cookies = itertools.chain.from_iterable(
293 self._cookies[p].values() for p in pairs
294 )
295 path_len = len(request_url.path)
296 for cookie in cookies:
297 name = cookie.key
298 domain = cookie["domain"]
300 if (domain, name) in self._host_only_cookies:
301 if domain != hostname:
302 continue
304 # Skip edge case when the cookie has a trailing slash but request doesn't.
305 if len(cookie["path"]) > path_len:
306 continue
308 if is_not_secure and cookie["secure"]:
309 continue
311 # It's critical we use the Morsel so the coded_value
312 # (based on cookie version) is preserved
313 mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel()))
314 mrsl_val.set(cookie.key, cookie.value, cookie.coded_value)
315 filtered[name] = mrsl_val
317 return filtered
319 @staticmethod
320 def _is_domain_match(domain: str, hostname: str) -> bool:
321 """Implements domain matching adhering to RFC 6265."""
322 if hostname == domain:
323 return True
325 if not hostname.endswith(domain):
326 return False
328 non_matching = hostname[: -len(domain)]
330 if not non_matching.endswith("."):
331 return False
333 return not is_ip_address(hostname)
335 @classmethod
336 def _parse_date(cls, date_str: str) -> Optional[int]:
337 """Implements date string parsing adhering to RFC 6265."""
338 if not date_str:
339 return None
341 found_time = False
342 found_day = False
343 found_month = False
344 found_year = False
346 hour = minute = second = 0
347 day = 0
348 month = 0
349 year = 0
351 for token_match in cls.DATE_TOKENS_RE.finditer(date_str):
352 token = token_match.group("token")
354 if not found_time:
355 time_match = cls.DATE_HMS_TIME_RE.match(token)
356 if time_match:
357 found_time = True
358 hour, minute, second = (int(s) for s in time_match.groups())
359 continue
361 if not found_day:
362 day_match = cls.DATE_DAY_OF_MONTH_RE.match(token)
363 if day_match:
364 found_day = True
365 day = int(day_match.group())
366 continue
368 if not found_month:
369 month_match = cls.DATE_MONTH_RE.match(token)
370 if month_match:
371 found_month = True
372 assert month_match.lastindex is not None
373 month = month_match.lastindex
374 continue
376 if not found_year:
377 year_match = cls.DATE_YEAR_RE.match(token)
378 if year_match:
379 found_year = True
380 year = int(year_match.group())
382 if 70 <= year <= 99:
383 year += 1900
384 elif 0 <= year <= 69:
385 year += 2000
387 if False in (found_day, found_month, found_year, found_time):
388 return None
390 if not 1 <= day <= 31:
391 return None
393 if year < 1601 or hour > 23 or minute > 59 or second > 59:
394 return None
396 return calendar.timegm((year, month, day, hour, minute, second, -1, -1, -1))
399class DummyCookieJar(AbstractCookieJar):
400 """Implements a dummy cookie storage.
402 It can be used with the ClientSession when no cookie processing is needed.
404 """
406 def __iter__(self) -> "Iterator[Morsel[str]]":
407 while False:
408 yield None
410 def __len__(self) -> int:
411 return 0
413 def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
414 pass
416 def clear_domain(self, domain: str) -> None:
417 pass
419 def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
420 pass
422 def filter_cookies(self, request_url: URL) -> "BaseCookie[str]":
423 return SimpleCookie()