1from __future__ import annotations
2
3import http.cookies
4import json
5import os
6import stat
7import typing
8import warnings
9from datetime import datetime
10from email.utils import format_datetime, formatdate
11from functools import partial
12from mimetypes import guess_type
13from urllib.parse import quote
14
15import anyio
16import anyio.to_thread
17
18from starlette._compat import md5_hexdigest
19from starlette.background import BackgroundTask
20from starlette.concurrency import iterate_in_threadpool
21from starlette.datastructures import URL, MutableHeaders
22from starlette.types import Receive, Scope, Send
23
24
25class Response:
26 media_type = None
27 charset = "utf-8"
28
29 def __init__(
30 self,
31 content: typing.Any = None,
32 status_code: int = 200,
33 headers: typing.Mapping[str, str] | None = None,
34 media_type: str | None = None,
35 background: BackgroundTask | None = None,
36 ) -> None:
37 self.status_code = status_code
38 if media_type is not None:
39 self.media_type = media_type
40 self.background = background
41 self.body = self.render(content)
42 self.init_headers(headers)
43
44 def render(self, content: typing.Any) -> bytes:
45 if content is None:
46 return b""
47 if isinstance(content, bytes):
48 return content
49 return content.encode(self.charset) # type: ignore
50
51 def init_headers(self, headers: typing.Mapping[str, str] | None = None) -> None:
52 if headers is None:
53 raw_headers: list[tuple[bytes, bytes]] = []
54 populate_content_length = True
55 populate_content_type = True
56 else:
57 raw_headers = [
58 (k.lower().encode("latin-1"), v.encode("latin-1"))
59 for k, v in headers.items()
60 ]
61 keys = [h[0] for h in raw_headers]
62 populate_content_length = b"content-length" not in keys
63 populate_content_type = b"content-type" not in keys
64
65 body = getattr(self, "body", None)
66 if (
67 body is not None
68 and populate_content_length
69 and not (self.status_code < 200 or self.status_code in (204, 304))
70 ):
71 content_length = str(len(body))
72 raw_headers.append((b"content-length", content_length.encode("latin-1")))
73
74 content_type = self.media_type
75 if content_type is not None and populate_content_type:
76 if (
77 content_type.startswith("text/")
78 and "charset=" not in content_type.lower()
79 ):
80 content_type += "; charset=" + self.charset
81 raw_headers.append((b"content-type", content_type.encode("latin-1")))
82
83 self.raw_headers = raw_headers
84
85 @property
86 def headers(self) -> MutableHeaders:
87 if not hasattr(self, "_headers"):
88 self._headers = MutableHeaders(raw=self.raw_headers)
89 return self._headers
90
91 def set_cookie(
92 self,
93 key: str,
94 value: str = "",
95 max_age: int | None = None,
96 expires: datetime | str | int | None = None,
97 path: str = "/",
98 domain: str | None = None,
99 secure: bool = False,
100 httponly: bool = False,
101 samesite: typing.Literal["lax", "strict", "none"] | None = "lax",
102 ) -> None:
103 cookie: http.cookies.BaseCookie[str] = http.cookies.SimpleCookie()
104 cookie[key] = value
105 if max_age is not None:
106 cookie[key]["max-age"] = max_age
107 if expires is not None:
108 if isinstance(expires, datetime):
109 cookie[key]["expires"] = format_datetime(expires, usegmt=True)
110 else:
111 cookie[key]["expires"] = expires
112 if path is not None:
113 cookie[key]["path"] = path
114 if domain is not None:
115 cookie[key]["domain"] = domain
116 if secure:
117 cookie[key]["secure"] = True
118 if httponly:
119 cookie[key]["httponly"] = True
120 if samesite is not None:
121 assert samesite.lower() in [
122 "strict",
123 "lax",
124 "none",
125 ], "samesite must be either 'strict', 'lax' or 'none'"
126 cookie[key]["samesite"] = samesite
127 cookie_val = cookie.output(header="").strip()
128 self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1")))
129
130 def delete_cookie(
131 self,
132 key: str,
133 path: str = "/",
134 domain: str | None = None,
135 secure: bool = False,
136 httponly: bool = False,
137 samesite: typing.Literal["lax", "strict", "none"] | None = "lax",
138 ) -> None:
139 self.set_cookie(
140 key,
141 max_age=0,
142 expires=0,
143 path=path,
144 domain=domain,
145 secure=secure,
146 httponly=httponly,
147 samesite=samesite,
148 )
149
150 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
151 prefix = "websocket." if scope["type"] == "websocket" else ""
152 await send(
153 {
154 "type": prefix + "http.response.start",
155 "status": self.status_code,
156 "headers": self.raw_headers,
157 }
158 )
159 await send({"type": prefix + "http.response.body", "body": self.body})
160
161 if self.background is not None:
162 await self.background()
163
164
165class HTMLResponse(Response):
166 media_type = "text/html"
167
168
169class PlainTextResponse(Response):
170 media_type = "text/plain"
171
172
173class JSONResponse(Response):
174 media_type = "application/json"
175
176 def __init__(
177 self,
178 content: typing.Any,
179 status_code: int = 200,
180 headers: typing.Mapping[str, str] | None = None,
181 media_type: str | None = None,
182 background: BackgroundTask | None = None,
183 ) -> None:
184 super().__init__(content, status_code, headers, media_type, background)
185
186 def render(self, content: typing.Any) -> bytes:
187 return json.dumps(
188 content,
189 ensure_ascii=False,
190 allow_nan=False,
191 indent=None,
192 separators=(",", ":"),
193 ).encode("utf-8")
194
195
196class RedirectResponse(Response):
197 def __init__(
198 self,
199 url: str | URL,
200 status_code: int = 307,
201 headers: typing.Mapping[str, str] | None = None,
202 background: BackgroundTask | None = None,
203 ) -> None:
204 super().__init__(
205 content=b"", status_code=status_code, headers=headers, background=background
206 )
207 self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;")
208
209
210Content = typing.Union[str, bytes]
211SyncContentStream = typing.Iterable[Content]
212AsyncContentStream = typing.AsyncIterable[Content]
213ContentStream = typing.Union[AsyncContentStream, SyncContentStream]
214
215
216class StreamingResponse(Response):
217 body_iterator: AsyncContentStream
218
219 def __init__(
220 self,
221 content: ContentStream,
222 status_code: int = 200,
223 headers: typing.Mapping[str, str] | None = None,
224 media_type: str | None = None,
225 background: BackgroundTask | None = None,
226 ) -> None:
227 if isinstance(content, typing.AsyncIterable):
228 self.body_iterator = content
229 else:
230 self.body_iterator = iterate_in_threadpool(content)
231 self.status_code = status_code
232 self.media_type = self.media_type if media_type is None else media_type
233 self.background = background
234 self.init_headers(headers)
235
236 async def listen_for_disconnect(self, receive: Receive) -> None:
237 while True:
238 message = await receive()
239 if message["type"] == "http.disconnect":
240 break
241
242 async def stream_response(self, send: Send) -> None:
243 await send(
244 {
245 "type": "http.response.start",
246 "status": self.status_code,
247 "headers": self.raw_headers,
248 }
249 )
250 async for chunk in self.body_iterator:
251 if not isinstance(chunk, bytes):
252 chunk = chunk.encode(self.charset)
253 await send({"type": "http.response.body", "body": chunk, "more_body": True})
254
255 await send({"type": "http.response.body", "body": b"", "more_body": False})
256
257 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
258 async with anyio.create_task_group() as task_group:
259
260 async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
261 await func()
262 task_group.cancel_scope.cancel()
263
264 task_group.start_soon(wrap, partial(self.stream_response, send))
265 await wrap(partial(self.listen_for_disconnect, receive))
266
267 if self.background is not None:
268 await self.background()
269
270
271class FileResponse(Response):
272 chunk_size = 64 * 1024
273
274 def __init__(
275 self,
276 path: str | os.PathLike[str],
277 status_code: int = 200,
278 headers: typing.Mapping[str, str] | None = None,
279 media_type: str | None = None,
280 background: BackgroundTask | None = None,
281 filename: str | None = None,
282 stat_result: os.stat_result | None = None,
283 method: str | None = None,
284 content_disposition_type: str = "attachment",
285 ) -> None:
286 self.path = path
287 self.status_code = status_code
288 self.filename = filename
289 if method is not None:
290 warnings.warn(
291 "The 'method' parameter is not used, and it will be removed.",
292 DeprecationWarning,
293 )
294 if media_type is None:
295 media_type = guess_type(filename or path)[0] or "text/plain"
296 self.media_type = media_type
297 self.background = background
298 self.init_headers(headers)
299 if self.filename is not None:
300 content_disposition_filename = quote(self.filename)
301 if content_disposition_filename != self.filename:
302 content_disposition = "{}; filename*=utf-8''{}".format(
303 content_disposition_type, content_disposition_filename
304 )
305 else:
306 content_disposition = '{}; filename="{}"'.format(
307 content_disposition_type, self.filename
308 )
309 self.headers.setdefault("content-disposition", content_disposition)
310 self.stat_result = stat_result
311 if stat_result is not None:
312 self.set_stat_headers(stat_result)
313
314 def set_stat_headers(self, stat_result: os.stat_result) -> None:
315 content_length = str(stat_result.st_size)
316 last_modified = formatdate(stat_result.st_mtime, usegmt=True)
317 etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size)
318 etag = f'"{md5_hexdigest(etag_base.encode(), usedforsecurity=False)}"'
319
320 self.headers.setdefault("content-length", content_length)
321 self.headers.setdefault("last-modified", last_modified)
322 self.headers.setdefault("etag", etag)
323
324 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
325 if self.stat_result is None:
326 try:
327 stat_result = await anyio.to_thread.run_sync(os.stat, self.path)
328 self.set_stat_headers(stat_result)
329 except FileNotFoundError:
330 raise RuntimeError(f"File at path {self.path} does not exist.")
331 else:
332 mode = stat_result.st_mode
333 if not stat.S_ISREG(mode):
334 raise RuntimeError(f"File at path {self.path} is not a file.")
335 await send(
336 {
337 "type": "http.response.start",
338 "status": self.status_code,
339 "headers": self.raw_headers,
340 }
341 )
342 if scope["method"].upper() == "HEAD":
343 await send({"type": "http.response.body", "body": b"", "more_body": False})
344 elif "extensions" in scope and "http.response.pathsend" in scope["extensions"]:
345 await send({"type": "http.response.pathsend", "path": str(self.path)})
346 else:
347 async with await anyio.open_file(self.path, mode="rb") as file:
348 more_body = True
349 while more_body:
350 chunk = await file.read(self.chunk_size)
351 more_body = len(chunk) == self.chunk_size
352 await send(
353 {
354 "type": "http.response.body",
355 "body": chunk,
356 "more_body": more_body,
357 }
358 )
359 if self.background is not None:
360 await self.background()