1from __future__ import annotations
2
3import json
4import typing
5from http import cookies as http_cookies
6
7import anyio
8
9from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper
10from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
11from starlette.exceptions import HTTPException
12from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
13from starlette.types import Message, Receive, Scope, Send
14
15try:
16 from multipart.multipart import parse_options_header
17except ModuleNotFoundError: # pragma: nocover
18 parse_options_header = None
19
20
21if typing.TYPE_CHECKING:
22 from starlette.routing import Router
23
24
25SERVER_PUSH_HEADERS_TO_COPY = {
26 "accept",
27 "accept-encoding",
28 "accept-language",
29 "cache-control",
30 "user-agent",
31}
32
33
34def cookie_parser(cookie_string: str) -> dict[str, str]:
35 """
36 This function parses a ``Cookie`` HTTP header into a dict of key/value pairs.
37
38 It attempts to mimic browser cookie parsing behavior: browsers and web servers
39 frequently disregard the spec (RFC 6265) when setting and reading cookies,
40 so we attempt to suit the common scenarios here.
41
42 This function has been adapted from Django 3.1.0.
43 Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based
44 on an outdated spec and will fail on lots of input we want to support
45 """
46 cookie_dict: dict[str, str] = {}
47 for chunk in cookie_string.split(";"):
48 if "=" in chunk:
49 key, val = chunk.split("=", 1)
50 else:
51 # Assume an empty name per
52 # https://bugzilla.mozilla.org/show_bug.cgi?id=169091
53 key, val = "", chunk
54 key, val = key.strip(), val.strip()
55 if key or val:
56 # unquote using Python's algorithm.
57 cookie_dict[key] = http_cookies._unquote(val)
58 return cookie_dict
59
60
61class ClientDisconnect(Exception):
62 pass
63
64
65class HTTPConnection(typing.Mapping[str, typing.Any]):
66 """
67 A base class for incoming HTTP connections, that is used to provide
68 any functionality that is common to both `Request` and `WebSocket`.
69 """
70
71 def __init__(self, scope: Scope, receive: Receive | None = None) -> None:
72 assert scope["type"] in ("http", "websocket")
73 self.scope = scope
74
75 def __getitem__(self, key: str) -> typing.Any:
76 return self.scope[key]
77
78 def __iter__(self) -> typing.Iterator[str]:
79 return iter(self.scope)
80
81 def __len__(self) -> int:
82 return len(self.scope)
83
84 # Don't use the `abc.Mapping.__eq__` implementation.
85 # Connection instances should never be considered equal
86 # unless `self is other`.
87 __eq__ = object.__eq__
88 __hash__ = object.__hash__
89
90 @property
91 def app(self) -> typing.Any:
92 return self.scope["app"]
93
94 @property
95 def url(self) -> URL:
96 if not hasattr(self, "_url"):
97 self._url = URL(scope=self.scope)
98 return self._url
99
100 @property
101 def base_url(self) -> URL:
102 if not hasattr(self, "_base_url"):
103 base_url_scope = dict(self.scope)
104 # This is used by request.url_for, it might be used inside a Mount which
105 # would have its own child scope with its own root_path, but the base URL
106 # for url_for should still be the top level app root path.
107 app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", ""))
108 path = app_root_path
109 if not path.endswith("/"):
110 path += "/"
111 base_url_scope["path"] = path
112 base_url_scope["query_string"] = b""
113 base_url_scope["root_path"] = app_root_path
114 self._base_url = URL(scope=base_url_scope)
115 return self._base_url
116
117 @property
118 def headers(self) -> Headers:
119 if not hasattr(self, "_headers"):
120 self._headers = Headers(scope=self.scope)
121 return self._headers
122
123 @property
124 def query_params(self) -> QueryParams:
125 if not hasattr(self, "_query_params"):
126 self._query_params = QueryParams(self.scope["query_string"])
127 return self._query_params
128
129 @property
130 def path_params(self) -> dict[str, typing.Any]:
131 return self.scope.get("path_params", {})
132
133 @property
134 def cookies(self) -> dict[str, str]:
135 if not hasattr(self, "_cookies"):
136 cookies: dict[str, str] = {}
137 cookie_header = self.headers.get("cookie")
138
139 if cookie_header:
140 cookies = cookie_parser(cookie_header)
141 self._cookies = cookies
142 return self._cookies
143
144 @property
145 def client(self) -> Address | None:
146 # client is a 2 item tuple of (host, port), None if missing
147 host_port = self.scope.get("client")
148 if host_port is not None:
149 return Address(*host_port)
150 return None
151
152 @property
153 def session(self) -> dict[str, typing.Any]:
154 assert "session" in self.scope, "SessionMiddleware must be installed to access request.session"
155 return self.scope["session"] # type: ignore[no-any-return]
156
157 @property
158 def auth(self) -> typing.Any:
159 assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth"
160 return self.scope["auth"]
161
162 @property
163 def user(self) -> typing.Any:
164 assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user"
165 return self.scope["user"]
166
167 @property
168 def state(self) -> State:
169 if not hasattr(self, "_state"):
170 # Ensure 'state' has an empty dict if it's not already populated.
171 self.scope.setdefault("state", {})
172 # Create a state instance with a reference to the dict in which it should
173 # store info
174 self._state = State(self.scope["state"])
175 return self._state
176
177 def url_for(self, name: str, /, **path_params: typing.Any) -> URL:
178 router: Router = self.scope["router"]
179 url_path = router.url_path_for(name, **path_params)
180 return url_path.make_absolute_url(base_url=self.base_url)
181
182
183async def empty_receive() -> typing.NoReturn:
184 raise RuntimeError("Receive channel has not been made available")
185
186
187async def empty_send(message: Message) -> typing.NoReturn:
188 raise RuntimeError("Send channel has not been made available")
189
190
191class Request(HTTPConnection):
192 _form: FormData | None
193
194 def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send):
195 super().__init__(scope)
196 assert scope["type"] == "http"
197 self._receive = receive
198 self._send = send
199 self._stream_consumed = False
200 self._is_disconnected = False
201 self._form = None
202
203 @property
204 def method(self) -> str:
205 return typing.cast(str, self.scope["method"])
206
207 @property
208 def receive(self) -> Receive:
209 return self._receive
210
211 async def stream(self) -> typing.AsyncGenerator[bytes, None]:
212 if hasattr(self, "_body"):
213 yield self._body
214 yield b""
215 return
216 if self._stream_consumed:
217 raise RuntimeError("Stream consumed")
218 while not self._stream_consumed:
219 message = await self._receive()
220 if message["type"] == "http.request":
221 body = message.get("body", b"")
222 if not message.get("more_body", False):
223 self._stream_consumed = True
224 if body:
225 yield body
226 elif message["type"] == "http.disconnect":
227 self._is_disconnected = True
228 raise ClientDisconnect()
229 yield b""
230
231 async def body(self) -> bytes:
232 if not hasattr(self, "_body"):
233 chunks: list[bytes] = []
234 async for chunk in self.stream():
235 chunks.append(chunk)
236 self._body = b"".join(chunks)
237 return self._body
238
239 async def json(self) -> typing.Any:
240 if not hasattr(self, "_json"):
241 body = await self.body()
242 self._json = json.loads(body)
243 return self._json
244
245 async def _get_form(self, *, max_files: int | float = 1000, max_fields: int | float = 1000) -> FormData:
246 if self._form is None:
247 assert (
248 parse_options_header is not None
249 ), "The `python-multipart` library must be installed to use form parsing."
250 content_type_header = self.headers.get("Content-Type")
251 content_type: bytes
252 content_type, _ = parse_options_header(content_type_header)
253 if content_type == b"multipart/form-data":
254 try:
255 multipart_parser = MultiPartParser(
256 self.headers,
257 self.stream(),
258 max_files=max_files,
259 max_fields=max_fields,
260 )
261 self._form = await multipart_parser.parse()
262 except MultiPartException as exc:
263 if "app" in self.scope:
264 raise HTTPException(status_code=400, detail=exc.message)
265 raise exc
266 elif content_type == b"application/x-www-form-urlencoded":
267 form_parser = FormParser(self.headers, self.stream())
268 self._form = await form_parser.parse()
269 else:
270 self._form = FormData()
271 return self._form
272
273 def form(
274 self, *, max_files: int | float = 1000, max_fields: int | float = 1000
275 ) -> AwaitableOrContextManager[FormData]:
276 return AwaitableOrContextManagerWrapper(self._get_form(max_files=max_files, max_fields=max_fields))
277
278 async def close(self) -> None:
279 if self._form is not None:
280 await self._form.close()
281
282 async def is_disconnected(self) -> bool:
283 if not self._is_disconnected:
284 message: Message = {}
285
286 # If message isn't immediately available, move on
287 with anyio.CancelScope() as cs:
288 cs.cancel()
289 message = await self._receive()
290
291 if message.get("type") == "http.disconnect":
292 self._is_disconnected = True
293
294 return self._is_disconnected
295
296 async def send_push_promise(self, path: str) -> None:
297 if "http.response.push" in self.scope.get("extensions", {}):
298 raw_headers: list[tuple[bytes, bytes]] = []
299 for name in SERVER_PUSH_HEADERS_TO_COPY:
300 for value in self.headers.getlist(name):
301 raw_headers.append((name.encode("latin-1"), value.encode("latin-1")))
302 await self._send({"type": "http.response.push", "path": path, "headers": raw_headers})