1from __future__ import annotations
2
3import html
4import inspect
5import traceback
6import typing
7
8from starlette._utils import is_async_callable
9from starlette.concurrency import run_in_threadpool
10from starlette.requests import Request
11from starlette.responses import HTMLResponse, PlainTextResponse, Response
12from starlette.types import ASGIApp, Message, Receive, Scope, Send
13
14STYLES = """
15p {
16 color: #211c1c;
17}
18.traceback-container {
19 border: 1px solid #038BB8;
20}
21.traceback-title {
22 background-color: #038BB8;
23 color: lemonchiffon;
24 padding: 12px;
25 font-size: 20px;
26 margin-top: 0px;
27}
28.frame-line {
29 padding-left: 10px;
30 font-family: monospace;
31}
32.frame-filename {
33 font-family: monospace;
34}
35.center-line {
36 background-color: #038BB8;
37 color: #f9f6e1;
38 padding: 5px 0px 5px 5px;
39}
40.lineno {
41 margin-right: 5px;
42}
43.frame-title {
44 font-weight: unset;
45 padding: 10px 10px 10px 10px;
46 background-color: #E4F4FD;
47 margin-right: 10px;
48 color: #191f21;
49 font-size: 17px;
50 border: 1px solid #c7dce8;
51}
52.collapse-btn {
53 float: right;
54 padding: 0px 5px 1px 5px;
55 border: solid 1px #96aebb;
56 cursor: pointer;
57}
58.collapsed {
59 display: none;
60}
61.source-code {
62 font-family: courier;
63 font-size: small;
64 padding-bottom: 10px;
65}
66"""
67
68JS = """
69<script type="text/javascript">
70 function collapse(element){
71 const frameId = element.getAttribute("data-frame-id");
72 const frame = document.getElementById(frameId);
73
74 if (frame.classList.contains("collapsed")){
75 element.innerHTML = "‒";
76 frame.classList.remove("collapsed");
77 } else {
78 element.innerHTML = "+";
79 frame.classList.add("collapsed");
80 }
81 }
82</script>
83"""
84
85TEMPLATE = """
86<html>
87 <head>
88 <style type='text/css'>
89 {styles}
90 </style>
91 <title>Starlette Debugger</title>
92 </head>
93 <body>
94 <h1>500 Server Error</h1>
95 <h2>{error}</h2>
96 <div class="traceback-container">
97 <p class="traceback-title">Traceback</p>
98 <div>{exc_html}</div>
99 </div>
100 {js}
101 </body>
102</html>
103"""
104
105FRAME_TEMPLATE = """
106<div>
107 <p class="frame-title">File <span class="frame-filename">{frame_filename}</span>,
108 line <i>{frame_lineno}</i>,
109 in <b>{frame_name}</b>
110 <span class="collapse-btn" data-frame-id="{frame_filename}-{frame_lineno}" onclick="collapse(this)">{collapse_button}</span>
111 </p>
112 <div id="{frame_filename}-{frame_lineno}" class="source-code {collapsed}">{code_context}</div>
113</div>
114""" # noqa: E501
115
116LINE = """
117<p><span class="frame-line">
118<span class="lineno">{lineno}.</span> {line}</span></p>
119"""
120
121CENTER_LINE = """
122<p class="center-line"><span class="frame-line center-line">
123<span class="lineno">{lineno}.</span> {line}</span></p>
124"""
125
126
127class ServerErrorMiddleware:
128 """
129 Handles returning 500 responses when a server error occurs.
130
131 If 'debug' is set, then traceback responses will be returned,
132 otherwise the designated 'handler' will be called.
133
134 This middleware class should generally be used to wrap *everything*
135 else up, so that unhandled exceptions anywhere in the stack
136 always result in an appropriate 500 response.
137 """
138
139 def __init__(
140 self,
141 app: ASGIApp,
142 handler: typing.Callable[[Request, Exception], typing.Any] | None = None,
143 debug: bool = False,
144 ) -> None:
145 self.app = app
146 self.handler = handler
147 self.debug = debug
148
149 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
150 if scope["type"] != "http":
151 await self.app(scope, receive, send)
152 return
153
154 response_started = False
155
156 async def _send(message: Message) -> None:
157 nonlocal response_started, send
158
159 if message["type"] == "http.response.start":
160 response_started = True
161 await send(message)
162
163 try:
164 await self.app(scope, receive, _send)
165 except Exception as exc:
166 request = Request(scope)
167 if self.debug:
168 # In debug mode, return traceback responses.
169 response = self.debug_response(request, exc)
170 elif self.handler is None:
171 # Use our default 500 error handler.
172 response = self.error_response(request, exc)
173 else:
174 # Use an installed 500 error handler.
175 if is_async_callable(self.handler):
176 response = await self.handler(request, exc)
177 else:
178 response = await run_in_threadpool(self.handler, request, exc)
179
180 if not response_started:
181 await response(scope, receive, send)
182
183 # We always continue to raise the exception.
184 # This allows servers to log the error, or allows test clients
185 # to optionally raise the error within the test case.
186 raise exc
187
188 def format_line(
189 self, index: int, line: str, frame_lineno: int, frame_index: int
190 ) -> str:
191 values = {
192 # HTML escape - line could contain < or >
193 "line": html.escape(line).replace(" ", " "),
194 "lineno": (frame_lineno - frame_index) + index,
195 }
196
197 if index != frame_index:
198 return LINE.format(**values)
199 return CENTER_LINE.format(**values)
200
201 def generate_frame_html(self, frame: inspect.FrameInfo, is_collapsed: bool) -> str:
202 code_context = "".join(
203 self.format_line(
204 index,
205 line,
206 frame.lineno,
207 frame.index, # type: ignore[arg-type]
208 )
209 for index, line in enumerate(frame.code_context or [])
210 )
211
212 values = {
213 # HTML escape - filename could contain < or >, especially if it's a virtual
214 # file e.g. <stdin> in the REPL
215 "frame_filename": html.escape(frame.filename),
216 "frame_lineno": frame.lineno,
217 # HTML escape - if you try very hard it's possible to name a function with <
218 # or >
219 "frame_name": html.escape(frame.function),
220 "code_context": code_context,
221 "collapsed": "collapsed" if is_collapsed else "",
222 "collapse_button": "+" if is_collapsed else "‒",
223 }
224 return FRAME_TEMPLATE.format(**values)
225
226 def generate_html(self, exc: Exception, limit: int = 7) -> str:
227 traceback_obj = traceback.TracebackException.from_exception(
228 exc, capture_locals=True
229 )
230
231 exc_html = ""
232 is_collapsed = False
233 exc_traceback = exc.__traceback__
234 if exc_traceback is not None:
235 frames = inspect.getinnerframes(exc_traceback, limit)
236 for frame in reversed(frames):
237 exc_html += self.generate_frame_html(frame, is_collapsed)
238 is_collapsed = True
239
240 # escape error class and text
241 error = (
242 f"{html.escape(traceback_obj.exc_type.__name__)}: "
243 f"{html.escape(str(traceback_obj))}"
244 )
245
246 return TEMPLATE.format(styles=STYLES, js=JS, error=error, exc_html=exc_html)
247
248 def generate_plain_text(self, exc: Exception) -> str:
249 return "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
250
251 def debug_response(self, request: Request, exc: Exception) -> Response:
252 accept = request.headers.get("accept", "")
253
254 if "text/html" in accept:
255 content = self.generate_html(exc)
256 return HTMLResponse(content, status_code=500)
257 content = self.generate_plain_text(exc)
258 return PlainTextResponse(content, status_code=500)
259
260 def error_response(self, request: Request, exc: Exception) -> Response:
261 return PlainTextResponse("Internal Server Error", status_code=500)