1from __future__ import annotations
2
3import typing
4from shlex import shlex
5from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
6
7from starlette.concurrency import run_in_threadpool
8from starlette.types import Scope
9
10
11class Address(typing.NamedTuple):
12 host: str
13 port: int
14
15
16_KeyType = typing.TypeVar("_KeyType")
17# Mapping keys are invariant but their values are covariant since
18# you can only read them
19# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()`
20_CovariantValueType = typing.TypeVar("_CovariantValueType", covariant=True)
21
22
23class URL:
24 def __init__(
25 self,
26 url: str = "",
27 scope: Scope | None = None,
28 **components: typing.Any,
29 ) -> None:
30 if scope is not None:
31 assert not url, 'Cannot set both "url" and "scope".'
32 assert not components, 'Cannot set both "scope" and "**components".'
33 scheme = scope.get("scheme", "http")
34 server = scope.get("server", None)
35 path = scope["path"]
36 query_string = scope.get("query_string", b"")
37
38 host_header = None
39 for key, value in scope["headers"]:
40 if key == b"host":
41 host_header = value.decode("latin-1")
42 break
43
44 if host_header is not None:
45 url = f"{scheme}://{host_header}{path}"
46 elif server is None:
47 url = path
48 else:
49 host, port = server
50 default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme]
51 if port == default_port:
52 url = f"{scheme}://{host}{path}"
53 else:
54 url = f"{scheme}://{host}:{port}{path}"
55
56 if query_string:
57 url += "?" + query_string.decode()
58 elif components:
59 assert not url, 'Cannot set both "url" and "**components".'
60 url = URL("").replace(**components).components.geturl()
61
62 self._url = url
63
64 @property
65 def components(self) -> SplitResult:
66 if not hasattr(self, "_components"):
67 self._components = urlsplit(self._url)
68 return self._components
69
70 @property
71 def scheme(self) -> str:
72 return self.components.scheme
73
74 @property
75 def netloc(self) -> str:
76 return self.components.netloc
77
78 @property
79 def path(self) -> str:
80 return self.components.path
81
82 @property
83 def query(self) -> str:
84 return self.components.query
85
86 @property
87 def fragment(self) -> str:
88 return self.components.fragment
89
90 @property
91 def username(self) -> None | str:
92 return self.components.username
93
94 @property
95 def password(self) -> None | str:
96 return self.components.password
97
98 @property
99 def hostname(self) -> None | str:
100 return self.components.hostname
101
102 @property
103 def port(self) -> int | None:
104 return self.components.port
105
106 @property
107 def is_secure(self) -> bool:
108 return self.scheme in ("https", "wss")
109
110 def replace(self, **kwargs: typing.Any) -> URL:
111 if (
112 "username" in kwargs
113 or "password" in kwargs
114 or "hostname" in kwargs
115 or "port" in kwargs
116 ):
117 hostname = kwargs.pop("hostname", None)
118 port = kwargs.pop("port", self.port)
119 username = kwargs.pop("username", self.username)
120 password = kwargs.pop("password", self.password)
121
122 if hostname is None:
123 netloc = self.netloc
124 _, _, hostname = netloc.rpartition("@")
125
126 if hostname[-1] != "]":
127 hostname = hostname.rsplit(":", 1)[0]
128
129 netloc = hostname
130 if port is not None:
131 netloc += f":{port}"
132 if username is not None:
133 userpass = username
134 if password is not None:
135 userpass += f":{password}"
136 netloc = f"{userpass}@{netloc}"
137
138 kwargs["netloc"] = netloc
139
140 components = self.components._replace(**kwargs)
141 return self.__class__(components.geturl())
142
143 def include_query_params(self, **kwargs: typing.Any) -> URL:
144 params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
145 params.update({str(key): str(value) for key, value in kwargs.items()})
146 query = urlencode(params.multi_items())
147 return self.replace(query=query)
148
149 def replace_query_params(self, **kwargs: typing.Any) -> URL:
150 query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
151 return self.replace(query=query)
152
153 def remove_query_params(self, keys: str | typing.Sequence[str]) -> URL:
154 if isinstance(keys, str):
155 keys = [keys]
156 params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
157 for key in keys:
158 params.pop(key, None)
159 query = urlencode(params.multi_items())
160 return self.replace(query=query)
161
162 def __eq__(self, other: typing.Any) -> bool:
163 return str(self) == str(other)
164
165 def __str__(self) -> str:
166 return self._url
167
168 def __repr__(self) -> str:
169 url = str(self)
170 if self.password:
171 url = str(self.replace(password="********"))
172 return f"{self.__class__.__name__}({repr(url)})"
173
174
175class URLPath(str):
176 """
177 A URL path string that may also hold an associated protocol and/or host.
178 Used by the routing to return `url_path_for` matches.
179 """
180
181 def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath:
182 assert protocol in ("http", "websocket", "")
183 return str.__new__(cls, path)
184
185 def __init__(self, path: str, protocol: str = "", host: str = "") -> None:
186 self.protocol = protocol
187 self.host = host
188
189 def make_absolute_url(self, base_url: str | URL) -> URL:
190 if isinstance(base_url, str):
191 base_url = URL(base_url)
192 if self.protocol:
193 scheme = {
194 "http": {True: "https", False: "http"},
195 "websocket": {True: "wss", False: "ws"},
196 }[self.protocol][base_url.is_secure]
197 else:
198 scheme = base_url.scheme
199
200 netloc = self.host or base_url.netloc
201 path = base_url.path.rstrip("/") + str(self)
202 return URL(scheme=scheme, netloc=netloc, path=path)
203
204
205class Secret:
206 """
207 Holds a string value that should not be revealed in tracebacks etc.
208 You should cast the value to `str` at the point it is required.
209 """
210
211 def __init__(self, value: str):
212 self._value = value
213
214 def __repr__(self) -> str:
215 class_name = self.__class__.__name__
216 return f"{class_name}('**********')"
217
218 def __str__(self) -> str:
219 return self._value
220
221 def __bool__(self) -> bool:
222 return bool(self._value)
223
224
225class CommaSeparatedStrings(typing.Sequence[str]):
226 def __init__(self, value: str | typing.Sequence[str]):
227 if isinstance(value, str):
228 splitter = shlex(value, posix=True)
229 splitter.whitespace = ","
230 splitter.whitespace_split = True
231 self._items = [item.strip() for item in splitter]
232 else:
233 self._items = list(value)
234
235 def __len__(self) -> int:
236 return len(self._items)
237
238 def __getitem__(self, index: int | slice) -> typing.Any:
239 return self._items[index]
240
241 def __iter__(self) -> typing.Iterator[str]:
242 return iter(self._items)
243
244 def __repr__(self) -> str:
245 class_name = self.__class__.__name__
246 items = [item for item in self]
247 return f"{class_name}({items!r})"
248
249 def __str__(self) -> str:
250 return ", ".join(repr(item) for item in self)
251
252
253class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]):
254 _dict: dict[_KeyType, _CovariantValueType]
255
256 def __init__(
257 self,
258 *args: ImmutableMultiDict[_KeyType, _CovariantValueType]
259 | typing.Mapping[_KeyType, _CovariantValueType]
260 | typing.Iterable[tuple[_KeyType, _CovariantValueType]],
261 **kwargs: typing.Any,
262 ) -> None:
263 assert len(args) < 2, "Too many arguments."
264
265 value: typing.Any = args[0] if args else []
266 if kwargs:
267 value = (
268 ImmutableMultiDict(value).multi_items()
269 + ImmutableMultiDict(kwargs).multi_items()
270 )
271
272 if not value:
273 _items: list[tuple[typing.Any, typing.Any]] = []
274 elif hasattr(value, "multi_items"):
275 value = typing.cast(
276 ImmutableMultiDict[_KeyType, _CovariantValueType], value
277 )
278 _items = list(value.multi_items())
279 elif hasattr(value, "items"):
280 value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value)
281 _items = list(value.items())
282 else:
283 value = typing.cast("list[tuple[typing.Any, typing.Any]]", value)
284 _items = list(value)
285
286 self._dict = {k: v for k, v in _items}
287 self._list = _items
288
289 def getlist(self, key: typing.Any) -> list[_CovariantValueType]:
290 return [item_value for item_key, item_value in self._list if item_key == key]
291
292 def keys(self) -> typing.KeysView[_KeyType]:
293 return self._dict.keys()
294
295 def values(self) -> typing.ValuesView[_CovariantValueType]:
296 return self._dict.values()
297
298 def items(self) -> typing.ItemsView[_KeyType, _CovariantValueType]:
299 return self._dict.items()
300
301 def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]:
302 return list(self._list)
303
304 def __getitem__(self, key: _KeyType) -> _CovariantValueType:
305 return self._dict[key]
306
307 def __contains__(self, key: typing.Any) -> bool:
308 return key in self._dict
309
310 def __iter__(self) -> typing.Iterator[_KeyType]:
311 return iter(self.keys())
312
313 def __len__(self) -> int:
314 return len(self._dict)
315
316 def __eq__(self, other: typing.Any) -> bool:
317 if not isinstance(other, self.__class__):
318 return False
319 return sorted(self._list) == sorted(other._list)
320
321 def __repr__(self) -> str:
322 class_name = self.__class__.__name__
323 items = self.multi_items()
324 return f"{class_name}({items!r})"
325
326
327class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]):
328 def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
329 self.setlist(key, [value])
330
331 def __delitem__(self, key: typing.Any) -> None:
332 self._list = [(k, v) for k, v in self._list if k != key]
333 del self._dict[key]
334
335 def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
336 self._list = [(k, v) for k, v in self._list if k != key]
337 return self._dict.pop(key, default)
338
339 def popitem(self) -> tuple[typing.Any, typing.Any]:
340 key, value = self._dict.popitem()
341 self._list = [(k, v) for k, v in self._list if k != key]
342 return key, value
343
344 def poplist(self, key: typing.Any) -> list[typing.Any]:
345 values = [v for k, v in self._list if k == key]
346 self.pop(key)
347 return values
348
349 def clear(self) -> None:
350 self._dict.clear()
351 self._list.clear()
352
353 def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
354 if key not in self:
355 self._dict[key] = default
356 self._list.append((key, default))
357
358 return self[key]
359
360 def setlist(self, key: typing.Any, values: list[typing.Any]) -> None:
361 if not values:
362 self.pop(key, None)
363 else:
364 existing_items = [(k, v) for (k, v) in self._list if k != key]
365 self._list = existing_items + [(key, value) for value in values]
366 self._dict[key] = values[-1]
367
368 def append(self, key: typing.Any, value: typing.Any) -> None:
369 self._list.append((key, value))
370 self._dict[key] = value
371
372 def update(
373 self,
374 *args: MultiDict
375 | typing.Mapping[typing.Any, typing.Any]
376 | list[tuple[typing.Any, typing.Any]],
377 **kwargs: typing.Any,
378 ) -> None:
379 value = MultiDict(*args, **kwargs)
380 existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()]
381 self._list = existing_items + value.multi_items()
382 self._dict.update(value)
383
384
385class QueryParams(ImmutableMultiDict[str, str]):
386 """
387 An immutable multidict.
388 """
389
390 def __init__(
391 self,
392 *args: ImmutableMultiDict[typing.Any, typing.Any]
393 | typing.Mapping[typing.Any, typing.Any]
394 | list[tuple[typing.Any, typing.Any]]
395 | str
396 | bytes,
397 **kwargs: typing.Any,
398 ) -> None:
399 assert len(args) < 2, "Too many arguments."
400
401 value = args[0] if args else []
402
403 if isinstance(value, str):
404 super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
405 elif isinstance(value, bytes):
406 super().__init__(
407 parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs
408 )
409 else:
410 super().__init__(*args, **kwargs) # type: ignore[arg-type]
411 self._list = [(str(k), str(v)) for k, v in self._list]
412 self._dict = {str(k): str(v) for k, v in self._dict.items()}
413
414 def __str__(self) -> str:
415 return urlencode(self._list)
416
417 def __repr__(self) -> str:
418 class_name = self.__class__.__name__
419 query_string = str(self)
420 return f"{class_name}({query_string!r})"
421
422
423class UploadFile:
424 """
425 An uploaded file included as part of the request data.
426 """
427
428 def __init__(
429 self,
430 file: typing.BinaryIO,
431 *,
432 size: int | None = None,
433 filename: str | None = None,
434 headers: Headers | None = None,
435 ) -> None:
436 self.filename = filename
437 self.file = file
438 self.size = size
439 self.headers = headers or Headers()
440
441 @property
442 def content_type(self) -> str | None:
443 return self.headers.get("content-type", None)
444
445 @property
446 def _in_memory(self) -> bool:
447 # check for SpooledTemporaryFile._rolled
448 rolled_to_disk = getattr(self.file, "_rolled", True)
449 return not rolled_to_disk
450
451 async def write(self, data: bytes) -> None:
452 if self.size is not None:
453 self.size += len(data)
454
455 if self._in_memory:
456 self.file.write(data)
457 else:
458 await run_in_threadpool(self.file.write, data)
459
460 async def read(self, size: int = -1) -> bytes:
461 if self._in_memory:
462 return self.file.read(size)
463 return await run_in_threadpool(self.file.read, size)
464
465 async def seek(self, offset: int) -> None:
466 if self._in_memory:
467 self.file.seek(offset)
468 else:
469 await run_in_threadpool(self.file.seek, offset)
470
471 async def close(self) -> None:
472 if self._in_memory:
473 self.file.close()
474 else:
475 await run_in_threadpool(self.file.close)
476
477 def __repr__(self) -> str:
478 return (
479 f"{self.__class__.__name__}("
480 f"filename={self.filename!r}, "
481 f"size={self.size!r}, "
482 f"headers={self.headers!r})"
483 )
484
485
486class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]):
487 """
488 An immutable multidict, containing both file uploads and text input.
489 """
490
491 def __init__(
492 self,
493 *args: FormData
494 | typing.Mapping[str, str | UploadFile]
495 | list[tuple[str, str | UploadFile]],
496 **kwargs: str | UploadFile,
497 ) -> None:
498 super().__init__(*args, **kwargs)
499
500 async def close(self) -> None:
501 for key, value in self.multi_items():
502 if isinstance(value, UploadFile):
503 await value.close()
504
505
506class Headers(typing.Mapping[str, str]):
507 """
508 An immutable, case-insensitive multidict.
509 """
510
511 def __init__(
512 self,
513 headers: typing.Mapping[str, str] | None = None,
514 raw: list[tuple[bytes, bytes]] | None = None,
515 scope: typing.MutableMapping[str, typing.Any] | None = None,
516 ) -> None:
517 self._list: list[tuple[bytes, bytes]] = []
518 if headers is not None:
519 assert raw is None, 'Cannot set both "headers" and "raw".'
520 assert scope is None, 'Cannot set both "headers" and "scope".'
521 self._list = [
522 (key.lower().encode("latin-1"), value.encode("latin-1"))
523 for key, value in headers.items()
524 ]
525 elif raw is not None:
526 assert scope is None, 'Cannot set both "raw" and "scope".'
527 self._list = raw
528 elif scope is not None:
529 # scope["headers"] isn't necessarily a list
530 # it might be a tuple or other iterable
531 self._list = scope["headers"] = list(scope["headers"])
532
533 @property
534 def raw(self) -> list[tuple[bytes, bytes]]:
535 return list(self._list)
536
537 def keys(self) -> list[str]: # type: ignore[override]
538 return [key.decode("latin-1") for key, value in self._list]
539
540 def values(self) -> list[str]: # type: ignore[override]
541 return [value.decode("latin-1") for key, value in self._list]
542
543 def items(self) -> list[tuple[str, str]]: # type: ignore[override]
544 return [
545 (key.decode("latin-1"), value.decode("latin-1"))
546 for key, value in self._list
547 ]
548
549 def getlist(self, key: str) -> list[str]:
550 get_header_key = key.lower().encode("latin-1")
551 return [
552 item_value.decode("latin-1")
553 for item_key, item_value in self._list
554 if item_key == get_header_key
555 ]
556
557 def mutablecopy(self) -> MutableHeaders:
558 return MutableHeaders(raw=self._list[:])
559
560 def __getitem__(self, key: str) -> str:
561 get_header_key = key.lower().encode("latin-1")
562 for header_key, header_value in self._list:
563 if header_key == get_header_key:
564 return header_value.decode("latin-1")
565 raise KeyError(key)
566
567 def __contains__(self, key: typing.Any) -> bool:
568 get_header_key = key.lower().encode("latin-1")
569 for header_key, header_value in self._list:
570 if header_key == get_header_key:
571 return True
572 return False
573
574 def __iter__(self) -> typing.Iterator[typing.Any]:
575 return iter(self.keys())
576
577 def __len__(self) -> int:
578 return len(self._list)
579
580 def __eq__(self, other: typing.Any) -> bool:
581 if not isinstance(other, Headers):
582 return False
583 return sorted(self._list) == sorted(other._list)
584
585 def __repr__(self) -> str:
586 class_name = self.__class__.__name__
587 as_dict = dict(self.items())
588 if len(as_dict) == len(self):
589 return f"{class_name}({as_dict!r})"
590 return f"{class_name}(raw={self.raw!r})"
591
592
593class MutableHeaders(Headers):
594 def __setitem__(self, key: str, value: str) -> None:
595 """
596 Set the header `key` to `value`, removing any duplicate entries.
597 Retains insertion order.
598 """
599 set_key = key.lower().encode("latin-1")
600 set_value = value.encode("latin-1")
601
602 found_indexes: list[int] = []
603 for idx, (item_key, item_value) in enumerate(self._list):
604 if item_key == set_key:
605 found_indexes.append(idx)
606
607 for idx in reversed(found_indexes[1:]):
608 del self._list[idx]
609
610 if found_indexes:
611 idx = found_indexes[0]
612 self._list[idx] = (set_key, set_value)
613 else:
614 self._list.append((set_key, set_value))
615
616 def __delitem__(self, key: str) -> None:
617 """
618 Remove the header `key`.
619 """
620 del_key = key.lower().encode("latin-1")
621
622 pop_indexes: list[int] = []
623 for idx, (item_key, item_value) in enumerate(self._list):
624 if item_key == del_key:
625 pop_indexes.append(idx)
626
627 for idx in reversed(pop_indexes):
628 del self._list[idx]
629
630 def __ior__(self, other: typing.Mapping[str, str]) -> MutableHeaders:
631 if not isinstance(other, typing.Mapping):
632 raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
633 self.update(other)
634 return self
635
636 def __or__(self, other: typing.Mapping[str, str]) -> MutableHeaders:
637 if not isinstance(other, typing.Mapping):
638 raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
639 new = self.mutablecopy()
640 new.update(other)
641 return new
642
643 @property
644 def raw(self) -> list[tuple[bytes, bytes]]:
645 return self._list
646
647 def setdefault(self, key: str, value: str) -> str:
648 """
649 If the header `key` does not exist, then set it to `value`.
650 Returns the header value.
651 """
652 set_key = key.lower().encode("latin-1")
653 set_value = value.encode("latin-1")
654
655 for idx, (item_key, item_value) in enumerate(self._list):
656 if item_key == set_key:
657 return item_value.decode("latin-1")
658 self._list.append((set_key, set_value))
659 return value
660
661 def update(self, other: typing.Mapping[str, str]) -> None:
662 for key, val in other.items():
663 self[key] = val
664
665 def append(self, key: str, value: str) -> None:
666 """
667 Append a header, preserving any duplicate entries.
668 """
669 append_key = key.lower().encode("latin-1")
670 append_value = value.encode("latin-1")
671 self._list.append((append_key, append_value))
672
673 def add_vary_header(self, vary: str) -> None:
674 existing = self.get("vary")
675 if existing is not None:
676 vary = ", ".join([existing, vary])
677 self["vary"] = vary
678
679
680class State:
681 """
682 An object that can be used to store arbitrary state.
683
684 Used for `request.state` and `app.state`.
685 """
686
687 _state: dict[str, typing.Any]
688
689 def __init__(self, state: dict[str, typing.Any] | None = None):
690 if state is None:
691 state = {}
692 super().__setattr__("_state", state)
693
694 def __setattr__(self, key: typing.Any, value: typing.Any) -> None:
695 self._state[key] = value
696
697 def __getattr__(self, key: typing.Any) -> typing.Any:
698 try:
699 return self._state[key]
700 except KeyError:
701 message = "'{}' object has no attribute '{}'"
702 raise AttributeError(message.format(self.__class__.__name__, key))
703
704 def __delattr__(self, key: typing.Any) -> None:
705 del self._state[key]