1from __future__ import annotations
2
3import importlib.util
4import os
5import stat
6import typing
7from email.utils import parsedate
8
9import anyio
10import anyio.to_thread
11
12from starlette._utils import get_route_path
13from starlette.datastructures import URL, Headers
14from starlette.exceptions import HTTPException
15from starlette.responses import FileResponse, RedirectResponse, Response
16from starlette.types import Receive, Scope, Send
17
18PathLike = typing.Union[str, "os.PathLike[str]"]
19
20
21class NotModifiedResponse(Response):
22 NOT_MODIFIED_HEADERS = (
23 "cache-control",
24 "content-location",
25 "date",
26 "etag",
27 "expires",
28 "vary",
29 )
30
31 def __init__(self, headers: Headers):
32 super().__init__(
33 status_code=304,
34 headers={
35 name: value
36 for name, value in headers.items()
37 if name in self.NOT_MODIFIED_HEADERS
38 },
39 )
40
41
42class StaticFiles:
43 def __init__(
44 self,
45 *,
46 directory: PathLike | None = None,
47 packages: list[str | tuple[str, str]] | None = None,
48 html: bool = False,
49 check_dir: bool = True,
50 follow_symlink: bool = False,
51 ) -> None:
52 self.directory = directory
53 self.packages = packages
54 self.all_directories = self.get_directories(directory, packages)
55 self.html = html
56 self.config_checked = False
57 self.follow_symlink = follow_symlink
58 if check_dir and directory is not None and not os.path.isdir(directory):
59 raise RuntimeError(f"Directory '{directory}' does not exist")
60
61 def get_directories(
62 self,
63 directory: PathLike | None = None,
64 packages: list[str | tuple[str, str]] | None = None,
65 ) -> list[PathLike]:
66 """
67 Given `directory` and `packages` arguments, return a list of all the
68 directories that should be used for serving static files from.
69 """
70 directories = []
71 if directory is not None:
72 directories.append(directory)
73
74 for package in packages or []:
75 if isinstance(package, tuple):
76 package, statics_dir = package
77 else:
78 statics_dir = "statics"
79 spec = importlib.util.find_spec(package)
80 assert spec is not None, f"Package {package!r} could not be found."
81 assert spec.origin is not None, f"Package {package!r} could not be found."
82 package_directory = os.path.normpath(
83 os.path.join(spec.origin, "..", statics_dir)
84 )
85 assert os.path.isdir(
86 package_directory
87 ), f"Directory '{statics_dir!r}' in package {package!r} could not be found."
88 directories.append(package_directory)
89
90 return directories
91
92 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
93 """
94 The ASGI entry point.
95 """
96 assert scope["type"] == "http"
97
98 if not self.config_checked:
99 await self.check_config()
100 self.config_checked = True
101
102 path = self.get_path(scope)
103 response = await self.get_response(path, scope)
104 await response(scope, receive, send)
105
106 def get_path(self, scope: Scope) -> str:
107 """
108 Given the ASGI scope, return the `path` string to serve up,
109 with OS specific path separators, and any '..', '.' components removed.
110 """
111 route_path = get_route_path(scope)
112 return os.path.normpath(os.path.join(*route_path.split("/"))) # noqa: E501
113
114 async def get_response(self, path: str, scope: Scope) -> Response:
115 """
116 Returns an HTTP response, given the incoming path, method and request headers.
117 """
118 if scope["method"] not in ("GET", "HEAD"):
119 raise HTTPException(status_code=405)
120
121 try:
122 full_path, stat_result = await anyio.to_thread.run_sync(
123 self.lookup_path, path
124 )
125 except PermissionError:
126 raise HTTPException(status_code=401)
127 except OSError:
128 raise
129
130 if stat_result and stat.S_ISREG(stat_result.st_mode):
131 # We have a static file to serve.
132 return self.file_response(full_path, stat_result, scope)
133
134 elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html:
135 # We're in HTML mode, and have got a directory URL.
136 # Check if we have 'index.html' file to serve.
137 index_path = os.path.join(path, "index.html")
138 full_path, stat_result = await anyio.to_thread.run_sync(
139 self.lookup_path, index_path
140 )
141 if stat_result is not None and stat.S_ISREG(stat_result.st_mode):
142 if not scope["path"].endswith("/"):
143 # Directory URLs should redirect to always end in "/".
144 url = URL(scope=scope)
145 url = url.replace(path=url.path + "/")
146 return RedirectResponse(url=url)
147 return self.file_response(full_path, stat_result, scope)
148
149 if self.html:
150 # Check for '404.html' if we're in HTML mode.
151 full_path, stat_result = await anyio.to_thread.run_sync(
152 self.lookup_path, "404.html"
153 )
154 if stat_result and stat.S_ISREG(stat_result.st_mode):
155 return FileResponse(full_path, stat_result=stat_result, status_code=404)
156 raise HTTPException(status_code=404)
157
158 def lookup_path(self, path: str) -> tuple[str, os.stat_result | None]:
159 for directory in self.all_directories:
160 joined_path = os.path.join(directory, path)
161 if self.follow_symlink:
162 full_path = os.path.abspath(joined_path)
163 else:
164 full_path = os.path.realpath(joined_path)
165 directory = os.path.realpath(directory)
166 if os.path.commonpath([full_path, directory]) != directory:
167 # Don't allow misbehaving clients to break out of the static files
168 # directory.
169 continue
170 try:
171 return full_path, os.stat(full_path)
172 except (FileNotFoundError, NotADirectoryError):
173 continue
174 return "", None
175
176 def file_response(
177 self,
178 full_path: PathLike,
179 stat_result: os.stat_result,
180 scope: Scope,
181 status_code: int = 200,
182 ) -> Response:
183 request_headers = Headers(scope=scope)
184
185 response = FileResponse(
186 full_path, status_code=status_code, stat_result=stat_result
187 )
188 if self.is_not_modified(response.headers, request_headers):
189 return NotModifiedResponse(response.headers)
190 return response
191
192 async def check_config(self) -> None:
193 """
194 Perform a one-off configuration check that StaticFiles is actually
195 pointed at a directory, so that we can raise loud errors rather than
196 just returning 404 responses.
197 """
198 if self.directory is None:
199 return
200
201 try:
202 stat_result = await anyio.to_thread.run_sync(os.stat, self.directory)
203 except FileNotFoundError:
204 raise RuntimeError(
205 f"StaticFiles directory '{self.directory}' does not exist."
206 )
207 if not (stat.S_ISDIR(stat_result.st_mode) or stat.S_ISLNK(stat_result.st_mode)):
208 raise RuntimeError(
209 f"StaticFiles path '{self.directory}' is not a directory."
210 )
211
212 def is_not_modified(
213 self, response_headers: Headers, request_headers: Headers
214 ) -> bool:
215 """
216 Given the request and response headers, return `True` if an HTTP
217 "Not Modified" response could be returned instead.
218 """
219 try:
220 if_none_match = request_headers["if-none-match"]
221 etag = response_headers["etag"]
222 if etag in [tag.strip(" W/") for tag in if_none_match.split(",")]:
223 return True
224 except KeyError:
225 pass
226
227 try:
228 if_modified_since = parsedate(request_headers["if-modified-since"])
229 last_modified = parsedate(response_headers["last-modified"])
230 if (
231 if_modified_since is not None
232 and last_modified is not None
233 and if_modified_since >= last_modified
234 ):
235 return True
236 except KeyError:
237 pass
238
239 return False