Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/starlette/requests.py: 34%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

200 statements  

1from __future__ import annotations 

2 

3import json 

4import typing 

5from http import cookies as http_cookies 

6 

7import anyio 

8 

9from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper 

10from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State 

11from starlette.exceptions import HTTPException 

12from starlette.formparsers import FormParser, MultiPartException, MultiPartParser 

13from starlette.types import Message, Receive, Scope, Send 

14 

15try: 

16 from multipart.multipart import parse_options_header 

17except ModuleNotFoundError: # pragma: nocover 

18 parse_options_header = None 

19 

20 

21if typing.TYPE_CHECKING: 

22 from starlette.routing import Router 

23 

24 

25SERVER_PUSH_HEADERS_TO_COPY = { 

26 "accept", 

27 "accept-encoding", 

28 "accept-language", 

29 "cache-control", 

30 "user-agent", 

31} 

32 

33 

34def cookie_parser(cookie_string: str) -> dict[str, str]: 

35 """ 

36 This function parses a ``Cookie`` HTTP header into a dict of key/value pairs. 

37 

38 It attempts to mimic browser cookie parsing behavior: browsers and web servers 

39 frequently disregard the spec (RFC 6265) when setting and reading cookies, 

40 so we attempt to suit the common scenarios here. 

41 

42 This function has been adapted from Django 3.1.0. 

43 Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based 

44 on an outdated spec and will fail on lots of input we want to support 

45 """ 

46 cookie_dict: dict[str, str] = {} 

47 for chunk in cookie_string.split(";"): 

48 if "=" in chunk: 

49 key, val = chunk.split("=", 1) 

50 else: 

51 # Assume an empty name per 

52 # https://bugzilla.mozilla.org/show_bug.cgi?id=169091 

53 key, val = "", chunk 

54 key, val = key.strip(), val.strip() 

55 if key or val: 

56 # unquote using Python's algorithm. 

57 cookie_dict[key] = http_cookies._unquote(val) 

58 return cookie_dict 

59 

60 

61class ClientDisconnect(Exception): 

62 pass 

63 

64 

65class HTTPConnection(typing.Mapping[str, typing.Any]): 

66 """ 

67 A base class for incoming HTTP connections, that is used to provide 

68 any functionality that is common to both `Request` and `WebSocket`. 

69 """ 

70 

71 def __init__(self, scope: Scope, receive: Receive | None = None) -> None: 

72 assert scope["type"] in ("http", "websocket") 

73 self.scope = scope 

74 

75 def __getitem__(self, key: str) -> typing.Any: 

76 return self.scope[key] 

77 

78 def __iter__(self) -> typing.Iterator[str]: 

79 return iter(self.scope) 

80 

81 def __len__(self) -> int: 

82 return len(self.scope) 

83 

84 # Don't use the `abc.Mapping.__eq__` implementation. 

85 # Connection instances should never be considered equal 

86 # unless `self is other`. 

87 __eq__ = object.__eq__ 

88 __hash__ = object.__hash__ 

89 

90 @property 

91 def app(self) -> typing.Any: 

92 return self.scope["app"] 

93 

94 @property 

95 def url(self) -> URL: 

96 if not hasattr(self, "_url"): 

97 self._url = URL(scope=self.scope) 

98 return self._url 

99 

100 @property 

101 def base_url(self) -> URL: 

102 if not hasattr(self, "_base_url"): 

103 base_url_scope = dict(self.scope) 

104 # This is used by request.url_for, it might be used inside a Mount which 

105 # would have its own child scope with its own root_path, but the base URL 

106 # for url_for should still be the top level app root path. 

107 app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", "")) 

108 path = app_root_path 

109 if not path.endswith("/"): 

110 path += "/" 

111 base_url_scope["path"] = path 

112 base_url_scope["query_string"] = b"" 

113 base_url_scope["root_path"] = app_root_path 

114 self._base_url = URL(scope=base_url_scope) 

115 return self._base_url 

116 

117 @property 

118 def headers(self) -> Headers: 

119 if not hasattr(self, "_headers"): 

120 self._headers = Headers(scope=self.scope) 

121 return self._headers 

122 

123 @property 

124 def query_params(self) -> QueryParams: 

125 if not hasattr(self, "_query_params"): 

126 self._query_params = QueryParams(self.scope["query_string"]) 

127 return self._query_params 

128 

129 @property 

130 def path_params(self) -> dict[str, typing.Any]: 

131 return self.scope.get("path_params", {}) 

132 

133 @property 

134 def cookies(self) -> dict[str, str]: 

135 if not hasattr(self, "_cookies"): 

136 cookies: dict[str, str] = {} 

137 cookie_header = self.headers.get("cookie") 

138 

139 if cookie_header: 

140 cookies = cookie_parser(cookie_header) 

141 self._cookies = cookies 

142 return self._cookies 

143 

144 @property 

145 def client(self) -> Address | None: 

146 # client is a 2 item tuple of (host, port), None if missing 

147 host_port = self.scope.get("client") 

148 if host_port is not None: 

149 return Address(*host_port) 

150 return None 

151 

152 @property 

153 def session(self) -> dict[str, typing.Any]: 

154 assert "session" in self.scope, "SessionMiddleware must be installed to access request.session" 

155 return self.scope["session"] # type: ignore[no-any-return] 

156 

157 @property 

158 def auth(self) -> typing.Any: 

159 assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth" 

160 return self.scope["auth"] 

161 

162 @property 

163 def user(self) -> typing.Any: 

164 assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user" 

165 return self.scope["user"] 

166 

167 @property 

168 def state(self) -> State: 

169 if not hasattr(self, "_state"): 

170 # Ensure 'state' has an empty dict if it's not already populated. 

171 self.scope.setdefault("state", {}) 

172 # Create a state instance with a reference to the dict in which it should 

173 # store info 

174 self._state = State(self.scope["state"]) 

175 return self._state 

176 

177 def url_for(self, name: str, /, **path_params: typing.Any) -> URL: 

178 router: Router = self.scope["router"] 

179 url_path = router.url_path_for(name, **path_params) 

180 return url_path.make_absolute_url(base_url=self.base_url) 

181 

182 

183async def empty_receive() -> typing.NoReturn: 

184 raise RuntimeError("Receive channel has not been made available") 

185 

186 

187async def empty_send(message: Message) -> typing.NoReturn: 

188 raise RuntimeError("Send channel has not been made available") 

189 

190 

191class Request(HTTPConnection): 

192 _form: FormData | None 

193 

194 def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send): 

195 super().__init__(scope) 

196 assert scope["type"] == "http" 

197 self._receive = receive 

198 self._send = send 

199 self._stream_consumed = False 

200 self._is_disconnected = False 

201 self._form = None 

202 

203 @property 

204 def method(self) -> str: 

205 return typing.cast(str, self.scope["method"]) 

206 

207 @property 

208 def receive(self) -> Receive: 

209 return self._receive 

210 

211 async def stream(self) -> typing.AsyncGenerator[bytes, None]: 

212 if hasattr(self, "_body"): 

213 yield self._body 

214 yield b"" 

215 return 

216 if self._stream_consumed: 

217 raise RuntimeError("Stream consumed") 

218 while not self._stream_consumed: 

219 message = await self._receive() 

220 if message["type"] == "http.request": 

221 body = message.get("body", b"") 

222 if not message.get("more_body", False): 

223 self._stream_consumed = True 

224 if body: 

225 yield body 

226 elif message["type"] == "http.disconnect": 

227 self._is_disconnected = True 

228 raise ClientDisconnect() 

229 yield b"" 

230 

231 async def body(self) -> bytes: 

232 if not hasattr(self, "_body"): 

233 chunks: list[bytes] = [] 

234 async for chunk in self.stream(): 

235 chunks.append(chunk) 

236 self._body = b"".join(chunks) 

237 return self._body 

238 

239 async def json(self) -> typing.Any: 

240 if not hasattr(self, "_json"): 

241 body = await self.body() 

242 self._json = json.loads(body) 

243 return self._json 

244 

245 async def _get_form(self, *, max_files: int | float = 1000, max_fields: int | float = 1000) -> FormData: 

246 if self._form is None: 

247 assert ( 

248 parse_options_header is not None 

249 ), "The `python-multipart` library must be installed to use form parsing." 

250 content_type_header = self.headers.get("Content-Type") 

251 content_type: bytes 

252 content_type, _ = parse_options_header(content_type_header) 

253 if content_type == b"multipart/form-data": 

254 try: 

255 multipart_parser = MultiPartParser( 

256 self.headers, 

257 self.stream(), 

258 max_files=max_files, 

259 max_fields=max_fields, 

260 ) 

261 self._form = await multipart_parser.parse() 

262 except MultiPartException as exc: 

263 if "app" in self.scope: 

264 raise HTTPException(status_code=400, detail=exc.message) 

265 raise exc 

266 elif content_type == b"application/x-www-form-urlencoded": 

267 form_parser = FormParser(self.headers, self.stream()) 

268 self._form = await form_parser.parse() 

269 else: 

270 self._form = FormData() 

271 return self._form 

272 

273 def form( 

274 self, *, max_files: int | float = 1000, max_fields: int | float = 1000 

275 ) -> AwaitableOrContextManager[FormData]: 

276 return AwaitableOrContextManagerWrapper(self._get_form(max_files=max_files, max_fields=max_fields)) 

277 

278 async def close(self) -> None: 

279 if self._form is not None: 

280 await self._form.close() 

281 

282 async def is_disconnected(self) -> bool: 

283 if not self._is_disconnected: 

284 message: Message = {} 

285 

286 # If message isn't immediately available, move on 

287 with anyio.CancelScope() as cs: 

288 cs.cancel() 

289 message = await self._receive() 

290 

291 if message.get("type") == "http.disconnect": 

292 self._is_disconnected = True 

293 

294 return self._is_disconnected 

295 

296 async def send_push_promise(self, path: str) -> None: 

297 if "http.response.push" in self.scope.get("extensions", {}): 

298 raw_headers: list[tuple[bytes, bytes]] = [] 

299 for name in SERVER_PUSH_HEADERS_TO_COPY: 

300 for value in self.headers.getlist(name): 

301 raw_headers.append((name.encode("latin-1"), value.encode("latin-1"))) 

302 await self._send({"type": "http.response.push", "path": path, "headers": raw_headers})