Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/starlette/requests.py: 34%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import json
4import sys
5from collections.abc import AsyncGenerator, Iterator, Mapping
6from http import cookies as http_cookies
7from typing import TYPE_CHECKING, Any, Generic, NoReturn, cast
9import anyio
11from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper
12from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
13from starlette.exceptions import HTTPException
14from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
15from starlette.types import Message, Receive, Scope, Send
17if TYPE_CHECKING:
18 from python_multipart.multipart import parse_options_header
20 from starlette.applications import Starlette
21 from starlette.middleware.sessions import Session
22 from starlette.routing import Router
23else:
24 try:
25 try:
26 from python_multipart.multipart import parse_options_header
27 except ModuleNotFoundError: # pragma: no cover
28 from multipart.multipart import parse_options_header
29 except ModuleNotFoundError: # pragma: no cover
30 parse_options_header = None
32if sys.version_info >= (3, 13): # pragma: no cover
33 from typing import TypeVar
34else: # pragma: no cover
35 from typing_extensions import TypeVar
37SERVER_PUSH_HEADERS_TO_COPY = {
38 "accept",
39 "accept-encoding",
40 "accept-language",
41 "cache-control",
42 "user-agent",
43}
46def cookie_parser(cookie_string: str) -> dict[str, str]:
47 """
48 This function parses a ``Cookie`` HTTP header into a dict of key/value pairs.
50 It attempts to mimic browser cookie parsing behavior: browsers and web servers
51 frequently disregard the spec (RFC 6265) when setting and reading cookies,
52 so we attempt to suit the common scenarios here.
54 This function has been adapted from Django 3.1.0.
55 Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based
56 on an outdated spec and will fail on lots of input we want to support
57 """
58 cookie_dict: dict[str, str] = {}
59 for chunk in cookie_string.split(";"):
60 if "=" in chunk:
61 key, val = chunk.split("=", 1)
62 else:
63 # Assume an empty name per
64 # https://bugzilla.mozilla.org/show_bug.cgi?id=169091
65 key, val = "", chunk
66 key, val = key.strip(), val.strip()
67 if key or val:
68 # unquote using Python's algorithm.
69 cookie_dict[key] = http_cookies._unquote(val)
70 return cookie_dict
73class ClientDisconnect(Exception):
74 pass
77StateT = TypeVar("StateT", bound=Mapping[str, Any] | State, default=State)
80class HTTPConnection(Mapping[str, Any], Generic[StateT]):
81 """
82 A base class for incoming HTTP connections, that is used to provide
83 any functionality that is common to both `Request` and `WebSocket`.
84 """
86 def __init__(self, scope: Scope, receive: Receive | None = None) -> None:
87 assert scope["type"] in ("http", "websocket")
88 self.scope = scope
90 def __getitem__(self, key: str) -> Any:
91 return self.scope[key]
93 def __iter__(self) -> Iterator[str]:
94 return iter(self.scope)
96 def __len__(self) -> int:
97 return len(self.scope)
99 # Don't use the `abc.Mapping.__eq__` implementation.
100 # Connection instances should never be considered equal
101 # unless `self is other`.
102 __eq__ = object.__eq__
103 __hash__ = object.__hash__
105 @property
106 def app(self) -> Any:
107 return self.scope["app"]
109 @property
110 def url(self) -> URL:
111 if not hasattr(self, "_url"): # pragma: no branch
112 self._url = URL(scope=self.scope)
113 return self._url
115 @property
116 def base_url(self) -> URL:
117 if not hasattr(self, "_base_url"):
118 base_url_scope = dict(self.scope)
119 # This is used by request.url_for, it might be used inside a Mount which
120 # would have its own child scope with its own root_path, but the base URL
121 # for url_for should still be the top level app root path.
122 app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", ""))
123 path = app_root_path
124 if not path.endswith("/"):
125 path += "/"
126 base_url_scope["path"] = path
127 base_url_scope["query_string"] = b""
128 base_url_scope["root_path"] = app_root_path
129 self._base_url = URL(scope=base_url_scope)
130 return self._base_url
132 @property
133 def headers(self) -> Headers:
134 if not hasattr(self, "_headers"):
135 self._headers = Headers(scope=self.scope)
136 return self._headers
138 @property
139 def query_params(self) -> QueryParams:
140 if not hasattr(self, "_query_params"): # pragma: no branch
141 self._query_params = QueryParams(self.scope["query_string"])
142 return self._query_params
144 @property
145 def path_params(self) -> dict[str, Any]:
146 path_params: dict[str, Any] = self.scope.get("path_params", {})
147 return path_params
149 @property
150 def cookies(self) -> dict[str, str]:
151 if not hasattr(self, "_cookies"):
152 cookies: dict[str, str] = {}
153 cookie_headers = self.headers.getlist("cookie")
155 for header in cookie_headers:
156 cookies.update(cookie_parser(header))
158 self._cookies = cookies
159 return self._cookies
161 @property
162 def client(self) -> Address | None:
163 # client is a 2 item tuple of (host, port), None if missing
164 host_port = self.scope.get("client")
165 if host_port is not None:
166 return Address(*host_port)
167 return None
169 @property
170 def session(self) -> dict[str, Any]:
171 assert "session" in self.scope, "SessionMiddleware must be installed to access request.session"
172 session: Session = self.scope["session"]
173 # We keep the hasattr in case people actually use their own `SessionMiddleware` implementation.
174 if hasattr(session, "mark_accessed"): # pragma: no branch
175 session.mark_accessed()
176 return session
178 @property
179 def auth(self) -> Any:
180 assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth"
181 return self.scope["auth"]
183 @property
184 def user(self) -> Any:
185 assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user"
186 return self.scope["user"]
188 @property
189 def state(self) -> StateT:
190 if not hasattr(self, "_state"):
191 # Ensure 'state' has an empty dict if it's not already populated.
192 self.scope.setdefault("state", {})
193 # Create a state instance with a reference to the dict in which it should
194 # store info
195 self._state = State(self.scope["state"])
196 return cast(StateT, self._state)
198 def url_for(self, name: str, /, **path_params: Any) -> URL:
199 url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app")
200 if url_path_provider is None:
201 raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.")
202 url_path = url_path_provider.url_path_for(name, **path_params)
203 return url_path.make_absolute_url(base_url=self.base_url)
206async def empty_receive() -> NoReturn:
207 raise RuntimeError("Receive channel has not been made available")
210async def empty_send(message: Message) -> NoReturn:
211 raise RuntimeError("Send channel has not been made available")
214class Request(HTTPConnection[StateT]):
215 _form: FormData | None
217 def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send):
218 super().__init__(scope)
219 assert scope["type"] == "http"
220 self._receive = receive
221 self._send = send
222 self._stream_consumed = False
223 self._is_disconnected = False
224 self._form = None
226 @property
227 def method(self) -> str:
228 return cast(str, self.scope["method"])
230 @property
231 def receive(self) -> Receive:
232 return self._receive
234 async def stream(self) -> AsyncGenerator[bytes, None]:
235 if hasattr(self, "_body"):
236 yield self._body
237 yield b""
238 return
239 if self._stream_consumed:
240 raise RuntimeError("Stream consumed")
241 while not self._stream_consumed:
242 message = await self._receive()
243 if message["type"] == "http.request":
244 body = message.get("body", b"")
245 if not message.get("more_body", False):
246 self._stream_consumed = True
247 if body:
248 yield body
249 elif message["type"] == "http.disconnect": # pragma: no branch
250 self._is_disconnected = True
251 raise ClientDisconnect()
252 yield b""
254 async def body(self) -> bytes:
255 if not hasattr(self, "_body"):
256 chunks: list[bytes] = []
257 async for chunk in self.stream():
258 chunks.append(chunk)
259 self._body = b"".join(chunks)
260 return self._body
262 async def json(self) -> Any:
263 if not hasattr(self, "_json"): # pragma: no branch
264 body = await self.body()
265 self._json = json.loads(body)
266 return self._json
268 async def _get_form(
269 self,
270 *,
271 max_files: int | float = 1000,
272 max_fields: int | float = 1000,
273 max_part_size: int = 1024 * 1024,
274 ) -> FormData:
275 if self._form is None: # pragma: no branch
276 assert parse_options_header is not None, (
277 "The `python-multipart` library must be installed to use form parsing."
278 )
279 content_type_header = self.headers.get("Content-Type")
280 content_type: bytes
281 content_type, _ = parse_options_header(content_type_header)
282 if content_type == b"multipart/form-data":
283 try:
284 multipart_parser = MultiPartParser(
285 self.headers,
286 self.stream(),
287 max_files=max_files,
288 max_fields=max_fields,
289 max_part_size=max_part_size,
290 )
291 self._form = await multipart_parser.parse()
292 except MultiPartException as exc:
293 if "app" in self.scope:
294 raise HTTPException(status_code=400, detail=exc.message)
295 raise exc
296 elif content_type == b"application/x-www-form-urlencoded":
297 form_parser = FormParser(self.headers, self.stream())
298 self._form = await form_parser.parse()
299 else:
300 self._form = FormData()
301 return self._form
303 def form(
304 self,
305 *,
306 max_files: int | float = 1000,
307 max_fields: int | float = 1000,
308 max_part_size: int = 1024 * 1024,
309 ) -> AwaitableOrContextManager[FormData]:
310 return AwaitableOrContextManagerWrapper(
311 self._get_form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size)
312 )
314 async def close(self) -> None:
315 if self._form is not None: # pragma: no branch
316 await self._form.close()
318 async def is_disconnected(self) -> bool:
319 if not self._is_disconnected:
320 message: Message = {}
322 # If message isn't immediately available, move on
323 with anyio.CancelScope() as cs:
324 cs.cancel()
325 message = await self._receive()
327 if message.get("type") == "http.disconnect":
328 self._is_disconnected = True
330 return self._is_disconnected
332 async def send_push_promise(self, path: str) -> None:
333 if "http.response.push" in self.scope.get("extensions", {}):
334 raw_headers: list[tuple[bytes, bytes]] = []
335 for name in SERVER_PUSH_HEADERS_TO_COPY:
336 for value in self.headers.getlist(name):
337 raw_headers.append((name.encode("latin-1"), value.encode("latin-1")))
338 await self._send({"type": "http.response.push", "path": path, "headers": raw_headers})