1from __future__ import annotations
2
3import re
4from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, MutableMapping, Sequence, ValuesView
5from shlex import shlex
6from typing import Any, BinaryIO, NamedTuple, TypeVar, cast
7from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
8
9from starlette.concurrency import run_in_threadpool
10from starlette.types import Scope
11
12
13class Address(NamedTuple):
14 host: str
15 port: int
16
17
18_KeyType = TypeVar("_KeyType")
19# Mapping keys are invariant but their values are covariant since
20# you can only read them
21# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()`
22_CovariantValueType = TypeVar("_CovariantValueType", covariant=True)
23
24# Rejects Host header chars (/, ?, #, @, ...) that would let urlsplit produce a path differing from scope["path"].
25_HOST_RE = re.compile(r"^([a-z0-9.-]+|\[[a-f0-9]*:[a-f0-9.:]+\])(?::[0-9]+)?$", re.IGNORECASE)
26
27
28class URL:
29 def __init__(
30 self,
31 url: str = "",
32 scope: Scope | None = None,
33 **components: Any,
34 ) -> None:
35 if scope is not None:
36 assert not url, 'Cannot set both "url" and "scope".'
37 assert not components, 'Cannot set both "scope" and "**components".'
38 scheme = scope.get("scheme", "http")
39 server = scope.get("server", None)
40 path = scope["path"]
41 query_string = scope.get("query_string", b"")
42
43 host_header = None
44 for key, value in scope["headers"]:
45 if key == b"host":
46 host_header = value.decode("latin-1")
47 break
48
49 if host_header is not None and _HOST_RE.fullmatch(host_header):
50 url = f"{scheme}://{host_header}{path}"
51 elif server is None:
52 url = path
53 else:
54 host, port = server
55 default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme]
56 if port == default_port:
57 url = f"{scheme}://{host}{path}"
58 else:
59 url = f"{scheme}://{host}:{port}{path}"
60
61 if query_string:
62 url += "?" + query_string.decode()
63 elif components:
64 assert not url, 'Cannot set both "url" and "**components".'
65 url = URL("").replace(**components).components.geturl()
66
67 self._url = url
68
69 @property
70 def components(self) -> SplitResult:
71 if not hasattr(self, "_components"):
72 self._components = urlsplit(self._url)
73 return self._components
74
75 @property
76 def scheme(self) -> str:
77 return self.components.scheme
78
79 @property
80 def netloc(self) -> str:
81 return self.components.netloc
82
83 @property
84 def path(self) -> str:
85 return self.components.path
86
87 @property
88 def query(self) -> str:
89 return self.components.query
90
91 @property
92 def fragment(self) -> str:
93 return self.components.fragment
94
95 @property
96 def username(self) -> None | str:
97 return self.components.username
98
99 @property
100 def password(self) -> None | str:
101 return self.components.password
102
103 @property
104 def hostname(self) -> None | str:
105 return self.components.hostname
106
107 @property
108 def port(self) -> int | None:
109 return self.components.port
110
111 @property
112 def is_secure(self) -> bool:
113 return self.scheme in ("https", "wss")
114
115 def replace(self, **kwargs: Any) -> URL:
116 if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs:
117 hostname = kwargs.pop("hostname", None)
118 port = kwargs.pop("port", self.port)
119 username = kwargs.pop("username", self.username)
120 password = kwargs.pop("password", self.password)
121
122 if hostname is None:
123 netloc = self.netloc
124 _, _, hostname = netloc.rpartition("@")
125
126 if hostname[-1] != "]":
127 hostname = hostname.rsplit(":", 1)[0]
128
129 netloc = hostname
130 if port is not None:
131 netloc += f":{port}"
132 if username is not None:
133 userpass = username
134 if password is not None:
135 userpass += f":{password}"
136 netloc = f"{userpass}@{netloc}"
137
138 kwargs["netloc"] = netloc
139
140 components = self.components._replace(**kwargs)
141 return self.__class__(components.geturl())
142
143 def include_query_params(self, **kwargs: Any) -> URL:
144 params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
145 params.update({str(key): str(value) for key, value in kwargs.items()})
146 query = urlencode(params.multi_items())
147 return self.replace(query=query)
148
149 def replace_query_params(self, **kwargs: Any) -> URL:
150 query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
151 return self.replace(query=query)
152
153 def remove_query_params(self, keys: str | Sequence[str]) -> URL:
154 if isinstance(keys, str):
155 keys = [keys]
156 params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
157 for key in keys:
158 params.pop(key, None)
159 query = urlencode(params.multi_items())
160 return self.replace(query=query)
161
162 def __eq__(self, other: Any) -> bool:
163 return str(self) == str(other)
164
165 def __str__(self) -> str:
166 return self._url
167
168 def __repr__(self) -> str:
169 url = str(self)
170 if self.password:
171 url = str(self.replace(password="********"))
172 return f"{self.__class__.__name__}({repr(url)})"
173
174
175class URLPath(str):
176 """
177 A URL path string that may also hold an associated protocol and/or host.
178 Used by the routing to return `url_path_for` matches.
179 """
180
181 def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath:
182 assert protocol in ("http", "websocket", "")
183 return str.__new__(cls, path)
184
185 def __init__(self, path: str, protocol: str = "", host: str = "") -> None:
186 self.protocol = protocol
187 self.host = host
188
189 def make_absolute_url(self, base_url: str | URL) -> URL:
190 if isinstance(base_url, str):
191 base_url = URL(base_url)
192 if self.protocol:
193 scheme = {
194 "http": {True: "https", False: "http"},
195 "websocket": {True: "wss", False: "ws"},
196 }[self.protocol][base_url.is_secure]
197 else:
198 scheme = base_url.scheme
199
200 netloc = self.host or base_url.netloc
201 path = base_url.path.rstrip("/") + str(self)
202 return URL(scheme=scheme, netloc=netloc, path=path)
203
204
205class Secret:
206 """
207 Holds a string value that should not be revealed in tracebacks etc.
208 You should cast the value to `str` at the point it is required.
209 """
210
211 def __init__(self, value: str):
212 self._value = value
213
214 def __repr__(self) -> str:
215 class_name = self.__class__.__name__
216 return f"{class_name}('**********')"
217
218 def __str__(self) -> str:
219 return self._value
220
221 def __bool__(self) -> bool:
222 return bool(self._value)
223
224
225class CommaSeparatedStrings(Sequence[str]):
226 def __init__(self, value: str | Sequence[str]):
227 if isinstance(value, str):
228 splitter = shlex(value, posix=True)
229 splitter.whitespace = ","
230 splitter.whitespace_split = True
231 self._items = [item.strip() for item in splitter]
232 else:
233 self._items = list(value)
234
235 def __len__(self) -> int:
236 return len(self._items)
237
238 def __getitem__(self, index: int | slice) -> Any:
239 return self._items[index]
240
241 def __iter__(self) -> Iterator[str]:
242 return iter(self._items)
243
244 def __repr__(self) -> str:
245 class_name = self.__class__.__name__
246 items = [item for item in self]
247 return f"{class_name}({items!r})"
248
249 def __str__(self) -> str:
250 return ", ".join(repr(item) for item in self)
251
252
253class ImmutableMultiDict(Mapping[_KeyType, _CovariantValueType]):
254 _dict: dict[_KeyType, _CovariantValueType]
255
256 def __init__(
257 self,
258 *args: ImmutableMultiDict[_KeyType, _CovariantValueType]
259 | Mapping[_KeyType, _CovariantValueType]
260 | Iterable[tuple[_KeyType, _CovariantValueType]],
261 **kwargs: Any,
262 ) -> None:
263 assert len(args) < 2, "Too many arguments."
264
265 value: Any = args[0] if args else []
266 if kwargs:
267 value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items()
268
269 if not value:
270 _items: list[tuple[Any, Any]] = []
271 elif hasattr(value, "multi_items"):
272 value = cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value)
273 _items = list(value.multi_items())
274 elif hasattr(value, "items"):
275 value = cast(Mapping[_KeyType, _CovariantValueType], value)
276 _items = list(value.items())
277 else:
278 value = cast("list[tuple[Any, Any]]", value)
279 _items = list(value)
280
281 self._dict = {k: v for k, v in _items}
282 self._list = _items
283
284 def getlist(self, key: Any) -> list[_CovariantValueType]:
285 return [item_value for item_key, item_value in self._list if item_key == key]
286
287 def keys(self) -> KeysView[_KeyType]:
288 return self._dict.keys()
289
290 def values(self) -> ValuesView[_CovariantValueType]:
291 return self._dict.values()
292
293 def items(self) -> ItemsView[_KeyType, _CovariantValueType]:
294 return self._dict.items()
295
296 def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]:
297 return list(self._list)
298
299 def __getitem__(self, key: _KeyType) -> _CovariantValueType:
300 return self._dict[key]
301
302 def __contains__(self, key: Any) -> bool:
303 return key in self._dict
304
305 def __iter__(self) -> Iterator[_KeyType]:
306 return iter(self.keys())
307
308 def __len__(self) -> int:
309 return len(self._dict)
310
311 def __eq__(self, other: Any) -> bool:
312 if not isinstance(other, self.__class__):
313 return False
314 return sorted(self._list) == sorted(other._list)
315
316 def __repr__(self) -> str:
317 class_name = self.__class__.__name__
318 items = self.multi_items()
319 return f"{class_name}({items!r})"
320
321
322class MultiDict(ImmutableMultiDict[Any, Any]):
323 def __setitem__(self, key: Any, value: Any) -> None:
324 self.setlist(key, [value])
325
326 def __delitem__(self, key: Any) -> None:
327 self._list = [(k, v) for k, v in self._list if k != key]
328 del self._dict[key]
329
330 def pop(self, key: Any, default: Any = None) -> Any:
331 self._list = [(k, v) for k, v in self._list if k != key]
332 return self._dict.pop(key, default)
333
334 def popitem(self) -> tuple[Any, Any]:
335 key, value = self._dict.popitem()
336 self._list = [(k, v) for k, v in self._list if k != key]
337 return key, value
338
339 def poplist(self, key: Any) -> list[Any]:
340 values = [v for k, v in self._list if k == key]
341 self.pop(key)
342 return values
343
344 def clear(self) -> None:
345 self._dict.clear()
346 self._list.clear()
347
348 def setdefault(self, key: Any, default: Any = None) -> Any:
349 if key not in self:
350 self._dict[key] = default
351 self._list.append((key, default))
352
353 return self[key]
354
355 def setlist(self, key: Any, values: list[Any]) -> None:
356 if not values:
357 self.pop(key, None)
358 else:
359 existing_items = [(k, v) for (k, v) in self._list if k != key]
360 self._list = existing_items + [(key, value) for value in values]
361 self._dict[key] = values[-1]
362
363 def append(self, key: Any, value: Any) -> None:
364 self._list.append((key, value))
365 self._dict[key] = value
366
367 def update(
368 self,
369 *args: MultiDict | Mapping[Any, Any] | list[tuple[Any, Any]],
370 **kwargs: Any,
371 ) -> None:
372 value = MultiDict(*args, **kwargs)
373 existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()]
374 self._list = existing_items + value.multi_items()
375 self._dict.update(value)
376
377
378class QueryParams(ImmutableMultiDict[str, str]):
379 """
380 An immutable multidict.
381 """
382
383 def __init__(
384 self,
385 *args: ImmutableMultiDict[Any, Any] | Mapping[Any, Any] | list[tuple[Any, Any]] | str | bytes,
386 **kwargs: Any,
387 ) -> None:
388 assert len(args) < 2, "Too many arguments."
389
390 value = args[0] if args else []
391
392 if isinstance(value, str):
393 super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
394 elif isinstance(value, bytes):
395 super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs)
396 else:
397 super().__init__(*args, **kwargs) # type: ignore[arg-type]
398 self._list = [(str(k), str(v)) for k, v in self._list]
399 self._dict = {str(k): str(v) for k, v in self._dict.items()}
400
401 def __str__(self) -> str:
402 return urlencode(self._list)
403
404 def __repr__(self) -> str:
405 class_name = self.__class__.__name__
406 query_string = str(self)
407 return f"{class_name}({query_string!r})"
408
409
410class UploadFile:
411 """
412 An uploaded file included as part of the request data.
413 """
414
415 def __init__(
416 self,
417 file: BinaryIO,
418 *,
419 size: int | None = None,
420 filename: str | None = None,
421 headers: Headers | None = None,
422 ) -> None:
423 self.filename = filename
424 self.file = file
425 self.size = size
426 self.headers = headers or Headers()
427
428 # Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks.
429 # Note 0 means unlimited mirroring SpooledTemporaryFile's __init__
430 self._max_mem_size = getattr(self.file, "_max_size", 0)
431
432 @property
433 def content_type(self) -> str | None:
434 return self.headers.get("content-type", None)
435
436 @property
437 def _in_memory(self) -> bool:
438 # check for SpooledTemporaryFile._rolled
439 rolled_to_disk = getattr(self.file, "_rolled", True)
440 return not rolled_to_disk
441
442 def _will_roll(self, size_to_add: int) -> bool:
443 # If we're not in_memory then we will always roll
444 if not self._in_memory:
445 return True
446
447 # Check for SpooledTemporaryFile._max_size
448 future_size = self.file.tell() + size_to_add
449 return bool(future_size > self._max_mem_size) if self._max_mem_size else False
450
451 async def write(self, data: bytes) -> None:
452 new_data_len = len(data)
453 if self.size is not None:
454 self.size += new_data_len
455
456 if self._will_roll(new_data_len):
457 await run_in_threadpool(self.file.write, data)
458 else:
459 self.file.write(data)
460
461 async def read(self, size: int = -1) -> bytes:
462 if self._in_memory:
463 return self.file.read(size)
464 return await run_in_threadpool(self.file.read, size)
465
466 async def seek(self, offset: int) -> None:
467 if self._in_memory:
468 self.file.seek(offset)
469 else:
470 await run_in_threadpool(self.file.seek, offset)
471
472 async def close(self) -> None:
473 if self._in_memory:
474 self.file.close()
475 else:
476 await run_in_threadpool(self.file.close)
477
478 def __repr__(self) -> str:
479 return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})"
480
481
482class FormData(ImmutableMultiDict[str, UploadFile | str]):
483 """
484 An immutable multidict, containing both file uploads and text input.
485 """
486
487 def __init__(
488 self,
489 *args: FormData | Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]],
490 **kwargs: str | UploadFile,
491 ) -> None:
492 super().__init__(*args, **kwargs)
493
494 async def close(self) -> None:
495 for key, value in self.multi_items():
496 if isinstance(value, UploadFile):
497 await value.close()
498
499
500class Headers(Mapping[str, str]):
501 """
502 An immutable, case-insensitive multidict.
503 """
504
505 def __init__(
506 self,
507 headers: Mapping[str, str] | None = None,
508 raw: list[tuple[bytes, bytes]] | None = None,
509 scope: MutableMapping[str, Any] | None = None,
510 ) -> None:
511 self._list: list[tuple[bytes, bytes]] = []
512 if headers is not None:
513 assert raw is None, 'Cannot set both "headers" and "raw".'
514 assert scope is None, 'Cannot set both "headers" and "scope".'
515 self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()]
516 elif raw is not None:
517 assert scope is None, 'Cannot set both "raw" and "scope".'
518 self._list = raw
519 elif scope is not None:
520 # scope["headers"] isn't necessarily a list
521 # it might be a tuple or other iterable
522 self._list = scope["headers"] = list(scope["headers"])
523
524 @property
525 def raw(self) -> list[tuple[bytes, bytes]]:
526 return list(self._list)
527
528 def keys(self) -> list[str]: # type: ignore[override]
529 return [key.decode("latin-1") for key, value in self._list]
530
531 def values(self) -> list[str]: # type: ignore[override]
532 return [value.decode("latin-1") for key, value in self._list]
533
534 def items(self) -> list[tuple[str, str]]: # type: ignore[override]
535 return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list]
536
537 def getlist(self, key: str) -> list[str]:
538 get_header_key = key.lower().encode("latin-1")
539 return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key]
540
541 def mutablecopy(self) -> MutableHeaders:
542 return MutableHeaders(raw=self._list[:])
543
544 def __getitem__(self, key: str) -> str:
545 get_header_key = key.lower().encode("latin-1")
546 for header_key, header_value in self._list:
547 if header_key == get_header_key:
548 return header_value.decode("latin-1")
549 raise KeyError(key)
550
551 def __contains__(self, key: Any) -> bool:
552 get_header_key = key.lower().encode("latin-1")
553 for header_key, header_value in self._list:
554 if header_key == get_header_key:
555 return True
556 return False
557
558 def __iter__(self) -> Iterator[Any]:
559 return iter(self.keys())
560
561 def __len__(self) -> int:
562 return len(self._list)
563
564 def __eq__(self, other: Any) -> bool:
565 if not isinstance(other, Headers):
566 return False
567 return sorted(self._list) == sorted(other._list)
568
569 def __repr__(self) -> str:
570 class_name = self.__class__.__name__
571 as_dict = dict(self.items())
572 if len(as_dict) == len(self):
573 return f"{class_name}({as_dict!r})"
574 return f"{class_name}(raw={self.raw!r})"
575
576
577class MutableHeaders(Headers):
578 def __setitem__(self, key: str, value: str) -> None:
579 """
580 Set the header `key` to `value`, removing any duplicate entries.
581 Retains insertion order.
582 """
583 set_key = key.lower().encode("latin-1")
584 set_value = value.encode("latin-1")
585
586 found_indexes: list[int] = []
587 for idx, (item_key, item_value) in enumerate(self._list):
588 if item_key == set_key:
589 found_indexes.append(idx)
590
591 for idx in reversed(found_indexes[1:]):
592 del self._list[idx]
593
594 if found_indexes:
595 idx = found_indexes[0]
596 self._list[idx] = (set_key, set_value)
597 else:
598 self._list.append((set_key, set_value))
599
600 def __delitem__(self, key: str) -> None:
601 """
602 Remove the header `key`.
603 """
604 del_key = key.lower().encode("latin-1")
605
606 pop_indexes: list[int] = []
607 for idx, (item_key, item_value) in enumerate(self._list):
608 if item_key == del_key:
609 pop_indexes.append(idx)
610
611 for idx in reversed(pop_indexes):
612 del self._list[idx]
613
614 def __ior__(self, other: Mapping[str, str]) -> MutableHeaders:
615 if not isinstance(other, Mapping):
616 raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
617 self.update(other)
618 return self
619
620 def __or__(self, other: Mapping[str, str]) -> MutableHeaders:
621 if not isinstance(other, Mapping):
622 raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
623 new = self.mutablecopy()
624 new.update(other)
625 return new
626
627 @property
628 def raw(self) -> list[tuple[bytes, bytes]]:
629 return self._list
630
631 def setdefault(self, key: str, value: str) -> str:
632 """
633 If the header `key` does not exist, then set it to `value`.
634 Returns the header value.
635 """
636 set_key = key.lower().encode("latin-1")
637 set_value = value.encode("latin-1")
638
639 for idx, (item_key, item_value) in enumerate(self._list):
640 if item_key == set_key:
641 return item_value.decode("latin-1")
642 self._list.append((set_key, set_value))
643 return value
644
645 def update(self, other: Mapping[str, str]) -> None:
646 for key, val in other.items():
647 self[key] = val
648
649 def append(self, key: str, value: str) -> None:
650 """
651 Append a header, preserving any duplicate entries.
652 """
653 append_key = key.lower().encode("latin-1")
654 append_value = value.encode("latin-1")
655 self._list.append((append_key, append_value))
656
657 def add_vary_header(self, vary: str) -> None:
658 existing = self.get("vary")
659 if existing is not None:
660 vary = ", ".join([existing, vary])
661 self["vary"] = vary
662
663
664class State:
665 """
666 An object that can be used to store arbitrary state.
667
668 Used for `request.state` and `app.state`.
669 """
670
671 _state: dict[str, Any]
672
673 def __init__(self, state: dict[str, Any] | None = None):
674 if state is None:
675 state = {}
676 super().__setattr__("_state", state)
677
678 def __setattr__(self, key: Any, value: Any) -> None:
679 self._state[key] = value
680
681 def __getattr__(self, key: Any) -> Any:
682 try:
683 return self._state[key]
684 except KeyError:
685 message = "'{}' object has no attribute '{}'"
686 raise AttributeError(message.format(self.__class__.__name__, key))
687
688 def __delattr__(self, key: Any) -> None:
689 del self._state[key]
690
691 def __getitem__(self, key: str) -> Any:
692 return self._state[key]
693
694 def __setitem__(self, key: str, value: Any) -> None:
695 self._state[key] = value
696
697 def __delitem__(self, key: str) -> None:
698 del self._state[key]
699
700 def __iter__(self) -> Iterator[str]:
701 return iter(self._state)
702
703 def __len__(self) -> int:
704 return len(self._state)