1from __future__ import annotations
2
3import hashlib
4import http.cookies
5import json
6import os
7import re
8import stat
9import sys
10import warnings
11from collections.abc import AsyncIterable, Awaitable, Iterable, Mapping, Sequence
12from datetime import datetime
13from email.utils import format_datetime, formatdate
14from functools import partial
15from mimetypes import guess_type
16from secrets import token_hex
17from typing import Any, Callable, Literal, Union
18from urllib.parse import quote
19
20import anyio
21import anyio.to_thread
22
23from starlette._utils import collapse_excgroups
24from starlette.background import BackgroundTask
25from starlette.concurrency import iterate_in_threadpool
26from starlette.datastructures import URL, Headers, MutableHeaders
27from starlette.requests import ClientDisconnect
28from starlette.types import Receive, Scope, Send
29
30
31class Response:
32 media_type = None
33 charset = "utf-8"
34
35 def __init__(
36 self,
37 content: Any = None,
38 status_code: int = 200,
39 headers: Mapping[str, str] | None = None,
40 media_type: str | None = None,
41 background: BackgroundTask | None = None,
42 ) -> None:
43 self.status_code = status_code
44 if media_type is not None:
45 self.media_type = media_type
46 self.background = background
47 self.body = self.render(content)
48 self.init_headers(headers)
49
50 def render(self, content: Any) -> bytes | memoryview:
51 if content is None:
52 return b""
53 if isinstance(content, (bytes, memoryview)):
54 return content
55 return content.encode(self.charset) # type: ignore
56
57 def init_headers(self, headers: Mapping[str, str] | None = None) -> None:
58 if headers is None:
59 raw_headers: list[tuple[bytes, bytes]] = []
60 populate_content_length = True
61 populate_content_type = True
62 else:
63 raw_headers = [(k.lower().encode("latin-1"), v.encode("latin-1")) for k, v in headers.items()]
64 keys = [h[0] for h in raw_headers]
65 populate_content_length = b"content-length" not in keys
66 populate_content_type = b"content-type" not in keys
67
68 body = getattr(self, "body", None)
69 if (
70 body is not None
71 and populate_content_length
72 and not (self.status_code < 200 or self.status_code in (204, 304))
73 ):
74 content_length = str(len(body))
75 raw_headers.append((b"content-length", content_length.encode("latin-1")))
76
77 content_type = self.media_type
78 if content_type is not None and populate_content_type:
79 if content_type.startswith("text/") and "charset=" not in content_type.lower():
80 content_type += "; charset=" + self.charset
81 raw_headers.append((b"content-type", content_type.encode("latin-1")))
82
83 self.raw_headers = raw_headers
84
85 @property
86 def headers(self) -> MutableHeaders:
87 if not hasattr(self, "_headers"):
88 self._headers = MutableHeaders(raw=self.raw_headers)
89 return self._headers
90
91 def set_cookie(
92 self,
93 key: str,
94 value: str = "",
95 max_age: int | None = None,
96 expires: datetime | str | int | None = None,
97 path: str | None = "/",
98 domain: str | None = None,
99 secure: bool = False,
100 httponly: bool = False,
101 samesite: Literal["lax", "strict", "none"] | None = "lax",
102 partitioned: bool = False,
103 ) -> None:
104 cookie: http.cookies.BaseCookie[str] = http.cookies.SimpleCookie()
105 cookie[key] = value
106 if max_age is not None:
107 cookie[key]["max-age"] = max_age
108 if expires is not None:
109 if isinstance(expires, datetime):
110 cookie[key]["expires"] = format_datetime(expires, usegmt=True)
111 else:
112 cookie[key]["expires"] = expires
113 if path is not None:
114 cookie[key]["path"] = path
115 if domain is not None:
116 cookie[key]["domain"] = domain
117 if secure:
118 cookie[key]["secure"] = True
119 if httponly:
120 cookie[key]["httponly"] = True
121 if samesite is not None:
122 assert samesite.lower() in [
123 "strict",
124 "lax",
125 "none",
126 ], "samesite must be either 'strict', 'lax' or 'none'"
127 cookie[key]["samesite"] = samesite
128 if partitioned:
129 if sys.version_info < (3, 14):
130 raise ValueError("Partitioned cookies are only supported in Python 3.14 and above.") # pragma: no cover
131 cookie[key]["partitioned"] = True # pragma: no cover
132
133 cookie_val = cookie.output(header="").strip()
134 self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1")))
135
136 def delete_cookie(
137 self,
138 key: str,
139 path: str = "/",
140 domain: str | None = None,
141 secure: bool = False,
142 httponly: bool = False,
143 samesite: Literal["lax", "strict", "none"] | None = "lax",
144 ) -> None:
145 self.set_cookie(
146 key,
147 max_age=0,
148 expires=0,
149 path=path,
150 domain=domain,
151 secure=secure,
152 httponly=httponly,
153 samesite=samesite,
154 )
155
156 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
157 prefix = "websocket." if scope["type"] == "websocket" else ""
158 await send(
159 {
160 "type": prefix + "http.response.start",
161 "status": self.status_code,
162 "headers": self.raw_headers,
163 }
164 )
165 await send({"type": prefix + "http.response.body", "body": self.body})
166
167 if self.background is not None:
168 await self.background()
169
170
171class HTMLResponse(Response):
172 media_type = "text/html"
173
174
175class PlainTextResponse(Response):
176 media_type = "text/plain"
177
178
179class JSONResponse(Response):
180 media_type = "application/json"
181
182 def __init__(
183 self,
184 content: Any,
185 status_code: int = 200,
186 headers: Mapping[str, str] | None = None,
187 media_type: str | None = None,
188 background: BackgroundTask | None = None,
189 ) -> None:
190 super().__init__(content, status_code, headers, media_type, background)
191
192 def render(self, content: Any) -> bytes:
193 return json.dumps(
194 content,
195 ensure_ascii=False,
196 allow_nan=False,
197 indent=None,
198 separators=(",", ":"),
199 ).encode("utf-8")
200
201
202class RedirectResponse(Response):
203 def __init__(
204 self,
205 url: str | URL,
206 status_code: int = 307,
207 headers: Mapping[str, str] | None = None,
208 background: BackgroundTask | None = None,
209 ) -> None:
210 super().__init__(content=b"", status_code=status_code, headers=headers, background=background)
211 self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;")
212
213
214Content = Union[str, bytes, memoryview]
215SyncContentStream = Iterable[Content]
216AsyncContentStream = AsyncIterable[Content]
217ContentStream = Union[AsyncContentStream, SyncContentStream]
218
219
220class StreamingResponse(Response):
221 body_iterator: AsyncContentStream
222
223 def __init__(
224 self,
225 content: ContentStream,
226 status_code: int = 200,
227 headers: Mapping[str, str] | None = None,
228 media_type: str | None = None,
229 background: BackgroundTask | None = None,
230 ) -> None:
231 if isinstance(content, AsyncIterable):
232 self.body_iterator = content
233 else:
234 self.body_iterator = iterate_in_threadpool(content)
235 self.status_code = status_code
236 self.media_type = self.media_type if media_type is None else media_type
237 self.background = background
238 self.init_headers(headers)
239
240 async def listen_for_disconnect(self, receive: Receive) -> None:
241 while True:
242 message = await receive()
243 if message["type"] == "http.disconnect":
244 break
245
246 async def stream_response(self, send: Send) -> None:
247 await send(
248 {
249 "type": "http.response.start",
250 "status": self.status_code,
251 "headers": self.raw_headers,
252 }
253 )
254 async for chunk in self.body_iterator:
255 if not isinstance(chunk, (bytes, memoryview)):
256 chunk = chunk.encode(self.charset)
257 await send({"type": "http.response.body", "body": chunk, "more_body": True})
258
259 await send({"type": "http.response.body", "body": b"", "more_body": False})
260
261 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
262 spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split(".")))
263
264 if spec_version >= (2, 4):
265 try:
266 await self.stream_response(send)
267 except OSError:
268 raise ClientDisconnect()
269 else:
270 with collapse_excgroups():
271 async with anyio.create_task_group() as task_group:
272
273 async def wrap(func: Callable[[], Awaitable[None]]) -> None:
274 await func()
275 task_group.cancel_scope.cancel()
276
277 task_group.start_soon(wrap, partial(self.stream_response, send))
278 await wrap(partial(self.listen_for_disconnect, receive))
279
280 if self.background is not None:
281 await self.background()
282
283
284class MalformedRangeHeader(Exception):
285 def __init__(self, content: str = "Malformed range header.") -> None:
286 self.content = content
287
288
289class RangeNotSatisfiable(Exception):
290 def __init__(self, max_size: int) -> None:
291 self.max_size = max_size
292
293
294_RANGE_PATTERN = re.compile(r"(\d*)-(\d*)")
295
296
297class FileResponse(Response):
298 chunk_size = 64 * 1024
299
300 def __init__(
301 self,
302 path: str | os.PathLike[str],
303 status_code: int = 200,
304 headers: Mapping[str, str] | None = None,
305 media_type: str | None = None,
306 background: BackgroundTask | None = None,
307 filename: str | None = None,
308 stat_result: os.stat_result | None = None,
309 method: str | None = None,
310 content_disposition_type: str = "attachment",
311 ) -> None:
312 self.path = path
313 self.status_code = status_code
314 self.filename = filename
315 if method is not None:
316 warnings.warn(
317 "The 'method' parameter is not used, and it will be removed.",
318 DeprecationWarning,
319 )
320 if media_type is None:
321 media_type = guess_type(filename or path)[0] or "text/plain"
322 self.media_type = media_type
323 self.background = background
324 self.init_headers(headers)
325 self.headers.setdefault("accept-ranges", "bytes")
326 if self.filename is not None:
327 content_disposition_filename = quote(self.filename)
328 if content_disposition_filename != self.filename:
329 content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}"
330 else:
331 content_disposition = f'{content_disposition_type}; filename="{self.filename}"'
332 self.headers.setdefault("content-disposition", content_disposition)
333 self.stat_result = stat_result
334 if stat_result is not None:
335 self.set_stat_headers(stat_result)
336
337 def set_stat_headers(self, stat_result: os.stat_result) -> None:
338 content_length = str(stat_result.st_size)
339 last_modified = formatdate(stat_result.st_mtime, usegmt=True)
340 etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size)
341 etag = f'"{hashlib.md5(etag_base.encode(), usedforsecurity=False).hexdigest()}"'
342
343 self.headers.setdefault("content-length", content_length)
344 self.headers.setdefault("last-modified", last_modified)
345 self.headers.setdefault("etag", etag)
346
347 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
348 send_header_only: bool = scope["method"].upper() == "HEAD"
349 send_pathsend: bool = "http.response.pathsend" in scope.get("extensions", {})
350
351 if self.stat_result is None:
352 try:
353 stat_result = await anyio.to_thread.run_sync(os.stat, self.path)
354 self.set_stat_headers(stat_result)
355 except FileNotFoundError:
356 raise RuntimeError(f"File at path {self.path} does not exist.")
357 else:
358 mode = stat_result.st_mode
359 if not stat.S_ISREG(mode):
360 raise RuntimeError(f"File at path {self.path} is not a file.")
361 else:
362 stat_result = self.stat_result
363
364 headers = Headers(scope=scope)
365 http_range = headers.get("range")
366 http_if_range = headers.get("if-range")
367
368 if http_range is None or (http_if_range is not None and not self._should_use_range(http_if_range)):
369 await self._handle_simple(send, send_header_only, send_pathsend)
370 else:
371 try:
372 ranges = self._parse_range_header(http_range, stat_result.st_size)
373 except MalformedRangeHeader as exc:
374 return await PlainTextResponse(exc.content, status_code=400)(scope, receive, send)
375 except RangeNotSatisfiable as exc:
376 response = PlainTextResponse(status_code=416, headers={"Content-Range": f"*/{exc.max_size}"})
377 return await response(scope, receive, send)
378
379 if len(ranges) == 1:
380 start, end = ranges[0]
381 await self._handle_single_range(send, start, end, stat_result.st_size, send_header_only)
382 else:
383 await self._handle_multiple_ranges(send, ranges, stat_result.st_size, send_header_only)
384
385 if self.background is not None:
386 await self.background()
387
388 async def _handle_simple(self, send: Send, send_header_only: bool, send_pathsend: bool) -> None:
389 await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers})
390 if send_header_only:
391 await send({"type": "http.response.body", "body": b"", "more_body": False})
392 elif send_pathsend:
393 await send({"type": "http.response.pathsend", "path": str(self.path)})
394 else:
395 async with await anyio.open_file(self.path, mode="rb") as file:
396 more_body = True
397 while more_body:
398 chunk = await file.read(self.chunk_size)
399 more_body = len(chunk) == self.chunk_size
400 await send({"type": "http.response.body", "body": chunk, "more_body": more_body})
401
402 async def _handle_single_range(
403 self, send: Send, start: int, end: int, file_size: int, send_header_only: bool
404 ) -> None:
405 self.headers["content-range"] = f"bytes {start}-{end - 1}/{file_size}"
406 self.headers["content-length"] = str(end - start)
407 await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers})
408 if send_header_only:
409 await send({"type": "http.response.body", "body": b"", "more_body": False})
410 else:
411 async with await anyio.open_file(self.path, mode="rb") as file:
412 await file.seek(start)
413 more_body = True
414 while more_body:
415 chunk = await file.read(min(self.chunk_size, end - start))
416 start += len(chunk)
417 more_body = len(chunk) == self.chunk_size and start < end
418 await send({"type": "http.response.body", "body": chunk, "more_body": more_body})
419
420 async def _handle_multiple_ranges(
421 self,
422 send: Send,
423 ranges: list[tuple[int, int]],
424 file_size: int,
425 send_header_only: bool,
426 ) -> None:
427 # In firefox and chrome, they use boundary with 95-96 bits entropy (that's roughly 13 bytes).
428 boundary = token_hex(13)
429 content_length, header_generator = self.generate_multipart(
430 ranges, boundary, file_size, self.headers["content-type"]
431 )
432 self.headers["content-range"] = f"multipart/byteranges; boundary={boundary}"
433 self.headers["content-length"] = str(content_length)
434 await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers})
435 if send_header_only:
436 await send({"type": "http.response.body", "body": b"", "more_body": False})
437 else:
438 async with await anyio.open_file(self.path, mode="rb") as file:
439 for start, end in ranges:
440 await send({"type": "http.response.body", "body": header_generator(start, end), "more_body": True})
441 await file.seek(start)
442 while start < end:
443 chunk = await file.read(min(self.chunk_size, end - start))
444 start += len(chunk)
445 await send({"type": "http.response.body", "body": chunk, "more_body": True})
446 await send({"type": "http.response.body", "body": b"\n", "more_body": True})
447 await send(
448 {
449 "type": "http.response.body",
450 "body": f"\n--{boundary}--\n".encode("latin-1"),
451 "more_body": False,
452 }
453 )
454
455 def _should_use_range(self, http_if_range: str) -> bool:
456 return http_if_range == self.headers["last-modified"] or http_if_range == self.headers["etag"]
457
458 @staticmethod
459 def _parse_range_header(http_range: str, file_size: int) -> list[tuple[int, int]]:
460 ranges: list[tuple[int, int]] = []
461 try:
462 units, range_ = http_range.split("=", 1)
463 except ValueError:
464 raise MalformedRangeHeader()
465
466 units = units.strip().lower()
467
468 if units != "bytes":
469 raise MalformedRangeHeader("Only support bytes range")
470
471 ranges = [
472 (
473 int(_[0]) if _[0] else file_size - int(_[1]),
474 int(_[1]) + 1 if _[0] and _[1] and int(_[1]) < file_size else file_size,
475 )
476 for _ in _RANGE_PATTERN.findall(range_)
477 if _ != ("", "")
478 ]
479
480 if len(ranges) == 0:
481 raise MalformedRangeHeader("Range header: range must be requested")
482
483 if any(not (0 <= start < file_size) for start, _ in ranges):
484 raise RangeNotSatisfiable(file_size)
485
486 if any(start > end for start, end in ranges):
487 raise MalformedRangeHeader("Range header: start must be less than end")
488
489 if len(ranges) == 1:
490 return ranges
491
492 # Merge ranges
493 result: list[tuple[int, int]] = []
494 for start, end in ranges:
495 for p in range(len(result)):
496 p_start, p_end = result[p]
497 if start > p_end:
498 continue
499 elif end < p_start:
500 result.insert(p, (start, end)) # THIS IS NOT REACHED!
501 break
502 else:
503 result[p] = (min(start, p_start), max(end, p_end))
504 break
505 else:
506 result.append((start, end))
507
508 return result
509
510 def generate_multipart(
511 self,
512 ranges: Sequence[tuple[int, int]],
513 boundary: str,
514 max_size: int,
515 content_type: str,
516 ) -> tuple[int, Callable[[int, int], bytes]]:
517 r"""
518 Multipart response headers generator.
519
520 ```
521 --{boundary}\n
522 Content-Type: {content_type}\n
523 Content-Range: bytes {start}-{end-1}/{max_size}\n
524 \n
525 ..........content...........\n
526 --{boundary}\n
527 Content-Type: {content_type}\n
528 Content-Range: bytes {start}-{end-1}/{max_size}\n
529 \n
530 ..........content...........\n
531 --{boundary}--\n
532 ```
533 """
534 boundary_len = len(boundary)
535 static_header_part_len = 44 + boundary_len + len(content_type) + len(str(max_size))
536 content_length = sum(
537 (len(str(start)) + len(str(end - 1)) + static_header_part_len) # Headers
538 + (end - start) # Content
539 for start, end in ranges
540 ) + (
541 5 + boundary_len # --boundary--\n
542 )
543 return (
544 content_length,
545 lambda start, end: (
546 f"--{boundary}\nContent-Type: {content_type}\nContent-Range: bytes {start}-{end - 1}/{max_size}\n\n"
547 ).encode("latin-1"),
548 )