1from __future__ import annotations
2
3from contextlib import (
4 contextmanager,
5 nullcontext,
6)
7import re
8import sys
9from typing import (
10 Generator,
11 Literal,
12 Sequence,
13 Type,
14 cast,
15)
16import warnings
17
18
19@contextmanager
20def assert_produces_warning(
21 expected_warning: type[Warning] | bool | tuple[type[Warning], ...] | None = Warning,
22 filter_level: Literal[
23 "error", "ignore", "always", "default", "module", "once"
24 ] = "always",
25 check_stacklevel: bool = True,
26 raise_on_extra_warnings: bool = True,
27 match: str | None = None,
28) -> Generator[list[warnings.WarningMessage], None, None]:
29 """
30 Context manager for running code expected to either raise a specific warning,
31 multiple specific warnings, or not raise any warnings. Verifies that the code
32 raises the expected warning(s), and that it does not raise any other unexpected
33 warnings. It is basically a wrapper around ``warnings.catch_warnings``.
34
35 Parameters
36 ----------
37 expected_warning : {Warning, False, tuple[Warning, ...], None}, default Warning
38 The type of Exception raised. ``exception.Warning`` is the base
39 class for all warnings. To raise multiple types of exceptions,
40 pass them as a tuple. To check that no warning is returned,
41 specify ``False`` or ``None``.
42 filter_level : str or None, default "always"
43 Specifies whether warnings are ignored, displayed, or turned
44 into errors.
45 Valid values are:
46
47 * "error" - turns matching warnings into exceptions
48 * "ignore" - discard the warning
49 * "always" - always emit a warning
50 * "default" - print the warning the first time it is generated
51 from each location
52 * "module" - print the warning the first time it is generated
53 from each module
54 * "once" - print the warning the first time it is generated
55
56 check_stacklevel : bool, default True
57 If True, displays the line that called the function containing
58 the warning to show were the function is called. Otherwise, the
59 line that implements the function is displayed.
60 raise_on_extra_warnings : bool, default True
61 Whether extra warnings not of the type `expected_warning` should
62 cause the test to fail.
63 match : str, optional
64 Match warning message.
65
66 Examples
67 --------
68 >>> import warnings
69 >>> with assert_produces_warning():
70 ... warnings.warn(UserWarning())
71 ...
72 >>> with assert_produces_warning(False):
73 ... warnings.warn(RuntimeWarning())
74 ...
75 Traceback (most recent call last):
76 ...
77 AssertionError: Caused unexpected warning(s): ['RuntimeWarning'].
78 >>> with assert_produces_warning(UserWarning):
79 ... warnings.warn(RuntimeWarning())
80 Traceback (most recent call last):
81 ...
82 AssertionError: Did not see expected warning of class 'UserWarning'.
83
84 ..warn:: This is *not* thread-safe.
85 """
86 __tracebackhide__ = True
87
88 with warnings.catch_warnings(record=True) as w:
89 warnings.simplefilter(filter_level)
90 try:
91 yield w
92 finally:
93 if expected_warning:
94 expected_warning = cast(Type[Warning], expected_warning)
95 _assert_caught_expected_warning(
96 caught_warnings=w,
97 expected_warning=expected_warning,
98 match=match,
99 check_stacklevel=check_stacklevel,
100 )
101 if raise_on_extra_warnings:
102 _assert_caught_no_extra_warnings(
103 caught_warnings=w,
104 expected_warning=expected_warning,
105 )
106
107
108def maybe_produces_warning(warning: type[Warning], condition: bool, **kwargs):
109 """
110 Return a context manager that possibly checks a warning based on the condition
111 """
112 if condition:
113 return assert_produces_warning(warning, **kwargs)
114 else:
115 return nullcontext()
116
117
118def _assert_caught_expected_warning(
119 *,
120 caught_warnings: Sequence[warnings.WarningMessage],
121 expected_warning: type[Warning],
122 match: str | None,
123 check_stacklevel: bool,
124) -> None:
125 """Assert that there was the expected warning among the caught warnings."""
126 saw_warning = False
127 matched_message = False
128 unmatched_messages = []
129
130 for actual_warning in caught_warnings:
131 if issubclass(actual_warning.category, expected_warning):
132 saw_warning = True
133
134 if check_stacklevel:
135 _assert_raised_with_correct_stacklevel(actual_warning)
136
137 if match is not None:
138 if re.search(match, str(actual_warning.message)):
139 matched_message = True
140 else:
141 unmatched_messages.append(actual_warning.message)
142
143 if not saw_warning:
144 raise AssertionError(
145 f"Did not see expected warning of class "
146 f"{repr(expected_warning.__name__)}"
147 )
148
149 if match and not matched_message:
150 raise AssertionError(
151 f"Did not see warning {repr(expected_warning.__name__)} "
152 f"matching '{match}'. The emitted warning messages are "
153 f"{unmatched_messages}"
154 )
155
156
157def _assert_caught_no_extra_warnings(
158 *,
159 caught_warnings: Sequence[warnings.WarningMessage],
160 expected_warning: type[Warning] | bool | tuple[type[Warning], ...] | None,
161) -> None:
162 """Assert that no extra warnings apart from the expected ones are caught."""
163 extra_warnings = []
164
165 for actual_warning in caught_warnings:
166 if _is_unexpected_warning(actual_warning, expected_warning):
167 # GH#38630 pytest.filterwarnings does not suppress these.
168 if actual_warning.category == ResourceWarning:
169 # GH 44732: Don't make the CI flaky by filtering SSL-related
170 # ResourceWarning from dependencies
171 if "unclosed <ssl.SSLSocket" in str(actual_warning.message):
172 continue
173 # GH 44844: Matplotlib leaves font files open during the entire process
174 # upon import. Don't make CI flaky if ResourceWarning raised
175 # due to these open files.
176 if any("matplotlib" in mod for mod in sys.modules):
177 continue
178 extra_warnings.append(
179 (
180 actual_warning.category.__name__,
181 actual_warning.message,
182 actual_warning.filename,
183 actual_warning.lineno,
184 )
185 )
186
187 if extra_warnings:
188 raise AssertionError(f"Caused unexpected warning(s): {repr(extra_warnings)}")
189
190
191def _is_unexpected_warning(
192 actual_warning: warnings.WarningMessage,
193 expected_warning: type[Warning] | bool | tuple[type[Warning], ...] | None,
194) -> bool:
195 """Check if the actual warning issued is unexpected."""
196 if actual_warning and not expected_warning:
197 return True
198 expected_warning = cast(Type[Warning], expected_warning)
199 return bool(not issubclass(actual_warning.category, expected_warning))
200
201
202def _assert_raised_with_correct_stacklevel(
203 actual_warning: warnings.WarningMessage,
204) -> None:
205 from inspect import (
206 getframeinfo,
207 stack,
208 )
209
210 caller = getframeinfo(stack()[4][0])
211 msg = (
212 "Warning not set with correct stacklevel. "
213 f"File where warning is raised: {actual_warning.filename} != "
214 f"{caller.filename}. Warning message: {actual_warning.message}"
215 )
216 assert actual_warning.filename == caller.filename, msg