Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/starlette/requests.py: 35%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

208 statements  

1from __future__ import annotations 

2 

3import json 

4import sys 

5from collections.abc import AsyncGenerator, Iterator, Mapping 

6from http import cookies as http_cookies 

7from typing import TYPE_CHECKING, Any, Generic, NoReturn, cast 

8 

9import anyio 

10 

11from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper 

12from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State 

13from starlette.exceptions import HTTPException 

14from starlette.formparsers import FormParser, MultiPartException, MultiPartParser 

15from starlette.types import Message, Receive, Scope, Send 

16 

17if TYPE_CHECKING: 

18 from python_multipart.multipart import parse_options_header 

19 

20 from starlette.applications import Starlette 

21 from starlette.routing import Router 

22else: 

23 try: 

24 try: 

25 from python_multipart.multipart import parse_options_header 

26 except ModuleNotFoundError: # pragma: no cover 

27 from multipart.multipart import parse_options_header 

28 except ModuleNotFoundError: # pragma: no cover 

29 parse_options_header = None 

30 

31if sys.version_info >= (3, 13): # pragma: no cover 

32 from typing import TypeVar 

33else: # pragma: no cover 

34 from typing_extensions import TypeVar 

35 

36SERVER_PUSH_HEADERS_TO_COPY = { 

37 "accept", 

38 "accept-encoding", 

39 "accept-language", 

40 "cache-control", 

41 "user-agent", 

42} 

43 

44 

45def cookie_parser(cookie_string: str) -> dict[str, str]: 

46 """ 

47 This function parses a ``Cookie`` HTTP header into a dict of key/value pairs. 

48 

49 It attempts to mimic browser cookie parsing behavior: browsers and web servers 

50 frequently disregard the spec (RFC 6265) when setting and reading cookies, 

51 so we attempt to suit the common scenarios here. 

52 

53 This function has been adapted from Django 3.1.0. 

54 Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based 

55 on an outdated spec and will fail on lots of input we want to support 

56 """ 

57 cookie_dict: dict[str, str] = {} 

58 for chunk in cookie_string.split(";"): 

59 if "=" in chunk: 

60 key, val = chunk.split("=", 1) 

61 else: 

62 # Assume an empty name per 

63 # https://bugzilla.mozilla.org/show_bug.cgi?id=169091 

64 key, val = "", chunk 

65 key, val = key.strip(), val.strip() 

66 if key or val: 

67 # unquote using Python's algorithm. 

68 cookie_dict[key] = http_cookies._unquote(val) 

69 return cookie_dict 

70 

71 

72class ClientDisconnect(Exception): 

73 pass 

74 

75 

76StateT = TypeVar("StateT", bound=Mapping[str, Any] | State, default=State) 

77 

78 

79class HTTPConnection(Mapping[str, Any], Generic[StateT]): 

80 """ 

81 A base class for incoming HTTP connections, that is used to provide 

82 any functionality that is common to both `Request` and `WebSocket`. 

83 """ 

84 

85 def __init__(self, scope: Scope, receive: Receive | None = None) -> None: 

86 assert scope["type"] in ("http", "websocket") 

87 self.scope = scope 

88 

89 def __getitem__(self, key: str) -> Any: 

90 return self.scope[key] 

91 

92 def __iter__(self) -> Iterator[str]: 

93 return iter(self.scope) 

94 

95 def __len__(self) -> int: 

96 return len(self.scope) 

97 

98 # Don't use the `abc.Mapping.__eq__` implementation. 

99 # Connection instances should never be considered equal 

100 # unless `self is other`. 

101 __eq__ = object.__eq__ 

102 __hash__ = object.__hash__ 

103 

104 @property 

105 def app(self) -> Any: 

106 return self.scope["app"] 

107 

108 @property 

109 def url(self) -> URL: 

110 if not hasattr(self, "_url"): # pragma: no branch 

111 self._url = URL(scope=self.scope) 

112 return self._url 

113 

114 @property 

115 def base_url(self) -> URL: 

116 if not hasattr(self, "_base_url"): 

117 base_url_scope = dict(self.scope) 

118 # This is used by request.url_for, it might be used inside a Mount which 

119 # would have its own child scope with its own root_path, but the base URL 

120 # for url_for should still be the top level app root path. 

121 app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", "")) 

122 path = app_root_path 

123 if not path.endswith("/"): 

124 path += "/" 

125 base_url_scope["path"] = path 

126 base_url_scope["query_string"] = b"" 

127 base_url_scope["root_path"] = app_root_path 

128 self._base_url = URL(scope=base_url_scope) 

129 return self._base_url 

130 

131 @property 

132 def headers(self) -> Headers: 

133 if not hasattr(self, "_headers"): 

134 self._headers = Headers(scope=self.scope) 

135 return self._headers 

136 

137 @property 

138 def query_params(self) -> QueryParams: 

139 if not hasattr(self, "_query_params"): # pragma: no branch 

140 self._query_params = QueryParams(self.scope["query_string"]) 

141 return self._query_params 

142 

143 @property 

144 def path_params(self) -> dict[str, Any]: 

145 return self.scope.get("path_params", {}) 

146 

147 @property 

148 def cookies(self) -> dict[str, str]: 

149 if not hasattr(self, "_cookies"): 

150 cookies: dict[str, str] = {} 

151 cookie_headers = self.headers.getlist("cookie") 

152 

153 for header in cookie_headers: 

154 cookies.update(cookie_parser(header)) 

155 

156 self._cookies = cookies 

157 return self._cookies 

158 

159 @property 

160 def client(self) -> Address | None: 

161 # client is a 2 item tuple of (host, port), None if missing 

162 host_port = self.scope.get("client") 

163 if host_port is not None: 

164 return Address(*host_port) 

165 return None 

166 

167 @property 

168 def session(self) -> dict[str, Any]: 

169 assert "session" in self.scope, "SessionMiddleware must be installed to access request.session" 

170 return self.scope["session"] # type: ignore[no-any-return] 

171 

172 @property 

173 def auth(self) -> Any: 

174 assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth" 

175 return self.scope["auth"] 

176 

177 @property 

178 def user(self) -> Any: 

179 assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user" 

180 return self.scope["user"] 

181 

182 @property 

183 def state(self) -> StateT: 

184 if not hasattr(self, "_state"): 

185 # Ensure 'state' has an empty dict if it's not already populated. 

186 self.scope.setdefault("state", {}) 

187 # Create a state instance with a reference to the dict in which it should 

188 # store info 

189 self._state = State(self.scope["state"]) 

190 return cast(StateT, self._state) 

191 

192 def url_for(self, name: str, /, **path_params: Any) -> URL: 

193 url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app") 

194 if url_path_provider is None: 

195 raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.") 

196 url_path = url_path_provider.url_path_for(name, **path_params) 

197 return url_path.make_absolute_url(base_url=self.base_url) 

198 

199 

200async def empty_receive() -> NoReturn: 

201 raise RuntimeError("Receive channel has not been made available") 

202 

203 

204async def empty_send(message: Message) -> NoReturn: 

205 raise RuntimeError("Send channel has not been made available") 

206 

207 

208class Request(HTTPConnection[StateT]): 

209 _form: FormData | None 

210 

211 def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send): 

212 super().__init__(scope) 

213 assert scope["type"] == "http" 

214 self._receive = receive 

215 self._send = send 

216 self._stream_consumed = False 

217 self._is_disconnected = False 

218 self._form = None 

219 

220 @property 

221 def method(self) -> str: 

222 return cast(str, self.scope["method"]) 

223 

224 @property 

225 def receive(self) -> Receive: 

226 return self._receive 

227 

228 async def stream(self) -> AsyncGenerator[bytes, None]: 

229 if hasattr(self, "_body"): 

230 yield self._body 

231 yield b"" 

232 return 

233 if self._stream_consumed: 

234 raise RuntimeError("Stream consumed") 

235 while not self._stream_consumed: 

236 message = await self._receive() 

237 if message["type"] == "http.request": 

238 body = message.get("body", b"") 

239 if not message.get("more_body", False): 

240 self._stream_consumed = True 

241 if body: 

242 yield body 

243 elif message["type"] == "http.disconnect": # pragma: no branch 

244 self._is_disconnected = True 

245 raise ClientDisconnect() 

246 yield b"" 

247 

248 async def body(self) -> bytes: 

249 if not hasattr(self, "_body"): 

250 chunks: list[bytes] = [] 

251 async for chunk in self.stream(): 

252 chunks.append(chunk) 

253 self._body = b"".join(chunks) 

254 return self._body 

255 

256 async def json(self) -> Any: 

257 if not hasattr(self, "_json"): # pragma: no branch 

258 body = await self.body() 

259 self._json = json.loads(body) 

260 return self._json 

261 

262 async def _get_form( 

263 self, 

264 *, 

265 max_files: int | float = 1000, 

266 max_fields: int | float = 1000, 

267 max_part_size: int = 1024 * 1024, 

268 ) -> FormData: 

269 if self._form is None: # pragma: no branch 

270 assert parse_options_header is not None, ( 

271 "The `python-multipart` library must be installed to use form parsing." 

272 ) 

273 content_type_header = self.headers.get("Content-Type") 

274 content_type: bytes 

275 content_type, _ = parse_options_header(content_type_header) 

276 if content_type == b"multipart/form-data": 

277 try: 

278 multipart_parser = MultiPartParser( 

279 self.headers, 

280 self.stream(), 

281 max_files=max_files, 

282 max_fields=max_fields, 

283 max_part_size=max_part_size, 

284 ) 

285 self._form = await multipart_parser.parse() 

286 except MultiPartException as exc: 

287 if "app" in self.scope: 

288 raise HTTPException(status_code=400, detail=exc.message) 

289 raise exc 

290 elif content_type == b"application/x-www-form-urlencoded": 

291 form_parser = FormParser(self.headers, self.stream()) 

292 self._form = await form_parser.parse() 

293 else: 

294 self._form = FormData() 

295 return self._form 

296 

297 def form( 

298 self, 

299 *, 

300 max_files: int | float = 1000, 

301 max_fields: int | float = 1000, 

302 max_part_size: int = 1024 * 1024, 

303 ) -> AwaitableOrContextManager[FormData]: 

304 return AwaitableOrContextManagerWrapper( 

305 self._get_form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size) 

306 ) 

307 

308 async def close(self) -> None: 

309 if self._form is not None: # pragma: no branch 

310 await self._form.close() 

311 

312 async def is_disconnected(self) -> bool: 

313 if not self._is_disconnected: 

314 message: Message = {} 

315 

316 # If message isn't immediately available, move on 

317 with anyio.CancelScope() as cs: 

318 cs.cancel() 

319 message = await self._receive() 

320 

321 if message.get("type") == "http.disconnect": 

322 self._is_disconnected = True 

323 

324 return self._is_disconnected 

325 

326 async def send_push_promise(self, path: str) -> None: 

327 if "http.response.push" in self.scope.get("extensions", {}): 

328 raw_headers: list[tuple[bytes, bytes]] = [] 

329 for name in SERVER_PUSH_HEADERS_TO_COPY: 

330 for value in self.headers.getlist(name): 

331 raw_headers.append((name.encode("latin-1"), value.encode("latin-1"))) 

332 await self._send({"type": "http.response.push", "path": path, "headers": raw_headers})