Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/starlette/requests.py: 34%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

206 statements  

1from __future__ import annotations 

2 

3import json 

4from collections.abc import AsyncGenerator, Iterator, Mapping 

5from http import cookies as http_cookies 

6from typing import TYPE_CHECKING, Any, NoReturn, cast 

7 

8import anyio 

9 

10from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper 

11from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State 

12from starlette.exceptions import HTTPException 

13from starlette.formparsers import FormParser, MultiPartException, MultiPartParser 

14from starlette.types import Message, Receive, Scope, Send 

15 

16if TYPE_CHECKING: 

17 from python_multipart.multipart import parse_options_header 

18 

19 from starlette.applications import Starlette 

20 from starlette.routing import Router 

21else: 

22 try: 

23 try: 

24 from python_multipart.multipart import parse_options_header 

25 except ModuleNotFoundError: # pragma: no cover 

26 from multipart.multipart import parse_options_header 

27 except ModuleNotFoundError: # pragma: no cover 

28 parse_options_header = None 

29 

30 

31SERVER_PUSH_HEADERS_TO_COPY = { 

32 "accept", 

33 "accept-encoding", 

34 "accept-language", 

35 "cache-control", 

36 "user-agent", 

37} 

38 

39 

40def cookie_parser(cookie_string: str) -> dict[str, str]: 

41 """ 

42 This function parses a ``Cookie`` HTTP header into a dict of key/value pairs. 

43 

44 It attempts to mimic browser cookie parsing behavior: browsers and web servers 

45 frequently disregard the spec (RFC 6265) when setting and reading cookies, 

46 so we attempt to suit the common scenarios here. 

47 

48 This function has been adapted from Django 3.1.0. 

49 Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based 

50 on an outdated spec and will fail on lots of input we want to support 

51 """ 

52 cookie_dict: dict[str, str] = {} 

53 for chunk in cookie_string.split(";"): 

54 if "=" in chunk: 

55 key, val = chunk.split("=", 1) 

56 else: 

57 # Assume an empty name per 

58 # https://bugzilla.mozilla.org/show_bug.cgi?id=169091 

59 key, val = "", chunk 

60 key, val = key.strip(), val.strip() 

61 if key or val: 

62 # unquote using Python's algorithm. 

63 cookie_dict[key] = http_cookies._unquote(val) 

64 return cookie_dict 

65 

66 

67class ClientDisconnect(Exception): 

68 pass 

69 

70 

71class HTTPConnection(Mapping[str, Any]): 

72 """ 

73 A base class for incoming HTTP connections, that is used to provide 

74 any functionality that is common to both `Request` and `WebSocket`. 

75 """ 

76 

77 def __init__(self, scope: Scope, receive: Receive | None = None) -> None: 

78 assert scope["type"] in ("http", "websocket") 

79 self.scope = scope 

80 

81 def __getitem__(self, key: str) -> Any: 

82 return self.scope[key] 

83 

84 def __iter__(self) -> Iterator[str]: 

85 return iter(self.scope) 

86 

87 def __len__(self) -> int: 

88 return len(self.scope) 

89 

90 # Don't use the `abc.Mapping.__eq__` implementation. 

91 # Connection instances should never be considered equal 

92 # unless `self is other`. 

93 __eq__ = object.__eq__ 

94 __hash__ = object.__hash__ 

95 

96 @property 

97 def app(self) -> Any: 

98 return self.scope["app"] 

99 

100 @property 

101 def url(self) -> URL: 

102 if not hasattr(self, "_url"): # pragma: no branch 

103 self._url = URL(scope=self.scope) 

104 return self._url 

105 

106 @property 

107 def base_url(self) -> URL: 

108 if not hasattr(self, "_base_url"): 

109 base_url_scope = dict(self.scope) 

110 # This is used by request.url_for, it might be used inside a Mount which 

111 # would have its own child scope with its own root_path, but the base URL 

112 # for url_for should still be the top level app root path. 

113 app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", "")) 

114 path = app_root_path 

115 if not path.endswith("/"): 

116 path += "/" 

117 base_url_scope["path"] = path 

118 base_url_scope["query_string"] = b"" 

119 base_url_scope["root_path"] = app_root_path 

120 self._base_url = URL(scope=base_url_scope) 

121 return self._base_url 

122 

123 @property 

124 def headers(self) -> Headers: 

125 if not hasattr(self, "_headers"): 

126 self._headers = Headers(scope=self.scope) 

127 return self._headers 

128 

129 @property 

130 def query_params(self) -> QueryParams: 

131 if not hasattr(self, "_query_params"): # pragma: no branch 

132 self._query_params = QueryParams(self.scope["query_string"]) 

133 return self._query_params 

134 

135 @property 

136 def path_params(self) -> dict[str, Any]: 

137 return self.scope.get("path_params", {}) 

138 

139 @property 

140 def cookies(self) -> dict[str, str]: 

141 if not hasattr(self, "_cookies"): 

142 cookies: dict[str, str] = {} 

143 cookie_header = self.headers.get("cookie") 

144 

145 if cookie_header: 

146 cookies = cookie_parser(cookie_header) 

147 self._cookies = cookies 

148 return self._cookies 

149 

150 @property 

151 def client(self) -> Address | None: 

152 # client is a 2 item tuple of (host, port), None if missing 

153 host_port = self.scope.get("client") 

154 if host_port is not None: 

155 return Address(*host_port) 

156 return None 

157 

158 @property 

159 def session(self) -> dict[str, Any]: 

160 assert "session" in self.scope, "SessionMiddleware must be installed to access request.session" 

161 return self.scope["session"] # type: ignore[no-any-return] 

162 

163 @property 

164 def auth(self) -> Any: 

165 assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth" 

166 return self.scope["auth"] 

167 

168 @property 

169 def user(self) -> Any: 

170 assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user" 

171 return self.scope["user"] 

172 

173 @property 

174 def state(self) -> State: 

175 if not hasattr(self, "_state"): 

176 # Ensure 'state' has an empty dict if it's not already populated. 

177 self.scope.setdefault("state", {}) 

178 # Create a state instance with a reference to the dict in which it should 

179 # store info 

180 self._state = State(self.scope["state"]) 

181 return self._state 

182 

183 def url_for(self, name: str, /, **path_params: Any) -> URL: 

184 url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app") 

185 if url_path_provider is None: 

186 raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.") 

187 url_path = url_path_provider.url_path_for(name, **path_params) 

188 return url_path.make_absolute_url(base_url=self.base_url) 

189 

190 

191async def empty_receive() -> NoReturn: 

192 raise RuntimeError("Receive channel has not been made available") 

193 

194 

195async def empty_send(message: Message) -> NoReturn: 

196 raise RuntimeError("Send channel has not been made available") 

197 

198 

199class Request(HTTPConnection): 

200 _form: FormData | None 

201 

202 def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send): 

203 super().__init__(scope) 

204 assert scope["type"] == "http" 

205 self._receive = receive 

206 self._send = send 

207 self._stream_consumed = False 

208 self._is_disconnected = False 

209 self._form = None 

210 

211 @property 

212 def method(self) -> str: 

213 return cast(str, self.scope["method"]) 

214 

215 @property 

216 def receive(self) -> Receive: 

217 return self._receive 

218 

219 async def stream(self) -> AsyncGenerator[bytes, None]: 

220 if hasattr(self, "_body"): 

221 yield self._body 

222 yield b"" 

223 return 

224 if self._stream_consumed: 

225 raise RuntimeError("Stream consumed") 

226 while not self._stream_consumed: 

227 message = await self._receive() 

228 if message["type"] == "http.request": 

229 body = message.get("body", b"") 

230 if not message.get("more_body", False): 

231 self._stream_consumed = True 

232 if body: 

233 yield body 

234 elif message["type"] == "http.disconnect": # pragma: no branch 

235 self._is_disconnected = True 

236 raise ClientDisconnect() 

237 yield b"" 

238 

239 async def body(self) -> bytes: 

240 if not hasattr(self, "_body"): 

241 chunks: list[bytes] = [] 

242 async for chunk in self.stream(): 

243 chunks.append(chunk) 

244 self._body = b"".join(chunks) 

245 return self._body 

246 

247 async def json(self) -> Any: 

248 if not hasattr(self, "_json"): # pragma: no branch 

249 body = await self.body() 

250 self._json = json.loads(body) 

251 return self._json 

252 

253 async def _get_form( 

254 self, 

255 *, 

256 max_files: int | float = 1000, 

257 max_fields: int | float = 1000, 

258 max_part_size: int = 1024 * 1024, 

259 ) -> FormData: 

260 if self._form is None: # pragma: no branch 

261 assert parse_options_header is not None, ( 

262 "The `python-multipart` library must be installed to use form parsing." 

263 ) 

264 content_type_header = self.headers.get("Content-Type") 

265 content_type: bytes 

266 content_type, _ = parse_options_header(content_type_header) 

267 if content_type == b"multipart/form-data": 

268 try: 

269 multipart_parser = MultiPartParser( 

270 self.headers, 

271 self.stream(), 

272 max_files=max_files, 

273 max_fields=max_fields, 

274 max_part_size=max_part_size, 

275 ) 

276 self._form = await multipart_parser.parse() 

277 except MultiPartException as exc: 

278 if "app" in self.scope: 

279 raise HTTPException(status_code=400, detail=exc.message) 

280 raise exc 

281 elif content_type == b"application/x-www-form-urlencoded": 

282 form_parser = FormParser(self.headers, self.stream()) 

283 self._form = await form_parser.parse() 

284 else: 

285 self._form = FormData() 

286 return self._form 

287 

288 def form( 

289 self, 

290 *, 

291 max_files: int | float = 1000, 

292 max_fields: int | float = 1000, 

293 max_part_size: int = 1024 * 1024, 

294 ) -> AwaitableOrContextManager[FormData]: 

295 return AwaitableOrContextManagerWrapper( 

296 self._get_form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size) 

297 ) 

298 

299 async def close(self) -> None: 

300 if self._form is not None: # pragma: no branch 

301 await self._form.close() 

302 

303 async def is_disconnected(self) -> bool: 

304 if not self._is_disconnected: 

305 message: Message = {} 

306 

307 # If message isn't immediately available, move on 

308 with anyio.CancelScope() as cs: 

309 cs.cancel() 

310 message = await self._receive() 

311 

312 if message.get("type") == "http.disconnect": 

313 self._is_disconnected = True 

314 

315 return self._is_disconnected 

316 

317 async def send_push_promise(self, path: str) -> None: 

318 if "http.response.push" in self.scope.get("extensions", {}): 

319 raw_headers: list[tuple[bytes, bytes]] = [] 

320 for name in SERVER_PUSH_HEADERS_TO_COPY: 

321 for value in self.headers.getlist(name): 

322 raw_headers.append((name.encode("latin-1"), value.encode("latin-1"))) 

323 await self._send({"type": "http.response.push", "path": path, "headers": raw_headers})