1from __future__ import annotations
2
3import http.cookies
4import json
5import os
6import stat
7import typing
8import warnings
9from datetime import datetime
10from email.utils import format_datetime, formatdate
11from functools import partial
12from mimetypes import guess_type
13from urllib.parse import quote
14
15import anyio
16import anyio.to_thread
17
18from starlette._compat import md5_hexdigest
19from starlette.background import BackgroundTask
20from starlette.concurrency import iterate_in_threadpool
21from starlette.datastructures import URL, MutableHeaders
22from starlette.types import Receive, Scope, Send
23
24
25class Response:
26 media_type = None
27 charset = "utf-8"
28
29 def __init__(
30 self,
31 content: typing.Any = None,
32 status_code: int = 200,
33 headers: typing.Mapping[str, str] | None = None,
34 media_type: str | None = None,
35 background: BackgroundTask | None = None,
36 ) -> None:
37 self.status_code = status_code
38 if media_type is not None:
39 self.media_type = media_type
40 self.background = background
41 self.body = self.render(content)
42 self.init_headers(headers)
43
44 def render(self, content: typing.Any) -> bytes | memoryview:
45 if content is None:
46 return b""
47 if isinstance(content, (bytes, memoryview)):
48 return content
49 return content.encode(self.charset) # type: ignore
50
51 def init_headers(self, headers: typing.Mapping[str, str] | None = None) -> None:
52 if headers is None:
53 raw_headers: list[tuple[bytes, bytes]] = []
54 populate_content_length = True
55 populate_content_type = True
56 else:
57 raw_headers = [(k.lower().encode("latin-1"), v.encode("latin-1")) for k, v in headers.items()]
58 keys = [h[0] for h in raw_headers]
59 populate_content_length = b"content-length" not in keys
60 populate_content_type = b"content-type" not in keys
61
62 body = getattr(self, "body", None)
63 if (
64 body is not None
65 and populate_content_length
66 and not (self.status_code < 200 or self.status_code in (204, 304))
67 ):
68 content_length = str(len(body))
69 raw_headers.append((b"content-length", content_length.encode("latin-1")))
70
71 content_type = self.media_type
72 if content_type is not None and populate_content_type:
73 if content_type.startswith("text/") and "charset=" not in content_type.lower():
74 content_type += "; charset=" + self.charset
75 raw_headers.append((b"content-type", content_type.encode("latin-1")))
76
77 self.raw_headers = raw_headers
78
79 @property
80 def headers(self) -> MutableHeaders:
81 if not hasattr(self, "_headers"):
82 self._headers = MutableHeaders(raw=self.raw_headers)
83 return self._headers
84
85 def set_cookie(
86 self,
87 key: str,
88 value: str = "",
89 max_age: int | None = None,
90 expires: datetime | str | int | None = None,
91 path: str | None = "/",
92 domain: str | None = None,
93 secure: bool = False,
94 httponly: bool = False,
95 samesite: typing.Literal["lax", "strict", "none"] | None = "lax",
96 ) -> None:
97 cookie: http.cookies.BaseCookie[str] = http.cookies.SimpleCookie()
98 cookie[key] = value
99 if max_age is not None:
100 cookie[key]["max-age"] = max_age
101 if expires is not None:
102 if isinstance(expires, datetime):
103 cookie[key]["expires"] = format_datetime(expires, usegmt=True)
104 else:
105 cookie[key]["expires"] = expires
106 if path is not None:
107 cookie[key]["path"] = path
108 if domain is not None:
109 cookie[key]["domain"] = domain
110 if secure:
111 cookie[key]["secure"] = True
112 if httponly:
113 cookie[key]["httponly"] = True
114 if samesite is not None:
115 assert samesite.lower() in [
116 "strict",
117 "lax",
118 "none",
119 ], "samesite must be either 'strict', 'lax' or 'none'"
120 cookie[key]["samesite"] = samesite
121 cookie_val = cookie.output(header="").strip()
122 self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1")))
123
124 def delete_cookie(
125 self,
126 key: str,
127 path: str = "/",
128 domain: str | None = None,
129 secure: bool = False,
130 httponly: bool = False,
131 samesite: typing.Literal["lax", "strict", "none"] | None = "lax",
132 ) -> None:
133 self.set_cookie(
134 key,
135 max_age=0,
136 expires=0,
137 path=path,
138 domain=domain,
139 secure=secure,
140 httponly=httponly,
141 samesite=samesite,
142 )
143
144 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
145 prefix = "websocket." if scope["type"] == "websocket" else ""
146 await send(
147 {
148 "type": prefix + "http.response.start",
149 "status": self.status_code,
150 "headers": self.raw_headers,
151 }
152 )
153 await send({"type": prefix + "http.response.body", "body": self.body})
154
155 if self.background is not None:
156 await self.background()
157
158
159class HTMLResponse(Response):
160 media_type = "text/html"
161
162
163class PlainTextResponse(Response):
164 media_type = "text/plain"
165
166
167class JSONResponse(Response):
168 media_type = "application/json"
169
170 def __init__(
171 self,
172 content: typing.Any,
173 status_code: int = 200,
174 headers: typing.Mapping[str, str] | None = None,
175 media_type: str | None = None,
176 background: BackgroundTask | None = None,
177 ) -> None:
178 super().__init__(content, status_code, headers, media_type, background)
179
180 def render(self, content: typing.Any) -> bytes:
181 return json.dumps(
182 content,
183 ensure_ascii=False,
184 allow_nan=False,
185 indent=None,
186 separators=(",", ":"),
187 ).encode("utf-8")
188
189
190class RedirectResponse(Response):
191 def __init__(
192 self,
193 url: str | URL,
194 status_code: int = 307,
195 headers: typing.Mapping[str, str] | None = None,
196 background: BackgroundTask | None = None,
197 ) -> None:
198 super().__init__(content=b"", status_code=status_code, headers=headers, background=background)
199 self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;")
200
201
202Content = typing.Union[str, bytes, memoryview]
203SyncContentStream = typing.Iterable[Content]
204AsyncContentStream = typing.AsyncIterable[Content]
205ContentStream = typing.Union[AsyncContentStream, SyncContentStream]
206
207
208class StreamingResponse(Response):
209 body_iterator: AsyncContentStream
210
211 def __init__(
212 self,
213 content: ContentStream,
214 status_code: int = 200,
215 headers: typing.Mapping[str, str] | None = None,
216 media_type: str | None = None,
217 background: BackgroundTask | None = None,
218 ) -> None:
219 if isinstance(content, typing.AsyncIterable):
220 self.body_iterator = content
221 else:
222 self.body_iterator = iterate_in_threadpool(content)
223 self.status_code = status_code
224 self.media_type = self.media_type if media_type is None else media_type
225 self.background = background
226 self.init_headers(headers)
227
228 async def listen_for_disconnect(self, receive: Receive) -> None:
229 while True:
230 message = await receive()
231 if message["type"] == "http.disconnect":
232 break
233
234 async def stream_response(self, send: Send) -> None:
235 await send(
236 {
237 "type": "http.response.start",
238 "status": self.status_code,
239 "headers": self.raw_headers,
240 }
241 )
242 async for chunk in self.body_iterator:
243 if not isinstance(chunk, (bytes, memoryview)):
244 chunk = chunk.encode(self.charset)
245 await send({"type": "http.response.body", "body": chunk, "more_body": True})
246
247 await send({"type": "http.response.body", "body": b"", "more_body": False})
248
249 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
250 async with anyio.create_task_group() as task_group:
251
252 async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
253 await func()
254 task_group.cancel_scope.cancel()
255
256 task_group.start_soon(wrap, partial(self.stream_response, send))
257 await wrap(partial(self.listen_for_disconnect, receive))
258
259 if self.background is not None:
260 await self.background()
261
262
263class FileResponse(Response):
264 chunk_size = 64 * 1024
265
266 def __init__(
267 self,
268 path: str | os.PathLike[str],
269 status_code: int = 200,
270 headers: typing.Mapping[str, str] | None = None,
271 media_type: str | None = None,
272 background: BackgroundTask | None = None,
273 filename: str | None = None,
274 stat_result: os.stat_result | None = None,
275 method: str | None = None,
276 content_disposition_type: str = "attachment",
277 ) -> None:
278 self.path = path
279 self.status_code = status_code
280 self.filename = filename
281 if method is not None:
282 warnings.warn(
283 "The 'method' parameter is not used, and it will be removed.",
284 DeprecationWarning,
285 )
286 if media_type is None:
287 media_type = guess_type(filename or path)[0] or "text/plain"
288 self.media_type = media_type
289 self.background = background
290 self.init_headers(headers)
291 if self.filename is not None:
292 content_disposition_filename = quote(self.filename)
293 if content_disposition_filename != self.filename:
294 content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}"
295 else:
296 content_disposition = f'{content_disposition_type}; filename="{self.filename}"'
297 self.headers.setdefault("content-disposition", content_disposition)
298 self.stat_result = stat_result
299 if stat_result is not None:
300 self.set_stat_headers(stat_result)
301
302 def set_stat_headers(self, stat_result: os.stat_result) -> None:
303 content_length = str(stat_result.st_size)
304 last_modified = formatdate(stat_result.st_mtime, usegmt=True)
305 etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size)
306 etag = f'"{md5_hexdigest(etag_base.encode(), usedforsecurity=False)}"'
307
308 self.headers.setdefault("content-length", content_length)
309 self.headers.setdefault("last-modified", last_modified)
310 self.headers.setdefault("etag", etag)
311
312 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
313 if self.stat_result is None:
314 try:
315 stat_result = await anyio.to_thread.run_sync(os.stat, self.path)
316 self.set_stat_headers(stat_result)
317 except FileNotFoundError:
318 raise RuntimeError(f"File at path {self.path} does not exist.")
319 else:
320 mode = stat_result.st_mode
321 if not stat.S_ISREG(mode):
322 raise RuntimeError(f"File at path {self.path} is not a file.")
323 await send(
324 {
325 "type": "http.response.start",
326 "status": self.status_code,
327 "headers": self.raw_headers,
328 }
329 )
330 if scope["method"].upper() == "HEAD":
331 await send({"type": "http.response.body", "body": b"", "more_body": False})
332 else:
333 async with await anyio.open_file(self.path, mode="rb") as file:
334 more_body = True
335 while more_body:
336 chunk = await file.read(self.chunk_size)
337 more_body = len(chunk) == self.chunk_size
338 await send(
339 {
340 "type": "http.response.body",
341 "body": chunk,
342 "more_body": more_body,
343 }
344 )
345 if self.background is not None:
346 await self.background()