1from __future__ import annotations
2
3import json
4import typing
5from http import cookies as http_cookies
6
7import anyio
8
9from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper
10from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
11from starlette.exceptions import HTTPException
12from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
13from starlette.types import Message, Receive, Scope, Send
14
15try:
16 from multipart.multipart import parse_options_header
17except ModuleNotFoundError: # pragma: nocover
18 parse_options_header = None
19
20
21if typing.TYPE_CHECKING:
22 from starlette.routing import Router
23
24
25SERVER_PUSH_HEADERS_TO_COPY = {
26 "accept",
27 "accept-encoding",
28 "accept-language",
29 "cache-control",
30 "user-agent",
31}
32
33
34def cookie_parser(cookie_string: str) -> dict[str, str]:
35 """
36 This function parses a ``Cookie`` HTTP header into a dict of key/value pairs.
37
38 It attempts to mimic browser cookie parsing behavior: browsers and web servers
39 frequently disregard the spec (RFC 6265) when setting and reading cookies,
40 so we attempt to suit the common scenarios here.
41
42 This function has been adapted from Django 3.1.0.
43 Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based
44 on an outdated spec and will fail on lots of input we want to support
45 """
46 cookie_dict: dict[str, str] = {}
47 for chunk in cookie_string.split(";"):
48 if "=" in chunk:
49 key, val = chunk.split("=", 1)
50 else:
51 # Assume an empty name per
52 # https://bugzilla.mozilla.org/show_bug.cgi?id=169091
53 key, val = "", chunk
54 key, val = key.strip(), val.strip()
55 if key or val:
56 # unquote using Python's algorithm.
57 cookie_dict[key] = http_cookies._unquote(val)
58 return cookie_dict
59
60
61class ClientDisconnect(Exception):
62 pass
63
64
65class HTTPConnection(typing.Mapping[str, typing.Any]):
66 """
67 A base class for incoming HTTP connections, that is used to provide
68 any functionality that is common to both `Request` and `WebSocket`.
69 """
70
71 def __init__(self, scope: Scope, receive: Receive | None = None) -> None:
72 assert scope["type"] in ("http", "websocket")
73 self.scope = scope
74
75 def __getitem__(self, key: str) -> typing.Any:
76 return self.scope[key]
77
78 def __iter__(self) -> typing.Iterator[str]:
79 return iter(self.scope)
80
81 def __len__(self) -> int:
82 return len(self.scope)
83
84 # Don't use the `abc.Mapping.__eq__` implementation.
85 # Connection instances should never be considered equal
86 # unless `self is other`.
87 __eq__ = object.__eq__
88 __hash__ = object.__hash__
89
90 @property
91 def app(self) -> typing.Any:
92 return self.scope["app"]
93
94 @property
95 def url(self) -> URL:
96 if not hasattr(self, "_url"):
97 self._url = URL(scope=self.scope)
98 return self._url
99
100 @property
101 def base_url(self) -> URL:
102 if not hasattr(self, "_base_url"):
103 base_url_scope = dict(self.scope)
104 # This is used by request.url_for, it might be used inside a Mount which
105 # would have its own child scope with its own root_path, but the base URL
106 # for url_for should still be the top level app root path.
107 app_root_path = base_url_scope.get(
108 "app_root_path", base_url_scope.get("root_path", "")
109 )
110 path = app_root_path
111 if not path.endswith("/"):
112 path += "/"
113 base_url_scope["path"] = path
114 base_url_scope["query_string"] = b""
115 base_url_scope["root_path"] = app_root_path
116 self._base_url = URL(scope=base_url_scope)
117 return self._base_url
118
119 @property
120 def headers(self) -> Headers:
121 if not hasattr(self, "_headers"):
122 self._headers = Headers(scope=self.scope)
123 return self._headers
124
125 @property
126 def query_params(self) -> QueryParams:
127 if not hasattr(self, "_query_params"):
128 self._query_params = QueryParams(self.scope["query_string"])
129 return self._query_params
130
131 @property
132 def path_params(self) -> dict[str, typing.Any]:
133 return self.scope.get("path_params", {})
134
135 @property
136 def cookies(self) -> dict[str, str]:
137 if not hasattr(self, "_cookies"):
138 cookies: dict[str, str] = {}
139 cookie_header = self.headers.get("cookie")
140
141 if cookie_header:
142 cookies = cookie_parser(cookie_header)
143 self._cookies = cookies
144 return self._cookies
145
146 @property
147 def client(self) -> Address | None:
148 # client is a 2 item tuple of (host, port), None or missing
149 host_port = self.scope.get("client")
150 if host_port is not None:
151 return Address(*host_port)
152 return None
153
154 @property
155 def session(self) -> dict[str, typing.Any]:
156 assert (
157 "session" in self.scope
158 ), "SessionMiddleware must be installed to access request.session"
159 return self.scope["session"] # type: ignore[no-any-return]
160
161 @property
162 def auth(self) -> typing.Any:
163 assert (
164 "auth" in self.scope
165 ), "AuthenticationMiddleware must be installed to access request.auth"
166 return self.scope["auth"]
167
168 @property
169 def user(self) -> typing.Any:
170 assert (
171 "user" in self.scope
172 ), "AuthenticationMiddleware must be installed to access request.user"
173 return self.scope["user"]
174
175 @property
176 def state(self) -> State:
177 if not hasattr(self, "_state"):
178 # Ensure 'state' has an empty dict if it's not already populated.
179 self.scope.setdefault("state", {})
180 # Create a state instance with a reference to the dict in which it should
181 # store info
182 self._state = State(self.scope["state"])
183 return self._state
184
185 def url_for(self, name: str, /, **path_params: typing.Any) -> URL:
186 router: Router = self.scope["router"]
187 url_path = router.url_path_for(name, **path_params)
188 return url_path.make_absolute_url(base_url=self.base_url)
189
190
191async def empty_receive() -> typing.NoReturn:
192 raise RuntimeError("Receive channel has not been made available")
193
194
195async def empty_send(message: Message) -> typing.NoReturn:
196 raise RuntimeError("Send channel has not been made available")
197
198
199class Request(HTTPConnection):
200 _form: FormData | None
201
202 def __init__(
203 self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
204 ):
205 super().__init__(scope)
206 assert scope["type"] == "http"
207 self._receive = receive
208 self._send = send
209 self._stream_consumed = False
210 self._is_disconnected = False
211 self._form = None
212
213 @property
214 def method(self) -> str:
215 return typing.cast(str, self.scope["method"])
216
217 @property
218 def receive(self) -> Receive:
219 return self._receive
220
221 async def stream(self) -> typing.AsyncGenerator[bytes, None]:
222 if hasattr(self, "_body"):
223 yield self._body
224 yield b""
225 return
226 if self._stream_consumed:
227 raise RuntimeError("Stream consumed")
228 while not self._stream_consumed:
229 message = await self._receive()
230 if message["type"] == "http.request":
231 body = message.get("body", b"")
232 if not message.get("more_body", False):
233 self._stream_consumed = True
234 if body:
235 yield body
236 elif message["type"] == "http.disconnect":
237 self._is_disconnected = True
238 raise ClientDisconnect()
239 yield b""
240
241 async def body(self) -> bytes:
242 if not hasattr(self, "_body"):
243 chunks: list[bytes] = []
244 async for chunk in self.stream():
245 chunks.append(chunk)
246 self._body = b"".join(chunks)
247 return self._body
248
249 async def json(self) -> typing.Any:
250 if not hasattr(self, "_json"):
251 body = await self.body()
252 self._json = json.loads(body)
253 return self._json
254
255 async def _get_form(
256 self, *, max_files: int | float = 1000, max_fields: int | float = 1000
257 ) -> FormData:
258 if self._form is None:
259 assert (
260 parse_options_header is not None
261 ), "The `python-multipart` library must be installed to use form parsing."
262 content_type_header = self.headers.get("Content-Type")
263 content_type: bytes
264 content_type, _ = parse_options_header(content_type_header)
265 if content_type == b"multipart/form-data":
266 try:
267 multipart_parser = MultiPartParser(
268 self.headers,
269 self.stream(),
270 max_files=max_files,
271 max_fields=max_fields,
272 )
273 self._form = await multipart_parser.parse()
274 except MultiPartException as exc:
275 if "app" in self.scope:
276 raise HTTPException(status_code=400, detail=exc.message)
277 raise exc
278 elif content_type == b"application/x-www-form-urlencoded":
279 form_parser = FormParser(self.headers, self.stream())
280 self._form = await form_parser.parse()
281 else:
282 self._form = FormData()
283 return self._form
284
285 def form(
286 self, *, max_files: int | float = 1000, max_fields: int | float = 1000
287 ) -> AwaitableOrContextManager[FormData]:
288 return AwaitableOrContextManagerWrapper(
289 self._get_form(max_files=max_files, max_fields=max_fields)
290 )
291
292 async def close(self) -> None:
293 if self._form is not None:
294 await self._form.close()
295
296 async def is_disconnected(self) -> bool:
297 if not self._is_disconnected:
298 message: Message = {}
299
300 # If message isn't immediately available, move on
301 with anyio.CancelScope() as cs:
302 cs.cancel()
303 message = await self._receive()
304
305 if message.get("type") == "http.disconnect":
306 self._is_disconnected = True
307
308 return self._is_disconnected
309
310 async def send_push_promise(self, path: str) -> None:
311 if "http.response.push" in self.scope.get("extensions", {}):
312 raw_headers: list[tuple[bytes, bytes]] = []
313 for name in SERVER_PUSH_HEADERS_TO_COPY:
314 for value in self.headers.getlist(name):
315 raw_headers.append(
316 (name.encode("latin-1"), value.encode("latin-1"))
317 )
318 await self._send(
319 {"type": "http.response.push", "path": path, "headers": raw_headers}
320 )