Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/starlette/requests.py: 34%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

206 statements  

1from __future__ import annotations 

2 

3import json 

4from collections.abc import AsyncGenerator, Iterator, Mapping 

5from http import cookies as http_cookies 

6from typing import TYPE_CHECKING, Any, NoReturn, cast 

7 

8import anyio 

9 

10from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper 

11from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State 

12from starlette.exceptions import HTTPException 

13from starlette.formparsers import FormParser, MultiPartException, MultiPartParser 

14from starlette.types import Message, Receive, Scope, Send 

15 

16if TYPE_CHECKING: 

17 from python_multipart.multipart import parse_options_header 

18 

19 from starlette.applications import Starlette 

20 from starlette.routing import Router 

21else: 

22 try: 

23 try: 

24 from python_multipart.multipart import parse_options_header 

25 except ModuleNotFoundError: # pragma: no cover 

26 from multipart.multipart import parse_options_header 

27 except ModuleNotFoundError: # pragma: no cover 

28 parse_options_header = None 

29 

30 

31SERVER_PUSH_HEADERS_TO_COPY = { 

32 "accept", 

33 "accept-encoding", 

34 "accept-language", 

35 "cache-control", 

36 "user-agent", 

37} 

38 

39 

40def cookie_parser(cookie_string: str) -> dict[str, str]: 

41 """ 

42 This function parses a ``Cookie`` HTTP header into a dict of key/value pairs. 

43 

44 It attempts to mimic browser cookie parsing behavior: browsers and web servers 

45 frequently disregard the spec (RFC 6265) when setting and reading cookies, 

46 so we attempt to suit the common scenarios here. 

47 

48 This function has been adapted from Django 3.1.0. 

49 Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based 

50 on an outdated spec and will fail on lots of input we want to support 

51 """ 

52 cookie_dict: dict[str, str] = {} 

53 for chunk in cookie_string.split(";"): 

54 if "=" in chunk: 

55 key, val = chunk.split("=", 1) 

56 else: 

57 # Assume an empty name per 

58 # https://bugzilla.mozilla.org/show_bug.cgi?id=169091 

59 key, val = "", chunk 

60 key, val = key.strip(), val.strip() 

61 if key or val: 

62 # unquote using Python's algorithm. 

63 cookie_dict[key] = http_cookies._unquote(val) 

64 return cookie_dict 

65 

66 

67class ClientDisconnect(Exception): 

68 pass 

69 

70 

71class HTTPConnection(Mapping[str, Any]): 

72 """ 

73 A base class for incoming HTTP connections, that is used to provide 

74 any functionality that is common to both `Request` and `WebSocket`. 

75 """ 

76 

77 def __init__(self, scope: Scope, receive: Receive | None = None) -> None: 

78 assert scope["type"] in ("http", "websocket") 

79 self.scope = scope 

80 

81 def __getitem__(self, key: str) -> Any: 

82 return self.scope[key] 

83 

84 def __iter__(self) -> Iterator[str]: 

85 return iter(self.scope) 

86 

87 def __len__(self) -> int: 

88 return len(self.scope) 

89 

90 # Don't use the `abc.Mapping.__eq__` implementation. 

91 # Connection instances should never be considered equal 

92 # unless `self is other`. 

93 __eq__ = object.__eq__ 

94 __hash__ = object.__hash__ 

95 

96 @property 

97 def app(self) -> Any: 

98 return self.scope["app"] 

99 

100 @property 

101 def url(self) -> URL: 

102 if not hasattr(self, "_url"): # pragma: no branch 

103 self._url = URL(scope=self.scope) 

104 return self._url 

105 

106 @property 

107 def base_url(self) -> URL: 

108 if not hasattr(self, "_base_url"): 

109 base_url_scope = dict(self.scope) 

110 # This is used by request.url_for, it might be used inside a Mount which 

111 # would have its own child scope with its own root_path, but the base URL 

112 # for url_for should still be the top level app root path. 

113 app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", "")) 

114 path = app_root_path 

115 if not path.endswith("/"): 

116 path += "/" 

117 base_url_scope["path"] = path 

118 base_url_scope["query_string"] = b"" 

119 base_url_scope["root_path"] = app_root_path 

120 self._base_url = URL(scope=base_url_scope) 

121 return self._base_url 

122 

123 @property 

124 def headers(self) -> Headers: 

125 if not hasattr(self, "_headers"): 

126 self._headers = Headers(scope=self.scope) 

127 return self._headers 

128 

129 @property 

130 def query_params(self) -> QueryParams: 

131 if not hasattr(self, "_query_params"): # pragma: no branch 

132 self._query_params = QueryParams(self.scope["query_string"]) 

133 return self._query_params 

134 

135 @property 

136 def path_params(self) -> dict[str, Any]: 

137 return self.scope.get("path_params", {}) 

138 

139 @property 

140 def cookies(self) -> dict[str, str]: 

141 if not hasattr(self, "_cookies"): 

142 cookies: dict[str, str] = {} 

143 cookie_headers = self.headers.getlist("cookie") 

144 

145 for header in cookie_headers: 

146 cookies.update(cookie_parser(header)) 

147 

148 self._cookies = cookies 

149 return self._cookies 

150 

151 @property 

152 def client(self) -> Address | None: 

153 # client is a 2 item tuple of (host, port), None if missing 

154 host_port = self.scope.get("client") 

155 if host_port is not None: 

156 return Address(*host_port) 

157 return None 

158 

159 @property 

160 def session(self) -> dict[str, Any]: 

161 assert "session" in self.scope, "SessionMiddleware must be installed to access request.session" 

162 return self.scope["session"] # type: ignore[no-any-return] 

163 

164 @property 

165 def auth(self) -> Any: 

166 assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth" 

167 return self.scope["auth"] 

168 

169 @property 

170 def user(self) -> Any: 

171 assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user" 

172 return self.scope["user"] 

173 

174 @property 

175 def state(self) -> State: 

176 if not hasattr(self, "_state"): 

177 # Ensure 'state' has an empty dict if it's not already populated. 

178 self.scope.setdefault("state", {}) 

179 # Create a state instance with a reference to the dict in which it should 

180 # store info 

181 self._state = State(self.scope["state"]) 

182 return self._state 

183 

184 def url_for(self, name: str, /, **path_params: Any) -> URL: 

185 url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app") 

186 if url_path_provider is None: 

187 raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.") 

188 url_path = url_path_provider.url_path_for(name, **path_params) 

189 return url_path.make_absolute_url(base_url=self.base_url) 

190 

191 

192async def empty_receive() -> NoReturn: 

193 raise RuntimeError("Receive channel has not been made available") 

194 

195 

196async def empty_send(message: Message) -> NoReturn: 

197 raise RuntimeError("Send channel has not been made available") 

198 

199 

200class Request(HTTPConnection): 

201 _form: FormData | None 

202 

203 def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send): 

204 super().__init__(scope) 

205 assert scope["type"] == "http" 

206 self._receive = receive 

207 self._send = send 

208 self._stream_consumed = False 

209 self._is_disconnected = False 

210 self._form = None 

211 

212 @property 

213 def method(self) -> str: 

214 return cast(str, self.scope["method"]) 

215 

216 @property 

217 def receive(self) -> Receive: 

218 return self._receive 

219 

220 async def stream(self) -> AsyncGenerator[bytes, None]: 

221 if hasattr(self, "_body"): 

222 yield self._body 

223 yield b"" 

224 return 

225 if self._stream_consumed: 

226 raise RuntimeError("Stream consumed") 

227 while not self._stream_consumed: 

228 message = await self._receive() 

229 if message["type"] == "http.request": 

230 body = message.get("body", b"") 

231 if not message.get("more_body", False): 

232 self._stream_consumed = True 

233 if body: 

234 yield body 

235 elif message["type"] == "http.disconnect": # pragma: no branch 

236 self._is_disconnected = True 

237 raise ClientDisconnect() 

238 yield b"" 

239 

240 async def body(self) -> bytes: 

241 if not hasattr(self, "_body"): 

242 chunks: list[bytes] = [] 

243 async for chunk in self.stream(): 

244 chunks.append(chunk) 

245 self._body = b"".join(chunks) 

246 return self._body 

247 

248 async def json(self) -> Any: 

249 if not hasattr(self, "_json"): # pragma: no branch 

250 body = await self.body() 

251 self._json = json.loads(body) 

252 return self._json 

253 

254 async def _get_form( 

255 self, 

256 *, 

257 max_files: int | float = 1000, 

258 max_fields: int | float = 1000, 

259 max_part_size: int = 1024 * 1024, 

260 ) -> FormData: 

261 if self._form is None: # pragma: no branch 

262 assert parse_options_header is not None, ( 

263 "The `python-multipart` library must be installed to use form parsing." 

264 ) 

265 content_type_header = self.headers.get("Content-Type") 

266 content_type: bytes 

267 content_type, _ = parse_options_header(content_type_header) 

268 if content_type == b"multipart/form-data": 

269 try: 

270 multipart_parser = MultiPartParser( 

271 self.headers, 

272 self.stream(), 

273 max_files=max_files, 

274 max_fields=max_fields, 

275 max_part_size=max_part_size, 

276 ) 

277 self._form = await multipart_parser.parse() 

278 except MultiPartException as exc: 

279 if "app" in self.scope: 

280 raise HTTPException(status_code=400, detail=exc.message) 

281 raise exc 

282 elif content_type == b"application/x-www-form-urlencoded": 

283 form_parser = FormParser(self.headers, self.stream()) 

284 self._form = await form_parser.parse() 

285 else: 

286 self._form = FormData() 

287 return self._form 

288 

289 def form( 

290 self, 

291 *, 

292 max_files: int | float = 1000, 

293 max_fields: int | float = 1000, 

294 max_part_size: int = 1024 * 1024, 

295 ) -> AwaitableOrContextManager[FormData]: 

296 return AwaitableOrContextManagerWrapper( 

297 self._get_form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size) 

298 ) 

299 

300 async def close(self) -> None: 

301 if self._form is not None: # pragma: no branch 

302 await self._form.close() 

303 

304 async def is_disconnected(self) -> bool: 

305 if not self._is_disconnected: 

306 message: Message = {} 

307 

308 # If message isn't immediately available, move on 

309 with anyio.CancelScope() as cs: 

310 cs.cancel() 

311 message = await self._receive() 

312 

313 if message.get("type") == "http.disconnect": 

314 self._is_disconnected = True 

315 

316 return self._is_disconnected 

317 

318 async def send_push_promise(self, path: str) -> None: 

319 if "http.response.push" in self.scope.get("extensions", {}): 

320 raw_headers: list[tuple[bytes, bytes]] = [] 

321 for name in SERVER_PUSH_HEADERS_TO_COPY: 

322 for value in self.headers.getlist(name): 

323 raw_headers.append((name.encode("latin-1"), value.encode("latin-1"))) 

324 await self._send({"type": "http.response.push", "path": path, "headers": raw_headers})