Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/starlette/datastructures.py: 31%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

433 statements  

1from __future__ import annotations 

2 

3import typing 

4from shlex import shlex 

5from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit 

6 

7from starlette.concurrency import run_in_threadpool 

8from starlette.types import Scope 

9 

10 

11class Address(typing.NamedTuple): 

12 host: str 

13 port: int 

14 

15 

16_KeyType = typing.TypeVar("_KeyType") 

17# Mapping keys are invariant but their values are covariant since 

18# you can only read them 

19# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()` 

20_CovariantValueType = typing.TypeVar("_CovariantValueType", covariant=True) 

21 

22 

23class URL: 

24 def __init__( 

25 self, 

26 url: str = "", 

27 scope: Scope | None = None, 

28 **components: typing.Any, 

29 ) -> None: 

30 if scope is not None: 

31 assert not url, 'Cannot set both "url" and "scope".' 

32 assert not components, 'Cannot set both "scope" and "**components".' 

33 scheme = scope.get("scheme", "http") 

34 server = scope.get("server", None) 

35 path = scope["path"] 

36 query_string = scope.get("query_string", b"") 

37 

38 host_header = None 

39 for key, value in scope["headers"]: 

40 if key == b"host": 

41 host_header = value.decode("latin-1") 

42 break 

43 

44 if host_header is not None: 

45 url = f"{scheme}://{host_header}{path}" 

46 elif server is None: 

47 url = path 

48 else: 

49 host, port = server 

50 default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme] 

51 if port == default_port: 

52 url = f"{scheme}://{host}{path}" 

53 else: 

54 url = f"{scheme}://{host}:{port}{path}" 

55 

56 if query_string: 

57 url += "?" + query_string.decode() 

58 elif components: 

59 assert not url, 'Cannot set both "url" and "**components".' 

60 url = URL("").replace(**components).components.geturl() 

61 

62 self._url = url 

63 

64 @property 

65 def components(self) -> SplitResult: 

66 if not hasattr(self, "_components"): 

67 self._components = urlsplit(self._url) 

68 return self._components 

69 

70 @property 

71 def scheme(self) -> str: 

72 return self.components.scheme 

73 

74 @property 

75 def netloc(self) -> str: 

76 return self.components.netloc 

77 

78 @property 

79 def path(self) -> str: 

80 return self.components.path 

81 

82 @property 

83 def query(self) -> str: 

84 return self.components.query 

85 

86 @property 

87 def fragment(self) -> str: 

88 return self.components.fragment 

89 

90 @property 

91 def username(self) -> None | str: 

92 return self.components.username 

93 

94 @property 

95 def password(self) -> None | str: 

96 return self.components.password 

97 

98 @property 

99 def hostname(self) -> None | str: 

100 return self.components.hostname 

101 

102 @property 

103 def port(self) -> int | None: 

104 return self.components.port 

105 

106 @property 

107 def is_secure(self) -> bool: 

108 return self.scheme in ("https", "wss") 

109 

110 def replace(self, **kwargs: typing.Any) -> URL: 

111 if ( 

112 "username" in kwargs 

113 or "password" in kwargs 

114 or "hostname" in kwargs 

115 or "port" in kwargs 

116 ): 

117 hostname = kwargs.pop("hostname", None) 

118 port = kwargs.pop("port", self.port) 

119 username = kwargs.pop("username", self.username) 

120 password = kwargs.pop("password", self.password) 

121 

122 if hostname is None: 

123 netloc = self.netloc 

124 _, _, hostname = netloc.rpartition("@") 

125 

126 if hostname[-1] != "]": 

127 hostname = hostname.rsplit(":", 1)[0] 

128 

129 netloc = hostname 

130 if port is not None: 

131 netloc += f":{port}" 

132 if username is not None: 

133 userpass = username 

134 if password is not None: 

135 userpass += f":{password}" 

136 netloc = f"{userpass}@{netloc}" 

137 

138 kwargs["netloc"] = netloc 

139 

140 components = self.components._replace(**kwargs) 

141 return self.__class__(components.geturl()) 

142 

143 def include_query_params(self, **kwargs: typing.Any) -> URL: 

144 params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) 

145 params.update({str(key): str(value) for key, value in kwargs.items()}) 

146 query = urlencode(params.multi_items()) 

147 return self.replace(query=query) 

148 

149 def replace_query_params(self, **kwargs: typing.Any) -> URL: 

150 query = urlencode([(str(key), str(value)) for key, value in kwargs.items()]) 

151 return self.replace(query=query) 

152 

153 def remove_query_params(self, keys: str | typing.Sequence[str]) -> URL: 

154 if isinstance(keys, str): 

155 keys = [keys] 

156 params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) 

157 for key in keys: 

158 params.pop(key, None) 

159 query = urlencode(params.multi_items()) 

160 return self.replace(query=query) 

161 

162 def __eq__(self, other: typing.Any) -> bool: 

163 return str(self) == str(other) 

164 

165 def __str__(self) -> str: 

166 return self._url 

167 

168 def __repr__(self) -> str: 

169 url = str(self) 

170 if self.password: 

171 url = str(self.replace(password="********")) 

172 return f"{self.__class__.__name__}({repr(url)})" 

173 

174 

175class URLPath(str): 

176 """ 

177 A URL path string that may also hold an associated protocol and/or host. 

178 Used by the routing to return `url_path_for` matches. 

179 """ 

180 

181 def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath: 

182 assert protocol in ("http", "websocket", "") 

183 return str.__new__(cls, path) 

184 

185 def __init__(self, path: str, protocol: str = "", host: str = "") -> None: 

186 self.protocol = protocol 

187 self.host = host 

188 

189 def make_absolute_url(self, base_url: str | URL) -> URL: 

190 if isinstance(base_url, str): 

191 base_url = URL(base_url) 

192 if self.protocol: 

193 scheme = { 

194 "http": {True: "https", False: "http"}, 

195 "websocket": {True: "wss", False: "ws"}, 

196 }[self.protocol][base_url.is_secure] 

197 else: 

198 scheme = base_url.scheme 

199 

200 netloc = self.host or base_url.netloc 

201 path = base_url.path.rstrip("/") + str(self) 

202 return URL(scheme=scheme, netloc=netloc, path=path) 

203 

204 

205class Secret: 

206 """ 

207 Holds a string value that should not be revealed in tracebacks etc. 

208 You should cast the value to `str` at the point it is required. 

209 """ 

210 

211 def __init__(self, value: str): 

212 self._value = value 

213 

214 def __repr__(self) -> str: 

215 class_name = self.__class__.__name__ 

216 return f"{class_name}('**********')" 

217 

218 def __str__(self) -> str: 

219 return self._value 

220 

221 def __bool__(self) -> bool: 

222 return bool(self._value) 

223 

224 

225class CommaSeparatedStrings(typing.Sequence[str]): 

226 def __init__(self, value: str | typing.Sequence[str]): 

227 if isinstance(value, str): 

228 splitter = shlex(value, posix=True) 

229 splitter.whitespace = "," 

230 splitter.whitespace_split = True 

231 self._items = [item.strip() for item in splitter] 

232 else: 

233 self._items = list(value) 

234 

235 def __len__(self) -> int: 

236 return len(self._items) 

237 

238 def __getitem__(self, index: int | slice) -> typing.Any: 

239 return self._items[index] 

240 

241 def __iter__(self) -> typing.Iterator[str]: 

242 return iter(self._items) 

243 

244 def __repr__(self) -> str: 

245 class_name = self.__class__.__name__ 

246 items = [item for item in self] 

247 return f"{class_name}({items!r})" 

248 

249 def __str__(self) -> str: 

250 return ", ".join(repr(item) for item in self) 

251 

252 

253class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]): 

254 _dict: dict[_KeyType, _CovariantValueType] 

255 

256 def __init__( 

257 self, 

258 *args: ImmutableMultiDict[_KeyType, _CovariantValueType] 

259 | typing.Mapping[_KeyType, _CovariantValueType] 

260 | typing.Iterable[tuple[_KeyType, _CovariantValueType]], 

261 **kwargs: typing.Any, 

262 ) -> None: 

263 assert len(args) < 2, "Too many arguments." 

264 

265 value: typing.Any = args[0] if args else [] 

266 if kwargs: 

267 value = ( 

268 ImmutableMultiDict(value).multi_items() 

269 + ImmutableMultiDict(kwargs).multi_items() 

270 ) 

271 

272 if not value: 

273 _items: list[tuple[typing.Any, typing.Any]] = [] 

274 elif hasattr(value, "multi_items"): 

275 value = typing.cast( 

276 ImmutableMultiDict[_KeyType, _CovariantValueType], value 

277 ) 

278 _items = list(value.multi_items()) 

279 elif hasattr(value, "items"): 

280 value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value) 

281 _items = list(value.items()) 

282 else: 

283 value = typing.cast("list[tuple[typing.Any, typing.Any]]", value) 

284 _items = list(value) 

285 

286 self._dict = {k: v for k, v in _items} 

287 self._list = _items 

288 

289 def getlist(self, key: typing.Any) -> list[_CovariantValueType]: 

290 return [item_value for item_key, item_value in self._list if item_key == key] 

291 

292 def keys(self) -> typing.KeysView[_KeyType]: 

293 return self._dict.keys() 

294 

295 def values(self) -> typing.ValuesView[_CovariantValueType]: 

296 return self._dict.values() 

297 

298 def items(self) -> typing.ItemsView[_KeyType, _CovariantValueType]: 

299 return self._dict.items() 

300 

301 def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]: 

302 return list(self._list) 

303 

304 def __getitem__(self, key: _KeyType) -> _CovariantValueType: 

305 return self._dict[key] 

306 

307 def __contains__(self, key: typing.Any) -> bool: 

308 return key in self._dict 

309 

310 def __iter__(self) -> typing.Iterator[_KeyType]: 

311 return iter(self.keys()) 

312 

313 def __len__(self) -> int: 

314 return len(self._dict) 

315 

316 def __eq__(self, other: typing.Any) -> bool: 

317 if not isinstance(other, self.__class__): 

318 return False 

319 return sorted(self._list) == sorted(other._list) 

320 

321 def __repr__(self) -> str: 

322 class_name = self.__class__.__name__ 

323 items = self.multi_items() 

324 return f"{class_name}({items!r})" 

325 

326 

327class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]): 

328 def __setitem__(self, key: typing.Any, value: typing.Any) -> None: 

329 self.setlist(key, [value]) 

330 

331 def __delitem__(self, key: typing.Any) -> None: 

332 self._list = [(k, v) for k, v in self._list if k != key] 

333 del self._dict[key] 

334 

335 def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any: 

336 self._list = [(k, v) for k, v in self._list if k != key] 

337 return self._dict.pop(key, default) 

338 

339 def popitem(self) -> tuple[typing.Any, typing.Any]: 

340 key, value = self._dict.popitem() 

341 self._list = [(k, v) for k, v in self._list if k != key] 

342 return key, value 

343 

344 def poplist(self, key: typing.Any) -> list[typing.Any]: 

345 values = [v for k, v in self._list if k == key] 

346 self.pop(key) 

347 return values 

348 

349 def clear(self) -> None: 

350 self._dict.clear() 

351 self._list.clear() 

352 

353 def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any: 

354 if key not in self: 

355 self._dict[key] = default 

356 self._list.append((key, default)) 

357 

358 return self[key] 

359 

360 def setlist(self, key: typing.Any, values: list[typing.Any]) -> None: 

361 if not values: 

362 self.pop(key, None) 

363 else: 

364 existing_items = [(k, v) for (k, v) in self._list if k != key] 

365 self._list = existing_items + [(key, value) for value in values] 

366 self._dict[key] = values[-1] 

367 

368 def append(self, key: typing.Any, value: typing.Any) -> None: 

369 self._list.append((key, value)) 

370 self._dict[key] = value 

371 

372 def update( 

373 self, 

374 *args: MultiDict 

375 | typing.Mapping[typing.Any, typing.Any] 

376 | list[tuple[typing.Any, typing.Any]], 

377 **kwargs: typing.Any, 

378 ) -> None: 

379 value = MultiDict(*args, **kwargs) 

380 existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()] 

381 self._list = existing_items + value.multi_items() 

382 self._dict.update(value) 

383 

384 

385class QueryParams(ImmutableMultiDict[str, str]): 

386 """ 

387 An immutable multidict. 

388 """ 

389 

390 def __init__( 

391 self, 

392 *args: ImmutableMultiDict[typing.Any, typing.Any] 

393 | typing.Mapping[typing.Any, typing.Any] 

394 | list[tuple[typing.Any, typing.Any]] 

395 | str 

396 | bytes, 

397 **kwargs: typing.Any, 

398 ) -> None: 

399 assert len(args) < 2, "Too many arguments." 

400 

401 value = args[0] if args else [] 

402 

403 if isinstance(value, str): 

404 super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs) 

405 elif isinstance(value, bytes): 

406 super().__init__( 

407 parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs 

408 ) 

409 else: 

410 super().__init__(*args, **kwargs) # type: ignore[arg-type] 

411 self._list = [(str(k), str(v)) for k, v in self._list] 

412 self._dict = {str(k): str(v) for k, v in self._dict.items()} 

413 

414 def __str__(self) -> str: 

415 return urlencode(self._list) 

416 

417 def __repr__(self) -> str: 

418 class_name = self.__class__.__name__ 

419 query_string = str(self) 

420 return f"{class_name}({query_string!r})" 

421 

422 

423class UploadFile: 

424 """ 

425 An uploaded file included as part of the request data. 

426 """ 

427 

428 def __init__( 

429 self, 

430 file: typing.BinaryIO, 

431 *, 

432 size: int | None = None, 

433 filename: str | None = None, 

434 headers: Headers | None = None, 

435 ) -> None: 

436 self.filename = filename 

437 self.file = file 

438 self.size = size 

439 self.headers = headers or Headers() 

440 

441 @property 

442 def content_type(self) -> str | None: 

443 return self.headers.get("content-type", None) 

444 

445 @property 

446 def _in_memory(self) -> bool: 

447 # check for SpooledTemporaryFile._rolled 

448 rolled_to_disk = getattr(self.file, "_rolled", True) 

449 return not rolled_to_disk 

450 

451 async def write(self, data: bytes) -> None: 

452 if self.size is not None: 

453 self.size += len(data) 

454 

455 if self._in_memory: 

456 self.file.write(data) 

457 else: 

458 await run_in_threadpool(self.file.write, data) 

459 

460 async def read(self, size: int = -1) -> bytes: 

461 if self._in_memory: 

462 return self.file.read(size) 

463 return await run_in_threadpool(self.file.read, size) 

464 

465 async def seek(self, offset: int) -> None: 

466 if self._in_memory: 

467 self.file.seek(offset) 

468 else: 

469 await run_in_threadpool(self.file.seek, offset) 

470 

471 async def close(self) -> None: 

472 if self._in_memory: 

473 self.file.close() 

474 else: 

475 await run_in_threadpool(self.file.close) 

476 

477 def __repr__(self) -> str: 

478 return ( 

479 f"{self.__class__.__name__}(" 

480 f"filename={self.filename!r}, " 

481 f"size={self.size!r}, " 

482 f"headers={self.headers!r})" 

483 ) 

484 

485 

486class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]): 

487 """ 

488 An immutable multidict, containing both file uploads and text input. 

489 """ 

490 

491 def __init__( 

492 self, 

493 *args: FormData 

494 | typing.Mapping[str, str | UploadFile] 

495 | list[tuple[str, str | UploadFile]], 

496 **kwargs: str | UploadFile, 

497 ) -> None: 

498 super().__init__(*args, **kwargs) 

499 

500 async def close(self) -> None: 

501 for key, value in self.multi_items(): 

502 if isinstance(value, UploadFile): 

503 await value.close() 

504 

505 

506class Headers(typing.Mapping[str, str]): 

507 """ 

508 An immutable, case-insensitive multidict. 

509 """ 

510 

511 def __init__( 

512 self, 

513 headers: typing.Mapping[str, str] | None = None, 

514 raw: list[tuple[bytes, bytes]] | None = None, 

515 scope: typing.MutableMapping[str, typing.Any] | None = None, 

516 ) -> None: 

517 self._list: list[tuple[bytes, bytes]] = [] 

518 if headers is not None: 

519 assert raw is None, 'Cannot set both "headers" and "raw".' 

520 assert scope is None, 'Cannot set both "headers" and "scope".' 

521 self._list = [ 

522 (key.lower().encode("latin-1"), value.encode("latin-1")) 

523 for key, value in headers.items() 

524 ] 

525 elif raw is not None: 

526 assert scope is None, 'Cannot set both "raw" and "scope".' 

527 self._list = raw 

528 elif scope is not None: 

529 # scope["headers"] isn't necessarily a list 

530 # it might be a tuple or other iterable 

531 self._list = scope["headers"] = list(scope["headers"]) 

532 

533 @property 

534 def raw(self) -> list[tuple[bytes, bytes]]: 

535 return list(self._list) 

536 

537 def keys(self) -> list[str]: # type: ignore[override] 

538 return [key.decode("latin-1") for key, value in self._list] 

539 

540 def values(self) -> list[str]: # type: ignore[override] 

541 return [value.decode("latin-1") for key, value in self._list] 

542 

543 def items(self) -> list[tuple[str, str]]: # type: ignore[override] 

544 return [ 

545 (key.decode("latin-1"), value.decode("latin-1")) 

546 for key, value in self._list 

547 ] 

548 

549 def getlist(self, key: str) -> list[str]: 

550 get_header_key = key.lower().encode("latin-1") 

551 return [ 

552 item_value.decode("latin-1") 

553 for item_key, item_value in self._list 

554 if item_key == get_header_key 

555 ] 

556 

557 def mutablecopy(self) -> MutableHeaders: 

558 return MutableHeaders(raw=self._list[:]) 

559 

560 def __getitem__(self, key: str) -> str: 

561 get_header_key = key.lower().encode("latin-1") 

562 for header_key, header_value in self._list: 

563 if header_key == get_header_key: 

564 return header_value.decode("latin-1") 

565 raise KeyError(key) 

566 

567 def __contains__(self, key: typing.Any) -> bool: 

568 get_header_key = key.lower().encode("latin-1") 

569 for header_key, header_value in self._list: 

570 if header_key == get_header_key: 

571 return True 

572 return False 

573 

574 def __iter__(self) -> typing.Iterator[typing.Any]: 

575 return iter(self.keys()) 

576 

577 def __len__(self) -> int: 

578 return len(self._list) 

579 

580 def __eq__(self, other: typing.Any) -> bool: 

581 if not isinstance(other, Headers): 

582 return False 

583 return sorted(self._list) == sorted(other._list) 

584 

585 def __repr__(self) -> str: 

586 class_name = self.__class__.__name__ 

587 as_dict = dict(self.items()) 

588 if len(as_dict) == len(self): 

589 return f"{class_name}({as_dict!r})" 

590 return f"{class_name}(raw={self.raw!r})" 

591 

592 

593class MutableHeaders(Headers): 

594 def __setitem__(self, key: str, value: str) -> None: 

595 """ 

596 Set the header `key` to `value`, removing any duplicate entries. 

597 Retains insertion order. 

598 """ 

599 set_key = key.lower().encode("latin-1") 

600 set_value = value.encode("latin-1") 

601 

602 found_indexes: list[int] = [] 

603 for idx, (item_key, item_value) in enumerate(self._list): 

604 if item_key == set_key: 

605 found_indexes.append(idx) 

606 

607 for idx in reversed(found_indexes[1:]): 

608 del self._list[idx] 

609 

610 if found_indexes: 

611 idx = found_indexes[0] 

612 self._list[idx] = (set_key, set_value) 

613 else: 

614 self._list.append((set_key, set_value)) 

615 

616 def __delitem__(self, key: str) -> None: 

617 """ 

618 Remove the header `key`. 

619 """ 

620 del_key = key.lower().encode("latin-1") 

621 

622 pop_indexes: list[int] = [] 

623 for idx, (item_key, item_value) in enumerate(self._list): 

624 if item_key == del_key: 

625 pop_indexes.append(idx) 

626 

627 for idx in reversed(pop_indexes): 

628 del self._list[idx] 

629 

630 def __ior__(self, other: typing.Mapping[str, str]) -> MutableHeaders: 

631 if not isinstance(other, typing.Mapping): 

632 raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") 

633 self.update(other) 

634 return self 

635 

636 def __or__(self, other: typing.Mapping[str, str]) -> MutableHeaders: 

637 if not isinstance(other, typing.Mapping): 

638 raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") 

639 new = self.mutablecopy() 

640 new.update(other) 

641 return new 

642 

643 @property 

644 def raw(self) -> list[tuple[bytes, bytes]]: 

645 return self._list 

646 

647 def setdefault(self, key: str, value: str) -> str: 

648 """ 

649 If the header `key` does not exist, then set it to `value`. 

650 Returns the header value. 

651 """ 

652 set_key = key.lower().encode("latin-1") 

653 set_value = value.encode("latin-1") 

654 

655 for idx, (item_key, item_value) in enumerate(self._list): 

656 if item_key == set_key: 

657 return item_value.decode("latin-1") 

658 self._list.append((set_key, set_value)) 

659 return value 

660 

661 def update(self, other: typing.Mapping[str, str]) -> None: 

662 for key, val in other.items(): 

663 self[key] = val 

664 

665 def append(self, key: str, value: str) -> None: 

666 """ 

667 Append a header, preserving any duplicate entries. 

668 """ 

669 append_key = key.lower().encode("latin-1") 

670 append_value = value.encode("latin-1") 

671 self._list.append((append_key, append_value)) 

672 

673 def add_vary_header(self, vary: str) -> None: 

674 existing = self.get("vary") 

675 if existing is not None: 

676 vary = ", ".join([existing, vary]) 

677 self["vary"] = vary 

678 

679 

680class State: 

681 """ 

682 An object that can be used to store arbitrary state. 

683 

684 Used for `request.state` and `app.state`. 

685 """ 

686 

687 _state: dict[str, typing.Any] 

688 

689 def __init__(self, state: dict[str, typing.Any] | None = None): 

690 if state is None: 

691 state = {} 

692 super().__setattr__("_state", state) 

693 

694 def __setattr__(self, key: typing.Any, value: typing.Any) -> None: 

695 self._state[key] = value 

696 

697 def __getattr__(self, key: typing.Any) -> typing.Any: 

698 try: 

699 return self._state[key] 

700 except KeyError: 

701 message = "'{}' object has no attribute '{}'" 

702 raise AttributeError(message.format(self.__class__.__name__, key)) 

703 

704 def __delattr__(self, key: typing.Any) -> None: 

705 del self._state[key]