1from __future__ import annotations
2
3import json
4from collections.abc import AsyncGenerator, Iterator, Mapping
5from http import cookies as http_cookies
6from typing import TYPE_CHECKING, Any, NoReturn, cast
7
8import anyio
9
10from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper
11from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
12from starlette.exceptions import HTTPException
13from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
14from starlette.types import Message, Receive, Scope, Send
15
16if TYPE_CHECKING:
17 from python_multipart.multipart import parse_options_header
18
19 from starlette.applications import Starlette
20 from starlette.routing import Router
21else:
22 try:
23 try:
24 from python_multipart.multipart import parse_options_header
25 except ModuleNotFoundError: # pragma: no cover
26 from multipart.multipart import parse_options_header
27 except ModuleNotFoundError: # pragma: no cover
28 parse_options_header = None
29
30
31SERVER_PUSH_HEADERS_TO_COPY = {
32 "accept",
33 "accept-encoding",
34 "accept-language",
35 "cache-control",
36 "user-agent",
37}
38
39
40def cookie_parser(cookie_string: str) -> dict[str, str]:
41 """
42 This function parses a ``Cookie`` HTTP header into a dict of key/value pairs.
43
44 It attempts to mimic browser cookie parsing behavior: browsers and web servers
45 frequently disregard the spec (RFC 6265) when setting and reading cookies,
46 so we attempt to suit the common scenarios here.
47
48 This function has been adapted from Django 3.1.0.
49 Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based
50 on an outdated spec and will fail on lots of input we want to support
51 """
52 cookie_dict: dict[str, str] = {}
53 for chunk in cookie_string.split(";"):
54 if "=" in chunk:
55 key, val = chunk.split("=", 1)
56 else:
57 # Assume an empty name per
58 # https://bugzilla.mozilla.org/show_bug.cgi?id=169091
59 key, val = "", chunk
60 key, val = key.strip(), val.strip()
61 if key or val:
62 # unquote using Python's algorithm.
63 cookie_dict[key] = http_cookies._unquote(val)
64 return cookie_dict
65
66
67class ClientDisconnect(Exception):
68 pass
69
70
71class HTTPConnection(Mapping[str, Any]):
72 """
73 A base class for incoming HTTP connections, that is used to provide
74 any functionality that is common to both `Request` and `WebSocket`.
75 """
76
77 def __init__(self, scope: Scope, receive: Receive | None = None) -> None:
78 assert scope["type"] in ("http", "websocket")
79 self.scope = scope
80
81 def __getitem__(self, key: str) -> Any:
82 return self.scope[key]
83
84 def __iter__(self) -> Iterator[str]:
85 return iter(self.scope)
86
87 def __len__(self) -> int:
88 return len(self.scope)
89
90 # Don't use the `abc.Mapping.__eq__` implementation.
91 # Connection instances should never be considered equal
92 # unless `self is other`.
93 __eq__ = object.__eq__
94 __hash__ = object.__hash__
95
96 @property
97 def app(self) -> Any:
98 return self.scope["app"]
99
100 @property
101 def url(self) -> URL:
102 if not hasattr(self, "_url"): # pragma: no branch
103 self._url = URL(scope=self.scope)
104 return self._url
105
106 @property
107 def base_url(self) -> URL:
108 if not hasattr(self, "_base_url"):
109 base_url_scope = dict(self.scope)
110 # This is used by request.url_for, it might be used inside a Mount which
111 # would have its own child scope with its own root_path, but the base URL
112 # for url_for should still be the top level app root path.
113 app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", ""))
114 path = app_root_path
115 if not path.endswith("/"):
116 path += "/"
117 base_url_scope["path"] = path
118 base_url_scope["query_string"] = b""
119 base_url_scope["root_path"] = app_root_path
120 self._base_url = URL(scope=base_url_scope)
121 return self._base_url
122
123 @property
124 def headers(self) -> Headers:
125 if not hasattr(self, "_headers"):
126 self._headers = Headers(scope=self.scope)
127 return self._headers
128
129 @property
130 def query_params(self) -> QueryParams:
131 if not hasattr(self, "_query_params"): # pragma: no branch
132 self._query_params = QueryParams(self.scope["query_string"])
133 return self._query_params
134
135 @property
136 def path_params(self) -> dict[str, Any]:
137 return self.scope.get("path_params", {})
138
139 @property
140 def cookies(self) -> dict[str, str]:
141 if not hasattr(self, "_cookies"):
142 cookies: dict[str, str] = {}
143 cookie_header = self.headers.get("cookie")
144
145 if cookie_header:
146 cookies = cookie_parser(cookie_header)
147 self._cookies = cookies
148 return self._cookies
149
150 @property
151 def client(self) -> Address | None:
152 # client is a 2 item tuple of (host, port), None if missing
153 host_port = self.scope.get("client")
154 if host_port is not None:
155 return Address(*host_port)
156 return None
157
158 @property
159 def session(self) -> dict[str, Any]:
160 assert "session" in self.scope, "SessionMiddleware must be installed to access request.session"
161 return self.scope["session"] # type: ignore[no-any-return]
162
163 @property
164 def auth(self) -> Any:
165 assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth"
166 return self.scope["auth"]
167
168 @property
169 def user(self) -> Any:
170 assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user"
171 return self.scope["user"]
172
173 @property
174 def state(self) -> State:
175 if not hasattr(self, "_state"):
176 # Ensure 'state' has an empty dict if it's not already populated.
177 self.scope.setdefault("state", {})
178 # Create a state instance with a reference to the dict in which it should
179 # store info
180 self._state = State(self.scope["state"])
181 return self._state
182
183 def url_for(self, name: str, /, **path_params: Any) -> URL:
184 url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app")
185 if url_path_provider is None:
186 raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.")
187 url_path = url_path_provider.url_path_for(name, **path_params)
188 return url_path.make_absolute_url(base_url=self.base_url)
189
190
191async def empty_receive() -> NoReturn:
192 raise RuntimeError("Receive channel has not been made available")
193
194
195async def empty_send(message: Message) -> NoReturn:
196 raise RuntimeError("Send channel has not been made available")
197
198
199class Request(HTTPConnection):
200 _form: FormData | None
201
202 def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send):
203 super().__init__(scope)
204 assert scope["type"] == "http"
205 self._receive = receive
206 self._send = send
207 self._stream_consumed = False
208 self._is_disconnected = False
209 self._form = None
210
211 @property
212 def method(self) -> str:
213 return cast(str, self.scope["method"])
214
215 @property
216 def receive(self) -> Receive:
217 return self._receive
218
219 async def stream(self) -> AsyncGenerator[bytes, None]:
220 if hasattr(self, "_body"):
221 yield self._body
222 yield b""
223 return
224 if self._stream_consumed:
225 raise RuntimeError("Stream consumed")
226 while not self._stream_consumed:
227 message = await self._receive()
228 if message["type"] == "http.request":
229 body = message.get("body", b"")
230 if not message.get("more_body", False):
231 self._stream_consumed = True
232 if body:
233 yield body
234 elif message["type"] == "http.disconnect": # pragma: no branch
235 self._is_disconnected = True
236 raise ClientDisconnect()
237 yield b""
238
239 async def body(self) -> bytes:
240 if not hasattr(self, "_body"):
241 chunks: list[bytes] = []
242 async for chunk in self.stream():
243 chunks.append(chunk)
244 self._body = b"".join(chunks)
245 return self._body
246
247 async def json(self) -> Any:
248 if not hasattr(self, "_json"): # pragma: no branch
249 body = await self.body()
250 self._json = json.loads(body)
251 return self._json
252
253 async def _get_form(
254 self,
255 *,
256 max_files: int | float = 1000,
257 max_fields: int | float = 1000,
258 max_part_size: int = 1024 * 1024,
259 ) -> FormData:
260 if self._form is None: # pragma: no branch
261 assert parse_options_header is not None, (
262 "The `python-multipart` library must be installed to use form parsing."
263 )
264 content_type_header = self.headers.get("Content-Type")
265 content_type: bytes
266 content_type, _ = parse_options_header(content_type_header)
267 if content_type == b"multipart/form-data":
268 try:
269 multipart_parser = MultiPartParser(
270 self.headers,
271 self.stream(),
272 max_files=max_files,
273 max_fields=max_fields,
274 max_part_size=max_part_size,
275 )
276 self._form = await multipart_parser.parse()
277 except MultiPartException as exc:
278 if "app" in self.scope:
279 raise HTTPException(status_code=400, detail=exc.message)
280 raise exc
281 elif content_type == b"application/x-www-form-urlencoded":
282 form_parser = FormParser(self.headers, self.stream())
283 self._form = await form_parser.parse()
284 else:
285 self._form = FormData()
286 return self._form
287
288 def form(
289 self,
290 *,
291 max_files: int | float = 1000,
292 max_fields: int | float = 1000,
293 max_part_size: int = 1024 * 1024,
294 ) -> AwaitableOrContextManager[FormData]:
295 return AwaitableOrContextManagerWrapper(
296 self._get_form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size)
297 )
298
299 async def close(self) -> None:
300 if self._form is not None: # pragma: no branch
301 await self._form.close()
302
303 async def is_disconnected(self) -> bool:
304 if not self._is_disconnected:
305 message: Message = {}
306
307 # If message isn't immediately available, move on
308 with anyio.CancelScope() as cs:
309 cs.cancel()
310 message = await self._receive()
311
312 if message.get("type") == "http.disconnect":
313 self._is_disconnected = True
314
315 return self._is_disconnected
316
317 async def send_push_promise(self, path: str) -> None:
318 if "http.response.push" in self.scope.get("extensions", {}):
319 raw_headers: list[tuple[bytes, bytes]] = []
320 for name in SERVER_PUSH_HEADERS_TO_COPY:
321 for value in self.headers.getlist(name):
322 raw_headers.append((name.encode("latin-1"), value.encode("latin-1")))
323 await self._send({"type": "http.response.push", "path": path, "headers": raw_headers})