1from __future__ import annotations
2
3from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, MutableMapping, Sequence, ValuesView
4from shlex import shlex
5from typing import (
6 Any,
7 BinaryIO,
8 NamedTuple,
9 TypeVar,
10 cast,
11)
12from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
13
14from starlette.concurrency import run_in_threadpool
15from starlette.types import Scope
16
17
18class Address(NamedTuple):
19 host: str
20 port: int
21
22
23_KeyType = TypeVar("_KeyType")
24# Mapping keys are invariant but their values are covariant since
25# you can only read them
26# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()`
27_CovariantValueType = TypeVar("_CovariantValueType", covariant=True)
28
29
30class URL:
31 def __init__(
32 self,
33 url: str = "",
34 scope: Scope | None = None,
35 **components: Any,
36 ) -> None:
37 if scope is not None:
38 assert not url, 'Cannot set both "url" and "scope".'
39 assert not components, 'Cannot set both "scope" and "**components".'
40 scheme = scope.get("scheme", "http")
41 server = scope.get("server", None)
42 path = scope["path"]
43 query_string = scope.get("query_string", b"")
44
45 host_header = None
46 for key, value in scope["headers"]:
47 if key == b"host":
48 host_header = value.decode("latin-1")
49 break
50
51 if host_header is not None:
52 url = f"{scheme}://{host_header}{path}"
53 elif server is None:
54 url = path
55 else:
56 host, port = server
57 default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme]
58 if port == default_port:
59 url = f"{scheme}://{host}{path}"
60 else:
61 url = f"{scheme}://{host}:{port}{path}"
62
63 if query_string:
64 url += "?" + query_string.decode()
65 elif components:
66 assert not url, 'Cannot set both "url" and "**components".'
67 url = URL("").replace(**components).components.geturl()
68
69 self._url = url
70
71 @property
72 def components(self) -> SplitResult:
73 if not hasattr(self, "_components"):
74 self._components = urlsplit(self._url)
75 return self._components
76
77 @property
78 def scheme(self) -> str:
79 return self.components.scheme
80
81 @property
82 def netloc(self) -> str:
83 return self.components.netloc
84
85 @property
86 def path(self) -> str:
87 return self.components.path
88
89 @property
90 def query(self) -> str:
91 return self.components.query
92
93 @property
94 def fragment(self) -> str:
95 return self.components.fragment
96
97 @property
98 def username(self) -> None | str:
99 return self.components.username
100
101 @property
102 def password(self) -> None | str:
103 return self.components.password
104
105 @property
106 def hostname(self) -> None | str:
107 return self.components.hostname
108
109 @property
110 def port(self) -> int | None:
111 return self.components.port
112
113 @property
114 def is_secure(self) -> bool:
115 return self.scheme in ("https", "wss")
116
117 def replace(self, **kwargs: Any) -> URL:
118 if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs:
119 hostname = kwargs.pop("hostname", None)
120 port = kwargs.pop("port", self.port)
121 username = kwargs.pop("username", self.username)
122 password = kwargs.pop("password", self.password)
123
124 if hostname is None:
125 netloc = self.netloc
126 _, _, hostname = netloc.rpartition("@")
127
128 if hostname[-1] != "]":
129 hostname = hostname.rsplit(":", 1)[0]
130
131 netloc = hostname
132 if port is not None:
133 netloc += f":{port}"
134 if username is not None:
135 userpass = username
136 if password is not None:
137 userpass += f":{password}"
138 netloc = f"{userpass}@{netloc}"
139
140 kwargs["netloc"] = netloc
141
142 components = self.components._replace(**kwargs)
143 return self.__class__(components.geturl())
144
145 def include_query_params(self, **kwargs: Any) -> URL:
146 params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
147 params.update({str(key): str(value) for key, value in kwargs.items()})
148 query = urlencode(params.multi_items())
149 return self.replace(query=query)
150
151 def replace_query_params(self, **kwargs: Any) -> URL:
152 query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
153 return self.replace(query=query)
154
155 def remove_query_params(self, keys: str | Sequence[str]) -> URL:
156 if isinstance(keys, str):
157 keys = [keys]
158 params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
159 for key in keys:
160 params.pop(key, None)
161 query = urlencode(params.multi_items())
162 return self.replace(query=query)
163
164 def __eq__(self, other: Any) -> bool:
165 return str(self) == str(other)
166
167 def __str__(self) -> str:
168 return self._url
169
170 def __repr__(self) -> str:
171 url = str(self)
172 if self.password:
173 url = str(self.replace(password="********"))
174 return f"{self.__class__.__name__}({repr(url)})"
175
176
177class URLPath(str):
178 """
179 A URL path string that may also hold an associated protocol and/or host.
180 Used by the routing to return `url_path_for` matches.
181 """
182
183 def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath:
184 assert protocol in ("http", "websocket", "")
185 return str.__new__(cls, path)
186
187 def __init__(self, path: str, protocol: str = "", host: str = "") -> None:
188 self.protocol = protocol
189 self.host = host
190
191 def make_absolute_url(self, base_url: str | URL) -> URL:
192 if isinstance(base_url, str):
193 base_url = URL(base_url)
194 if self.protocol:
195 scheme = {
196 "http": {True: "https", False: "http"},
197 "websocket": {True: "wss", False: "ws"},
198 }[self.protocol][base_url.is_secure]
199 else:
200 scheme = base_url.scheme
201
202 netloc = self.host or base_url.netloc
203 path = base_url.path.rstrip("/") + str(self)
204 return URL(scheme=scheme, netloc=netloc, path=path)
205
206
207class Secret:
208 """
209 Holds a string value that should not be revealed in tracebacks etc.
210 You should cast the value to `str` at the point it is required.
211 """
212
213 def __init__(self, value: str):
214 self._value = value
215
216 def __repr__(self) -> str:
217 class_name = self.__class__.__name__
218 return f"{class_name}('**********')"
219
220 def __str__(self) -> str:
221 return self._value
222
223 def __bool__(self) -> bool:
224 return bool(self._value)
225
226
227class CommaSeparatedStrings(Sequence[str]):
228 def __init__(self, value: str | Sequence[str]):
229 if isinstance(value, str):
230 splitter = shlex(value, posix=True)
231 splitter.whitespace = ","
232 splitter.whitespace_split = True
233 self._items = [item.strip() for item in splitter]
234 else:
235 self._items = list(value)
236
237 def __len__(self) -> int:
238 return len(self._items)
239
240 def __getitem__(self, index: int | slice) -> Any:
241 return self._items[index]
242
243 def __iter__(self) -> Iterator[str]:
244 return iter(self._items)
245
246 def __repr__(self) -> str:
247 class_name = self.__class__.__name__
248 items = [item for item in self]
249 return f"{class_name}({items!r})"
250
251 def __str__(self) -> str:
252 return ", ".join(repr(item) for item in self)
253
254
255class ImmutableMultiDict(Mapping[_KeyType, _CovariantValueType]):
256 _dict: dict[_KeyType, _CovariantValueType]
257
258 def __init__(
259 self,
260 *args: ImmutableMultiDict[_KeyType, _CovariantValueType]
261 | Mapping[_KeyType, _CovariantValueType]
262 | Iterable[tuple[_KeyType, _CovariantValueType]],
263 **kwargs: Any,
264 ) -> None:
265 assert len(args) < 2, "Too many arguments."
266
267 value: Any = args[0] if args else []
268 if kwargs:
269 value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items()
270
271 if not value:
272 _items: list[tuple[Any, Any]] = []
273 elif hasattr(value, "multi_items"):
274 value = cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value)
275 _items = list(value.multi_items())
276 elif hasattr(value, "items"):
277 value = cast(Mapping[_KeyType, _CovariantValueType], value)
278 _items = list(value.items())
279 else:
280 value = cast("list[tuple[Any, Any]]", value)
281 _items = list(value)
282
283 self._dict = {k: v for k, v in _items}
284 self._list = _items
285
286 def getlist(self, key: Any) -> list[_CovariantValueType]:
287 return [item_value for item_key, item_value in self._list if item_key == key]
288
289 def keys(self) -> KeysView[_KeyType]:
290 return self._dict.keys()
291
292 def values(self) -> ValuesView[_CovariantValueType]:
293 return self._dict.values()
294
295 def items(self) -> ItemsView[_KeyType, _CovariantValueType]:
296 return self._dict.items()
297
298 def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]:
299 return list(self._list)
300
301 def __getitem__(self, key: _KeyType) -> _CovariantValueType:
302 return self._dict[key]
303
304 def __contains__(self, key: Any) -> bool:
305 return key in self._dict
306
307 def __iter__(self) -> Iterator[_KeyType]:
308 return iter(self.keys())
309
310 def __len__(self) -> int:
311 return len(self._dict)
312
313 def __eq__(self, other: Any) -> bool:
314 if not isinstance(other, self.__class__):
315 return False
316 return sorted(self._list) == sorted(other._list)
317
318 def __repr__(self) -> str:
319 class_name = self.__class__.__name__
320 items = self.multi_items()
321 return f"{class_name}({items!r})"
322
323
324class MultiDict(ImmutableMultiDict[Any, Any]):
325 def __setitem__(self, key: Any, value: Any) -> None:
326 self.setlist(key, [value])
327
328 def __delitem__(self, key: Any) -> None:
329 self._list = [(k, v) for k, v in self._list if k != key]
330 del self._dict[key]
331
332 def pop(self, key: Any, default: Any = None) -> Any:
333 self._list = [(k, v) for k, v in self._list if k != key]
334 return self._dict.pop(key, default)
335
336 def popitem(self) -> tuple[Any, Any]:
337 key, value = self._dict.popitem()
338 self._list = [(k, v) for k, v in self._list if k != key]
339 return key, value
340
341 def poplist(self, key: Any) -> list[Any]:
342 values = [v for k, v in self._list if k == key]
343 self.pop(key)
344 return values
345
346 def clear(self) -> None:
347 self._dict.clear()
348 self._list.clear()
349
350 def setdefault(self, key: Any, default: Any = None) -> Any:
351 if key not in self:
352 self._dict[key] = default
353 self._list.append((key, default))
354
355 return self[key]
356
357 def setlist(self, key: Any, values: list[Any]) -> None:
358 if not values:
359 self.pop(key, None)
360 else:
361 existing_items = [(k, v) for (k, v) in self._list if k != key]
362 self._list = existing_items + [(key, value) for value in values]
363 self._dict[key] = values[-1]
364
365 def append(self, key: Any, value: Any) -> None:
366 self._list.append((key, value))
367 self._dict[key] = value
368
369 def update(
370 self,
371 *args: MultiDict | Mapping[Any, Any] | list[tuple[Any, Any]],
372 **kwargs: Any,
373 ) -> None:
374 value = MultiDict(*args, **kwargs)
375 existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()]
376 self._list = existing_items + value.multi_items()
377 self._dict.update(value)
378
379
380class QueryParams(ImmutableMultiDict[str, str]):
381 """
382 An immutable multidict.
383 """
384
385 def __init__(
386 self,
387 *args: ImmutableMultiDict[Any, Any] | Mapping[Any, Any] | list[tuple[Any, Any]] | str | bytes,
388 **kwargs: Any,
389 ) -> None:
390 assert len(args) < 2, "Too many arguments."
391
392 value = args[0] if args else []
393
394 if isinstance(value, str):
395 super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
396 elif isinstance(value, bytes):
397 super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs)
398 else:
399 super().__init__(*args, **kwargs) # type: ignore[arg-type]
400 self._list = [(str(k), str(v)) for k, v in self._list]
401 self._dict = {str(k): str(v) for k, v in self._dict.items()}
402
403 def __str__(self) -> str:
404 return urlencode(self._list)
405
406 def __repr__(self) -> str:
407 class_name = self.__class__.__name__
408 query_string = str(self)
409 return f"{class_name}({query_string!r})"
410
411
412class UploadFile:
413 """
414 An uploaded file included as part of the request data.
415 """
416
417 def __init__(
418 self,
419 file: BinaryIO,
420 *,
421 size: int | None = None,
422 filename: str | None = None,
423 headers: Headers | None = None,
424 ) -> None:
425 self.filename = filename
426 self.file = file
427 self.size = size
428 self.headers = headers or Headers()
429
430 # Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks.
431 # Note 0 means unlimited mirroring SpooledTemporaryFile's __init__
432 self._max_mem_size = getattr(self.file, "_max_size", 0)
433
434 @property
435 def content_type(self) -> str | None:
436 return self.headers.get("content-type", None)
437
438 @property
439 def _in_memory(self) -> bool:
440 # check for SpooledTemporaryFile._rolled
441 rolled_to_disk = getattr(self.file, "_rolled", True)
442 return not rolled_to_disk
443
444 def _will_roll(self, size_to_add: int) -> bool:
445 # If we're not in_memory then we will always roll
446 if not self._in_memory:
447 return True
448
449 # Check for SpooledTemporaryFile._max_size
450 future_size = self.file.tell() + size_to_add
451 return bool(future_size > self._max_mem_size) if self._max_mem_size else False
452
453 async def write(self, data: bytes) -> None:
454 new_data_len = len(data)
455 if self.size is not None:
456 self.size += new_data_len
457
458 if self._will_roll(new_data_len):
459 await run_in_threadpool(self.file.write, data)
460 else:
461 self.file.write(data)
462
463 async def read(self, size: int = -1) -> bytes:
464 if self._in_memory:
465 return self.file.read(size)
466 return await run_in_threadpool(self.file.read, size)
467
468 async def seek(self, offset: int) -> None:
469 if self._in_memory:
470 self.file.seek(offset)
471 else:
472 await run_in_threadpool(self.file.seek, offset)
473
474 async def close(self) -> None:
475 if self._in_memory:
476 self.file.close()
477 else:
478 await run_in_threadpool(self.file.close)
479
480 def __repr__(self) -> str:
481 return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})"
482
483
484class FormData(ImmutableMultiDict[str, UploadFile | str]):
485 """
486 An immutable multidict, containing both file uploads and text input.
487 """
488
489 def __init__(
490 self,
491 *args: FormData | Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]],
492 **kwargs: str | UploadFile,
493 ) -> None:
494 super().__init__(*args, **kwargs)
495
496 async def close(self) -> None:
497 for key, value in self.multi_items():
498 if isinstance(value, UploadFile):
499 await value.close()
500
501
502class Headers(Mapping[str, str]):
503 """
504 An immutable, case-insensitive multidict.
505 """
506
507 def __init__(
508 self,
509 headers: Mapping[str, str] | None = None,
510 raw: list[tuple[bytes, bytes]] | None = None,
511 scope: MutableMapping[str, Any] | None = None,
512 ) -> None:
513 self._list: list[tuple[bytes, bytes]] = []
514 if headers is not None:
515 assert raw is None, 'Cannot set both "headers" and "raw".'
516 assert scope is None, 'Cannot set both "headers" and "scope".'
517 self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()]
518 elif raw is not None:
519 assert scope is None, 'Cannot set both "raw" and "scope".'
520 self._list = raw
521 elif scope is not None:
522 # scope["headers"] isn't necessarily a list
523 # it might be a tuple or other iterable
524 self._list = scope["headers"] = list(scope["headers"])
525
526 @property
527 def raw(self) -> list[tuple[bytes, bytes]]:
528 return list(self._list)
529
530 def keys(self) -> list[str]: # type: ignore[override]
531 return [key.decode("latin-1") for key, value in self._list]
532
533 def values(self) -> list[str]: # type: ignore[override]
534 return [value.decode("latin-1") for key, value in self._list]
535
536 def items(self) -> list[tuple[str, str]]: # type: ignore[override]
537 return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list]
538
539 def getlist(self, key: str) -> list[str]:
540 get_header_key = key.lower().encode("latin-1")
541 return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key]
542
543 def mutablecopy(self) -> MutableHeaders:
544 return MutableHeaders(raw=self._list[:])
545
546 def __getitem__(self, key: str) -> str:
547 get_header_key = key.lower().encode("latin-1")
548 for header_key, header_value in self._list:
549 if header_key == get_header_key:
550 return header_value.decode("latin-1")
551 raise KeyError(key)
552
553 def __contains__(self, key: Any) -> bool:
554 get_header_key = key.lower().encode("latin-1")
555 for header_key, header_value in self._list:
556 if header_key == get_header_key:
557 return True
558 return False
559
560 def __iter__(self) -> Iterator[Any]:
561 return iter(self.keys())
562
563 def __len__(self) -> int:
564 return len(self._list)
565
566 def __eq__(self, other: Any) -> bool:
567 if not isinstance(other, Headers):
568 return False
569 return sorted(self._list) == sorted(other._list)
570
571 def __repr__(self) -> str:
572 class_name = self.__class__.__name__
573 as_dict = dict(self.items())
574 if len(as_dict) == len(self):
575 return f"{class_name}({as_dict!r})"
576 return f"{class_name}(raw={self.raw!r})"
577
578
579class MutableHeaders(Headers):
580 def __setitem__(self, key: str, value: str) -> None:
581 """
582 Set the header `key` to `value`, removing any duplicate entries.
583 Retains insertion order.
584 """
585 set_key = key.lower().encode("latin-1")
586 set_value = value.encode("latin-1")
587
588 found_indexes: list[int] = []
589 for idx, (item_key, item_value) in enumerate(self._list):
590 if item_key == set_key:
591 found_indexes.append(idx)
592
593 for idx in reversed(found_indexes[1:]):
594 del self._list[idx]
595
596 if found_indexes:
597 idx = found_indexes[0]
598 self._list[idx] = (set_key, set_value)
599 else:
600 self._list.append((set_key, set_value))
601
602 def __delitem__(self, key: str) -> None:
603 """
604 Remove the header `key`.
605 """
606 del_key = key.lower().encode("latin-1")
607
608 pop_indexes: list[int] = []
609 for idx, (item_key, item_value) in enumerate(self._list):
610 if item_key == del_key:
611 pop_indexes.append(idx)
612
613 for idx in reversed(pop_indexes):
614 del self._list[idx]
615
616 def __ior__(self, other: Mapping[str, str]) -> MutableHeaders:
617 if not isinstance(other, Mapping):
618 raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
619 self.update(other)
620 return self
621
622 def __or__(self, other: Mapping[str, str]) -> MutableHeaders:
623 if not isinstance(other, Mapping):
624 raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
625 new = self.mutablecopy()
626 new.update(other)
627 return new
628
629 @property
630 def raw(self) -> list[tuple[bytes, bytes]]:
631 return self._list
632
633 def setdefault(self, key: str, value: str) -> str:
634 """
635 If the header `key` does not exist, then set it to `value`.
636 Returns the header value.
637 """
638 set_key = key.lower().encode("latin-1")
639 set_value = value.encode("latin-1")
640
641 for idx, (item_key, item_value) in enumerate(self._list):
642 if item_key == set_key:
643 return item_value.decode("latin-1")
644 self._list.append((set_key, set_value))
645 return value
646
647 def update(self, other: Mapping[str, str]) -> None:
648 for key, val in other.items():
649 self[key] = val
650
651 def append(self, key: str, value: str) -> None:
652 """
653 Append a header, preserving any duplicate entries.
654 """
655 append_key = key.lower().encode("latin-1")
656 append_value = value.encode("latin-1")
657 self._list.append((append_key, append_value))
658
659 def add_vary_header(self, vary: str) -> None:
660 existing = self.get("vary")
661 if existing is not None:
662 vary = ", ".join([existing, vary])
663 self["vary"] = vary
664
665
666class State:
667 """
668 An object that can be used to store arbitrary state.
669
670 Used for `request.state` and `app.state`.
671 """
672
673 _state: dict[str, Any]
674
675 def __init__(self, state: dict[str, Any] | None = None):
676 if state is None:
677 state = {}
678 super().__setattr__("_state", state)
679
680 def __setattr__(self, key: Any, value: Any) -> None:
681 self._state[key] = value
682
683 def __getattr__(self, key: Any) -> Any:
684 try:
685 return self._state[key]
686 except KeyError:
687 message = "'{}' object has no attribute '{}'"
688 raise AttributeError(message.format(self.__class__.__name__, key))
689
690 def __delattr__(self, key: Any) -> None:
691 del self._state[key]