1# traceback_exception_init() adapted from trio
2#
3# _ExceptionPrintContext and traceback_exception_format() copied from the standard
4# library
5from __future__ import annotations
6
7import collections.abc
8import sys
9import textwrap
10import traceback
11from functools import singledispatch
12from types import TracebackType
13from typing import Any, List, Optional
14
15from ._exceptions import BaseExceptionGroup
16
17max_group_width = 15
18max_group_depth = 10
19_cause_message = (
20 "\nThe above exception was the direct cause of the following exception:\n\n"
21)
22
23_context_message = (
24 "\nDuring handling of the above exception, another exception occurred:\n\n"
25)
26
27
28def _format_final_exc_line(etype, value):
29 valuestr = _safe_string(value, "exception")
30 if value is None or not valuestr:
31 line = f"{etype}\n"
32 else:
33 line = f"{etype}: {valuestr}\n"
34
35 return line
36
37
38def _safe_string(value, what, func=str):
39 try:
40 return func(value)
41 except BaseException:
42 return f"<{what} {func.__name__}() failed>"
43
44
45class _ExceptionPrintContext:
46 def __init__(self):
47 self.seen = set()
48 self.exception_group_depth = 0
49 self.need_close = False
50
51 def indent(self):
52 return " " * (2 * self.exception_group_depth)
53
54 def emit(self, text_gen, margin_char=None):
55 if margin_char is None:
56 margin_char = "|"
57 indent_str = self.indent()
58 if self.exception_group_depth:
59 indent_str += margin_char + " "
60
61 if isinstance(text_gen, str):
62 yield textwrap.indent(text_gen, indent_str, lambda line: True)
63 else:
64 for text in text_gen:
65 yield textwrap.indent(text, indent_str, lambda line: True)
66
67
68def exceptiongroup_excepthook(
69 etype: type[BaseException], value: BaseException, tb: TracebackType | None
70) -> None:
71 sys.stderr.write("".join(traceback.format_exception(etype, value, tb)))
72
73
74class PatchedTracebackException(traceback.TracebackException):
75 def __init__(
76 self,
77 exc_type: type[BaseException],
78 exc_value: BaseException,
79 exc_traceback: TracebackType | None,
80 *,
81 limit: int | None = None,
82 lookup_lines: bool = True,
83 capture_locals: bool = False,
84 compact: bool = False,
85 _seen: set[int] | None = None,
86 ) -> None:
87 kwargs: dict[str, Any] = {}
88 if sys.version_info >= (3, 10):
89 kwargs["compact"] = compact
90
91 is_recursive_call = _seen is not None
92 if _seen is None:
93 _seen = set()
94 _seen.add(id(exc_value))
95
96 self.stack = traceback.StackSummary.extract(
97 traceback.walk_tb(exc_traceback),
98 limit=limit,
99 lookup_lines=lookup_lines,
100 capture_locals=capture_locals,
101 )
102 self.exc_type = exc_type
103 # Capture now to permit freeing resources: only complication is in the
104 # unofficial API _format_final_exc_line
105 self._str = _safe_string(exc_value, "exception")
106 try:
107 self.__notes__ = getattr(exc_value, "__notes__", None)
108 except KeyError:
109 # Workaround for https://github.com/python/cpython/issues/98778 on Python
110 # <= 3.9, and some 3.10 and 3.11 patch versions.
111 HTTPError = getattr(sys.modules.get("urllib.error", None), "HTTPError", ())
112 if sys.version_info[:2] <= (3, 11) and isinstance(exc_value, HTTPError):
113 self.__notes__ = None
114 else:
115 raise
116
117 if exc_type and issubclass(exc_type, SyntaxError):
118 # Handle SyntaxError's specially
119 self.filename = exc_value.filename
120 lno = exc_value.lineno
121 self.lineno = str(lno) if lno is not None else None
122 self.text = exc_value.text
123 self.offset = exc_value.offset
124 self.msg = exc_value.msg
125 if sys.version_info >= (3, 10):
126 end_lno = exc_value.end_lineno
127 self.end_lineno = str(end_lno) if end_lno is not None else None
128 self.end_offset = exc_value.end_offset
129 elif (
130 exc_type
131 and issubclass(exc_type, (NameError, AttributeError))
132 and getattr(exc_value, "name", None) is not None
133 ):
134 suggestion = _compute_suggestion_error(exc_value, exc_traceback)
135 if suggestion:
136 self._str += f". Did you mean: '{suggestion}'?"
137
138 if lookup_lines:
139 # Force all lines in the stack to be loaded
140 for frame in self.stack:
141 frame.line
142
143 self.__suppress_context__ = (
144 exc_value.__suppress_context__ if exc_value is not None else False
145 )
146
147 # Convert __cause__ and __context__ to `TracebackExceptions`s, use a
148 # queue to avoid recursion (only the top-level call gets _seen == None)
149 if not is_recursive_call:
150 queue = [(self, exc_value)]
151 while queue:
152 te, e = queue.pop()
153
154 if e and e.__cause__ is not None and id(e.__cause__) not in _seen:
155 cause = PatchedTracebackException(
156 type(e.__cause__),
157 e.__cause__,
158 e.__cause__.__traceback__,
159 limit=limit,
160 lookup_lines=lookup_lines,
161 capture_locals=capture_locals,
162 _seen=_seen,
163 )
164 else:
165 cause = None
166
167 if compact:
168 need_context = (
169 cause is None and e is not None and not e.__suppress_context__
170 )
171 else:
172 need_context = True
173 if (
174 e
175 and e.__context__ is not None
176 and need_context
177 and id(e.__context__) not in _seen
178 ):
179 context = PatchedTracebackException(
180 type(e.__context__),
181 e.__context__,
182 e.__context__.__traceback__,
183 limit=limit,
184 lookup_lines=lookup_lines,
185 capture_locals=capture_locals,
186 _seen=_seen,
187 )
188 else:
189 context = None
190
191 # Capture each of the exceptions in the ExceptionGroup along with each
192 # of their causes and contexts
193 if e and isinstance(e, BaseExceptionGroup):
194 exceptions = []
195 for exc in e.exceptions:
196 texc = PatchedTracebackException(
197 type(exc),
198 exc,
199 exc.__traceback__,
200 lookup_lines=lookup_lines,
201 capture_locals=capture_locals,
202 _seen=_seen,
203 )
204 exceptions.append(texc)
205 else:
206 exceptions = None
207
208 te.__cause__ = cause
209 te.__context__ = context
210 te.exceptions = exceptions
211 if cause:
212 queue.append((te.__cause__, e.__cause__))
213 if context:
214 queue.append((te.__context__, e.__context__))
215 if exceptions:
216 queue.extend(zip(te.exceptions, e.exceptions))
217
218 def format(self, *, chain=True, _ctx=None):
219 if _ctx is None:
220 _ctx = _ExceptionPrintContext()
221
222 output = []
223 exc = self
224 if chain:
225 while exc:
226 if exc.__cause__ is not None:
227 chained_msg = _cause_message
228 chained_exc = exc.__cause__
229 elif exc.__context__ is not None and not exc.__suppress_context__:
230 chained_msg = _context_message
231 chained_exc = exc.__context__
232 else:
233 chained_msg = None
234 chained_exc = None
235
236 output.append((chained_msg, exc))
237 exc = chained_exc
238 else:
239 output.append((None, exc))
240
241 for msg, exc in reversed(output):
242 if msg is not None:
243 yield from _ctx.emit(msg)
244 if exc.exceptions is None:
245 if exc.stack:
246 yield from _ctx.emit("Traceback (most recent call last):\n")
247 yield from _ctx.emit(exc.stack.format())
248 yield from _ctx.emit(exc.format_exception_only())
249 elif _ctx.exception_group_depth > max_group_depth:
250 # exception group, but depth exceeds limit
251 yield from _ctx.emit(f"... (max_group_depth is {max_group_depth})\n")
252 else:
253 # format exception group
254 is_toplevel = _ctx.exception_group_depth == 0
255 if is_toplevel:
256 _ctx.exception_group_depth += 1
257
258 if exc.stack:
259 yield from _ctx.emit(
260 "Exception Group Traceback (most recent call last):\n",
261 margin_char="+" if is_toplevel else None,
262 )
263 yield from _ctx.emit(exc.stack.format())
264
265 yield from _ctx.emit(exc.format_exception_only())
266 num_excs = len(exc.exceptions)
267 if num_excs <= max_group_width:
268 n = num_excs
269 else:
270 n = max_group_width + 1
271 _ctx.need_close = False
272 for i in range(n):
273 last_exc = i == n - 1
274 if last_exc:
275 # The closing frame may be added by a recursive call
276 _ctx.need_close = True
277
278 if max_group_width is not None:
279 truncated = i >= max_group_width
280 else:
281 truncated = False
282 title = f"{i + 1}" if not truncated else "..."
283 yield (
284 _ctx.indent()
285 + ("+-" if i == 0 else " ")
286 + f"+---------------- {title} ----------------\n"
287 )
288 _ctx.exception_group_depth += 1
289 if not truncated:
290 yield from exc.exceptions[i].format(chain=chain, _ctx=_ctx)
291 else:
292 remaining = num_excs - max_group_width
293 plural = "s" if remaining > 1 else ""
294 yield from _ctx.emit(
295 f"and {remaining} more exception{plural}\n"
296 )
297
298 if last_exc and _ctx.need_close:
299 yield _ctx.indent() + "+------------------------------------\n"
300 _ctx.need_close = False
301 _ctx.exception_group_depth -= 1
302
303 if is_toplevel:
304 assert _ctx.exception_group_depth == 1
305 _ctx.exception_group_depth = 0
306
307 def format_exception_only(self):
308 """Format the exception part of the traceback.
309 The return value is a generator of strings, each ending in a newline.
310 Normally, the generator emits a single string; however, for
311 SyntaxError exceptions, it emits several lines that (when
312 printed) display detailed information about where the syntax
313 error occurred.
314 The message indicating which exception occurred is always the last
315 string in the output.
316 """
317 if self.exc_type is None:
318 yield traceback._format_final_exc_line(None, self._str)
319 return
320
321 stype = self.exc_type.__qualname__
322 smod = self.exc_type.__module__
323 if smod not in ("__main__", "builtins"):
324 if not isinstance(smod, str):
325 smod = "<unknown>"
326 stype = smod + "." + stype
327
328 if not issubclass(self.exc_type, SyntaxError):
329 yield _format_final_exc_line(stype, self._str)
330 elif traceback_exception_format_syntax_error is not None:
331 yield from traceback_exception_format_syntax_error(self, stype)
332 else:
333 yield from traceback_exception_original_format_exception_only(self)
334
335 if isinstance(self.__notes__, collections.abc.Sequence):
336 for note in self.__notes__:
337 note = _safe_string(note, "note")
338 yield from [line + "\n" for line in note.split("\n")]
339 elif self.__notes__ is not None:
340 yield _safe_string(self.__notes__, "__notes__", func=repr)
341
342
343traceback_exception_original_format = traceback.TracebackException.format
344traceback_exception_original_format_exception_only = (
345 traceback.TracebackException.format_exception_only
346)
347traceback_exception_format_syntax_error = getattr(
348 traceback.TracebackException, "_format_syntax_error", None
349)
350if sys.excepthook is sys.__excepthook__:
351 traceback.TracebackException.__init__ = ( # type: ignore[assignment]
352 PatchedTracebackException.__init__
353 )
354 traceback.TracebackException.format = ( # type: ignore[assignment]
355 PatchedTracebackException.format
356 )
357 traceback.TracebackException.format_exception_only = ( # type: ignore[assignment]
358 PatchedTracebackException.format_exception_only
359 )
360 sys.excepthook = exceptiongroup_excepthook
361
362# Ubuntu's system Python has a sitecustomize.py file that imports
363# apport_python_hook and replaces sys.excepthook.
364#
365# The custom hook captures the error for crash reporting, and then calls
366# sys.__excepthook__ to actually print the error.
367#
368# We don't mind it capturing the error for crash reporting, but we want to
369# take over printing the error. So we monkeypatch the apport_python_hook
370# module so that instead of calling sys.__excepthook__, it calls our custom
371# hook.
372#
373# More details: https://github.com/python-trio/trio/issues/1065
374if getattr(sys.excepthook, "__name__", None) in (
375 "apport_excepthook",
376 # on ubuntu 22.10 the hook was renamed to partial_apport_excepthook
377 "partial_apport_excepthook",
378):
379 # patch traceback like above
380 traceback.TracebackException.__init__ = ( # type: ignore[assignment]
381 PatchedTracebackException.__init__
382 )
383 traceback.TracebackException.format = ( # type: ignore[assignment]
384 PatchedTracebackException.format
385 )
386 traceback.TracebackException.format_exception_only = ( # type: ignore[assignment]
387 PatchedTracebackException.format_exception_only
388 )
389
390 from types import ModuleType
391
392 import apport_python_hook
393
394 assert sys.excepthook is apport_python_hook.apport_excepthook
395
396 # monkeypatch the sys module that apport has imported
397 fake_sys = ModuleType("exceptiongroup_fake_sys")
398 fake_sys.__dict__.update(sys.__dict__)
399 fake_sys.__excepthook__ = exceptiongroup_excepthook
400 apport_python_hook.sys = fake_sys
401
402
403@singledispatch
404def format_exception_only(__exc: BaseException) -> List[str]:
405 return list(
406 PatchedTracebackException(
407 type(__exc), __exc, None, compact=True
408 ).format_exception_only()
409 )
410
411
412@format_exception_only.register
413def _(__exc: type, value: BaseException) -> List[str]:
414 return format_exception_only(value)
415
416
417@singledispatch
418def format_exception(
419 __exc: BaseException,
420 limit: Optional[int] = None,
421 chain: bool = True,
422) -> List[str]:
423 return list(
424 PatchedTracebackException(
425 type(__exc), __exc, __exc.__traceback__, limit=limit, compact=True
426 ).format(chain=chain)
427 )
428
429
430@format_exception.register
431def _(
432 __exc: type,
433 value: BaseException,
434 tb: TracebackType,
435 limit: Optional[int] = None,
436 chain: bool = True,
437) -> List[str]:
438 return format_exception(value, limit, chain)
439
440
441@singledispatch
442def print_exception(
443 __exc: BaseException,
444 limit: Optional[int] = None,
445 file: Any = None,
446 chain: bool = True,
447) -> None:
448 if file is None:
449 file = sys.stderr
450
451 for line in PatchedTracebackException(
452 type(__exc), __exc, __exc.__traceback__, limit=limit
453 ).format(chain=chain):
454 print(line, file=file, end="")
455
456
457@print_exception.register
458def _(
459 __exc: type,
460 value: BaseException,
461 tb: TracebackType,
462 limit: Optional[int] = None,
463 file: Any = None,
464 chain: bool = True,
465) -> None:
466 print_exception(value, limit, file, chain)
467
468
469def print_exc(
470 limit: Optional[int] = None,
471 file: Any | None = None,
472 chain: bool = True,
473) -> None:
474 value = sys.exc_info()[1]
475 print_exception(value, limit, file, chain)
476
477
478# Python levenshtein edit distance code for NameError/AttributeError
479# suggestions, backported from 3.12
480
481_MAX_CANDIDATE_ITEMS = 750
482_MAX_STRING_SIZE = 40
483_MOVE_COST = 2
484_CASE_COST = 1
485_SENTINEL = object()
486
487
488def _substitution_cost(ch_a, ch_b):
489 if ch_a == ch_b:
490 return 0
491 if ch_a.lower() == ch_b.lower():
492 return _CASE_COST
493 return _MOVE_COST
494
495
496def _compute_suggestion_error(exc_value, tb):
497 wrong_name = getattr(exc_value, "name", None)
498 if wrong_name is None or not isinstance(wrong_name, str):
499 return None
500 if isinstance(exc_value, AttributeError):
501 obj = getattr(exc_value, "obj", _SENTINEL)
502 if obj is _SENTINEL:
503 return None
504 obj = exc_value.obj
505 try:
506 d = dir(obj)
507 except Exception:
508 return None
509 else:
510 assert isinstance(exc_value, NameError)
511 # find most recent frame
512 if tb is None:
513 return None
514 while tb.tb_next is not None:
515 tb = tb.tb_next
516 frame = tb.tb_frame
517
518 d = list(frame.f_locals) + list(frame.f_globals) + list(frame.f_builtins)
519 if len(d) > _MAX_CANDIDATE_ITEMS:
520 return None
521 wrong_name_len = len(wrong_name)
522 if wrong_name_len > _MAX_STRING_SIZE:
523 return None
524 best_distance = wrong_name_len
525 suggestion = None
526 for possible_name in d:
527 if possible_name == wrong_name:
528 # A missing attribute is "found". Don't suggest it (see GH-88821).
529 continue
530 # No more than 1/3 of the involved characters should need changed.
531 max_distance = (len(possible_name) + wrong_name_len + 3) * _MOVE_COST // 6
532 # Don't take matches we've already beaten.
533 max_distance = min(max_distance, best_distance - 1)
534 current_distance = _levenshtein_distance(
535 wrong_name, possible_name, max_distance
536 )
537 if current_distance > max_distance:
538 continue
539 if not suggestion or current_distance < best_distance:
540 suggestion = possible_name
541 best_distance = current_distance
542 return suggestion
543
544
545def _levenshtein_distance(a, b, max_cost):
546 # A Python implementation of Python/suggestions.c:levenshtein_distance.
547
548 # Both strings are the same
549 if a == b:
550 return 0
551
552 # Trim away common affixes
553 pre = 0
554 while a[pre:] and b[pre:] and a[pre] == b[pre]:
555 pre += 1
556 a = a[pre:]
557 b = b[pre:]
558 post = 0
559 while a[: post or None] and b[: post or None] and a[post - 1] == b[post - 1]:
560 post -= 1
561 a = a[: post or None]
562 b = b[: post or None]
563 if not a or not b:
564 return _MOVE_COST * (len(a) + len(b))
565 if len(a) > _MAX_STRING_SIZE or len(b) > _MAX_STRING_SIZE:
566 return max_cost + 1
567
568 # Prefer shorter buffer
569 if len(b) < len(a):
570 a, b = b, a
571
572 # Quick fail when a match is impossible
573 if (len(b) - len(a)) * _MOVE_COST > max_cost:
574 return max_cost + 1
575
576 # Instead of producing the whole traditional len(a)-by-len(b)
577 # matrix, we can update just one row in place.
578 # Initialize the buffer row
579 row = list(range(_MOVE_COST, _MOVE_COST * (len(a) + 1), _MOVE_COST))
580
581 result = 0
582 for bindex in range(len(b)):
583 bchar = b[bindex]
584 distance = result = bindex * _MOVE_COST
585 minimum = sys.maxsize
586 for index in range(len(a)):
587 # 1) Previous distance in this row is cost(b[:b_index], a[:index])
588 substitute = distance + _substitution_cost(bchar, a[index])
589 # 2) cost(b[:b_index], a[:index+1]) from previous row
590 distance = row[index]
591 # 3) existing result is cost(b[:b_index+1], a[index])
592
593 insert_delete = min(result, distance) + _MOVE_COST
594 result = min(insert_delete, substitute)
595
596 # cost(b[:b_index+1], a[:index+1])
597 row[index] = result
598 if result < minimum:
599 minimum = result
600 if minimum > max_cost:
601 # Everything in this row is too big, so bail early.
602 return max_cost + 1
603 return result