1import concurrent.futures
2import sys
3import threading
4
5import pytest
6
7import numpy as np
8from numpy._core import _rational_tests
9from numpy._core.tests.test_stringdtype import random_unicode_string_list
10from numpy.testing import IS_64BIT, IS_WASM
11from numpy.testing._private.utils import run_threaded
12
13if IS_WASM:
14 pytest.skip(allow_module_level=True, reason="no threading support in wasm")
15
16pytestmark = pytest.mark.thread_unsafe(
17 reason="tests in this module are already explicitly multi-threaded"
18)
19
20def test_parallel_randomstate():
21 # if the coercion cache is enabled and not thread-safe, creating
22 # RandomState instances simultaneously leads to a data race
23 def func(seed):
24 np.random.RandomState(seed)
25
26 run_threaded(func, 500, pass_count=True)
27
28 # seeding and setting state shouldn't race with generating RNG samples
29 rng = np.random.RandomState()
30
31 def func(seed):
32 base_rng = np.random.RandomState(seed)
33 state = base_rng.get_state()
34 rng.seed(seed)
35 rng.random()
36 rng.set_state(state)
37
38 run_threaded(func, 8, pass_count=True)
39
40def test_parallel_ufunc_execution():
41 # if the loop data cache or dispatch cache are not thread-safe
42 # computing ufuncs simultaneously in multiple threads leads
43 # to a data race that causes crashes or spurious exceptions
44 for dtype in [np.float32, np.float64, np.int32]:
45 for op in [np.random.random((25,)).astype(dtype), dtype(25)]:
46 for ufunc in [np.isnan, np.sin]:
47 run_threaded(lambda: ufunc(op), 500)
48
49 # see gh-26690
50 NUM_THREADS = 50
51
52 a = np.ones(1000)
53
54 def f(b):
55 b.wait()
56 return a.sum()
57
58 run_threaded(f, NUM_THREADS, pass_barrier=True)
59
60
61def test_temp_elision_thread_safety():
62 amid = np.ones(50000)
63 bmid = np.ones(50000)
64 alarge = np.ones(1000000)
65 blarge = np.ones(1000000)
66
67 def func(count):
68 if count % 4 == 0:
69 (amid * 2) + bmid
70 elif count % 4 == 1:
71 (amid + bmid) - 2
72 elif count % 4 == 2:
73 (alarge * 2) + blarge
74 else:
75 (alarge + blarge) - 2
76
77 run_threaded(func, 100, pass_count=True)
78
79
80def test_eigvalsh_thread_safety():
81 # if lapack isn't thread safe this will randomly segfault or error
82 # see gh-24512
83 rng = np.random.RandomState(873699172)
84 matrices = (
85 rng.random((5, 10, 10, 3, 3)),
86 rng.random((5, 10, 10, 3, 3)),
87 )
88
89 run_threaded(lambda i: np.linalg.eigvalsh(matrices[i]), 2,
90 pass_count=True)
91
92
93def test_printoptions_thread_safety():
94 # until NumPy 2.1 the printoptions state was stored in globals
95 # this verifies that they are now stored in a context variable
96 b = threading.Barrier(2)
97
98 def legacy_113():
99 np.set_printoptions(legacy='1.13', precision=12)
100 b.wait()
101 po = np.get_printoptions()
102 assert po['legacy'] == '1.13'
103 assert po['precision'] == 12
104 orig_linewidth = po['linewidth']
105 with np.printoptions(linewidth=34, legacy='1.21'):
106 po = np.get_printoptions()
107 assert po['legacy'] == '1.21'
108 assert po['precision'] == 12
109 assert po['linewidth'] == 34
110 po = np.get_printoptions()
111 assert po['linewidth'] == orig_linewidth
112 assert po['legacy'] == '1.13'
113 assert po['precision'] == 12
114
115 def legacy_125():
116 np.set_printoptions(legacy='1.25', precision=7)
117 b.wait()
118 po = np.get_printoptions()
119 assert po['legacy'] == '1.25'
120 assert po['precision'] == 7
121 orig_linewidth = po['linewidth']
122 with np.printoptions(linewidth=6, legacy='1.13'):
123 po = np.get_printoptions()
124 assert po['legacy'] == '1.13'
125 assert po['precision'] == 7
126 assert po['linewidth'] == 6
127 po = np.get_printoptions()
128 assert po['linewidth'] == orig_linewidth
129 assert po['legacy'] == '1.25'
130 assert po['precision'] == 7
131
132 task1 = threading.Thread(target=legacy_113)
133 task2 = threading.Thread(target=legacy_125)
134
135 task1.start()
136 task2.start()
137 task1.join()
138 task2.join()
139
140
141def test_parallel_reduction():
142 # gh-28041
143 NUM_THREADS = 50
144
145 x = np.arange(1000)
146
147 def closure(b):
148 b.wait()
149 np.sum(x)
150
151 run_threaded(closure, NUM_THREADS, pass_barrier=True)
152
153
154def test_parallel_flat_iterator():
155 # gh-28042
156 x = np.arange(20).reshape(5, 4).T
157
158 def closure(b):
159 b.wait()
160 for _ in range(100):
161 list(x.flat)
162
163 run_threaded(closure, outer_iterations=100, pass_barrier=True)
164
165 # gh-28143
166 def prepare_args():
167 return [np.arange(10)]
168
169 def closure(x, b):
170 b.wait()
171 for _ in range(100):
172 y = np.arange(10)
173 y.flat[x] = x
174
175 run_threaded(closure, pass_barrier=True, prepare_args=prepare_args)
176
177
178def test_multithreaded_repeat():
179 x0 = np.arange(10)
180
181 def closure(b):
182 b.wait()
183 for _ in range(100):
184 x = np.repeat(x0, 2, axis=0)[::2]
185
186 run_threaded(closure, max_workers=10, pass_barrier=True)
187
188
189def test_structured_advanced_indexing():
190 # Test that copyswap(n) used by integer array indexing is threadsafe
191 # for structured datatypes, see gh-15387. This test can behave randomly.
192
193 # Create a deeply nested dtype to make a failure more likely:
194 dt = np.dtype([("", "f8")])
195 dt = np.dtype([("", dt)] * 2)
196 dt = np.dtype([("", dt)] * 2)
197 # The array should be large enough to likely run into threading issues
198 arr = np.random.uniform(size=(6000, 8)).view(dt)[:, 0]
199
200 rng = np.random.default_rng()
201
202 def func(arr):
203 indx = rng.integers(0, len(arr), size=6000, dtype=np.intp)
204 arr[indx]
205
206 tpe = concurrent.futures.ThreadPoolExecutor(max_workers=8)
207 futures = [tpe.submit(func, arr) for _ in range(10)]
208 for f in futures:
209 f.result()
210
211 assert arr.dtype is dt
212
213
214def test_structured_threadsafety2():
215 # Nonzero (and some other functions) should be threadsafe for
216 # structured datatypes, see gh-15387. This test can behave randomly.
217 from concurrent.futures import ThreadPoolExecutor
218
219 # Create a deeply nested dtype to make a failure more likely:
220 dt = np.dtype([("", "f8")])
221 dt = np.dtype([("", dt)])
222 dt = np.dtype([("", dt)] * 2)
223 # The array should be large enough to likely run into threading issues
224 arr = np.random.uniform(size=(5000, 4)).view(dt)[:, 0]
225
226 def func(arr):
227 arr.nonzero()
228
229 tpe = ThreadPoolExecutor(max_workers=8)
230 futures = [tpe.submit(func, arr) for _ in range(10)]
231 for f in futures:
232 f.result()
233
234 assert arr.dtype is dt
235
236
237def test_stringdtype_multithreaded_access_and_mutation():
238 # this test uses an RNG and may crash or cause deadlocks if there is a
239 # threading bug
240 rng = np.random.default_rng(0x4D3D3D3)
241
242 string_list = random_unicode_string_list()
243
244 def func(arr):
245 rnd = rng.random()
246 # either write to random locations in the array, compute a ufunc, or
247 # re-initialize the array
248 if rnd < 0.25:
249 num = np.random.randint(0, arr.size)
250 arr[num] = arr[num] + "hello"
251 elif rnd < 0.5:
252 if rnd < 0.375:
253 np.add(arr, arr)
254 else:
255 np.add(arr, arr, out=arr)
256 elif rnd < 0.75:
257 if rnd < 0.875:
258 np.multiply(arr, np.int64(2))
259 else:
260 np.multiply(arr, np.int64(2), out=arr)
261 else:
262 arr[:] = string_list
263
264 with concurrent.futures.ThreadPoolExecutor(max_workers=8) as tpe:
265 arr = np.array(string_list, dtype="T")
266 futures = [tpe.submit(func, arr) for _ in range(500)]
267
268 for f in futures:
269 f.result()
270
271
272@pytest.mark.skipif(
273 not IS_64BIT,
274 reason="Sometimes causes failures or crashes due to OOM on 32 bit runners"
275)
276def test_legacy_usertype_cast_init_thread_safety():
277 def closure(b):
278 b.wait()
279 np.full((10, 10), 1, _rational_tests.rational)
280
281 run_threaded(closure, 250, pass_barrier=True)
282
283@pytest.mark.parametrize("dtype", [bool, int, float])
284def test_nonzero(dtype):
285 # See: gh-28361
286 #
287 # np.nonzero uses np.count_nonzero to determine the size of the output.
288 # array. In a second pass the indices of the non-zero elements are
289 # determined, but they can have changed
290 #
291 # This test triggers a data race which is suppressed in the TSAN CI.
292 # The test is to ensure np.nonzero does not generate a segmentation fault
293 x = np.random.randint(4, size=100).astype(dtype)
294 expected_warning = ('number of non-zero array elements changed'
295 ' during function execution')
296
297 def func(index):
298 for _ in range(10):
299 if index == 0:
300 x[::2] = np.random.randint(2)
301 else:
302 try:
303 _ = np.nonzero(x)
304 except RuntimeError as ex:
305 assert expected_warning in str(ex)
306
307 run_threaded(func, max_workers=10, pass_count=True, outer_iterations=5)
308
309
310# These are all implemented using PySequence_Fast, which needs locking to be safe
311def np_broadcast(arrs):
312 for i in range(50):
313 np.broadcast(arrs)
314
315def create_array(arrs):
316 for i in range(50):
317 np.array(arrs)
318
319def create_nditer(arrs):
320 for i in range(50):
321 np.nditer(arrs)
322
323
324@pytest.mark.parametrize(
325 "kernel, outcome",
326 (
327 (np_broadcast, "error"),
328 (create_array, "error"),
329 (create_nditer, "success"),
330 ),
331)
332def test_arg_locking(kernel, outcome):
333 # should complete without triggering races but may error
334
335 done = 0
336 arrs = [np.array([1, 2, 3]) for _ in range(1000)]
337
338 def read_arrs(b):
339 nonlocal done
340 b.wait()
341 try:
342 kernel(arrs)
343 finally:
344 done += 1
345
346 def contract_and_expand_list(b):
347 b.wait()
348 while done < 4:
349 if len(arrs) > 10:
350 arrs.pop(0)
351 elif len(arrs) <= 10:
352 arrs.extend([np.array([1, 2, 3]) for _ in range(1000)])
353
354 def replace_list_items(b):
355 b.wait()
356 rng = np.random.RandomState()
357 rng.seed(0x4d3d3d3)
358 while done < 4:
359 data = rng.randint(0, 1000, size=4)
360 arrs[data[0]] = data[1:]
361
362 for mutation_func in (replace_list_items, contract_and_expand_list):
363 b = threading.Barrier(5)
364 try:
365 with concurrent.futures.ThreadPoolExecutor(max_workers=5) as tpe:
366 tasks = [tpe.submit(read_arrs, b) for _ in range(4)]
367 tasks.append(tpe.submit(mutation_func, b))
368 for t in tasks:
369 t.result()
370 except RuntimeError as e:
371 if outcome == "success":
372 raise
373 assert "Inconsistent object during array creation?" in str(e)
374 msg = "replace_list_items should not raise errors"
375 assert mutation_func is contract_and_expand_list, msg
376 finally:
377 if len(tasks) < 5:
378 b.abort()
379
380@pytest.mark.skipif(sys.version_info < (3, 12), reason="Python >= 3.12 required")
381def test_array__buffer__thread_safety():
382 import inspect
383 arr = np.arange(1000)
384 flags = [inspect.BufferFlags.STRIDED, inspect.BufferFlags.READ]
385
386 def func(b):
387 b.wait()
388 for i in range(100):
389 arr.__buffer__(flags[i % 2])
390
391 run_threaded(func, max_workers=8, pass_barrier=True)
392
393@pytest.mark.skipif(sys.version_info < (3, 12), reason="Python >= 3.12 required")
394def test_void_dtype__buffer__thread_safety():
395 import inspect
396 dt = np.dtype([('name', np.str_, 16), ('grades', np.float64, (2,))])
397 x = np.array(('ndarray_scalar', (1.2, 3.0)), dtype=dt)[()]
398 assert isinstance(x, np.void)
399 flags = [inspect.BufferFlags.STRIDES, inspect.BufferFlags.READ]
400
401 def func(b):
402 b.wait()
403 for i in range(100):
404 x.__buffer__(flags[i % 2])
405
406 run_threaded(func, max_workers=8, pass_barrier=True)