Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/starlette/requests.py: 34%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

212 statements  

1from __future__ import annotations 

2 

3import json 

4import sys 

5from collections.abc import AsyncGenerator, Iterator, Mapping 

6from http import cookies as http_cookies 

7from typing import TYPE_CHECKING, Any, Generic, NoReturn, cast 

8 

9import anyio 

10 

11from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper 

12from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State 

13from starlette.exceptions import HTTPException 

14from starlette.formparsers import FormParser, MultiPartException, MultiPartParser 

15from starlette.types import Message, Receive, Scope, Send 

16 

17if TYPE_CHECKING: 

18 from python_multipart.multipart import parse_options_header 

19 

20 from starlette.applications import Starlette 

21 from starlette.middleware.sessions import Session 

22 from starlette.routing import Router 

23else: 

24 try: 

25 try: 

26 from python_multipart.multipart import parse_options_header 

27 except ModuleNotFoundError: # pragma: no cover 

28 from multipart.multipart import parse_options_header 

29 except ModuleNotFoundError: # pragma: no cover 

30 parse_options_header = None 

31 

32if sys.version_info >= (3, 13): # pragma: no cover 

33 from typing import TypeVar 

34else: # pragma: no cover 

35 from typing_extensions import TypeVar 

36 

37SERVER_PUSH_HEADERS_TO_COPY = { 

38 "accept", 

39 "accept-encoding", 

40 "accept-language", 

41 "cache-control", 

42 "user-agent", 

43} 

44 

45 

46def cookie_parser(cookie_string: str) -> dict[str, str]: 

47 """ 

48 This function parses a ``Cookie`` HTTP header into a dict of key/value pairs. 

49 

50 It attempts to mimic browser cookie parsing behavior: browsers and web servers 

51 frequently disregard the spec (RFC 6265) when setting and reading cookies, 

52 so we attempt to suit the common scenarios here. 

53 

54 This function has been adapted from Django 3.1.0. 

55 Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based 

56 on an outdated spec and will fail on lots of input we want to support 

57 """ 

58 cookie_dict: dict[str, str] = {} 

59 for chunk in cookie_string.split(";"): 

60 if "=" in chunk: 

61 key, val = chunk.split("=", 1) 

62 else: 

63 # Assume an empty name per 

64 # https://bugzilla.mozilla.org/show_bug.cgi?id=169091 

65 key, val = "", chunk 

66 key, val = key.strip(), val.strip() 

67 if key or val: 

68 # unquote using Python's algorithm. 

69 cookie_dict[key] = http_cookies._unquote(val) 

70 return cookie_dict 

71 

72 

73class ClientDisconnect(Exception): 

74 pass 

75 

76 

77StateT = TypeVar("StateT", bound=Mapping[str, Any] | State, default=State) 

78 

79 

80class HTTPConnection(Mapping[str, Any], Generic[StateT]): 

81 """ 

82 A base class for incoming HTTP connections, that is used to provide 

83 any functionality that is common to both `Request` and `WebSocket`. 

84 """ 

85 

86 def __init__(self, scope: Scope, receive: Receive | None = None) -> None: 

87 assert scope["type"] in ("http", "websocket") 

88 self.scope = scope 

89 

90 def __getitem__(self, key: str) -> Any: 

91 return self.scope[key] 

92 

93 def __iter__(self) -> Iterator[str]: 

94 return iter(self.scope) 

95 

96 def __len__(self) -> int: 

97 return len(self.scope) 

98 

99 # Don't use the `abc.Mapping.__eq__` implementation. 

100 # Connection instances should never be considered equal 

101 # unless `self is other`. 

102 __eq__ = object.__eq__ 

103 __hash__ = object.__hash__ 

104 

105 @property 

106 def app(self) -> Any: 

107 return self.scope["app"] 

108 

109 @property 

110 def url(self) -> URL: 

111 if not hasattr(self, "_url"): # pragma: no branch 

112 self._url = URL(scope=self.scope) 

113 return self._url 

114 

115 @property 

116 def base_url(self) -> URL: 

117 if not hasattr(self, "_base_url"): 

118 base_url_scope = dict(self.scope) 

119 # This is used by request.url_for, it might be used inside a Mount which 

120 # would have its own child scope with its own root_path, but the base URL 

121 # for url_for should still be the top level app root path. 

122 app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", "")) 

123 path = app_root_path 

124 if not path.endswith("/"): 

125 path += "/" 

126 base_url_scope["path"] = path 

127 base_url_scope["query_string"] = b"" 

128 base_url_scope["root_path"] = app_root_path 

129 self._base_url = URL(scope=base_url_scope) 

130 return self._base_url 

131 

132 @property 

133 def headers(self) -> Headers: 

134 if not hasattr(self, "_headers"): 

135 self._headers = Headers(scope=self.scope) 

136 return self._headers 

137 

138 @property 

139 def query_params(self) -> QueryParams: 

140 if not hasattr(self, "_query_params"): # pragma: no branch 

141 self._query_params = QueryParams(self.scope["query_string"]) 

142 return self._query_params 

143 

144 @property 

145 def path_params(self) -> dict[str, Any]: 

146 return self.scope.get("path_params", {}) 

147 

148 @property 

149 def cookies(self) -> dict[str, str]: 

150 if not hasattr(self, "_cookies"): 

151 cookies: dict[str, str] = {} 

152 cookie_headers = self.headers.getlist("cookie") 

153 

154 for header in cookie_headers: 

155 cookies.update(cookie_parser(header)) 

156 

157 self._cookies = cookies 

158 return self._cookies 

159 

160 @property 

161 def client(self) -> Address | None: 

162 # client is a 2 item tuple of (host, port), None if missing 

163 host_port = self.scope.get("client") 

164 if host_port is not None: 

165 return Address(*host_port) 

166 return None 

167 

168 @property 

169 def session(self) -> dict[str, Any]: 

170 assert "session" in self.scope, "SessionMiddleware must be installed to access request.session" 

171 session: Session = self.scope["session"] 

172 # We keep the hasattr in case people actually use their own `SessionMiddleware` implementation. 

173 if hasattr(session, "mark_accessed"): # pragma: no branch 

174 session.mark_accessed() 

175 return session 

176 

177 @property 

178 def auth(self) -> Any: 

179 assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth" 

180 return self.scope["auth"] 

181 

182 @property 

183 def user(self) -> Any: 

184 assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user" 

185 return self.scope["user"] 

186 

187 @property 

188 def state(self) -> StateT: 

189 if not hasattr(self, "_state"): 

190 # Ensure 'state' has an empty dict if it's not already populated. 

191 self.scope.setdefault("state", {}) 

192 # Create a state instance with a reference to the dict in which it should 

193 # store info 

194 self._state = State(self.scope["state"]) 

195 return cast(StateT, self._state) 

196 

197 def url_for(self, name: str, /, **path_params: Any) -> URL: 

198 url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app") 

199 if url_path_provider is None: 

200 raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.") 

201 url_path = url_path_provider.url_path_for(name, **path_params) 

202 return url_path.make_absolute_url(base_url=self.base_url) 

203 

204 

205async def empty_receive() -> NoReturn: 

206 raise RuntimeError("Receive channel has not been made available") 

207 

208 

209async def empty_send(message: Message) -> NoReturn: 

210 raise RuntimeError("Send channel has not been made available") 

211 

212 

213class Request(HTTPConnection[StateT]): 

214 _form: FormData | None 

215 

216 def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send): 

217 super().__init__(scope) 

218 assert scope["type"] == "http" 

219 self._receive = receive 

220 self._send = send 

221 self._stream_consumed = False 

222 self._is_disconnected = False 

223 self._form = None 

224 

225 @property 

226 def method(self) -> str: 

227 return cast(str, self.scope["method"]) 

228 

229 @property 

230 def receive(self) -> Receive: 

231 return self._receive 

232 

233 async def stream(self) -> AsyncGenerator[bytes, None]: 

234 if hasattr(self, "_body"): 

235 yield self._body 

236 yield b"" 

237 return 

238 if self._stream_consumed: 

239 raise RuntimeError("Stream consumed") 

240 while not self._stream_consumed: 

241 message = await self._receive() 

242 if message["type"] == "http.request": 

243 body = message.get("body", b"") 

244 if not message.get("more_body", False): 

245 self._stream_consumed = True 

246 if body: 

247 yield body 

248 elif message["type"] == "http.disconnect": # pragma: no branch 

249 self._is_disconnected = True 

250 raise ClientDisconnect() 

251 yield b"" 

252 

253 async def body(self) -> bytes: 

254 if not hasattr(self, "_body"): 

255 chunks: list[bytes] = [] 

256 async for chunk in self.stream(): 

257 chunks.append(chunk) 

258 self._body = b"".join(chunks) 

259 return self._body 

260 

261 async def json(self) -> Any: 

262 if not hasattr(self, "_json"): # pragma: no branch 

263 body = await self.body() 

264 self._json = json.loads(body) 

265 return self._json 

266 

267 async def _get_form( 

268 self, 

269 *, 

270 max_files: int | float = 1000, 

271 max_fields: int | float = 1000, 

272 max_part_size: int = 1024 * 1024, 

273 ) -> FormData: 

274 if self._form is None: # pragma: no branch 

275 assert parse_options_header is not None, ( 

276 "The `python-multipart` library must be installed to use form parsing." 

277 ) 

278 content_type_header = self.headers.get("Content-Type") 

279 content_type: bytes 

280 content_type, _ = parse_options_header(content_type_header) 

281 if content_type == b"multipart/form-data": 

282 try: 

283 multipart_parser = MultiPartParser( 

284 self.headers, 

285 self.stream(), 

286 max_files=max_files, 

287 max_fields=max_fields, 

288 max_part_size=max_part_size, 

289 ) 

290 self._form = await multipart_parser.parse() 

291 except MultiPartException as exc: 

292 if "app" in self.scope: 

293 raise HTTPException(status_code=400, detail=exc.message) 

294 raise exc 

295 elif content_type == b"application/x-www-form-urlencoded": 

296 form_parser = FormParser(self.headers, self.stream()) 

297 self._form = await form_parser.parse() 

298 else: 

299 self._form = FormData() 

300 return self._form 

301 

302 def form( 

303 self, 

304 *, 

305 max_files: int | float = 1000, 

306 max_fields: int | float = 1000, 

307 max_part_size: int = 1024 * 1024, 

308 ) -> AwaitableOrContextManager[FormData]: 

309 return AwaitableOrContextManagerWrapper( 

310 self._get_form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size) 

311 ) 

312 

313 async def close(self) -> None: 

314 if self._form is not None: # pragma: no branch 

315 await self._form.close() 

316 

317 async def is_disconnected(self) -> bool: 

318 if not self._is_disconnected: 

319 message: Message = {} 

320 

321 # If message isn't immediately available, move on 

322 with anyio.CancelScope() as cs: 

323 cs.cancel() 

324 message = await self._receive() 

325 

326 if message.get("type") == "http.disconnect": 

327 self._is_disconnected = True 

328 

329 return self._is_disconnected 

330 

331 async def send_push_promise(self, path: str) -> None: 

332 if "http.response.push" in self.scope.get("extensions", {}): 

333 raw_headers: list[tuple[bytes, bytes]] = [] 

334 for name in SERVER_PUSH_HEADERS_TO_COPY: 

335 for value in self.headers.getlist(name): 

336 raw_headers.append((name.encode("latin-1"), value.encode("latin-1"))) 

337 await self._send({"type": "http.response.push", "path": path, "headers": raw_headers})