Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/starlette/requests.py: 34%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

213 statements  

1from __future__ import annotations 

2 

3import json 

4import sys 

5from collections.abc import AsyncGenerator, Iterator, Mapping 

6from http import cookies as http_cookies 

7from typing import TYPE_CHECKING, Any, Generic, NoReturn, cast 

8 

9import anyio 

10 

11from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper 

12from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State 

13from starlette.exceptions import HTTPException 

14from starlette.formparsers import FormParser, MultiPartException, MultiPartParser 

15from starlette.types import Message, Receive, Scope, Send 

16 

17if TYPE_CHECKING: 

18 from python_multipart.multipart import parse_options_header 

19 

20 from starlette.applications import Starlette 

21 from starlette.middleware.sessions import Session 

22 from starlette.routing import Router 

23else: 

24 try: 

25 try: 

26 from python_multipart.multipart import parse_options_header 

27 except ModuleNotFoundError: # pragma: no cover 

28 from multipart.multipart import parse_options_header 

29 except ModuleNotFoundError: # pragma: no cover 

30 parse_options_header = None 

31 

32if sys.version_info >= (3, 13): # pragma: no cover 

33 from typing import TypeVar 

34else: # pragma: no cover 

35 from typing_extensions import TypeVar 

36 

37SERVER_PUSH_HEADERS_TO_COPY = { 

38 "accept", 

39 "accept-encoding", 

40 "accept-language", 

41 "cache-control", 

42 "user-agent", 

43} 

44 

45 

46def cookie_parser(cookie_string: str) -> dict[str, str]: 

47 """ 

48 This function parses a ``Cookie`` HTTP header into a dict of key/value pairs. 

49 

50 It attempts to mimic browser cookie parsing behavior: browsers and web servers 

51 frequently disregard the spec (RFC 6265) when setting and reading cookies, 

52 so we attempt to suit the common scenarios here. 

53 

54 This function has been adapted from Django 3.1.0. 

55 Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based 

56 on an outdated spec and will fail on lots of input we want to support 

57 """ 

58 cookie_dict: dict[str, str] = {} 

59 for chunk in cookie_string.split(";"): 

60 if "=" in chunk: 

61 key, val = chunk.split("=", 1) 

62 else: 

63 # Assume an empty name per 

64 # https://bugzilla.mozilla.org/show_bug.cgi?id=169091 

65 key, val = "", chunk 

66 key, val = key.strip(), val.strip() 

67 if key or val: 

68 # unquote using Python's algorithm. 

69 cookie_dict[key] = http_cookies._unquote(val) 

70 return cookie_dict 

71 

72 

73class ClientDisconnect(Exception): 

74 pass 

75 

76 

77StateT = TypeVar("StateT", bound=Mapping[str, Any] | State, default=State) 

78 

79 

80class HTTPConnection(Mapping[str, Any], Generic[StateT]): 

81 """ 

82 A base class for incoming HTTP connections, that is used to provide 

83 any functionality that is common to both `Request` and `WebSocket`. 

84 """ 

85 

86 def __init__(self, scope: Scope, receive: Receive | None = None) -> None: 

87 assert scope["type"] in ("http", "websocket") 

88 self.scope = scope 

89 

90 def __getitem__(self, key: str) -> Any: 

91 return self.scope[key] 

92 

93 def __iter__(self) -> Iterator[str]: 

94 return iter(self.scope) 

95 

96 def __len__(self) -> int: 

97 return len(self.scope) 

98 

99 # Don't use the `abc.Mapping.__eq__` implementation. 

100 # Connection instances should never be considered equal 

101 # unless `self is other`. 

102 __eq__ = object.__eq__ 

103 __hash__ = object.__hash__ 

104 

105 @property 

106 def app(self) -> Any: 

107 return self.scope["app"] 

108 

109 @property 

110 def url(self) -> URL: 

111 if not hasattr(self, "_url"): # pragma: no branch 

112 self._url = URL(scope=self.scope) 

113 return self._url 

114 

115 @property 

116 def base_url(self) -> URL: 

117 if not hasattr(self, "_base_url"): 

118 base_url_scope = dict(self.scope) 

119 # This is used by request.url_for, it might be used inside a Mount which 

120 # would have its own child scope with its own root_path, but the base URL 

121 # for url_for should still be the top level app root path. 

122 app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", "")) 

123 path = app_root_path 

124 if not path.endswith("/"): 

125 path += "/" 

126 base_url_scope["path"] = path 

127 base_url_scope["query_string"] = b"" 

128 base_url_scope["root_path"] = app_root_path 

129 self._base_url = URL(scope=base_url_scope) 

130 return self._base_url 

131 

132 @property 

133 def headers(self) -> Headers: 

134 if not hasattr(self, "_headers"): 

135 self._headers = Headers(scope=self.scope) 

136 return self._headers 

137 

138 @property 

139 def query_params(self) -> QueryParams: 

140 if not hasattr(self, "_query_params"): # pragma: no branch 

141 self._query_params = QueryParams(self.scope["query_string"]) 

142 return self._query_params 

143 

144 @property 

145 def path_params(self) -> dict[str, Any]: 

146 path_params: dict[str, Any] = self.scope.get("path_params", {}) 

147 return path_params 

148 

149 @property 

150 def cookies(self) -> dict[str, str]: 

151 if not hasattr(self, "_cookies"): 

152 cookies: dict[str, str] = {} 

153 cookie_headers = self.headers.getlist("cookie") 

154 

155 for header in cookie_headers: 

156 cookies.update(cookie_parser(header)) 

157 

158 self._cookies = cookies 

159 return self._cookies 

160 

161 @property 

162 def client(self) -> Address | None: 

163 # client is a 2 item tuple of (host, port), None if missing 

164 host_port = self.scope.get("client") 

165 if host_port is not None: 

166 return Address(*host_port) 

167 return None 

168 

169 @property 

170 def session(self) -> dict[str, Any]: 

171 assert "session" in self.scope, "SessionMiddleware must be installed to access request.session" 

172 session: Session = self.scope["session"] 

173 # We keep the hasattr in case people actually use their own `SessionMiddleware` implementation. 

174 if hasattr(session, "mark_accessed"): # pragma: no branch 

175 session.mark_accessed() 

176 return session 

177 

178 @property 

179 def auth(self) -> Any: 

180 assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth" 

181 return self.scope["auth"] 

182 

183 @property 

184 def user(self) -> Any: 

185 assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user" 

186 return self.scope["user"] 

187 

188 @property 

189 def state(self) -> StateT: 

190 if not hasattr(self, "_state"): 

191 # Ensure 'state' has an empty dict if it's not already populated. 

192 self.scope.setdefault("state", {}) 

193 # Create a state instance with a reference to the dict in which it should 

194 # store info 

195 self._state = State(self.scope["state"]) 

196 return cast(StateT, self._state) 

197 

198 def url_for(self, name: str, /, **path_params: Any) -> URL: 

199 url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app") 

200 if url_path_provider is None: 

201 raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.") 

202 url_path = url_path_provider.url_path_for(name, **path_params) 

203 return url_path.make_absolute_url(base_url=self.base_url) 

204 

205 

206async def empty_receive() -> NoReturn: 

207 raise RuntimeError("Receive channel has not been made available") 

208 

209 

210async def empty_send(message: Message) -> NoReturn: 

211 raise RuntimeError("Send channel has not been made available") 

212 

213 

214class Request(HTTPConnection[StateT]): 

215 _form: FormData | None 

216 

217 def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send): 

218 super().__init__(scope) 

219 assert scope["type"] == "http" 

220 self._receive = receive 

221 self._send = send 

222 self._stream_consumed = False 

223 self._is_disconnected = False 

224 self._form = None 

225 

226 @property 

227 def method(self) -> str: 

228 return cast(str, self.scope["method"]) 

229 

230 @property 

231 def receive(self) -> Receive: 

232 return self._receive 

233 

234 async def stream(self) -> AsyncGenerator[bytes, None]: 

235 if hasattr(self, "_body"): 

236 yield self._body 

237 yield b"" 

238 return 

239 if self._stream_consumed: 

240 raise RuntimeError("Stream consumed") 

241 while not self._stream_consumed: 

242 message = await self._receive() 

243 if message["type"] == "http.request": 

244 body = message.get("body", b"") 

245 if not message.get("more_body", False): 

246 self._stream_consumed = True 

247 if body: 

248 yield body 

249 elif message["type"] == "http.disconnect": # pragma: no branch 

250 self._is_disconnected = True 

251 raise ClientDisconnect() 

252 yield b"" 

253 

254 async def body(self) -> bytes: 

255 if not hasattr(self, "_body"): 

256 chunks: list[bytes] = [] 

257 async for chunk in self.stream(): 

258 chunks.append(chunk) 

259 self._body = b"".join(chunks) 

260 return self._body 

261 

262 async def json(self) -> Any: 

263 if not hasattr(self, "_json"): # pragma: no branch 

264 body = await self.body() 

265 self._json = json.loads(body) 

266 return self._json 

267 

268 async def _get_form( 

269 self, 

270 *, 

271 max_files: int | float = 1000, 

272 max_fields: int | float = 1000, 

273 max_part_size: int = 1024 * 1024, 

274 ) -> FormData: 

275 if self._form is None: # pragma: no branch 

276 assert parse_options_header is not None, ( 

277 "The `python-multipart` library must be installed to use form parsing." 

278 ) 

279 content_type_header = self.headers.get("Content-Type") 

280 content_type: bytes 

281 content_type, _ = parse_options_header(content_type_header) 

282 if content_type == b"multipart/form-data": 

283 try: 

284 multipart_parser = MultiPartParser( 

285 self.headers, 

286 self.stream(), 

287 max_files=max_files, 

288 max_fields=max_fields, 

289 max_part_size=max_part_size, 

290 ) 

291 self._form = await multipart_parser.parse() 

292 except MultiPartException as exc: 

293 if "app" in self.scope: 

294 raise HTTPException(status_code=400, detail=exc.message) 

295 raise exc 

296 elif content_type == b"application/x-www-form-urlencoded": 

297 form_parser = FormParser(self.headers, self.stream()) 

298 self._form = await form_parser.parse() 

299 else: 

300 self._form = FormData() 

301 return self._form 

302 

303 def form( 

304 self, 

305 *, 

306 max_files: int | float = 1000, 

307 max_fields: int | float = 1000, 

308 max_part_size: int = 1024 * 1024, 

309 ) -> AwaitableOrContextManager[FormData]: 

310 return AwaitableOrContextManagerWrapper( 

311 self._get_form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size) 

312 ) 

313 

314 async def close(self) -> None: 

315 if self._form is not None: # pragma: no branch 

316 await self._form.close() 

317 

318 async def is_disconnected(self) -> bool: 

319 if not self._is_disconnected: 

320 message: Message = {} 

321 

322 # If message isn't immediately available, move on 

323 with anyio.CancelScope() as cs: 

324 cs.cancel() 

325 message = await self._receive() 

326 

327 if message.get("type") == "http.disconnect": 

328 self._is_disconnected = True 

329 

330 return self._is_disconnected 

331 

332 async def send_push_promise(self, path: str) -> None: 

333 if "http.response.push" in self.scope.get("extensions", {}): 

334 raw_headers: list[tuple[bytes, bytes]] = [] 

335 for name in SERVER_PUSH_HEADERS_TO_COPY: 

336 for value in self.headers.getlist(name): 

337 raw_headers.append((name.encode("latin-1"), value.encode("latin-1"))) 

338 await self._send({"type": "http.response.push", "path": path, "headers": raw_headers})