Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/numpy/_core/tests/test_multithreading.py: 1%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

260 statements  

1import concurrent.futures 

2import sys 

3import threading 

4 

5import pytest 

6 

7import numpy as np 

8from numpy._core import _rational_tests 

9from numpy._core.tests.test_stringdtype import random_unicode_string_list 

10from numpy.testing import IS_64BIT, IS_WASM 

11from numpy.testing._private.utils import run_threaded 

12 

13if IS_WASM: 

14 pytest.skip(allow_module_level=True, reason="no threading support in wasm") 

15 

16pytestmark = pytest.mark.thread_unsafe( 

17 reason="tests in this module are already explicitly multi-threaded" 

18) 

19 

20def test_parallel_randomstate(): 

21 # if the coercion cache is enabled and not thread-safe, creating 

22 # RandomState instances simultaneously leads to a data race 

23 def func(seed): 

24 np.random.RandomState(seed) 

25 

26 run_threaded(func, 500, pass_count=True) 

27 

28 # seeding and setting state shouldn't race with generating RNG samples 

29 rng = np.random.RandomState() 

30 

31 def func(seed): 

32 base_rng = np.random.RandomState(seed) 

33 state = base_rng.get_state() 

34 rng.seed(seed) 

35 rng.random() 

36 rng.set_state(state) 

37 

38 run_threaded(func, 8, pass_count=True) 

39 

40def test_parallel_ufunc_execution(): 

41 # if the loop data cache or dispatch cache are not thread-safe 

42 # computing ufuncs simultaneously in multiple threads leads 

43 # to a data race that causes crashes or spurious exceptions 

44 for dtype in [np.float32, np.float64, np.int32]: 

45 for op in [np.random.random((25,)).astype(dtype), dtype(25)]: 

46 for ufunc in [np.isnan, np.sin]: 

47 run_threaded(lambda: ufunc(op), 500) 

48 

49 # see gh-26690 

50 NUM_THREADS = 50 

51 

52 a = np.ones(1000) 

53 

54 def f(b): 

55 b.wait() 

56 return a.sum() 

57 

58 run_threaded(f, NUM_THREADS, pass_barrier=True) 

59 

60 

61def test_temp_elision_thread_safety(): 

62 amid = np.ones(50000) 

63 bmid = np.ones(50000) 

64 alarge = np.ones(1000000) 

65 blarge = np.ones(1000000) 

66 

67 def func(count): 

68 if count % 4 == 0: 

69 (amid * 2) + bmid 

70 elif count % 4 == 1: 

71 (amid + bmid) - 2 

72 elif count % 4 == 2: 

73 (alarge * 2) + blarge 

74 else: 

75 (alarge + blarge) - 2 

76 

77 run_threaded(func, 100, pass_count=True) 

78 

79 

80def test_eigvalsh_thread_safety(): 

81 # if lapack isn't thread safe this will randomly segfault or error 

82 # see gh-24512 

83 rng = np.random.RandomState(873699172) 

84 matrices = ( 

85 rng.random((5, 10, 10, 3, 3)), 

86 rng.random((5, 10, 10, 3, 3)), 

87 ) 

88 

89 run_threaded(lambda i: np.linalg.eigvalsh(matrices[i]), 2, 

90 pass_count=True) 

91 

92 

93def test_printoptions_thread_safety(): 

94 # until NumPy 2.1 the printoptions state was stored in globals 

95 # this verifies that they are now stored in a context variable 

96 b = threading.Barrier(2) 

97 

98 def legacy_113(): 

99 np.set_printoptions(legacy='1.13', precision=12) 

100 b.wait() 

101 po = np.get_printoptions() 

102 assert po['legacy'] == '1.13' 

103 assert po['precision'] == 12 

104 orig_linewidth = po['linewidth'] 

105 with np.printoptions(linewidth=34, legacy='1.21'): 

106 po = np.get_printoptions() 

107 assert po['legacy'] == '1.21' 

108 assert po['precision'] == 12 

109 assert po['linewidth'] == 34 

110 po = np.get_printoptions() 

111 assert po['linewidth'] == orig_linewidth 

112 assert po['legacy'] == '1.13' 

113 assert po['precision'] == 12 

114 

115 def legacy_125(): 

116 np.set_printoptions(legacy='1.25', precision=7) 

117 b.wait() 

118 po = np.get_printoptions() 

119 assert po['legacy'] == '1.25' 

120 assert po['precision'] == 7 

121 orig_linewidth = po['linewidth'] 

122 with np.printoptions(linewidth=6, legacy='1.13'): 

123 po = np.get_printoptions() 

124 assert po['legacy'] == '1.13' 

125 assert po['precision'] == 7 

126 assert po['linewidth'] == 6 

127 po = np.get_printoptions() 

128 assert po['linewidth'] == orig_linewidth 

129 assert po['legacy'] == '1.25' 

130 assert po['precision'] == 7 

131 

132 task1 = threading.Thread(target=legacy_113) 

133 task2 = threading.Thread(target=legacy_125) 

134 

135 task1.start() 

136 task2.start() 

137 task1.join() 

138 task2.join() 

139 

140 

141def test_parallel_reduction(): 

142 # gh-28041 

143 NUM_THREADS = 50 

144 

145 x = np.arange(1000) 

146 

147 def closure(b): 

148 b.wait() 

149 np.sum(x) 

150 

151 run_threaded(closure, NUM_THREADS, pass_barrier=True) 

152 

153 

154def test_parallel_flat_iterator(): 

155 # gh-28042 

156 x = np.arange(20).reshape(5, 4).T 

157 

158 def closure(b): 

159 b.wait() 

160 for _ in range(100): 

161 list(x.flat) 

162 

163 run_threaded(closure, outer_iterations=100, pass_barrier=True) 

164 

165 # gh-28143 

166 def prepare_args(): 

167 return [np.arange(10)] 

168 

169 def closure(x, b): 

170 b.wait() 

171 for _ in range(100): 

172 y = np.arange(10) 

173 y.flat[x] = x 

174 

175 run_threaded(closure, pass_barrier=True, prepare_args=prepare_args) 

176 

177 

178def test_multithreaded_repeat(): 

179 x0 = np.arange(10) 

180 

181 def closure(b): 

182 b.wait() 

183 for _ in range(100): 

184 x = np.repeat(x0, 2, axis=0)[::2] 

185 

186 run_threaded(closure, max_workers=10, pass_barrier=True) 

187 

188 

189def test_structured_advanced_indexing(): 

190 # Test that copyswap(n) used by integer array indexing is threadsafe 

191 # for structured datatypes, see gh-15387. This test can behave randomly. 

192 

193 # Create a deeply nested dtype to make a failure more likely: 

194 dt = np.dtype([("", "f8")]) 

195 dt = np.dtype([("", dt)] * 2) 

196 dt = np.dtype([("", dt)] * 2) 

197 # The array should be large enough to likely run into threading issues 

198 arr = np.random.uniform(size=(6000, 8)).view(dt)[:, 0] 

199 

200 rng = np.random.default_rng() 

201 

202 def func(arr): 

203 indx = rng.integers(0, len(arr), size=6000, dtype=np.intp) 

204 arr[indx] 

205 

206 tpe = concurrent.futures.ThreadPoolExecutor(max_workers=8) 

207 futures = [tpe.submit(func, arr) for _ in range(10)] 

208 for f in futures: 

209 f.result() 

210 

211 assert arr.dtype is dt 

212 

213 

214def test_structured_threadsafety2(): 

215 # Nonzero (and some other functions) should be threadsafe for 

216 # structured datatypes, see gh-15387. This test can behave randomly. 

217 from concurrent.futures import ThreadPoolExecutor 

218 

219 # Create a deeply nested dtype to make a failure more likely: 

220 dt = np.dtype([("", "f8")]) 

221 dt = np.dtype([("", dt)]) 

222 dt = np.dtype([("", dt)] * 2) 

223 # The array should be large enough to likely run into threading issues 

224 arr = np.random.uniform(size=(5000, 4)).view(dt)[:, 0] 

225 

226 def func(arr): 

227 arr.nonzero() 

228 

229 tpe = ThreadPoolExecutor(max_workers=8) 

230 futures = [tpe.submit(func, arr) for _ in range(10)] 

231 for f in futures: 

232 f.result() 

233 

234 assert arr.dtype is dt 

235 

236 

237def test_stringdtype_multithreaded_access_and_mutation(): 

238 # this test uses an RNG and may crash or cause deadlocks if there is a 

239 # threading bug 

240 rng = np.random.default_rng(0x4D3D3D3) 

241 

242 string_list = random_unicode_string_list() 

243 

244 def func(arr): 

245 rnd = rng.random() 

246 # either write to random locations in the array, compute a ufunc, or 

247 # re-initialize the array 

248 if rnd < 0.25: 

249 num = np.random.randint(0, arr.size) 

250 arr[num] = arr[num] + "hello" 

251 elif rnd < 0.5: 

252 if rnd < 0.375: 

253 np.add(arr, arr) 

254 else: 

255 np.add(arr, arr, out=arr) 

256 elif rnd < 0.75: 

257 if rnd < 0.875: 

258 np.multiply(arr, np.int64(2)) 

259 else: 

260 np.multiply(arr, np.int64(2), out=arr) 

261 else: 

262 arr[:] = string_list 

263 

264 with concurrent.futures.ThreadPoolExecutor(max_workers=8) as tpe: 

265 arr = np.array(string_list, dtype="T") 

266 futures = [tpe.submit(func, arr) for _ in range(500)] 

267 

268 for f in futures: 

269 f.result() 

270 

271 

272@pytest.mark.skipif( 

273 not IS_64BIT, 

274 reason="Sometimes causes failures or crashes due to OOM on 32 bit runners" 

275) 

276def test_legacy_usertype_cast_init_thread_safety(): 

277 def closure(b): 

278 b.wait() 

279 np.full((10, 10), 1, _rational_tests.rational) 

280 

281 run_threaded(closure, 250, pass_barrier=True) 

282 

283@pytest.mark.parametrize("dtype", [bool, int, float]) 

284def test_nonzero(dtype): 

285 # See: gh-28361 

286 # 

287 # np.nonzero uses np.count_nonzero to determine the size of the output. 

288 # array. In a second pass the indices of the non-zero elements are 

289 # determined, but they can have changed 

290 # 

291 # This test triggers a data race which is suppressed in the TSAN CI. 

292 # The test is to ensure np.nonzero does not generate a segmentation fault 

293 x = np.random.randint(4, size=100).astype(dtype) 

294 expected_warning = ('number of non-zero array elements changed' 

295 ' during function execution') 

296 

297 def func(index): 

298 for _ in range(10): 

299 if index == 0: 

300 x[::2] = np.random.randint(2) 

301 else: 

302 try: 

303 _ = np.nonzero(x) 

304 except RuntimeError as ex: 

305 assert expected_warning in str(ex) 

306 

307 run_threaded(func, max_workers=10, pass_count=True, outer_iterations=5) 

308 

309 

310# These are all implemented using PySequence_Fast, which needs locking to be safe 

311def np_broadcast(arrs): 

312 for i in range(50): 

313 np.broadcast(arrs) 

314 

315def create_array(arrs): 

316 for i in range(50): 

317 np.array(arrs) 

318 

319def create_nditer(arrs): 

320 for i in range(50): 

321 np.nditer(arrs) 

322 

323 

324@pytest.mark.parametrize( 

325 "kernel, outcome", 

326 ( 

327 (np_broadcast, "error"), 

328 (create_array, "error"), 

329 (create_nditer, "success"), 

330 ), 

331) 

332def test_arg_locking(kernel, outcome): 

333 # should complete without triggering races but may error 

334 

335 done = 0 

336 arrs = [np.array([1, 2, 3]) for _ in range(1000)] 

337 

338 def read_arrs(b): 

339 nonlocal done 

340 b.wait() 

341 try: 

342 kernel(arrs) 

343 finally: 

344 done += 1 

345 

346 def contract_and_expand_list(b): 

347 b.wait() 

348 while done < 4: 

349 if len(arrs) > 10: 

350 arrs.pop(0) 

351 elif len(arrs) <= 10: 

352 arrs.extend([np.array([1, 2, 3]) for _ in range(1000)]) 

353 

354 def replace_list_items(b): 

355 b.wait() 

356 rng = np.random.RandomState() 

357 rng.seed(0x4d3d3d3) 

358 while done < 4: 

359 data = rng.randint(0, 1000, size=4) 

360 arrs[data[0]] = data[1:] 

361 

362 for mutation_func in (replace_list_items, contract_and_expand_list): 

363 b = threading.Barrier(5) 

364 try: 

365 with concurrent.futures.ThreadPoolExecutor(max_workers=5) as tpe: 

366 tasks = [tpe.submit(read_arrs, b) for _ in range(4)] 

367 tasks.append(tpe.submit(mutation_func, b)) 

368 for t in tasks: 

369 t.result() 

370 except RuntimeError as e: 

371 if outcome == "success": 

372 raise 

373 assert "Inconsistent object during array creation?" in str(e) 

374 msg = "replace_list_items should not raise errors" 

375 assert mutation_func is contract_and_expand_list, msg 

376 finally: 

377 if len(tasks) < 5: 

378 b.abort() 

379 

380@pytest.mark.skipif(sys.version_info < (3, 12), reason="Python >= 3.12 required") 

381def test_array__buffer__thread_safety(): 

382 import inspect 

383 arr = np.arange(1000) 

384 flags = [inspect.BufferFlags.STRIDED, inspect.BufferFlags.READ] 

385 

386 def func(b): 

387 b.wait() 

388 for i in range(100): 

389 arr.__buffer__(flags[i % 2]) 

390 

391 run_threaded(func, max_workers=8, pass_barrier=True) 

392 

393@pytest.mark.skipif(sys.version_info < (3, 12), reason="Python >= 3.12 required") 

394def test_void_dtype__buffer__thread_safety(): 

395 import inspect 

396 dt = np.dtype([('name', np.str_, 16), ('grades', np.float64, (2,))]) 

397 x = np.array(('ndarray_scalar', (1.2, 3.0)), dtype=dt)[()] 

398 assert isinstance(x, np.void) 

399 flags = [inspect.BufferFlags.STRIDES, inspect.BufferFlags.READ] 

400 

401 def func(b): 

402 b.wait() 

403 for i in range(100): 

404 x.__buffer__(flags[i % 2]) 

405 

406 run_threaded(func, max_workers=8, pass_barrier=True)