Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/starlette/requests.py: 34%
195 statements
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 06:12 +0000
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 06:12 +0000
1import json
2import typing
3from http import cookies as http_cookies
5import anyio
7from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper
8from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
9from starlette.exceptions import HTTPException
10from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
11from starlette.types import Message, Receive, Scope, Send
13try:
14 from multipart.multipart import parse_options_header
15except ImportError: # pragma: nocover
16 parse_options_header = None
19if typing.TYPE_CHECKING:
20 from starlette.routing import Router
23SERVER_PUSH_HEADERS_TO_COPY = {
24 "accept",
25 "accept-encoding",
26 "accept-language",
27 "cache-control",
28 "user-agent",
29}
32def cookie_parser(cookie_string: str) -> typing.Dict[str, str]:
33 """
34 This function parses a ``Cookie`` HTTP header into a dict of key/value pairs.
36 It attempts to mimic browser cookie parsing behavior: browsers and web servers
37 frequently disregard the spec (RFC 6265) when setting and reading cookies,
38 so we attempt to suit the common scenarios here.
40 This function has been adapted from Django 3.1.0.
41 Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based
42 on an outdated spec and will fail on lots of input we want to support
43 """
44 cookie_dict: typing.Dict[str, str] = {}
45 for chunk in cookie_string.split(";"):
46 if "=" in chunk:
47 key, val = chunk.split("=", 1)
48 else:
49 # Assume an empty name per
50 # https://bugzilla.mozilla.org/show_bug.cgi?id=169091
51 key, val = "", chunk
52 key, val = key.strip(), val.strip()
53 if key or val:
54 # unquote using Python's algorithm.
55 cookie_dict[key] = http_cookies._unquote(val)
56 return cookie_dict
59class ClientDisconnect(Exception):
60 pass
63class HTTPConnection(typing.Mapping[str, typing.Any]):
64 """
65 A base class for incoming HTTP connections, that is used to provide
66 any functionality that is common to both `Request` and `WebSocket`.
67 """
69 def __init__(self, scope: Scope, receive: typing.Optional[Receive] = None) -> None:
70 assert scope["type"] in ("http", "websocket")
71 self.scope = scope
73 def __getitem__(self, key: str) -> typing.Any:
74 return self.scope[key]
76 def __iter__(self) -> typing.Iterator[str]:
77 return iter(self.scope)
79 def __len__(self) -> int:
80 return len(self.scope)
82 # Don't use the `abc.Mapping.__eq__` implementation.
83 # Connection instances should never be considered equal
84 # unless `self is other`.
85 __eq__ = object.__eq__
86 __hash__ = object.__hash__
88 @property
89 def app(self) -> typing.Any:
90 return self.scope["app"]
92 @property
93 def url(self) -> URL:
94 if not hasattr(self, "_url"):
95 self._url = URL(scope=self.scope)
96 return self._url
98 @property
99 def base_url(self) -> URL:
100 if not hasattr(self, "_base_url"):
101 base_url_scope = dict(self.scope)
102 base_url_scope["path"] = "/"
103 base_url_scope["query_string"] = b""
104 base_url_scope["root_path"] = base_url_scope.get(
105 "app_root_path", base_url_scope.get("root_path", "")
106 )
107 self._base_url = URL(scope=base_url_scope)
108 return self._base_url
110 @property
111 def headers(self) -> Headers:
112 if not hasattr(self, "_headers"):
113 self._headers = Headers(scope=self.scope)
114 return self._headers
116 @property
117 def query_params(self) -> QueryParams:
118 if not hasattr(self, "_query_params"):
119 self._query_params = QueryParams(self.scope["query_string"])
120 return self._query_params
122 @property
123 def path_params(self) -> typing.Dict[str, typing.Any]:
124 return self.scope.get("path_params", {})
126 @property
127 def cookies(self) -> typing.Dict[str, str]:
128 if not hasattr(self, "_cookies"):
129 cookies: typing.Dict[str, str] = {}
130 cookie_header = self.headers.get("cookie")
132 if cookie_header:
133 cookies = cookie_parser(cookie_header)
134 self._cookies = cookies
135 return self._cookies
137 @property
138 def client(self) -> typing.Optional[Address]:
139 # client is a 2 item tuple of (host, port), None or missing
140 host_port = self.scope.get("client")
141 if host_port is not None:
142 return Address(*host_port)
143 return None
145 @property
146 def session(self) -> typing.Dict[str, typing.Any]:
147 assert (
148 "session" in self.scope
149 ), "SessionMiddleware must be installed to access request.session"
150 return self.scope["session"]
152 @property
153 def auth(self) -> typing.Any:
154 assert (
155 "auth" in self.scope
156 ), "AuthenticationMiddleware must be installed to access request.auth"
157 return self.scope["auth"]
159 @property
160 def user(self) -> typing.Any:
161 assert (
162 "user" in self.scope
163 ), "AuthenticationMiddleware must be installed to access request.user"
164 return self.scope["user"]
166 @property
167 def state(self) -> State:
168 if not hasattr(self, "_state"):
169 # Ensure 'state' has an empty dict if it's not already populated.
170 self.scope.setdefault("state", {})
171 # Create a state instance with a reference to the dict in which it should
172 # store info
173 self._state = State(self.scope["state"])
174 return self._state
176 def url_for(self, name: str, **path_params: typing.Any) -> str:
177 router: Router = self.scope["router"]
178 url_path = router.url_path_for(name, **path_params)
179 return url_path.make_absolute_url(base_url=self.base_url)
182async def empty_receive() -> typing.NoReturn:
183 raise RuntimeError("Receive channel has not been made available")
186async def empty_send(message: Message) -> typing.NoReturn:
187 raise RuntimeError("Send channel has not been made available")
190class Request(HTTPConnection):
191 _form: typing.Optional[FormData]
193 def __init__(
194 self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
195 ):
196 super().__init__(scope)
197 assert scope["type"] == "http"
198 self._receive = receive
199 self._send = send
200 self._stream_consumed = False
201 self._is_disconnected = False
202 self._form = None
204 @property
205 def method(self) -> str:
206 return self.scope["method"]
208 @property
209 def receive(self) -> Receive:
210 return self._receive
212 async def stream(self) -> typing.AsyncGenerator[bytes, None]:
213 if hasattr(self, "_body"):
214 yield self._body
215 yield b""
216 return
217 if self._stream_consumed:
218 raise RuntimeError("Stream consumed")
219 self._stream_consumed = True
220 while True:
221 message = await self._receive()
222 if message["type"] == "http.request":
223 body = message.get("body", b"")
224 if body:
225 yield body
226 if not message.get("more_body", False):
227 break
228 elif message["type"] == "http.disconnect":
229 self._is_disconnected = True
230 raise ClientDisconnect()
231 yield b""
233 async def body(self) -> bytes:
234 if not hasattr(self, "_body"):
235 chunks: "typing.List[bytes]" = []
236 async for chunk in self.stream():
237 chunks.append(chunk)
238 self._body = b"".join(chunks)
239 return self._body
241 async def json(self) -> typing.Any:
242 if not hasattr(self, "_json"):
243 body = await self.body()
244 self._json = json.loads(body)
245 return self._json
247 async def _get_form(
248 self,
249 *,
250 max_files: typing.Union[int, float] = 1000,
251 max_fields: typing.Union[int, float] = 1000,
252 ) -> FormData:
253 if self._form is None:
254 assert (
255 parse_options_header is not None
256 ), "The `python-multipart` library must be installed to use form parsing."
257 content_type_header = self.headers.get("Content-Type")
258 content_type: bytes
259 content_type, _ = parse_options_header(content_type_header)
260 if content_type == b"multipart/form-data":
261 try:
262 multipart_parser = MultiPartParser(
263 self.headers,
264 self.stream(),
265 max_files=max_files,
266 max_fields=max_fields,
267 )
268 self._form = await multipart_parser.parse()
269 except MultiPartException as exc:
270 if "app" in self.scope:
271 raise HTTPException(status_code=400, detail=exc.message)
272 raise exc
273 elif content_type == b"application/x-www-form-urlencoded":
274 form_parser = FormParser(self.headers, self.stream())
275 self._form = await form_parser.parse()
276 else:
277 self._form = FormData()
278 return self._form
280 def form(
281 self,
282 *,
283 max_files: typing.Union[int, float] = 1000,
284 max_fields: typing.Union[int, float] = 1000,
285 ) -> AwaitableOrContextManager[FormData]:
286 return AwaitableOrContextManagerWrapper(
287 self._get_form(max_files=max_files, max_fields=max_fields)
288 )
290 async def close(self) -> None:
291 if self._form is not None:
292 await self._form.close()
294 async def is_disconnected(self) -> bool:
295 if not self._is_disconnected:
296 message: Message = {}
298 # If message isn't immediately available, move on
299 with anyio.CancelScope() as cs:
300 cs.cancel()
301 message = await self._receive()
303 if message.get("type") == "http.disconnect":
304 self._is_disconnected = True
306 return self._is_disconnected
308 async def send_push_promise(self, path: str) -> None:
309 if "http.response.push" in self.scope.get("extensions", {}):
310 raw_headers: "typing.List[typing.Tuple[bytes, bytes]]" = []
311 for name in SERVER_PUSH_HEADERS_TO_COPY:
312 for value in self.headers.getlist(name):
313 raw_headers.append(
314 (name.encode("latin-1"), value.encode("latin-1"))
315 )
316 await self._send(
317 {"type": "http.response.push", "path": path, "headers": raw_headers}
318 )