Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/starlette/responses.py: 23%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

322 statements  

1from __future__ import annotations 

2 

3import hashlib 

4import http.cookies 

5import json 

6import os 

7import stat 

8import sys 

9import warnings 

10from collections.abc import AsyncIterable, Awaitable, Iterable, Mapping, Sequence 

11from datetime import datetime 

12from email.utils import format_datetime, formatdate 

13from functools import partial 

14from mimetypes import guess_type 

15from secrets import token_hex 

16from typing import Any, Callable, Literal, Union 

17from urllib.parse import quote 

18 

19import anyio 

20import anyio.to_thread 

21 

22from starlette._utils import collapse_excgroups 

23from starlette.background import BackgroundTask 

24from starlette.concurrency import iterate_in_threadpool 

25from starlette.datastructures import URL, Headers, MutableHeaders 

26from starlette.requests import ClientDisconnect 

27from starlette.types import Receive, Scope, Send 

28 

29 

30class Response: 

31 media_type = None 

32 charset = "utf-8" 

33 

34 def __init__( 

35 self, 

36 content: Any = None, 

37 status_code: int = 200, 

38 headers: Mapping[str, str] | None = None, 

39 media_type: str | None = None, 

40 background: BackgroundTask | None = None, 

41 ) -> None: 

42 self.status_code = status_code 

43 if media_type is not None: 

44 self.media_type = media_type 

45 self.background = background 

46 self.body = self.render(content) 

47 self.init_headers(headers) 

48 

49 def render(self, content: Any) -> bytes | memoryview: 

50 if content is None: 

51 return b"" 

52 if isinstance(content, (bytes, memoryview)): 

53 return content 

54 return content.encode(self.charset) # type: ignore 

55 

56 def init_headers(self, headers: Mapping[str, str] | None = None) -> None: 

57 if headers is None: 

58 raw_headers: list[tuple[bytes, bytes]] = [] 

59 populate_content_length = True 

60 populate_content_type = True 

61 else: 

62 raw_headers = [(k.lower().encode("latin-1"), v.encode("latin-1")) for k, v in headers.items()] 

63 keys = [h[0] for h in raw_headers] 

64 populate_content_length = b"content-length" not in keys 

65 populate_content_type = b"content-type" not in keys 

66 

67 body = getattr(self, "body", None) 

68 if ( 

69 body is not None 

70 and populate_content_length 

71 and not (self.status_code < 200 or self.status_code in (204, 304)) 

72 ): 

73 content_length = str(len(body)) 

74 raw_headers.append((b"content-length", content_length.encode("latin-1"))) 

75 

76 content_type = self.media_type 

77 if content_type is not None and populate_content_type: 

78 if content_type.startswith("text/") and "charset=" not in content_type.lower(): 

79 content_type += "; charset=" + self.charset 

80 raw_headers.append((b"content-type", content_type.encode("latin-1"))) 

81 

82 self.raw_headers = raw_headers 

83 

84 @property 

85 def headers(self) -> MutableHeaders: 

86 if not hasattr(self, "_headers"): 

87 self._headers = MutableHeaders(raw=self.raw_headers) 

88 return self._headers 

89 

90 def set_cookie( 

91 self, 

92 key: str, 

93 value: str = "", 

94 max_age: int | None = None, 

95 expires: datetime | str | int | None = None, 

96 path: str | None = "/", 

97 domain: str | None = None, 

98 secure: bool = False, 

99 httponly: bool = False, 

100 samesite: Literal["lax", "strict", "none"] | None = "lax", 

101 partitioned: bool = False, 

102 ) -> None: 

103 cookie: http.cookies.BaseCookie[str] = http.cookies.SimpleCookie() 

104 cookie[key] = value 

105 if max_age is not None: 

106 cookie[key]["max-age"] = max_age 

107 if expires is not None: 

108 if isinstance(expires, datetime): 

109 cookie[key]["expires"] = format_datetime(expires, usegmt=True) 

110 else: 

111 cookie[key]["expires"] = expires 

112 if path is not None: 

113 cookie[key]["path"] = path 

114 if domain is not None: 

115 cookie[key]["domain"] = domain 

116 if secure: 

117 cookie[key]["secure"] = True 

118 if httponly: 

119 cookie[key]["httponly"] = True 

120 if samesite is not None: 

121 assert samesite.lower() in [ 

122 "strict", 

123 "lax", 

124 "none", 

125 ], "samesite must be either 'strict', 'lax' or 'none'" 

126 cookie[key]["samesite"] = samesite 

127 if partitioned: 

128 if sys.version_info < (3, 14): 

129 raise ValueError("Partitioned cookies are only supported in Python 3.14 and above.") # pragma: no cover 

130 cookie[key]["partitioned"] = True # pragma: no cover 

131 

132 cookie_val = cookie.output(header="").strip() 

133 self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1"))) 

134 

135 def delete_cookie( 

136 self, 

137 key: str, 

138 path: str = "/", 

139 domain: str | None = None, 

140 secure: bool = False, 

141 httponly: bool = False, 

142 samesite: Literal["lax", "strict", "none"] | None = "lax", 

143 ) -> None: 

144 self.set_cookie( 

145 key, 

146 max_age=0, 

147 expires=0, 

148 path=path, 

149 domain=domain, 

150 secure=secure, 

151 httponly=httponly, 

152 samesite=samesite, 

153 ) 

154 

155 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

156 prefix = "websocket." if scope["type"] == "websocket" else "" 

157 await send( 

158 { 

159 "type": prefix + "http.response.start", 

160 "status": self.status_code, 

161 "headers": self.raw_headers, 

162 } 

163 ) 

164 await send({"type": prefix + "http.response.body", "body": self.body}) 

165 

166 if self.background is not None: 

167 await self.background() 

168 

169 

170class HTMLResponse(Response): 

171 media_type = "text/html" 

172 

173 

174class PlainTextResponse(Response): 

175 media_type = "text/plain" 

176 

177 

178class JSONResponse(Response): 

179 media_type = "application/json" 

180 

181 def __init__( 

182 self, 

183 content: Any, 

184 status_code: int = 200, 

185 headers: Mapping[str, str] | None = None, 

186 media_type: str | None = None, 

187 background: BackgroundTask | None = None, 

188 ) -> None: 

189 super().__init__(content, status_code, headers, media_type, background) 

190 

191 def render(self, content: Any) -> bytes: 

192 return json.dumps( 

193 content, 

194 ensure_ascii=False, 

195 allow_nan=False, 

196 indent=None, 

197 separators=(",", ":"), 

198 ).encode("utf-8") 

199 

200 

201class RedirectResponse(Response): 

202 def __init__( 

203 self, 

204 url: str | URL, 

205 status_code: int = 307, 

206 headers: Mapping[str, str] | None = None, 

207 background: BackgroundTask | None = None, 

208 ) -> None: 

209 super().__init__(content=b"", status_code=status_code, headers=headers, background=background) 

210 self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;") 

211 

212 

213Content = Union[str, bytes, memoryview] 

214SyncContentStream = Iterable[Content] 

215AsyncContentStream = AsyncIterable[Content] 

216ContentStream = Union[AsyncContentStream, SyncContentStream] 

217 

218 

219class StreamingResponse(Response): 

220 body_iterator: AsyncContentStream 

221 

222 def __init__( 

223 self, 

224 content: ContentStream, 

225 status_code: int = 200, 

226 headers: Mapping[str, str] | None = None, 

227 media_type: str | None = None, 

228 background: BackgroundTask | None = None, 

229 ) -> None: 

230 if isinstance(content, AsyncIterable): 

231 self.body_iterator = content 

232 else: 

233 self.body_iterator = iterate_in_threadpool(content) 

234 self.status_code = status_code 

235 self.media_type = self.media_type if media_type is None else media_type 

236 self.background = background 

237 self.init_headers(headers) 

238 

239 async def listen_for_disconnect(self, receive: Receive) -> None: 

240 while True: 

241 message = await receive() 

242 if message["type"] == "http.disconnect": 

243 break 

244 

245 async def stream_response(self, send: Send) -> None: 

246 await send( 

247 { 

248 "type": "http.response.start", 

249 "status": self.status_code, 

250 "headers": self.raw_headers, 

251 } 

252 ) 

253 async for chunk in self.body_iterator: 

254 if not isinstance(chunk, (bytes, memoryview)): 

255 chunk = chunk.encode(self.charset) 

256 await send({"type": "http.response.body", "body": chunk, "more_body": True}) 

257 

258 await send({"type": "http.response.body", "body": b"", "more_body": False}) 

259 

260 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

261 spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split("."))) 

262 

263 if spec_version >= (2, 4): 

264 try: 

265 await self.stream_response(send) 

266 except OSError: 

267 raise ClientDisconnect() 

268 else: 

269 with collapse_excgroups(): 

270 async with anyio.create_task_group() as task_group: 

271 

272 async def wrap(func: Callable[[], Awaitable[None]]) -> None: 

273 await func() 

274 task_group.cancel_scope.cancel() 

275 

276 task_group.start_soon(wrap, partial(self.stream_response, send)) 

277 await wrap(partial(self.listen_for_disconnect, receive)) 

278 

279 if self.background is not None: 

280 await self.background() 

281 

282 

283class MalformedRangeHeader(Exception): 

284 def __init__(self, content: str = "Malformed range header.") -> None: 

285 self.content = content 

286 

287 

288class RangeNotSatisfiable(Exception): 

289 def __init__(self, max_size: int) -> None: 

290 self.max_size = max_size 

291 

292 

293class FileResponse(Response): 

294 chunk_size = 64 * 1024 

295 

296 def __init__( 

297 self, 

298 path: str | os.PathLike[str], 

299 status_code: int = 200, 

300 headers: Mapping[str, str] | None = None, 

301 media_type: str | None = None, 

302 background: BackgroundTask | None = None, 

303 filename: str | None = None, 

304 stat_result: os.stat_result | None = None, 

305 method: str | None = None, 

306 content_disposition_type: str = "attachment", 

307 ) -> None: 

308 self.path = path 

309 self.status_code = status_code 

310 self.filename = filename 

311 if method is not None: 

312 warnings.warn( 

313 "The 'method' parameter is not used, and it will be removed.", 

314 DeprecationWarning, 

315 ) 

316 if media_type is None: 

317 media_type = guess_type(filename or path)[0] or "text/plain" 

318 self.media_type = media_type 

319 self.background = background 

320 self.init_headers(headers) 

321 self.headers.setdefault("accept-ranges", "bytes") 

322 if self.filename is not None: 

323 content_disposition_filename = quote(self.filename) 

324 if content_disposition_filename != self.filename: 

325 content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}" 

326 else: 

327 content_disposition = f'{content_disposition_type}; filename="{self.filename}"' 

328 self.headers.setdefault("content-disposition", content_disposition) 

329 self.stat_result = stat_result 

330 if stat_result is not None: 

331 self.set_stat_headers(stat_result) 

332 

333 def set_stat_headers(self, stat_result: os.stat_result) -> None: 

334 content_length = str(stat_result.st_size) 

335 last_modified = formatdate(stat_result.st_mtime, usegmt=True) 

336 etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size) 

337 etag = f'"{hashlib.md5(etag_base.encode(), usedforsecurity=False).hexdigest()}"' 

338 

339 self.headers.setdefault("content-length", content_length) 

340 self.headers.setdefault("last-modified", last_modified) 

341 self.headers.setdefault("etag", etag) 

342 

343 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

344 send_header_only: bool = scope["method"].upper() == "HEAD" 

345 send_pathsend: bool = "http.response.pathsend" in scope.get("extensions", {}) 

346 

347 if self.stat_result is None: 

348 try: 

349 stat_result = await anyio.to_thread.run_sync(os.stat, self.path) 

350 self.set_stat_headers(stat_result) 

351 except FileNotFoundError: 

352 raise RuntimeError(f"File at path {self.path} does not exist.") 

353 else: 

354 mode = stat_result.st_mode 

355 if not stat.S_ISREG(mode): 

356 raise RuntimeError(f"File at path {self.path} is not a file.") 

357 else: 

358 stat_result = self.stat_result 

359 

360 headers = Headers(scope=scope) 

361 http_range = headers.get("range") 

362 http_if_range = headers.get("if-range") 

363 

364 if http_range is None or (http_if_range is not None and not self._should_use_range(http_if_range)): 

365 await self._handle_simple(send, send_header_only, send_pathsend) 

366 else: 

367 try: 

368 ranges = self._parse_range_header(http_range, stat_result.st_size) 

369 except MalformedRangeHeader as exc: 

370 return await PlainTextResponse(exc.content, status_code=400)(scope, receive, send) 

371 except RangeNotSatisfiable as exc: 

372 response = PlainTextResponse(status_code=416, headers={"Content-Range": f"*/{exc.max_size}"}) 

373 return await response(scope, receive, send) 

374 

375 if len(ranges) == 1: 

376 start, end = ranges[0] 

377 await self._handle_single_range(send, start, end, stat_result.st_size, send_header_only) 

378 else: 

379 await self._handle_multiple_ranges(send, ranges, stat_result.st_size, send_header_only) 

380 

381 if self.background is not None: 

382 await self.background() 

383 

384 async def _handle_simple(self, send: Send, send_header_only: bool, send_pathsend: bool) -> None: 

385 await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers}) 

386 if send_header_only: 

387 await send({"type": "http.response.body", "body": b"", "more_body": False}) 

388 elif send_pathsend: 

389 await send({"type": "http.response.pathsend", "path": str(self.path)}) 

390 else: 

391 async with await anyio.open_file(self.path, mode="rb") as file: 

392 more_body = True 

393 while more_body: 

394 chunk = await file.read(self.chunk_size) 

395 more_body = len(chunk) == self.chunk_size 

396 await send({"type": "http.response.body", "body": chunk, "more_body": more_body}) 

397 

398 async def _handle_single_range( 

399 self, send: Send, start: int, end: int, file_size: int, send_header_only: bool 

400 ) -> None: 

401 self.headers["content-range"] = f"bytes {start}-{end - 1}/{file_size}" 

402 self.headers["content-length"] = str(end - start) 

403 await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers}) 

404 if send_header_only: 

405 await send({"type": "http.response.body", "body": b"", "more_body": False}) 

406 else: 

407 async with await anyio.open_file(self.path, mode="rb") as file: 

408 await file.seek(start) 

409 more_body = True 

410 while more_body: 

411 chunk = await file.read(min(self.chunk_size, end - start)) 

412 start += len(chunk) 

413 more_body = len(chunk) == self.chunk_size and start < end 

414 await send({"type": "http.response.body", "body": chunk, "more_body": more_body}) 

415 

416 async def _handle_multiple_ranges( 

417 self, 

418 send: Send, 

419 ranges: list[tuple[int, int]], 

420 file_size: int, 

421 send_header_only: bool, 

422 ) -> None: 

423 # In firefox and chrome, they use boundary with 95-96 bits entropy (that's roughly 13 bytes). 

424 boundary = token_hex(13) 

425 content_length, header_generator = self.generate_multipart( 

426 ranges, boundary, file_size, self.headers["content-type"] 

427 ) 

428 self.headers["content-range"] = f"multipart/byteranges; boundary={boundary}" 

429 self.headers["content-length"] = str(content_length) 

430 await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers}) 

431 if send_header_only: 

432 await send({"type": "http.response.body", "body": b"", "more_body": False}) 

433 else: 

434 async with await anyio.open_file(self.path, mode="rb") as file: 

435 for start, end in ranges: 

436 await send({"type": "http.response.body", "body": header_generator(start, end), "more_body": True}) 

437 await file.seek(start) 

438 while start < end: 

439 chunk = await file.read(min(self.chunk_size, end - start)) 

440 start += len(chunk) 

441 await send({"type": "http.response.body", "body": chunk, "more_body": True}) 

442 await send({"type": "http.response.body", "body": b"\n", "more_body": True}) 

443 await send( 

444 { 

445 "type": "http.response.body", 

446 "body": f"\n--{boundary}--\n".encode("latin-1"), 

447 "more_body": False, 

448 } 

449 ) 

450 

451 def _should_use_range(self, http_if_range: str) -> bool: 

452 return http_if_range == self.headers["last-modified"] or http_if_range == self.headers["etag"] 

453 

454 @classmethod 

455 def _parse_range_header(cls, http_range: str, file_size: int) -> list[tuple[int, int]]: 

456 ranges: list[tuple[int, int]] = [] 

457 try: 

458 units, range_ = http_range.split("=", 1) 

459 except ValueError: 

460 raise MalformedRangeHeader() 

461 

462 units = units.strip().lower() 

463 

464 if units != "bytes": 

465 raise MalformedRangeHeader("Only support bytes range") 

466 

467 ranges = cls._parse_ranges(range_, file_size) 

468 

469 if len(ranges) == 0: 

470 raise MalformedRangeHeader("Range header: range must be requested") 

471 

472 if any(not (0 <= start < file_size) for start, _ in ranges): 

473 raise RangeNotSatisfiable(file_size) 

474 

475 if any(start > end for start, end in ranges): 

476 raise MalformedRangeHeader("Range header: start must be less than end") 

477 

478 if len(ranges) == 1: 

479 return ranges 

480 

481 # Merge ranges 

482 result: list[tuple[int, int]] = [] 

483 for start, end in ranges: 

484 for p in range(len(result)): 

485 p_start, p_end = result[p] 

486 if start > p_end: 

487 continue 

488 elif end < p_start: 

489 result.insert(p, (start, end)) # THIS IS NOT REACHED! 

490 break 

491 else: 

492 result[p] = (min(start, p_start), max(end, p_end)) 

493 break 

494 else: 

495 result.append((start, end)) 

496 

497 return result 

498 

499 @classmethod 

500 def _parse_ranges(cls, range_: str, file_size: int) -> list[tuple[int, int]]: 

501 ranges: list[tuple[int, int]] = [] 

502 

503 for part in range_.split(","): 

504 part = part.strip() 

505 

506 # If the range is empty or a single dash, we ignore it. 

507 if not part or part == "-": 

508 continue 

509 

510 # If the range is not in the format "start-end", we ignore it. 

511 if "-" not in part: 

512 continue 

513 

514 start_str, end_str = part.split("-", 1) 

515 start_str = start_str.strip() 

516 end_str = end_str.strip() 

517 

518 try: 

519 start = int(start_str) if start_str else file_size - int(end_str) 

520 end = int(end_str) + 1 if start_str and end_str and int(end_str) < file_size else file_size 

521 ranges.append((start, end)) 

522 except ValueError: 

523 # If the range is not numeric, we ignore it. 

524 continue 

525 

526 return ranges 

527 

528 def generate_multipart( 

529 self, 

530 ranges: Sequence[tuple[int, int]], 

531 boundary: str, 

532 max_size: int, 

533 content_type: str, 

534 ) -> tuple[int, Callable[[int, int], bytes]]: 

535 r""" 

536 Multipart response headers generator. 

537 

538 ``` 

539 --{boundary}\n 

540 Content-Type: {content_type}\n 

541 Content-Range: bytes {start}-{end-1}/{max_size}\n 

542 \n 

543 ..........content...........\n 

544 --{boundary}\n 

545 Content-Type: {content_type}\n 

546 Content-Range: bytes {start}-{end-1}/{max_size}\n 

547 \n 

548 ..........content...........\n 

549 --{boundary}--\n 

550 ``` 

551 """ 

552 boundary_len = len(boundary) 

553 static_header_part_len = 44 + boundary_len + len(content_type) + len(str(max_size)) 

554 content_length = sum( 

555 (len(str(start)) + len(str(end - 1)) + static_header_part_len) # Headers 

556 + (end - start) # Content 

557 for start, end in ranges 

558 ) + ( 

559 5 + boundary_len # --boundary--\n 

560 ) 

561 return ( 

562 content_length, 

563 lambda start, end: ( 

564 f"--{boundary}\nContent-Type: {content_type}\nContent-Range: bytes {start}-{end - 1}/{max_size}\n\n" 

565 ).encode("latin-1"), 

566 )