1from __future__ import annotations
2
3from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, MutableMapping, Sequence, ValuesView
4from shlex import shlex
5from typing import (
6 Any,
7 BinaryIO,
8 NamedTuple,
9 TypeVar,
10 Union,
11 cast,
12)
13from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
14
15from starlette.concurrency import run_in_threadpool
16from starlette.types import Scope
17
18
19class Address(NamedTuple):
20 host: str
21 port: int
22
23
24_KeyType = TypeVar("_KeyType")
25# Mapping keys are invariant but their values are covariant since
26# you can only read them
27# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()`
28_CovariantValueType = TypeVar("_CovariantValueType", covariant=True)
29
30
31class URL:
32 def __init__(
33 self,
34 url: str = "",
35 scope: Scope | None = None,
36 **components: Any,
37 ) -> None:
38 if scope is not None:
39 assert not url, 'Cannot set both "url" and "scope".'
40 assert not components, 'Cannot set both "scope" and "**components".'
41 scheme = scope.get("scheme", "http")
42 server = scope.get("server", None)
43 path = scope["path"]
44 query_string = scope.get("query_string", b"")
45
46 host_header = None
47 for key, value in scope["headers"]:
48 if key == b"host":
49 host_header = value.decode("latin-1")
50 break
51
52 if host_header is not None:
53 url = f"{scheme}://{host_header}{path}"
54 elif server is None:
55 url = path
56 else:
57 host, port = server
58 default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme]
59 if port == default_port:
60 url = f"{scheme}://{host}{path}"
61 else:
62 url = f"{scheme}://{host}:{port}{path}"
63
64 if query_string:
65 url += "?" + query_string.decode()
66 elif components:
67 assert not url, 'Cannot set both "url" and "**components".'
68 url = URL("").replace(**components).components.geturl()
69
70 self._url = url
71
72 @property
73 def components(self) -> SplitResult:
74 if not hasattr(self, "_components"):
75 self._components = urlsplit(self._url)
76 return self._components
77
78 @property
79 def scheme(self) -> str:
80 return self.components.scheme
81
82 @property
83 def netloc(self) -> str:
84 return self.components.netloc
85
86 @property
87 def path(self) -> str:
88 return self.components.path
89
90 @property
91 def query(self) -> str:
92 return self.components.query
93
94 @property
95 def fragment(self) -> str:
96 return self.components.fragment
97
98 @property
99 def username(self) -> None | str:
100 return self.components.username
101
102 @property
103 def password(self) -> None | str:
104 return self.components.password
105
106 @property
107 def hostname(self) -> None | str:
108 return self.components.hostname
109
110 @property
111 def port(self) -> int | None:
112 return self.components.port
113
114 @property
115 def is_secure(self) -> bool:
116 return self.scheme in ("https", "wss")
117
118 def replace(self, **kwargs: Any) -> URL:
119 if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs:
120 hostname = kwargs.pop("hostname", None)
121 port = kwargs.pop("port", self.port)
122 username = kwargs.pop("username", self.username)
123 password = kwargs.pop("password", self.password)
124
125 if hostname is None:
126 netloc = self.netloc
127 _, _, hostname = netloc.rpartition("@")
128
129 if hostname[-1] != "]":
130 hostname = hostname.rsplit(":", 1)[0]
131
132 netloc = hostname
133 if port is not None:
134 netloc += f":{port}"
135 if username is not None:
136 userpass = username
137 if password is not None:
138 userpass += f":{password}"
139 netloc = f"{userpass}@{netloc}"
140
141 kwargs["netloc"] = netloc
142
143 components = self.components._replace(**kwargs)
144 return self.__class__(components.geturl())
145
146 def include_query_params(self, **kwargs: Any) -> URL:
147 params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
148 params.update({str(key): str(value) for key, value in kwargs.items()})
149 query = urlencode(params.multi_items())
150 return self.replace(query=query)
151
152 def replace_query_params(self, **kwargs: Any) -> URL:
153 query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
154 return self.replace(query=query)
155
156 def remove_query_params(self, keys: str | Sequence[str]) -> URL:
157 if isinstance(keys, str):
158 keys = [keys]
159 params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
160 for key in keys:
161 params.pop(key, None)
162 query = urlencode(params.multi_items())
163 return self.replace(query=query)
164
165 def __eq__(self, other: Any) -> bool:
166 return str(self) == str(other)
167
168 def __str__(self) -> str:
169 return self._url
170
171 def __repr__(self) -> str:
172 url = str(self)
173 if self.password:
174 url = str(self.replace(password="********"))
175 return f"{self.__class__.__name__}({repr(url)})"
176
177
178class URLPath(str):
179 """
180 A URL path string that may also hold an associated protocol and/or host.
181 Used by the routing to return `url_path_for` matches.
182 """
183
184 def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath:
185 assert protocol in ("http", "websocket", "")
186 return str.__new__(cls, path)
187
188 def __init__(self, path: str, protocol: str = "", host: str = "") -> None:
189 self.protocol = protocol
190 self.host = host
191
192 def make_absolute_url(self, base_url: str | URL) -> URL:
193 if isinstance(base_url, str):
194 base_url = URL(base_url)
195 if self.protocol:
196 scheme = {
197 "http": {True: "https", False: "http"},
198 "websocket": {True: "wss", False: "ws"},
199 }[self.protocol][base_url.is_secure]
200 else:
201 scheme = base_url.scheme
202
203 netloc = self.host or base_url.netloc
204 path = base_url.path.rstrip("/") + str(self)
205 return URL(scheme=scheme, netloc=netloc, path=path)
206
207
208class Secret:
209 """
210 Holds a string value that should not be revealed in tracebacks etc.
211 You should cast the value to `str` at the point it is required.
212 """
213
214 def __init__(self, value: str):
215 self._value = value
216
217 def __repr__(self) -> str:
218 class_name = self.__class__.__name__
219 return f"{class_name}('**********')"
220
221 def __str__(self) -> str:
222 return self._value
223
224 def __bool__(self) -> bool:
225 return bool(self._value)
226
227
228class CommaSeparatedStrings(Sequence[str]):
229 def __init__(self, value: str | Sequence[str]):
230 if isinstance(value, str):
231 splitter = shlex(value, posix=True)
232 splitter.whitespace = ","
233 splitter.whitespace_split = True
234 self._items = [item.strip() for item in splitter]
235 else:
236 self._items = list(value)
237
238 def __len__(self) -> int:
239 return len(self._items)
240
241 def __getitem__(self, index: int | slice) -> Any:
242 return self._items[index]
243
244 def __iter__(self) -> Iterator[str]:
245 return iter(self._items)
246
247 def __repr__(self) -> str:
248 class_name = self.__class__.__name__
249 items = [item for item in self]
250 return f"{class_name}({items!r})"
251
252 def __str__(self) -> str:
253 return ", ".join(repr(item) for item in self)
254
255
256class ImmutableMultiDict(Mapping[_KeyType, _CovariantValueType]):
257 _dict: dict[_KeyType, _CovariantValueType]
258
259 def __init__(
260 self,
261 *args: ImmutableMultiDict[_KeyType, _CovariantValueType]
262 | Mapping[_KeyType, _CovariantValueType]
263 | Iterable[tuple[_KeyType, _CovariantValueType]],
264 **kwargs: Any,
265 ) -> None:
266 assert len(args) < 2, "Too many arguments."
267
268 value: Any = args[0] if args else []
269 if kwargs:
270 value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items()
271
272 if not value:
273 _items: list[tuple[Any, Any]] = []
274 elif hasattr(value, "multi_items"):
275 value = cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value)
276 _items = list(value.multi_items())
277 elif hasattr(value, "items"):
278 value = cast(Mapping[_KeyType, _CovariantValueType], value)
279 _items = list(value.items())
280 else:
281 value = cast("list[tuple[Any, Any]]", value)
282 _items = list(value)
283
284 self._dict = {k: v for k, v in _items}
285 self._list = _items
286
287 def getlist(self, key: Any) -> list[_CovariantValueType]:
288 return [item_value for item_key, item_value in self._list if item_key == key]
289
290 def keys(self) -> KeysView[_KeyType]:
291 return self._dict.keys()
292
293 def values(self) -> ValuesView[_CovariantValueType]:
294 return self._dict.values()
295
296 def items(self) -> ItemsView[_KeyType, _CovariantValueType]:
297 return self._dict.items()
298
299 def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]:
300 return list(self._list)
301
302 def __getitem__(self, key: _KeyType) -> _CovariantValueType:
303 return self._dict[key]
304
305 def __contains__(self, key: Any) -> bool:
306 return key in self._dict
307
308 def __iter__(self) -> Iterator[_KeyType]:
309 return iter(self.keys())
310
311 def __len__(self) -> int:
312 return len(self._dict)
313
314 def __eq__(self, other: Any) -> bool:
315 if not isinstance(other, self.__class__):
316 return False
317 return sorted(self._list) == sorted(other._list)
318
319 def __repr__(self) -> str:
320 class_name = self.__class__.__name__
321 items = self.multi_items()
322 return f"{class_name}({items!r})"
323
324
325class MultiDict(ImmutableMultiDict[Any, Any]):
326 def __setitem__(self, key: Any, value: Any) -> None:
327 self.setlist(key, [value])
328
329 def __delitem__(self, key: Any) -> None:
330 self._list = [(k, v) for k, v in self._list if k != key]
331 del self._dict[key]
332
333 def pop(self, key: Any, default: Any = None) -> Any:
334 self._list = [(k, v) for k, v in self._list if k != key]
335 return self._dict.pop(key, default)
336
337 def popitem(self) -> tuple[Any, Any]:
338 key, value = self._dict.popitem()
339 self._list = [(k, v) for k, v in self._list if k != key]
340 return key, value
341
342 def poplist(self, key: Any) -> list[Any]:
343 values = [v for k, v in self._list if k == key]
344 self.pop(key)
345 return values
346
347 def clear(self) -> None:
348 self._dict.clear()
349 self._list.clear()
350
351 def setdefault(self, key: Any, default: Any = None) -> Any:
352 if key not in self:
353 self._dict[key] = default
354 self._list.append((key, default))
355
356 return self[key]
357
358 def setlist(self, key: Any, values: list[Any]) -> None:
359 if not values:
360 self.pop(key, None)
361 else:
362 existing_items = [(k, v) for (k, v) in self._list if k != key]
363 self._list = existing_items + [(key, value) for value in values]
364 self._dict[key] = values[-1]
365
366 def append(self, key: Any, value: Any) -> None:
367 self._list.append((key, value))
368 self._dict[key] = value
369
370 def update(
371 self,
372 *args: MultiDict | Mapping[Any, Any] | list[tuple[Any, Any]],
373 **kwargs: Any,
374 ) -> None:
375 value = MultiDict(*args, **kwargs)
376 existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()]
377 self._list = existing_items + value.multi_items()
378 self._dict.update(value)
379
380
381class QueryParams(ImmutableMultiDict[str, str]):
382 """
383 An immutable multidict.
384 """
385
386 def __init__(
387 self,
388 *args: ImmutableMultiDict[Any, Any] | Mapping[Any, Any] | list[tuple[Any, Any]] | str | bytes,
389 **kwargs: Any,
390 ) -> None:
391 assert len(args) < 2, "Too many arguments."
392
393 value = args[0] if args else []
394
395 if isinstance(value, str):
396 super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
397 elif isinstance(value, bytes):
398 super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs)
399 else:
400 super().__init__(*args, **kwargs) # type: ignore[arg-type]
401 self._list = [(str(k), str(v)) for k, v in self._list]
402 self._dict = {str(k): str(v) for k, v in self._dict.items()}
403
404 def __str__(self) -> str:
405 return urlencode(self._list)
406
407 def __repr__(self) -> str:
408 class_name = self.__class__.__name__
409 query_string = str(self)
410 return f"{class_name}({query_string!r})"
411
412
413class UploadFile:
414 """
415 An uploaded file included as part of the request data.
416 """
417
418 def __init__(
419 self,
420 file: BinaryIO,
421 *,
422 size: int | None = None,
423 filename: str | None = None,
424 headers: Headers | None = None,
425 ) -> None:
426 self.filename = filename
427 self.file = file
428 self.size = size
429 self.headers = headers or Headers()
430
431 # Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks.
432 # Note 0 means unlimited mirroring SpooledTemporaryFile's __init__
433 self._max_mem_size = getattr(self.file, "_max_size", 0)
434
435 @property
436 def content_type(self) -> str | None:
437 return self.headers.get("content-type", None)
438
439 @property
440 def _in_memory(self) -> bool:
441 # check for SpooledTemporaryFile._rolled
442 rolled_to_disk = getattr(self.file, "_rolled", True)
443 return not rolled_to_disk
444
445 def _will_roll(self, size_to_add: int) -> bool:
446 # If we're not in_memory then we will always roll
447 if not self._in_memory:
448 return True
449
450 # Check for SpooledTemporaryFile._max_size
451 future_size = self.file.tell() + size_to_add
452 return bool(future_size > self._max_mem_size) if self._max_mem_size else False
453
454 async def write(self, data: bytes) -> None:
455 new_data_len = len(data)
456 if self.size is not None:
457 self.size += new_data_len
458
459 if self._will_roll(new_data_len):
460 await run_in_threadpool(self.file.write, data)
461 else:
462 self.file.write(data)
463
464 async def read(self, size: int = -1) -> bytes:
465 if self._in_memory:
466 return self.file.read(size)
467 return await run_in_threadpool(self.file.read, size)
468
469 async def seek(self, offset: int) -> None:
470 if self._in_memory:
471 self.file.seek(offset)
472 else:
473 await run_in_threadpool(self.file.seek, offset)
474
475 async def close(self) -> None:
476 if self._in_memory:
477 self.file.close()
478 else:
479 await run_in_threadpool(self.file.close)
480
481 def __repr__(self) -> str:
482 return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})"
483
484
485class FormData(ImmutableMultiDict[str, Union[UploadFile, str]]):
486 """
487 An immutable multidict, containing both file uploads and text input.
488 """
489
490 def __init__(
491 self,
492 *args: FormData | Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]],
493 **kwargs: str | UploadFile,
494 ) -> None:
495 super().__init__(*args, **kwargs)
496
497 async def close(self) -> None:
498 for key, value in self.multi_items():
499 if isinstance(value, UploadFile):
500 await value.close()
501
502
503class Headers(Mapping[str, str]):
504 """
505 An immutable, case-insensitive multidict.
506 """
507
508 def __init__(
509 self,
510 headers: Mapping[str, str] | None = None,
511 raw: list[tuple[bytes, bytes]] | None = None,
512 scope: MutableMapping[str, Any] | None = None,
513 ) -> None:
514 self._list: list[tuple[bytes, bytes]] = []
515 if headers is not None:
516 assert raw is None, 'Cannot set both "headers" and "raw".'
517 assert scope is None, 'Cannot set both "headers" and "scope".'
518 self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()]
519 elif raw is not None:
520 assert scope is None, 'Cannot set both "raw" and "scope".'
521 self._list = raw
522 elif scope is not None:
523 # scope["headers"] isn't necessarily a list
524 # it might be a tuple or other iterable
525 self._list = scope["headers"] = list(scope["headers"])
526
527 @property
528 def raw(self) -> list[tuple[bytes, bytes]]:
529 return list(self._list)
530
531 def keys(self) -> list[str]: # type: ignore[override]
532 return [key.decode("latin-1") for key, value in self._list]
533
534 def values(self) -> list[str]: # type: ignore[override]
535 return [value.decode("latin-1") for key, value in self._list]
536
537 def items(self) -> list[tuple[str, str]]: # type: ignore[override]
538 return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list]
539
540 def getlist(self, key: str) -> list[str]:
541 get_header_key = key.lower().encode("latin-1")
542 return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key]
543
544 def mutablecopy(self) -> MutableHeaders:
545 return MutableHeaders(raw=self._list[:])
546
547 def __getitem__(self, key: str) -> str:
548 get_header_key = key.lower().encode("latin-1")
549 for header_key, header_value in self._list:
550 if header_key == get_header_key:
551 return header_value.decode("latin-1")
552 raise KeyError(key)
553
554 def __contains__(self, key: Any) -> bool:
555 get_header_key = key.lower().encode("latin-1")
556 for header_key, header_value in self._list:
557 if header_key == get_header_key:
558 return True
559 return False
560
561 def __iter__(self) -> Iterator[Any]:
562 return iter(self.keys())
563
564 def __len__(self) -> int:
565 return len(self._list)
566
567 def __eq__(self, other: Any) -> bool:
568 if not isinstance(other, Headers):
569 return False
570 return sorted(self._list) == sorted(other._list)
571
572 def __repr__(self) -> str:
573 class_name = self.__class__.__name__
574 as_dict = dict(self.items())
575 if len(as_dict) == len(self):
576 return f"{class_name}({as_dict!r})"
577 return f"{class_name}(raw={self.raw!r})"
578
579
580class MutableHeaders(Headers):
581 def __setitem__(self, key: str, value: str) -> None:
582 """
583 Set the header `key` to `value`, removing any duplicate entries.
584 Retains insertion order.
585 """
586 set_key = key.lower().encode("latin-1")
587 set_value = value.encode("latin-1")
588
589 found_indexes: list[int] = []
590 for idx, (item_key, item_value) in enumerate(self._list):
591 if item_key == set_key:
592 found_indexes.append(idx)
593
594 for idx in reversed(found_indexes[1:]):
595 del self._list[idx]
596
597 if found_indexes:
598 idx = found_indexes[0]
599 self._list[idx] = (set_key, set_value)
600 else:
601 self._list.append((set_key, set_value))
602
603 def __delitem__(self, key: str) -> None:
604 """
605 Remove the header `key`.
606 """
607 del_key = key.lower().encode("latin-1")
608
609 pop_indexes: list[int] = []
610 for idx, (item_key, item_value) in enumerate(self._list):
611 if item_key == del_key:
612 pop_indexes.append(idx)
613
614 for idx in reversed(pop_indexes):
615 del self._list[idx]
616
617 def __ior__(self, other: Mapping[str, str]) -> MutableHeaders:
618 if not isinstance(other, Mapping):
619 raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
620 self.update(other)
621 return self
622
623 def __or__(self, other: Mapping[str, str]) -> MutableHeaders:
624 if not isinstance(other, Mapping):
625 raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
626 new = self.mutablecopy()
627 new.update(other)
628 return new
629
630 @property
631 def raw(self) -> list[tuple[bytes, bytes]]:
632 return self._list
633
634 def setdefault(self, key: str, value: str) -> str:
635 """
636 If the header `key` does not exist, then set it to `value`.
637 Returns the header value.
638 """
639 set_key = key.lower().encode("latin-1")
640 set_value = value.encode("latin-1")
641
642 for idx, (item_key, item_value) in enumerate(self._list):
643 if item_key == set_key:
644 return item_value.decode("latin-1")
645 self._list.append((set_key, set_value))
646 return value
647
648 def update(self, other: Mapping[str, str]) -> None:
649 for key, val in other.items():
650 self[key] = val
651
652 def append(self, key: str, value: str) -> None:
653 """
654 Append a header, preserving any duplicate entries.
655 """
656 append_key = key.lower().encode("latin-1")
657 append_value = value.encode("latin-1")
658 self._list.append((append_key, append_value))
659
660 def add_vary_header(self, vary: str) -> None:
661 existing = self.get("vary")
662 if existing is not None:
663 vary = ", ".join([existing, vary])
664 self["vary"] = vary
665
666
667class State:
668 """
669 An object that can be used to store arbitrary state.
670
671 Used for `request.state` and `app.state`.
672 """
673
674 _state: dict[str, Any]
675
676 def __init__(self, state: dict[str, Any] | None = None):
677 if state is None:
678 state = {}
679 super().__setattr__("_state", state)
680
681 def __setattr__(self, key: Any, value: Any) -> None:
682 self._state[key] = value
683
684 def __getattr__(self, key: Any) -> Any:
685 try:
686 return self._state[key]
687 except KeyError:
688 message = "'{}' object has no attribute '{}'"
689 raise AttributeError(message.format(self.__class__.__name__, key))
690
691 def __delattr__(self, key: Any) -> None:
692 del self._state[key]