1from __future__ import annotations
2
3import errno
4import importlib.util
5import os
6import stat
7import typing
8from email.utils import parsedate
9
10import anyio
11import anyio.to_thread
12
13from starlette._utils import get_route_path
14from starlette.datastructures import URL, Headers
15from starlette.exceptions import HTTPException
16from starlette.responses import FileResponse, RedirectResponse, Response
17from starlette.types import Receive, Scope, Send
18
19PathLike = typing.Union[str, "os.PathLike[str]"]
20
21
22class NotModifiedResponse(Response):
23 NOT_MODIFIED_HEADERS = (
24 "cache-control",
25 "content-location",
26 "date",
27 "etag",
28 "expires",
29 "vary",
30 )
31
32 def __init__(self, headers: Headers):
33 super().__init__(
34 status_code=304,
35 headers={name: value for name, value in headers.items() if name in self.NOT_MODIFIED_HEADERS},
36 )
37
38
39class StaticFiles:
40 def __init__(
41 self,
42 *,
43 directory: PathLike | None = None,
44 packages: list[str | tuple[str, str]] | None = None,
45 html: bool = False,
46 check_dir: bool = True,
47 follow_symlink: bool = False,
48 ) -> None:
49 self.directory = directory
50 self.packages = packages
51 self.all_directories = self.get_directories(directory, packages)
52 self.html = html
53 self.config_checked = False
54 self.follow_symlink = follow_symlink
55 if check_dir and directory is not None and not os.path.isdir(directory):
56 raise RuntimeError(f"Directory '{directory}' does not exist")
57
58 def get_directories(
59 self,
60 directory: PathLike | None = None,
61 packages: list[str | tuple[str, str]] | None = None,
62 ) -> list[PathLike]:
63 """
64 Given `directory` and `packages` arguments, return a list of all the
65 directories that should be used for serving static files from.
66 """
67 directories = []
68 if directory is not None:
69 directories.append(directory)
70
71 for package in packages or []:
72 if isinstance(package, tuple):
73 package, statics_dir = package
74 else:
75 statics_dir = "statics"
76 spec = importlib.util.find_spec(package)
77 assert spec is not None, f"Package {package!r} could not be found."
78 assert spec.origin is not None, f"Package {package!r} could not be found."
79 package_directory = os.path.normpath(os.path.join(spec.origin, "..", statics_dir))
80 assert os.path.isdir(
81 package_directory
82 ), f"Directory '{statics_dir!r}' in package {package!r} could not be found."
83 directories.append(package_directory)
84
85 return directories
86
87 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
88 """
89 The ASGI entry point.
90 """
91 assert scope["type"] == "http"
92
93 if not self.config_checked:
94 await self.check_config()
95 self.config_checked = True
96
97 path = self.get_path(scope)
98 response = await self.get_response(path, scope)
99 await response(scope, receive, send)
100
101 def get_path(self, scope: Scope) -> str:
102 """
103 Given the ASGI scope, return the `path` string to serve up,
104 with OS specific path separators, and any '..', '.' components removed.
105 """
106 route_path = get_route_path(scope)
107 return os.path.normpath(os.path.join(*route_path.split("/")))
108
109 async def get_response(self, path: str, scope: Scope) -> Response:
110 """
111 Returns an HTTP response, given the incoming path, method and request headers.
112 """
113 if scope["method"] not in ("GET", "HEAD"):
114 raise HTTPException(status_code=405)
115
116 try:
117 full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, path)
118 except PermissionError:
119 raise HTTPException(status_code=401)
120 except OSError as exc:
121 # Filename is too long, so it can't be a valid static file.
122 if exc.errno == errno.ENAMETOOLONG:
123 raise HTTPException(status_code=404)
124
125 raise exc
126
127 if stat_result and stat.S_ISREG(stat_result.st_mode):
128 # We have a static file to serve.
129 return self.file_response(full_path, stat_result, scope)
130
131 elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html:
132 # We're in HTML mode, and have got a directory URL.
133 # Check if we have 'index.html' file to serve.
134 index_path = os.path.join(path, "index.html")
135 full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, index_path)
136 if stat_result is not None and stat.S_ISREG(stat_result.st_mode):
137 if not scope["path"].endswith("/"):
138 # Directory URLs should redirect to always end in "/".
139 url = URL(scope=scope)
140 url = url.replace(path=url.path + "/")
141 return RedirectResponse(url=url)
142 return self.file_response(full_path, stat_result, scope)
143
144 if self.html:
145 # Check for '404.html' if we're in HTML mode.
146 full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, "404.html")
147 if stat_result and stat.S_ISREG(stat_result.st_mode):
148 return FileResponse(full_path, stat_result=stat_result, status_code=404)
149 raise HTTPException(status_code=404)
150
151 def lookup_path(self, path: str) -> tuple[str, os.stat_result | None]:
152 for directory in self.all_directories:
153 joined_path = os.path.join(directory, path)
154 if self.follow_symlink:
155 full_path = os.path.abspath(joined_path)
156 else:
157 full_path = os.path.realpath(joined_path)
158 directory = os.path.realpath(directory)
159 if os.path.commonpath([full_path, directory]) != directory:
160 # Don't allow misbehaving clients to break out of the static files
161 # directory.
162 continue
163 try:
164 return full_path, os.stat(full_path)
165 except (FileNotFoundError, NotADirectoryError):
166 continue
167 return "", None
168
169 def file_response(
170 self,
171 full_path: PathLike,
172 stat_result: os.stat_result,
173 scope: Scope,
174 status_code: int = 200,
175 ) -> Response:
176 request_headers = Headers(scope=scope)
177
178 response = FileResponse(full_path, status_code=status_code, stat_result=stat_result)
179 if self.is_not_modified(response.headers, request_headers):
180 return NotModifiedResponse(response.headers)
181 return response
182
183 async def check_config(self) -> None:
184 """
185 Perform a one-off configuration check that StaticFiles is actually
186 pointed at a directory, so that we can raise loud errors rather than
187 just returning 404 responses.
188 """
189 if self.directory is None:
190 return
191
192 try:
193 stat_result = await anyio.to_thread.run_sync(os.stat, self.directory)
194 except FileNotFoundError:
195 raise RuntimeError(f"StaticFiles directory '{self.directory}' does not exist.")
196 if not (stat.S_ISDIR(stat_result.st_mode) or stat.S_ISLNK(stat_result.st_mode)):
197 raise RuntimeError(f"StaticFiles path '{self.directory}' is not a directory.")
198
199 def is_not_modified(self, response_headers: Headers, request_headers: Headers) -> bool:
200 """
201 Given the request and response headers, return `True` if an HTTP
202 "Not Modified" response could be returned instead.
203 """
204 try:
205 if_none_match = request_headers["if-none-match"]
206 etag = response_headers["etag"]
207 if etag in [tag.strip(" W/") for tag in if_none_match.split(",")]:
208 return True
209 except KeyError:
210 pass
211
212 try:
213 if_modified_since = parsedate(request_headers["if-modified-since"])
214 last_modified = parsedate(response_headers["last-modified"])
215 if if_modified_since is not None and last_modified is not None and if_modified_since >= last_modified:
216 return True
217 except KeyError:
218 pass
219
220 return False