1from __future__ import annotations
2
3import html
4import inspect
5import sys
6import traceback
7
8from starlette._utils import is_async_callable
9from starlette.concurrency import run_in_threadpool
10from starlette.requests import Request
11from starlette.responses import HTMLResponse, PlainTextResponse, Response
12from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
13
14STYLES = """
15p {
16 color: #211c1c;
17}
18.traceback-container {
19 border: 1px solid #038BB8;
20}
21.traceback-title {
22 background-color: #038BB8;
23 color: lemonchiffon;
24 padding: 12px;
25 font-size: 20px;
26 margin-top: 0px;
27}
28.frame-line {
29 padding-left: 10px;
30 font-family: monospace;
31}
32.frame-filename {
33 font-family: monospace;
34}
35.center-line {
36 background-color: #038BB8;
37 color: #f9f6e1;
38 padding: 5px 0px 5px 5px;
39}
40.lineno {
41 margin-right: 5px;
42}
43.frame-title {
44 font-weight: unset;
45 padding: 10px 10px 10px 10px;
46 background-color: #E4F4FD;
47 margin-right: 10px;
48 color: #191f21;
49 font-size: 17px;
50 border: 1px solid #c7dce8;
51}
52.collapse-btn {
53 float: right;
54 padding: 0px 5px 1px 5px;
55 border: solid 1px #96aebb;
56 cursor: pointer;
57}
58.collapsed {
59 display: none;
60}
61.source-code {
62 font-family: courier;
63 font-size: small;
64 padding-bottom: 10px;
65}
66"""
67
68JS = """
69<script type="text/javascript">
70 function collapse(element){
71 const frameId = element.getAttribute("data-frame-id");
72 const frame = document.getElementById(frameId);
73
74 if (frame.classList.contains("collapsed")){
75 element.innerHTML = "‒";
76 frame.classList.remove("collapsed");
77 } else {
78 element.innerHTML = "+";
79 frame.classList.add("collapsed");
80 }
81 }
82</script>
83"""
84
85TEMPLATE = """
86<html>
87 <head>
88 <style type='text/css'>
89 {styles}
90 </style>
91 <title>Starlette Debugger</title>
92 </head>
93 <body>
94 <h1>500 Server Error</h1>
95 <h2>{error}</h2>
96 <div class="traceback-container">
97 <p class="traceback-title">Traceback</p>
98 <div>{exc_html}</div>
99 </div>
100 {js}
101 </body>
102</html>
103"""
104
105FRAME_TEMPLATE = """
106<div>
107 <p class="frame-title">File <span class="frame-filename">{frame_filename}</span>,
108 line <i>{frame_lineno}</i>,
109 in <b>{frame_name}</b>
110 <span class="collapse-btn" data-frame-id="{frame_filename}-{frame_lineno}" onclick="collapse(this)">{collapse_button}</span>
111 </p>
112 <div id="{frame_filename}-{frame_lineno}" class="source-code {collapsed}">{code_context}</div>
113</div>
114""" # noqa: E501
115
116LINE = """
117<p><span class="frame-line">
118<span class="lineno">{lineno}.</span> {line}</span></p>
119"""
120
121CENTER_LINE = """
122<p class="center-line"><span class="frame-line center-line">
123<span class="lineno">{lineno}.</span> {line}</span></p>
124"""
125
126
127class ServerErrorMiddleware:
128 """
129 Handles returning 500 responses when a server error occurs.
130
131 If 'debug' is set, then traceback responses will be returned,
132 otherwise the designated 'handler' will be called.
133
134 This middleware class should generally be used to wrap *everything*
135 else up, so that unhandled exceptions anywhere in the stack
136 always result in an appropriate 500 response.
137 """
138
139 def __init__(
140 self,
141 app: ASGIApp,
142 handler: ExceptionHandler | None = None,
143 debug: bool = False,
144 ) -> None:
145 self.app = app
146 self.handler = handler
147 self.debug = debug
148
149 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
150 if scope["type"] != "http":
151 await self.app(scope, receive, send)
152 return
153
154 response_started = False
155
156 async def _send(message: Message) -> None:
157 nonlocal response_started, send
158
159 if message["type"] == "http.response.start":
160 response_started = True
161 await send(message)
162
163 try:
164 await self.app(scope, receive, _send)
165 except Exception as exc:
166 request = Request(scope)
167 if self.debug:
168 # In debug mode, return traceback responses.
169 response = self.debug_response(request, exc)
170 elif self.handler is None:
171 # Use our default 500 error handler.
172 response = self.error_response(request, exc)
173 else:
174 # Use an installed 500 error handler.
175 if is_async_callable(self.handler):
176 response = await self.handler(request, exc)
177 else:
178 response = await run_in_threadpool(self.handler, request, exc)
179
180 if not response_started:
181 await response(scope, receive, send)
182
183 # We always continue to raise the exception.
184 # This allows servers to log the error, or allows test clients
185 # to optionally raise the error within the test case.
186 raise exc
187
188 def format_line(self, index: int, line: str, frame_lineno: int, frame_index: int) -> str:
189 values = {
190 # HTML escape - line could contain < or >
191 "line": html.escape(line).replace(" ", " "),
192 "lineno": (frame_lineno - frame_index) + index,
193 }
194
195 if index != frame_index:
196 return LINE.format(**values)
197 return CENTER_LINE.format(**values)
198
199 def generate_frame_html(self, frame: inspect.FrameInfo, is_collapsed: bool) -> str:
200 code_context = "".join(
201 self.format_line(
202 index,
203 line,
204 frame.lineno,
205 frame.index, # type: ignore[arg-type]
206 )
207 for index, line in enumerate(frame.code_context or [])
208 )
209
210 values = {
211 # HTML escape - filename could contain < or >, especially if it's a virtual
212 # file e.g. <stdin> in the REPL
213 "frame_filename": html.escape(frame.filename),
214 "frame_lineno": frame.lineno,
215 # HTML escape - if you try very hard it's possible to name a function with <
216 # or >
217 "frame_name": html.escape(frame.function),
218 "code_context": code_context,
219 "collapsed": "collapsed" if is_collapsed else "",
220 "collapse_button": "+" if is_collapsed else "‒",
221 }
222 return FRAME_TEMPLATE.format(**values)
223
224 def generate_html(self, exc: Exception, limit: int = 7) -> str:
225 traceback_obj = traceback.TracebackException.from_exception(exc, capture_locals=True)
226
227 exc_html = ""
228 is_collapsed = False
229 exc_traceback = exc.__traceback__
230 if exc_traceback is not None:
231 frames = inspect.getinnerframes(exc_traceback, limit)
232 for frame in reversed(frames):
233 exc_html += self.generate_frame_html(frame, is_collapsed)
234 is_collapsed = True
235
236 if sys.version_info >= (3, 13): # pragma: no cover
237 exc_type_str = traceback_obj.exc_type_str
238 else: # pragma: no cover
239 exc_type_str = traceback_obj.exc_type.__name__
240
241 # escape error class and text
242 error = f"{html.escape(exc_type_str)}: {html.escape(str(traceback_obj))}"
243
244 return TEMPLATE.format(styles=STYLES, js=JS, error=error, exc_html=exc_html)
245
246 def generate_plain_text(self, exc: Exception) -> str:
247 return "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
248
249 def debug_response(self, request: Request, exc: Exception) -> Response:
250 accept = request.headers.get("accept", "")
251
252 if "text/html" in accept:
253 content = self.generate_html(exc)
254 return HTMLResponse(content, status_code=500)
255 content = self.generate_plain_text(exc)
256 return PlainTextResponse(content, status_code=500)
257
258 def error_response(self, request: Request, exc: Exception) -> Response:
259 return PlainTextResponse("Internal Server Error", status_code=500)