Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/starlette/responses.py: 23%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

322 statements  

1from __future__ import annotations 

2 

3import hashlib 

4import http.cookies 

5import json 

6import os 

7import stat 

8import sys 

9import warnings 

10from collections.abc import AsyncIterable, Awaitable, Callable, Iterable, Mapping, Sequence 

11from datetime import datetime 

12from email.utils import format_datetime, formatdate 

13from functools import partial 

14from mimetypes import guess_type 

15from secrets import token_hex 

16from typing import Any, Literal 

17from urllib.parse import quote 

18 

19import anyio 

20import anyio.to_thread 

21 

22from starlette._utils import collapse_excgroups 

23from starlette.background import BackgroundTask 

24from starlette.concurrency import iterate_in_threadpool 

25from starlette.datastructures import URL, Headers, MutableHeaders 

26from starlette.requests import ClientDisconnect 

27from starlette.types import Receive, Scope, Send 

28 

29 

30class Response: 

31 media_type = None 

32 charset = "utf-8" 

33 

34 def __init__( 

35 self, 

36 content: Any = None, 

37 status_code: int = 200, 

38 headers: Mapping[str, str] | None = None, 

39 media_type: str | None = None, 

40 background: BackgroundTask | None = None, 

41 ) -> None: 

42 self.status_code = status_code 

43 if media_type is not None: 

44 self.media_type = media_type 

45 self.background = background 

46 self.body = self.render(content) 

47 self.init_headers(headers) 

48 

49 def render(self, content: Any) -> bytes | memoryview: 

50 if content is None: 

51 return b"" 

52 if isinstance(content, bytes | memoryview): 

53 return content 

54 return content.encode(self.charset) # type: ignore 

55 

56 def init_headers(self, headers: Mapping[str, str] | None = None) -> None: 

57 if headers is None: 

58 raw_headers: list[tuple[bytes, bytes]] = [] 

59 populate_content_length = True 

60 populate_content_type = True 

61 else: 

62 raw_headers = [(k.lower().encode("latin-1"), v.encode("latin-1")) for k, v in headers.items()] 

63 keys = [h[0] for h in raw_headers] 

64 populate_content_length = b"content-length" not in keys 

65 populate_content_type = b"content-type" not in keys 

66 

67 body = getattr(self, "body", None) 

68 if ( 

69 body is not None 

70 and populate_content_length 

71 and not (self.status_code < 200 or self.status_code in (204, 304)) 

72 ): 

73 content_length = str(len(body)) 

74 raw_headers.append((b"content-length", content_length.encode("latin-1"))) 

75 

76 content_type = self.media_type 

77 if content_type is not None and populate_content_type: 

78 if content_type.startswith("text/") and "charset=" not in content_type.lower(): 

79 content_type += "; charset=" + self.charset 

80 raw_headers.append((b"content-type", content_type.encode("latin-1"))) 

81 

82 self.raw_headers = raw_headers 

83 

84 @property 

85 def headers(self) -> MutableHeaders: 

86 if not hasattr(self, "_headers"): 

87 self._headers = MutableHeaders(raw=self.raw_headers) 

88 return self._headers 

89 

90 def set_cookie( 

91 self, 

92 key: str, 

93 value: str = "", 

94 max_age: int | None = None, 

95 expires: datetime | str | int | None = None, 

96 path: str | None = "/", 

97 domain: str | None = None, 

98 secure: bool = False, 

99 httponly: bool = False, 

100 samesite: Literal["lax", "strict", "none"] | None = "lax", 

101 partitioned: bool = False, 

102 ) -> None: 

103 cookie: http.cookies.BaseCookie[str] = http.cookies.SimpleCookie() 

104 cookie[key] = value 

105 if max_age is not None: 

106 cookie[key]["max-age"] = max_age 

107 if expires is not None: 

108 if isinstance(expires, datetime): 

109 cookie[key]["expires"] = format_datetime(expires, usegmt=True) 

110 else: 

111 cookie[key]["expires"] = expires 

112 if path is not None: 

113 cookie[key]["path"] = path 

114 if domain is not None: 

115 cookie[key]["domain"] = domain 

116 if secure: 

117 cookie[key]["secure"] = True 

118 if httponly: 

119 cookie[key]["httponly"] = True 

120 if samesite is not None: 

121 assert samesite.lower() in [ 

122 "strict", 

123 "lax", 

124 "none", 

125 ], "samesite must be either 'strict', 'lax' or 'none'" 

126 cookie[key]["samesite"] = samesite 

127 if partitioned: 

128 if sys.version_info < (3, 14): 

129 raise ValueError("Partitioned cookies are only supported in Python 3.14 and above.") # pragma: no cover 

130 cookie[key]["partitioned"] = True # pragma: no cover 

131 

132 cookie_val = cookie.output(header="").strip() 

133 self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1"))) 

134 

135 def delete_cookie( 

136 self, 

137 key: str, 

138 path: str = "/", 

139 domain: str | None = None, 

140 secure: bool = False, 

141 httponly: bool = False, 

142 samesite: Literal["lax", "strict", "none"] | None = "lax", 

143 ) -> None: 

144 self.set_cookie( 

145 key, 

146 max_age=0, 

147 expires=0, 

148 path=path, 

149 domain=domain, 

150 secure=secure, 

151 httponly=httponly, 

152 samesite=samesite, 

153 ) 

154 

155 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

156 prefix = "websocket." if scope["type"] == "websocket" else "" 

157 await send( 

158 { 

159 "type": prefix + "http.response.start", 

160 "status": self.status_code, 

161 "headers": self.raw_headers, 

162 } 

163 ) 

164 await send({"type": prefix + "http.response.body", "body": self.body}) 

165 

166 if self.background is not None: 

167 await self.background() 

168 

169 

170class HTMLResponse(Response): 

171 media_type = "text/html" 

172 

173 

174class PlainTextResponse(Response): 

175 media_type = "text/plain" 

176 

177 

178class JSONResponse(Response): 

179 media_type = "application/json" 

180 

181 def __init__( 

182 self, 

183 content: Any, 

184 status_code: int = 200, 

185 headers: Mapping[str, str] | None = None, 

186 media_type: str | None = None, 

187 background: BackgroundTask | None = None, 

188 ) -> None: 

189 super().__init__(content, status_code, headers, media_type, background) 

190 

191 def render(self, content: Any) -> bytes: 

192 return json.dumps( 

193 content, 

194 ensure_ascii=False, 

195 allow_nan=False, 

196 indent=None, 

197 separators=(",", ":"), 

198 ).encode("utf-8") 

199 

200 

201class RedirectResponse(Response): 

202 def __init__( 

203 self, 

204 url: str | URL, 

205 status_code: int = 307, 

206 headers: Mapping[str, str] | None = None, 

207 background: BackgroundTask | None = None, 

208 ) -> None: 

209 super().__init__(content=b"", status_code=status_code, headers=headers, background=background) 

210 self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;") 

211 

212 

213Content = str | bytes | memoryview 

214SyncContentStream = Iterable[Content] 

215AsyncContentStream = AsyncIterable[Content] 

216ContentStream = AsyncContentStream | SyncContentStream 

217 

218 

219class StreamingResponse(Response): 

220 body_iterator: AsyncContentStream 

221 

222 def __init__( 

223 self, 

224 content: ContentStream, 

225 status_code: int = 200, 

226 headers: Mapping[str, str] | None = None, 

227 media_type: str | None = None, 

228 background: BackgroundTask | None = None, 

229 ) -> None: 

230 if isinstance(content, AsyncIterable): 

231 self.body_iterator = content 

232 else: 

233 self.body_iterator = iterate_in_threadpool(content) 

234 self.status_code = status_code 

235 self.media_type = self.media_type if media_type is None else media_type 

236 self.background = background 

237 self.init_headers(headers) 

238 

239 async def listen_for_disconnect(self, receive: Receive) -> None: 

240 while True: 

241 message = await receive() 

242 if message["type"] == "http.disconnect": 

243 break 

244 

245 async def stream_response(self, send: Send) -> None: 

246 await send( 

247 { 

248 "type": "http.response.start", 

249 "status": self.status_code, 

250 "headers": self.raw_headers, 

251 } 

252 ) 

253 async for chunk in self.body_iterator: 

254 if not isinstance(chunk, bytes | memoryview): 

255 chunk = chunk.encode(self.charset) 

256 await send({"type": "http.response.body", "body": chunk, "more_body": True}) 

257 

258 await send({"type": "http.response.body", "body": b"", "more_body": False}) 

259 

260 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

261 spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split("."))) 

262 

263 if spec_version >= (2, 4): 

264 try: 

265 await self.stream_response(send) 

266 except OSError: 

267 raise ClientDisconnect() 

268 else: 

269 with collapse_excgroups(): 

270 async with anyio.create_task_group() as task_group: 

271 

272 async def wrap(func: Callable[[], Awaitable[None]]) -> None: 

273 await func() 

274 task_group.cancel_scope.cancel() 

275 

276 task_group.start_soon(wrap, partial(self.stream_response, send)) 

277 await wrap(partial(self.listen_for_disconnect, receive)) 

278 

279 if self.background is not None: 

280 await self.background() 

281 

282 

283class MalformedRangeHeader(Exception): 

284 def __init__(self, content: str = "Malformed range header.") -> None: 

285 self.content = content 

286 

287 

288class RangeNotSatisfiable(Exception): 

289 def __init__(self, max_size: int) -> None: 

290 self.max_size = max_size 

291 

292 

293class FileResponse(Response): 

294 chunk_size = 64 * 1024 

295 

296 def __init__( 

297 self, 

298 path: str | os.PathLike[str], 

299 status_code: int = 200, 

300 headers: Mapping[str, str] | None = None, 

301 media_type: str | None = None, 

302 background: BackgroundTask | None = None, 

303 filename: str | None = None, 

304 stat_result: os.stat_result | None = None, 

305 method: str | None = None, 

306 content_disposition_type: str = "attachment", 

307 ) -> None: 

308 self.path = path 

309 self.status_code = status_code 

310 self.filename = filename 

311 if method is not None: 

312 warnings.warn( 

313 "The 'method' parameter is not used, and it will be removed.", 

314 DeprecationWarning, 

315 ) 

316 if media_type is None: 

317 media_type = guess_type(filename or path)[0] or "text/plain" 

318 self.media_type = media_type 

319 self.background = background 

320 self.init_headers(headers) 

321 self.headers.setdefault("accept-ranges", "bytes") 

322 if self.filename is not None: 

323 content_disposition_filename = quote(self.filename) 

324 if content_disposition_filename != self.filename: 

325 content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}" 

326 else: 

327 content_disposition = f'{content_disposition_type}; filename="{self.filename}"' 

328 self.headers.setdefault("content-disposition", content_disposition) 

329 self.stat_result = stat_result 

330 if stat_result is not None: 

331 self.set_stat_headers(stat_result) 

332 

333 def set_stat_headers(self, stat_result: os.stat_result) -> None: 

334 content_length = str(stat_result.st_size) 

335 last_modified = formatdate(stat_result.st_mtime, usegmt=True) 

336 etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size) 

337 etag = f'"{hashlib.md5(etag_base.encode(), usedforsecurity=False).hexdigest()}"' 

338 

339 self.headers.setdefault("content-length", content_length) 

340 self.headers.setdefault("last-modified", last_modified) 

341 self.headers.setdefault("etag", etag) 

342 

343 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

344 send_header_only: bool = scope["method"].upper() == "HEAD" 

345 send_pathsend: bool = "http.response.pathsend" in scope.get("extensions", {}) 

346 

347 if self.stat_result is None: 

348 try: 

349 stat_result = await anyio.to_thread.run_sync(os.stat, self.path) 

350 self.set_stat_headers(stat_result) 

351 except FileNotFoundError: 

352 raise RuntimeError(f"File at path {self.path} does not exist.") 

353 else: 

354 mode = stat_result.st_mode 

355 if not stat.S_ISREG(mode): 

356 raise RuntimeError(f"File at path {self.path} is not a file.") 

357 else: 

358 stat_result = self.stat_result 

359 

360 headers = Headers(scope=scope) 

361 http_range = headers.get("range") 

362 http_if_range = headers.get("if-range") 

363 

364 if http_range is None or (http_if_range is not None and not self._should_use_range(http_if_range)): 

365 await self._handle_simple(send, send_header_only, send_pathsend) 

366 else: 

367 try: 

368 ranges = self._parse_range_header(http_range, stat_result.st_size) 

369 except MalformedRangeHeader as exc: 

370 return await PlainTextResponse(exc.content, status_code=400)(scope, receive, send) 

371 except RangeNotSatisfiable as exc: 

372 response = PlainTextResponse(status_code=416, headers={"Content-Range": f"*/{exc.max_size}"}) 

373 return await response(scope, receive, send) 

374 

375 if len(ranges) == 1: 

376 start, end = ranges[0] 

377 await self._handle_single_range(send, start, end, stat_result.st_size, send_header_only) 

378 else: 

379 await self._handle_multiple_ranges(send, ranges, stat_result.st_size, send_header_only) 

380 

381 if self.background is not None: 

382 await self.background() 

383 

384 async def _handle_simple(self, send: Send, send_header_only: bool, send_pathsend: bool) -> None: 

385 await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers}) 

386 if send_header_only: 

387 await send({"type": "http.response.body", "body": b"", "more_body": False}) 

388 elif send_pathsend: 

389 await send({"type": "http.response.pathsend", "path": str(self.path)}) 

390 else: 

391 async with await anyio.open_file(self.path, mode="rb") as file: 

392 more_body = True 

393 while more_body: 

394 chunk = await file.read(self.chunk_size) 

395 more_body = len(chunk) == self.chunk_size 

396 await send({"type": "http.response.body", "body": chunk, "more_body": more_body}) 

397 

398 async def _handle_single_range( 

399 self, send: Send, start: int, end: int, file_size: int, send_header_only: bool 

400 ) -> None: 

401 self.headers["content-range"] = f"bytes {start}-{end - 1}/{file_size}" 

402 self.headers["content-length"] = str(end - start) 

403 await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers}) 

404 if send_header_only: 

405 await send({"type": "http.response.body", "body": b"", "more_body": False}) 

406 else: 

407 async with await anyio.open_file(self.path, mode="rb") as file: 

408 await file.seek(start) 

409 more_body = True 

410 while more_body: 

411 chunk = await file.read(min(self.chunk_size, end - start)) 

412 start += len(chunk) 

413 more_body = len(chunk) == self.chunk_size and start < end 

414 await send({"type": "http.response.body", "body": chunk, "more_body": more_body}) 

415 

416 async def _handle_multiple_ranges( 

417 self, 

418 send: Send, 

419 ranges: list[tuple[int, int]], 

420 file_size: int, 

421 send_header_only: bool, 

422 ) -> None: 

423 # In firefox and chrome, they use boundary with 95-96 bits entropy (that's roughly 13 bytes). 

424 boundary = token_hex(13) 

425 content_length, header_generator = self.generate_multipart( 

426 ranges, boundary, file_size, self.headers["content-type"] 

427 ) 

428 self.headers["content-range"] = f"multipart/byteranges; boundary={boundary}" 

429 self.headers["content-length"] = str(content_length) 

430 await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers}) 

431 if send_header_only: 

432 await send({"type": "http.response.body", "body": b"", "more_body": False}) 

433 else: 

434 async with await anyio.open_file(self.path, mode="rb") as file: 

435 for start, end in ranges: 

436 await send({"type": "http.response.body", "body": header_generator(start, end), "more_body": True}) 

437 await file.seek(start) 

438 while start < end: 

439 chunk = await file.read(min(self.chunk_size, end - start)) 

440 start += len(chunk) 

441 await send({"type": "http.response.body", "body": chunk, "more_body": True}) 

442 await send({"type": "http.response.body", "body": b"\n", "more_body": True}) 

443 await send( 

444 { 

445 "type": "http.response.body", 

446 "body": f"\n--{boundary}--\n".encode("latin-1"), 

447 "more_body": False, 

448 } 

449 ) 

450 

451 def _should_use_range(self, http_if_range: str) -> bool: 

452 return http_if_range == self.headers["last-modified"] or http_if_range == self.headers["etag"] 

453 

454 @classmethod 

455 def _parse_range_header(cls, http_range: str, file_size: int) -> list[tuple[int, int]]: 

456 ranges: list[tuple[int, int]] = [] 

457 try: 

458 units, range_ = http_range.split("=", 1) 

459 except ValueError: 

460 raise MalformedRangeHeader() 

461 

462 units = units.strip().lower() 

463 

464 if units != "bytes": 

465 raise MalformedRangeHeader("Only support bytes range") 

466 

467 ranges = cls._parse_ranges(range_, file_size) 

468 

469 if len(ranges) == 0: 

470 raise MalformedRangeHeader("Range header: range must be requested") 

471 

472 if any(not (0 <= start < file_size) for start, _ in ranges): 

473 raise RangeNotSatisfiable(file_size) 

474 

475 if any(start > end for start, end in ranges): 

476 raise MalformedRangeHeader("Range header: start must be less than end") 

477 

478 if len(ranges) == 1: 

479 return ranges 

480 

481 # Merge ranges 

482 result: list[tuple[int, int]] = [] 

483 for start, end in ranges: 

484 for p in range(len(result)): 

485 p_start, p_end = result[p] 

486 if start > p_end: 

487 continue 

488 elif end < p_start: 

489 result.insert(p, (start, end)) # THIS IS NOT REACHED! 

490 break 

491 else: 

492 result[p] = (min(start, p_start), max(end, p_end)) 

493 break 

494 else: 

495 result.append((start, end)) 

496 

497 return result 

498 

499 @classmethod 

500 def _parse_ranges(cls, range_: str, file_size: int) -> list[tuple[int, int]]: 

501 ranges: list[tuple[int, int]] = [] 

502 

503 for part in range_.split(","): 

504 part = part.strip() 

505 

506 # If the range is empty or a single dash, we ignore it. 

507 if not part or part == "-": 

508 continue 

509 

510 # If the range is not in the format "start-end", we ignore it. 

511 if "-" not in part: 

512 continue 

513 

514 start_str, end_str = part.split("-", 1) 

515 start_str = start_str.strip() 

516 end_str = end_str.strip() 

517 

518 try: 

519 start = int(start_str) if start_str else file_size - int(end_str) 

520 end = int(end_str) + 1 if start_str and end_str and int(end_str) < file_size else file_size 

521 ranges.append((start, end)) 

522 except ValueError: 

523 # If the range is not numeric, we ignore it. 

524 continue 

525 

526 return ranges 

527 

528 def generate_multipart( 

529 self, 

530 ranges: Sequence[tuple[int, int]], 

531 boundary: str, 

532 max_size: int, 

533 content_type: str, 

534 ) -> tuple[int, Callable[[int, int], bytes]]: 

535 r""" 

536 Multipart response headers generator. 

537 

538 ``` 

539 --{boundary}\n 

540 Content-Type: {content_type}\n 

541 Content-Range: bytes {start}-{end-1}/{max_size}\n 

542 \n 

543 ..........content...........\n 

544 --{boundary}\n 

545 Content-Type: {content_type}\n 

546 Content-Range: bytes {start}-{end-1}/{max_size}\n 

547 \n 

548 ..........content...........\n 

549 --{boundary}--\n 

550 ``` 

551 """ 

552 boundary_len = len(boundary) 

553 static_header_part_len = 44 + boundary_len + len(content_type) + len(str(max_size)) 

554 content_length = sum( 

555 (len(str(start)) + len(str(end - 1)) + static_header_part_len) # Headers 

556 + (end - start) # Content 

557 for start, end in ranges 

558 ) + ( 

559 5 + boundary_len # --boundary--\n 

560 ) 

561 return ( 

562 content_length, 

563 lambda start, end: ( 

564 f"--{boundary}\nContent-Type: {content_type}\nContent-Range: bytes {start}-{end - 1}/{max_size}\n\n" 

565 ).encode("latin-1"), 

566 )