Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/starlette/responses.py: 24%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

305 statements  

1from __future__ import annotations 

2 

3import hashlib 

4import http.cookies 

5import json 

6import os 

7import re 

8import stat 

9import sys 

10import warnings 

11from collections.abc import AsyncIterable, Awaitable, Iterable, Mapping, Sequence 

12from datetime import datetime 

13from email.utils import format_datetime, formatdate 

14from functools import partial 

15from mimetypes import guess_type 

16from secrets import token_hex 

17from typing import Any, Callable, Literal, Union 

18from urllib.parse import quote 

19 

20import anyio 

21import anyio.to_thread 

22 

23from starlette._utils import collapse_excgroups 

24from starlette.background import BackgroundTask 

25from starlette.concurrency import iterate_in_threadpool 

26from starlette.datastructures import URL, Headers, MutableHeaders 

27from starlette.requests import ClientDisconnect 

28from starlette.types import Receive, Scope, Send 

29 

30 

31class Response: 

32 media_type = None 

33 charset = "utf-8" 

34 

35 def __init__( 

36 self, 

37 content: Any = None, 

38 status_code: int = 200, 

39 headers: Mapping[str, str] | None = None, 

40 media_type: str | None = None, 

41 background: BackgroundTask | None = None, 

42 ) -> None: 

43 self.status_code = status_code 

44 if media_type is not None: 

45 self.media_type = media_type 

46 self.background = background 

47 self.body = self.render(content) 

48 self.init_headers(headers) 

49 

50 def render(self, content: Any) -> bytes | memoryview: 

51 if content is None: 

52 return b"" 

53 if isinstance(content, (bytes, memoryview)): 

54 return content 

55 return content.encode(self.charset) # type: ignore 

56 

57 def init_headers(self, headers: Mapping[str, str] | None = None) -> None: 

58 if headers is None: 

59 raw_headers: list[tuple[bytes, bytes]] = [] 

60 populate_content_length = True 

61 populate_content_type = True 

62 else: 

63 raw_headers = [(k.lower().encode("latin-1"), v.encode("latin-1")) for k, v in headers.items()] 

64 keys = [h[0] for h in raw_headers] 

65 populate_content_length = b"content-length" not in keys 

66 populate_content_type = b"content-type" not in keys 

67 

68 body = getattr(self, "body", None) 

69 if ( 

70 body is not None 

71 and populate_content_length 

72 and not (self.status_code < 200 or self.status_code in (204, 304)) 

73 ): 

74 content_length = str(len(body)) 

75 raw_headers.append((b"content-length", content_length.encode("latin-1"))) 

76 

77 content_type = self.media_type 

78 if content_type is not None and populate_content_type: 

79 if content_type.startswith("text/") and "charset=" not in content_type.lower(): 

80 content_type += "; charset=" + self.charset 

81 raw_headers.append((b"content-type", content_type.encode("latin-1"))) 

82 

83 self.raw_headers = raw_headers 

84 

85 @property 

86 def headers(self) -> MutableHeaders: 

87 if not hasattr(self, "_headers"): 

88 self._headers = MutableHeaders(raw=self.raw_headers) 

89 return self._headers 

90 

91 def set_cookie( 

92 self, 

93 key: str, 

94 value: str = "", 

95 max_age: int | None = None, 

96 expires: datetime | str | int | None = None, 

97 path: str | None = "/", 

98 domain: str | None = None, 

99 secure: bool = False, 

100 httponly: bool = False, 

101 samesite: Literal["lax", "strict", "none"] | None = "lax", 

102 partitioned: bool = False, 

103 ) -> None: 

104 cookie: http.cookies.BaseCookie[str] = http.cookies.SimpleCookie() 

105 cookie[key] = value 

106 if max_age is not None: 

107 cookie[key]["max-age"] = max_age 

108 if expires is not None: 

109 if isinstance(expires, datetime): 

110 cookie[key]["expires"] = format_datetime(expires, usegmt=True) 

111 else: 

112 cookie[key]["expires"] = expires 

113 if path is not None: 

114 cookie[key]["path"] = path 

115 if domain is not None: 

116 cookie[key]["domain"] = domain 

117 if secure: 

118 cookie[key]["secure"] = True 

119 if httponly: 

120 cookie[key]["httponly"] = True 

121 if samesite is not None: 

122 assert samesite.lower() in [ 

123 "strict", 

124 "lax", 

125 "none", 

126 ], "samesite must be either 'strict', 'lax' or 'none'" 

127 cookie[key]["samesite"] = samesite 

128 if partitioned: 

129 if sys.version_info < (3, 14): 

130 raise ValueError("Partitioned cookies are only supported in Python 3.14 and above.") # pragma: no cover 

131 cookie[key]["partitioned"] = True # pragma: no cover 

132 

133 cookie_val = cookie.output(header="").strip() 

134 self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1"))) 

135 

136 def delete_cookie( 

137 self, 

138 key: str, 

139 path: str = "/", 

140 domain: str | None = None, 

141 secure: bool = False, 

142 httponly: bool = False, 

143 samesite: Literal["lax", "strict", "none"] | None = "lax", 

144 ) -> None: 

145 self.set_cookie( 

146 key, 

147 max_age=0, 

148 expires=0, 

149 path=path, 

150 domain=domain, 

151 secure=secure, 

152 httponly=httponly, 

153 samesite=samesite, 

154 ) 

155 

156 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

157 prefix = "websocket." if scope["type"] == "websocket" else "" 

158 await send( 

159 { 

160 "type": prefix + "http.response.start", 

161 "status": self.status_code, 

162 "headers": self.raw_headers, 

163 } 

164 ) 

165 await send({"type": prefix + "http.response.body", "body": self.body}) 

166 

167 if self.background is not None: 

168 await self.background() 

169 

170 

171class HTMLResponse(Response): 

172 media_type = "text/html" 

173 

174 

175class PlainTextResponse(Response): 

176 media_type = "text/plain" 

177 

178 

179class JSONResponse(Response): 

180 media_type = "application/json" 

181 

182 def __init__( 

183 self, 

184 content: Any, 

185 status_code: int = 200, 

186 headers: Mapping[str, str] | None = None, 

187 media_type: str | None = None, 

188 background: BackgroundTask | None = None, 

189 ) -> None: 

190 super().__init__(content, status_code, headers, media_type, background) 

191 

192 def render(self, content: Any) -> bytes: 

193 return json.dumps( 

194 content, 

195 ensure_ascii=False, 

196 allow_nan=False, 

197 indent=None, 

198 separators=(",", ":"), 

199 ).encode("utf-8") 

200 

201 

202class RedirectResponse(Response): 

203 def __init__( 

204 self, 

205 url: str | URL, 

206 status_code: int = 307, 

207 headers: Mapping[str, str] | None = None, 

208 background: BackgroundTask | None = None, 

209 ) -> None: 

210 super().__init__(content=b"", status_code=status_code, headers=headers, background=background) 

211 self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;") 

212 

213 

214Content = Union[str, bytes, memoryview] 

215SyncContentStream = Iterable[Content] 

216AsyncContentStream = AsyncIterable[Content] 

217ContentStream = Union[AsyncContentStream, SyncContentStream] 

218 

219 

220class StreamingResponse(Response): 

221 body_iterator: AsyncContentStream 

222 

223 def __init__( 

224 self, 

225 content: ContentStream, 

226 status_code: int = 200, 

227 headers: Mapping[str, str] | None = None, 

228 media_type: str | None = None, 

229 background: BackgroundTask | None = None, 

230 ) -> None: 

231 if isinstance(content, AsyncIterable): 

232 self.body_iterator = content 

233 else: 

234 self.body_iterator = iterate_in_threadpool(content) 

235 self.status_code = status_code 

236 self.media_type = self.media_type if media_type is None else media_type 

237 self.background = background 

238 self.init_headers(headers) 

239 

240 async def listen_for_disconnect(self, receive: Receive) -> None: 

241 while True: 

242 message = await receive() 

243 if message["type"] == "http.disconnect": 

244 break 

245 

246 async def stream_response(self, send: Send) -> None: 

247 await send( 

248 { 

249 "type": "http.response.start", 

250 "status": self.status_code, 

251 "headers": self.raw_headers, 

252 } 

253 ) 

254 async for chunk in self.body_iterator: 

255 if not isinstance(chunk, (bytes, memoryview)): 

256 chunk = chunk.encode(self.charset) 

257 await send({"type": "http.response.body", "body": chunk, "more_body": True}) 

258 

259 await send({"type": "http.response.body", "body": b"", "more_body": False}) 

260 

261 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

262 spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split("."))) 

263 

264 if spec_version >= (2, 4): 

265 try: 

266 await self.stream_response(send) 

267 except OSError: 

268 raise ClientDisconnect() 

269 else: 

270 with collapse_excgroups(): 

271 async with anyio.create_task_group() as task_group: 

272 

273 async def wrap(func: Callable[[], Awaitable[None]]) -> None: 

274 await func() 

275 task_group.cancel_scope.cancel() 

276 

277 task_group.start_soon(wrap, partial(self.stream_response, send)) 

278 await wrap(partial(self.listen_for_disconnect, receive)) 

279 

280 if self.background is not None: 

281 await self.background() 

282 

283 

284class MalformedRangeHeader(Exception): 

285 def __init__(self, content: str = "Malformed range header.") -> None: 

286 self.content = content 

287 

288 

289class RangeNotSatisfiable(Exception): 

290 def __init__(self, max_size: int) -> None: 

291 self.max_size = max_size 

292 

293 

294_RANGE_PATTERN = re.compile(r"(\d*)-(\d*)") 

295 

296 

297class FileResponse(Response): 

298 chunk_size = 64 * 1024 

299 

300 def __init__( 

301 self, 

302 path: str | os.PathLike[str], 

303 status_code: int = 200, 

304 headers: Mapping[str, str] | None = None, 

305 media_type: str | None = None, 

306 background: BackgroundTask | None = None, 

307 filename: str | None = None, 

308 stat_result: os.stat_result | None = None, 

309 method: str | None = None, 

310 content_disposition_type: str = "attachment", 

311 ) -> None: 

312 self.path = path 

313 self.status_code = status_code 

314 self.filename = filename 

315 if method is not None: 

316 warnings.warn( 

317 "The 'method' parameter is not used, and it will be removed.", 

318 DeprecationWarning, 

319 ) 

320 if media_type is None: 

321 media_type = guess_type(filename or path)[0] or "text/plain" 

322 self.media_type = media_type 

323 self.background = background 

324 self.init_headers(headers) 

325 self.headers.setdefault("accept-ranges", "bytes") 

326 if self.filename is not None: 

327 content_disposition_filename = quote(self.filename) 

328 if content_disposition_filename != self.filename: 

329 content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}" 

330 else: 

331 content_disposition = f'{content_disposition_type}; filename="{self.filename}"' 

332 self.headers.setdefault("content-disposition", content_disposition) 

333 self.stat_result = stat_result 

334 if stat_result is not None: 

335 self.set_stat_headers(stat_result) 

336 

337 def set_stat_headers(self, stat_result: os.stat_result) -> None: 

338 content_length = str(stat_result.st_size) 

339 last_modified = formatdate(stat_result.st_mtime, usegmt=True) 

340 etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size) 

341 etag = f'"{hashlib.md5(etag_base.encode(), usedforsecurity=False).hexdigest()}"' 

342 

343 self.headers.setdefault("content-length", content_length) 

344 self.headers.setdefault("last-modified", last_modified) 

345 self.headers.setdefault("etag", etag) 

346 

347 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

348 send_header_only: bool = scope["method"].upper() == "HEAD" 

349 send_pathsend: bool = "http.response.pathsend" in scope.get("extensions", {}) 

350 

351 if self.stat_result is None: 

352 try: 

353 stat_result = await anyio.to_thread.run_sync(os.stat, self.path) 

354 self.set_stat_headers(stat_result) 

355 except FileNotFoundError: 

356 raise RuntimeError(f"File at path {self.path} does not exist.") 

357 else: 

358 mode = stat_result.st_mode 

359 if not stat.S_ISREG(mode): 

360 raise RuntimeError(f"File at path {self.path} is not a file.") 

361 else: 

362 stat_result = self.stat_result 

363 

364 headers = Headers(scope=scope) 

365 http_range = headers.get("range") 

366 http_if_range = headers.get("if-range") 

367 

368 if http_range is None or (http_if_range is not None and not self._should_use_range(http_if_range)): 

369 await self._handle_simple(send, send_header_only, send_pathsend) 

370 else: 

371 try: 

372 ranges = self._parse_range_header(http_range, stat_result.st_size) 

373 except MalformedRangeHeader as exc: 

374 return await PlainTextResponse(exc.content, status_code=400)(scope, receive, send) 

375 except RangeNotSatisfiable as exc: 

376 response = PlainTextResponse(status_code=416, headers={"Content-Range": f"*/{exc.max_size}"}) 

377 return await response(scope, receive, send) 

378 

379 if len(ranges) == 1: 

380 start, end = ranges[0] 

381 await self._handle_single_range(send, start, end, stat_result.st_size, send_header_only) 

382 else: 

383 await self._handle_multiple_ranges(send, ranges, stat_result.st_size, send_header_only) 

384 

385 if self.background is not None: 

386 await self.background() 

387 

388 async def _handle_simple(self, send: Send, send_header_only: bool, send_pathsend: bool) -> None: 

389 await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers}) 

390 if send_header_only: 

391 await send({"type": "http.response.body", "body": b"", "more_body": False}) 

392 elif send_pathsend: 

393 await send({"type": "http.response.pathsend", "path": str(self.path)}) 

394 else: 

395 async with await anyio.open_file(self.path, mode="rb") as file: 

396 more_body = True 

397 while more_body: 

398 chunk = await file.read(self.chunk_size) 

399 more_body = len(chunk) == self.chunk_size 

400 await send({"type": "http.response.body", "body": chunk, "more_body": more_body}) 

401 

402 async def _handle_single_range( 

403 self, send: Send, start: int, end: int, file_size: int, send_header_only: bool 

404 ) -> None: 

405 self.headers["content-range"] = f"bytes {start}-{end - 1}/{file_size}" 

406 self.headers["content-length"] = str(end - start) 

407 await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers}) 

408 if send_header_only: 

409 await send({"type": "http.response.body", "body": b"", "more_body": False}) 

410 else: 

411 async with await anyio.open_file(self.path, mode="rb") as file: 

412 await file.seek(start) 

413 more_body = True 

414 while more_body: 

415 chunk = await file.read(min(self.chunk_size, end - start)) 

416 start += len(chunk) 

417 more_body = len(chunk) == self.chunk_size and start < end 

418 await send({"type": "http.response.body", "body": chunk, "more_body": more_body}) 

419 

420 async def _handle_multiple_ranges( 

421 self, 

422 send: Send, 

423 ranges: list[tuple[int, int]], 

424 file_size: int, 

425 send_header_only: bool, 

426 ) -> None: 

427 # In firefox and chrome, they use boundary with 95-96 bits entropy (that's roughly 13 bytes). 

428 boundary = token_hex(13) 

429 content_length, header_generator = self.generate_multipart( 

430 ranges, boundary, file_size, self.headers["content-type"] 

431 ) 

432 self.headers["content-range"] = f"multipart/byteranges; boundary={boundary}" 

433 self.headers["content-length"] = str(content_length) 

434 await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers}) 

435 if send_header_only: 

436 await send({"type": "http.response.body", "body": b"", "more_body": False}) 

437 else: 

438 async with await anyio.open_file(self.path, mode="rb") as file: 

439 for start, end in ranges: 

440 await send({"type": "http.response.body", "body": header_generator(start, end), "more_body": True}) 

441 await file.seek(start) 

442 while start < end: 

443 chunk = await file.read(min(self.chunk_size, end - start)) 

444 start += len(chunk) 

445 await send({"type": "http.response.body", "body": chunk, "more_body": True}) 

446 await send({"type": "http.response.body", "body": b"\n", "more_body": True}) 

447 await send( 

448 { 

449 "type": "http.response.body", 

450 "body": f"\n--{boundary}--\n".encode("latin-1"), 

451 "more_body": False, 

452 } 

453 ) 

454 

455 def _should_use_range(self, http_if_range: str) -> bool: 

456 return http_if_range == self.headers["last-modified"] or http_if_range == self.headers["etag"] 

457 

458 @staticmethod 

459 def _parse_range_header(http_range: str, file_size: int) -> list[tuple[int, int]]: 

460 ranges: list[tuple[int, int]] = [] 

461 try: 

462 units, range_ = http_range.split("=", 1) 

463 except ValueError: 

464 raise MalformedRangeHeader() 

465 

466 units = units.strip().lower() 

467 

468 if units != "bytes": 

469 raise MalformedRangeHeader("Only support bytes range") 

470 

471 ranges = [ 

472 ( 

473 int(_[0]) if _[0] else file_size - int(_[1]), 

474 int(_[1]) + 1 if _[0] and _[1] and int(_[1]) < file_size else file_size, 

475 ) 

476 for _ in _RANGE_PATTERN.findall(range_) 

477 if _ != ("", "") 

478 ] 

479 

480 if len(ranges) == 0: 

481 raise MalformedRangeHeader("Range header: range must be requested") 

482 

483 if any(not (0 <= start < file_size) for start, _ in ranges): 

484 raise RangeNotSatisfiable(file_size) 

485 

486 if any(start > end for start, end in ranges): 

487 raise MalformedRangeHeader("Range header: start must be less than end") 

488 

489 if len(ranges) == 1: 

490 return ranges 

491 

492 # Merge ranges 

493 result: list[tuple[int, int]] = [] 

494 for start, end in ranges: 

495 for p in range(len(result)): 

496 p_start, p_end = result[p] 

497 if start > p_end: 

498 continue 

499 elif end < p_start: 

500 result.insert(p, (start, end)) # THIS IS NOT REACHED! 

501 break 

502 else: 

503 result[p] = (min(start, p_start), max(end, p_end)) 

504 break 

505 else: 

506 result.append((start, end)) 

507 

508 return result 

509 

510 def generate_multipart( 

511 self, 

512 ranges: Sequence[tuple[int, int]], 

513 boundary: str, 

514 max_size: int, 

515 content_type: str, 

516 ) -> tuple[int, Callable[[int, int], bytes]]: 

517 r""" 

518 Multipart response headers generator. 

519 

520 ``` 

521 --{boundary}\n 

522 Content-Type: {content_type}\n 

523 Content-Range: bytes {start}-{end-1}/{max_size}\n 

524 \n 

525 ..........content...........\n 

526 --{boundary}\n 

527 Content-Type: {content_type}\n 

528 Content-Range: bytes {start}-{end-1}/{max_size}\n 

529 \n 

530 ..........content...........\n 

531 --{boundary}--\n 

532 ``` 

533 """ 

534 boundary_len = len(boundary) 

535 static_header_part_len = 44 + boundary_len + len(content_type) + len(str(max_size)) 

536 content_length = sum( 

537 (len(str(start)) + len(str(end - 1)) + static_header_part_len) # Headers 

538 + (end - start) # Content 

539 for start, end in ranges 

540 ) + ( 

541 5 + boundary_len # --boundary--\n 

542 ) 

543 return ( 

544 content_length, 

545 lambda start, end: ( 

546 f"--{boundary}\nContent-Type: {content_type}\nContent-Range: bytes {start}-{end - 1}/{max_size}\n\n" 

547 ).encode("latin-1"), 

548 )