Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/starlette/requests.py: 34%

195 statements  

« prev     ^ index     » next       coverage.py v7.2.2, created at 2023-03-26 06:12 +0000

1import json 

2import typing 

3from http import cookies as http_cookies 

4 

5import anyio 

6 

7from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper 

8from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State 

9from starlette.exceptions import HTTPException 

10from starlette.formparsers import FormParser, MultiPartException, MultiPartParser 

11from starlette.types import Message, Receive, Scope, Send 

12 

13try: 

14 from multipart.multipart import parse_options_header 

15except ImportError: # pragma: nocover 

16 parse_options_header = None 

17 

18 

19if typing.TYPE_CHECKING: 

20 from starlette.routing import Router 

21 

22 

23SERVER_PUSH_HEADERS_TO_COPY = { 

24 "accept", 

25 "accept-encoding", 

26 "accept-language", 

27 "cache-control", 

28 "user-agent", 

29} 

30 

31 

32def cookie_parser(cookie_string: str) -> typing.Dict[str, str]: 

33 """ 

34 This function parses a ``Cookie`` HTTP header into a dict of key/value pairs. 

35 

36 It attempts to mimic browser cookie parsing behavior: browsers and web servers 

37 frequently disregard the spec (RFC 6265) when setting and reading cookies, 

38 so we attempt to suit the common scenarios here. 

39 

40 This function has been adapted from Django 3.1.0. 

41 Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based 

42 on an outdated spec and will fail on lots of input we want to support 

43 """ 

44 cookie_dict: typing.Dict[str, str] = {} 

45 for chunk in cookie_string.split(";"): 

46 if "=" in chunk: 

47 key, val = chunk.split("=", 1) 

48 else: 

49 # Assume an empty name per 

50 # https://bugzilla.mozilla.org/show_bug.cgi?id=169091 

51 key, val = "", chunk 

52 key, val = key.strip(), val.strip() 

53 if key or val: 

54 # unquote using Python's algorithm. 

55 cookie_dict[key] = http_cookies._unquote(val) 

56 return cookie_dict 

57 

58 

59class ClientDisconnect(Exception): 

60 pass 

61 

62 

63class HTTPConnection(typing.Mapping[str, typing.Any]): 

64 """ 

65 A base class for incoming HTTP connections, that is used to provide 

66 any functionality that is common to both `Request` and `WebSocket`. 

67 """ 

68 

69 def __init__(self, scope: Scope, receive: typing.Optional[Receive] = None) -> None: 

70 assert scope["type"] in ("http", "websocket") 

71 self.scope = scope 

72 

73 def __getitem__(self, key: str) -> typing.Any: 

74 return self.scope[key] 

75 

76 def __iter__(self) -> typing.Iterator[str]: 

77 return iter(self.scope) 

78 

79 def __len__(self) -> int: 

80 return len(self.scope) 

81 

82 # Don't use the `abc.Mapping.__eq__` implementation. 

83 # Connection instances should never be considered equal 

84 # unless `self is other`. 

85 __eq__ = object.__eq__ 

86 __hash__ = object.__hash__ 

87 

88 @property 

89 def app(self) -> typing.Any: 

90 return self.scope["app"] 

91 

92 @property 

93 def url(self) -> URL: 

94 if not hasattr(self, "_url"): 

95 self._url = URL(scope=self.scope) 

96 return self._url 

97 

98 @property 

99 def base_url(self) -> URL: 

100 if not hasattr(self, "_base_url"): 

101 base_url_scope = dict(self.scope) 

102 base_url_scope["path"] = "/" 

103 base_url_scope["query_string"] = b"" 

104 base_url_scope["root_path"] = base_url_scope.get( 

105 "app_root_path", base_url_scope.get("root_path", "") 

106 ) 

107 self._base_url = URL(scope=base_url_scope) 

108 return self._base_url 

109 

110 @property 

111 def headers(self) -> Headers: 

112 if not hasattr(self, "_headers"): 

113 self._headers = Headers(scope=self.scope) 

114 return self._headers 

115 

116 @property 

117 def query_params(self) -> QueryParams: 

118 if not hasattr(self, "_query_params"): 

119 self._query_params = QueryParams(self.scope["query_string"]) 

120 return self._query_params 

121 

122 @property 

123 def path_params(self) -> typing.Dict[str, typing.Any]: 

124 return self.scope.get("path_params", {}) 

125 

126 @property 

127 def cookies(self) -> typing.Dict[str, str]: 

128 if not hasattr(self, "_cookies"): 

129 cookies: typing.Dict[str, str] = {} 

130 cookie_header = self.headers.get("cookie") 

131 

132 if cookie_header: 

133 cookies = cookie_parser(cookie_header) 

134 self._cookies = cookies 

135 return self._cookies 

136 

137 @property 

138 def client(self) -> typing.Optional[Address]: 

139 # client is a 2 item tuple of (host, port), None or missing 

140 host_port = self.scope.get("client") 

141 if host_port is not None: 

142 return Address(*host_port) 

143 return None 

144 

145 @property 

146 def session(self) -> typing.Dict[str, typing.Any]: 

147 assert ( 

148 "session" in self.scope 

149 ), "SessionMiddleware must be installed to access request.session" 

150 return self.scope["session"] 

151 

152 @property 

153 def auth(self) -> typing.Any: 

154 assert ( 

155 "auth" in self.scope 

156 ), "AuthenticationMiddleware must be installed to access request.auth" 

157 return self.scope["auth"] 

158 

159 @property 

160 def user(self) -> typing.Any: 

161 assert ( 

162 "user" in self.scope 

163 ), "AuthenticationMiddleware must be installed to access request.user" 

164 return self.scope["user"] 

165 

166 @property 

167 def state(self) -> State: 

168 if not hasattr(self, "_state"): 

169 # Ensure 'state' has an empty dict if it's not already populated. 

170 self.scope.setdefault("state", {}) 

171 # Create a state instance with a reference to the dict in which it should 

172 # store info 

173 self._state = State(self.scope["state"]) 

174 return self._state 

175 

176 def url_for(self, name: str, **path_params: typing.Any) -> str: 

177 router: Router = self.scope["router"] 

178 url_path = router.url_path_for(name, **path_params) 

179 return url_path.make_absolute_url(base_url=self.base_url) 

180 

181 

182async def empty_receive() -> typing.NoReturn: 

183 raise RuntimeError("Receive channel has not been made available") 

184 

185 

186async def empty_send(message: Message) -> typing.NoReturn: 

187 raise RuntimeError("Send channel has not been made available") 

188 

189 

190class Request(HTTPConnection): 

191 _form: typing.Optional[FormData] 

192 

193 def __init__( 

194 self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send 

195 ): 

196 super().__init__(scope) 

197 assert scope["type"] == "http" 

198 self._receive = receive 

199 self._send = send 

200 self._stream_consumed = False 

201 self._is_disconnected = False 

202 self._form = None 

203 

204 @property 

205 def method(self) -> str: 

206 return self.scope["method"] 

207 

208 @property 

209 def receive(self) -> Receive: 

210 return self._receive 

211 

212 async def stream(self) -> typing.AsyncGenerator[bytes, None]: 

213 if hasattr(self, "_body"): 

214 yield self._body 

215 yield b"" 

216 return 

217 if self._stream_consumed: 

218 raise RuntimeError("Stream consumed") 

219 self._stream_consumed = True 

220 while True: 

221 message = await self._receive() 

222 if message["type"] == "http.request": 

223 body = message.get("body", b"") 

224 if body: 

225 yield body 

226 if not message.get("more_body", False): 

227 break 

228 elif message["type"] == "http.disconnect": 

229 self._is_disconnected = True 

230 raise ClientDisconnect() 

231 yield b"" 

232 

233 async def body(self) -> bytes: 

234 if not hasattr(self, "_body"): 

235 chunks: "typing.List[bytes]" = [] 

236 async for chunk in self.stream(): 

237 chunks.append(chunk) 

238 self._body = b"".join(chunks) 

239 return self._body 

240 

241 async def json(self) -> typing.Any: 

242 if not hasattr(self, "_json"): 

243 body = await self.body() 

244 self._json = json.loads(body) 

245 return self._json 

246 

247 async def _get_form( 

248 self, 

249 *, 

250 max_files: typing.Union[int, float] = 1000, 

251 max_fields: typing.Union[int, float] = 1000, 

252 ) -> FormData: 

253 if self._form is None: 

254 assert ( 

255 parse_options_header is not None 

256 ), "The `python-multipart` library must be installed to use form parsing." 

257 content_type_header = self.headers.get("Content-Type") 

258 content_type: bytes 

259 content_type, _ = parse_options_header(content_type_header) 

260 if content_type == b"multipart/form-data": 

261 try: 

262 multipart_parser = MultiPartParser( 

263 self.headers, 

264 self.stream(), 

265 max_files=max_files, 

266 max_fields=max_fields, 

267 ) 

268 self._form = await multipart_parser.parse() 

269 except MultiPartException as exc: 

270 if "app" in self.scope: 

271 raise HTTPException(status_code=400, detail=exc.message) 

272 raise exc 

273 elif content_type == b"application/x-www-form-urlencoded": 

274 form_parser = FormParser(self.headers, self.stream()) 

275 self._form = await form_parser.parse() 

276 else: 

277 self._form = FormData() 

278 return self._form 

279 

280 def form( 

281 self, 

282 *, 

283 max_files: typing.Union[int, float] = 1000, 

284 max_fields: typing.Union[int, float] = 1000, 

285 ) -> AwaitableOrContextManager[FormData]: 

286 return AwaitableOrContextManagerWrapper( 

287 self._get_form(max_files=max_files, max_fields=max_fields) 

288 ) 

289 

290 async def close(self) -> None: 

291 if self._form is not None: 

292 await self._form.close() 

293 

294 async def is_disconnected(self) -> bool: 

295 if not self._is_disconnected: 

296 message: Message = {} 

297 

298 # If message isn't immediately available, move on 

299 with anyio.CancelScope() as cs: 

300 cs.cancel() 

301 message = await self._receive() 

302 

303 if message.get("type") == "http.disconnect": 

304 self._is_disconnected = True 

305 

306 return self._is_disconnected 

307 

308 async def send_push_promise(self, path: str) -> None: 

309 if "http.response.push" in self.scope.get("extensions", {}): 

310 raw_headers: "typing.List[typing.Tuple[bytes, bytes]]" = [] 

311 for name in SERVER_PUSH_HEADERS_TO_COPY: 

312 for value in self.headers.getlist(name): 

313 raw_headers.append( 

314 (name.encode("latin-1"), value.encode("latin-1")) 

315 ) 

316 await self._send( 

317 {"type": "http.response.push", "path": path, "headers": raw_headers} 

318 )