Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/starlette/datastructures.py: 31%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

441 statements  

1from __future__ import annotations 

2 

3from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, MutableMapping, Sequence, ValuesView 

4from shlex import shlex 

5from typing import ( 

6 Any, 

7 BinaryIO, 

8 NamedTuple, 

9 TypeVar, 

10 cast, 

11) 

12from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit 

13 

14from starlette.concurrency import run_in_threadpool 

15from starlette.types import Scope 

16 

17 

18class Address(NamedTuple): 

19 host: str 

20 port: int 

21 

22 

23_KeyType = TypeVar("_KeyType") 

24# Mapping keys are invariant but their values are covariant since 

25# you can only read them 

26# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()` 

27_CovariantValueType = TypeVar("_CovariantValueType", covariant=True) 

28 

29 

30class URL: 

31 def __init__( 

32 self, 

33 url: str = "", 

34 scope: Scope | None = None, 

35 **components: Any, 

36 ) -> None: 

37 if scope is not None: 

38 assert not url, 'Cannot set both "url" and "scope".' 

39 assert not components, 'Cannot set both "scope" and "**components".' 

40 scheme = scope.get("scheme", "http") 

41 server = scope.get("server", None) 

42 path = scope["path"] 

43 query_string = scope.get("query_string", b"") 

44 

45 host_header = None 

46 for key, value in scope["headers"]: 

47 if key == b"host": 

48 host_header = value.decode("latin-1") 

49 break 

50 

51 if host_header is not None: 

52 url = f"{scheme}://{host_header}{path}" 

53 elif server is None: 

54 url = path 

55 else: 

56 host, port = server 

57 default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme] 

58 if port == default_port: 

59 url = f"{scheme}://{host}{path}" 

60 else: 

61 url = f"{scheme}://{host}:{port}{path}" 

62 

63 if query_string: 

64 url += "?" + query_string.decode() 

65 elif components: 

66 assert not url, 'Cannot set both "url" and "**components".' 

67 url = URL("").replace(**components).components.geturl() 

68 

69 self._url = url 

70 

71 @property 

72 def components(self) -> SplitResult: 

73 if not hasattr(self, "_components"): 

74 self._components = urlsplit(self._url) 

75 return self._components 

76 

77 @property 

78 def scheme(self) -> str: 

79 return self.components.scheme 

80 

81 @property 

82 def netloc(self) -> str: 

83 return self.components.netloc 

84 

85 @property 

86 def path(self) -> str: 

87 return self.components.path 

88 

89 @property 

90 def query(self) -> str: 

91 return self.components.query 

92 

93 @property 

94 def fragment(self) -> str: 

95 return self.components.fragment 

96 

97 @property 

98 def username(self) -> None | str: 

99 return self.components.username 

100 

101 @property 

102 def password(self) -> None | str: 

103 return self.components.password 

104 

105 @property 

106 def hostname(self) -> None | str: 

107 return self.components.hostname 

108 

109 @property 

110 def port(self) -> int | None: 

111 return self.components.port 

112 

113 @property 

114 def is_secure(self) -> bool: 

115 return self.scheme in ("https", "wss") 

116 

117 def replace(self, **kwargs: Any) -> URL: 

118 if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs: 

119 hostname = kwargs.pop("hostname", None) 

120 port = kwargs.pop("port", self.port) 

121 username = kwargs.pop("username", self.username) 

122 password = kwargs.pop("password", self.password) 

123 

124 if hostname is None: 

125 netloc = self.netloc 

126 _, _, hostname = netloc.rpartition("@") 

127 

128 if hostname[-1] != "]": 

129 hostname = hostname.rsplit(":", 1)[0] 

130 

131 netloc = hostname 

132 if port is not None: 

133 netloc += f":{port}" 

134 if username is not None: 

135 userpass = username 

136 if password is not None: 

137 userpass += f":{password}" 

138 netloc = f"{userpass}@{netloc}" 

139 

140 kwargs["netloc"] = netloc 

141 

142 components = self.components._replace(**kwargs) 

143 return self.__class__(components.geturl()) 

144 

145 def include_query_params(self, **kwargs: Any) -> URL: 

146 params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) 

147 params.update({str(key): str(value) for key, value in kwargs.items()}) 

148 query = urlencode(params.multi_items()) 

149 return self.replace(query=query) 

150 

151 def replace_query_params(self, **kwargs: Any) -> URL: 

152 query = urlencode([(str(key), str(value)) for key, value in kwargs.items()]) 

153 return self.replace(query=query) 

154 

155 def remove_query_params(self, keys: str | Sequence[str]) -> URL: 

156 if isinstance(keys, str): 

157 keys = [keys] 

158 params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) 

159 for key in keys: 

160 params.pop(key, None) 

161 query = urlencode(params.multi_items()) 

162 return self.replace(query=query) 

163 

164 def __eq__(self, other: Any) -> bool: 

165 return str(self) == str(other) 

166 

167 def __str__(self) -> str: 

168 return self._url 

169 

170 def __repr__(self) -> str: 

171 url = str(self) 

172 if self.password: 

173 url = str(self.replace(password="********")) 

174 return f"{self.__class__.__name__}({repr(url)})" 

175 

176 

177class URLPath(str): 

178 """ 

179 A URL path string that may also hold an associated protocol and/or host. 

180 Used by the routing to return `url_path_for` matches. 

181 """ 

182 

183 def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath: 

184 assert protocol in ("http", "websocket", "") 

185 return str.__new__(cls, path) 

186 

187 def __init__(self, path: str, protocol: str = "", host: str = "") -> None: 

188 self.protocol = protocol 

189 self.host = host 

190 

191 def make_absolute_url(self, base_url: str | URL) -> URL: 

192 if isinstance(base_url, str): 

193 base_url = URL(base_url) 

194 if self.protocol: 

195 scheme = { 

196 "http": {True: "https", False: "http"}, 

197 "websocket": {True: "wss", False: "ws"}, 

198 }[self.protocol][base_url.is_secure] 

199 else: 

200 scheme = base_url.scheme 

201 

202 netloc = self.host or base_url.netloc 

203 path = base_url.path.rstrip("/") + str(self) 

204 return URL(scheme=scheme, netloc=netloc, path=path) 

205 

206 

207class Secret: 

208 """ 

209 Holds a string value that should not be revealed in tracebacks etc. 

210 You should cast the value to `str` at the point it is required. 

211 """ 

212 

213 def __init__(self, value: str): 

214 self._value = value 

215 

216 def __repr__(self) -> str: 

217 class_name = self.__class__.__name__ 

218 return f"{class_name}('**********')" 

219 

220 def __str__(self) -> str: 

221 return self._value 

222 

223 def __bool__(self) -> bool: 

224 return bool(self._value) 

225 

226 

227class CommaSeparatedStrings(Sequence[str]): 

228 def __init__(self, value: str | Sequence[str]): 

229 if isinstance(value, str): 

230 splitter = shlex(value, posix=True) 

231 splitter.whitespace = "," 

232 splitter.whitespace_split = True 

233 self._items = [item.strip() for item in splitter] 

234 else: 

235 self._items = list(value) 

236 

237 def __len__(self) -> int: 

238 return len(self._items) 

239 

240 def __getitem__(self, index: int | slice) -> Any: 

241 return self._items[index] 

242 

243 def __iter__(self) -> Iterator[str]: 

244 return iter(self._items) 

245 

246 def __repr__(self) -> str: 

247 class_name = self.__class__.__name__ 

248 items = [item for item in self] 

249 return f"{class_name}({items!r})" 

250 

251 def __str__(self) -> str: 

252 return ", ".join(repr(item) for item in self) 

253 

254 

255class ImmutableMultiDict(Mapping[_KeyType, _CovariantValueType]): 

256 _dict: dict[_KeyType, _CovariantValueType] 

257 

258 def __init__( 

259 self, 

260 *args: ImmutableMultiDict[_KeyType, _CovariantValueType] 

261 | Mapping[_KeyType, _CovariantValueType] 

262 | Iterable[tuple[_KeyType, _CovariantValueType]], 

263 **kwargs: Any, 

264 ) -> None: 

265 assert len(args) < 2, "Too many arguments." 

266 

267 value: Any = args[0] if args else [] 

268 if kwargs: 

269 value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items() 

270 

271 if not value: 

272 _items: list[tuple[Any, Any]] = [] 

273 elif hasattr(value, "multi_items"): 

274 value = cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value) 

275 _items = list(value.multi_items()) 

276 elif hasattr(value, "items"): 

277 value = cast(Mapping[_KeyType, _CovariantValueType], value) 

278 _items = list(value.items()) 

279 else: 

280 value = cast("list[tuple[Any, Any]]", value) 

281 _items = list(value) 

282 

283 self._dict = {k: v for k, v in _items} 

284 self._list = _items 

285 

286 def getlist(self, key: Any) -> list[_CovariantValueType]: 

287 return [item_value for item_key, item_value in self._list if item_key == key] 

288 

289 def keys(self) -> KeysView[_KeyType]: 

290 return self._dict.keys() 

291 

292 def values(self) -> ValuesView[_CovariantValueType]: 

293 return self._dict.values() 

294 

295 def items(self) -> ItemsView[_KeyType, _CovariantValueType]: 

296 return self._dict.items() 

297 

298 def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]: 

299 return list(self._list) 

300 

301 def __getitem__(self, key: _KeyType) -> _CovariantValueType: 

302 return self._dict[key] 

303 

304 def __contains__(self, key: Any) -> bool: 

305 return key in self._dict 

306 

307 def __iter__(self) -> Iterator[_KeyType]: 

308 return iter(self.keys()) 

309 

310 def __len__(self) -> int: 

311 return len(self._dict) 

312 

313 def __eq__(self, other: Any) -> bool: 

314 if not isinstance(other, self.__class__): 

315 return False 

316 return sorted(self._list) == sorted(other._list) 

317 

318 def __repr__(self) -> str: 

319 class_name = self.__class__.__name__ 

320 items = self.multi_items() 

321 return f"{class_name}({items!r})" 

322 

323 

324class MultiDict(ImmutableMultiDict[Any, Any]): 

325 def __setitem__(self, key: Any, value: Any) -> None: 

326 self.setlist(key, [value]) 

327 

328 def __delitem__(self, key: Any) -> None: 

329 self._list = [(k, v) for k, v in self._list if k != key] 

330 del self._dict[key] 

331 

332 def pop(self, key: Any, default: Any = None) -> Any: 

333 self._list = [(k, v) for k, v in self._list if k != key] 

334 return self._dict.pop(key, default) 

335 

336 def popitem(self) -> tuple[Any, Any]: 

337 key, value = self._dict.popitem() 

338 self._list = [(k, v) for k, v in self._list if k != key] 

339 return key, value 

340 

341 def poplist(self, key: Any) -> list[Any]: 

342 values = [v for k, v in self._list if k == key] 

343 self.pop(key) 

344 return values 

345 

346 def clear(self) -> None: 

347 self._dict.clear() 

348 self._list.clear() 

349 

350 def setdefault(self, key: Any, default: Any = None) -> Any: 

351 if key not in self: 

352 self._dict[key] = default 

353 self._list.append((key, default)) 

354 

355 return self[key] 

356 

357 def setlist(self, key: Any, values: list[Any]) -> None: 

358 if not values: 

359 self.pop(key, None) 

360 else: 

361 existing_items = [(k, v) for (k, v) in self._list if k != key] 

362 self._list = existing_items + [(key, value) for value in values] 

363 self._dict[key] = values[-1] 

364 

365 def append(self, key: Any, value: Any) -> None: 

366 self._list.append((key, value)) 

367 self._dict[key] = value 

368 

369 def update( 

370 self, 

371 *args: MultiDict | Mapping[Any, Any] | list[tuple[Any, Any]], 

372 **kwargs: Any, 

373 ) -> None: 

374 value = MultiDict(*args, **kwargs) 

375 existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()] 

376 self._list = existing_items + value.multi_items() 

377 self._dict.update(value) 

378 

379 

380class QueryParams(ImmutableMultiDict[str, str]): 

381 """ 

382 An immutable multidict. 

383 """ 

384 

385 def __init__( 

386 self, 

387 *args: ImmutableMultiDict[Any, Any] | Mapping[Any, Any] | list[tuple[Any, Any]] | str | bytes, 

388 **kwargs: Any, 

389 ) -> None: 

390 assert len(args) < 2, "Too many arguments." 

391 

392 value = args[0] if args else [] 

393 

394 if isinstance(value, str): 

395 super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs) 

396 elif isinstance(value, bytes): 

397 super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs) 

398 else: 

399 super().__init__(*args, **kwargs) # type: ignore[arg-type] 

400 self._list = [(str(k), str(v)) for k, v in self._list] 

401 self._dict = {str(k): str(v) for k, v in self._dict.items()} 

402 

403 def __str__(self) -> str: 

404 return urlencode(self._list) 

405 

406 def __repr__(self) -> str: 

407 class_name = self.__class__.__name__ 

408 query_string = str(self) 

409 return f"{class_name}({query_string!r})" 

410 

411 

412class UploadFile: 

413 """ 

414 An uploaded file included as part of the request data. 

415 """ 

416 

417 def __init__( 

418 self, 

419 file: BinaryIO, 

420 *, 

421 size: int | None = None, 

422 filename: str | None = None, 

423 headers: Headers | None = None, 

424 ) -> None: 

425 self.filename = filename 

426 self.file = file 

427 self.size = size 

428 self.headers = headers or Headers() 

429 

430 # Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks. 

431 # Note 0 means unlimited mirroring SpooledTemporaryFile's __init__ 

432 self._max_mem_size = getattr(self.file, "_max_size", 0) 

433 

434 @property 

435 def content_type(self) -> str | None: 

436 return self.headers.get("content-type", None) 

437 

438 @property 

439 def _in_memory(self) -> bool: 

440 # check for SpooledTemporaryFile._rolled 

441 rolled_to_disk = getattr(self.file, "_rolled", True) 

442 return not rolled_to_disk 

443 

444 def _will_roll(self, size_to_add: int) -> bool: 

445 # If we're not in_memory then we will always roll 

446 if not self._in_memory: 

447 return True 

448 

449 # Check for SpooledTemporaryFile._max_size 

450 future_size = self.file.tell() + size_to_add 

451 return bool(future_size > self._max_mem_size) if self._max_mem_size else False 

452 

453 async def write(self, data: bytes) -> None: 

454 new_data_len = len(data) 

455 if self.size is not None: 

456 self.size += new_data_len 

457 

458 if self._will_roll(new_data_len): 

459 await run_in_threadpool(self.file.write, data) 

460 else: 

461 self.file.write(data) 

462 

463 async def read(self, size: int = -1) -> bytes: 

464 if self._in_memory: 

465 return self.file.read(size) 

466 return await run_in_threadpool(self.file.read, size) 

467 

468 async def seek(self, offset: int) -> None: 

469 if self._in_memory: 

470 self.file.seek(offset) 

471 else: 

472 await run_in_threadpool(self.file.seek, offset) 

473 

474 async def close(self) -> None: 

475 if self._in_memory: 

476 self.file.close() 

477 else: 

478 await run_in_threadpool(self.file.close) 

479 

480 def __repr__(self) -> str: 

481 return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})" 

482 

483 

484class FormData(ImmutableMultiDict[str, UploadFile | str]): 

485 """ 

486 An immutable multidict, containing both file uploads and text input. 

487 """ 

488 

489 def __init__( 

490 self, 

491 *args: FormData | Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]], 

492 **kwargs: str | UploadFile, 

493 ) -> None: 

494 super().__init__(*args, **kwargs) 

495 

496 async def close(self) -> None: 

497 for key, value in self.multi_items(): 

498 if isinstance(value, UploadFile): 

499 await value.close() 

500 

501 

502class Headers(Mapping[str, str]): 

503 """ 

504 An immutable, case-insensitive multidict. 

505 """ 

506 

507 def __init__( 

508 self, 

509 headers: Mapping[str, str] | None = None, 

510 raw: list[tuple[bytes, bytes]] | None = None, 

511 scope: MutableMapping[str, Any] | None = None, 

512 ) -> None: 

513 self._list: list[tuple[bytes, bytes]] = [] 

514 if headers is not None: 

515 assert raw is None, 'Cannot set both "headers" and "raw".' 

516 assert scope is None, 'Cannot set both "headers" and "scope".' 

517 self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()] 

518 elif raw is not None: 

519 assert scope is None, 'Cannot set both "raw" and "scope".' 

520 self._list = raw 

521 elif scope is not None: 

522 # scope["headers"] isn't necessarily a list 

523 # it might be a tuple or other iterable 

524 self._list = scope["headers"] = list(scope["headers"]) 

525 

526 @property 

527 def raw(self) -> list[tuple[bytes, bytes]]: 

528 return list(self._list) 

529 

530 def keys(self) -> list[str]: # type: ignore[override] 

531 return [key.decode("latin-1") for key, value in self._list] 

532 

533 def values(self) -> list[str]: # type: ignore[override] 

534 return [value.decode("latin-1") for key, value in self._list] 

535 

536 def items(self) -> list[tuple[str, str]]: # type: ignore[override] 

537 return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list] 

538 

539 def getlist(self, key: str) -> list[str]: 

540 get_header_key = key.lower().encode("latin-1") 

541 return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key] 

542 

543 def mutablecopy(self) -> MutableHeaders: 

544 return MutableHeaders(raw=self._list[:]) 

545 

546 def __getitem__(self, key: str) -> str: 

547 get_header_key = key.lower().encode("latin-1") 

548 for header_key, header_value in self._list: 

549 if header_key == get_header_key: 

550 return header_value.decode("latin-1") 

551 raise KeyError(key) 

552 

553 def __contains__(self, key: Any) -> bool: 

554 get_header_key = key.lower().encode("latin-1") 

555 for header_key, header_value in self._list: 

556 if header_key == get_header_key: 

557 return True 

558 return False 

559 

560 def __iter__(self) -> Iterator[Any]: 

561 return iter(self.keys()) 

562 

563 def __len__(self) -> int: 

564 return len(self._list) 

565 

566 def __eq__(self, other: Any) -> bool: 

567 if not isinstance(other, Headers): 

568 return False 

569 return sorted(self._list) == sorted(other._list) 

570 

571 def __repr__(self) -> str: 

572 class_name = self.__class__.__name__ 

573 as_dict = dict(self.items()) 

574 if len(as_dict) == len(self): 

575 return f"{class_name}({as_dict!r})" 

576 return f"{class_name}(raw={self.raw!r})" 

577 

578 

579class MutableHeaders(Headers): 

580 def __setitem__(self, key: str, value: str) -> None: 

581 """ 

582 Set the header `key` to `value`, removing any duplicate entries. 

583 Retains insertion order. 

584 """ 

585 set_key = key.lower().encode("latin-1") 

586 set_value = value.encode("latin-1") 

587 

588 found_indexes: list[int] = [] 

589 for idx, (item_key, item_value) in enumerate(self._list): 

590 if item_key == set_key: 

591 found_indexes.append(idx) 

592 

593 for idx in reversed(found_indexes[1:]): 

594 del self._list[idx] 

595 

596 if found_indexes: 

597 idx = found_indexes[0] 

598 self._list[idx] = (set_key, set_value) 

599 else: 

600 self._list.append((set_key, set_value)) 

601 

602 def __delitem__(self, key: str) -> None: 

603 """ 

604 Remove the header `key`. 

605 """ 

606 del_key = key.lower().encode("latin-1") 

607 

608 pop_indexes: list[int] = [] 

609 for idx, (item_key, item_value) in enumerate(self._list): 

610 if item_key == del_key: 

611 pop_indexes.append(idx) 

612 

613 for idx in reversed(pop_indexes): 

614 del self._list[idx] 

615 

616 def __ior__(self, other: Mapping[str, str]) -> MutableHeaders: 

617 if not isinstance(other, Mapping): 

618 raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") 

619 self.update(other) 

620 return self 

621 

622 def __or__(self, other: Mapping[str, str]) -> MutableHeaders: 

623 if not isinstance(other, Mapping): 

624 raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") 

625 new = self.mutablecopy() 

626 new.update(other) 

627 return new 

628 

629 @property 

630 def raw(self) -> list[tuple[bytes, bytes]]: 

631 return self._list 

632 

633 def setdefault(self, key: str, value: str) -> str: 

634 """ 

635 If the header `key` does not exist, then set it to `value`. 

636 Returns the header value. 

637 """ 

638 set_key = key.lower().encode("latin-1") 

639 set_value = value.encode("latin-1") 

640 

641 for idx, (item_key, item_value) in enumerate(self._list): 

642 if item_key == set_key: 

643 return item_value.decode("latin-1") 

644 self._list.append((set_key, set_value)) 

645 return value 

646 

647 def update(self, other: Mapping[str, str]) -> None: 

648 for key, val in other.items(): 

649 self[key] = val 

650 

651 def append(self, key: str, value: str) -> None: 

652 """ 

653 Append a header, preserving any duplicate entries. 

654 """ 

655 append_key = key.lower().encode("latin-1") 

656 append_value = value.encode("latin-1") 

657 self._list.append((append_key, append_value)) 

658 

659 def add_vary_header(self, vary: str) -> None: 

660 existing = self.get("vary") 

661 if existing is not None: 

662 vary = ", ".join([existing, vary]) 

663 self["vary"] = vary 

664 

665 

666class State: 

667 """ 

668 An object that can be used to store arbitrary state. 

669 

670 Used for `request.state` and `app.state`. 

671 """ 

672 

673 _state: dict[str, Any] 

674 

675 def __init__(self, state: dict[str, Any] | None = None): 

676 if state is None: 

677 state = {} 

678 super().__setattr__("_state", state) 

679 

680 def __setattr__(self, key: Any, value: Any) -> None: 

681 self._state[key] = value 

682 

683 def __getattr__(self, key: Any) -> Any: 

684 try: 

685 return self._state[key] 

686 except KeyError: 

687 message = "'{}' object has no attribute '{}'" 

688 raise AttributeError(message.format(self.__class__.__name__, key)) 

689 

690 def __delattr__(self, key: Any) -> None: 

691 del self._state[key]