Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/numpy/testing/_private/utils.py: 14%

876 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-23 06:43 +0000

1""" 

2Utility function to facilitate testing. 

3 

4""" 

5import os 

6import sys 

7import platform 

8import re 

9import gc 

10import operator 

11import warnings 

12from functools import partial, wraps 

13import shutil 

14import contextlib 

15from tempfile import mkdtemp, mkstemp 

16from unittest.case import SkipTest 

17from warnings import WarningMessage 

18import pprint 

19import sysconfig 

20 

21import numpy as np 

22from numpy.core import ( 

23 intp, float32, empty, arange, array_repr, ndarray, isnat, array) 

24from numpy import isfinite, isnan, isinf 

25import numpy.linalg._umath_linalg 

26 

27from io import StringIO 

28 

29__all__ = [ 

30 'assert_equal', 'assert_almost_equal', 'assert_approx_equal', 

31 'assert_array_equal', 'assert_array_less', 'assert_string_equal', 

32 'assert_array_almost_equal', 'assert_raises', 'build_err_msg', 

33 'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal', 

34 'rundocs', 'runstring', 'verbose', 'measure', 

35 'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex', 

36 'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings', 

37 'assert_allclose', 'IgnoreException', 'clear_and_catch_warnings', 

38 'SkipTest', 'KnownFailureException', 'temppath', 'tempdir', 'IS_PYPY', 

39 'HAS_REFCOUNT', "IS_WASM", 'suppress_warnings', 'assert_array_compare', 

40 'assert_no_gc_cycles', 'break_cycles', 'HAS_LAPACK64', 'IS_PYSTON', 

41 '_OLD_PROMOTION', 'IS_MUSL', '_SUPPORTS_SVE' 

42 ] 

43 

44 

45class KnownFailureException(Exception): 

46 '''Raise this exception to mark a test as a known failing test.''' 

47 pass 

48 

49 

50KnownFailureTest = KnownFailureException # backwards compat 

51verbose = 0 

52 

53IS_WASM = platform.machine() in ["wasm32", "wasm64"] 

54IS_PYPY = sys.implementation.name == 'pypy' 

55IS_PYSTON = hasattr(sys, "pyston_version_info") 

56HAS_REFCOUNT = getattr(sys, 'getrefcount', None) is not None and not IS_PYSTON 

57HAS_LAPACK64 = numpy.linalg._umath_linalg._ilp64 

58 

59_OLD_PROMOTION = lambda: np._get_promotion_state() == 'legacy' 

60 

61IS_MUSL = False 

62# alternate way is 

63# from packaging.tags import sys_tags 

64# _tags = list(sys_tags()) 

65# if 'musllinux' in _tags[0].platform: 

66_v = sysconfig.get_config_var('HOST_GNU_TYPE') or '' 

67if 'musl' in _v: 

68 IS_MUSL = True 

69 

70 

71def assert_(val, msg=''): 

72 """ 

73 Assert that works in release mode. 

74 Accepts callable msg to allow deferring evaluation until failure. 

75 

76 The Python built-in ``assert`` does not work when executing code in 

77 optimized mode (the ``-O`` flag) - no byte-code is generated for it. 

78 

79 For documentation on usage, refer to the Python documentation. 

80 

81 """ 

82 __tracebackhide__ = True # Hide traceback for py.test 

83 if not val: 

84 try: 

85 smsg = msg() 

86 except TypeError: 

87 smsg = msg 

88 raise AssertionError(smsg) 

89 

90 

91if os.name == 'nt': 

92 # Code "stolen" from enthought/debug/memusage.py 

93 def GetPerformanceAttributes(object, counter, instance=None, 

94 inum=-1, format=None, machine=None): 

95 # NOTE: Many counters require 2 samples to give accurate results, 

96 # including "% Processor Time" (as by definition, at any instant, a 

97 # thread's CPU usage is either 0 or 100). To read counters like this, 

98 # you should copy this function, but keep the counter open, and call 

99 # CollectQueryData() each time you need to know. 

100 # See http://msdn.microsoft.com/library/en-us/dnperfmo/html/perfmonpt2.asp (dead link) 

101 # My older explanation for this was that the "AddCounter" process 

102 # forced the CPU to 100%, but the above makes more sense :) 

103 import win32pdh 

104 if format is None: 

105 format = win32pdh.PDH_FMT_LONG 

106 path = win32pdh.MakeCounterPath( (machine, object, instance, None, 

107 inum, counter)) 

108 hq = win32pdh.OpenQuery() 

109 try: 

110 hc = win32pdh.AddCounter(hq, path) 

111 try: 

112 win32pdh.CollectQueryData(hq) 

113 type, val = win32pdh.GetFormattedCounterValue(hc, format) 

114 return val 

115 finally: 

116 win32pdh.RemoveCounter(hc) 

117 finally: 

118 win32pdh.CloseQuery(hq) 

119 

120 def memusage(processName="python", instance=0): 

121 # from win32pdhutil, part of the win32all package 

122 import win32pdh 

123 return GetPerformanceAttributes("Process", "Virtual Bytes", 

124 processName, instance, 

125 win32pdh.PDH_FMT_LONG, None) 

126elif sys.platform[:5] == 'linux': 

127 

128 def memusage(_proc_pid_stat=f'/proc/{os.getpid()}/stat'): 

129 """ 

130 Return virtual memory size in bytes of the running python. 

131 

132 """ 

133 try: 

134 with open(_proc_pid_stat) as f: 

135 l = f.readline().split(' ') 

136 return int(l[22]) 

137 except Exception: 

138 return 

139else: 

140 def memusage(): 

141 """ 

142 Return memory usage of running python. [Not implemented] 

143 

144 """ 

145 raise NotImplementedError 

146 

147 

148if sys.platform[:5] == 'linux': 

149 def jiffies(_proc_pid_stat=f'/proc/{os.getpid()}/stat', _load_time=[]): 

150 """ 

151 Return number of jiffies elapsed. 

152 

153 Return number of jiffies (1/100ths of a second) that this 

154 process has been scheduled in user mode. See man 5 proc. 

155 

156 """ 

157 import time 

158 if not _load_time: 

159 _load_time.append(time.time()) 

160 try: 

161 with open(_proc_pid_stat) as f: 

162 l = f.readline().split(' ') 

163 return int(l[13]) 

164 except Exception: 

165 return int(100*(time.time()-_load_time[0])) 

166else: 

167 # os.getpid is not in all platforms available. 

168 # Using time is safe but inaccurate, especially when process 

169 # was suspended or sleeping. 

170 def jiffies(_load_time=[]): 

171 """ 

172 Return number of jiffies elapsed. 

173 

174 Return number of jiffies (1/100ths of a second) that this 

175 process has been scheduled in user mode. See man 5 proc. 

176 

177 """ 

178 import time 

179 if not _load_time: 

180 _load_time.append(time.time()) 

181 return int(100*(time.time()-_load_time[0])) 

182 

183 

184def build_err_msg(arrays, err_msg, header='Items are not equal:', 

185 verbose=True, names=('ACTUAL', 'DESIRED'), precision=8): 

186 msg = ['\n' + header] 

187 if err_msg: 

188 if err_msg.find('\n') == -1 and len(err_msg) < 79-len(header): 

189 msg = [msg[0] + ' ' + err_msg] 

190 else: 

191 msg.append(err_msg) 

192 if verbose: 

193 for i, a in enumerate(arrays): 

194 

195 if isinstance(a, ndarray): 

196 # precision argument is only needed if the objects are ndarrays 

197 r_func = partial(array_repr, precision=precision) 

198 else: 

199 r_func = repr 

200 

201 try: 

202 r = r_func(a) 

203 except Exception as exc: 

204 r = f'[repr failed for <{type(a).__name__}>: {exc}]' 

205 if r.count('\n') > 3: 

206 r = '\n'.join(r.splitlines()[:3]) 

207 r += '...' 

208 msg.append(f' {names[i]}: {r}') 

209 return '\n'.join(msg) 

210 

211 

212def assert_equal(actual, desired, err_msg='', verbose=True): 

213 """ 

214 Raises an AssertionError if two objects are not equal. 

215 

216 Given two objects (scalars, lists, tuples, dictionaries or numpy arrays), 

217 check that all elements of these objects are equal. An exception is raised 

218 at the first conflicting values. 

219 

220 When one of `actual` and `desired` is a scalar and the other is array_like, 

221 the function checks that each element of the array_like object is equal to 

222 the scalar. 

223 

224 This function handles NaN comparisons as if NaN was a "normal" number. 

225 That is, AssertionError is not raised if both objects have NaNs in the same 

226 positions. This is in contrast to the IEEE standard on NaNs, which says 

227 that NaN compared to anything must return False. 

228 

229 Parameters 

230 ---------- 

231 actual : array_like 

232 The object to check. 

233 desired : array_like 

234 The expected object. 

235 err_msg : str, optional 

236 The error message to be printed in case of failure. 

237 verbose : bool, optional 

238 If True, the conflicting values are appended to the error message. 

239 

240 Raises 

241 ------ 

242 AssertionError 

243 If actual and desired are not equal. 

244 

245 Examples 

246 -------- 

247 >>> np.testing.assert_equal([4,5], [4,6]) 

248 Traceback (most recent call last): 

249 ... 

250 AssertionError: 

251 Items are not equal: 

252 item=1 

253 ACTUAL: 5 

254 DESIRED: 6 

255 

256 The following comparison does not raise an exception. There are NaNs 

257 in the inputs, but they are in the same positions. 

258 

259 >>> np.testing.assert_equal(np.array([1.0, 2.0, np.nan]), [1, 2, np.nan]) 

260 

261 """ 

262 __tracebackhide__ = True # Hide traceback for py.test 

263 if isinstance(desired, dict): 

264 if not isinstance(actual, dict): 

265 raise AssertionError(repr(type(actual))) 

266 assert_equal(len(actual), len(desired), err_msg, verbose) 

267 for k, i in desired.items(): 

268 if k not in actual: 

269 raise AssertionError(repr(k)) 

270 assert_equal(actual[k], desired[k], f'key={k!r}\n{err_msg}', 

271 verbose) 

272 return 

273 if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)): 

274 assert_equal(len(actual), len(desired), err_msg, verbose) 

275 for k in range(len(desired)): 

276 assert_equal(actual[k], desired[k], f'item={k!r}\n{err_msg}', 

277 verbose) 

278 return 

279 from numpy.core import ndarray, isscalar, signbit 

280 from numpy.lib import iscomplexobj, real, imag 

281 if isinstance(actual, ndarray) or isinstance(desired, ndarray): 

282 return assert_array_equal(actual, desired, err_msg, verbose) 

283 msg = build_err_msg([actual, desired], err_msg, verbose=verbose) 

284 

285 # Handle complex numbers: separate into real/imag to handle 

286 # nan/inf/negative zero correctly 

287 # XXX: catch ValueError for subclasses of ndarray where iscomplex fail 

288 try: 

289 usecomplex = iscomplexobj(actual) or iscomplexobj(desired) 

290 except (ValueError, TypeError): 

291 usecomplex = False 

292 

293 if usecomplex: 

294 if iscomplexobj(actual): 

295 actualr = real(actual) 

296 actuali = imag(actual) 

297 else: 

298 actualr = actual 

299 actuali = 0 

300 if iscomplexobj(desired): 

301 desiredr = real(desired) 

302 desiredi = imag(desired) 

303 else: 

304 desiredr = desired 

305 desiredi = 0 

306 try: 

307 assert_equal(actualr, desiredr) 

308 assert_equal(actuali, desiredi) 

309 except AssertionError: 

310 raise AssertionError(msg) 

311 

312 # isscalar test to check cases such as [np.nan] != np.nan 

313 if isscalar(desired) != isscalar(actual): 

314 raise AssertionError(msg) 

315 

316 try: 

317 isdesnat = isnat(desired) 

318 isactnat = isnat(actual) 

319 dtypes_match = (np.asarray(desired).dtype.type == 

320 np.asarray(actual).dtype.type) 

321 if isdesnat and isactnat: 

322 # If both are NaT (and have the same dtype -- datetime or 

323 # timedelta) they are considered equal. 

324 if dtypes_match: 

325 return 

326 else: 

327 raise AssertionError(msg) 

328 

329 except (TypeError, ValueError, NotImplementedError): 

330 pass 

331 

332 # Inf/nan/negative zero handling 

333 try: 

334 isdesnan = isnan(desired) 

335 isactnan = isnan(actual) 

336 if isdesnan and isactnan: 

337 return # both nan, so equal 

338 

339 # handle signed zero specially for floats 

340 array_actual = np.asarray(actual) 

341 array_desired = np.asarray(desired) 

342 if (array_actual.dtype.char in 'Mm' or 

343 array_desired.dtype.char in 'Mm'): 

344 # version 1.18 

345 # until this version, isnan failed for datetime64 and timedelta64. 

346 # Now it succeeds but comparison to scalar with a different type 

347 # emits a DeprecationWarning. 

348 # Avoid that by skipping the next check 

349 raise NotImplementedError('cannot compare to a scalar ' 

350 'with a different type') 

351 

352 if desired == 0 and actual == 0: 

353 if not signbit(desired) == signbit(actual): 

354 raise AssertionError(msg) 

355 

356 except (TypeError, ValueError, NotImplementedError): 

357 pass 

358 

359 try: 

360 # Explicitly use __eq__ for comparison, gh-2552 

361 if not (desired == actual): 

362 raise AssertionError(msg) 

363 

364 except (DeprecationWarning, FutureWarning) as e: 

365 # this handles the case when the two types are not even comparable 

366 if 'elementwise == comparison' in e.args[0]: 

367 raise AssertionError(msg) 

368 else: 

369 raise 

370 

371 

372def print_assert_equal(test_string, actual, desired): 

373 """ 

374 Test if two objects are equal, and print an error message if test fails. 

375 

376 The test is performed with ``actual == desired``. 

377 

378 Parameters 

379 ---------- 

380 test_string : str 

381 The message supplied to AssertionError. 

382 actual : object 

383 The object to test for equality against `desired`. 

384 desired : object 

385 The expected result. 

386 

387 Examples 

388 -------- 

389 >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 1]) 

390 >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 2]) 

391 Traceback (most recent call last): 

392 ... 

393 AssertionError: Test XYZ of func xyz failed 

394 ACTUAL: 

395 [0, 1] 

396 DESIRED: 

397 [0, 2] 

398 

399 """ 

400 __tracebackhide__ = True # Hide traceback for py.test 

401 import pprint 

402 

403 if not (actual == desired): 

404 msg = StringIO() 

405 msg.write(test_string) 

406 msg.write(' failed\nACTUAL: \n') 

407 pprint.pprint(actual, msg) 

408 msg.write('DESIRED: \n') 

409 pprint.pprint(desired, msg) 

410 raise AssertionError(msg.getvalue()) 

411 

412 

413@np._no_nep50_warning() 

414def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True): 

415 """ 

416 Raises an AssertionError if two items are not equal up to desired 

417 precision. 

418 

419 .. note:: It is recommended to use one of `assert_allclose`, 

420 `assert_array_almost_equal_nulp` or `assert_array_max_ulp` 

421 instead of this function for more consistent floating point 

422 comparisons. 

423 

424 The test verifies that the elements of `actual` and `desired` satisfy. 

425 

426 ``abs(desired-actual) < float64(1.5 * 10**(-decimal))`` 

427 

428 That is a looser test than originally documented, but agrees with what the 

429 actual implementation in `assert_array_almost_equal` did up to rounding 

430 vagaries. An exception is raised at conflicting values. For ndarrays this 

431 delegates to assert_array_almost_equal 

432 

433 Parameters 

434 ---------- 

435 actual : array_like 

436 The object to check. 

437 desired : array_like 

438 The expected object. 

439 decimal : int, optional 

440 Desired precision, default is 7. 

441 err_msg : str, optional 

442 The error message to be printed in case of failure. 

443 verbose : bool, optional 

444 If True, the conflicting values are appended to the error message. 

445 

446 Raises 

447 ------ 

448 AssertionError 

449 If actual and desired are not equal up to specified precision. 

450 

451 See Also 

452 -------- 

453 assert_allclose: Compare two array_like objects for equality with desired 

454 relative and/or absolute precision. 

455 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal 

456 

457 Examples 

458 -------- 

459 >>> from numpy.testing import assert_almost_equal 

460 >>> assert_almost_equal(2.3333333333333, 2.33333334) 

461 >>> assert_almost_equal(2.3333333333333, 2.33333334, decimal=10) 

462 Traceback (most recent call last): 

463 ... 

464 AssertionError: 

465 Arrays are not almost equal to 10 decimals 

466 ACTUAL: 2.3333333333333 

467 DESIRED: 2.33333334 

468 

469 >>> assert_almost_equal(np.array([1.0,2.3333333333333]), 

470 ... np.array([1.0,2.33333334]), decimal=9) 

471 Traceback (most recent call last): 

472 ... 

473 AssertionError: 

474 Arrays are not almost equal to 9 decimals 

475 <BLANKLINE> 

476 Mismatched elements: 1 / 2 (50%) 

477 Max absolute difference: 6.66669964e-09 

478 Max relative difference: 2.85715698e-09 

479 x: array([1. , 2.333333333]) 

480 y: array([1. , 2.33333334]) 

481 

482 """ 

483 __tracebackhide__ = True # Hide traceback for py.test 

484 from numpy.core import ndarray 

485 from numpy.lib import iscomplexobj, real, imag 

486 

487 # Handle complex numbers: separate into real/imag to handle 

488 # nan/inf/negative zero correctly 

489 # XXX: catch ValueError for subclasses of ndarray where iscomplex fail 

490 try: 

491 usecomplex = iscomplexobj(actual) or iscomplexobj(desired) 

492 except ValueError: 

493 usecomplex = False 

494 

495 def _build_err_msg(): 

496 header = ('Arrays are not almost equal to %d decimals' % decimal) 

497 return build_err_msg([actual, desired], err_msg, verbose=verbose, 

498 header=header) 

499 

500 if usecomplex: 

501 if iscomplexobj(actual): 

502 actualr = real(actual) 

503 actuali = imag(actual) 

504 else: 

505 actualr = actual 

506 actuali = 0 

507 if iscomplexobj(desired): 

508 desiredr = real(desired) 

509 desiredi = imag(desired) 

510 else: 

511 desiredr = desired 

512 desiredi = 0 

513 try: 

514 assert_almost_equal(actualr, desiredr, decimal=decimal) 

515 assert_almost_equal(actuali, desiredi, decimal=decimal) 

516 except AssertionError: 

517 raise AssertionError(_build_err_msg()) 

518 

519 if isinstance(actual, (ndarray, tuple, list)) \ 

520 or isinstance(desired, (ndarray, tuple, list)): 

521 return assert_array_almost_equal(actual, desired, decimal, err_msg) 

522 try: 

523 # If one of desired/actual is not finite, handle it specially here: 

524 # check that both are nan if any is a nan, and test for equality 

525 # otherwise 

526 if not (isfinite(desired) and isfinite(actual)): 

527 if isnan(desired) or isnan(actual): 

528 if not (isnan(desired) and isnan(actual)): 

529 raise AssertionError(_build_err_msg()) 

530 else: 

531 if not desired == actual: 

532 raise AssertionError(_build_err_msg()) 

533 return 

534 except (NotImplementedError, TypeError): 

535 pass 

536 if abs(desired - actual) >= np.float64(1.5 * 10.0**(-decimal)): 

537 raise AssertionError(_build_err_msg()) 

538 

539 

540@np._no_nep50_warning() 

541def assert_approx_equal(actual, desired, significant=7, err_msg='', 

542 verbose=True): 

543 """ 

544 Raises an AssertionError if two items are not equal up to significant 

545 digits. 

546 

547 .. note:: It is recommended to use one of `assert_allclose`, 

548 `assert_array_almost_equal_nulp` or `assert_array_max_ulp` 

549 instead of this function for more consistent floating point 

550 comparisons. 

551 

552 Given two numbers, check that they are approximately equal. 

553 Approximately equal is defined as the number of significant digits 

554 that agree. 

555 

556 Parameters 

557 ---------- 

558 actual : scalar 

559 The object to check. 

560 desired : scalar 

561 The expected object. 

562 significant : int, optional 

563 Desired precision, default is 7. 

564 err_msg : str, optional 

565 The error message to be printed in case of failure. 

566 verbose : bool, optional 

567 If True, the conflicting values are appended to the error message. 

568 

569 Raises 

570 ------ 

571 AssertionError 

572 If actual and desired are not equal up to specified precision. 

573 

574 See Also 

575 -------- 

576 assert_allclose: Compare two array_like objects for equality with desired 

577 relative and/or absolute precision. 

578 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal 

579 

580 Examples 

581 -------- 

582 >>> np.testing.assert_approx_equal(0.12345677777777e-20, 0.1234567e-20) 

583 >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345671e-20, 

584 ... significant=8) 

585 >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345672e-20, 

586 ... significant=8) 

587 Traceback (most recent call last): 

588 ... 

589 AssertionError: 

590 Items are not equal to 8 significant digits: 

591 ACTUAL: 1.234567e-21 

592 DESIRED: 1.2345672e-21 

593 

594 the evaluated condition that raises the exception is 

595 

596 >>> abs(0.12345670e-20/1e-21 - 0.12345672e-20/1e-21) >= 10**-(8-1) 

597 True 

598 

599 """ 

600 __tracebackhide__ = True # Hide traceback for py.test 

601 import numpy as np 

602 

603 (actual, desired) = map(float, (actual, desired)) 

604 if desired == actual: 

605 return 

606 # Normalized the numbers to be in range (-10.0,10.0) 

607 # scale = float(pow(10,math.floor(math.log10(0.5*(abs(desired)+abs(actual)))))) 

608 with np.errstate(invalid='ignore'): 

609 scale = 0.5*(np.abs(desired) + np.abs(actual)) 

610 scale = np.power(10, np.floor(np.log10(scale))) 

611 try: 

612 sc_desired = desired/scale 

613 except ZeroDivisionError: 

614 sc_desired = 0.0 

615 try: 

616 sc_actual = actual/scale 

617 except ZeroDivisionError: 

618 sc_actual = 0.0 

619 msg = build_err_msg( 

620 [actual, desired], err_msg, 

621 header='Items are not equal to %d significant digits:' % significant, 

622 verbose=verbose) 

623 try: 

624 # If one of desired/actual is not finite, handle it specially here: 

625 # check that both are nan if any is a nan, and test for equality 

626 # otherwise 

627 if not (isfinite(desired) and isfinite(actual)): 

628 if isnan(desired) or isnan(actual): 

629 if not (isnan(desired) and isnan(actual)): 

630 raise AssertionError(msg) 

631 else: 

632 if not desired == actual: 

633 raise AssertionError(msg) 

634 return 

635 except (TypeError, NotImplementedError): 

636 pass 

637 if np.abs(sc_desired - sc_actual) >= np.power(10., -(significant-1)): 

638 raise AssertionError(msg) 

639 

640 

641@np._no_nep50_warning() 

642def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', 

643 precision=6, equal_nan=True, equal_inf=True, 

644 *, strict=False): 

645 __tracebackhide__ = True # Hide traceback for py.test 

646 from numpy.core import (array2string, isnan, inf, bool_, errstate, 

647 all, max, object_) 

648 

649 x = np.asanyarray(x) 

650 y = np.asanyarray(y) 

651 

652 # original array for output formatting 

653 ox, oy = x, y 

654 

655 def isnumber(x): 

656 return x.dtype.char in '?bhilqpBHILQPefdgFDG' 

657 

658 def istime(x): 

659 return x.dtype.char in "Mm" 

660 

661 def func_assert_same_pos(x, y, func=isnan, hasval='nan'): 

662 """Handling nan/inf. 

663 

664 Combine results of running func on x and y, checking that they are True 

665 at the same locations. 

666 

667 """ 

668 __tracebackhide__ = True # Hide traceback for py.test 

669 

670 x_id = func(x) 

671 y_id = func(y) 

672 # We include work-arounds here to handle three types of slightly 

673 # pathological ndarray subclasses: 

674 # (1) all() on `masked` array scalars can return masked arrays, so we 

675 # use != True 

676 # (2) __eq__ on some ndarray subclasses returns Python booleans 

677 # instead of element-wise comparisons, so we cast to bool_() and 

678 # use isinstance(..., bool) checks 

679 # (3) subclasses with bare-bones __array_function__ implementations may 

680 # not implement np.all(), so favor using the .all() method 

681 # We are not committed to supporting such subclasses, but it's nice to 

682 # support them if possible. 

683 if bool_(x_id == y_id).all() != True: 

684 msg = build_err_msg([x, y], 

685 err_msg + '\nx and y %s location mismatch:' 

686 % (hasval), verbose=verbose, header=header, 

687 names=('x', 'y'), precision=precision) 

688 raise AssertionError(msg) 

689 # If there is a scalar, then here we know the array has the same 

690 # flag as it everywhere, so we should return the scalar flag. 

691 if isinstance(x_id, bool) or x_id.ndim == 0: 

692 return bool_(x_id) 

693 elif isinstance(y_id, bool) or y_id.ndim == 0: 

694 return bool_(y_id) 

695 else: 

696 return y_id 

697 

698 try: 

699 if strict: 

700 cond = x.shape == y.shape and x.dtype == y.dtype 

701 else: 

702 cond = (x.shape == () or y.shape == ()) or x.shape == y.shape 

703 if not cond: 

704 if x.shape != y.shape: 

705 reason = f'\n(shapes {x.shape}, {y.shape} mismatch)' 

706 else: 

707 reason = f'\n(dtypes {x.dtype}, {y.dtype} mismatch)' 

708 msg = build_err_msg([x, y], 

709 err_msg 

710 + reason, 

711 verbose=verbose, header=header, 

712 names=('x', 'y'), precision=precision) 

713 raise AssertionError(msg) 

714 

715 flagged = bool_(False) 

716 if isnumber(x) and isnumber(y): 

717 if equal_nan: 

718 flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan') 

719 

720 if equal_inf: 

721 flagged |= func_assert_same_pos(x, y, 

722 func=lambda xy: xy == +inf, 

723 hasval='+inf') 

724 flagged |= func_assert_same_pos(x, y, 

725 func=lambda xy: xy == -inf, 

726 hasval='-inf') 

727 

728 elif istime(x) and istime(y): 

729 # If one is datetime64 and the other timedelta64 there is no point 

730 if equal_nan and x.dtype.type == y.dtype.type: 

731 flagged = func_assert_same_pos(x, y, func=isnat, hasval="NaT") 

732 

733 if flagged.ndim > 0: 

734 x, y = x[~flagged], y[~flagged] 

735 # Only do the comparison if actual values are left 

736 if x.size == 0: 

737 return 

738 elif flagged: 

739 # no sense doing comparison if everything is flagged. 

740 return 

741 

742 val = comparison(x, y) 

743 

744 if isinstance(val, bool): 

745 cond = val 

746 reduced = array([val]) 

747 else: 

748 reduced = val.ravel() 

749 cond = reduced.all() 

750 

751 # The below comparison is a hack to ensure that fully masked 

752 # results, for which val.ravel().all() returns np.ma.masked, 

753 # do not trigger a failure (np.ma.masked != True evaluates as 

754 # np.ma.masked, which is falsy). 

755 if cond != True: 

756 n_mismatch = reduced.size - reduced.sum(dtype=intp) 

757 n_elements = flagged.size if flagged.ndim != 0 else reduced.size 

758 percent_mismatch = 100 * n_mismatch / n_elements 

759 remarks = [ 

760 'Mismatched elements: {} / {} ({:.3g}%)'.format( 

761 n_mismatch, n_elements, percent_mismatch)] 

762 

763 with errstate(all='ignore'): 

764 # ignore errors for non-numeric types 

765 with contextlib.suppress(TypeError): 

766 error = abs(x - y) 

767 if np.issubdtype(x.dtype, np.unsignedinteger): 

768 error2 = abs(y - x) 

769 np.minimum(error, error2, out=error) 

770 max_abs_error = max(error) 

771 if getattr(error, 'dtype', object_) == object_: 

772 remarks.append('Max absolute difference: ' 

773 + str(max_abs_error)) 

774 else: 

775 remarks.append('Max absolute difference: ' 

776 + array2string(max_abs_error)) 

777 

778 # note: this definition of relative error matches that one 

779 # used by assert_allclose (found in np.isclose) 

780 # Filter values where the divisor would be zero 

781 nonzero = bool_(y != 0) 

782 if all(~nonzero): 

783 max_rel_error = array(inf) 

784 else: 

785 max_rel_error = max(error[nonzero] / abs(y[nonzero])) 

786 if getattr(error, 'dtype', object_) == object_: 

787 remarks.append('Max relative difference: ' 

788 + str(max_rel_error)) 

789 else: 

790 remarks.append('Max relative difference: ' 

791 + array2string(max_rel_error)) 

792 

793 err_msg += '\n' + '\n'.join(remarks) 

794 msg = build_err_msg([ox, oy], err_msg, 

795 verbose=verbose, header=header, 

796 names=('x', 'y'), precision=precision) 

797 raise AssertionError(msg) 

798 except ValueError: 

799 import traceback 

800 efmt = traceback.format_exc() 

801 header = f'error during assertion:\n\n{efmt}\n\n{header}' 

802 

803 msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header, 

804 names=('x', 'y'), precision=precision) 

805 raise ValueError(msg) 

806 

807 

808def assert_array_equal(x, y, err_msg='', verbose=True, *, strict=False): 

809 """ 

810 Raises an AssertionError if two array_like objects are not equal. 

811 

812 Given two array_like objects, check that the shape is equal and all 

813 elements of these objects are equal (but see the Notes for the special 

814 handling of a scalar). An exception is raised at shape mismatch or 

815 conflicting values. In contrast to the standard usage in numpy, NaNs 

816 are compared like numbers, no assertion is raised if both objects have 

817 NaNs in the same positions. 

818 

819 The usual caution for verifying equality with floating point numbers is 

820 advised. 

821 

822 Parameters 

823 ---------- 

824 x : array_like 

825 The actual object to check. 

826 y : array_like 

827 The desired, expected object. 

828 err_msg : str, optional 

829 The error message to be printed in case of failure. 

830 verbose : bool, optional 

831 If True, the conflicting values are appended to the error message. 

832 strict : bool, optional 

833 If True, raise an AssertionError when either the shape or the data 

834 type of the array_like objects does not match. The special 

835 handling for scalars mentioned in the Notes section is disabled. 

836 

837 .. versionadded:: 1.24.0 

838 

839 Raises 

840 ------ 

841 AssertionError 

842 If actual and desired objects are not equal. 

843 

844 See Also 

845 -------- 

846 assert_allclose: Compare two array_like objects for equality with desired 

847 relative and/or absolute precision. 

848 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal 

849 

850 Notes 

851 ----- 

852 When one of `x` and `y` is a scalar and the other is array_like, the 

853 function checks that each element of the array_like object is equal to 

854 the scalar. This behaviour can be disabled with the `strict` parameter. 

855 

856 Examples 

857 -------- 

858 The first assert does not raise an exception: 

859 

860 >>> np.testing.assert_array_equal([1.0,2.33333,np.nan], 

861 ... [np.exp(0),2.33333, np.nan]) 

862 

863 Assert fails with numerical imprecision with floats: 

864 

865 >>> np.testing.assert_array_equal([1.0,np.pi,np.nan], 

866 ... [1, np.sqrt(np.pi)**2, np.nan]) 

867 Traceback (most recent call last): 

868 ... 

869 AssertionError: 

870 Arrays are not equal 

871 <BLANKLINE> 

872 Mismatched elements: 1 / 3 (33.3%) 

873 Max absolute difference: 4.4408921e-16 

874 Max relative difference: 1.41357986e-16 

875 x: array([1. , 3.141593, nan]) 

876 y: array([1. , 3.141593, nan]) 

877 

878 Use `assert_allclose` or one of the nulp (number of floating point values) 

879 functions for these cases instead: 

880 

881 >>> np.testing.assert_allclose([1.0,np.pi,np.nan], 

882 ... [1, np.sqrt(np.pi)**2, np.nan], 

883 ... rtol=1e-10, atol=0) 

884 

885 As mentioned in the Notes section, `assert_array_equal` has special 

886 handling for scalars. Here the test checks that each value in `x` is 3: 

887 

888 >>> x = np.full((2, 5), fill_value=3) 

889 >>> np.testing.assert_array_equal(x, 3) 

890 

891 Use `strict` to raise an AssertionError when comparing a scalar with an 

892 array: 

893 

894 >>> np.testing.assert_array_equal(x, 3, strict=True) 

895 Traceback (most recent call last): 

896 ... 

897 AssertionError: 

898 Arrays are not equal 

899 <BLANKLINE> 

900 (shapes (2, 5), () mismatch) 

901 x: array([[3, 3, 3, 3, 3], 

902 [3, 3, 3, 3, 3]]) 

903 y: array(3) 

904 

905 The `strict` parameter also ensures that the array data types match: 

906 

907 >>> x = np.array([2, 2, 2]) 

908 >>> y = np.array([2., 2., 2.], dtype=np.float32) 

909 >>> np.testing.assert_array_equal(x, y, strict=True) 

910 Traceback (most recent call last): 

911 ... 

912 AssertionError: 

913 Arrays are not equal 

914 <BLANKLINE> 

915 (dtypes int64, float32 mismatch) 

916 x: array([2, 2, 2]) 

917 y: array([2., 2., 2.], dtype=float32) 

918 """ 

919 __tracebackhide__ = True # Hide traceback for py.test 

920 assert_array_compare(operator.__eq__, x, y, err_msg=err_msg, 

921 verbose=verbose, header='Arrays are not equal', 

922 strict=strict) 

923 

924 

925@np._no_nep50_warning() 

926def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): 

927 """ 

928 Raises an AssertionError if two objects are not equal up to desired 

929 precision. 

930 

931 .. note:: It is recommended to use one of `assert_allclose`, 

932 `assert_array_almost_equal_nulp` or `assert_array_max_ulp` 

933 instead of this function for more consistent floating point 

934 comparisons. 

935 

936 The test verifies identical shapes and that the elements of ``actual`` and 

937 ``desired`` satisfy. 

938 

939 ``abs(desired-actual) < 1.5 * 10**(-decimal)`` 

940 

941 That is a looser test than originally documented, but agrees with what the 

942 actual implementation did up to rounding vagaries. An exception is raised 

943 at shape mismatch or conflicting values. In contrast to the standard usage 

944 in numpy, NaNs are compared like numbers, no assertion is raised if both 

945 objects have NaNs in the same positions. 

946 

947 Parameters 

948 ---------- 

949 x : array_like 

950 The actual object to check. 

951 y : array_like 

952 The desired, expected object. 

953 decimal : int, optional 

954 Desired precision, default is 6. 

955 err_msg : str, optional 

956 The error message to be printed in case of failure. 

957 verbose : bool, optional 

958 If True, the conflicting values are appended to the error message. 

959 

960 Raises 

961 ------ 

962 AssertionError 

963 If actual and desired are not equal up to specified precision. 

964 

965 See Also 

966 -------- 

967 assert_allclose: Compare two array_like objects for equality with desired 

968 relative and/or absolute precision. 

969 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal 

970 

971 Examples 

972 -------- 

973 the first assert does not raise an exception 

974 

975 >>> np.testing.assert_array_almost_equal([1.0,2.333,np.nan], 

976 ... [1.0,2.333,np.nan]) 

977 

978 >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan], 

979 ... [1.0,2.33339,np.nan], decimal=5) 

980 Traceback (most recent call last): 

981 ... 

982 AssertionError: 

983 Arrays are not almost equal to 5 decimals 

984 <BLANKLINE> 

985 Mismatched elements: 1 / 3 (33.3%) 

986 Max absolute difference: 6.e-05 

987 Max relative difference: 2.57136612e-05 

988 x: array([1. , 2.33333, nan]) 

989 y: array([1. , 2.33339, nan]) 

990 

991 >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan], 

992 ... [1.0,2.33333, 5], decimal=5) 

993 Traceback (most recent call last): 

994 ... 

995 AssertionError: 

996 Arrays are not almost equal to 5 decimals 

997 <BLANKLINE> 

998 x and y nan location mismatch: 

999 x: array([1. , 2.33333, nan]) 

1000 y: array([1. , 2.33333, 5. ]) 

1001 

1002 """ 

1003 __tracebackhide__ = True # Hide traceback for py.test 

1004 from numpy.core import number, float_, result_type 

1005 from numpy.core.numerictypes import issubdtype 

1006 from numpy.core.fromnumeric import any as npany 

1007 

1008 def compare(x, y): 

1009 try: 

1010 if npany(isinf(x)) or npany(isinf(y)): 

1011 xinfid = isinf(x) 

1012 yinfid = isinf(y) 

1013 if not (xinfid == yinfid).all(): 

1014 return False 

1015 # if one item, x and y is +- inf 

1016 if x.size == y.size == 1: 

1017 return x == y 

1018 x = x[~xinfid] 

1019 y = y[~yinfid] 

1020 except (TypeError, NotImplementedError): 

1021 pass 

1022 

1023 # make sure y is an inexact type to avoid abs(MIN_INT); will cause 

1024 # casting of x later. 

1025 dtype = result_type(y, 1.) 

1026 y = np.asanyarray(y, dtype) 

1027 z = abs(x - y) 

1028 

1029 if not issubdtype(z.dtype, number): 

1030 z = z.astype(float_) # handle object arrays 

1031 

1032 return z < 1.5 * 10.0**(-decimal) 

1033 

1034 assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, 

1035 header=('Arrays are not almost equal to %d decimals' % decimal), 

1036 precision=decimal) 

1037 

1038 

1039def assert_array_less(x, y, err_msg='', verbose=True): 

1040 """ 

1041 Raises an AssertionError if two array_like objects are not ordered by less 

1042 than. 

1043 

1044 Given two array_like objects, check that the shape is equal and all 

1045 elements of the first object are strictly smaller than those of the 

1046 second object. An exception is raised at shape mismatch or incorrectly 

1047 ordered values. Shape mismatch does not raise if an object has zero 

1048 dimension. In contrast to the standard usage in numpy, NaNs are 

1049 compared, no assertion is raised if both objects have NaNs in the same 

1050 positions. 

1051 

1052 Parameters 

1053 ---------- 

1054 x : array_like 

1055 The smaller object to check. 

1056 y : array_like 

1057 The larger object to compare. 

1058 err_msg : string 

1059 The error message to be printed in case of failure. 

1060 verbose : bool 

1061 If True, the conflicting values are appended to the error message. 

1062 

1063 Raises 

1064 ------ 

1065 AssertionError 

1066 If x is not strictly smaller than y, element-wise. 

1067 

1068 See Also 

1069 -------- 

1070 assert_array_equal: tests objects for equality 

1071 assert_array_almost_equal: test objects for equality up to precision 

1072 

1073 Examples 

1074 -------- 

1075 >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1.1, 2.0, np.nan]) 

1076 >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1, 2.0, np.nan]) 

1077 Traceback (most recent call last): 

1078 ... 

1079 AssertionError: 

1080 Arrays are not less-ordered 

1081 <BLANKLINE> 

1082 Mismatched elements: 1 / 3 (33.3%) 

1083 Max absolute difference: 1. 

1084 Max relative difference: 0.5 

1085 x: array([ 1., 1., nan]) 

1086 y: array([ 1., 2., nan]) 

1087 

1088 >>> np.testing.assert_array_less([1.0, 4.0], 3) 

1089 Traceback (most recent call last): 

1090 ... 

1091 AssertionError: 

1092 Arrays are not less-ordered 

1093 <BLANKLINE> 

1094 Mismatched elements: 1 / 2 (50%) 

1095 Max absolute difference: 2. 

1096 Max relative difference: 0.66666667 

1097 x: array([1., 4.]) 

1098 y: array(3) 

1099 

1100 >>> np.testing.assert_array_less([1.0, 2.0, 3.0], [4]) 

1101 Traceback (most recent call last): 

1102 ... 

1103 AssertionError: 

1104 Arrays are not less-ordered 

1105 <BLANKLINE> 

1106 (shapes (3,), (1,) mismatch) 

1107 x: array([1., 2., 3.]) 

1108 y: array([4]) 

1109 

1110 """ 

1111 __tracebackhide__ = True # Hide traceback for py.test 

1112 assert_array_compare(operator.__lt__, x, y, err_msg=err_msg, 

1113 verbose=verbose, 

1114 header='Arrays are not less-ordered', 

1115 equal_inf=False) 

1116 

1117 

1118def runstring(astr, dict): 

1119 exec(astr, dict) 

1120 

1121 

1122def assert_string_equal(actual, desired): 

1123 """ 

1124 Test if two strings are equal. 

1125 

1126 If the given strings are equal, `assert_string_equal` does nothing. 

1127 If they are not equal, an AssertionError is raised, and the diff 

1128 between the strings is shown. 

1129 

1130 Parameters 

1131 ---------- 

1132 actual : str 

1133 The string to test for equality against the expected string. 

1134 desired : str 

1135 The expected string. 

1136 

1137 Examples 

1138 -------- 

1139 >>> np.testing.assert_string_equal('abc', 'abc') 

1140 >>> np.testing.assert_string_equal('abc', 'abcd') 

1141 Traceback (most recent call last): 

1142 File "<stdin>", line 1, in <module> 

1143 ... 

1144 AssertionError: Differences in strings: 

1145 - abc+ abcd? + 

1146 

1147 """ 

1148 # delay import of difflib to reduce startup time 

1149 __tracebackhide__ = True # Hide traceback for py.test 

1150 import difflib 

1151 

1152 if not isinstance(actual, str): 

1153 raise AssertionError(repr(type(actual))) 

1154 if not isinstance(desired, str): 

1155 raise AssertionError(repr(type(desired))) 

1156 if desired == actual: 

1157 return 

1158 

1159 diff = list(difflib.Differ().compare(actual.splitlines(True), 

1160 desired.splitlines(True))) 

1161 diff_list = [] 

1162 while diff: 

1163 d1 = diff.pop(0) 

1164 if d1.startswith(' '): 

1165 continue 

1166 if d1.startswith('- '): 

1167 l = [d1] 

1168 d2 = diff.pop(0) 

1169 if d2.startswith('? '): 

1170 l.append(d2) 

1171 d2 = diff.pop(0) 

1172 if not d2.startswith('+ '): 

1173 raise AssertionError(repr(d2)) 

1174 l.append(d2) 

1175 if diff: 

1176 d3 = diff.pop(0) 

1177 if d3.startswith('? '): 

1178 l.append(d3) 

1179 else: 

1180 diff.insert(0, d3) 

1181 if d2[2:] == d1[2:]: 

1182 continue 

1183 diff_list.extend(l) 

1184 continue 

1185 raise AssertionError(repr(d1)) 

1186 if not diff_list: 

1187 return 

1188 msg = f"Differences in strings:\n{''.join(diff_list).rstrip()}" 

1189 if actual != desired: 

1190 raise AssertionError(msg) 

1191 

1192 

1193def rundocs(filename=None, raise_on_error=True): 

1194 """ 

1195 Run doctests found in the given file. 

1196 

1197 By default `rundocs` raises an AssertionError on failure. 

1198 

1199 Parameters 

1200 ---------- 

1201 filename : str 

1202 The path to the file for which the doctests are run. 

1203 raise_on_error : bool 

1204 Whether to raise an AssertionError when a doctest fails. Default is 

1205 True. 

1206 

1207 Notes 

1208 ----- 

1209 The doctests can be run by the user/developer by adding the ``doctests`` 

1210 argument to the ``test()`` call. For example, to run all tests (including 

1211 doctests) for `numpy.lib`: 

1212 

1213 >>> np.lib.test(doctests=True) # doctest: +SKIP 

1214 """ 

1215 from numpy.distutils.misc_util import exec_mod_from_location 

1216 import doctest 

1217 if filename is None: 

1218 f = sys._getframe(1) 

1219 filename = f.f_globals['__file__'] 

1220 name = os.path.splitext(os.path.basename(filename))[0] 

1221 m = exec_mod_from_location(name, filename) 

1222 

1223 tests = doctest.DocTestFinder().find(m) 

1224 runner = doctest.DocTestRunner(verbose=False) 

1225 

1226 msg = [] 

1227 if raise_on_error: 

1228 out = lambda s: msg.append(s) 

1229 else: 

1230 out = None 

1231 

1232 for test in tests: 

1233 runner.run(test, out=out) 

1234 

1235 if runner.failures > 0 and raise_on_error: 

1236 raise AssertionError("Some doctests failed:\n%s" % "\n".join(msg)) 

1237 

1238 

1239def check_support_sve(): 

1240 """ 

1241 gh-22982 

1242 """ 

1243 

1244 import subprocess 

1245 cmd = 'lscpu' 

1246 try: 

1247 output = subprocess.run(cmd, capture_output=True, text=True) 

1248 return 'sve' in output.stdout 

1249 except OSError: 

1250 return False 

1251 

1252 

1253_SUPPORTS_SVE = check_support_sve() 

1254 

1255# 

1256# assert_raises and assert_raises_regex are taken from unittest. 

1257# 

1258import unittest 

1259 

1260 

1261class _Dummy(unittest.TestCase): 

1262 def nop(self): 

1263 pass 

1264 

1265 

1266_d = _Dummy('nop') 

1267 

1268 

1269def assert_raises(*args, **kwargs): 

1270 """ 

1271 assert_raises(exception_class, callable, *args, **kwargs) 

1272 assert_raises(exception_class) 

1273 

1274 Fail unless an exception of class exception_class is thrown 

1275 by callable when invoked with arguments args and keyword 

1276 arguments kwargs. If a different type of exception is 

1277 thrown, it will not be caught, and the test case will be 

1278 deemed to have suffered an error, exactly as for an 

1279 unexpected exception. 

1280 

1281 Alternatively, `assert_raises` can be used as a context manager: 

1282 

1283 >>> from numpy.testing import assert_raises 

1284 >>> with assert_raises(ZeroDivisionError): 

1285 ... 1 / 0 

1286 

1287 is equivalent to 

1288 

1289 >>> def div(x, y): 

1290 ... return x / y 

1291 >>> assert_raises(ZeroDivisionError, div, 1, 0) 

1292 

1293 """ 

1294 __tracebackhide__ = True # Hide traceback for py.test 

1295 return _d.assertRaises(*args, **kwargs) 

1296 

1297 

1298def assert_raises_regex(exception_class, expected_regexp, *args, **kwargs): 

1299 """ 

1300 assert_raises_regex(exception_class, expected_regexp, callable, *args, 

1301 **kwargs) 

1302 assert_raises_regex(exception_class, expected_regexp) 

1303 

1304 Fail unless an exception of class exception_class and with message that 

1305 matches expected_regexp is thrown by callable when invoked with arguments 

1306 args and keyword arguments kwargs. 

1307 

1308 Alternatively, can be used as a context manager like `assert_raises`. 

1309 

1310 Notes 

1311 ----- 

1312 .. versionadded:: 1.9.0 

1313 

1314 """ 

1315 __tracebackhide__ = True # Hide traceback for py.test 

1316 return _d.assertRaisesRegex(exception_class, expected_regexp, *args, **kwargs) 

1317 

1318 

1319def decorate_methods(cls, decorator, testmatch=None): 

1320 """ 

1321 Apply a decorator to all methods in a class matching a regular expression. 

1322 

1323 The given decorator is applied to all public methods of `cls` that are 

1324 matched by the regular expression `testmatch` 

1325 (``testmatch.search(methodname)``). Methods that are private, i.e. start 

1326 with an underscore, are ignored. 

1327 

1328 Parameters 

1329 ---------- 

1330 cls : class 

1331 Class whose methods to decorate. 

1332 decorator : function 

1333 Decorator to apply to methods 

1334 testmatch : compiled regexp or str, optional 

1335 The regular expression. Default value is None, in which case the 

1336 nose default (``re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep)``) 

1337 is used. 

1338 If `testmatch` is a string, it is compiled to a regular expression 

1339 first. 

1340 

1341 """ 

1342 if testmatch is None: 

1343 testmatch = re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep) 

1344 else: 

1345 testmatch = re.compile(testmatch) 

1346 cls_attr = cls.__dict__ 

1347 

1348 # delayed import to reduce startup time 

1349 from inspect import isfunction 

1350 

1351 methods = [_m for _m in cls_attr.values() if isfunction(_m)] 

1352 for function in methods: 

1353 try: 

1354 if hasattr(function, 'compat_func_name'): 

1355 funcname = function.compat_func_name 

1356 else: 

1357 funcname = function.__name__ 

1358 except AttributeError: 

1359 # not a function 

1360 continue 

1361 if testmatch.search(funcname) and not funcname.startswith('_'): 

1362 setattr(cls, funcname, decorator(function)) 

1363 return 

1364 

1365 

1366def measure(code_str, times=1, label=None): 

1367 """ 

1368 Return elapsed time for executing code in the namespace of the caller. 

1369 

1370 The supplied code string is compiled with the Python builtin ``compile``. 

1371 The precision of the timing is 10 milli-seconds. If the code will execute 

1372 fast on this timescale, it can be executed many times to get reasonable 

1373 timing accuracy. 

1374 

1375 Parameters 

1376 ---------- 

1377 code_str : str 

1378 The code to be timed. 

1379 times : int, optional 

1380 The number of times the code is executed. Default is 1. The code is 

1381 only compiled once. 

1382 label : str, optional 

1383 A label to identify `code_str` with. This is passed into ``compile`` 

1384 as the second argument (for run-time error messages). 

1385 

1386 Returns 

1387 ------- 

1388 elapsed : float 

1389 Total elapsed time in seconds for executing `code_str` `times` times. 

1390 

1391 Examples 

1392 -------- 

1393 >>> times = 10 

1394 >>> etime = np.testing.measure('for i in range(1000): np.sqrt(i**2)', times=times) 

1395 >>> print("Time for a single execution : ", etime / times, "s") # doctest: +SKIP 

1396 Time for a single execution : 0.005 s 

1397 

1398 """ 

1399 frame = sys._getframe(1) 

1400 locs, globs = frame.f_locals, frame.f_globals 

1401 

1402 code = compile(code_str, f'Test name: {label} ', 'exec') 

1403 i = 0 

1404 elapsed = jiffies() 

1405 while i < times: 

1406 i += 1 

1407 exec(code, globs, locs) 

1408 elapsed = jiffies() - elapsed 

1409 return 0.01*elapsed 

1410 

1411 

1412def _assert_valid_refcount(op): 

1413 """ 

1414 Check that ufuncs don't mishandle refcount of object `1`. 

1415 Used in a few regression tests. 

1416 """ 

1417 if not HAS_REFCOUNT: 

1418 return True 

1419 

1420 import gc 

1421 import numpy as np 

1422 

1423 b = np.arange(100*100).reshape(100, 100) 

1424 c = b 

1425 i = 1 

1426 

1427 gc.disable() 

1428 try: 

1429 rc = sys.getrefcount(i) 

1430 for j in range(15): 

1431 d = op(b, c) 

1432 assert_(sys.getrefcount(i) >= rc) 

1433 finally: 

1434 gc.enable() 

1435 del d # for pyflakes 

1436 

1437 

1438def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True, 

1439 err_msg='', verbose=True): 

1440 """ 

1441 Raises an AssertionError if two objects are not equal up to desired 

1442 tolerance. 

1443 

1444 Given two array_like objects, check that their shapes and all elements 

1445 are equal (but see the Notes for the special handling of a scalar). An 

1446 exception is raised if the shapes mismatch or any values conflict. In 

1447 contrast to the standard usage in numpy, NaNs are compared like numbers, 

1448 no assertion is raised if both objects have NaNs in the same positions. 

1449 

1450 The test is equivalent to ``allclose(actual, desired, rtol, atol)`` (note 

1451 that ``allclose`` has different default values). It compares the difference 

1452 between `actual` and `desired` to ``atol + rtol * abs(desired)``. 

1453 

1454 .. versionadded:: 1.5.0 

1455 

1456 Parameters 

1457 ---------- 

1458 actual : array_like 

1459 Array obtained. 

1460 desired : array_like 

1461 Array desired. 

1462 rtol : float, optional 

1463 Relative tolerance. 

1464 atol : float, optional 

1465 Absolute tolerance. 

1466 equal_nan : bool, optional. 

1467 If True, NaNs will compare equal. 

1468 err_msg : str, optional 

1469 The error message to be printed in case of failure. 

1470 verbose : bool, optional 

1471 If True, the conflicting values are appended to the error message. 

1472 

1473 Raises 

1474 ------ 

1475 AssertionError 

1476 If actual and desired are not equal up to specified precision. 

1477 

1478 See Also 

1479 -------- 

1480 assert_array_almost_equal_nulp, assert_array_max_ulp 

1481 

1482 Notes 

1483 ----- 

1484 When one of `actual` and `desired` is a scalar and the other is 

1485 array_like, the function checks that each element of the array_like 

1486 object is equal to the scalar. 

1487 

1488 Examples 

1489 -------- 

1490 >>> x = [1e-5, 1e-3, 1e-1] 

1491 >>> y = np.arccos(np.cos(x)) 

1492 >>> np.testing.assert_allclose(x, y, rtol=1e-5, atol=0) 

1493 

1494 """ 

1495 __tracebackhide__ = True # Hide traceback for py.test 

1496 import numpy as np 

1497 

1498 def compare(x, y): 

1499 return np.core.numeric.isclose(x, y, rtol=rtol, atol=atol, 

1500 equal_nan=equal_nan) 

1501 

1502 actual, desired = np.asanyarray(actual), np.asanyarray(desired) 

1503 header = f'Not equal to tolerance rtol={rtol:g}, atol={atol:g}' 

1504 assert_array_compare(compare, actual, desired, err_msg=str(err_msg), 

1505 verbose=verbose, header=header, equal_nan=equal_nan) 

1506 

1507 

1508def assert_array_almost_equal_nulp(x, y, nulp=1): 

1509 """ 

1510 Compare two arrays relatively to their spacing. 

1511 

1512 This is a relatively robust method to compare two arrays whose amplitude 

1513 is variable. 

1514 

1515 Parameters 

1516 ---------- 

1517 x, y : array_like 

1518 Input arrays. 

1519 nulp : int, optional 

1520 The maximum number of unit in the last place for tolerance (see Notes). 

1521 Default is 1. 

1522 

1523 Returns 

1524 ------- 

1525 None 

1526 

1527 Raises 

1528 ------ 

1529 AssertionError 

1530 If the spacing between `x` and `y` for one or more elements is larger 

1531 than `nulp`. 

1532 

1533 See Also 

1534 -------- 

1535 assert_array_max_ulp : Check that all items of arrays differ in at most 

1536 N Units in the Last Place. 

1537 spacing : Return the distance between x and the nearest adjacent number. 

1538 

1539 Notes 

1540 ----- 

1541 An assertion is raised if the following condition is not met:: 

1542 

1543 abs(x - y) <= nulp * spacing(maximum(abs(x), abs(y))) 

1544 

1545 Examples 

1546 -------- 

1547 >>> x = np.array([1., 1e-10, 1e-20]) 

1548 >>> eps = np.finfo(x.dtype).eps 

1549 >>> np.testing.assert_array_almost_equal_nulp(x, x*eps/2 + x) 

1550 

1551 >>> np.testing.assert_array_almost_equal_nulp(x, x*eps + x) 

1552 Traceback (most recent call last): 

1553 ... 

1554 AssertionError: X and Y are not equal to 1 ULP (max is 2) 

1555 

1556 """ 

1557 __tracebackhide__ = True # Hide traceback for py.test 

1558 import numpy as np 

1559 ax = np.abs(x) 

1560 ay = np.abs(y) 

1561 ref = nulp * np.spacing(np.where(ax > ay, ax, ay)) 

1562 if not np.all(np.abs(x-y) <= ref): 

1563 if np.iscomplexobj(x) or np.iscomplexobj(y): 

1564 msg = "X and Y are not equal to %d ULP" % nulp 

1565 else: 

1566 max_nulp = np.max(nulp_diff(x, y)) 

1567 msg = "X and Y are not equal to %d ULP (max is %g)" % (nulp, max_nulp) 

1568 raise AssertionError(msg) 

1569 

1570 

1571def assert_array_max_ulp(a, b, maxulp=1, dtype=None): 

1572 """ 

1573 Check that all items of arrays differ in at most N Units in the Last Place. 

1574 

1575 Parameters 

1576 ---------- 

1577 a, b : array_like 

1578 Input arrays to be compared. 

1579 maxulp : int, optional 

1580 The maximum number of units in the last place that elements of `a` and 

1581 `b` can differ. Default is 1. 

1582 dtype : dtype, optional 

1583 Data-type to convert `a` and `b` to if given. Default is None. 

1584 

1585 Returns 

1586 ------- 

1587 ret : ndarray 

1588 Array containing number of representable floating point numbers between 

1589 items in `a` and `b`. 

1590 

1591 Raises 

1592 ------ 

1593 AssertionError 

1594 If one or more elements differ by more than `maxulp`. 

1595 

1596 Notes 

1597 ----- 

1598 For computing the ULP difference, this API does not differentiate between 

1599 various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000 

1600 is zero). 

1601 

1602 See Also 

1603 -------- 

1604 assert_array_almost_equal_nulp : Compare two arrays relatively to their 

1605 spacing. 

1606 

1607 Examples 

1608 -------- 

1609 >>> a = np.linspace(0., 1., 100) 

1610 >>> res = np.testing.assert_array_max_ulp(a, np.arcsin(np.sin(a))) 

1611 

1612 """ 

1613 __tracebackhide__ = True # Hide traceback for py.test 

1614 import numpy as np 

1615 ret = nulp_diff(a, b, dtype) 

1616 if not np.all(ret <= maxulp): 

1617 raise AssertionError("Arrays are not almost equal up to %g " 

1618 "ULP (max difference is %g ULP)" % 

1619 (maxulp, np.max(ret))) 

1620 return ret 

1621 

1622 

1623def nulp_diff(x, y, dtype=None): 

1624 """For each item in x and y, return the number of representable floating 

1625 points between them. 

1626 

1627 Parameters 

1628 ---------- 

1629 x : array_like 

1630 first input array 

1631 y : array_like 

1632 second input array 

1633 dtype : dtype, optional 

1634 Data-type to convert `x` and `y` to if given. Default is None. 

1635 

1636 Returns 

1637 ------- 

1638 nulp : array_like 

1639 number of representable floating point numbers between each item in x 

1640 and y. 

1641 

1642 Notes 

1643 ----- 

1644 For computing the ULP difference, this API does not differentiate between 

1645 various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000 

1646 is zero). 

1647 

1648 Examples 

1649 -------- 

1650 # By definition, epsilon is the smallest number such as 1 + eps != 1, so 

1651 # there should be exactly one ULP between 1 and 1 + eps 

1652 >>> nulp_diff(1, 1 + np.finfo(x.dtype).eps) 

1653 1.0 

1654 """ 

1655 import numpy as np 

1656 if dtype: 

1657 x = np.asarray(x, dtype=dtype) 

1658 y = np.asarray(y, dtype=dtype) 

1659 else: 

1660 x = np.asarray(x) 

1661 y = np.asarray(y) 

1662 

1663 t = np.common_type(x, y) 

1664 if np.iscomplexobj(x) or np.iscomplexobj(y): 

1665 raise NotImplementedError("_nulp not implemented for complex array") 

1666 

1667 x = np.array([x], dtype=t) 

1668 y = np.array([y], dtype=t) 

1669 

1670 x[np.isnan(x)] = np.nan 

1671 y[np.isnan(y)] = np.nan 

1672 

1673 if not x.shape == y.shape: 

1674 raise ValueError("x and y do not have the same shape: %s - %s" % 

1675 (x.shape, y.shape)) 

1676 

1677 def _diff(rx, ry, vdt): 

1678 diff = np.asarray(rx-ry, dtype=vdt) 

1679 return np.abs(diff) 

1680 

1681 rx = integer_repr(x) 

1682 ry = integer_repr(y) 

1683 return _diff(rx, ry, t) 

1684 

1685 

1686def _integer_repr(x, vdt, comp): 

1687 # Reinterpret binary representation of the float as sign-magnitude: 

1688 # take into account two-complement representation 

1689 # See also 

1690 # https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/ 

1691 rx = x.view(vdt) 

1692 if not (rx.size == 1): 

1693 rx[rx < 0] = comp - rx[rx < 0] 

1694 else: 

1695 if rx < 0: 

1696 rx = comp - rx 

1697 

1698 return rx 

1699 

1700 

1701def integer_repr(x): 

1702 """Return the signed-magnitude interpretation of the binary representation 

1703 of x.""" 

1704 import numpy as np 

1705 if x.dtype == np.float16: 

1706 return _integer_repr(x, np.int16, np.int16(-2**15)) 

1707 elif x.dtype == np.float32: 

1708 return _integer_repr(x, np.int32, np.int32(-2**31)) 

1709 elif x.dtype == np.float64: 

1710 return _integer_repr(x, np.int64, np.int64(-2**63)) 

1711 else: 

1712 raise ValueError(f'Unsupported dtype {x.dtype}') 

1713 

1714 

1715@contextlib.contextmanager 

1716def _assert_warns_context(warning_class, name=None): 

1717 __tracebackhide__ = True # Hide traceback for py.test 

1718 with suppress_warnings() as sup: 

1719 l = sup.record(warning_class) 

1720 yield 

1721 if not len(l) > 0: 

1722 name_str = f' when calling {name}' if name is not None else '' 

1723 raise AssertionError("No warning raised" + name_str) 

1724 

1725 

1726def assert_warns(warning_class, *args, **kwargs): 

1727 """ 

1728 Fail unless the given callable throws the specified warning. 

1729 

1730 A warning of class warning_class should be thrown by the callable when 

1731 invoked with arguments args and keyword arguments kwargs. 

1732 If a different type of warning is thrown, it will not be caught. 

1733 

1734 If called with all arguments other than the warning class omitted, may be 

1735 used as a context manager: 

1736 

1737 with assert_warns(SomeWarning): 

1738 do_something() 

1739 

1740 The ability to be used as a context manager is new in NumPy v1.11.0. 

1741 

1742 .. versionadded:: 1.4.0 

1743 

1744 Parameters 

1745 ---------- 

1746 warning_class : class 

1747 The class defining the warning that `func` is expected to throw. 

1748 func : callable, optional 

1749 Callable to test 

1750 *args : Arguments 

1751 Arguments for `func`. 

1752 **kwargs : Kwargs 

1753 Keyword arguments for `func`. 

1754 

1755 Returns 

1756 ------- 

1757 The value returned by `func`. 

1758 

1759 Examples 

1760 -------- 

1761 >>> import warnings 

1762 >>> def deprecated_func(num): 

1763 ... warnings.warn("Please upgrade", DeprecationWarning) 

1764 ... return num*num 

1765 >>> with np.testing.assert_warns(DeprecationWarning): 

1766 ... assert deprecated_func(4) == 16 

1767 >>> # or passing a func 

1768 >>> ret = np.testing.assert_warns(DeprecationWarning, deprecated_func, 4) 

1769 >>> assert ret == 16 

1770 """ 

1771 if not args: 

1772 return _assert_warns_context(warning_class) 

1773 

1774 func = args[0] 

1775 args = args[1:] 

1776 with _assert_warns_context(warning_class, name=func.__name__): 

1777 return func(*args, **kwargs) 

1778 

1779 

1780@contextlib.contextmanager 

1781def _assert_no_warnings_context(name=None): 

1782 __tracebackhide__ = True # Hide traceback for py.test 

1783 with warnings.catch_warnings(record=True) as l: 

1784 warnings.simplefilter('always') 

1785 yield 

1786 if len(l) > 0: 

1787 name_str = f' when calling {name}' if name is not None else '' 

1788 raise AssertionError(f'Got warnings{name_str}: {l}') 

1789 

1790 

1791def assert_no_warnings(*args, **kwargs): 

1792 """ 

1793 Fail if the given callable produces any warnings. 

1794 

1795 If called with all arguments omitted, may be used as a context manager: 

1796 

1797 with assert_no_warnings(): 

1798 do_something() 

1799 

1800 The ability to be used as a context manager is new in NumPy v1.11.0. 

1801 

1802 .. versionadded:: 1.7.0 

1803 

1804 Parameters 

1805 ---------- 

1806 func : callable 

1807 The callable to test. 

1808 \\*args : Arguments 

1809 Arguments passed to `func`. 

1810 \\*\\*kwargs : Kwargs 

1811 Keyword arguments passed to `func`. 

1812 

1813 Returns 

1814 ------- 

1815 The value returned by `func`. 

1816 

1817 """ 

1818 if not args: 

1819 return _assert_no_warnings_context() 

1820 

1821 func = args[0] 

1822 args = args[1:] 

1823 with _assert_no_warnings_context(name=func.__name__): 

1824 return func(*args, **kwargs) 

1825 

1826 

1827def _gen_alignment_data(dtype=float32, type='binary', max_size=24): 

1828 """ 

1829 generator producing data with different alignment and offsets 

1830 to test simd vectorization 

1831 

1832 Parameters 

1833 ---------- 

1834 dtype : dtype 

1835 data type to produce 

1836 type : string 

1837 'unary': create data for unary operations, creates one input 

1838 and output array 

1839 'binary': create data for unary operations, creates two input 

1840 and output array 

1841 max_size : integer 

1842 maximum size of data to produce 

1843 

1844 Returns 

1845 ------- 

1846 if type is 'unary' yields one output, one input array and a message 

1847 containing information on the data 

1848 if type is 'binary' yields one output array, two input array and a message 

1849 containing information on the data 

1850 

1851 """ 

1852 ufmt = 'unary offset=(%d, %d), size=%d, dtype=%r, %s' 

1853 bfmt = 'binary offset=(%d, %d, %d), size=%d, dtype=%r, %s' 

1854 for o in range(3): 

1855 for s in range(o + 2, max(o + 3, max_size)): 

1856 if type == 'unary': 

1857 inp = lambda: arange(s, dtype=dtype)[o:] 

1858 out = empty((s,), dtype=dtype)[o:] 

1859 yield out, inp(), ufmt % (o, o, s, dtype, 'out of place') 

1860 d = inp() 

1861 yield d, d, ufmt % (o, o, s, dtype, 'in place') 

1862 yield out[1:], inp()[:-1], ufmt % \ 

1863 (o + 1, o, s - 1, dtype, 'out of place') 

1864 yield out[:-1], inp()[1:], ufmt % \ 

1865 (o, o + 1, s - 1, dtype, 'out of place') 

1866 yield inp()[:-1], inp()[1:], ufmt % \ 

1867 (o, o + 1, s - 1, dtype, 'aliased') 

1868 yield inp()[1:], inp()[:-1], ufmt % \ 

1869 (o + 1, o, s - 1, dtype, 'aliased') 

1870 if type == 'binary': 

1871 inp1 = lambda: arange(s, dtype=dtype)[o:] 

1872 inp2 = lambda: arange(s, dtype=dtype)[o:] 

1873 out = empty((s,), dtype=dtype)[o:] 

1874 yield out, inp1(), inp2(), bfmt % \ 

1875 (o, o, o, s, dtype, 'out of place') 

1876 d = inp1() 

1877 yield d, d, inp2(), bfmt % \ 

1878 (o, o, o, s, dtype, 'in place1') 

1879 d = inp2() 

1880 yield d, inp1(), d, bfmt % \ 

1881 (o, o, o, s, dtype, 'in place2') 

1882 yield out[1:], inp1()[:-1], inp2()[:-1], bfmt % \ 

1883 (o + 1, o, o, s - 1, dtype, 'out of place') 

1884 yield out[:-1], inp1()[1:], inp2()[:-1], bfmt % \ 

1885 (o, o + 1, o, s - 1, dtype, 'out of place') 

1886 yield out[:-1], inp1()[:-1], inp2()[1:], bfmt % \ 

1887 (o, o, o + 1, s - 1, dtype, 'out of place') 

1888 yield inp1()[1:], inp1()[:-1], inp2()[:-1], bfmt % \ 

1889 (o + 1, o, o, s - 1, dtype, 'aliased') 

1890 yield inp1()[:-1], inp1()[1:], inp2()[:-1], bfmt % \ 

1891 (o, o + 1, o, s - 1, dtype, 'aliased') 

1892 yield inp1()[:-1], inp1()[:-1], inp2()[1:], bfmt % \ 

1893 (o, o, o + 1, s - 1, dtype, 'aliased') 

1894 

1895 

1896class IgnoreException(Exception): 

1897 "Ignoring this exception due to disabled feature" 

1898 pass 

1899 

1900 

1901@contextlib.contextmanager 

1902def tempdir(*args, **kwargs): 

1903 """Context manager to provide a temporary test folder. 

1904 

1905 All arguments are passed as this to the underlying tempfile.mkdtemp 

1906 function. 

1907 

1908 """ 

1909 tmpdir = mkdtemp(*args, **kwargs) 

1910 try: 

1911 yield tmpdir 

1912 finally: 

1913 shutil.rmtree(tmpdir) 

1914 

1915 

1916@contextlib.contextmanager 

1917def temppath(*args, **kwargs): 

1918 """Context manager for temporary files. 

1919 

1920 Context manager that returns the path to a closed temporary file. Its 

1921 parameters are the same as for tempfile.mkstemp and are passed directly 

1922 to that function. The underlying file is removed when the context is 

1923 exited, so it should be closed at that time. 

1924 

1925 Windows does not allow a temporary file to be opened if it is already 

1926 open, so the underlying file must be closed after opening before it 

1927 can be opened again. 

1928 

1929 """ 

1930 fd, path = mkstemp(*args, **kwargs) 

1931 os.close(fd) 

1932 try: 

1933 yield path 

1934 finally: 

1935 os.remove(path) 

1936 

1937 

1938class clear_and_catch_warnings(warnings.catch_warnings): 

1939 """ Context manager that resets warning registry for catching warnings 

1940 

1941 Warnings can be slippery, because, whenever a warning is triggered, Python 

1942 adds a ``__warningregistry__`` member to the *calling* module. This makes 

1943 it impossible to retrigger the warning in this module, whatever you put in 

1944 the warnings filters. This context manager accepts a sequence of `modules` 

1945 as a keyword argument to its constructor and: 

1946 

1947 * stores and removes any ``__warningregistry__`` entries in given `modules` 

1948 on entry; 

1949 * resets ``__warningregistry__`` to its previous state on exit. 

1950 

1951 This makes it possible to trigger any warning afresh inside the context 

1952 manager without disturbing the state of warnings outside. 

1953 

1954 For compatibility with Python 3.0, please consider all arguments to be 

1955 keyword-only. 

1956 

1957 Parameters 

1958 ---------- 

1959 record : bool, optional 

1960 Specifies whether warnings should be captured by a custom 

1961 implementation of ``warnings.showwarning()`` and be appended to a list 

1962 returned by the context manager. Otherwise None is returned by the 

1963 context manager. The objects appended to the list are arguments whose 

1964 attributes mirror the arguments to ``showwarning()``. 

1965 modules : sequence, optional 

1966 Sequence of modules for which to reset warnings registry on entry and 

1967 restore on exit. To work correctly, all 'ignore' filters should 

1968 filter by one of these modules. 

1969 

1970 Examples 

1971 -------- 

1972 >>> import warnings 

1973 >>> with np.testing.clear_and_catch_warnings( 

1974 ... modules=[np.core.fromnumeric]): 

1975 ... warnings.simplefilter('always') 

1976 ... warnings.filterwarnings('ignore', module='np.core.fromnumeric') 

1977 ... # do something that raises a warning but ignore those in 

1978 ... # np.core.fromnumeric 

1979 """ 

1980 class_modules = () 

1981 

1982 def __init__(self, record=False, modules=()): 

1983 self.modules = set(modules).union(self.class_modules) 

1984 self._warnreg_copies = {} 

1985 super().__init__(record=record) 

1986 

1987 def __enter__(self): 

1988 for mod in self.modules: 

1989 if hasattr(mod, '__warningregistry__'): 

1990 mod_reg = mod.__warningregistry__ 

1991 self._warnreg_copies[mod] = mod_reg.copy() 

1992 mod_reg.clear() 

1993 return super().__enter__() 

1994 

1995 def __exit__(self, *exc_info): 

1996 super().__exit__(*exc_info) 

1997 for mod in self.modules: 

1998 if hasattr(mod, '__warningregistry__'): 

1999 mod.__warningregistry__.clear() 

2000 if mod in self._warnreg_copies: 

2001 mod.__warningregistry__.update(self._warnreg_copies[mod]) 

2002 

2003 

2004class suppress_warnings: 

2005 """ 

2006 Context manager and decorator doing much the same as 

2007 ``warnings.catch_warnings``. 

2008 

2009 However, it also provides a filter mechanism to work around 

2010 https://bugs.python.org/issue4180. 

2011 

2012 This bug causes Python before 3.4 to not reliably show warnings again 

2013 after they have been ignored once (even within catch_warnings). It 

2014 means that no "ignore" filter can be used easily, since following 

2015 tests might need to see the warning. Additionally it allows easier 

2016 specificity for testing warnings and can be nested. 

2017 

2018 Parameters 

2019 ---------- 

2020 forwarding_rule : str, optional 

2021 One of "always", "once", "module", or "location". Analogous to 

2022 the usual warnings module filter mode, it is useful to reduce 

2023 noise mostly on the outmost level. Unsuppressed and unrecorded 

2024 warnings will be forwarded based on this rule. Defaults to "always". 

2025 "location" is equivalent to the warnings "default", match by exact 

2026 location the warning warning originated from. 

2027 

2028 Notes 

2029 ----- 

2030 Filters added inside the context manager will be discarded again 

2031 when leaving it. Upon entering all filters defined outside a 

2032 context will be applied automatically. 

2033 

2034 When a recording filter is added, matching warnings are stored in the 

2035 ``log`` attribute as well as in the list returned by ``record``. 

2036 

2037 If filters are added and the ``module`` keyword is given, the 

2038 warning registry of this module will additionally be cleared when 

2039 applying it, entering the context, or exiting it. This could cause 

2040 warnings to appear a second time after leaving the context if they 

2041 were configured to be printed once (default) and were already 

2042 printed before the context was entered. 

2043 

2044 Nesting this context manager will work as expected when the 

2045 forwarding rule is "always" (default). Unfiltered and unrecorded 

2046 warnings will be passed out and be matched by the outer level. 

2047 On the outmost level they will be printed (or caught by another 

2048 warnings context). The forwarding rule argument can modify this 

2049 behaviour. 

2050 

2051 Like ``catch_warnings`` this context manager is not threadsafe. 

2052 

2053 Examples 

2054 -------- 

2055 

2056 With a context manager:: 

2057 

2058 with np.testing.suppress_warnings() as sup: 

2059 sup.filter(DeprecationWarning, "Some text") 

2060 sup.filter(module=np.ma.core) 

2061 log = sup.record(FutureWarning, "Does this occur?") 

2062 command_giving_warnings() 

2063 # The FutureWarning was given once, the filtered warnings were 

2064 # ignored. All other warnings abide outside settings (may be 

2065 # printed/error) 

2066 assert_(len(log) == 1) 

2067 assert_(len(sup.log) == 1) # also stored in log attribute 

2068 

2069 Or as a decorator:: 

2070 

2071 sup = np.testing.suppress_warnings() 

2072 sup.filter(module=np.ma.core) # module must match exactly 

2073 @sup 

2074 def some_function(): 

2075 # do something which causes a warning in np.ma.core 

2076 pass 

2077 """ 

2078 def __init__(self, forwarding_rule="always"): 

2079 self._entered = False 

2080 

2081 # Suppressions are either instance or defined inside one with block: 

2082 self._suppressions = [] 

2083 

2084 if forwarding_rule not in {"always", "module", "once", "location"}: 

2085 raise ValueError("unsupported forwarding rule.") 

2086 self._forwarding_rule = forwarding_rule 

2087 

2088 def _clear_registries(self): 

2089 if hasattr(warnings, "_filters_mutated"): 

2090 # clearing the registry should not be necessary on new pythons, 

2091 # instead the filters should be mutated. 

2092 warnings._filters_mutated() 

2093 return 

2094 # Simply clear the registry, this should normally be harmless, 

2095 # note that on new pythons it would be invalidated anyway. 

2096 for module in self._tmp_modules: 

2097 if hasattr(module, "__warningregistry__"): 

2098 module.__warningregistry__.clear() 

2099 

2100 def _filter(self, category=Warning, message="", module=None, record=False): 

2101 if record: 

2102 record = [] # The log where to store warnings 

2103 else: 

2104 record = None 

2105 if self._entered: 

2106 if module is None: 

2107 warnings.filterwarnings( 

2108 "always", category=category, message=message) 

2109 else: 

2110 module_regex = module.__name__.replace('.', r'\.') + '$' 

2111 warnings.filterwarnings( 

2112 "always", category=category, message=message, 

2113 module=module_regex) 

2114 self._tmp_modules.add(module) 

2115 self._clear_registries() 

2116 

2117 self._tmp_suppressions.append( 

2118 (category, message, re.compile(message, re.I), module, record)) 

2119 else: 

2120 self._suppressions.append( 

2121 (category, message, re.compile(message, re.I), module, record)) 

2122 

2123 return record 

2124 

2125 def filter(self, category=Warning, message="", module=None): 

2126 """ 

2127 Add a new suppressing filter or apply it if the state is entered. 

2128 

2129 Parameters 

2130 ---------- 

2131 category : class, optional 

2132 Warning class to filter 

2133 message : string, optional 

2134 Regular expression matching the warning message. 

2135 module : module, optional 

2136 Module to filter for. Note that the module (and its file) 

2137 must match exactly and cannot be a submodule. This may make 

2138 it unreliable for external modules. 

2139 

2140 Notes 

2141 ----- 

2142 When added within a context, filters are only added inside 

2143 the context and will be forgotten when the context is exited. 

2144 """ 

2145 self._filter(category=category, message=message, module=module, 

2146 record=False) 

2147 

2148 def record(self, category=Warning, message="", module=None): 

2149 """ 

2150 Append a new recording filter or apply it if the state is entered. 

2151 

2152 All warnings matching will be appended to the ``log`` attribute. 

2153 

2154 Parameters 

2155 ---------- 

2156 category : class, optional 

2157 Warning class to filter 

2158 message : string, optional 

2159 Regular expression matching the warning message. 

2160 module : module, optional 

2161 Module to filter for. Note that the module (and its file) 

2162 must match exactly and cannot be a submodule. This may make 

2163 it unreliable for external modules. 

2164 

2165 Returns 

2166 ------- 

2167 log : list 

2168 A list which will be filled with all matched warnings. 

2169 

2170 Notes 

2171 ----- 

2172 When added within a context, filters are only added inside 

2173 the context and will be forgotten when the context is exited. 

2174 """ 

2175 return self._filter(category=category, message=message, module=module, 

2176 record=True) 

2177 

2178 def __enter__(self): 

2179 if self._entered: 

2180 raise RuntimeError("cannot enter suppress_warnings twice.") 

2181 

2182 self._orig_show = warnings.showwarning 

2183 self._filters = warnings.filters 

2184 warnings.filters = self._filters[:] 

2185 

2186 self._entered = True 

2187 self._tmp_suppressions = [] 

2188 self._tmp_modules = set() 

2189 self._forwarded = set() 

2190 

2191 self.log = [] # reset global log (no need to keep same list) 

2192 

2193 for cat, mess, _, mod, log in self._suppressions: 

2194 if log is not None: 

2195 del log[:] # clear the log 

2196 if mod is None: 

2197 warnings.filterwarnings( 

2198 "always", category=cat, message=mess) 

2199 else: 

2200 module_regex = mod.__name__.replace('.', r'\.') + '$' 

2201 warnings.filterwarnings( 

2202 "always", category=cat, message=mess, 

2203 module=module_regex) 

2204 self._tmp_modules.add(mod) 

2205 warnings.showwarning = self._showwarning 

2206 self._clear_registries() 

2207 

2208 return self 

2209 

2210 def __exit__(self, *exc_info): 

2211 warnings.showwarning = self._orig_show 

2212 warnings.filters = self._filters 

2213 self._clear_registries() 

2214 self._entered = False 

2215 del self._orig_show 

2216 del self._filters 

2217 

2218 def _showwarning(self, message, category, filename, lineno, 

2219 *args, use_warnmsg=None, **kwargs): 

2220 for cat, _, pattern, mod, rec in ( 

2221 self._suppressions + self._tmp_suppressions)[::-1]: 

2222 if (issubclass(category, cat) and 

2223 pattern.match(message.args[0]) is not None): 

2224 if mod is None: 

2225 # Message and category match, either recorded or ignored 

2226 if rec is not None: 

2227 msg = WarningMessage(message, category, filename, 

2228 lineno, **kwargs) 

2229 self.log.append(msg) 

2230 rec.append(msg) 

2231 return 

2232 # Use startswith, because warnings strips the c or o from 

2233 # .pyc/.pyo files. 

2234 elif mod.__file__.startswith(filename): 

2235 # The message and module (filename) match 

2236 if rec is not None: 

2237 msg = WarningMessage(message, category, filename, 

2238 lineno, **kwargs) 

2239 self.log.append(msg) 

2240 rec.append(msg) 

2241 return 

2242 

2243 # There is no filter in place, so pass to the outside handler 

2244 # unless we should only pass it once 

2245 if self._forwarding_rule == "always": 

2246 if use_warnmsg is None: 

2247 self._orig_show(message, category, filename, lineno, 

2248 *args, **kwargs) 

2249 else: 

2250 self._orig_showmsg(use_warnmsg) 

2251 return 

2252 

2253 if self._forwarding_rule == "once": 

2254 signature = (message.args, category) 

2255 elif self._forwarding_rule == "module": 

2256 signature = (message.args, category, filename) 

2257 elif self._forwarding_rule == "location": 

2258 signature = (message.args, category, filename, lineno) 

2259 

2260 if signature in self._forwarded: 

2261 return 

2262 self._forwarded.add(signature) 

2263 if use_warnmsg is None: 

2264 self._orig_show(message, category, filename, lineno, *args, 

2265 **kwargs) 

2266 else: 

2267 self._orig_showmsg(use_warnmsg) 

2268 

2269 def __call__(self, func): 

2270 """ 

2271 Function decorator to apply certain suppressions to a whole 

2272 function. 

2273 """ 

2274 @wraps(func) 

2275 def new_func(*args, **kwargs): 

2276 with self: 

2277 return func(*args, **kwargs) 

2278 

2279 return new_func 

2280 

2281 

2282@contextlib.contextmanager 

2283def _assert_no_gc_cycles_context(name=None): 

2284 __tracebackhide__ = True # Hide traceback for py.test 

2285 

2286 # not meaningful to test if there is no refcounting 

2287 if not HAS_REFCOUNT: 

2288 yield 

2289 return 

2290 

2291 assert_(gc.isenabled()) 

2292 gc.disable() 

2293 gc_debug = gc.get_debug() 

2294 try: 

2295 for i in range(100): 

2296 if gc.collect() == 0: 

2297 break 

2298 else: 

2299 raise RuntimeError( 

2300 "Unable to fully collect garbage - perhaps a __del__ method " 

2301 "is creating more reference cycles?") 

2302 

2303 gc.set_debug(gc.DEBUG_SAVEALL) 

2304 yield 

2305 # gc.collect returns the number of unreachable objects in cycles that 

2306 # were found -- we are checking that no cycles were created in the context 

2307 n_objects_in_cycles = gc.collect() 

2308 objects_in_cycles = gc.garbage[:] 

2309 finally: 

2310 del gc.garbage[:] 

2311 gc.set_debug(gc_debug) 

2312 gc.enable() 

2313 

2314 if n_objects_in_cycles: 

2315 name_str = f' when calling {name}' if name is not None else '' 

2316 raise AssertionError( 

2317 "Reference cycles were found{}: {} objects were collected, " 

2318 "of which {} are shown below:{}" 

2319 .format( 

2320 name_str, 

2321 n_objects_in_cycles, 

2322 len(objects_in_cycles), 

2323 ''.join( 

2324 "\n {} object with id={}:\n {}".format( 

2325 type(o).__name__, 

2326 id(o), 

2327 pprint.pformat(o).replace('\n', '\n ') 

2328 ) for o in objects_in_cycles 

2329 ) 

2330 ) 

2331 ) 

2332 

2333 

2334def assert_no_gc_cycles(*args, **kwargs): 

2335 """ 

2336 Fail if the given callable produces any reference cycles. 

2337 

2338 If called with all arguments omitted, may be used as a context manager: 

2339 

2340 with assert_no_gc_cycles(): 

2341 do_something() 

2342 

2343 .. versionadded:: 1.15.0 

2344 

2345 Parameters 

2346 ---------- 

2347 func : callable 

2348 The callable to test. 

2349 \\*args : Arguments 

2350 Arguments passed to `func`. 

2351 \\*\\*kwargs : Kwargs 

2352 Keyword arguments passed to `func`. 

2353 

2354 Returns 

2355 ------- 

2356 Nothing. The result is deliberately discarded to ensure that all cycles 

2357 are found. 

2358 

2359 """ 

2360 if not args: 

2361 return _assert_no_gc_cycles_context() 

2362 

2363 func = args[0] 

2364 args = args[1:] 

2365 with _assert_no_gc_cycles_context(name=func.__name__): 

2366 func(*args, **kwargs) 

2367 

2368 

2369def break_cycles(): 

2370 """ 

2371 Break reference cycles by calling gc.collect 

2372 Objects can call other objects' methods (for instance, another object's 

2373 __del__) inside their own __del__. On PyPy, the interpreter only runs 

2374 between calls to gc.collect, so multiple calls are needed to completely 

2375 release all cycles. 

2376 """ 

2377 

2378 gc.collect() 

2379 if IS_PYPY: 

2380 # a few more, just to make sure all the finalizers are called 

2381 gc.collect() 

2382 gc.collect() 

2383 gc.collect() 

2384 gc.collect() 

2385 

2386 

2387def requires_memory(free_bytes): 

2388 """Decorator to skip a test if not enough memory is available""" 

2389 import pytest 

2390 

2391 def decorator(func): 

2392 @wraps(func) 

2393 def wrapper(*a, **kw): 

2394 msg = check_free_memory(free_bytes) 

2395 if msg is not None: 

2396 pytest.skip(msg) 

2397 

2398 try: 

2399 return func(*a, **kw) 

2400 except MemoryError: 

2401 # Probably ran out of memory regardless: don't regard as failure 

2402 pytest.xfail("MemoryError raised") 

2403 

2404 return wrapper 

2405 

2406 return decorator 

2407 

2408 

2409def check_free_memory(free_bytes): 

2410 """ 

2411 Check whether `free_bytes` amount of memory is currently free. 

2412 Returns: None if enough memory available, otherwise error message 

2413 """ 

2414 env_var = 'NPY_AVAILABLE_MEM' 

2415 env_value = os.environ.get(env_var) 

2416 if env_value is not None: 

2417 try: 

2418 mem_free = _parse_size(env_value) 

2419 except ValueError as exc: 

2420 raise ValueError(f'Invalid environment variable {env_var}: {exc}') 

2421 

2422 msg = (f'{free_bytes/1e9} GB memory required, but environment variable ' 

2423 f'NPY_AVAILABLE_MEM={env_value} set') 

2424 else: 

2425 mem_free = _get_mem_available() 

2426 

2427 if mem_free is None: 

2428 msg = ("Could not determine available memory; set NPY_AVAILABLE_MEM " 

2429 "environment variable (e.g. NPY_AVAILABLE_MEM=16GB) to run " 

2430 "the test.") 

2431 mem_free = -1 

2432 else: 

2433 msg = f'{free_bytes/1e9} GB memory required, but {mem_free/1e9} GB available' 

2434 

2435 return msg if mem_free < free_bytes else None 

2436 

2437 

2438def _parse_size(size_str): 

2439 """Convert memory size strings ('12 GB' etc.) to float""" 

2440 suffixes = {'': 1, 'b': 1, 

2441 'k': 1000, 'm': 1000**2, 'g': 1000**3, 't': 1000**4, 

2442 'kb': 1000, 'mb': 1000**2, 'gb': 1000**3, 'tb': 1000**4, 

2443 'kib': 1024, 'mib': 1024**2, 'gib': 1024**3, 'tib': 1024**4} 

2444 

2445 size_re = re.compile(r'^\s*(\d+|\d+\.\d+)\s*({0})\s*$'.format( 

2446 '|'.join(suffixes.keys())), re.I) 

2447 

2448 m = size_re.match(size_str.lower()) 

2449 if not m or m.group(2) not in suffixes: 

2450 raise ValueError(f'value {size_str!r} not a valid size') 

2451 return int(float(m.group(1)) * suffixes[m.group(2)]) 

2452 

2453 

2454def _get_mem_available(): 

2455 """Return available memory in bytes, or None if unknown.""" 

2456 try: 

2457 import psutil 

2458 return psutil.virtual_memory().available 

2459 except (ImportError, AttributeError): 

2460 pass 

2461 

2462 if sys.platform.startswith('linux'): 

2463 info = {} 

2464 with open('/proc/meminfo') as f: 

2465 for line in f: 

2466 p = line.split() 

2467 info[p[0].strip(':').lower()] = int(p[1]) * 1024 

2468 

2469 if 'memavailable' in info: 

2470 # Linux >= 3.14 

2471 return info['memavailable'] 

2472 else: 

2473 return info['memfree'] + info['cached'] 

2474 

2475 return None 

2476 

2477 

2478def _no_tracing(func): 

2479 """ 

2480 Decorator to temporarily turn off tracing for the duration of a test. 

2481 Needed in tests that check refcounting, otherwise the tracing itself 

2482 influences the refcounts 

2483 """ 

2484 if not hasattr(sys, 'gettrace'): 

2485 return func 

2486 else: 

2487 @wraps(func) 

2488 def wrapper(*args, **kwargs): 

2489 original_trace = sys.gettrace() 

2490 try: 

2491 sys.settrace(None) 

2492 return func(*args, **kwargs) 

2493 finally: 

2494 sys.settrace(original_trace) 

2495 return wrapper 

2496 

2497 

2498def _get_glibc_version(): 

2499 try: 

2500 ver = os.confstr('CS_GNU_LIBC_VERSION').rsplit(' ')[1] 

2501 except Exception: 

2502 ver = '0.0' 

2503 

2504 return ver 

2505 

2506 

2507_glibcver = _get_glibc_version() 

2508_glibc_older_than = lambda x: (_glibcver != '0.0' and _glibcver < x) 

2509