1# traceback_exception_init() adapted from trio
2#
3# _ExceptionPrintContext and traceback_exception_format() copied from the standard
4# library
5from __future__ import annotations
6
7import collections.abc
8import sys
9import textwrap
10import traceback
11from functools import singledispatch
12from types import TracebackType
13from typing import Any, List, Optional
14
15from ._exceptions import BaseExceptionGroup
16
17max_group_width = 15
18max_group_depth = 10
19_cause_message = (
20 "\nThe above exception was the direct cause of the following exception:\n\n"
21)
22
23_context_message = (
24 "\nDuring handling of the above exception, another exception occurred:\n\n"
25)
26
27
28def _format_final_exc_line(etype, value):
29 valuestr = _safe_string(value, "exception")
30 if value is None or not valuestr:
31 line = f"{etype}\n"
32 else:
33 line = f"{etype}: {valuestr}\n"
34
35 return line
36
37
38def _safe_string(value, what, func=str):
39 try:
40 return func(value)
41 except BaseException:
42 return f"<{what} {func.__name__}() failed>"
43
44
45class _ExceptionPrintContext:
46 def __init__(self):
47 self.seen = set()
48 self.exception_group_depth = 0
49 self.need_close = False
50
51 def indent(self):
52 return " " * (2 * self.exception_group_depth)
53
54 def emit(self, text_gen, margin_char=None):
55 if margin_char is None:
56 margin_char = "|"
57 indent_str = self.indent()
58 if self.exception_group_depth:
59 indent_str += margin_char + " "
60
61 if isinstance(text_gen, str):
62 yield textwrap.indent(text_gen, indent_str, lambda line: True)
63 else:
64 for text in text_gen:
65 yield textwrap.indent(text, indent_str, lambda line: True)
66
67
68def exceptiongroup_excepthook(
69 etype: type[BaseException], value: BaseException, tb: TracebackType | None
70) -> None:
71 sys.stderr.write("".join(traceback.format_exception(etype, value, tb)))
72
73
74class PatchedTracebackException(traceback.TracebackException):
75 def __init__(
76 self,
77 exc_type: type[BaseException],
78 exc_value: BaseException,
79 exc_traceback: TracebackType | None,
80 *,
81 limit: int | None = None,
82 lookup_lines: bool = True,
83 capture_locals: bool = False,
84 compact: bool = False,
85 _seen: set[int] | None = None,
86 ) -> None:
87 kwargs: dict[str, Any] = {}
88 if sys.version_info >= (3, 10):
89 kwargs["compact"] = compact
90
91 is_recursive_call = _seen is not None
92 if _seen is None:
93 _seen = set()
94 _seen.add(id(exc_value))
95
96 self.stack = traceback.StackSummary.extract(
97 traceback.walk_tb(exc_traceback),
98 limit=limit,
99 lookup_lines=lookup_lines,
100 capture_locals=capture_locals,
101 )
102 self.exc_type = exc_type
103 # Capture now to permit freeing resources: only complication is in the
104 # unofficial API _format_final_exc_line
105 self._str = _safe_string(exc_value, "exception")
106 try:
107 self.__notes__ = getattr(exc_value, "__notes__", None)
108 except KeyError:
109 # Workaround for https://github.com/python/cpython/issues/98778 on Python
110 # <= 3.9, and some 3.10 and 3.11 patch versions.
111 HTTPError = getattr(sys.modules.get("urllib.error", None), "HTTPError", ())
112 if sys.version_info[:2] <= (3, 11) and isinstance(exc_value, HTTPError):
113 self.__notes__ = None
114 else:
115 raise
116
117 if exc_type and issubclass(exc_type, SyntaxError):
118 # Handle SyntaxError's specially
119 self.filename = exc_value.filename
120 lno = exc_value.lineno
121 self.lineno = str(lno) if lno is not None else None
122 self.text = exc_value.text
123 self.offset = exc_value.offset
124 self.msg = exc_value.msg
125 if sys.version_info >= (3, 10):
126 end_lno = exc_value.end_lineno
127 self.end_lineno = str(end_lno) if end_lno is not None else None
128 self.end_offset = exc_value.end_offset
129 elif (
130 exc_type
131 and issubclass(exc_type, (NameError, AttributeError))
132 and getattr(exc_value, "name", None) is not None
133 ):
134 suggestion = _compute_suggestion_error(exc_value, exc_traceback)
135 if suggestion:
136 self._str += f". Did you mean: '{suggestion}'?"
137
138 if lookup_lines:
139 # Force all lines in the stack to be loaded
140 for frame in self.stack:
141 frame.line
142
143 self.__suppress_context__ = (
144 exc_value.__suppress_context__ if exc_value is not None else False
145 )
146
147 # Convert __cause__ and __context__ to `TracebackExceptions`s, use a
148 # queue to avoid recursion (only the top-level call gets _seen == None)
149 if not is_recursive_call:
150 queue = [(self, exc_value)]
151 while queue:
152 te, e = queue.pop()
153
154 if e and e.__cause__ is not None and id(e.__cause__) not in _seen:
155 cause = PatchedTracebackException(
156 type(e.__cause__),
157 e.__cause__,
158 e.__cause__.__traceback__,
159 limit=limit,
160 lookup_lines=lookup_lines,
161 capture_locals=capture_locals,
162 _seen=_seen,
163 )
164 else:
165 cause = None
166
167 if compact:
168 need_context = (
169 cause is None and e is not None and not e.__suppress_context__
170 )
171 else:
172 need_context = True
173 if (
174 e
175 and e.__context__ is not None
176 and need_context
177 and id(e.__context__) not in _seen
178 ):
179 context = PatchedTracebackException(
180 type(e.__context__),
181 e.__context__,
182 e.__context__.__traceback__,
183 limit=limit,
184 lookup_lines=lookup_lines,
185 capture_locals=capture_locals,
186 _seen=_seen,
187 )
188 else:
189 context = None
190
191 # Capture each of the exceptions in the ExceptionGroup along with each
192 # of their causes and contexts
193 if e and isinstance(e, BaseExceptionGroup):
194 exceptions = []
195 for exc in e.exceptions:
196 texc = PatchedTracebackException(
197 type(exc),
198 exc,
199 exc.__traceback__,
200 lookup_lines=lookup_lines,
201 capture_locals=capture_locals,
202 _seen=_seen,
203 )
204 exceptions.append(texc)
205 else:
206 exceptions = None
207
208 te.__cause__ = cause
209 te.__context__ = context
210 te.exceptions = exceptions
211 if cause:
212 queue.append((te.__cause__, e.__cause__))
213 if context:
214 queue.append((te.__context__, e.__context__))
215 if exceptions:
216 queue.extend(zip(te.exceptions, e.exceptions))
217
218 def format(self, *, chain=True, _ctx=None):
219 if _ctx is None:
220 _ctx = _ExceptionPrintContext()
221
222 output = []
223 exc = self
224 if chain:
225 while exc:
226 if exc.__cause__ is not None:
227 chained_msg = _cause_message
228 chained_exc = exc.__cause__
229 elif exc.__context__ is not None and not exc.__suppress_context__:
230 chained_msg = _context_message
231 chained_exc = exc.__context__
232 else:
233 chained_msg = None
234 chained_exc = None
235
236 output.append((chained_msg, exc))
237 exc = chained_exc
238 else:
239 output.append((None, exc))
240
241 for msg, exc in reversed(output):
242 if msg is not None:
243 yield from _ctx.emit(msg)
244 if exc.exceptions is None:
245 if exc.stack:
246 yield from _ctx.emit("Traceback (most recent call last):\n")
247 yield from _ctx.emit(exc.stack.format())
248 yield from _ctx.emit(exc.format_exception_only())
249 elif _ctx.exception_group_depth > max_group_depth:
250 # exception group, but depth exceeds limit
251 yield from _ctx.emit(f"... (max_group_depth is {max_group_depth})\n")
252 else:
253 # format exception group
254 is_toplevel = _ctx.exception_group_depth == 0
255 if is_toplevel:
256 _ctx.exception_group_depth += 1
257
258 if exc.stack:
259 yield from _ctx.emit(
260 "Exception Group Traceback (most recent call last):\n",
261 margin_char="+" if is_toplevel else None,
262 )
263 yield from _ctx.emit(exc.stack.format())
264
265 yield from _ctx.emit(exc.format_exception_only())
266 num_excs = len(exc.exceptions)
267 if num_excs <= max_group_width:
268 n = num_excs
269 else:
270 n = max_group_width + 1
271 _ctx.need_close = False
272 for i in range(n):
273 last_exc = i == n - 1
274 if last_exc:
275 # The closing frame may be added by a recursive call
276 _ctx.need_close = True
277
278 if max_group_width is not None:
279 truncated = i >= max_group_width
280 else:
281 truncated = False
282 title = f"{i + 1}" if not truncated else "..."
283 yield (
284 _ctx.indent()
285 + ("+-" if i == 0 else " ")
286 + f"+---------------- {title} ----------------\n"
287 )
288 _ctx.exception_group_depth += 1
289 if not truncated:
290 yield from exc.exceptions[i].format(chain=chain, _ctx=_ctx)
291 else:
292 remaining = num_excs - max_group_width
293 plural = "s" if remaining > 1 else ""
294 yield from _ctx.emit(
295 f"and {remaining} more exception{plural}\n"
296 )
297
298 if last_exc and _ctx.need_close:
299 yield _ctx.indent() + "+------------------------------------\n"
300 _ctx.need_close = False
301 _ctx.exception_group_depth -= 1
302
303 if is_toplevel:
304 assert _ctx.exception_group_depth == 1
305 _ctx.exception_group_depth = 0
306
307 def format_exception_only(self):
308 """Format the exception part of the traceback.
309 The return value is a generator of strings, each ending in a newline.
310 Normally, the generator emits a single string; however, for
311 SyntaxError exceptions, it emits several lines that (when
312 printed) display detailed information about where the syntax
313 error occurred.
314 The message indicating which exception occurred is always the last
315 string in the output.
316 """
317 if self.exc_type is None:
318 yield traceback._format_final_exc_line(None, self._str)
319 return
320
321 stype = self.exc_type.__qualname__
322 smod = self.exc_type.__module__
323 if smod not in ("__main__", "builtins"):
324 if not isinstance(smod, str):
325 smod = "<unknown>"
326 stype = smod + "." + stype
327
328 if not issubclass(self.exc_type, SyntaxError):
329 yield _format_final_exc_line(stype, self._str)
330 elif traceback_exception_format_syntax_error is not None:
331 yield from traceback_exception_format_syntax_error(self, stype)
332 else:
333 yield from traceback_exception_original_format_exception_only(self)
334
335 if isinstance(self.__notes__, collections.abc.Sequence):
336 for note in self.__notes__:
337 note = _safe_string(note, "note")
338 yield from [line + "\n" for line in note.split("\n")]
339 elif self.__notes__ is not None:
340 yield _safe_string(self.__notes__, "__notes__", func=repr)
341
342
343traceback_exception_original_format = traceback.TracebackException.format
344traceback_exception_original_format_exception_only = (
345 traceback.TracebackException.format_exception_only
346)
347traceback_exception_format_syntax_error = getattr(
348 traceback.TracebackException, "_format_syntax_error", None
349)
350if sys.excepthook is sys.__excepthook__:
351 traceback.TracebackException.__init__ = ( # type: ignore[assignment]
352 PatchedTracebackException.__init__
353 )
354 traceback.TracebackException.format = ( # type: ignore[assignment]
355 PatchedTracebackException.format
356 )
357 traceback.TracebackException.format_exception_only = ( # type: ignore[assignment]
358 PatchedTracebackException.format_exception_only
359 )
360 sys.excepthook = exceptiongroup_excepthook
361
362# Ubuntu's system Python has a sitecustomize.py file that imports
363# apport_python_hook and replaces sys.excepthook.
364#
365# The custom hook captures the error for crash reporting, and then calls
366# sys.__excepthook__ to actually print the error.
367#
368# We don't mind it capturing the error for crash reporting, but we want to
369# take over printing the error. So we monkeypatch the apport_python_hook
370# module so that instead of calling sys.__excepthook__, it calls our custom
371# hook.
372#
373# More details: https://github.com/python-trio/trio/issues/1065
374if getattr(sys.excepthook, "__name__", None) in (
375 "apport_excepthook",
376 # on ubuntu 22.10 the hook was renamed to partial_apport_excepthook
377 "partial_apport_excepthook",
378):
379 # patch traceback like above
380 traceback.TracebackException.__init__ = ( # type: ignore[assignment]
381 PatchedTracebackException.__init__
382 )
383 traceback.TracebackException.format = ( # type: ignore[assignment]
384 PatchedTracebackException.format
385 )
386 traceback.TracebackException.format_exception_only = ( # type: ignore[assignment]
387 PatchedTracebackException.format_exception_only
388 )
389
390 from types import ModuleType
391
392 import apport_python_hook
393
394 # monkeypatch the sys module that apport has imported
395 fake_sys = ModuleType("exceptiongroup_fake_sys")
396 fake_sys.__dict__.update(sys.__dict__)
397 fake_sys.__excepthook__ = exceptiongroup_excepthook
398 apport_python_hook.sys = fake_sys
399
400
401@singledispatch
402def format_exception_only(__exc: BaseException) -> List[str]:
403 return list(
404 PatchedTracebackException(
405 type(__exc), __exc, None, compact=True
406 ).format_exception_only()
407 )
408
409
410@format_exception_only.register
411def _(__exc: type, value: BaseException) -> List[str]:
412 return format_exception_only(value)
413
414
415@singledispatch
416def format_exception(
417 __exc: BaseException,
418 limit: Optional[int] = None,
419 chain: bool = True,
420) -> List[str]:
421 return list(
422 PatchedTracebackException(
423 type(__exc), __exc, __exc.__traceback__, limit=limit, compact=True
424 ).format(chain=chain)
425 )
426
427
428@format_exception.register
429def _(
430 __exc: type,
431 value: BaseException,
432 tb: TracebackType,
433 limit: Optional[int] = None,
434 chain: bool = True,
435) -> List[str]:
436 return format_exception(value, limit, chain)
437
438
439@singledispatch
440def print_exception(
441 __exc: BaseException,
442 limit: Optional[int] = None,
443 file: Any = None,
444 chain: bool = True,
445) -> None:
446 if file is None:
447 file = sys.stderr
448
449 for line in PatchedTracebackException(
450 type(__exc), __exc, __exc.__traceback__, limit=limit
451 ).format(chain=chain):
452 print(line, file=file, end="")
453
454
455@print_exception.register
456def _(
457 __exc: type,
458 value: BaseException,
459 tb: TracebackType,
460 limit: Optional[int] = None,
461 file: Any = None,
462 chain: bool = True,
463) -> None:
464 print_exception(value, limit, file, chain)
465
466
467def print_exc(
468 limit: Optional[int] = None,
469 file: Any | None = None,
470 chain: bool = True,
471) -> None:
472 value = sys.exc_info()[1]
473 print_exception(value, limit, file, chain)
474
475
476# Python levenshtein edit distance code for NameError/AttributeError
477# suggestions, backported from 3.12
478
479_MAX_CANDIDATE_ITEMS = 750
480_MAX_STRING_SIZE = 40
481_MOVE_COST = 2
482_CASE_COST = 1
483_SENTINEL = object()
484
485
486def _substitution_cost(ch_a, ch_b):
487 if ch_a == ch_b:
488 return 0
489 if ch_a.lower() == ch_b.lower():
490 return _CASE_COST
491 return _MOVE_COST
492
493
494def _compute_suggestion_error(exc_value, tb):
495 wrong_name = getattr(exc_value, "name", None)
496 if wrong_name is None or not isinstance(wrong_name, str):
497 return None
498 if isinstance(exc_value, AttributeError):
499 obj = getattr(exc_value, "obj", _SENTINEL)
500 if obj is _SENTINEL:
501 return None
502 obj = exc_value.obj
503 try:
504 d = dir(obj)
505 except Exception:
506 return None
507 else:
508 assert isinstance(exc_value, NameError)
509 # find most recent frame
510 if tb is None:
511 return None
512 while tb.tb_next is not None:
513 tb = tb.tb_next
514 frame = tb.tb_frame
515
516 d = list(frame.f_locals) + list(frame.f_globals) + list(frame.f_builtins)
517 if len(d) > _MAX_CANDIDATE_ITEMS:
518 return None
519 wrong_name_len = len(wrong_name)
520 if wrong_name_len > _MAX_STRING_SIZE:
521 return None
522 best_distance = wrong_name_len
523 suggestion = None
524 for possible_name in d:
525 if possible_name == wrong_name:
526 # A missing attribute is "found". Don't suggest it (see GH-88821).
527 continue
528 # No more than 1/3 of the involved characters should need changed.
529 max_distance = (len(possible_name) + wrong_name_len + 3) * _MOVE_COST // 6
530 # Don't take matches we've already beaten.
531 max_distance = min(max_distance, best_distance - 1)
532 current_distance = _levenshtein_distance(
533 wrong_name, possible_name, max_distance
534 )
535 if current_distance > max_distance:
536 continue
537 if not suggestion or current_distance < best_distance:
538 suggestion = possible_name
539 best_distance = current_distance
540 return suggestion
541
542
543def _levenshtein_distance(a, b, max_cost):
544 # A Python implementation of Python/suggestions.c:levenshtein_distance.
545
546 # Both strings are the same
547 if a == b:
548 return 0
549
550 # Trim away common affixes
551 pre = 0
552 while a[pre:] and b[pre:] and a[pre] == b[pre]:
553 pre += 1
554 a = a[pre:]
555 b = b[pre:]
556 post = 0
557 while a[: post or None] and b[: post or None] and a[post - 1] == b[post - 1]:
558 post -= 1
559 a = a[: post or None]
560 b = b[: post or None]
561 if not a or not b:
562 return _MOVE_COST * (len(a) + len(b))
563 if len(a) > _MAX_STRING_SIZE or len(b) > _MAX_STRING_SIZE:
564 return max_cost + 1
565
566 # Prefer shorter buffer
567 if len(b) < len(a):
568 a, b = b, a
569
570 # Quick fail when a match is impossible
571 if (len(b) - len(a)) * _MOVE_COST > max_cost:
572 return max_cost + 1
573
574 # Instead of producing the whole traditional len(a)-by-len(b)
575 # matrix, we can update just one row in place.
576 # Initialize the buffer row
577 row = list(range(_MOVE_COST, _MOVE_COST * (len(a) + 1), _MOVE_COST))
578
579 result = 0
580 for bindex in range(len(b)):
581 bchar = b[bindex]
582 distance = result = bindex * _MOVE_COST
583 minimum = sys.maxsize
584 for index in range(len(a)):
585 # 1) Previous distance in this row is cost(b[:b_index], a[:index])
586 substitute = distance + _substitution_cost(bchar, a[index])
587 # 2) cost(b[:b_index], a[:index+1]) from previous row
588 distance = row[index]
589 # 3) existing result is cost(b[:b_index+1], a[index])
590
591 insert_delete = min(result, distance) + _MOVE_COST
592 result = min(insert_delete, substitute)
593
594 # cost(b[:b_index+1], a[:index+1])
595 row[index] = result
596 if result < minimum:
597 minimum = result
598 if minimum > max_cost:
599 # Everything in this row is too big, so bail early.
600 return max_cost + 1
601 return result