Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/starlette/responses.py: 22%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

333 statements  

1from __future__ import annotations 

2 

3import hashlib 

4import http.cookies 

5import json 

6import os 

7import stat 

8import sys 

9from collections.abc import AsyncIterable, Awaitable, Callable, Iterable, Mapping, Sequence 

10from datetime import datetime 

11from email.utils import format_datetime, formatdate 

12from functools import partial 

13from mimetypes import guess_type 

14from secrets import token_hex 

15from typing import Any, Literal 

16from urllib.parse import quote 

17 

18import anyio 

19import anyio.to_thread 

20 

21from starlette._utils import collapse_excgroups 

22from starlette.background import BackgroundTask 

23from starlette.concurrency import iterate_in_threadpool 

24from starlette.datastructures import URL, Headers, MutableHeaders 

25from starlette.requests import ClientDisconnect 

26from starlette.types import Message, Receive, Scope, Send 

27 

28 

29class Response: 

30 media_type = None 

31 charset = "utf-8" 

32 

33 def __init__( 

34 self, 

35 content: Any = None, 

36 status_code: int = 200, 

37 headers: Mapping[str, str] | None = None, 

38 media_type: str | None = None, 

39 background: BackgroundTask | None = None, 

40 ) -> None: 

41 self.status_code = status_code 

42 if media_type is not None: 

43 self.media_type = media_type 

44 self.background = background 

45 self.body = self.render(content) 

46 self.init_headers(headers) 

47 

48 def render(self, content: Any) -> bytes | memoryview: 

49 if content is None: 

50 return b"" 

51 if isinstance(content, bytes | memoryview): 

52 return content 

53 return content.encode(self.charset) # type: ignore 

54 

55 def init_headers(self, headers: Mapping[str, str] | None = None) -> None: 

56 if headers is None: 

57 raw_headers: list[tuple[bytes, bytes]] = [] 

58 populate_content_length = True 

59 populate_content_type = True 

60 else: 

61 raw_headers = [(k.lower().encode("latin-1"), v.encode("latin-1")) for k, v in headers.items()] 

62 keys = [h[0] for h in raw_headers] 

63 populate_content_length = b"content-length" not in keys 

64 populate_content_type = b"content-type" not in keys 

65 

66 body = getattr(self, "body", None) 

67 if ( 

68 body is not None 

69 and populate_content_length 

70 and not (self.status_code < 200 or self.status_code in (204, 304)) 

71 ): 

72 content_length = str(len(body)) 

73 raw_headers.append((b"content-length", content_length.encode("latin-1"))) 

74 

75 content_type = self.media_type 

76 if content_type is not None and populate_content_type: 

77 if content_type.startswith("text/") and "charset=" not in content_type.lower(): 

78 content_type += "; charset=" + self.charset 

79 raw_headers.append((b"content-type", content_type.encode("latin-1"))) 

80 

81 self.raw_headers = raw_headers 

82 

83 @property 

84 def headers(self) -> MutableHeaders: 

85 if not hasattr(self, "_headers"): 

86 self._headers = MutableHeaders(raw=self.raw_headers) 

87 return self._headers 

88 

89 def set_cookie( 

90 self, 

91 key: str, 

92 value: str = "", 

93 max_age: int | None = None, 

94 expires: datetime | str | int | None = None, 

95 path: str | None = "/", 

96 domain: str | None = None, 

97 secure: bool = False, 

98 httponly: bool = False, 

99 samesite: Literal["lax", "strict", "none"] | None = "lax", 

100 partitioned: bool = False, 

101 ) -> None: 

102 cookie: http.cookies.BaseCookie[str] = http.cookies.SimpleCookie() 

103 cookie[key] = value 

104 if max_age is not None: 

105 cookie[key]["max-age"] = max_age 

106 if expires is not None: 

107 if isinstance(expires, datetime): 

108 cookie[key]["expires"] = format_datetime(expires, usegmt=True) 

109 else: 

110 cookie[key]["expires"] = expires 

111 if path is not None: 

112 cookie[key]["path"] = path 

113 if domain is not None: 

114 cookie[key]["domain"] = domain 

115 if secure: 

116 cookie[key]["secure"] = True 

117 if httponly: 

118 cookie[key]["httponly"] = True 

119 if samesite is not None: 

120 assert samesite.lower() in [ 

121 "strict", 

122 "lax", 

123 "none", 

124 ], "samesite must be either 'strict', 'lax' or 'none'" 

125 cookie[key]["samesite"] = samesite 

126 if partitioned: 

127 if sys.version_info < (3, 14): 

128 raise ValueError("Partitioned cookies are only supported in Python 3.14 and above.") # pragma: no cover 

129 cookie[key]["partitioned"] = True # pragma: no cover 

130 

131 cookie_val = cookie.output(header="").strip() 

132 self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1"))) 

133 

134 def delete_cookie( 

135 self, 

136 key: str, 

137 path: str = "/", 

138 domain: str | None = None, 

139 secure: bool = False, 

140 httponly: bool = False, 

141 samesite: Literal["lax", "strict", "none"] | None = "lax", 

142 ) -> None: 

143 self.set_cookie( 

144 key, 

145 max_age=0, 

146 expires=0, 

147 path=path, 

148 domain=domain, 

149 secure=secure, 

150 httponly=httponly, 

151 samesite=samesite, 

152 ) 

153 

154 def _wrap_websocket_denial_send(self, send: Send) -> Send: 

155 async def wrapped(message: Message) -> None: 

156 message_type = message["type"] 

157 if message_type in {"http.response.start", "http.response.body"}: # pragma: no branch 

158 message = {**message, "type": "websocket." + message_type} 

159 await send(message) 

160 

161 return wrapped 

162 

163 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

164 if scope["type"] == "websocket": 

165 send = self._wrap_websocket_denial_send(send) 

166 await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers}) 

167 await send({"type": "http.response.body", "body": self.body}) 

168 

169 if self.background is not None: 

170 await self.background() 

171 

172 

173class HTMLResponse(Response): 

174 media_type = "text/html" 

175 

176 

177class PlainTextResponse(Response): 

178 media_type = "text/plain" 

179 

180 

181class JSONResponse(Response): 

182 media_type = "application/json" 

183 

184 def __init__( 

185 self, 

186 content: Any, 

187 status_code: int = 200, 

188 headers: Mapping[str, str] | None = None, 

189 media_type: str | None = None, 

190 background: BackgroundTask | None = None, 

191 ) -> None: 

192 super().__init__(content, status_code, headers, media_type, background) 

193 

194 def render(self, content: Any) -> bytes: 

195 return json.dumps( 

196 content, 

197 ensure_ascii=False, 

198 allow_nan=False, 

199 indent=None, 

200 separators=(",", ":"), 

201 ).encode("utf-8") 

202 

203 

204class RedirectResponse(Response): 

205 def __init__( 

206 self, 

207 url: str | URL, 

208 status_code: int = 307, 

209 headers: Mapping[str, str] | None = None, 

210 background: BackgroundTask | None = None, 

211 ) -> None: 

212 super().__init__(content=b"", status_code=status_code, headers=headers, background=background) 

213 self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;") 

214 

215 

216Content = str | bytes | memoryview 

217SyncContentStream = Iterable[Content] 

218AsyncContentStream = AsyncIterable[Content] 

219ContentStream = AsyncContentStream | SyncContentStream 

220 

221 

222class StreamingResponse(Response): 

223 body_iterator: AsyncContentStream 

224 

225 def __init__( 

226 self, 

227 content: ContentStream, 

228 status_code: int = 200, 

229 headers: Mapping[str, str] | None = None, 

230 media_type: str | None = None, 

231 background: BackgroundTask | None = None, 

232 ) -> None: 

233 if isinstance(content, AsyncIterable): 

234 self.body_iterator = content 

235 else: 

236 self.body_iterator = iterate_in_threadpool(content) 

237 self.status_code = status_code 

238 self.media_type = self.media_type if media_type is None else media_type 

239 self.background = background 

240 self.init_headers(headers) 

241 

242 async def listen_for_disconnect(self, receive: Receive) -> None: 

243 while True: 

244 message = await receive() 

245 if message["type"] == "http.disconnect": 

246 break 

247 

248 async def stream_response(self, send: Send) -> None: 

249 await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers}) 

250 async for chunk in self.body_iterator: 

251 if not isinstance(chunk, bytes | memoryview): 

252 chunk = chunk.encode(self.charset) 

253 await send({"type": "http.response.body", "body": chunk, "more_body": True}) 

254 

255 await send({"type": "http.response.body", "body": b"", "more_body": False}) 

256 

257 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

258 if scope["type"] == "websocket": 

259 send = self._wrap_websocket_denial_send(send) 

260 await self.stream_response(send) 

261 if self.background is not None: 

262 await self.background() 

263 return 

264 

265 spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split("."))) 

266 

267 if spec_version >= (2, 4): 

268 try: 

269 await self.stream_response(send) 

270 except OSError: 

271 raise ClientDisconnect() 

272 else: 

273 with collapse_excgroups(): 

274 async with anyio.create_task_group() as task_group: 

275 

276 async def wrap(func: Callable[[], Awaitable[None]]) -> None: 

277 await func() 

278 task_group.cancel_scope.cancel() 

279 

280 task_group.start_soon(wrap, partial(self.stream_response, send)) 

281 await wrap(partial(self.listen_for_disconnect, receive)) 

282 

283 if self.background is not None: 

284 await self.background() 

285 

286 

287class MalformedRangeHeader(Exception): 

288 def __init__(self, content: str = "Malformed range header.") -> None: 

289 self.content = content 

290 

291 

292class RangeNotSatisfiable(Exception): 

293 def __init__(self, max_size: int) -> None: 

294 self.max_size = max_size 

295 

296 

297class FileResponse(Response): 

298 chunk_size = 64 * 1024 

299 

300 def __init__( 

301 self, 

302 path: str | os.PathLike[str], 

303 status_code: int = 200, 

304 headers: Mapping[str, str] | None = None, 

305 media_type: str | None = None, 

306 background: BackgroundTask | None = None, 

307 filename: str | None = None, 

308 stat_result: os.stat_result | None = None, 

309 content_disposition_type: str = "attachment", 

310 ) -> None: 

311 self.path = path 

312 self.status_code = status_code 

313 self.filename = filename 

314 if media_type is None: 

315 media_type = guess_type(filename or path)[0] or "text/plain" 

316 self.media_type = media_type 

317 self.background = background 

318 self.init_headers(headers) 

319 self.headers.setdefault("accept-ranges", "bytes") 

320 if self.filename is not None: 

321 content_disposition_filename = quote(self.filename) 

322 if content_disposition_filename != self.filename: 

323 content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}" 

324 else: 

325 content_disposition = f'{content_disposition_type}; filename="{self.filename}"' 

326 self.headers.setdefault("content-disposition", content_disposition) 

327 self.stat_result = stat_result 

328 if stat_result is not None: 

329 self.set_stat_headers(stat_result) 

330 

331 def set_stat_headers(self, stat_result: os.stat_result) -> None: 

332 content_length = str(stat_result.st_size) 

333 last_modified = formatdate(stat_result.st_mtime, usegmt=True) 

334 etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size) 

335 etag = f'"{hashlib.md5(etag_base.encode(), usedforsecurity=False).hexdigest()}"' 

336 

337 self.headers.setdefault("content-length", content_length) 

338 self.headers.setdefault("last-modified", last_modified) 

339 self.headers.setdefault("etag", etag) 

340 

341 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

342 scope_type = scope["type"] 

343 send_header_only = scope_type == "http" and scope["method"].upper() == "HEAD" 

344 send_pathsend = scope_type == "http" and "http.response.pathsend" in scope.get("extensions", {}) 

345 if scope_type == "websocket": 

346 send = self._wrap_websocket_denial_send(send) 

347 

348 if self.stat_result is None: 

349 try: 

350 stat_result = await anyio.to_thread.run_sync(os.stat, self.path) 

351 self.set_stat_headers(stat_result) 

352 except FileNotFoundError: 

353 raise RuntimeError(f"File at path {self.path} does not exist.") 

354 else: 

355 mode = stat_result.st_mode 

356 if not stat.S_ISREG(mode): 

357 raise RuntimeError(f"File at path {self.path} is not a file.") 

358 else: 

359 stat_result = self.stat_result 

360 

361 headers = Headers(scope=scope) 

362 http_range = headers.get("range") 

363 http_if_range = headers.get("if-range") 

364 

365 if http_range is None or (http_if_range is not None and not self._should_use_range(http_if_range)): 

366 await self._handle_simple(send, send_header_only, send_pathsend) 

367 else: 

368 try: 

369 ranges = self._parse_range_header(http_range, stat_result.st_size) 

370 except MalformedRangeHeader as exc: 

371 return await PlainTextResponse(exc.content, status_code=400)(scope, receive, send) 

372 except RangeNotSatisfiable as exc: 

373 response = PlainTextResponse(status_code=416, headers={"Content-Range": f"bytes */{exc.max_size}"}) 

374 return await response(scope, receive, send) 

375 

376 if len(ranges) == 1: 

377 start, end = ranges[0] 

378 await self._handle_single_range(send, start, end, stat_result.st_size, send_header_only) 

379 else: 

380 await self._handle_multiple_ranges(send, ranges, stat_result.st_size, send_header_only) 

381 

382 if self.background is not None: 

383 await self.background() 

384 

385 async def _handle_simple(self, send: Send, send_header_only: bool, send_pathsend: bool) -> None: 

386 await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers}) 

387 if send_header_only: 

388 await send({"type": "http.response.body", "body": b"", "more_body": False}) 

389 elif send_pathsend: 

390 await send({"type": "http.response.pathsend", "path": str(self.path)}) 

391 else: 

392 async with await anyio.open_file(self.path, mode="rb") as file: 

393 more_body = True 

394 while more_body: 

395 chunk = await file.read(self.chunk_size) 

396 more_body = len(chunk) == self.chunk_size 

397 await send({"type": "http.response.body", "body": chunk, "more_body": more_body}) 

398 

399 async def _handle_single_range( 

400 self, send: Send, start: int, end: int, file_size: int, send_header_only: bool 

401 ) -> None: 

402 headers = MutableHeaders(raw=list(self.raw_headers)) 

403 headers["content-range"] = f"bytes {start}-{end - 1}/{file_size}" 

404 headers["content-length"] = str(end - start) 

405 await send({"type": "http.response.start", "status": 206, "headers": headers.raw}) 

406 if send_header_only: 

407 await send({"type": "http.response.body", "body": b"", "more_body": False}) 

408 else: 

409 async with await anyio.open_file(self.path, mode="rb") as file: 

410 await file.seek(start) 

411 more_body = True 

412 while more_body: 

413 chunk = await file.read(min(self.chunk_size, end - start)) 

414 start += len(chunk) 

415 more_body = len(chunk) == self.chunk_size and start < end 

416 await send({"type": "http.response.body", "body": chunk, "more_body": more_body}) 

417 

418 async def _handle_multiple_ranges( 

419 self, 

420 send: Send, 

421 ranges: list[tuple[int, int]], 

422 file_size: int, 

423 send_header_only: bool, 

424 ) -> None: 

425 # In firefox and chrome, they use boundary with 95-96 bits entropy (that's roughly 13 bytes). 

426 boundary = token_hex(13) 

427 content_length, header_generator = self.generate_multipart( 

428 ranges, boundary, file_size, self.headers["content-type"] 

429 ) 

430 headers = MutableHeaders(raw=list(self.raw_headers)) 

431 headers["content-type"] = f"multipart/byteranges; boundary={boundary}" 

432 headers["content-length"] = str(content_length) 

433 await send({"type": "http.response.start", "status": 206, "headers": headers.raw}) 

434 if send_header_only: 

435 await send({"type": "http.response.body", "body": b"", "more_body": False}) 

436 else: 

437 async with await anyio.open_file(self.path, mode="rb") as file: 

438 for start, end in ranges: 

439 await send({"type": "http.response.body", "body": header_generator(start, end), "more_body": True}) 

440 await file.seek(start) 

441 while start < end: 

442 chunk = await file.read(min(self.chunk_size, end - start)) 

443 start += len(chunk) 

444 await send({"type": "http.response.body", "body": chunk, "more_body": True}) 

445 await send({"type": "http.response.body", "body": b"\r\n", "more_body": True}) 

446 await send( 

447 { 

448 "type": "http.response.body", 

449 "body": f"--{boundary}--".encode("latin-1"), 

450 "more_body": False, 

451 } 

452 ) 

453 

454 def _should_use_range(self, http_if_range: str) -> bool: 

455 return http_if_range == self.headers["last-modified"] or http_if_range == self.headers["etag"] 

456 

457 @classmethod 

458 def _parse_range_header(cls, http_range: str, file_size: int) -> list[tuple[int, int]]: 

459 ranges: list[tuple[int, int]] = [] 

460 try: 

461 units, range_ = http_range.split("=", 1) 

462 except ValueError: 

463 raise MalformedRangeHeader() 

464 

465 units = units.strip().lower() 

466 

467 if units != "bytes": 

468 raise MalformedRangeHeader("Only support bytes range") 

469 

470 ranges = cls._parse_ranges(range_, file_size) 

471 

472 if len(ranges) == 0: 

473 raise MalformedRangeHeader("Range header: range must be requested") 

474 

475 if any(not (0 <= start < file_size) for start, _ in ranges): 

476 raise RangeNotSatisfiable(file_size) 

477 

478 if any(start > end for start, end in ranges): 

479 raise MalformedRangeHeader("Range header: start must be less than end") 

480 

481 if len(ranges) == 1: 

482 return ranges 

483 

484 # Merge overlapping ranges 

485 ranges.sort() 

486 result: list[tuple[int, int]] = [ranges[0]] 

487 for start, end in ranges[1:]: 

488 last_start, last_end = result[-1] 

489 if start <= last_end: 

490 result[-1] = (last_start, max(last_end, end)) 

491 else: 

492 result.append((start, end)) 

493 

494 return result 

495 

496 @classmethod 

497 def _parse_ranges(cls, range_: str, file_size: int) -> list[tuple[int, int]]: 

498 ranges: list[tuple[int, int]] = [] 

499 

500 for part in range_.split(","): 

501 part = part.strip() 

502 

503 # If the range is empty or a single dash, we ignore it. 

504 if not part or part == "-": 

505 continue 

506 

507 # If the range is not in the format "start-end", we ignore it. 

508 if "-" not in part: 

509 continue 

510 

511 start_str, end_str = part.split("-", 1) 

512 start_str = start_str.strip() 

513 end_str = end_str.strip() 

514 

515 try: 

516 start = int(start_str) if start_str else file_size - int(end_str) 

517 end = int(end_str) + 1 if start_str and end_str and int(end_str) < file_size else file_size 

518 ranges.append((start, end)) 

519 except ValueError: 

520 # If the range is not numeric, we ignore it. 

521 continue 

522 

523 return ranges 

524 

525 def generate_multipart( 

526 self, 

527 ranges: Sequence[tuple[int, int]], 

528 boundary: str, 

529 max_size: int, 

530 content_type: str, 

531 ) -> tuple[int, Callable[[int, int], bytes]]: 

532 r""" 

533 Multipart response headers generator. 

534 

535 ``` 

536 --{boundary}\r\n 

537 Content-Type: {content_type}\r\n 

538 Content-Range: bytes {start}-{end-1}/{max_size}\r\n 

539 \r\n 

540 ..........content...........\r\n 

541 --{boundary}\r\n 

542 Content-Type: {content_type}\r\n 

543 Content-Range: bytes {start}-{end-1}/{max_size}\r\n 

544 \r\n 

545 ..........content...........\r\n 

546 --{boundary}-- 

547 ``` 

548 """ 

549 boundary_len = len(boundary) 

550 static_header_part_len = 49 + boundary_len + len(content_type) + len(str(max_size)) 

551 content_length = sum( 

552 (len(str(start)) + len(str(end - 1)) + static_header_part_len) # Headers 

553 + (end - start) # Content 

554 for start, end in ranges 

555 ) + ( 

556 4 + boundary_len # --boundary-- 

557 ) 

558 return ( 

559 content_length, 

560 lambda start, end: ( 

561 f"""\ 

562--{boundary}\r 

563Content-Type: {content_type}\r 

564Content-Range: bytes {start}-{end - 1}/{max_size}\r 

565\r 

566""" 

567 ).encode("latin-1"), 

568 )