1import importlib.util
2import os
3import re
4import shutil
5import textwrap
6from collections import defaultdict
7from typing import TYPE_CHECKING
8
9import pytest
10
11# Only trigger a full `mypy` run if this environment variable is set
12# Note that these tests tend to take over a minute even on a macOS M1 CPU,
13# and more than that in CI.
14RUN_MYPY = "NPY_RUN_MYPY_IN_TESTSUITE" in os.environ
15if RUN_MYPY and RUN_MYPY not in ('0', '', 'false'):
16 RUN_MYPY = True
17
18# Skips all functions in this file
19pytestmark = pytest.mark.skipif(
20 not RUN_MYPY,
21 reason="`NPY_RUN_MYPY_IN_TESTSUITE` not set"
22)
23
24
25try:
26 from mypy import api
27except ImportError:
28 NO_MYPY = True
29else:
30 NO_MYPY = False
31
32if TYPE_CHECKING:
33 from collections.abc import Iterator
34
35 # We need this as annotation, but it's located in a private namespace.
36 # As a compromise, do *not* import it during runtime
37 from _pytest.mark.structures import ParameterSet
38
39DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
40PASS_DIR = os.path.join(DATA_DIR, "pass")
41FAIL_DIR = os.path.join(DATA_DIR, "fail")
42REVEAL_DIR = os.path.join(DATA_DIR, "reveal")
43MISC_DIR = os.path.join(DATA_DIR, "misc")
44MYPY_INI = os.path.join(DATA_DIR, "mypy.ini")
45CACHE_DIR = os.path.join(DATA_DIR, ".mypy_cache")
46
47#: A dictionary with file names as keys and lists of the mypy stdout as values.
48#: To-be populated by `run_mypy`.
49OUTPUT_MYPY: defaultdict[str, list[str]] = defaultdict(list)
50
51
52def _key_func(key: str) -> str:
53 """Split at the first occurrence of the ``:`` character.
54
55 Windows drive-letters (*e.g.* ``C:``) are ignored herein.
56 """
57 drive, tail = os.path.splitdrive(key)
58 return os.path.join(drive, tail.split(":", 1)[0])
59
60
61def _strip_filename(msg: str) -> tuple[int, str]:
62 """Strip the filename and line number from a mypy message."""
63 _, tail = os.path.splitdrive(msg)
64 _, lineno, msg = tail.split(":", 2)
65 return int(lineno), msg.strip()
66
67
68def strip_func(match: re.Match[str]) -> str:
69 """`re.sub` helper function for stripping module names."""
70 return match.groups()[1]
71
72
73@pytest.fixture(scope="module", autouse=True)
74def run_mypy() -> None:
75 """Clears the cache and run mypy before running any of the typing tests.
76
77 The mypy results are cached in `OUTPUT_MYPY` for further use.
78
79 The cache refresh can be skipped using
80
81 NUMPY_TYPING_TEST_CLEAR_CACHE=0 pytest numpy/typing/tests
82 """
83 if (
84 os.path.isdir(CACHE_DIR)
85 and bool(os.environ.get("NUMPY_TYPING_TEST_CLEAR_CACHE", True)) # noqa: PLW1508
86 ):
87 shutil.rmtree(CACHE_DIR)
88
89 split_pattern = re.compile(r"(\s+)?\^(\~+)?")
90 for directory in (PASS_DIR, REVEAL_DIR, FAIL_DIR, MISC_DIR):
91 # Run mypy
92 stdout, stderr, exit_code = api.run([
93 "--config-file",
94 MYPY_INI,
95 "--cache-dir",
96 CACHE_DIR,
97 directory,
98 ])
99 if stderr:
100 pytest.fail(f"Unexpected mypy standard error\n\n{stderr}", False)
101 elif exit_code not in {0, 1}:
102 pytest.fail(f"Unexpected mypy exit code: {exit_code}\n\n{stdout}", False)
103
104 str_concat = ""
105 filename: str | None = None
106 for i in stdout.split("\n"):
107 if "note:" in i:
108 continue
109 if filename is None:
110 filename = _key_func(i)
111
112 str_concat += f"{i}\n"
113 if split_pattern.match(i) is not None:
114 OUTPUT_MYPY[filename].append(str_concat)
115 str_concat = ""
116 filename = None
117
118
119def get_test_cases(*directories: str) -> "Iterator[ParameterSet]":
120 for directory in directories:
121 for root, _, files in os.walk(directory):
122 for fname in files:
123 short_fname, ext = os.path.splitext(fname)
124 if ext not in (".pyi", ".py"):
125 continue
126
127 fullpath = os.path.join(root, fname)
128 yield pytest.param(fullpath, id=short_fname)
129
130
131_FAIL_INDENT = " " * 4
132_FAIL_SEP = "\n" + "_" * 79 + "\n\n"
133
134_FAIL_MSG_REVEAL = """{}:{} - reveal mismatch:
135
136{}"""
137
138
139@pytest.mark.slow
140@pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
141@pytest.mark.parametrize("path", get_test_cases(PASS_DIR, FAIL_DIR))
142def test_pass(path) -> None:
143 # Alias `OUTPUT_MYPY` so that it appears in the local namespace
144 output_mypy = OUTPUT_MYPY
145
146 if path not in output_mypy:
147 return
148
149 relpath = os.path.relpath(path)
150
151 # collect any reported errors, and clean up the output
152 messages = []
153 for message in output_mypy[path]:
154 lineno, content = _strip_filename(message)
155 content = content.removeprefix("error:").lstrip()
156 messages.append(f"{relpath}:{lineno} - {content}")
157
158 if messages:
159 pytest.fail("\n".join(messages), pytrace=False)
160
161
162@pytest.mark.slow
163@pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
164@pytest.mark.parametrize("path", get_test_cases(REVEAL_DIR))
165def test_reveal(path: str) -> None:
166 """Validate that mypy correctly infers the return-types of
167 the expressions in `path`.
168 """
169 __tracebackhide__ = True
170
171 output_mypy = OUTPUT_MYPY
172 if path not in output_mypy:
173 return
174
175 relpath = os.path.relpath(path)
176
177 # collect any reported errors, and clean up the output
178 failures = []
179 for error_line in output_mypy[path]:
180 lineno, error_msg = _strip_filename(error_line)
181 error_msg = textwrap.indent(error_msg, _FAIL_INDENT)
182 reason = _FAIL_MSG_REVEAL.format(relpath, lineno, error_msg)
183 failures.append(reason)
184
185 if failures:
186 reasons = _FAIL_SEP.join(failures)
187 pytest.fail(reasons, pytrace=False)
188
189
190@pytest.mark.slow
191@pytest.mark.skipif(NO_MYPY, reason="Mypy is not installed")
192@pytest.mark.parametrize("path", get_test_cases(PASS_DIR))
193def test_code_runs(path: str) -> None:
194 """Validate that the code in `path` properly during runtime."""
195 path_without_extension, _ = os.path.splitext(path)
196 dirname, filename = path.split(os.sep)[-2:]
197
198 spec = importlib.util.spec_from_file_location(
199 f"{dirname}.{filename}", path
200 )
201 assert spec is not None
202 assert spec.loader is not None
203
204 test_module = importlib.util.module_from_spec(spec)
205 spec.loader.exec_module(test_module)