Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/starlette/datastructures.py: 31%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

441 statements  

1from __future__ import annotations 

2 

3from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, MutableMapping, Sequence, ValuesView 

4from shlex import shlex 

5from typing import ( 

6 Any, 

7 BinaryIO, 

8 NamedTuple, 

9 TypeVar, 

10 Union, 

11 cast, 

12) 

13from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit 

14 

15from starlette.concurrency import run_in_threadpool 

16from starlette.types import Scope 

17 

18 

19class Address(NamedTuple): 

20 host: str 

21 port: int 

22 

23 

24_KeyType = TypeVar("_KeyType") 

25# Mapping keys are invariant but their values are covariant since 

26# you can only read them 

27# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()` 

28_CovariantValueType = TypeVar("_CovariantValueType", covariant=True) 

29 

30 

31class URL: 

32 def __init__( 

33 self, 

34 url: str = "", 

35 scope: Scope | None = None, 

36 **components: Any, 

37 ) -> None: 

38 if scope is not None: 

39 assert not url, 'Cannot set both "url" and "scope".' 

40 assert not components, 'Cannot set both "scope" and "**components".' 

41 scheme = scope.get("scheme", "http") 

42 server = scope.get("server", None) 

43 path = scope["path"] 

44 query_string = scope.get("query_string", b"") 

45 

46 host_header = None 

47 for key, value in scope["headers"]: 

48 if key == b"host": 

49 host_header = value.decode("latin-1") 

50 break 

51 

52 if host_header is not None: 

53 url = f"{scheme}://{host_header}{path}" 

54 elif server is None: 

55 url = path 

56 else: 

57 host, port = server 

58 default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme] 

59 if port == default_port: 

60 url = f"{scheme}://{host}{path}" 

61 else: 

62 url = f"{scheme}://{host}:{port}{path}" 

63 

64 if query_string: 

65 url += "?" + query_string.decode() 

66 elif components: 

67 assert not url, 'Cannot set both "url" and "**components".' 

68 url = URL("").replace(**components).components.geturl() 

69 

70 self._url = url 

71 

72 @property 

73 def components(self) -> SplitResult: 

74 if not hasattr(self, "_components"): 

75 self._components = urlsplit(self._url) 

76 return self._components 

77 

78 @property 

79 def scheme(self) -> str: 

80 return self.components.scheme 

81 

82 @property 

83 def netloc(self) -> str: 

84 return self.components.netloc 

85 

86 @property 

87 def path(self) -> str: 

88 return self.components.path 

89 

90 @property 

91 def query(self) -> str: 

92 return self.components.query 

93 

94 @property 

95 def fragment(self) -> str: 

96 return self.components.fragment 

97 

98 @property 

99 def username(self) -> None | str: 

100 return self.components.username 

101 

102 @property 

103 def password(self) -> None | str: 

104 return self.components.password 

105 

106 @property 

107 def hostname(self) -> None | str: 

108 return self.components.hostname 

109 

110 @property 

111 def port(self) -> int | None: 

112 return self.components.port 

113 

114 @property 

115 def is_secure(self) -> bool: 

116 return self.scheme in ("https", "wss") 

117 

118 def replace(self, **kwargs: Any) -> URL: 

119 if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs: 

120 hostname = kwargs.pop("hostname", None) 

121 port = kwargs.pop("port", self.port) 

122 username = kwargs.pop("username", self.username) 

123 password = kwargs.pop("password", self.password) 

124 

125 if hostname is None: 

126 netloc = self.netloc 

127 _, _, hostname = netloc.rpartition("@") 

128 

129 if hostname[-1] != "]": 

130 hostname = hostname.rsplit(":", 1)[0] 

131 

132 netloc = hostname 

133 if port is not None: 

134 netloc += f":{port}" 

135 if username is not None: 

136 userpass = username 

137 if password is not None: 

138 userpass += f":{password}" 

139 netloc = f"{userpass}@{netloc}" 

140 

141 kwargs["netloc"] = netloc 

142 

143 components = self.components._replace(**kwargs) 

144 return self.__class__(components.geturl()) 

145 

146 def include_query_params(self, **kwargs: Any) -> URL: 

147 params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) 

148 params.update({str(key): str(value) for key, value in kwargs.items()}) 

149 query = urlencode(params.multi_items()) 

150 return self.replace(query=query) 

151 

152 def replace_query_params(self, **kwargs: Any) -> URL: 

153 query = urlencode([(str(key), str(value)) for key, value in kwargs.items()]) 

154 return self.replace(query=query) 

155 

156 def remove_query_params(self, keys: str | Sequence[str]) -> URL: 

157 if isinstance(keys, str): 

158 keys = [keys] 

159 params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) 

160 for key in keys: 

161 params.pop(key, None) 

162 query = urlencode(params.multi_items()) 

163 return self.replace(query=query) 

164 

165 def __eq__(self, other: Any) -> bool: 

166 return str(self) == str(other) 

167 

168 def __str__(self) -> str: 

169 return self._url 

170 

171 def __repr__(self) -> str: 

172 url = str(self) 

173 if self.password: 

174 url = str(self.replace(password="********")) 

175 return f"{self.__class__.__name__}({repr(url)})" 

176 

177 

178class URLPath(str): 

179 """ 

180 A URL path string that may also hold an associated protocol and/or host. 

181 Used by the routing to return `url_path_for` matches. 

182 """ 

183 

184 def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath: 

185 assert protocol in ("http", "websocket", "") 

186 return str.__new__(cls, path) 

187 

188 def __init__(self, path: str, protocol: str = "", host: str = "") -> None: 

189 self.protocol = protocol 

190 self.host = host 

191 

192 def make_absolute_url(self, base_url: str | URL) -> URL: 

193 if isinstance(base_url, str): 

194 base_url = URL(base_url) 

195 if self.protocol: 

196 scheme = { 

197 "http": {True: "https", False: "http"}, 

198 "websocket": {True: "wss", False: "ws"}, 

199 }[self.protocol][base_url.is_secure] 

200 else: 

201 scheme = base_url.scheme 

202 

203 netloc = self.host or base_url.netloc 

204 path = base_url.path.rstrip("/") + str(self) 

205 return URL(scheme=scheme, netloc=netloc, path=path) 

206 

207 

208class Secret: 

209 """ 

210 Holds a string value that should not be revealed in tracebacks etc. 

211 You should cast the value to `str` at the point it is required. 

212 """ 

213 

214 def __init__(self, value: str): 

215 self._value = value 

216 

217 def __repr__(self) -> str: 

218 class_name = self.__class__.__name__ 

219 return f"{class_name}('**********')" 

220 

221 def __str__(self) -> str: 

222 return self._value 

223 

224 def __bool__(self) -> bool: 

225 return bool(self._value) 

226 

227 

228class CommaSeparatedStrings(Sequence[str]): 

229 def __init__(self, value: str | Sequence[str]): 

230 if isinstance(value, str): 

231 splitter = shlex(value, posix=True) 

232 splitter.whitespace = "," 

233 splitter.whitespace_split = True 

234 self._items = [item.strip() for item in splitter] 

235 else: 

236 self._items = list(value) 

237 

238 def __len__(self) -> int: 

239 return len(self._items) 

240 

241 def __getitem__(self, index: int | slice) -> Any: 

242 return self._items[index] 

243 

244 def __iter__(self) -> Iterator[str]: 

245 return iter(self._items) 

246 

247 def __repr__(self) -> str: 

248 class_name = self.__class__.__name__ 

249 items = [item for item in self] 

250 return f"{class_name}({items!r})" 

251 

252 def __str__(self) -> str: 

253 return ", ".join(repr(item) for item in self) 

254 

255 

256class ImmutableMultiDict(Mapping[_KeyType, _CovariantValueType]): 

257 _dict: dict[_KeyType, _CovariantValueType] 

258 

259 def __init__( 

260 self, 

261 *args: ImmutableMultiDict[_KeyType, _CovariantValueType] 

262 | Mapping[_KeyType, _CovariantValueType] 

263 | Iterable[tuple[_KeyType, _CovariantValueType]], 

264 **kwargs: Any, 

265 ) -> None: 

266 assert len(args) < 2, "Too many arguments." 

267 

268 value: Any = args[0] if args else [] 

269 if kwargs: 

270 value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items() 

271 

272 if not value: 

273 _items: list[tuple[Any, Any]] = [] 

274 elif hasattr(value, "multi_items"): 

275 value = cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value) 

276 _items = list(value.multi_items()) 

277 elif hasattr(value, "items"): 

278 value = cast(Mapping[_KeyType, _CovariantValueType], value) 

279 _items = list(value.items()) 

280 else: 

281 value = cast("list[tuple[Any, Any]]", value) 

282 _items = list(value) 

283 

284 self._dict = {k: v for k, v in _items} 

285 self._list = _items 

286 

287 def getlist(self, key: Any) -> list[_CovariantValueType]: 

288 return [item_value for item_key, item_value in self._list if item_key == key] 

289 

290 def keys(self) -> KeysView[_KeyType]: 

291 return self._dict.keys() 

292 

293 def values(self) -> ValuesView[_CovariantValueType]: 

294 return self._dict.values() 

295 

296 def items(self) -> ItemsView[_KeyType, _CovariantValueType]: 

297 return self._dict.items() 

298 

299 def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]: 

300 return list(self._list) 

301 

302 def __getitem__(self, key: _KeyType) -> _CovariantValueType: 

303 return self._dict[key] 

304 

305 def __contains__(self, key: Any) -> bool: 

306 return key in self._dict 

307 

308 def __iter__(self) -> Iterator[_KeyType]: 

309 return iter(self.keys()) 

310 

311 def __len__(self) -> int: 

312 return len(self._dict) 

313 

314 def __eq__(self, other: Any) -> bool: 

315 if not isinstance(other, self.__class__): 

316 return False 

317 return sorted(self._list) == sorted(other._list) 

318 

319 def __repr__(self) -> str: 

320 class_name = self.__class__.__name__ 

321 items = self.multi_items() 

322 return f"{class_name}({items!r})" 

323 

324 

325class MultiDict(ImmutableMultiDict[Any, Any]): 

326 def __setitem__(self, key: Any, value: Any) -> None: 

327 self.setlist(key, [value]) 

328 

329 def __delitem__(self, key: Any) -> None: 

330 self._list = [(k, v) for k, v in self._list if k != key] 

331 del self._dict[key] 

332 

333 def pop(self, key: Any, default: Any = None) -> Any: 

334 self._list = [(k, v) for k, v in self._list if k != key] 

335 return self._dict.pop(key, default) 

336 

337 def popitem(self) -> tuple[Any, Any]: 

338 key, value = self._dict.popitem() 

339 self._list = [(k, v) for k, v in self._list if k != key] 

340 return key, value 

341 

342 def poplist(self, key: Any) -> list[Any]: 

343 values = [v for k, v in self._list if k == key] 

344 self.pop(key) 

345 return values 

346 

347 def clear(self) -> None: 

348 self._dict.clear() 

349 self._list.clear() 

350 

351 def setdefault(self, key: Any, default: Any = None) -> Any: 

352 if key not in self: 

353 self._dict[key] = default 

354 self._list.append((key, default)) 

355 

356 return self[key] 

357 

358 def setlist(self, key: Any, values: list[Any]) -> None: 

359 if not values: 

360 self.pop(key, None) 

361 else: 

362 existing_items = [(k, v) for (k, v) in self._list if k != key] 

363 self._list = existing_items + [(key, value) for value in values] 

364 self._dict[key] = values[-1] 

365 

366 def append(self, key: Any, value: Any) -> None: 

367 self._list.append((key, value)) 

368 self._dict[key] = value 

369 

370 def update( 

371 self, 

372 *args: MultiDict | Mapping[Any, Any] | list[tuple[Any, Any]], 

373 **kwargs: Any, 

374 ) -> None: 

375 value = MultiDict(*args, **kwargs) 

376 existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()] 

377 self._list = existing_items + value.multi_items() 

378 self._dict.update(value) 

379 

380 

381class QueryParams(ImmutableMultiDict[str, str]): 

382 """ 

383 An immutable multidict. 

384 """ 

385 

386 def __init__( 

387 self, 

388 *args: ImmutableMultiDict[Any, Any] | Mapping[Any, Any] | list[tuple[Any, Any]] | str | bytes, 

389 **kwargs: Any, 

390 ) -> None: 

391 assert len(args) < 2, "Too many arguments." 

392 

393 value = args[0] if args else [] 

394 

395 if isinstance(value, str): 

396 super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs) 

397 elif isinstance(value, bytes): 

398 super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs) 

399 else: 

400 super().__init__(*args, **kwargs) # type: ignore[arg-type] 

401 self._list = [(str(k), str(v)) for k, v in self._list] 

402 self._dict = {str(k): str(v) for k, v in self._dict.items()} 

403 

404 def __str__(self) -> str: 

405 return urlencode(self._list) 

406 

407 def __repr__(self) -> str: 

408 class_name = self.__class__.__name__ 

409 query_string = str(self) 

410 return f"{class_name}({query_string!r})" 

411 

412 

413class UploadFile: 

414 """ 

415 An uploaded file included as part of the request data. 

416 """ 

417 

418 def __init__( 

419 self, 

420 file: BinaryIO, 

421 *, 

422 size: int | None = None, 

423 filename: str | None = None, 

424 headers: Headers | None = None, 

425 ) -> None: 

426 self.filename = filename 

427 self.file = file 

428 self.size = size 

429 self.headers = headers or Headers() 

430 

431 # Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks. 

432 # Note 0 means unlimited mirroring SpooledTemporaryFile's __init__ 

433 self._max_mem_size = getattr(self.file, "_max_size", 0) 

434 

435 @property 

436 def content_type(self) -> str | None: 

437 return self.headers.get("content-type", None) 

438 

439 @property 

440 def _in_memory(self) -> bool: 

441 # check for SpooledTemporaryFile._rolled 

442 rolled_to_disk = getattr(self.file, "_rolled", True) 

443 return not rolled_to_disk 

444 

445 def _will_roll(self, size_to_add: int) -> bool: 

446 # If we're not in_memory then we will always roll 

447 if not self._in_memory: 

448 return True 

449 

450 # Check for SpooledTemporaryFile._max_size 

451 future_size = self.file.tell() + size_to_add 

452 return bool(future_size > self._max_mem_size) if self._max_mem_size else False 

453 

454 async def write(self, data: bytes) -> None: 

455 new_data_len = len(data) 

456 if self.size is not None: 

457 self.size += new_data_len 

458 

459 if self._will_roll(new_data_len): 

460 await run_in_threadpool(self.file.write, data) 

461 else: 

462 self.file.write(data) 

463 

464 async def read(self, size: int = -1) -> bytes: 

465 if self._in_memory: 

466 return self.file.read(size) 

467 return await run_in_threadpool(self.file.read, size) 

468 

469 async def seek(self, offset: int) -> None: 

470 if self._in_memory: 

471 self.file.seek(offset) 

472 else: 

473 await run_in_threadpool(self.file.seek, offset) 

474 

475 async def close(self) -> None: 

476 if self._in_memory: 

477 self.file.close() 

478 else: 

479 await run_in_threadpool(self.file.close) 

480 

481 def __repr__(self) -> str: 

482 return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})" 

483 

484 

485class FormData(ImmutableMultiDict[str, Union[UploadFile, str]]): 

486 """ 

487 An immutable multidict, containing both file uploads and text input. 

488 """ 

489 

490 def __init__( 

491 self, 

492 *args: FormData | Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]], 

493 **kwargs: str | UploadFile, 

494 ) -> None: 

495 super().__init__(*args, **kwargs) 

496 

497 async def close(self) -> None: 

498 for key, value in self.multi_items(): 

499 if isinstance(value, UploadFile): 

500 await value.close() 

501 

502 

503class Headers(Mapping[str, str]): 

504 """ 

505 An immutable, case-insensitive multidict. 

506 """ 

507 

508 def __init__( 

509 self, 

510 headers: Mapping[str, str] | None = None, 

511 raw: list[tuple[bytes, bytes]] | None = None, 

512 scope: MutableMapping[str, Any] | None = None, 

513 ) -> None: 

514 self._list: list[tuple[bytes, bytes]] = [] 

515 if headers is not None: 

516 assert raw is None, 'Cannot set both "headers" and "raw".' 

517 assert scope is None, 'Cannot set both "headers" and "scope".' 

518 self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()] 

519 elif raw is not None: 

520 assert scope is None, 'Cannot set both "raw" and "scope".' 

521 self._list = raw 

522 elif scope is not None: 

523 # scope["headers"] isn't necessarily a list 

524 # it might be a tuple or other iterable 

525 self._list = scope["headers"] = list(scope["headers"]) 

526 

527 @property 

528 def raw(self) -> list[tuple[bytes, bytes]]: 

529 return list(self._list) 

530 

531 def keys(self) -> list[str]: # type: ignore[override] 

532 return [key.decode("latin-1") for key, value in self._list] 

533 

534 def values(self) -> list[str]: # type: ignore[override] 

535 return [value.decode("latin-1") for key, value in self._list] 

536 

537 def items(self) -> list[tuple[str, str]]: # type: ignore[override] 

538 return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list] 

539 

540 def getlist(self, key: str) -> list[str]: 

541 get_header_key = key.lower().encode("latin-1") 

542 return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key] 

543 

544 def mutablecopy(self) -> MutableHeaders: 

545 return MutableHeaders(raw=self._list[:]) 

546 

547 def __getitem__(self, key: str) -> str: 

548 get_header_key = key.lower().encode("latin-1") 

549 for header_key, header_value in self._list: 

550 if header_key == get_header_key: 

551 return header_value.decode("latin-1") 

552 raise KeyError(key) 

553 

554 def __contains__(self, key: Any) -> bool: 

555 get_header_key = key.lower().encode("latin-1") 

556 for header_key, header_value in self._list: 

557 if header_key == get_header_key: 

558 return True 

559 return False 

560 

561 def __iter__(self) -> Iterator[Any]: 

562 return iter(self.keys()) 

563 

564 def __len__(self) -> int: 

565 return len(self._list) 

566 

567 def __eq__(self, other: Any) -> bool: 

568 if not isinstance(other, Headers): 

569 return False 

570 return sorted(self._list) == sorted(other._list) 

571 

572 def __repr__(self) -> str: 

573 class_name = self.__class__.__name__ 

574 as_dict = dict(self.items()) 

575 if len(as_dict) == len(self): 

576 return f"{class_name}({as_dict!r})" 

577 return f"{class_name}(raw={self.raw!r})" 

578 

579 

580class MutableHeaders(Headers): 

581 def __setitem__(self, key: str, value: str) -> None: 

582 """ 

583 Set the header `key` to `value`, removing any duplicate entries. 

584 Retains insertion order. 

585 """ 

586 set_key = key.lower().encode("latin-1") 

587 set_value = value.encode("latin-1") 

588 

589 found_indexes: list[int] = [] 

590 for idx, (item_key, item_value) in enumerate(self._list): 

591 if item_key == set_key: 

592 found_indexes.append(idx) 

593 

594 for idx in reversed(found_indexes[1:]): 

595 del self._list[idx] 

596 

597 if found_indexes: 

598 idx = found_indexes[0] 

599 self._list[idx] = (set_key, set_value) 

600 else: 

601 self._list.append((set_key, set_value)) 

602 

603 def __delitem__(self, key: str) -> None: 

604 """ 

605 Remove the header `key`. 

606 """ 

607 del_key = key.lower().encode("latin-1") 

608 

609 pop_indexes: list[int] = [] 

610 for idx, (item_key, item_value) in enumerate(self._list): 

611 if item_key == del_key: 

612 pop_indexes.append(idx) 

613 

614 for idx in reversed(pop_indexes): 

615 del self._list[idx] 

616 

617 def __ior__(self, other: Mapping[str, str]) -> MutableHeaders: 

618 if not isinstance(other, Mapping): 

619 raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") 

620 self.update(other) 

621 return self 

622 

623 def __or__(self, other: Mapping[str, str]) -> MutableHeaders: 

624 if not isinstance(other, Mapping): 

625 raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") 

626 new = self.mutablecopy() 

627 new.update(other) 

628 return new 

629 

630 @property 

631 def raw(self) -> list[tuple[bytes, bytes]]: 

632 return self._list 

633 

634 def setdefault(self, key: str, value: str) -> str: 

635 """ 

636 If the header `key` does not exist, then set it to `value`. 

637 Returns the header value. 

638 """ 

639 set_key = key.lower().encode("latin-1") 

640 set_value = value.encode("latin-1") 

641 

642 for idx, (item_key, item_value) in enumerate(self._list): 

643 if item_key == set_key: 

644 return item_value.decode("latin-1") 

645 self._list.append((set_key, set_value)) 

646 return value 

647 

648 def update(self, other: Mapping[str, str]) -> None: 

649 for key, val in other.items(): 

650 self[key] = val 

651 

652 def append(self, key: str, value: str) -> None: 

653 """ 

654 Append a header, preserving any duplicate entries. 

655 """ 

656 append_key = key.lower().encode("latin-1") 

657 append_value = value.encode("latin-1") 

658 self._list.append((append_key, append_value)) 

659 

660 def add_vary_header(self, vary: str) -> None: 

661 existing = self.get("vary") 

662 if existing is not None: 

663 vary = ", ".join([existing, vary]) 

664 self["vary"] = vary 

665 

666 

667class State: 

668 """ 

669 An object that can be used to store arbitrary state. 

670 

671 Used for `request.state` and `app.state`. 

672 """ 

673 

674 _state: dict[str, Any] 

675 

676 def __init__(self, state: dict[str, Any] | None = None): 

677 if state is None: 

678 state = {} 

679 super().__setattr__("_state", state) 

680 

681 def __setattr__(self, key: Any, value: Any) -> None: 

682 self._state[key] = value 

683 

684 def __getattr__(self, key: Any) -> Any: 

685 try: 

686 return self._state[key] 

687 except KeyError: 

688 message = "'{}' object has no attribute '{}'" 

689 raise AttributeError(message.format(self.__class__.__name__, key)) 

690 

691 def __delattr__(self, key: Any) -> None: 

692 del self._state[key]