Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/starlette/datastructures.py: 31%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

453 statements  

1from __future__ import annotations 

2 

3import re 

4from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, MutableMapping, Sequence, ValuesView 

5from shlex import shlex 

6from typing import Any, BinaryIO, NamedTuple, TypeVar, cast 

7from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit 

8 

9from starlette.concurrency import run_in_threadpool 

10from starlette.types import Scope 

11 

12 

13class Address(NamedTuple): 

14 host: str 

15 port: int 

16 

17 

18_KeyType = TypeVar("_KeyType") 

19# Mapping keys are invariant but their values are covariant since 

20# you can only read them 

21# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()` 

22_CovariantValueType = TypeVar("_CovariantValueType", covariant=True) 

23 

24# Rejects Host header chars (/, ?, #, @, ...) that would let urlsplit produce a path differing from scope["path"]. 

25_HOST_RE = re.compile(r"^([a-z0-9.-]+|\[[a-f0-9]*:[a-f0-9.:]+\])(?::[0-9]+)?$", re.IGNORECASE) 

26 

27 

28class URL: 

29 def __init__( 

30 self, 

31 url: str = "", 

32 scope: Scope | None = None, 

33 **components: Any, 

34 ) -> None: 

35 if scope is not None: 

36 assert not url, 'Cannot set both "url" and "scope".' 

37 assert not components, 'Cannot set both "scope" and "**components".' 

38 scheme = scope.get("scheme", "http") 

39 server = scope.get("server", None) 

40 path = scope["path"] 

41 query_string = scope.get("query_string", b"") 

42 

43 host_header = None 

44 for key, value in scope["headers"]: 

45 if key == b"host": 

46 host_header = value.decode("latin-1") 

47 break 

48 

49 if host_header is not None and _HOST_RE.fullmatch(host_header): 

50 url = f"{scheme}://{host_header}{path}" 

51 elif server is None: 

52 url = path 

53 else: 

54 host, port = server 

55 default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme] 

56 if port == default_port: 

57 url = f"{scheme}://{host}{path}" 

58 else: 

59 url = f"{scheme}://{host}:{port}{path}" 

60 

61 if query_string: 

62 url += "?" + query_string.decode() 

63 elif components: 

64 assert not url, 'Cannot set both "url" and "**components".' 

65 url = URL("").replace(**components).components.geturl() 

66 

67 self._url = url 

68 

69 @property 

70 def components(self) -> SplitResult: 

71 if not hasattr(self, "_components"): 

72 self._components = urlsplit(self._url) 

73 return self._components 

74 

75 @property 

76 def scheme(self) -> str: 

77 return self.components.scheme 

78 

79 @property 

80 def netloc(self) -> str: 

81 return self.components.netloc 

82 

83 @property 

84 def path(self) -> str: 

85 return self.components.path 

86 

87 @property 

88 def query(self) -> str: 

89 return self.components.query 

90 

91 @property 

92 def fragment(self) -> str: 

93 return self.components.fragment 

94 

95 @property 

96 def username(self) -> None | str: 

97 return self.components.username 

98 

99 @property 

100 def password(self) -> None | str: 

101 return self.components.password 

102 

103 @property 

104 def hostname(self) -> None | str: 

105 return self.components.hostname 

106 

107 @property 

108 def port(self) -> int | None: 

109 return self.components.port 

110 

111 @property 

112 def is_secure(self) -> bool: 

113 return self.scheme in ("https", "wss") 

114 

115 def replace(self, **kwargs: Any) -> URL: 

116 if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs: 

117 hostname = kwargs.pop("hostname", None) 

118 port = kwargs.pop("port", self.port) 

119 username = kwargs.pop("username", self.username) 

120 password = kwargs.pop("password", self.password) 

121 

122 if hostname is None: 

123 netloc = self.netloc 

124 _, _, hostname = netloc.rpartition("@") 

125 

126 if hostname[-1] != "]": 

127 hostname = hostname.rsplit(":", 1)[0] 

128 

129 netloc = hostname 

130 if port is not None: 

131 netloc += f":{port}" 

132 if username is not None: 

133 userpass = username 

134 if password is not None: 

135 userpass += f":{password}" 

136 netloc = f"{userpass}@{netloc}" 

137 

138 kwargs["netloc"] = netloc 

139 

140 components = self.components._replace(**kwargs) 

141 return self.__class__(components.geturl()) 

142 

143 def include_query_params(self, **kwargs: Any) -> URL: 

144 params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) 

145 params.update({str(key): str(value) for key, value in kwargs.items()}) 

146 query = urlencode(params.multi_items()) 

147 return self.replace(query=query) 

148 

149 def replace_query_params(self, **kwargs: Any) -> URL: 

150 query = urlencode([(str(key), str(value)) for key, value in kwargs.items()]) 

151 return self.replace(query=query) 

152 

153 def remove_query_params(self, keys: str | Sequence[str]) -> URL: 

154 if isinstance(keys, str): 

155 keys = [keys] 

156 params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) 

157 for key in keys: 

158 params.pop(key, None) 

159 query = urlencode(params.multi_items()) 

160 return self.replace(query=query) 

161 

162 def __eq__(self, other: Any) -> bool: 

163 return str(self) == str(other) 

164 

165 def __str__(self) -> str: 

166 return self._url 

167 

168 def __repr__(self) -> str: 

169 url = str(self) 

170 if self.password: 

171 url = str(self.replace(password="********")) 

172 return f"{self.__class__.__name__}({repr(url)})" 

173 

174 

175class URLPath(str): 

176 """ 

177 A URL path string that may also hold an associated protocol and/or host. 

178 Used by the routing to return `url_path_for` matches. 

179 """ 

180 

181 def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath: 

182 assert protocol in ("http", "websocket", "") 

183 return str.__new__(cls, path) 

184 

185 def __init__(self, path: str, protocol: str = "", host: str = "") -> None: 

186 self.protocol = protocol 

187 self.host = host 

188 

189 def make_absolute_url(self, base_url: str | URL) -> URL: 

190 if isinstance(base_url, str): 

191 base_url = URL(base_url) 

192 if self.protocol: 

193 scheme = { 

194 "http": {True: "https", False: "http"}, 

195 "websocket": {True: "wss", False: "ws"}, 

196 }[self.protocol][base_url.is_secure] 

197 else: 

198 scheme = base_url.scheme 

199 

200 netloc = self.host or base_url.netloc 

201 path = base_url.path.rstrip("/") + str(self) 

202 return URL(scheme=scheme, netloc=netloc, path=path) 

203 

204 

205class Secret: 

206 """ 

207 Holds a string value that should not be revealed in tracebacks etc. 

208 You should cast the value to `str` at the point it is required. 

209 """ 

210 

211 def __init__(self, value: str): 

212 self._value = value 

213 

214 def __repr__(self) -> str: 

215 class_name = self.__class__.__name__ 

216 return f"{class_name}('**********')" 

217 

218 def __str__(self) -> str: 

219 return self._value 

220 

221 def __bool__(self) -> bool: 

222 return bool(self._value) 

223 

224 

225class CommaSeparatedStrings(Sequence[str]): 

226 def __init__(self, value: str | Sequence[str]): 

227 if isinstance(value, str): 

228 splitter = shlex(value, posix=True) 

229 splitter.whitespace = "," 

230 splitter.whitespace_split = True 

231 self._items = [item.strip() for item in splitter] 

232 else: 

233 self._items = list(value) 

234 

235 def __len__(self) -> int: 

236 return len(self._items) 

237 

238 def __getitem__(self, index: int | slice) -> Any: 

239 return self._items[index] 

240 

241 def __iter__(self) -> Iterator[str]: 

242 return iter(self._items) 

243 

244 def __repr__(self) -> str: 

245 class_name = self.__class__.__name__ 

246 items = [item for item in self] 

247 return f"{class_name}({items!r})" 

248 

249 def __str__(self) -> str: 

250 return ", ".join(repr(item) for item in self) 

251 

252 

253class ImmutableMultiDict(Mapping[_KeyType, _CovariantValueType]): 

254 _dict: dict[_KeyType, _CovariantValueType] 

255 

256 def __init__( 

257 self, 

258 *args: ImmutableMultiDict[_KeyType, _CovariantValueType] 

259 | Mapping[_KeyType, _CovariantValueType] 

260 | Iterable[tuple[_KeyType, _CovariantValueType]], 

261 **kwargs: Any, 

262 ) -> None: 

263 assert len(args) < 2, "Too many arguments." 

264 

265 value: Any = args[0] if args else [] 

266 if kwargs: 

267 value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items() 

268 

269 if not value: 

270 _items: list[tuple[Any, Any]] = [] 

271 elif hasattr(value, "multi_items"): 

272 value = cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value) 

273 _items = list(value.multi_items()) 

274 elif hasattr(value, "items"): 

275 value = cast(Mapping[_KeyType, _CovariantValueType], value) 

276 _items = list(value.items()) 

277 else: 

278 value = cast("list[tuple[Any, Any]]", value) 

279 _items = list(value) 

280 

281 self._dict = {k: v for k, v in _items} 

282 self._list = _items 

283 

284 def getlist(self, key: Any) -> list[_CovariantValueType]: 

285 return [item_value for item_key, item_value in self._list if item_key == key] 

286 

287 def keys(self) -> KeysView[_KeyType]: 

288 return self._dict.keys() 

289 

290 def values(self) -> ValuesView[_CovariantValueType]: 

291 return self._dict.values() 

292 

293 def items(self) -> ItemsView[_KeyType, _CovariantValueType]: 

294 return self._dict.items() 

295 

296 def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]: 

297 return list(self._list) 

298 

299 def __getitem__(self, key: _KeyType) -> _CovariantValueType: 

300 return self._dict[key] 

301 

302 def __contains__(self, key: Any) -> bool: 

303 return key in self._dict 

304 

305 def __iter__(self) -> Iterator[_KeyType]: 

306 return iter(self.keys()) 

307 

308 def __len__(self) -> int: 

309 return len(self._dict) 

310 

311 def __eq__(self, other: Any) -> bool: 

312 if not isinstance(other, self.__class__): 

313 return False 

314 return sorted(self._list) == sorted(other._list) 

315 

316 def __repr__(self) -> str: 

317 class_name = self.__class__.__name__ 

318 items = self.multi_items() 

319 return f"{class_name}({items!r})" 

320 

321 

322class MultiDict(ImmutableMultiDict[Any, Any]): 

323 def __setitem__(self, key: Any, value: Any) -> None: 

324 self.setlist(key, [value]) 

325 

326 def __delitem__(self, key: Any) -> None: 

327 self._list = [(k, v) for k, v in self._list if k != key] 

328 del self._dict[key] 

329 

330 def pop(self, key: Any, default: Any = None) -> Any: 

331 self._list = [(k, v) for k, v in self._list if k != key] 

332 return self._dict.pop(key, default) 

333 

334 def popitem(self) -> tuple[Any, Any]: 

335 key, value = self._dict.popitem() 

336 self._list = [(k, v) for k, v in self._list if k != key] 

337 return key, value 

338 

339 def poplist(self, key: Any) -> list[Any]: 

340 values = [v for k, v in self._list if k == key] 

341 self.pop(key) 

342 return values 

343 

344 def clear(self) -> None: 

345 self._dict.clear() 

346 self._list.clear() 

347 

348 def setdefault(self, key: Any, default: Any = None) -> Any: 

349 if key not in self: 

350 self._dict[key] = default 

351 self._list.append((key, default)) 

352 

353 return self[key] 

354 

355 def setlist(self, key: Any, values: list[Any]) -> None: 

356 if not values: 

357 self.pop(key, None) 

358 else: 

359 existing_items = [(k, v) for (k, v) in self._list if k != key] 

360 self._list = existing_items + [(key, value) for value in values] 

361 self._dict[key] = values[-1] 

362 

363 def append(self, key: Any, value: Any) -> None: 

364 self._list.append((key, value)) 

365 self._dict[key] = value 

366 

367 def update( 

368 self, 

369 *args: MultiDict | Mapping[Any, Any] | list[tuple[Any, Any]], 

370 **kwargs: Any, 

371 ) -> None: 

372 value = MultiDict(*args, **kwargs) 

373 existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()] 

374 self._list = existing_items + value.multi_items() 

375 self._dict.update(value) 

376 

377 

378class QueryParams(ImmutableMultiDict[str, str]): 

379 """ 

380 An immutable multidict. 

381 """ 

382 

383 def __init__( 

384 self, 

385 *args: ImmutableMultiDict[Any, Any] | Mapping[Any, Any] | list[tuple[Any, Any]] | str | bytes, 

386 **kwargs: Any, 

387 ) -> None: 

388 assert len(args) < 2, "Too many arguments." 

389 

390 value = args[0] if args else [] 

391 

392 if isinstance(value, str): 

393 super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs) 

394 elif isinstance(value, bytes): 

395 super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs) 

396 else: 

397 super().__init__(*args, **kwargs) # type: ignore[arg-type] 

398 self._list = [(str(k), str(v)) for k, v in self._list] 

399 self._dict = {str(k): str(v) for k, v in self._dict.items()} 

400 

401 def __str__(self) -> str: 

402 return urlencode(self._list) 

403 

404 def __repr__(self) -> str: 

405 class_name = self.__class__.__name__ 

406 query_string = str(self) 

407 return f"{class_name}({query_string!r})" 

408 

409 

410class UploadFile: 

411 """ 

412 An uploaded file included as part of the request data. 

413 """ 

414 

415 def __init__( 

416 self, 

417 file: BinaryIO, 

418 *, 

419 size: int | None = None, 

420 filename: str | None = None, 

421 headers: Headers | None = None, 

422 ) -> None: 

423 self.filename = filename 

424 self.file = file 

425 self.size = size 

426 self.headers = headers or Headers() 

427 

428 # Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks. 

429 # Note 0 means unlimited mirroring SpooledTemporaryFile's __init__ 

430 self._max_mem_size = getattr(self.file, "_max_size", 0) 

431 

432 @property 

433 def content_type(self) -> str | None: 

434 return self.headers.get("content-type", None) 

435 

436 @property 

437 def _in_memory(self) -> bool: 

438 # check for SpooledTemporaryFile._rolled 

439 rolled_to_disk = getattr(self.file, "_rolled", True) 

440 return not rolled_to_disk 

441 

442 def _will_roll(self, size_to_add: int) -> bool: 

443 # If we're not in_memory then we will always roll 

444 if not self._in_memory: 

445 return True 

446 

447 # Check for SpooledTemporaryFile._max_size 

448 future_size = self.file.tell() + size_to_add 

449 return bool(future_size > self._max_mem_size) if self._max_mem_size else False 

450 

451 async def write(self, data: bytes) -> None: 

452 new_data_len = len(data) 

453 if self.size is not None: 

454 self.size += new_data_len 

455 

456 if self._will_roll(new_data_len): 

457 await run_in_threadpool(self.file.write, data) 

458 else: 

459 self.file.write(data) 

460 

461 async def read(self, size: int = -1) -> bytes: 

462 if self._in_memory: 

463 return self.file.read(size) 

464 return await run_in_threadpool(self.file.read, size) 

465 

466 async def seek(self, offset: int) -> None: 

467 if self._in_memory: 

468 self.file.seek(offset) 

469 else: 

470 await run_in_threadpool(self.file.seek, offset) 

471 

472 async def close(self) -> None: 

473 if self._in_memory: 

474 self.file.close() 

475 else: 

476 await run_in_threadpool(self.file.close) 

477 

478 def __repr__(self) -> str: 

479 return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})" 

480 

481 

482class FormData(ImmutableMultiDict[str, UploadFile | str]): 

483 """ 

484 An immutable multidict, containing both file uploads and text input. 

485 """ 

486 

487 def __init__( 

488 self, 

489 *args: FormData | Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]], 

490 **kwargs: str | UploadFile, 

491 ) -> None: 

492 super().__init__(*args, **kwargs) 

493 

494 async def close(self) -> None: 

495 for key, value in self.multi_items(): 

496 if isinstance(value, UploadFile): 

497 await value.close() 

498 

499 

500class Headers(Mapping[str, str]): 

501 """ 

502 An immutable, case-insensitive multidict. 

503 """ 

504 

505 def __init__( 

506 self, 

507 headers: Mapping[str, str] | None = None, 

508 raw: list[tuple[bytes, bytes]] | None = None, 

509 scope: MutableMapping[str, Any] | None = None, 

510 ) -> None: 

511 self._list: list[tuple[bytes, bytes]] = [] 

512 if headers is not None: 

513 assert raw is None, 'Cannot set both "headers" and "raw".' 

514 assert scope is None, 'Cannot set both "headers" and "scope".' 

515 self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()] 

516 elif raw is not None: 

517 assert scope is None, 'Cannot set both "raw" and "scope".' 

518 self._list = raw 

519 elif scope is not None: 

520 # scope["headers"] isn't necessarily a list 

521 # it might be a tuple or other iterable 

522 self._list = scope["headers"] = list(scope["headers"]) 

523 

524 @property 

525 def raw(self) -> list[tuple[bytes, bytes]]: 

526 return list(self._list) 

527 

528 def keys(self) -> list[str]: # type: ignore[override] 

529 return [key.decode("latin-1") for key, value in self._list] 

530 

531 def values(self) -> list[str]: # type: ignore[override] 

532 return [value.decode("latin-1") for key, value in self._list] 

533 

534 def items(self) -> list[tuple[str, str]]: # type: ignore[override] 

535 return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list] 

536 

537 def getlist(self, key: str) -> list[str]: 

538 get_header_key = key.lower().encode("latin-1") 

539 return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key] 

540 

541 def mutablecopy(self) -> MutableHeaders: 

542 return MutableHeaders(raw=self._list[:]) 

543 

544 def __getitem__(self, key: str) -> str: 

545 get_header_key = key.lower().encode("latin-1") 

546 for header_key, header_value in self._list: 

547 if header_key == get_header_key: 

548 return header_value.decode("latin-1") 

549 raise KeyError(key) 

550 

551 def __contains__(self, key: Any) -> bool: 

552 get_header_key = key.lower().encode("latin-1") 

553 for header_key, header_value in self._list: 

554 if header_key == get_header_key: 

555 return True 

556 return False 

557 

558 def __iter__(self) -> Iterator[Any]: 

559 return iter(self.keys()) 

560 

561 def __len__(self) -> int: 

562 return len(self._list) 

563 

564 def __eq__(self, other: Any) -> bool: 

565 if not isinstance(other, Headers): 

566 return False 

567 return sorted(self._list) == sorted(other._list) 

568 

569 def __repr__(self) -> str: 

570 class_name = self.__class__.__name__ 

571 as_dict = dict(self.items()) 

572 if len(as_dict) == len(self): 

573 return f"{class_name}({as_dict!r})" 

574 return f"{class_name}(raw={self.raw!r})" 

575 

576 

577class MutableHeaders(Headers): 

578 def __setitem__(self, key: str, value: str) -> None: 

579 """ 

580 Set the header `key` to `value`, removing any duplicate entries. 

581 Retains insertion order. 

582 """ 

583 set_key = key.lower().encode("latin-1") 

584 set_value = value.encode("latin-1") 

585 

586 found_indexes: list[int] = [] 

587 for idx, (item_key, item_value) in enumerate(self._list): 

588 if item_key == set_key: 

589 found_indexes.append(idx) 

590 

591 for idx in reversed(found_indexes[1:]): 

592 del self._list[idx] 

593 

594 if found_indexes: 

595 idx = found_indexes[0] 

596 self._list[idx] = (set_key, set_value) 

597 else: 

598 self._list.append((set_key, set_value)) 

599 

600 def __delitem__(self, key: str) -> None: 

601 """ 

602 Remove the header `key`. 

603 """ 

604 del_key = key.lower().encode("latin-1") 

605 

606 pop_indexes: list[int] = [] 

607 for idx, (item_key, item_value) in enumerate(self._list): 

608 if item_key == del_key: 

609 pop_indexes.append(idx) 

610 

611 for idx in reversed(pop_indexes): 

612 del self._list[idx] 

613 

614 def __ior__(self, other: Mapping[str, str]) -> MutableHeaders: 

615 if not isinstance(other, Mapping): 

616 raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") 

617 self.update(other) 

618 return self 

619 

620 def __or__(self, other: Mapping[str, str]) -> MutableHeaders: 

621 if not isinstance(other, Mapping): 

622 raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") 

623 new = self.mutablecopy() 

624 new.update(other) 

625 return new 

626 

627 @property 

628 def raw(self) -> list[tuple[bytes, bytes]]: 

629 return self._list 

630 

631 def setdefault(self, key: str, value: str) -> str: 

632 """ 

633 If the header `key` does not exist, then set it to `value`. 

634 Returns the header value. 

635 """ 

636 set_key = key.lower().encode("latin-1") 

637 set_value = value.encode("latin-1") 

638 

639 for idx, (item_key, item_value) in enumerate(self._list): 

640 if item_key == set_key: 

641 return item_value.decode("latin-1") 

642 self._list.append((set_key, set_value)) 

643 return value 

644 

645 def update(self, other: Mapping[str, str]) -> None: 

646 for key, val in other.items(): 

647 self[key] = val 

648 

649 def append(self, key: str, value: str) -> None: 

650 """ 

651 Append a header, preserving any duplicate entries. 

652 """ 

653 append_key = key.lower().encode("latin-1") 

654 append_value = value.encode("latin-1") 

655 self._list.append((append_key, append_value)) 

656 

657 def add_vary_header(self, vary: str) -> None: 

658 existing = self.get("vary") 

659 if existing is not None: 

660 vary = ", ".join([existing, vary]) 

661 self["vary"] = vary 

662 

663 

664class State: 

665 """ 

666 An object that can be used to store arbitrary state. 

667 

668 Used for `request.state` and `app.state`. 

669 """ 

670 

671 _state: dict[str, Any] 

672 

673 def __init__(self, state: dict[str, Any] | None = None): 

674 if state is None: 

675 state = {} 

676 super().__setattr__("_state", state) 

677 

678 def __setattr__(self, key: Any, value: Any) -> None: 

679 self._state[key] = value 

680 

681 def __getattr__(self, key: Any) -> Any: 

682 try: 

683 return self._state[key] 

684 except KeyError: 

685 message = "'{}' object has no attribute '{}'" 

686 raise AttributeError(message.format(self.__class__.__name__, key)) 

687 

688 def __delattr__(self, key: Any) -> None: 

689 del self._state[key] 

690 

691 def __getitem__(self, key: str) -> Any: 

692 return self._state[key] 

693 

694 def __setitem__(self, key: str, value: Any) -> None: 

695 self._state[key] = value 

696 

697 def __delitem__(self, key: str) -> None: 

698 del self._state[key] 

699 

700 def __iter__(self) -> Iterator[str]: 

701 return iter(self._state) 

702 

703 def __len__(self) -> int: 

704 return len(self._state)