Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/starlette/datastructures.py: 31%

431 statements  

« prev     ^ index     » next       coverage.py v7.2.2, created at 2023-03-26 06:12 +0000

1import typing 

2from collections.abc import Sequence 

3from shlex import shlex 

4from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit 

5 

6from starlette.concurrency import run_in_threadpool 

7from starlette.types import Scope 

8 

9 

10class Address(typing.NamedTuple): 

11 host: str 

12 port: int 

13 

14 

15_KeyType = typing.TypeVar("_KeyType") 

16# Mapping keys are invariant but their values are covariant since 

17# you can only read them 

18# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()` 

19_CovariantValueType = typing.TypeVar("_CovariantValueType", covariant=True) 

20 

21 

22class URL: 

23 def __init__( 

24 self, 

25 url: str = "", 

26 scope: typing.Optional[Scope] = None, 

27 **components: typing.Any, 

28 ) -> None: 

29 if scope is not None: 

30 assert not url, 'Cannot set both "url" and "scope".' 

31 assert not components, 'Cannot set both "scope" and "**components".' 

32 scheme = scope.get("scheme", "http") 

33 server = scope.get("server", None) 

34 path = scope.get("root_path", "") + scope["path"] 

35 query_string = scope.get("query_string", b"") 

36 

37 host_header = None 

38 for key, value in scope["headers"]: 

39 if key == b"host": 

40 host_header = value.decode("latin-1") 

41 break 

42 

43 if host_header is not None: 

44 url = f"{scheme}://{host_header}{path}" 

45 elif server is None: 

46 url = path 

47 else: 

48 host, port = server 

49 default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme] 

50 if port == default_port: 

51 url = f"{scheme}://{host}{path}" 

52 else: 

53 url = f"{scheme}://{host}:{port}{path}" 

54 

55 if query_string: 

56 url += "?" + query_string.decode() 

57 elif components: 

58 assert not url, 'Cannot set both "url" and "**components".' 

59 url = URL("").replace(**components).components.geturl() 

60 

61 self._url = url 

62 

63 @property 

64 def components(self) -> SplitResult: 

65 if not hasattr(self, "_components"): 

66 self._components = urlsplit(self._url) 

67 return self._components 

68 

69 @property 

70 def scheme(self) -> str: 

71 return self.components.scheme 

72 

73 @property 

74 def netloc(self) -> str: 

75 return self.components.netloc 

76 

77 @property 

78 def path(self) -> str: 

79 return self.components.path 

80 

81 @property 

82 def query(self) -> str: 

83 return self.components.query 

84 

85 @property 

86 def fragment(self) -> str: 

87 return self.components.fragment 

88 

89 @property 

90 def username(self) -> typing.Union[None, str]: 

91 return self.components.username 

92 

93 @property 

94 def password(self) -> typing.Union[None, str]: 

95 return self.components.password 

96 

97 @property 

98 def hostname(self) -> typing.Union[None, str]: 

99 return self.components.hostname 

100 

101 @property 

102 def port(self) -> typing.Optional[int]: 

103 return self.components.port 

104 

105 @property 

106 def is_secure(self) -> bool: 

107 return self.scheme in ("https", "wss") 

108 

109 def replace(self, **kwargs: typing.Any) -> "URL": 

110 if ( 

111 "username" in kwargs 

112 or "password" in kwargs 

113 or "hostname" in kwargs 

114 or "port" in kwargs 

115 ): 

116 hostname = kwargs.pop("hostname", None) 

117 port = kwargs.pop("port", self.port) 

118 username = kwargs.pop("username", self.username) 

119 password = kwargs.pop("password", self.password) 

120 

121 if hostname is None: 

122 netloc = self.netloc 

123 _, _, hostname = netloc.rpartition("@") 

124 

125 if hostname[-1] != "]": 

126 hostname = hostname.rsplit(":", 1)[0] 

127 

128 netloc = hostname 

129 if port is not None: 

130 netloc += f":{port}" 

131 if username is not None: 

132 userpass = username 

133 if password is not None: 

134 userpass += f":{password}" 

135 netloc = f"{userpass}@{netloc}" 

136 

137 kwargs["netloc"] = netloc 

138 

139 components = self.components._replace(**kwargs) 

140 return self.__class__(components.geturl()) 

141 

142 def include_query_params(self, **kwargs: typing.Any) -> "URL": 

143 params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) 

144 params.update({str(key): str(value) for key, value in kwargs.items()}) 

145 query = urlencode(params.multi_items()) 

146 return self.replace(query=query) 

147 

148 def replace_query_params(self, **kwargs: typing.Any) -> "URL": 

149 query = urlencode([(str(key), str(value)) for key, value in kwargs.items()]) 

150 return self.replace(query=query) 

151 

152 def remove_query_params( 

153 self, keys: typing.Union[str, typing.Sequence[str]] 

154 ) -> "URL": 

155 if isinstance(keys, str): 

156 keys = [keys] 

157 params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) 

158 for key in keys: 

159 params.pop(key, None) 

160 query = urlencode(params.multi_items()) 

161 return self.replace(query=query) 

162 

163 def __eq__(self, other: typing.Any) -> bool: 

164 return str(self) == str(other) 

165 

166 def __str__(self) -> str: 

167 return self._url 

168 

169 def __repr__(self) -> str: 

170 url = str(self) 

171 if self.password: 

172 url = str(self.replace(password="********")) 

173 return f"{self.__class__.__name__}({repr(url)})" 

174 

175 

176class URLPath(str): 

177 """ 

178 A URL path string that may also hold an associated protocol and/or host. 

179 Used by the routing to return `url_path_for` matches. 

180 """ 

181 

182 def __new__(cls, path: str, protocol: str = "", host: str = "") -> "URLPath": 

183 assert protocol in ("http", "websocket", "") 

184 return str.__new__(cls, path) 

185 

186 def __init__(self, path: str, protocol: str = "", host: str = "") -> None: 

187 self.protocol = protocol 

188 self.host = host 

189 

190 def make_absolute_url(self, base_url: typing.Union[str, URL]) -> str: 

191 if isinstance(base_url, str): 

192 base_url = URL(base_url) 

193 if self.protocol: 

194 scheme = { 

195 "http": {True: "https", False: "http"}, 

196 "websocket": {True: "wss", False: "ws"}, 

197 }[self.protocol][base_url.is_secure] 

198 else: 

199 scheme = base_url.scheme 

200 

201 netloc = self.host or base_url.netloc 

202 path = base_url.path.rstrip("/") + str(self) 

203 return str(URL(scheme=scheme, netloc=netloc, path=path)) 

204 

205 

206class Secret: 

207 """ 

208 Holds a string value that should not be revealed in tracebacks etc. 

209 You should cast the value to `str` at the point it is required. 

210 """ 

211 

212 def __init__(self, value: str): 

213 self._value = value 

214 

215 def __repr__(self) -> str: 

216 class_name = self.__class__.__name__ 

217 return f"{class_name}('**********')" 

218 

219 def __str__(self) -> str: 

220 return self._value 

221 

222 def __bool__(self) -> bool: 

223 return bool(self._value) 

224 

225 

226class CommaSeparatedStrings(Sequence): 

227 def __init__(self, value: typing.Union[str, typing.Sequence[str]]): 

228 if isinstance(value, str): 

229 splitter = shlex(value, posix=True) 

230 splitter.whitespace = "," 

231 splitter.whitespace_split = True 

232 self._items = [item.strip() for item in splitter] 

233 else: 

234 self._items = list(value) 

235 

236 def __len__(self) -> int: 

237 return len(self._items) 

238 

239 def __getitem__(self, index: typing.Union[int, slice]) -> typing.Any: 

240 return self._items[index] 

241 

242 def __iter__(self) -> typing.Iterator[str]: 

243 return iter(self._items) 

244 

245 def __repr__(self) -> str: 

246 class_name = self.__class__.__name__ 

247 items = [item for item in self] 

248 return f"{class_name}({items!r})" 

249 

250 def __str__(self) -> str: 

251 return ", ".join(repr(item) for item in self) 

252 

253 

254class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]): 

255 _dict: typing.Dict[_KeyType, _CovariantValueType] 

256 

257 def __init__( 

258 self, 

259 *args: typing.Union[ 

260 "ImmutableMultiDict[_KeyType, _CovariantValueType]", 

261 typing.Mapping[_KeyType, _CovariantValueType], 

262 typing.Iterable[typing.Tuple[_KeyType, _CovariantValueType]], 

263 ], 

264 **kwargs: typing.Any, 

265 ) -> None: 

266 assert len(args) < 2, "Too many arguments." 

267 

268 value: typing.Any = args[0] if args else [] 

269 if kwargs: 

270 value = ( 

271 ImmutableMultiDict(value).multi_items() 

272 + ImmutableMultiDict(kwargs).multi_items() # type: ignore[operator] 

273 ) 

274 

275 if not value: 

276 _items: typing.List[typing.Tuple[typing.Any, typing.Any]] = [] 

277 elif hasattr(value, "multi_items"): 

278 value = typing.cast( 

279 ImmutableMultiDict[_KeyType, _CovariantValueType], value 

280 ) 

281 _items = list(value.multi_items()) 

282 elif hasattr(value, "items"): 

283 value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value) 

284 _items = list(value.items()) 

285 else: 

286 value = typing.cast( 

287 typing.List[typing.Tuple[typing.Any, typing.Any]], value 

288 ) 

289 _items = list(value) 

290 

291 self._dict = {k: v for k, v in _items} 

292 self._list = _items 

293 

294 def getlist(self, key: typing.Any) -> typing.List[_CovariantValueType]: 

295 return [item_value for item_key, item_value in self._list if item_key == key] 

296 

297 def keys(self) -> typing.KeysView[_KeyType]: 

298 return self._dict.keys() 

299 

300 def values(self) -> typing.ValuesView[_CovariantValueType]: 

301 return self._dict.values() 

302 

303 def items(self) -> typing.ItemsView[_KeyType, _CovariantValueType]: 

304 return self._dict.items() 

305 

306 def multi_items(self) -> typing.List[typing.Tuple[_KeyType, _CovariantValueType]]: 

307 return list(self._list) 

308 

309 def __getitem__(self, key: _KeyType) -> _CovariantValueType: 

310 return self._dict[key] 

311 

312 def __contains__(self, key: typing.Any) -> bool: 

313 return key in self._dict 

314 

315 def __iter__(self) -> typing.Iterator[_KeyType]: 

316 return iter(self.keys()) 

317 

318 def __len__(self) -> int: 

319 return len(self._dict) 

320 

321 def __eq__(self, other: typing.Any) -> bool: 

322 if not isinstance(other, self.__class__): 

323 return False 

324 return sorted(self._list) == sorted(other._list) 

325 

326 def __repr__(self) -> str: 

327 class_name = self.__class__.__name__ 

328 items = self.multi_items() 

329 return f"{class_name}({items!r})" 

330 

331 

332class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]): 

333 def __setitem__(self, key: typing.Any, value: typing.Any) -> None: 

334 self.setlist(key, [value]) 

335 

336 def __delitem__(self, key: typing.Any) -> None: 

337 self._list = [(k, v) for k, v in self._list if k != key] 

338 del self._dict[key] 

339 

340 def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any: 

341 self._list = [(k, v) for k, v in self._list if k != key] 

342 return self._dict.pop(key, default) 

343 

344 def popitem(self) -> typing.Tuple: 

345 key, value = self._dict.popitem() 

346 self._list = [(k, v) for k, v in self._list if k != key] 

347 return key, value 

348 

349 def poplist(self, key: typing.Any) -> typing.List: 

350 values = [v for k, v in self._list if k == key] 

351 self.pop(key) 

352 return values 

353 

354 def clear(self) -> None: 

355 self._dict.clear() 

356 self._list.clear() 

357 

358 def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any: 

359 if key not in self: 

360 self._dict[key] = default 

361 self._list.append((key, default)) 

362 

363 return self[key] 

364 

365 def setlist(self, key: typing.Any, values: typing.List) -> None: 

366 if not values: 

367 self.pop(key, None) 

368 else: 

369 existing_items = [(k, v) for (k, v) in self._list if k != key] 

370 self._list = existing_items + [(key, value) for value in values] 

371 self._dict[key] = values[-1] 

372 

373 def append(self, key: typing.Any, value: typing.Any) -> None: 

374 self._list.append((key, value)) 

375 self._dict[key] = value 

376 

377 def update( 

378 self, 

379 *args: typing.Union[ 

380 "MultiDict", 

381 typing.Mapping, 

382 typing.List[typing.Tuple[typing.Any, typing.Any]], 

383 ], 

384 **kwargs: typing.Any, 

385 ) -> None: 

386 value = MultiDict(*args, **kwargs) 

387 existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()] 

388 self._list = existing_items + value.multi_items() 

389 self._dict.update(value) 

390 

391 

392class QueryParams(ImmutableMultiDict[str, str]): 

393 """ 

394 An immutable multidict. 

395 """ 

396 

397 def __init__( 

398 self, 

399 *args: typing.Union[ 

400 "ImmutableMultiDict", 

401 typing.Mapping, 

402 typing.List[typing.Tuple[typing.Any, typing.Any]], 

403 str, 

404 bytes, 

405 ], 

406 **kwargs: typing.Any, 

407 ) -> None: 

408 assert len(args) < 2, "Too many arguments." 

409 

410 value = args[0] if args else [] 

411 

412 if isinstance(value, str): 

413 super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs) 

414 elif isinstance(value, bytes): 

415 super().__init__( 

416 parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs 

417 ) 

418 else: 

419 super().__init__(*args, **kwargs) # type: ignore[arg-type] 

420 self._list = [(str(k), str(v)) for k, v in self._list] 

421 self._dict = {str(k): str(v) for k, v in self._dict.items()} 

422 

423 def __str__(self) -> str: 

424 return urlencode(self._list) 

425 

426 def __repr__(self) -> str: 

427 class_name = self.__class__.__name__ 

428 query_string = str(self) 

429 return f"{class_name}({query_string!r})" 

430 

431 

432class UploadFile: 

433 """ 

434 An uploaded file included as part of the request data. 

435 """ 

436 

437 def __init__( 

438 self, 

439 file: typing.BinaryIO, 

440 *, 

441 size: typing.Optional[int] = None, 

442 filename: typing.Optional[str] = None, 

443 headers: "typing.Optional[Headers]" = None, 

444 ) -> None: 

445 self.filename = filename 

446 self.file = file 

447 self.size = size 

448 self.headers = headers or Headers() 

449 

450 @property 

451 def content_type(self) -> typing.Optional[str]: 

452 return self.headers.get("content-type", None) 

453 

454 @property 

455 def _in_memory(self) -> bool: 

456 # check for SpooledTemporaryFile._rolled 

457 rolled_to_disk = getattr(self.file, "_rolled", True) 

458 return not rolled_to_disk 

459 

460 async def write(self, data: bytes) -> None: 

461 if self.size is not None: 

462 self.size += len(data) 

463 

464 if self._in_memory: 

465 self.file.write(data) 

466 else: 

467 await run_in_threadpool(self.file.write, data) 

468 

469 async def read(self, size: int = -1) -> bytes: 

470 if self._in_memory: 

471 return self.file.read(size) 

472 return await run_in_threadpool(self.file.read, size) 

473 

474 async def seek(self, offset: int) -> None: 

475 if self._in_memory: 

476 self.file.seek(offset) 

477 else: 

478 await run_in_threadpool(self.file.seek, offset) 

479 

480 async def close(self) -> None: 

481 if self._in_memory: 

482 self.file.close() 

483 else: 

484 await run_in_threadpool(self.file.close) 

485 

486 

487class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]): 

488 """ 

489 An immutable multidict, containing both file uploads and text input. 

490 """ 

491 

492 def __init__( 

493 self, 

494 *args: typing.Union[ 

495 "FormData", 

496 typing.Mapping[str, typing.Union[str, UploadFile]], 

497 typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]], 

498 ], 

499 **kwargs: typing.Union[str, UploadFile], 

500 ) -> None: 

501 super().__init__(*args, **kwargs) 

502 

503 async def close(self) -> None: 

504 for key, value in self.multi_items(): 

505 if isinstance(value, UploadFile): 

506 await value.close() 

507 

508 

509class Headers(typing.Mapping[str, str]): 

510 """ 

511 An immutable, case-insensitive multidict. 

512 """ 

513 

514 def __init__( 

515 self, 

516 headers: typing.Optional[typing.Mapping[str, str]] = None, 

517 raw: typing.Optional[typing.List[typing.Tuple[bytes, bytes]]] = None, 

518 scope: typing.Optional[typing.MutableMapping[str, typing.Any]] = None, 

519 ) -> None: 

520 self._list: typing.List[typing.Tuple[bytes, bytes]] = [] 

521 if headers is not None: 

522 assert raw is None, 'Cannot set both "headers" and "raw".' 

523 assert scope is None, 'Cannot set both "headers" and "scope".' 

524 self._list = [ 

525 (key.lower().encode("latin-1"), value.encode("latin-1")) 

526 for key, value in headers.items() 

527 ] 

528 elif raw is not None: 

529 assert scope is None, 'Cannot set both "raw" and "scope".' 

530 self._list = raw 

531 elif scope is not None: 

532 # scope["headers"] isn't necessarily a list 

533 # it might be a tuple or other iterable 

534 self._list = scope["headers"] = list(scope["headers"]) 

535 

536 @property 

537 def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]: 

538 return list(self._list) 

539 

540 def keys(self) -> typing.List[str]: # type: ignore[override] 

541 return [key.decode("latin-1") for key, value in self._list] 

542 

543 def values(self) -> typing.List[str]: # type: ignore[override] 

544 return [value.decode("latin-1") for key, value in self._list] 

545 

546 def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore[override] 

547 return [ 

548 (key.decode("latin-1"), value.decode("latin-1")) 

549 for key, value in self._list 

550 ] 

551 

552 def getlist(self, key: str) -> typing.List[str]: 

553 get_header_key = key.lower().encode("latin-1") 

554 return [ 

555 item_value.decode("latin-1") 

556 for item_key, item_value in self._list 

557 if item_key == get_header_key 

558 ] 

559 

560 def mutablecopy(self) -> "MutableHeaders": 

561 return MutableHeaders(raw=self._list[:]) 

562 

563 def __getitem__(self, key: str) -> str: 

564 get_header_key = key.lower().encode("latin-1") 

565 for header_key, header_value in self._list: 

566 if header_key == get_header_key: 

567 return header_value.decode("latin-1") 

568 raise KeyError(key) 

569 

570 def __contains__(self, key: typing.Any) -> bool: 

571 get_header_key = key.lower().encode("latin-1") 

572 for header_key, header_value in self._list: 

573 if header_key == get_header_key: 

574 return True 

575 return False 

576 

577 def __iter__(self) -> typing.Iterator[typing.Any]: 

578 return iter(self.keys()) 

579 

580 def __len__(self) -> int: 

581 return len(self._list) 

582 

583 def __eq__(self, other: typing.Any) -> bool: 

584 if not isinstance(other, Headers): 

585 return False 

586 return sorted(self._list) == sorted(other._list) 

587 

588 def __repr__(self) -> str: 

589 class_name = self.__class__.__name__ 

590 as_dict = dict(self.items()) 

591 if len(as_dict) == len(self): 

592 return f"{class_name}({as_dict!r})" 

593 return f"{class_name}(raw={self.raw!r})" 

594 

595 

596class MutableHeaders(Headers): 

597 def __setitem__(self, key: str, value: str) -> None: 

598 """ 

599 Set the header `key` to `value`, removing any duplicate entries. 

600 Retains insertion order. 

601 """ 

602 set_key = key.lower().encode("latin-1") 

603 set_value = value.encode("latin-1") 

604 

605 found_indexes: "typing.List[int]" = [] 

606 for idx, (item_key, item_value) in enumerate(self._list): 

607 if item_key == set_key: 

608 found_indexes.append(idx) 

609 

610 for idx in reversed(found_indexes[1:]): 

611 del self._list[idx] 

612 

613 if found_indexes: 

614 idx = found_indexes[0] 

615 self._list[idx] = (set_key, set_value) 

616 else: 

617 self._list.append((set_key, set_value)) 

618 

619 def __delitem__(self, key: str) -> None: 

620 """ 

621 Remove the header `key`. 

622 """ 

623 del_key = key.lower().encode("latin-1") 

624 

625 pop_indexes: "typing.List[int]" = [] 

626 for idx, (item_key, item_value) in enumerate(self._list): 

627 if item_key == del_key: 

628 pop_indexes.append(idx) 

629 

630 for idx in reversed(pop_indexes): 

631 del self._list[idx] 

632 

633 def __ior__(self, other: typing.Mapping[str, str]) -> "MutableHeaders": 

634 if not isinstance(other, typing.Mapping): 

635 raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") 

636 self.update(other) 

637 return self 

638 

639 def __or__(self, other: typing.Mapping[str, str]) -> "MutableHeaders": 

640 if not isinstance(other, typing.Mapping): 

641 raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") 

642 new = self.mutablecopy() 

643 new.update(other) 

644 return new 

645 

646 @property 

647 def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]: 

648 return self._list 

649 

650 def setdefault(self, key: str, value: str) -> str: 

651 """ 

652 If the header `key` does not exist, then set it to `value`. 

653 Returns the header value. 

654 """ 

655 set_key = key.lower().encode("latin-1") 

656 set_value = value.encode("latin-1") 

657 

658 for idx, (item_key, item_value) in enumerate(self._list): 

659 if item_key == set_key: 

660 return item_value.decode("latin-1") 

661 self._list.append((set_key, set_value)) 

662 return value 

663 

664 def update(self, other: typing.Mapping[str, str]) -> None: 

665 for key, val in other.items(): 

666 self[key] = val 

667 

668 def append(self, key: str, value: str) -> None: 

669 """ 

670 Append a header, preserving any duplicate entries. 

671 """ 

672 append_key = key.lower().encode("latin-1") 

673 append_value = value.encode("latin-1") 

674 self._list.append((append_key, append_value)) 

675 

676 def add_vary_header(self, vary: str) -> None: 

677 existing = self.get("vary") 

678 if existing is not None: 

679 vary = ", ".join([existing, vary]) 

680 self["vary"] = vary 

681 

682 

683class State: 

684 """ 

685 An object that can be used to store arbitrary state. 

686 

687 Used for `request.state` and `app.state`. 

688 """ 

689 

690 _state: typing.Dict[str, typing.Any] 

691 

692 def __init__(self, state: typing.Optional[typing.Dict[str, typing.Any]] = None): 

693 if state is None: 

694 state = {} 

695 super().__setattr__("_state", state) 

696 

697 def __setattr__(self, key: typing.Any, value: typing.Any) -> None: 

698 self._state[key] = value 

699 

700 def __getattr__(self, key: typing.Any) -> typing.Any: 

701 try: 

702 return self._state[key] 

703 except KeyError: 

704 message = "'{}' object has no attribute '{}'" 

705 raise AttributeError(message.format(self.__class__.__name__, key)) 

706 

707 def __delattr__(self, key: typing.Any) -> None: 

708 del self._state[key]