Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/starlette/requests.py: 34%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

200 statements  

1from __future__ import annotations 

2 

3import json 

4import typing 

5from http import cookies as http_cookies 

6 

7import anyio 

8 

9from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper 

10from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State 

11from starlette.exceptions import HTTPException 

12from starlette.formparsers import FormParser, MultiPartException, MultiPartParser 

13from starlette.types import Message, Receive, Scope, Send 

14 

15try: 

16 from multipart.multipart import parse_options_header 

17except ModuleNotFoundError: # pragma: nocover 

18 parse_options_header = None 

19 

20 

21if typing.TYPE_CHECKING: 

22 from starlette.routing import Router 

23 

24 

25SERVER_PUSH_HEADERS_TO_COPY = { 

26 "accept", 

27 "accept-encoding", 

28 "accept-language", 

29 "cache-control", 

30 "user-agent", 

31} 

32 

33 

34def cookie_parser(cookie_string: str) -> dict[str, str]: 

35 """ 

36 This function parses a ``Cookie`` HTTP header into a dict of key/value pairs. 

37 

38 It attempts to mimic browser cookie parsing behavior: browsers and web servers 

39 frequently disregard the spec (RFC 6265) when setting and reading cookies, 

40 so we attempt to suit the common scenarios here. 

41 

42 This function has been adapted from Django 3.1.0. 

43 Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based 

44 on an outdated spec and will fail on lots of input we want to support 

45 """ 

46 cookie_dict: dict[str, str] = {} 

47 for chunk in cookie_string.split(";"): 

48 if "=" in chunk: 

49 key, val = chunk.split("=", 1) 

50 else: 

51 # Assume an empty name per 

52 # https://bugzilla.mozilla.org/show_bug.cgi?id=169091 

53 key, val = "", chunk 

54 key, val = key.strip(), val.strip() 

55 if key or val: 

56 # unquote using Python's algorithm. 

57 cookie_dict[key] = http_cookies._unquote(val) 

58 return cookie_dict 

59 

60 

61class ClientDisconnect(Exception): 

62 pass 

63 

64 

65class HTTPConnection(typing.Mapping[str, typing.Any]): 

66 """ 

67 A base class for incoming HTTP connections, that is used to provide 

68 any functionality that is common to both `Request` and `WebSocket`. 

69 """ 

70 

71 def __init__(self, scope: Scope, receive: Receive | None = None) -> None: 

72 assert scope["type"] in ("http", "websocket") 

73 self.scope = scope 

74 

75 def __getitem__(self, key: str) -> typing.Any: 

76 return self.scope[key] 

77 

78 def __iter__(self) -> typing.Iterator[str]: 

79 return iter(self.scope) 

80 

81 def __len__(self) -> int: 

82 return len(self.scope) 

83 

84 # Don't use the `abc.Mapping.__eq__` implementation. 

85 # Connection instances should never be considered equal 

86 # unless `self is other`. 

87 __eq__ = object.__eq__ 

88 __hash__ = object.__hash__ 

89 

90 @property 

91 def app(self) -> typing.Any: 

92 return self.scope["app"] 

93 

94 @property 

95 def url(self) -> URL: 

96 if not hasattr(self, "_url"): 

97 self._url = URL(scope=self.scope) 

98 return self._url 

99 

100 @property 

101 def base_url(self) -> URL: 

102 if not hasattr(self, "_base_url"): 

103 base_url_scope = dict(self.scope) 

104 # This is used by request.url_for, it might be used inside a Mount which 

105 # would have its own child scope with its own root_path, but the base URL 

106 # for url_for should still be the top level app root path. 

107 app_root_path = base_url_scope.get( 

108 "app_root_path", base_url_scope.get("root_path", "") 

109 ) 

110 path = app_root_path 

111 if not path.endswith("/"): 

112 path += "/" 

113 base_url_scope["path"] = path 

114 base_url_scope["query_string"] = b"" 

115 base_url_scope["root_path"] = app_root_path 

116 self._base_url = URL(scope=base_url_scope) 

117 return self._base_url 

118 

119 @property 

120 def headers(self) -> Headers: 

121 if not hasattr(self, "_headers"): 

122 self._headers = Headers(scope=self.scope) 

123 return self._headers 

124 

125 @property 

126 def query_params(self) -> QueryParams: 

127 if not hasattr(self, "_query_params"): 

128 self._query_params = QueryParams(self.scope["query_string"]) 

129 return self._query_params 

130 

131 @property 

132 def path_params(self) -> dict[str, typing.Any]: 

133 return self.scope.get("path_params", {}) 

134 

135 @property 

136 def cookies(self) -> dict[str, str]: 

137 if not hasattr(self, "_cookies"): 

138 cookies: dict[str, str] = {} 

139 cookie_header = self.headers.get("cookie") 

140 

141 if cookie_header: 

142 cookies = cookie_parser(cookie_header) 

143 self._cookies = cookies 

144 return self._cookies 

145 

146 @property 

147 def client(self) -> Address | None: 

148 # client is a 2 item tuple of (host, port), None or missing 

149 host_port = self.scope.get("client") 

150 if host_port is not None: 

151 return Address(*host_port) 

152 return None 

153 

154 @property 

155 def session(self) -> dict[str, typing.Any]: 

156 assert ( 

157 "session" in self.scope 

158 ), "SessionMiddleware must be installed to access request.session" 

159 return self.scope["session"] # type: ignore[no-any-return] 

160 

161 @property 

162 def auth(self) -> typing.Any: 

163 assert ( 

164 "auth" in self.scope 

165 ), "AuthenticationMiddleware must be installed to access request.auth" 

166 return self.scope["auth"] 

167 

168 @property 

169 def user(self) -> typing.Any: 

170 assert ( 

171 "user" in self.scope 

172 ), "AuthenticationMiddleware must be installed to access request.user" 

173 return self.scope["user"] 

174 

175 @property 

176 def state(self) -> State: 

177 if not hasattr(self, "_state"): 

178 # Ensure 'state' has an empty dict if it's not already populated. 

179 self.scope.setdefault("state", {}) 

180 # Create a state instance with a reference to the dict in which it should 

181 # store info 

182 self._state = State(self.scope["state"]) 

183 return self._state 

184 

185 def url_for(self, name: str, /, **path_params: typing.Any) -> URL: 

186 router: Router = self.scope["router"] 

187 url_path = router.url_path_for(name, **path_params) 

188 return url_path.make_absolute_url(base_url=self.base_url) 

189 

190 

191async def empty_receive() -> typing.NoReturn: 

192 raise RuntimeError("Receive channel has not been made available") 

193 

194 

195async def empty_send(message: Message) -> typing.NoReturn: 

196 raise RuntimeError("Send channel has not been made available") 

197 

198 

199class Request(HTTPConnection): 

200 _form: FormData | None 

201 

202 def __init__( 

203 self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send 

204 ): 

205 super().__init__(scope) 

206 assert scope["type"] == "http" 

207 self._receive = receive 

208 self._send = send 

209 self._stream_consumed = False 

210 self._is_disconnected = False 

211 self._form = None 

212 

213 @property 

214 def method(self) -> str: 

215 return typing.cast(str, self.scope["method"]) 

216 

217 @property 

218 def receive(self) -> Receive: 

219 return self._receive 

220 

221 async def stream(self) -> typing.AsyncGenerator[bytes, None]: 

222 if hasattr(self, "_body"): 

223 yield self._body 

224 yield b"" 

225 return 

226 if self._stream_consumed: 

227 raise RuntimeError("Stream consumed") 

228 while not self._stream_consumed: 

229 message = await self._receive() 

230 if message["type"] == "http.request": 

231 body = message.get("body", b"") 

232 if not message.get("more_body", False): 

233 self._stream_consumed = True 

234 if body: 

235 yield body 

236 elif message["type"] == "http.disconnect": 

237 self._is_disconnected = True 

238 raise ClientDisconnect() 

239 yield b"" 

240 

241 async def body(self) -> bytes: 

242 if not hasattr(self, "_body"): 

243 chunks: list[bytes] = [] 

244 async for chunk in self.stream(): 

245 chunks.append(chunk) 

246 self._body = b"".join(chunks) 

247 return self._body 

248 

249 async def json(self) -> typing.Any: 

250 if not hasattr(self, "_json"): 

251 body = await self.body() 

252 self._json = json.loads(body) 

253 return self._json 

254 

255 async def _get_form( 

256 self, *, max_files: int | float = 1000, max_fields: int | float = 1000 

257 ) -> FormData: 

258 if self._form is None: 

259 assert ( 

260 parse_options_header is not None 

261 ), "The `python-multipart` library must be installed to use form parsing." 

262 content_type_header = self.headers.get("Content-Type") 

263 content_type: bytes 

264 content_type, _ = parse_options_header(content_type_header) 

265 if content_type == b"multipart/form-data": 

266 try: 

267 multipart_parser = MultiPartParser( 

268 self.headers, 

269 self.stream(), 

270 max_files=max_files, 

271 max_fields=max_fields, 

272 ) 

273 self._form = await multipart_parser.parse() 

274 except MultiPartException as exc: 

275 if "app" in self.scope: 

276 raise HTTPException(status_code=400, detail=exc.message) 

277 raise exc 

278 elif content_type == b"application/x-www-form-urlencoded": 

279 form_parser = FormParser(self.headers, self.stream()) 

280 self._form = await form_parser.parse() 

281 else: 

282 self._form = FormData() 

283 return self._form 

284 

285 def form( 

286 self, *, max_files: int | float = 1000, max_fields: int | float = 1000 

287 ) -> AwaitableOrContextManager[FormData]: 

288 return AwaitableOrContextManagerWrapper( 

289 self._get_form(max_files=max_files, max_fields=max_fields) 

290 ) 

291 

292 async def close(self) -> None: 

293 if self._form is not None: 

294 await self._form.close() 

295 

296 async def is_disconnected(self) -> bool: 

297 if not self._is_disconnected: 

298 message: Message = {} 

299 

300 # If message isn't immediately available, move on 

301 with anyio.CancelScope() as cs: 

302 cs.cancel() 

303 message = await self._receive() 

304 

305 if message.get("type") == "http.disconnect": 

306 self._is_disconnected = True 

307 

308 return self._is_disconnected 

309 

310 async def send_push_promise(self, path: str) -> None: 

311 if "http.response.push" in self.scope.get("extensions", {}): 

312 raw_headers: list[tuple[bytes, bytes]] = [] 

313 for name in SERVER_PUSH_HEADERS_TO_COPY: 

314 for value in self.headers.getlist(name): 

315 raw_headers.append( 

316 (name.encode("latin-1"), value.encode("latin-1")) 

317 ) 

318 await self._send( 

319 {"type": "http.response.push", "path": path, "headers": raw_headers} 

320 )