1from __future__ import annotations
2
3import inspect
4import sys
5from collections.abc import Callable, Iterable, Mapping
6from contextlib import AbstractContextManager
7from types import TracebackType
8from typing import TYPE_CHECKING, Any
9
10if sys.version_info < (3, 11):
11 from ._exceptions import BaseExceptionGroup
12
13if TYPE_CHECKING:
14 _Handler = Callable[[BaseExceptionGroup[Any]], Any]
15
16
17class _Catcher:
18 def __init__(self, handler_map: Mapping[tuple[type[BaseException], ...], _Handler]):
19 self._handler_map = handler_map
20
21 def __enter__(self) -> None:
22 pass
23
24 def __exit__(
25 self,
26 etype: type[BaseException] | None,
27 exc: BaseException | None,
28 tb: TracebackType | None,
29 ) -> bool:
30 if exc is not None:
31 unhandled = self.handle_exception(exc)
32 if unhandled is exc:
33 return False
34 elif unhandled is None:
35 return True
36 else:
37 if isinstance(exc, BaseExceptionGroup):
38 try:
39 raise unhandled from exc.__cause__
40 except BaseExceptionGroup:
41 # Change __context__ to __cause__ because Python 3.11 does this
42 # too
43 unhandled.__context__ = exc.__cause__
44 raise
45
46 raise unhandled from exc
47
48 return False
49
50 def handle_exception(self, exc: BaseException) -> BaseException | None:
51 excgroup: BaseExceptionGroup | None
52 if isinstance(exc, BaseExceptionGroup):
53 excgroup = exc
54 else:
55 excgroup = BaseExceptionGroup("", [exc])
56
57 new_exceptions: list[BaseException] = []
58 for exc_types, handler in self._handler_map.items():
59 matched, excgroup = excgroup.split(exc_types)
60 if matched:
61 try:
62 try:
63 raise matched
64 except BaseExceptionGroup:
65 result = handler(matched)
66 except BaseExceptionGroup as new_exc:
67 if new_exc is matched:
68 new_exceptions.append(new_exc)
69 else:
70 new_exceptions.extend(new_exc.exceptions)
71 except BaseException as new_exc:
72 new_exceptions.append(new_exc)
73 else:
74 if inspect.iscoroutine(result):
75 raise TypeError(
76 f"Error trying to handle {matched!r} with {handler!r}. "
77 "Exception handler must be a sync function."
78 ) from exc
79
80 if not excgroup:
81 break
82
83 if new_exceptions:
84 if len(new_exceptions) == 1:
85 return new_exceptions[0]
86
87 return BaseExceptionGroup("", new_exceptions)
88 elif (
89 excgroup and len(excgroup.exceptions) == 1 and excgroup.exceptions[0] is exc
90 ):
91 return exc
92 else:
93 return excgroup
94
95
96def catch(
97 __handlers: Mapping[type[BaseException] | Iterable[type[BaseException]], _Handler],
98) -> AbstractContextManager[None]:
99 if not isinstance(__handlers, Mapping):
100 raise TypeError("the argument must be a mapping")
101
102 handler_map: dict[
103 tuple[type[BaseException], ...], Callable[[BaseExceptionGroup]]
104 ] = {}
105 for type_or_iterable, handler in __handlers.items():
106 iterable: tuple[type[BaseException]]
107 if isinstance(type_or_iterable, type) and issubclass(
108 type_or_iterable, BaseException
109 ):
110 iterable = (type_or_iterable,)
111 elif isinstance(type_or_iterable, Iterable):
112 iterable = tuple(type_or_iterable)
113 else:
114 raise TypeError(
115 "each key must be either an exception classes or an iterable thereof"
116 )
117
118 if not callable(handler):
119 raise TypeError("handlers must be callable")
120
121 for exc_type in iterable:
122 if not isinstance(exc_type, type) or not issubclass(
123 exc_type, BaseException
124 ):
125 raise TypeError(
126 "each key must be either an exception classes or an iterable "
127 "thereof"
128 )
129
130 if issubclass(exc_type, BaseExceptionGroup):
131 raise TypeError(
132 "catching ExceptionGroup with catch() is not allowed. "
133 "Use except instead."
134 )
135
136 handler_map[iterable] = handler
137
138 return _Catcher(handler_map)