1from __future__ import annotations
2
3from collections.abc import Callable, Sequence
4from functools import partial
5from inspect import getmro, isclass
6from typing import TYPE_CHECKING, Generic, Type, TypeVar, cast, overload
7
8_BaseExceptionT_co = TypeVar("_BaseExceptionT_co", bound=BaseException, covariant=True)
9_BaseExceptionT = TypeVar("_BaseExceptionT", bound=BaseException)
10_ExceptionT_co = TypeVar("_ExceptionT_co", bound=Exception, covariant=True)
11_ExceptionT = TypeVar("_ExceptionT", bound=Exception)
12# using typing.Self would require a typing_extensions dependency on py<3.11
13_ExceptionGroupSelf = TypeVar("_ExceptionGroupSelf", bound="ExceptionGroup")
14_BaseExceptionGroupSelf = TypeVar("_BaseExceptionGroupSelf", bound="BaseExceptionGroup")
15
16
17def check_direct_subclass(
18 exc: BaseException, parents: tuple[type[BaseException]]
19) -> bool:
20 for cls in getmro(exc.__class__)[:-1]:
21 if cls in parents:
22 return True
23
24 return False
25
26
27def get_condition_filter(
28 condition: type[_BaseExceptionT]
29 | tuple[type[_BaseExceptionT], ...]
30 | Callable[[_BaseExceptionT_co], bool],
31) -> Callable[[_BaseExceptionT_co], bool]:
32 if isclass(condition) and issubclass(
33 cast(Type[BaseException], condition), BaseException
34 ):
35 return partial(check_direct_subclass, parents=(condition,))
36 elif isinstance(condition, tuple):
37 if all(isclass(x) and issubclass(x, BaseException) for x in condition):
38 return partial(check_direct_subclass, parents=condition)
39 elif callable(condition):
40 return cast("Callable[[BaseException], bool]", condition)
41
42 raise TypeError("expected a function, exception type or tuple of exception types")
43
44
45def _derive_and_copy_attributes(self, excs):
46 eg = self.derive(excs)
47 eg.__cause__ = self.__cause__
48 eg.__context__ = self.__context__
49 eg.__traceback__ = self.__traceback__
50 if hasattr(self, "__notes__"):
51 # Create a new list so that add_note() only affects one exceptiongroup
52 eg.__notes__ = list(self.__notes__)
53 return eg
54
55
56class BaseExceptionGroup(BaseException, Generic[_BaseExceptionT_co]):
57 """A combination of multiple unrelated exceptions."""
58
59 def __new__(
60 cls: type[_BaseExceptionGroupSelf],
61 __message: str,
62 __exceptions: Sequence[_BaseExceptionT_co],
63 ) -> _BaseExceptionGroupSelf:
64 if not isinstance(__message, str):
65 raise TypeError(f"argument 1 must be str, not {type(__message)}")
66 if not isinstance(__exceptions, Sequence):
67 raise TypeError("second argument (exceptions) must be a sequence")
68 if not __exceptions:
69 raise ValueError(
70 "second argument (exceptions) must be a non-empty sequence"
71 )
72
73 for i, exc in enumerate(__exceptions):
74 if not isinstance(exc, BaseException):
75 raise ValueError(
76 f"Item {i} of second argument (exceptions) is not an exception"
77 )
78
79 if cls is BaseExceptionGroup:
80 if all(isinstance(exc, Exception) for exc in __exceptions):
81 cls = ExceptionGroup
82
83 if issubclass(cls, Exception):
84 for exc in __exceptions:
85 if not isinstance(exc, Exception):
86 if cls is ExceptionGroup:
87 raise TypeError(
88 "Cannot nest BaseExceptions in an ExceptionGroup"
89 )
90 else:
91 raise TypeError(
92 f"Cannot nest BaseExceptions in {cls.__name__!r}"
93 )
94
95 instance = super().__new__(cls, __message, __exceptions)
96 instance._message = __message
97 instance._exceptions = __exceptions
98 return instance
99
100 def add_note(self, note: str) -> None:
101 if not isinstance(note, str):
102 raise TypeError(
103 f"Expected a string, got note={note!r} (type {type(note).__name__})"
104 )
105
106 if not hasattr(self, "__notes__"):
107 self.__notes__: list[str] = []
108
109 self.__notes__.append(note)
110
111 @property
112 def message(self) -> str:
113 return self._message
114
115 @property
116 def exceptions(
117 self,
118 ) -> tuple[_BaseExceptionT_co | BaseExceptionGroup[_BaseExceptionT_co], ...]:
119 return tuple(self._exceptions)
120
121 @overload
122 def subgroup(
123 self, __condition: type[_ExceptionT] | tuple[type[_ExceptionT], ...]
124 ) -> ExceptionGroup[_ExceptionT] | None: ...
125
126 @overload
127 def subgroup(
128 self, __condition: type[_BaseExceptionT] | tuple[type[_BaseExceptionT], ...]
129 ) -> BaseExceptionGroup[_BaseExceptionT] | None: ...
130
131 @overload
132 def subgroup(
133 self,
134 __condition: Callable[[_BaseExceptionT_co | _BaseExceptionGroupSelf], bool],
135 ) -> BaseExceptionGroup[_BaseExceptionT_co] | None: ...
136
137 def subgroup(
138 self,
139 __condition: type[_BaseExceptionT]
140 | tuple[type[_BaseExceptionT], ...]
141 | Callable[[_BaseExceptionT_co | _BaseExceptionGroupSelf], bool],
142 ) -> BaseExceptionGroup[_BaseExceptionT] | None:
143 condition = get_condition_filter(__condition)
144 modified = False
145 if condition(self):
146 return self
147
148 exceptions: list[BaseException] = []
149 for exc in self.exceptions:
150 if isinstance(exc, BaseExceptionGroup):
151 subgroup = exc.subgroup(__condition)
152 if subgroup is not None:
153 exceptions.append(subgroup)
154
155 if subgroup is not exc:
156 modified = True
157 elif condition(exc):
158 exceptions.append(exc)
159 else:
160 modified = True
161
162 if not modified:
163 return self
164 elif exceptions:
165 group = _derive_and_copy_attributes(self, exceptions)
166 return group
167 else:
168 return None
169
170 @overload
171 def split(
172 self, __condition: type[_ExceptionT] | tuple[type[_ExceptionT], ...]
173 ) -> tuple[
174 ExceptionGroup[_ExceptionT] | None,
175 BaseExceptionGroup[_BaseExceptionT_co] | None,
176 ]: ...
177
178 @overload
179 def split(
180 self, __condition: type[_BaseExceptionT] | tuple[type[_BaseExceptionT], ...]
181 ) -> tuple[
182 BaseExceptionGroup[_BaseExceptionT] | None,
183 BaseExceptionGroup[_BaseExceptionT_co] | None,
184 ]: ...
185
186 @overload
187 def split(
188 self,
189 __condition: Callable[[_BaseExceptionT_co | _BaseExceptionGroupSelf], bool],
190 ) -> tuple[
191 BaseExceptionGroup[_BaseExceptionT_co] | None,
192 BaseExceptionGroup[_BaseExceptionT_co] | None,
193 ]: ...
194
195 def split(
196 self,
197 __condition: type[_BaseExceptionT]
198 | tuple[type[_BaseExceptionT], ...]
199 | Callable[[_BaseExceptionT_co], bool],
200 ) -> (
201 tuple[
202 ExceptionGroup[_ExceptionT] | None,
203 BaseExceptionGroup[_BaseExceptionT_co] | None,
204 ]
205 | tuple[
206 BaseExceptionGroup[_BaseExceptionT] | None,
207 BaseExceptionGroup[_BaseExceptionT_co] | None,
208 ]
209 | tuple[
210 BaseExceptionGroup[_BaseExceptionT_co] | None,
211 BaseExceptionGroup[_BaseExceptionT_co] | None,
212 ]
213 ):
214 condition = get_condition_filter(__condition)
215 if condition(self):
216 return self, None
217
218 matching_exceptions: list[BaseException] = []
219 nonmatching_exceptions: list[BaseException] = []
220 for exc in self.exceptions:
221 if isinstance(exc, BaseExceptionGroup):
222 matching, nonmatching = exc.split(condition)
223 if matching is not None:
224 matching_exceptions.append(matching)
225
226 if nonmatching is not None:
227 nonmatching_exceptions.append(nonmatching)
228 elif condition(exc):
229 matching_exceptions.append(exc)
230 else:
231 nonmatching_exceptions.append(exc)
232
233 matching_group: _BaseExceptionGroupSelf | None = None
234 if matching_exceptions:
235 matching_group = _derive_and_copy_attributes(self, matching_exceptions)
236
237 nonmatching_group: _BaseExceptionGroupSelf | None = None
238 if nonmatching_exceptions:
239 nonmatching_group = _derive_and_copy_attributes(
240 self, nonmatching_exceptions
241 )
242
243 return matching_group, nonmatching_group
244
245 @overload
246 def derive(self, __excs: Sequence[_ExceptionT]) -> ExceptionGroup[_ExceptionT]: ...
247
248 @overload
249 def derive(
250 self, __excs: Sequence[_BaseExceptionT]
251 ) -> BaseExceptionGroup[_BaseExceptionT]: ...
252
253 def derive(
254 self, __excs: Sequence[_BaseExceptionT]
255 ) -> BaseExceptionGroup[_BaseExceptionT]:
256 return BaseExceptionGroup(self.message, __excs)
257
258 def __str__(self) -> str:
259 suffix = "" if len(self._exceptions) == 1 else "s"
260 return f"{self.message} ({len(self._exceptions)} sub-exception{suffix})"
261
262 def __repr__(self) -> str:
263 return f"{self.__class__.__name__}({self.message!r}, {self._exceptions!r})"
264
265
266class ExceptionGroup(BaseExceptionGroup[_ExceptionT_co], Exception):
267 def __new__(
268 cls: type[_ExceptionGroupSelf],
269 __message: str,
270 __exceptions: Sequence[_ExceptionT_co],
271 ) -> _ExceptionGroupSelf:
272 return super().__new__(cls, __message, __exceptions)
273
274 if TYPE_CHECKING:
275
276 @property
277 def exceptions(
278 self,
279 ) -> tuple[_ExceptionT_co | ExceptionGroup[_ExceptionT_co], ...]: ...
280
281 @overload # type: ignore[override]
282 def subgroup(
283 self, __condition: type[_ExceptionT] | tuple[type[_ExceptionT], ...]
284 ) -> ExceptionGroup[_ExceptionT] | None: ...
285
286 @overload
287 def subgroup(
288 self, __condition: Callable[[_ExceptionT_co | _ExceptionGroupSelf], bool]
289 ) -> ExceptionGroup[_ExceptionT_co] | None: ...
290
291 def subgroup(
292 self,
293 __condition: type[_ExceptionT]
294 | tuple[type[_ExceptionT], ...]
295 | Callable[[_ExceptionT_co], bool],
296 ) -> ExceptionGroup[_ExceptionT] | None:
297 return super().subgroup(__condition)
298
299 @overload
300 def split(
301 self, __condition: type[_ExceptionT] | tuple[type[_ExceptionT], ...]
302 ) -> tuple[
303 ExceptionGroup[_ExceptionT] | None, ExceptionGroup[_ExceptionT_co] | None
304 ]: ...
305
306 @overload
307 def split(
308 self, __condition: Callable[[_ExceptionT_co | _ExceptionGroupSelf], bool]
309 ) -> tuple[
310 ExceptionGroup[_ExceptionT_co] | None, ExceptionGroup[_ExceptionT_co] | None
311 ]: ...
312
313 def split(
314 self: _ExceptionGroupSelf,
315 __condition: type[_ExceptionT]
316 | tuple[type[_ExceptionT], ...]
317 | Callable[[_ExceptionT_co], bool],
318 ) -> tuple[
319 ExceptionGroup[_ExceptionT_co] | None, ExceptionGroup[_ExceptionT_co] | None
320 ]:
321 return super().split(__condition)