Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/starlette/requests.py: 34%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import json
4from collections.abc import AsyncGenerator, Iterator, Mapping
5from http import cookies as http_cookies
6from typing import TYPE_CHECKING, Any, NoReturn, cast
8import anyio
10from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper
11from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
12from starlette.exceptions import HTTPException
13from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
14from starlette.types import Message, Receive, Scope, Send
16if TYPE_CHECKING:
17 from python_multipart.multipart import parse_options_header
19 from starlette.applications import Starlette
20 from starlette.routing import Router
21else:
22 try:
23 try:
24 from python_multipart.multipart import parse_options_header
25 except ModuleNotFoundError: # pragma: no cover
26 from multipart.multipart import parse_options_header
27 except ModuleNotFoundError: # pragma: no cover
28 parse_options_header = None
31SERVER_PUSH_HEADERS_TO_COPY = {
32 "accept",
33 "accept-encoding",
34 "accept-language",
35 "cache-control",
36 "user-agent",
37}
40def cookie_parser(cookie_string: str) -> dict[str, str]:
41 """
42 This function parses a ``Cookie`` HTTP header into a dict of key/value pairs.
44 It attempts to mimic browser cookie parsing behavior: browsers and web servers
45 frequently disregard the spec (RFC 6265) when setting and reading cookies,
46 so we attempt to suit the common scenarios here.
48 This function has been adapted from Django 3.1.0.
49 Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based
50 on an outdated spec and will fail on lots of input we want to support
51 """
52 cookie_dict: dict[str, str] = {}
53 for chunk in cookie_string.split(";"):
54 if "=" in chunk:
55 key, val = chunk.split("=", 1)
56 else:
57 # Assume an empty name per
58 # https://bugzilla.mozilla.org/show_bug.cgi?id=169091
59 key, val = "", chunk
60 key, val = key.strip(), val.strip()
61 if key or val:
62 # unquote using Python's algorithm.
63 cookie_dict[key] = http_cookies._unquote(val)
64 return cookie_dict
67class ClientDisconnect(Exception):
68 pass
71class HTTPConnection(Mapping[str, Any]):
72 """
73 A base class for incoming HTTP connections, that is used to provide
74 any functionality that is common to both `Request` and `WebSocket`.
75 """
77 def __init__(self, scope: Scope, receive: Receive | None = None) -> None:
78 assert scope["type"] in ("http", "websocket")
79 self.scope = scope
81 def __getitem__(self, key: str) -> Any:
82 return self.scope[key]
84 def __iter__(self) -> Iterator[str]:
85 return iter(self.scope)
87 def __len__(self) -> int:
88 return len(self.scope)
90 # Don't use the `abc.Mapping.__eq__` implementation.
91 # Connection instances should never be considered equal
92 # unless `self is other`.
93 __eq__ = object.__eq__
94 __hash__ = object.__hash__
96 @property
97 def app(self) -> Any:
98 return self.scope["app"]
100 @property
101 def url(self) -> URL:
102 if not hasattr(self, "_url"): # pragma: no branch
103 self._url = URL(scope=self.scope)
104 return self._url
106 @property
107 def base_url(self) -> URL:
108 if not hasattr(self, "_base_url"):
109 base_url_scope = dict(self.scope)
110 # This is used by request.url_for, it might be used inside a Mount which
111 # would have its own child scope with its own root_path, but the base URL
112 # for url_for should still be the top level app root path.
113 app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", ""))
114 path = app_root_path
115 if not path.endswith("/"):
116 path += "/"
117 base_url_scope["path"] = path
118 base_url_scope["query_string"] = b""
119 base_url_scope["root_path"] = app_root_path
120 self._base_url = URL(scope=base_url_scope)
121 return self._base_url
123 @property
124 def headers(self) -> Headers:
125 if not hasattr(self, "_headers"):
126 self._headers = Headers(scope=self.scope)
127 return self._headers
129 @property
130 def query_params(self) -> QueryParams:
131 if not hasattr(self, "_query_params"): # pragma: no branch
132 self._query_params = QueryParams(self.scope["query_string"])
133 return self._query_params
135 @property
136 def path_params(self) -> dict[str, Any]:
137 return self.scope.get("path_params", {})
139 @property
140 def cookies(self) -> dict[str, str]:
141 if not hasattr(self, "_cookies"):
142 cookies: dict[str, str] = {}
143 cookie_headers = self.headers.getlist("cookie")
145 for header in cookie_headers:
146 cookies.update(cookie_parser(header))
148 self._cookies = cookies
149 return self._cookies
151 @property
152 def client(self) -> Address | None:
153 # client is a 2 item tuple of (host, port), None if missing
154 host_port = self.scope.get("client")
155 if host_port is not None:
156 return Address(*host_port)
157 return None
159 @property
160 def session(self) -> dict[str, Any]:
161 assert "session" in self.scope, "SessionMiddleware must be installed to access request.session"
162 return self.scope["session"] # type: ignore[no-any-return]
164 @property
165 def auth(self) -> Any:
166 assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth"
167 return self.scope["auth"]
169 @property
170 def user(self) -> Any:
171 assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user"
172 return self.scope["user"]
174 @property
175 def state(self) -> State:
176 if not hasattr(self, "_state"):
177 # Ensure 'state' has an empty dict if it's not already populated.
178 self.scope.setdefault("state", {})
179 # Create a state instance with a reference to the dict in which it should
180 # store info
181 self._state = State(self.scope["state"])
182 return self._state
184 def url_for(self, name: str, /, **path_params: Any) -> URL:
185 url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app")
186 if url_path_provider is None:
187 raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.")
188 url_path = url_path_provider.url_path_for(name, **path_params)
189 return url_path.make_absolute_url(base_url=self.base_url)
192async def empty_receive() -> NoReturn:
193 raise RuntimeError("Receive channel has not been made available")
196async def empty_send(message: Message) -> NoReturn:
197 raise RuntimeError("Send channel has not been made available")
200class Request(HTTPConnection):
201 _form: FormData | None
203 def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send):
204 super().__init__(scope)
205 assert scope["type"] == "http"
206 self._receive = receive
207 self._send = send
208 self._stream_consumed = False
209 self._is_disconnected = False
210 self._form = None
212 @property
213 def method(self) -> str:
214 return cast(str, self.scope["method"])
216 @property
217 def receive(self) -> Receive:
218 return self._receive
220 async def stream(self) -> AsyncGenerator[bytes, None]:
221 if hasattr(self, "_body"):
222 yield self._body
223 yield b""
224 return
225 if self._stream_consumed:
226 raise RuntimeError("Stream consumed")
227 while not self._stream_consumed:
228 message = await self._receive()
229 if message["type"] == "http.request":
230 body = message.get("body", b"")
231 if not message.get("more_body", False):
232 self._stream_consumed = True
233 if body:
234 yield body
235 elif message["type"] == "http.disconnect": # pragma: no branch
236 self._is_disconnected = True
237 raise ClientDisconnect()
238 yield b""
240 async def body(self) -> bytes:
241 if not hasattr(self, "_body"):
242 chunks: list[bytes] = []
243 async for chunk in self.stream():
244 chunks.append(chunk)
245 self._body = b"".join(chunks)
246 return self._body
248 async def json(self) -> Any:
249 if not hasattr(self, "_json"): # pragma: no branch
250 body = await self.body()
251 self._json = json.loads(body)
252 return self._json
254 async def _get_form(
255 self,
256 *,
257 max_files: int | float = 1000,
258 max_fields: int | float = 1000,
259 max_part_size: int = 1024 * 1024,
260 ) -> FormData:
261 if self._form is None: # pragma: no branch
262 assert parse_options_header is not None, (
263 "The `python-multipart` library must be installed to use form parsing."
264 )
265 content_type_header = self.headers.get("Content-Type")
266 content_type: bytes
267 content_type, _ = parse_options_header(content_type_header)
268 if content_type == b"multipart/form-data":
269 try:
270 multipart_parser = MultiPartParser(
271 self.headers,
272 self.stream(),
273 max_files=max_files,
274 max_fields=max_fields,
275 max_part_size=max_part_size,
276 )
277 self._form = await multipart_parser.parse()
278 except MultiPartException as exc:
279 if "app" in self.scope:
280 raise HTTPException(status_code=400, detail=exc.message)
281 raise exc
282 elif content_type == b"application/x-www-form-urlencoded":
283 form_parser = FormParser(self.headers, self.stream())
284 self._form = await form_parser.parse()
285 else:
286 self._form = FormData()
287 return self._form
289 def form(
290 self,
291 *,
292 max_files: int | float = 1000,
293 max_fields: int | float = 1000,
294 max_part_size: int = 1024 * 1024,
295 ) -> AwaitableOrContextManager[FormData]:
296 return AwaitableOrContextManagerWrapper(
297 self._get_form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size)
298 )
300 async def close(self) -> None:
301 if self._form is not None: # pragma: no branch
302 await self._form.close()
304 async def is_disconnected(self) -> bool:
305 if not self._is_disconnected:
306 message: Message = {}
308 # If message isn't immediately available, move on
309 with anyio.CancelScope() as cs:
310 cs.cancel()
311 message = await self._receive()
313 if message.get("type") == "http.disconnect":
314 self._is_disconnected = True
316 return self._is_disconnected
318 async def send_push_promise(self, path: str) -> None:
319 if "http.response.push" in self.scope.get("extensions", {}):
320 raw_headers: list[tuple[bytes, bytes]] = []
321 for name in SERVER_PUSH_HEADERS_TO_COPY:
322 for value in self.headers.getlist(name):
323 raw_headers.append((name.encode("latin-1"), value.encode("latin-1")))
324 await self._send({"type": "http.response.push", "path": path, "headers": raw_headers})