1from __future__ import annotations
2
3import typing
4from shlex import shlex
5from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
6
7from starlette.concurrency import run_in_threadpool
8from starlette.types import Scope
9
10
11class Address(typing.NamedTuple):
12 host: str
13 port: int
14
15
16_KeyType = typing.TypeVar("_KeyType")
17# Mapping keys are invariant but their values are covariant since
18# you can only read them
19# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()`
20_CovariantValueType = typing.TypeVar("_CovariantValueType", covariant=True)
21
22
23class URL:
24 def __init__(
25 self,
26 url: str = "",
27 scope: Scope | None = None,
28 **components: typing.Any,
29 ) -> None:
30 if scope is not None:
31 assert not url, 'Cannot set both "url" and "scope".'
32 assert not components, 'Cannot set both "scope" and "**components".'
33 scheme = scope.get("scheme", "http")
34 server = scope.get("server", None)
35 path = scope["path"]
36 query_string = scope.get("query_string", b"")
37
38 host_header = None
39 for key, value in scope["headers"]:
40 if key == b"host":
41 host_header = value.decode("latin-1")
42 break
43
44 if host_header is not None:
45 url = f"{scheme}://{host_header}{path}"
46 elif server is None:
47 url = path
48 else:
49 host, port = server
50 default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme]
51 if port == default_port:
52 url = f"{scheme}://{host}{path}"
53 else:
54 url = f"{scheme}://{host}:{port}{path}"
55
56 if query_string:
57 url += "?" + query_string.decode()
58 elif components:
59 assert not url, 'Cannot set both "url" and "**components".'
60 url = URL("").replace(**components).components.geturl()
61
62 self._url = url
63
64 @property
65 def components(self) -> SplitResult:
66 if not hasattr(self, "_components"):
67 self._components = urlsplit(self._url)
68 return self._components
69
70 @property
71 def scheme(self) -> str:
72 return self.components.scheme
73
74 @property
75 def netloc(self) -> str:
76 return self.components.netloc
77
78 @property
79 def path(self) -> str:
80 return self.components.path
81
82 @property
83 def query(self) -> str:
84 return self.components.query
85
86 @property
87 def fragment(self) -> str:
88 return self.components.fragment
89
90 @property
91 def username(self) -> None | str:
92 return self.components.username
93
94 @property
95 def password(self) -> None | str:
96 return self.components.password
97
98 @property
99 def hostname(self) -> None | str:
100 return self.components.hostname
101
102 @property
103 def port(self) -> int | None:
104 return self.components.port
105
106 @property
107 def is_secure(self) -> bool:
108 return self.scheme in ("https", "wss")
109
110 def replace(self, **kwargs: typing.Any) -> URL:
111 if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs:
112 hostname = kwargs.pop("hostname", None)
113 port = kwargs.pop("port", self.port)
114 username = kwargs.pop("username", self.username)
115 password = kwargs.pop("password", self.password)
116
117 if hostname is None:
118 netloc = self.netloc
119 _, _, hostname = netloc.rpartition("@")
120
121 if hostname[-1] != "]":
122 hostname = hostname.rsplit(":", 1)[0]
123
124 netloc = hostname
125 if port is not None:
126 netloc += f":{port}"
127 if username is not None:
128 userpass = username
129 if password is not None:
130 userpass += f":{password}"
131 netloc = f"{userpass}@{netloc}"
132
133 kwargs["netloc"] = netloc
134
135 components = self.components._replace(**kwargs)
136 return self.__class__(components.geturl())
137
138 def include_query_params(self, **kwargs: typing.Any) -> URL:
139 params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
140 params.update({str(key): str(value) for key, value in kwargs.items()})
141 query = urlencode(params.multi_items())
142 return self.replace(query=query)
143
144 def replace_query_params(self, **kwargs: typing.Any) -> URL:
145 query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
146 return self.replace(query=query)
147
148 def remove_query_params(self, keys: str | typing.Sequence[str]) -> URL:
149 if isinstance(keys, str):
150 keys = [keys]
151 params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
152 for key in keys:
153 params.pop(key, None)
154 query = urlencode(params.multi_items())
155 return self.replace(query=query)
156
157 def __eq__(self, other: typing.Any) -> bool:
158 return str(self) == str(other)
159
160 def __str__(self) -> str:
161 return self._url
162
163 def __repr__(self) -> str:
164 url = str(self)
165 if self.password:
166 url = str(self.replace(password="********"))
167 return f"{self.__class__.__name__}({repr(url)})"
168
169
170class URLPath(str):
171 """
172 A URL path string that may also hold an associated protocol and/or host.
173 Used by the routing to return `url_path_for` matches.
174 """
175
176 def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath:
177 assert protocol in ("http", "websocket", "")
178 return str.__new__(cls, path)
179
180 def __init__(self, path: str, protocol: str = "", host: str = "") -> None:
181 self.protocol = protocol
182 self.host = host
183
184 def make_absolute_url(self, base_url: str | URL) -> URL:
185 if isinstance(base_url, str):
186 base_url = URL(base_url)
187 if self.protocol:
188 scheme = {
189 "http": {True: "https", False: "http"},
190 "websocket": {True: "wss", False: "ws"},
191 }[self.protocol][base_url.is_secure]
192 else:
193 scheme = base_url.scheme
194
195 netloc = self.host or base_url.netloc
196 path = base_url.path.rstrip("/") + str(self)
197 return URL(scheme=scheme, netloc=netloc, path=path)
198
199
200class Secret:
201 """
202 Holds a string value that should not be revealed in tracebacks etc.
203 You should cast the value to `str` at the point it is required.
204 """
205
206 def __init__(self, value: str):
207 self._value = value
208
209 def __repr__(self) -> str:
210 class_name = self.__class__.__name__
211 return f"{class_name}('**********')"
212
213 def __str__(self) -> str:
214 return self._value
215
216 def __bool__(self) -> bool:
217 return bool(self._value)
218
219
220class CommaSeparatedStrings(typing.Sequence[str]):
221 def __init__(self, value: str | typing.Sequence[str]):
222 if isinstance(value, str):
223 splitter = shlex(value, posix=True)
224 splitter.whitespace = ","
225 splitter.whitespace_split = True
226 self._items = [item.strip() for item in splitter]
227 else:
228 self._items = list(value)
229
230 def __len__(self) -> int:
231 return len(self._items)
232
233 def __getitem__(self, index: int | slice) -> typing.Any:
234 return self._items[index]
235
236 def __iter__(self) -> typing.Iterator[str]:
237 return iter(self._items)
238
239 def __repr__(self) -> str:
240 class_name = self.__class__.__name__
241 items = [item for item in self]
242 return f"{class_name}({items!r})"
243
244 def __str__(self) -> str:
245 return ", ".join(repr(item) for item in self)
246
247
248class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]):
249 _dict: dict[_KeyType, _CovariantValueType]
250
251 def __init__(
252 self,
253 *args: ImmutableMultiDict[_KeyType, _CovariantValueType]
254 | typing.Mapping[_KeyType, _CovariantValueType]
255 | typing.Iterable[tuple[_KeyType, _CovariantValueType]],
256 **kwargs: typing.Any,
257 ) -> None:
258 assert len(args) < 2, "Too many arguments."
259
260 value: typing.Any = args[0] if args else []
261 if kwargs:
262 value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items()
263
264 if not value:
265 _items: list[tuple[typing.Any, typing.Any]] = []
266 elif hasattr(value, "multi_items"):
267 value = typing.cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value)
268 _items = list(value.multi_items())
269 elif hasattr(value, "items"):
270 value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value)
271 _items = list(value.items())
272 else:
273 value = typing.cast("list[tuple[typing.Any, typing.Any]]", value)
274 _items = list(value)
275
276 self._dict = {k: v for k, v in _items}
277 self._list = _items
278
279 def getlist(self, key: typing.Any) -> list[_CovariantValueType]:
280 return [item_value for item_key, item_value in self._list if item_key == key]
281
282 def keys(self) -> typing.KeysView[_KeyType]:
283 return self._dict.keys()
284
285 def values(self) -> typing.ValuesView[_CovariantValueType]:
286 return self._dict.values()
287
288 def items(self) -> typing.ItemsView[_KeyType, _CovariantValueType]:
289 return self._dict.items()
290
291 def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]:
292 return list(self._list)
293
294 def __getitem__(self, key: _KeyType) -> _CovariantValueType:
295 return self._dict[key]
296
297 def __contains__(self, key: typing.Any) -> bool:
298 return key in self._dict
299
300 def __iter__(self) -> typing.Iterator[_KeyType]:
301 return iter(self.keys())
302
303 def __len__(self) -> int:
304 return len(self._dict)
305
306 def __eq__(self, other: typing.Any) -> bool:
307 if not isinstance(other, self.__class__):
308 return False
309 return sorted(self._list) == sorted(other._list)
310
311 def __repr__(self) -> str:
312 class_name = self.__class__.__name__
313 items = self.multi_items()
314 return f"{class_name}({items!r})"
315
316
317class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]):
318 def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
319 self.setlist(key, [value])
320
321 def __delitem__(self, key: typing.Any) -> None:
322 self._list = [(k, v) for k, v in self._list if k != key]
323 del self._dict[key]
324
325 def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
326 self._list = [(k, v) for k, v in self._list if k != key]
327 return self._dict.pop(key, default)
328
329 def popitem(self) -> tuple[typing.Any, typing.Any]:
330 key, value = self._dict.popitem()
331 self._list = [(k, v) for k, v in self._list if k != key]
332 return key, value
333
334 def poplist(self, key: typing.Any) -> list[typing.Any]:
335 values = [v for k, v in self._list if k == key]
336 self.pop(key)
337 return values
338
339 def clear(self) -> None:
340 self._dict.clear()
341 self._list.clear()
342
343 def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
344 if key not in self:
345 self._dict[key] = default
346 self._list.append((key, default))
347
348 return self[key]
349
350 def setlist(self, key: typing.Any, values: list[typing.Any]) -> None:
351 if not values:
352 self.pop(key, None)
353 else:
354 existing_items = [(k, v) for (k, v) in self._list if k != key]
355 self._list = existing_items + [(key, value) for value in values]
356 self._dict[key] = values[-1]
357
358 def append(self, key: typing.Any, value: typing.Any) -> None:
359 self._list.append((key, value))
360 self._dict[key] = value
361
362 def update(
363 self,
364 *args: MultiDict | typing.Mapping[typing.Any, typing.Any] | list[tuple[typing.Any, typing.Any]],
365 **kwargs: typing.Any,
366 ) -> None:
367 value = MultiDict(*args, **kwargs)
368 existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()]
369 self._list = existing_items + value.multi_items()
370 self._dict.update(value)
371
372
373class QueryParams(ImmutableMultiDict[str, str]):
374 """
375 An immutable multidict.
376 """
377
378 def __init__(
379 self,
380 *args: ImmutableMultiDict[typing.Any, typing.Any]
381 | typing.Mapping[typing.Any, typing.Any]
382 | list[tuple[typing.Any, typing.Any]]
383 | str
384 | bytes,
385 **kwargs: typing.Any,
386 ) -> None:
387 assert len(args) < 2, "Too many arguments."
388
389 value = args[0] if args else []
390
391 if isinstance(value, str):
392 super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
393 elif isinstance(value, bytes):
394 super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs)
395 else:
396 super().__init__(*args, **kwargs) # type: ignore[arg-type]
397 self._list = [(str(k), str(v)) for k, v in self._list]
398 self._dict = {str(k): str(v) for k, v in self._dict.items()}
399
400 def __str__(self) -> str:
401 return urlencode(self._list)
402
403 def __repr__(self) -> str:
404 class_name = self.__class__.__name__
405 query_string = str(self)
406 return f"{class_name}({query_string!r})"
407
408
409class UploadFile:
410 """
411 An uploaded file included as part of the request data.
412 """
413
414 def __init__(
415 self,
416 file: typing.BinaryIO,
417 *,
418 size: int | None = None,
419 filename: str | None = None,
420 headers: Headers | None = None,
421 ) -> None:
422 self.filename = filename
423 self.file = file
424 self.size = size
425 self.headers = headers or Headers()
426
427 @property
428 def content_type(self) -> str | None:
429 return self.headers.get("content-type", None)
430
431 @property
432 def _in_memory(self) -> bool:
433 # check for SpooledTemporaryFile._rolled
434 rolled_to_disk = getattr(self.file, "_rolled", True)
435 return not rolled_to_disk
436
437 async def write(self, data: bytes) -> None:
438 if self.size is not None:
439 self.size += len(data)
440
441 if self._in_memory:
442 self.file.write(data)
443 else:
444 await run_in_threadpool(self.file.write, data)
445
446 async def read(self, size: int = -1) -> bytes:
447 if self._in_memory:
448 return self.file.read(size)
449 return await run_in_threadpool(self.file.read, size)
450
451 async def seek(self, offset: int) -> None:
452 if self._in_memory:
453 self.file.seek(offset)
454 else:
455 await run_in_threadpool(self.file.seek, offset)
456
457 async def close(self) -> None:
458 if self._in_memory:
459 self.file.close()
460 else:
461 await run_in_threadpool(self.file.close)
462
463 def __repr__(self) -> str:
464 return (
465 f"{self.__class__.__name__}("
466 f"filename={self.filename!r}, "
467 f"size={self.size!r}, "
468 f"headers={self.headers!r})"
469 )
470
471
472class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]):
473 """
474 An immutable multidict, containing both file uploads and text input.
475 """
476
477 def __init__(
478 self,
479 *args: FormData | typing.Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]],
480 **kwargs: str | UploadFile,
481 ) -> None:
482 super().__init__(*args, **kwargs)
483
484 async def close(self) -> None:
485 for key, value in self.multi_items():
486 if isinstance(value, UploadFile):
487 await value.close()
488
489
490class Headers(typing.Mapping[str, str]):
491 """
492 An immutable, case-insensitive multidict.
493 """
494
495 def __init__(
496 self,
497 headers: typing.Mapping[str, str] | None = None,
498 raw: list[tuple[bytes, bytes]] | None = None,
499 scope: typing.MutableMapping[str, typing.Any] | None = None,
500 ) -> None:
501 self._list: list[tuple[bytes, bytes]] = []
502 if headers is not None:
503 assert raw is None, 'Cannot set both "headers" and "raw".'
504 assert scope is None, 'Cannot set both "headers" and "scope".'
505 self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()]
506 elif raw is not None:
507 assert scope is None, 'Cannot set both "raw" and "scope".'
508 self._list = raw
509 elif scope is not None:
510 # scope["headers"] isn't necessarily a list
511 # it might be a tuple or other iterable
512 self._list = scope["headers"] = list(scope["headers"])
513
514 @property
515 def raw(self) -> list[tuple[bytes, bytes]]:
516 return list(self._list)
517
518 def keys(self) -> list[str]: # type: ignore[override]
519 return [key.decode("latin-1") for key, value in self._list]
520
521 def values(self) -> list[str]: # type: ignore[override]
522 return [value.decode("latin-1") for key, value in self._list]
523
524 def items(self) -> list[tuple[str, str]]: # type: ignore[override]
525 return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list]
526
527 def getlist(self, key: str) -> list[str]:
528 get_header_key = key.lower().encode("latin-1")
529 return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key]
530
531 def mutablecopy(self) -> MutableHeaders:
532 return MutableHeaders(raw=self._list[:])
533
534 def __getitem__(self, key: str) -> str:
535 get_header_key = key.lower().encode("latin-1")
536 for header_key, header_value in self._list:
537 if header_key == get_header_key:
538 return header_value.decode("latin-1")
539 raise KeyError(key)
540
541 def __contains__(self, key: typing.Any) -> bool:
542 get_header_key = key.lower().encode("latin-1")
543 for header_key, header_value in self._list:
544 if header_key == get_header_key:
545 return True
546 return False
547
548 def __iter__(self) -> typing.Iterator[typing.Any]:
549 return iter(self.keys())
550
551 def __len__(self) -> int:
552 return len(self._list)
553
554 def __eq__(self, other: typing.Any) -> bool:
555 if not isinstance(other, Headers):
556 return False
557 return sorted(self._list) == sorted(other._list)
558
559 def __repr__(self) -> str:
560 class_name = self.__class__.__name__
561 as_dict = dict(self.items())
562 if len(as_dict) == len(self):
563 return f"{class_name}({as_dict!r})"
564 return f"{class_name}(raw={self.raw!r})"
565
566
567class MutableHeaders(Headers):
568 def __setitem__(self, key: str, value: str) -> None:
569 """
570 Set the header `key` to `value`, removing any duplicate entries.
571 Retains insertion order.
572 """
573 set_key = key.lower().encode("latin-1")
574 set_value = value.encode("latin-1")
575
576 found_indexes: list[int] = []
577 for idx, (item_key, item_value) in enumerate(self._list):
578 if item_key == set_key:
579 found_indexes.append(idx)
580
581 for idx in reversed(found_indexes[1:]):
582 del self._list[idx]
583
584 if found_indexes:
585 idx = found_indexes[0]
586 self._list[idx] = (set_key, set_value)
587 else:
588 self._list.append((set_key, set_value))
589
590 def __delitem__(self, key: str) -> None:
591 """
592 Remove the header `key`.
593 """
594 del_key = key.lower().encode("latin-1")
595
596 pop_indexes: list[int] = []
597 for idx, (item_key, item_value) in enumerate(self._list):
598 if item_key == del_key:
599 pop_indexes.append(idx)
600
601 for idx in reversed(pop_indexes):
602 del self._list[idx]
603
604 def __ior__(self, other: typing.Mapping[str, str]) -> MutableHeaders:
605 if not isinstance(other, typing.Mapping):
606 raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
607 self.update(other)
608 return self
609
610 def __or__(self, other: typing.Mapping[str, str]) -> MutableHeaders:
611 if not isinstance(other, typing.Mapping):
612 raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
613 new = self.mutablecopy()
614 new.update(other)
615 return new
616
617 @property
618 def raw(self) -> list[tuple[bytes, bytes]]:
619 return self._list
620
621 def setdefault(self, key: str, value: str) -> str:
622 """
623 If the header `key` does not exist, then set it to `value`.
624 Returns the header value.
625 """
626 set_key = key.lower().encode("latin-1")
627 set_value = value.encode("latin-1")
628
629 for idx, (item_key, item_value) in enumerate(self._list):
630 if item_key == set_key:
631 return item_value.decode("latin-1")
632 self._list.append((set_key, set_value))
633 return value
634
635 def update(self, other: typing.Mapping[str, str]) -> None:
636 for key, val in other.items():
637 self[key] = val
638
639 def append(self, key: str, value: str) -> None:
640 """
641 Append a header, preserving any duplicate entries.
642 """
643 append_key = key.lower().encode("latin-1")
644 append_value = value.encode("latin-1")
645 self._list.append((append_key, append_value))
646
647 def add_vary_header(self, vary: str) -> None:
648 existing = self.get("vary")
649 if existing is not None:
650 vary = ", ".join([existing, vary])
651 self["vary"] = vary
652
653
654class State:
655 """
656 An object that can be used to store arbitrary state.
657
658 Used for `request.state` and `app.state`.
659 """
660
661 _state: dict[str, typing.Any]
662
663 def __init__(self, state: dict[str, typing.Any] | None = None):
664 if state is None:
665 state = {}
666 super().__setattr__("_state", state)
667
668 def __setattr__(self, key: typing.Any, value: typing.Any) -> None:
669 self._state[key] = value
670
671 def __getattr__(self, key: typing.Any) -> typing.Any:
672 try:
673 return self._state[key]
674 except KeyError:
675 message = "'{}' object has no attribute '{}'"
676 raise AttributeError(message.format(self.__class__.__name__, key))
677
678 def __delattr__(self, key: typing.Any) -> None:
679 del self._state[key]