Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/cookiejar.py: 19%

255 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:52 +0000

1import contextlib 

2import datetime 

3import os # noqa 

4import pathlib 

5import pickle 

6import re 

7import warnings 

8from collections import defaultdict 

9from http.cookies import BaseCookie, Morsel, SimpleCookie 

10from typing import ( # noqa 

11 DefaultDict, 

12 Dict, 

13 Iterable, 

14 Iterator, 

15 List, 

16 Mapping, 

17 Optional, 

18 Set, 

19 Tuple, 

20 Union, 

21 cast, 

22) 

23 

24from yarl import URL 

25 

26from .abc import AbstractCookieJar, ClearCookiePredicate 

27from .helpers import is_ip_address, next_whole_second 

28from .typedefs import LooseCookies, PathLike, StrOrURL 

29 

30__all__ = ("CookieJar", "DummyCookieJar") 

31 

32 

33CookieItem = Union[str, "Morsel[str]"] 

34 

35 

36class CookieJar(AbstractCookieJar): 

37 """Implements cookie storage adhering to RFC 6265.""" 

38 

39 DATE_TOKENS_RE = re.compile( 

40 r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*" 

41 r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)" 

42 ) 

43 

44 DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})") 

45 

46 DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})") 

47 

48 DATE_MONTH_RE = re.compile( 

49 "(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|" "(aug)|(sep)|(oct)|(nov)|(dec)", 

50 re.I, 

51 ) 

52 

53 DATE_YEAR_RE = re.compile(r"(\d{2,4})") 

54 

55 MAX_TIME = datetime.datetime.max.replace(tzinfo=datetime.timezone.utc) 

56 

57 MAX_32BIT_TIME = datetime.datetime.fromtimestamp(2**31 - 1, datetime.timezone.utc) 

58 

59 def __init__( 

60 self, 

61 *, 

62 unsafe: bool = False, 

63 quote_cookie: bool = True, 

64 treat_as_secure_origin: Union[StrOrURL, List[StrOrURL], None] = None 

65 ) -> None: 

66 self._cookies: DefaultDict[Tuple[str, str], SimpleCookie[str]] = defaultdict( 

67 SimpleCookie 

68 ) 

69 self._host_only_cookies: Set[Tuple[str, str]] = set() 

70 self._unsafe = unsafe 

71 self._quote_cookie = quote_cookie 

72 if treat_as_secure_origin is None: 

73 treat_as_secure_origin = [] 

74 elif isinstance(treat_as_secure_origin, URL): 

75 treat_as_secure_origin = [treat_as_secure_origin.origin()] 

76 elif isinstance(treat_as_secure_origin, str): 

77 treat_as_secure_origin = [URL(treat_as_secure_origin).origin()] 

78 else: 

79 treat_as_secure_origin = [ 

80 URL(url).origin() if isinstance(url, str) else url.origin() 

81 for url in treat_as_secure_origin 

82 ] 

83 self._treat_as_secure_origin = treat_as_secure_origin 

84 self._next_expiration = next_whole_second() 

85 self._expirations: Dict[Tuple[str, str, str], datetime.datetime] = {} 

86 # #4515: datetime.max may not be representable on 32-bit platforms 

87 self._max_time = self.MAX_TIME 

88 try: 

89 self._max_time.timestamp() 

90 except OverflowError: 

91 self._max_time = self.MAX_32BIT_TIME 

92 

93 def save(self, file_path: PathLike) -> None: 

94 file_path = pathlib.Path(file_path) 

95 with file_path.open(mode="wb") as f: 

96 pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL) 

97 

98 def load(self, file_path: PathLike) -> None: 

99 file_path = pathlib.Path(file_path) 

100 with file_path.open(mode="rb") as f: 

101 self._cookies = pickle.load(f) 

102 

103 def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None: 

104 if predicate is None: 

105 self._next_expiration = next_whole_second() 

106 self._cookies.clear() 

107 self._host_only_cookies.clear() 

108 self._expirations.clear() 

109 return 

110 

111 to_del = [] 

112 now = datetime.datetime.now(datetime.timezone.utc) 

113 for (domain, path), cookie in self._cookies.items(): 

114 for name, morsel in cookie.items(): 

115 key = (domain, path, name) 

116 if ( 

117 key in self._expirations and self._expirations[key] <= now 

118 ) or predicate(morsel): 

119 to_del.append(key) 

120 

121 for domain, path, name in to_del: 

122 self._host_only_cookies.discard((domain, name)) 

123 key = (domain, path, name) 

124 if key in self._expirations: 

125 del self._expirations[(domain, path, name)] 

126 self._cookies[(domain, path)].pop(name, None) 

127 

128 next_expiration = min(self._expirations.values(), default=self._max_time) 

129 try: 

130 self._next_expiration = next_expiration.replace( 

131 microsecond=0 

132 ) + datetime.timedelta(seconds=1) 

133 except OverflowError: 

134 self._next_expiration = self._max_time 

135 

136 def clear_domain(self, domain: str) -> None: 

137 self.clear(lambda x: self._is_domain_match(domain, x["domain"])) 

138 

139 def __iter__(self) -> "Iterator[Morsel[str]]": 

140 self._do_expiration() 

141 for val in self._cookies.values(): 

142 yield from val.values() 

143 

144 def __len__(self) -> int: 

145 return sum(1 for i in self) 

146 

147 def _do_expiration(self) -> None: 

148 self.clear(lambda x: False) 

149 

150 def _expire_cookie( 

151 self, when: datetime.datetime, domain: str, path: str, name: str 

152 ) -> None: 

153 self._next_expiration = min(self._next_expiration, when) 

154 self._expirations[(domain, path, name)] = when 

155 

156 def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: 

157 """Update cookies.""" 

158 hostname = response_url.raw_host 

159 

160 if not self._unsafe and is_ip_address(hostname): 

161 # Don't accept cookies from IPs 

162 return 

163 

164 if isinstance(cookies, Mapping): 

165 cookies = cookies.items() 

166 

167 for name, cookie in cookies: 

168 if not isinstance(cookie, Morsel): 

169 tmp: SimpleCookie[str] = SimpleCookie() 

170 tmp[name] = cookie # type: ignore[assignment] 

171 cookie = tmp[name] 

172 

173 domain = cookie["domain"] 

174 

175 # ignore domains with trailing dots 

176 if domain.endswith("."): 

177 domain = "" 

178 del cookie["domain"] 

179 

180 if not domain and hostname is not None: 

181 # Set the cookie's domain to the response hostname 

182 # and set its host-only-flag 

183 self._host_only_cookies.add((hostname, name)) 

184 domain = cookie["domain"] = hostname 

185 

186 if domain.startswith("."): 

187 # Remove leading dot 

188 domain = domain[1:] 

189 cookie["domain"] = domain 

190 

191 if hostname and not self._is_domain_match(domain, hostname): 

192 # Setting cookies for different domains is not allowed 

193 continue 

194 

195 path = cookie["path"] 

196 if not path or not path.startswith("/"): 

197 # Set the cookie's path to the response path 

198 path = response_url.path 

199 if not path.startswith("/"): 

200 path = "/" 

201 else: 

202 # Cut everything from the last slash to the end 

203 path = "/" + path[1 : path.rfind("/")] 

204 cookie["path"] = path 

205 

206 max_age = cookie["max-age"] 

207 if max_age: 

208 try: 

209 delta_seconds = int(max_age) 

210 try: 

211 max_age_expiration = datetime.datetime.now( 

212 datetime.timezone.utc 

213 ) + datetime.timedelta(seconds=delta_seconds) 

214 except OverflowError: 

215 max_age_expiration = self._max_time 

216 self._expire_cookie(max_age_expiration, domain, path, name) 

217 except ValueError: 

218 cookie["max-age"] = "" 

219 

220 else: 

221 expires = cookie["expires"] 

222 if expires: 

223 expire_time = self._parse_date(expires) 

224 if expire_time: 

225 self._expire_cookie(expire_time, domain, path, name) 

226 else: 

227 cookie["expires"] = "" 

228 

229 self._cookies[(domain, path)][name] = cookie 

230 

231 self._do_expiration() 

232 

233 def filter_cookies( 

234 self, request_url: URL = URL() 

235 ) -> Union["BaseCookie[str]", "SimpleCookie[str]"]: 

236 """Returns this jar's cookies filtered by their attributes.""" 

237 self._do_expiration() 

238 if not isinstance(request_url, URL): 

239 warnings.warn( 

240 "The method accepts yarl.URL instances only, got {}".format( 

241 type(request_url) 

242 ), 

243 DeprecationWarning, 

244 ) 

245 request_url = URL(request_url) 

246 filtered: Union["SimpleCookie[str]", "BaseCookie[str]"] = ( 

247 SimpleCookie() if self._quote_cookie else BaseCookie() 

248 ) 

249 hostname = request_url.raw_host or "" 

250 request_origin = URL() 

251 with contextlib.suppress(ValueError): 

252 request_origin = request_url.origin() 

253 

254 is_not_secure = ( 

255 request_url.scheme not in ("https", "wss") 

256 and request_origin not in self._treat_as_secure_origin 

257 ) 

258 

259 for cookie in self: 

260 name = cookie.key 

261 domain = cookie["domain"] 

262 

263 # Send shared cookies 

264 if not domain: 

265 filtered[name] = cookie.value 

266 continue 

267 

268 if not self._unsafe and is_ip_address(hostname): 

269 continue 

270 

271 if (domain, name) in self._host_only_cookies: 

272 if domain != hostname: 

273 continue 

274 elif not self._is_domain_match(domain, hostname): 

275 continue 

276 

277 if not self._is_path_match(request_url.path, cookie["path"]): 

278 continue 

279 

280 if is_not_secure and cookie["secure"]: 

281 continue 

282 

283 # It's critical we use the Morsel so the coded_value 

284 # (based on cookie version) is preserved 

285 mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel())) 

286 mrsl_val.set(cookie.key, cookie.value, cookie.coded_value) 

287 filtered[name] = mrsl_val 

288 

289 return filtered 

290 

291 @staticmethod 

292 def _is_domain_match(domain: str, hostname: str) -> bool: 

293 """Implements domain matching adhering to RFC 6265.""" 

294 if hostname == domain: 

295 return True 

296 

297 if not hostname.endswith(domain): 

298 return False 

299 

300 non_matching = hostname[: -len(domain)] 

301 

302 if not non_matching.endswith("."): 

303 return False 

304 

305 return not is_ip_address(hostname) 

306 

307 @staticmethod 

308 def _is_path_match(req_path: str, cookie_path: str) -> bool: 

309 """Implements path matching adhering to RFC 6265.""" 

310 if not req_path.startswith("/"): 

311 req_path = "/" 

312 

313 if req_path == cookie_path: 

314 return True 

315 

316 if not req_path.startswith(cookie_path): 

317 return False 

318 

319 if cookie_path.endswith("/"): 

320 return True 

321 

322 non_matching = req_path[len(cookie_path) :] 

323 

324 return non_matching.startswith("/") 

325 

326 @classmethod 

327 def _parse_date(cls, date_str: str) -> Optional[datetime.datetime]: 

328 """Implements date string parsing adhering to RFC 6265.""" 

329 if not date_str: 

330 return None 

331 

332 found_time = False 

333 found_day = False 

334 found_month = False 

335 found_year = False 

336 

337 hour = minute = second = 0 

338 day = 0 

339 month = 0 

340 year = 0 

341 

342 for token_match in cls.DATE_TOKENS_RE.finditer(date_str): 

343 token = token_match.group("token") 

344 

345 if not found_time: 

346 time_match = cls.DATE_HMS_TIME_RE.match(token) 

347 if time_match: 

348 found_time = True 

349 hour, minute, second = (int(s) for s in time_match.groups()) 

350 continue 

351 

352 if not found_day: 

353 day_match = cls.DATE_DAY_OF_MONTH_RE.match(token) 

354 if day_match: 

355 found_day = True 

356 day = int(day_match.group()) 

357 continue 

358 

359 if not found_month: 

360 month_match = cls.DATE_MONTH_RE.match(token) 

361 if month_match: 

362 found_month = True 

363 assert month_match.lastindex is not None 

364 month = month_match.lastindex 

365 continue 

366 

367 if not found_year: 

368 year_match = cls.DATE_YEAR_RE.match(token) 

369 if year_match: 

370 found_year = True 

371 year = int(year_match.group()) 

372 

373 if 70 <= year <= 99: 

374 year += 1900 

375 elif 0 <= year <= 69: 

376 year += 2000 

377 

378 if False in (found_day, found_month, found_year, found_time): 

379 return None 

380 

381 if not 1 <= day <= 31: 

382 return None 

383 

384 if year < 1601 or hour > 23 or minute > 59 or second > 59: 

385 return None 

386 

387 return datetime.datetime( 

388 year, month, day, hour, minute, second, tzinfo=datetime.timezone.utc 

389 ) 

390 

391 

392class DummyCookieJar(AbstractCookieJar): 

393 """Implements a dummy cookie storage. 

394 

395 It can be used with the ClientSession when no cookie processing is needed. 

396 

397 """ 

398 

399 def __iter__(self) -> "Iterator[Morsel[str]]": 

400 while False: 

401 yield None 

402 

403 def __len__(self) -> int: 

404 return 0 

405 

406 def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None: 

407 pass 

408 

409 def clear_domain(self, domain: str) -> None: 

410 pass 

411 

412 def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: 

413 pass 

414 

415 def filter_cookies(self, request_url: URL) -> "BaseCookie[str]": 

416 return SimpleCookie()