1from __future__ import annotations
2
3import sys
4from collections.abc import Callable, Sequence
5from functools import partial
6from inspect import getmro, isclass
7from typing import TYPE_CHECKING, Generic, Type, TypeVar, cast, overload
8
9if sys.version_info < (3, 13):
10 from typing_extensions import TypeVar
11
12_BaseExceptionT_co = TypeVar(
13 "_BaseExceptionT_co", bound=BaseException, covariant=True, default=BaseException
14)
15_BaseExceptionT = TypeVar("_BaseExceptionT", bound=BaseException)
16_ExceptionT_co = TypeVar(
17 "_ExceptionT_co", bound=Exception, covariant=True, default=Exception
18)
19_ExceptionT = TypeVar("_ExceptionT", bound=Exception)
20# using typing.Self would require a typing_extensions dependency on py<3.11
21_ExceptionGroupSelf = TypeVar("_ExceptionGroupSelf", bound="ExceptionGroup")
22_BaseExceptionGroupSelf = TypeVar("_BaseExceptionGroupSelf", bound="BaseExceptionGroup")
23
24
25def check_direct_subclass(
26 exc: BaseException, parents: tuple[type[BaseException]]
27) -> bool:
28 for cls in getmro(exc.__class__)[:-1]:
29 if cls in parents:
30 return True
31
32 return False
33
34
35def get_condition_filter(
36 condition: type[_BaseExceptionT]
37 | tuple[type[_BaseExceptionT], ...]
38 | Callable[[_BaseExceptionT_co], bool],
39) -> Callable[[_BaseExceptionT_co], bool]:
40 if isclass(condition) and issubclass(
41 cast(Type[BaseException], condition), BaseException
42 ):
43 return partial(check_direct_subclass, parents=(condition,))
44 elif isinstance(condition, tuple):
45 if all(isclass(x) and issubclass(x, BaseException) for x in condition):
46 return partial(check_direct_subclass, parents=condition)
47 elif callable(condition):
48 return cast("Callable[[BaseException], bool]", condition)
49
50 raise TypeError("expected a function, exception type or tuple of exception types")
51
52
53def _derive_and_copy_attributes(self, excs):
54 eg = self.derive(excs)
55 eg.__cause__ = self.__cause__
56 eg.__context__ = self.__context__
57 eg.__traceback__ = self.__traceback__
58 if hasattr(self, "__notes__"):
59 # Create a new list so that add_note() only affects one exceptiongroup
60 eg.__notes__ = list(self.__notes__)
61 return eg
62
63
64class BaseExceptionGroup(BaseException, Generic[_BaseExceptionT_co]):
65 """A combination of multiple unrelated exceptions."""
66
67 def __new__(
68 cls: type[_BaseExceptionGroupSelf],
69 __message: str,
70 __exceptions: Sequence[_BaseExceptionT_co],
71 ) -> _BaseExceptionGroupSelf:
72 if not isinstance(__message, str):
73 raise TypeError(f"argument 1 must be str, not {type(__message)}")
74 if not isinstance(__exceptions, Sequence):
75 raise TypeError("second argument (exceptions) must be a sequence")
76 if not __exceptions:
77 raise ValueError(
78 "second argument (exceptions) must be a non-empty sequence"
79 )
80
81 for i, exc in enumerate(__exceptions):
82 if not isinstance(exc, BaseException):
83 raise ValueError(
84 f"Item {i} of second argument (exceptions) is not an exception"
85 )
86
87 if cls is BaseExceptionGroup:
88 if all(isinstance(exc, Exception) for exc in __exceptions):
89 cls = ExceptionGroup
90
91 if issubclass(cls, Exception):
92 for exc in __exceptions:
93 if not isinstance(exc, Exception):
94 if cls is ExceptionGroup:
95 raise TypeError(
96 "Cannot nest BaseExceptions in an ExceptionGroup"
97 )
98 else:
99 raise TypeError(
100 f"Cannot nest BaseExceptions in {cls.__name__!r}"
101 )
102
103 instance = super().__new__(cls, __message, __exceptions)
104 instance._exceptions = tuple(__exceptions)
105 return instance
106
107 def __init__(
108 self,
109 __message: str,
110 __exceptions: Sequence[_BaseExceptionT_co],
111 *args: object,
112 ) -> None:
113 BaseException.__init__(self, __message, __exceptions, *args)
114
115 def add_note(self, note: str) -> None:
116 if not isinstance(note, str):
117 raise TypeError(
118 f"Expected a string, got note={note!r} (type {type(note).__name__})"
119 )
120
121 if not hasattr(self, "__notes__"):
122 self.__notes__: list[str] = []
123
124 self.__notes__.append(note)
125
126 @property
127 def message(self) -> str:
128 return self.args[0]
129
130 @property
131 def exceptions(
132 self,
133 ) -> tuple[_BaseExceptionT_co | BaseExceptionGroup[_BaseExceptionT_co], ...]:
134 return tuple(self._exceptions)
135
136 @overload
137 def subgroup(
138 self, __condition: type[_ExceptionT] | tuple[type[_ExceptionT], ...]
139 ) -> ExceptionGroup[_ExceptionT] | None: ...
140
141 @overload
142 def subgroup(
143 self, __condition: type[_BaseExceptionT] | tuple[type[_BaseExceptionT], ...]
144 ) -> BaseExceptionGroup[_BaseExceptionT] | None: ...
145
146 @overload
147 def subgroup(
148 self,
149 __condition: Callable[[_BaseExceptionT_co | _BaseExceptionGroupSelf], bool],
150 ) -> BaseExceptionGroup[_BaseExceptionT_co] | None: ...
151
152 def subgroup(
153 self,
154 __condition: type[_BaseExceptionT]
155 | tuple[type[_BaseExceptionT], ...]
156 | Callable[[_BaseExceptionT_co | _BaseExceptionGroupSelf], bool],
157 ) -> BaseExceptionGroup[_BaseExceptionT] | None:
158 condition = get_condition_filter(__condition)
159 modified = False
160 if condition(self):
161 return self
162
163 exceptions: list[BaseException] = []
164 for exc in self.exceptions:
165 if isinstance(exc, BaseExceptionGroup):
166 subgroup = exc.subgroup(__condition)
167 if subgroup is not None:
168 exceptions.append(subgroup)
169
170 if subgroup is not exc:
171 modified = True
172 elif condition(exc):
173 exceptions.append(exc)
174 else:
175 modified = True
176
177 if not modified:
178 return self
179 elif exceptions:
180 group = _derive_and_copy_attributes(self, exceptions)
181 return group
182 else:
183 return None
184
185 @overload
186 def split(
187 self, __condition: type[_ExceptionT] | tuple[type[_ExceptionT], ...]
188 ) -> tuple[
189 ExceptionGroup[_ExceptionT] | None,
190 BaseExceptionGroup[_BaseExceptionT_co] | None,
191 ]: ...
192
193 @overload
194 def split(
195 self, __condition: type[_BaseExceptionT] | tuple[type[_BaseExceptionT], ...]
196 ) -> tuple[
197 BaseExceptionGroup[_BaseExceptionT] | None,
198 BaseExceptionGroup[_BaseExceptionT_co] | None,
199 ]: ...
200
201 @overload
202 def split(
203 self,
204 __condition: Callable[[_BaseExceptionT_co | _BaseExceptionGroupSelf], bool],
205 ) -> tuple[
206 BaseExceptionGroup[_BaseExceptionT_co] | None,
207 BaseExceptionGroup[_BaseExceptionT_co] | None,
208 ]: ...
209
210 def split(
211 self,
212 __condition: type[_BaseExceptionT]
213 | tuple[type[_BaseExceptionT], ...]
214 | Callable[[_BaseExceptionT_co], bool],
215 ) -> (
216 tuple[
217 ExceptionGroup[_ExceptionT] | None,
218 BaseExceptionGroup[_BaseExceptionT_co] | None,
219 ]
220 | tuple[
221 BaseExceptionGroup[_BaseExceptionT] | None,
222 BaseExceptionGroup[_BaseExceptionT_co] | None,
223 ]
224 | tuple[
225 BaseExceptionGroup[_BaseExceptionT_co] | None,
226 BaseExceptionGroup[_BaseExceptionT_co] | None,
227 ]
228 ):
229 condition = get_condition_filter(__condition)
230 if condition(self):
231 return self, None
232
233 matching_exceptions: list[BaseException] = []
234 nonmatching_exceptions: list[BaseException] = []
235 for exc in self.exceptions:
236 if isinstance(exc, BaseExceptionGroup):
237 matching, nonmatching = exc.split(condition)
238 if matching is not None:
239 matching_exceptions.append(matching)
240
241 if nonmatching is not None:
242 nonmatching_exceptions.append(nonmatching)
243 elif condition(exc):
244 matching_exceptions.append(exc)
245 else:
246 nonmatching_exceptions.append(exc)
247
248 matching_group: _BaseExceptionGroupSelf | None = None
249 if matching_exceptions:
250 matching_group = _derive_and_copy_attributes(self, matching_exceptions)
251
252 nonmatching_group: _BaseExceptionGroupSelf | None = None
253 if nonmatching_exceptions:
254 nonmatching_group = _derive_and_copy_attributes(
255 self, nonmatching_exceptions
256 )
257
258 return matching_group, nonmatching_group
259
260 @overload
261 def derive(self, __excs: Sequence[_ExceptionT]) -> ExceptionGroup[_ExceptionT]: ...
262
263 @overload
264 def derive(
265 self, __excs: Sequence[_BaseExceptionT]
266 ) -> BaseExceptionGroup[_BaseExceptionT]: ...
267
268 def derive(
269 self, __excs: Sequence[_BaseExceptionT]
270 ) -> BaseExceptionGroup[_BaseExceptionT]:
271 return BaseExceptionGroup(self.message, __excs)
272
273 def __str__(self) -> str:
274 suffix = "" if len(self._exceptions) == 1 else "s"
275 return f"{self.message} ({len(self._exceptions)} sub-exception{suffix})"
276
277 def __repr__(self) -> str:
278 return f"{self.__class__.__name__}({self.args[0]!r}, {self.args[1]!r})"
279
280
281class ExceptionGroup(BaseExceptionGroup[_ExceptionT_co], Exception):
282 def __new__(
283 cls: type[_ExceptionGroupSelf],
284 __message: str,
285 __exceptions: Sequence[_ExceptionT_co],
286 ) -> _ExceptionGroupSelf:
287 return super().__new__(cls, __message, __exceptions)
288
289 if TYPE_CHECKING:
290
291 @property
292 def exceptions(
293 self,
294 ) -> tuple[_ExceptionT_co | ExceptionGroup[_ExceptionT_co], ...]: ...
295
296 @overload # type: ignore[override]
297 def subgroup(
298 self, __condition: type[_ExceptionT] | tuple[type[_ExceptionT], ...]
299 ) -> ExceptionGroup[_ExceptionT] | None: ...
300
301 @overload
302 def subgroup(
303 self, __condition: Callable[[_ExceptionT_co | _ExceptionGroupSelf], bool]
304 ) -> ExceptionGroup[_ExceptionT_co] | None: ...
305
306 def subgroup(
307 self,
308 __condition: type[_ExceptionT]
309 | tuple[type[_ExceptionT], ...]
310 | Callable[[_ExceptionT_co], bool],
311 ) -> ExceptionGroup[_ExceptionT] | None:
312 return super().subgroup(__condition)
313
314 @overload
315 def split(
316 self, __condition: type[_ExceptionT] | tuple[type[_ExceptionT], ...]
317 ) -> tuple[
318 ExceptionGroup[_ExceptionT] | None, ExceptionGroup[_ExceptionT_co] | None
319 ]: ...
320
321 @overload
322 def split(
323 self, __condition: Callable[[_ExceptionT_co | _ExceptionGroupSelf], bool]
324 ) -> tuple[
325 ExceptionGroup[_ExceptionT_co] | None, ExceptionGroup[_ExceptionT_co] | None
326 ]: ...
327
328 def split(
329 self: _ExceptionGroupSelf,
330 __condition: type[_ExceptionT]
331 | tuple[type[_ExceptionT], ...]
332 | Callable[[_ExceptionT_co], bool],
333 ) -> tuple[
334 ExceptionGroup[_ExceptionT_co] | None, ExceptionGroup[_ExceptionT_co] | None
335 ]:
336 return super().split(__condition)