Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/numpy/testing/_private/utils.py: 19%

897 statements  

« prev     ^ index     » next       coverage.py v7.2.2, created at 2023-03-23 06:06 +0000

1""" 

2Utility function to facilitate testing. 

3 

4""" 

5import os 

6import sys 

7import platform 

8import re 

9import gc 

10import operator 

11import warnings 

12from functools import partial, wraps 

13import shutil 

14import contextlib 

15from tempfile import mkdtemp, mkstemp 

16from unittest.case import SkipTest 

17from warnings import WarningMessage 

18import pprint 

19 

20import numpy as np 

21from numpy.core import( 

22 intp, float32, empty, arange, array_repr, ndarray, isnat, array) 

23import numpy.linalg.lapack_lite 

24 

25from io import StringIO 

26 

27__all__ = [ 

28 'assert_equal', 'assert_almost_equal', 'assert_approx_equal', 

29 'assert_array_equal', 'assert_array_less', 'assert_string_equal', 

30 'assert_array_almost_equal', 'assert_raises', 'build_err_msg', 

31 'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal', 

32 'raises', 'rundocs', 'runstring', 'verbose', 'measure', 

33 'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex', 

34 'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings', 

35 'assert_allclose', 'IgnoreException', 'clear_and_catch_warnings', 

36 'SkipTest', 'KnownFailureException', 'temppath', 'tempdir', 'IS_PYPY', 

37 'HAS_REFCOUNT', "IS_WASM", 'suppress_warnings', 'assert_array_compare', 

38 'assert_no_gc_cycles', 'break_cycles', 'HAS_LAPACK64', 'IS_PYSTON', 

39 '_OLD_PROMOTION' 

40 ] 

41 

42 

43class KnownFailureException(Exception): 

44 '''Raise this exception to mark a test as a known failing test.''' 

45 pass 

46 

47 

48KnownFailureTest = KnownFailureException # backwards compat 

49verbose = 0 

50 

51IS_WASM = platform.machine() in ["wasm32", "wasm64"] 

52IS_PYPY = sys.implementation.name == 'pypy' 

53IS_PYSTON = hasattr(sys, "pyston_version_info") 

54HAS_REFCOUNT = getattr(sys, 'getrefcount', None) is not None and not IS_PYSTON 

55HAS_LAPACK64 = numpy.linalg.lapack_lite._ilp64 

56 

57_OLD_PROMOTION = lambda: np._get_promotion_state() == 'legacy' 

58 

59 

60def import_nose(): 

61 """ Import nose only when needed. 

62 """ 

63 nose_is_good = True 

64 minimum_nose_version = (1, 0, 0) 

65 try: 

66 import nose 

67 except ImportError: 

68 nose_is_good = False 

69 else: 

70 if nose.__versioninfo__ < minimum_nose_version: 

71 nose_is_good = False 

72 

73 if not nose_is_good: 

74 msg = ('Need nose >= %d.%d.%d for tests - see ' 

75 'https://nose.readthedocs.io' % 

76 minimum_nose_version) 

77 raise ImportError(msg) 

78 

79 return nose 

80 

81 

82def assert_(val, msg=''): 

83 """ 

84 Assert that works in release mode. 

85 Accepts callable msg to allow deferring evaluation until failure. 

86 

87 The Python built-in ``assert`` does not work when executing code in 

88 optimized mode (the ``-O`` flag) - no byte-code is generated for it. 

89 

90 For documentation on usage, refer to the Python documentation. 

91 

92 """ 

93 __tracebackhide__ = True # Hide traceback for py.test 

94 if not val: 

95 try: 

96 smsg = msg() 

97 except TypeError: 

98 smsg = msg 

99 raise AssertionError(smsg) 

100 

101 

102def gisnan(x): 

103 """like isnan, but always raise an error if type not supported instead of 

104 returning a TypeError object. 

105 

106 Notes 

107 ----- 

108 isnan and other ufunc sometimes return a NotImplementedType object instead 

109 of raising any exception. This function is a wrapper to make sure an 

110 exception is always raised. 

111 

112 This should be removed once this problem is solved at the Ufunc level.""" 

113 from numpy.core import isnan 

114 st = isnan(x) 

115 if isinstance(st, type(NotImplemented)): 

116 raise TypeError("isnan not supported for this type") 

117 return st 

118 

119 

120def gisfinite(x): 

121 """like isfinite, but always raise an error if type not supported instead 

122 of returning a TypeError object. 

123 

124 Notes 

125 ----- 

126 isfinite and other ufunc sometimes return a NotImplementedType object 

127 instead of raising any exception. This function is a wrapper to make sure 

128 an exception is always raised. 

129 

130 This should be removed once this problem is solved at the Ufunc level.""" 

131 from numpy.core import isfinite, errstate 

132 with errstate(invalid='ignore'): 

133 st = isfinite(x) 

134 if isinstance(st, type(NotImplemented)): 

135 raise TypeError("isfinite not supported for this type") 

136 return st 

137 

138 

139def gisinf(x): 

140 """like isinf, but always raise an error if type not supported instead of 

141 returning a TypeError object. 

142 

143 Notes 

144 ----- 

145 isinf and other ufunc sometimes return a NotImplementedType object instead 

146 of raising any exception. This function is a wrapper to make sure an 

147 exception is always raised. 

148 

149 This should be removed once this problem is solved at the Ufunc level.""" 

150 from numpy.core import isinf, errstate 

151 with errstate(invalid='ignore'): 

152 st = isinf(x) 

153 if isinstance(st, type(NotImplemented)): 

154 raise TypeError("isinf not supported for this type") 

155 return st 

156 

157 

158if os.name == 'nt': 

159 # Code "stolen" from enthought/debug/memusage.py 

160 def GetPerformanceAttributes(object, counter, instance=None, 

161 inum=-1, format=None, machine=None): 

162 # NOTE: Many counters require 2 samples to give accurate results, 

163 # including "% Processor Time" (as by definition, at any instant, a 

164 # thread's CPU usage is either 0 or 100). To read counters like this, 

165 # you should copy this function, but keep the counter open, and call 

166 # CollectQueryData() each time you need to know. 

167 # See http://msdn.microsoft.com/library/en-us/dnperfmo/html/perfmonpt2.asp (dead link) 

168 # My older explanation for this was that the "AddCounter" process 

169 # forced the CPU to 100%, but the above makes more sense :) 

170 import win32pdh 

171 if format is None: 

172 format = win32pdh.PDH_FMT_LONG 

173 path = win32pdh.MakeCounterPath( (machine, object, instance, None, 

174 inum, counter)) 

175 hq = win32pdh.OpenQuery() 

176 try: 

177 hc = win32pdh.AddCounter(hq, path) 

178 try: 

179 win32pdh.CollectQueryData(hq) 

180 type, val = win32pdh.GetFormattedCounterValue(hc, format) 

181 return val 

182 finally: 

183 win32pdh.RemoveCounter(hc) 

184 finally: 

185 win32pdh.CloseQuery(hq) 

186 

187 def memusage(processName="python", instance=0): 

188 # from win32pdhutil, part of the win32all package 

189 import win32pdh 

190 return GetPerformanceAttributes("Process", "Virtual Bytes", 

191 processName, instance, 

192 win32pdh.PDH_FMT_LONG, None) 

193elif sys.platform[:5] == 'linux': 

194 

195 def memusage(_proc_pid_stat=f'/proc/{os.getpid()}/stat'): 

196 """ 

197 Return virtual memory size in bytes of the running python. 

198 

199 """ 

200 try: 

201 with open(_proc_pid_stat, 'r') as f: 

202 l = f.readline().split(' ') 

203 return int(l[22]) 

204 except Exception: 

205 return 

206else: 

207 def memusage(): 

208 """ 

209 Return memory usage of running python. [Not implemented] 

210 

211 """ 

212 raise NotImplementedError 

213 

214 

215if sys.platform[:5] == 'linux': 

216 def jiffies(_proc_pid_stat=f'/proc/{os.getpid()}/stat', _load_time=[]): 

217 """ 

218 Return number of jiffies elapsed. 

219 

220 Return number of jiffies (1/100ths of a second) that this 

221 process has been scheduled in user mode. See man 5 proc. 

222 

223 """ 

224 import time 

225 if not _load_time: 

226 _load_time.append(time.time()) 

227 try: 

228 with open(_proc_pid_stat, 'r') as f: 

229 l = f.readline().split(' ') 

230 return int(l[13]) 

231 except Exception: 

232 return int(100*(time.time()-_load_time[0])) 

233else: 

234 # os.getpid is not in all platforms available. 

235 # Using time is safe but inaccurate, especially when process 

236 # was suspended or sleeping. 

237 def jiffies(_load_time=[]): 

238 """ 

239 Return number of jiffies elapsed. 

240 

241 Return number of jiffies (1/100ths of a second) that this 

242 process has been scheduled in user mode. See man 5 proc. 

243 

244 """ 

245 import time 

246 if not _load_time: 

247 _load_time.append(time.time()) 

248 return int(100*(time.time()-_load_time[0])) 

249 

250 

251def build_err_msg(arrays, err_msg, header='Items are not equal:', 

252 verbose=True, names=('ACTUAL', 'DESIRED'), precision=8): 

253 msg = ['\n' + header] 

254 if err_msg: 

255 if err_msg.find('\n') == -1 and len(err_msg) < 79-len(header): 

256 msg = [msg[0] + ' ' + err_msg] 

257 else: 

258 msg.append(err_msg) 

259 if verbose: 

260 for i, a in enumerate(arrays): 

261 

262 if isinstance(a, ndarray): 

263 # precision argument is only needed if the objects are ndarrays 

264 r_func = partial(array_repr, precision=precision) 

265 else: 

266 r_func = repr 

267 

268 try: 

269 r = r_func(a) 

270 except Exception as exc: 

271 r = f'[repr failed for <{type(a).__name__}>: {exc}]' 

272 if r.count('\n') > 3: 

273 r = '\n'.join(r.splitlines()[:3]) 

274 r += '...' 

275 msg.append(f' {names[i]}: {r}') 

276 return '\n'.join(msg) 

277 

278 

279def assert_equal(actual, desired, err_msg='', verbose=True): 

280 """ 

281 Raises an AssertionError if two objects are not equal. 

282 

283 Given two objects (scalars, lists, tuples, dictionaries or numpy arrays), 

284 check that all elements of these objects are equal. An exception is raised 

285 at the first conflicting values. 

286 

287 When one of `actual` and `desired` is a scalar and the other is array_like, 

288 the function checks that each element of the array_like object is equal to 

289 the scalar. 

290 

291 This function handles NaN comparisons as if NaN was a "normal" number. 

292 That is, AssertionError is not raised if both objects have NaNs in the same 

293 positions. This is in contrast to the IEEE standard on NaNs, which says 

294 that NaN compared to anything must return False. 

295 

296 Parameters 

297 ---------- 

298 actual : array_like 

299 The object to check. 

300 desired : array_like 

301 The expected object. 

302 err_msg : str, optional 

303 The error message to be printed in case of failure. 

304 verbose : bool, optional 

305 If True, the conflicting values are appended to the error message. 

306 

307 Raises 

308 ------ 

309 AssertionError 

310 If actual and desired are not equal. 

311 

312 Examples 

313 -------- 

314 >>> np.testing.assert_equal([4,5], [4,6]) 

315 Traceback (most recent call last): 

316 ... 

317 AssertionError: 

318 Items are not equal: 

319 item=1 

320 ACTUAL: 5 

321 DESIRED: 6 

322 

323 The following comparison does not raise an exception. There are NaNs 

324 in the inputs, but they are in the same positions. 

325 

326 >>> np.testing.assert_equal(np.array([1.0, 2.0, np.nan]), [1, 2, np.nan]) 

327 

328 """ 

329 __tracebackhide__ = True # Hide traceback for py.test 

330 if isinstance(desired, dict): 

331 if not isinstance(actual, dict): 

332 raise AssertionError(repr(type(actual))) 

333 assert_equal(len(actual), len(desired), err_msg, verbose) 

334 for k, i in desired.items(): 

335 if k not in actual: 

336 raise AssertionError(repr(k)) 

337 assert_equal(actual[k], desired[k], f'key={k!r}\n{err_msg}', 

338 verbose) 

339 return 

340 if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)): 

341 assert_equal(len(actual), len(desired), err_msg, verbose) 

342 for k in range(len(desired)): 

343 assert_equal(actual[k], desired[k], f'item={k!r}\n{err_msg}', 

344 verbose) 

345 return 

346 from numpy.core import ndarray, isscalar, signbit 

347 from numpy.lib import iscomplexobj, real, imag 

348 if isinstance(actual, ndarray) or isinstance(desired, ndarray): 

349 return assert_array_equal(actual, desired, err_msg, verbose) 

350 msg = build_err_msg([actual, desired], err_msg, verbose=verbose) 

351 

352 # Handle complex numbers: separate into real/imag to handle 

353 # nan/inf/negative zero correctly 

354 # XXX: catch ValueError for subclasses of ndarray where iscomplex fail 

355 try: 

356 usecomplex = iscomplexobj(actual) or iscomplexobj(desired) 

357 except (ValueError, TypeError): 

358 usecomplex = False 

359 

360 if usecomplex: 

361 if iscomplexobj(actual): 

362 actualr = real(actual) 

363 actuali = imag(actual) 

364 else: 

365 actualr = actual 

366 actuali = 0 

367 if iscomplexobj(desired): 

368 desiredr = real(desired) 

369 desiredi = imag(desired) 

370 else: 

371 desiredr = desired 

372 desiredi = 0 

373 try: 

374 assert_equal(actualr, desiredr) 

375 assert_equal(actuali, desiredi) 

376 except AssertionError: 

377 raise AssertionError(msg) 

378 

379 # isscalar test to check cases such as [np.nan] != np.nan 

380 if isscalar(desired) != isscalar(actual): 

381 raise AssertionError(msg) 

382 

383 try: 

384 isdesnat = isnat(desired) 

385 isactnat = isnat(actual) 

386 dtypes_match = (np.asarray(desired).dtype.type == 

387 np.asarray(actual).dtype.type) 

388 if isdesnat and isactnat: 

389 # If both are NaT (and have the same dtype -- datetime or 

390 # timedelta) they are considered equal. 

391 if dtypes_match: 

392 return 

393 else: 

394 raise AssertionError(msg) 

395 

396 except (TypeError, ValueError, NotImplementedError): 

397 pass 

398 

399 # Inf/nan/negative zero handling 

400 try: 

401 isdesnan = gisnan(desired) 

402 isactnan = gisnan(actual) 

403 if isdesnan and isactnan: 

404 return # both nan, so equal 

405 

406 # handle signed zero specially for floats 

407 array_actual = np.asarray(actual) 

408 array_desired = np.asarray(desired) 

409 if (array_actual.dtype.char in 'Mm' or 

410 array_desired.dtype.char in 'Mm'): 

411 # version 1.18 

412 # until this version, gisnan failed for datetime64 and timedelta64. 

413 # Now it succeeds but comparison to scalar with a different type 

414 # emits a DeprecationWarning. 

415 # Avoid that by skipping the next check 

416 raise NotImplementedError('cannot compare to a scalar ' 

417 'with a different type') 

418 

419 if desired == 0 and actual == 0: 

420 if not signbit(desired) == signbit(actual): 

421 raise AssertionError(msg) 

422 

423 except (TypeError, ValueError, NotImplementedError): 

424 pass 

425 

426 try: 

427 # Explicitly use __eq__ for comparison, gh-2552 

428 if not (desired == actual): 

429 raise AssertionError(msg) 

430 

431 except (DeprecationWarning, FutureWarning) as e: 

432 # this handles the case when the two types are not even comparable 

433 if 'elementwise == comparison' in e.args[0]: 

434 raise AssertionError(msg) 

435 else: 

436 raise 

437 

438 

439def print_assert_equal(test_string, actual, desired): 

440 """ 

441 Test if two objects are equal, and print an error message if test fails. 

442 

443 The test is performed with ``actual == desired``. 

444 

445 Parameters 

446 ---------- 

447 test_string : str 

448 The message supplied to AssertionError. 

449 actual : object 

450 The object to test for equality against `desired`. 

451 desired : object 

452 The expected result. 

453 

454 Examples 

455 -------- 

456 >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 1]) 

457 >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 2]) 

458 Traceback (most recent call last): 

459 ... 

460 AssertionError: Test XYZ of func xyz failed 

461 ACTUAL: 

462 [0, 1] 

463 DESIRED: 

464 [0, 2] 

465 

466 """ 

467 __tracebackhide__ = True # Hide traceback for py.test 

468 import pprint 

469 

470 if not (actual == desired): 

471 msg = StringIO() 

472 msg.write(test_string) 

473 msg.write(' failed\nACTUAL: \n') 

474 pprint.pprint(actual, msg) 

475 msg.write('DESIRED: \n') 

476 pprint.pprint(desired, msg) 

477 raise AssertionError(msg.getvalue()) 

478 

479 

480@np._no_nep50_warning() 

481def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): 

482 """ 

483 Raises an AssertionError if two items are not equal up to desired 

484 precision. 

485 

486 .. note:: It is recommended to use one of `assert_allclose`, 

487 `assert_array_almost_equal_nulp` or `assert_array_max_ulp` 

488 instead of this function for more consistent floating point 

489 comparisons. 

490 

491 The test verifies that the elements of `actual` and `desired` satisfy. 

492 

493 ``abs(desired-actual) < float64(1.5 * 10**(-decimal))`` 

494 

495 That is a looser test than originally documented, but agrees with what the 

496 actual implementation in `assert_array_almost_equal` did up to rounding 

497 vagaries. An exception is raised at conflicting values. For ndarrays this 

498 delegates to assert_array_almost_equal 

499 

500 Parameters 

501 ---------- 

502 actual : array_like 

503 The object to check. 

504 desired : array_like 

505 The expected object. 

506 decimal : int, optional 

507 Desired precision, default is 7. 

508 err_msg : str, optional 

509 The error message to be printed in case of failure. 

510 verbose : bool, optional 

511 If True, the conflicting values are appended to the error message. 

512 

513 Raises 

514 ------ 

515 AssertionError 

516 If actual and desired are not equal up to specified precision. 

517 

518 See Also 

519 -------- 

520 assert_allclose: Compare two array_like objects for equality with desired 

521 relative and/or absolute precision. 

522 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal 

523 

524 Examples 

525 -------- 

526 >>> from numpy.testing import assert_almost_equal 

527 >>> assert_almost_equal(2.3333333333333, 2.33333334) 

528 >>> assert_almost_equal(2.3333333333333, 2.33333334, decimal=10) 

529 Traceback (most recent call last): 

530 ... 

531 AssertionError: 

532 Arrays are not almost equal to 10 decimals 

533 ACTUAL: 2.3333333333333 

534 DESIRED: 2.33333334 

535 

536 >>> assert_almost_equal(np.array([1.0,2.3333333333333]), 

537 ... np.array([1.0,2.33333334]), decimal=9) 

538 Traceback (most recent call last): 

539 ... 

540 AssertionError: 

541 Arrays are not almost equal to 9 decimals 

542 <BLANKLINE> 

543 Mismatched elements: 1 / 2 (50%) 

544 Max absolute difference: 6.66669964e-09 

545 Max relative difference: 2.85715698e-09 

546 x: array([1. , 2.333333333]) 

547 y: array([1. , 2.33333334]) 

548 

549 """ 

550 __tracebackhide__ = True # Hide traceback for py.test 

551 from numpy.core import ndarray 

552 from numpy.lib import iscomplexobj, real, imag 

553 

554 # Handle complex numbers: separate into real/imag to handle 

555 # nan/inf/negative zero correctly 

556 # XXX: catch ValueError for subclasses of ndarray where iscomplex fail 

557 try: 

558 usecomplex = iscomplexobj(actual) or iscomplexobj(desired) 

559 except ValueError: 

560 usecomplex = False 

561 

562 def _build_err_msg(): 

563 header = ('Arrays are not almost equal to %d decimals' % decimal) 

564 return build_err_msg([actual, desired], err_msg, verbose=verbose, 

565 header=header) 

566 

567 if usecomplex: 

568 if iscomplexobj(actual): 

569 actualr = real(actual) 

570 actuali = imag(actual) 

571 else: 

572 actualr = actual 

573 actuali = 0 

574 if iscomplexobj(desired): 

575 desiredr = real(desired) 

576 desiredi = imag(desired) 

577 else: 

578 desiredr = desired 

579 desiredi = 0 

580 try: 

581 assert_almost_equal(actualr, desiredr, decimal=decimal) 

582 assert_almost_equal(actuali, desiredi, decimal=decimal) 

583 except AssertionError: 

584 raise AssertionError(_build_err_msg()) 

585 

586 if isinstance(actual, (ndarray, tuple, list)) \ 

587 or isinstance(desired, (ndarray, tuple, list)): 

588 return assert_array_almost_equal(actual, desired, decimal, err_msg) 

589 try: 

590 # If one of desired/actual is not finite, handle it specially here: 

591 # check that both are nan if any is a nan, and test for equality 

592 # otherwise 

593 if not (gisfinite(desired) and gisfinite(actual)): 

594 if gisnan(desired) or gisnan(actual): 

595 if not (gisnan(desired) and gisnan(actual)): 

596 raise AssertionError(_build_err_msg()) 

597 else: 

598 if not desired == actual: 

599 raise AssertionError(_build_err_msg()) 

600 return 

601 except (NotImplementedError, TypeError): 

602 pass 

603 if abs(desired - actual) >= np.float64(1.5 * 10.0**(-decimal)): 

604 raise AssertionError(_build_err_msg()) 

605 

606 

607@np._no_nep50_warning() 

608def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): 

609 """ 

610 Raises an AssertionError if two items are not equal up to significant 

611 digits. 

612 

613 .. note:: It is recommended to use one of `assert_allclose`, 

614 `assert_array_almost_equal_nulp` or `assert_array_max_ulp` 

615 instead of this function for more consistent floating point 

616 comparisons. 

617 

618 Given two numbers, check that they are approximately equal. 

619 Approximately equal is defined as the number of significant digits 

620 that agree. 

621 

622 Parameters 

623 ---------- 

624 actual : scalar 

625 The object to check. 

626 desired : scalar 

627 The expected object. 

628 significant : int, optional 

629 Desired precision, default is 7. 

630 err_msg : str, optional 

631 The error message to be printed in case of failure. 

632 verbose : bool, optional 

633 If True, the conflicting values are appended to the error message. 

634 

635 Raises 

636 ------ 

637 AssertionError 

638 If actual and desired are not equal up to specified precision. 

639 

640 See Also 

641 -------- 

642 assert_allclose: Compare two array_like objects for equality with desired 

643 relative and/or absolute precision. 

644 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal 

645 

646 Examples 

647 -------- 

648 >>> np.testing.assert_approx_equal(0.12345677777777e-20, 0.1234567e-20) 

649 >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345671e-20, 

650 ... significant=8) 

651 >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345672e-20, 

652 ... significant=8) 

653 Traceback (most recent call last): 

654 ... 

655 AssertionError: 

656 Items are not equal to 8 significant digits: 

657 ACTUAL: 1.234567e-21 

658 DESIRED: 1.2345672e-21 

659 

660 the evaluated condition that raises the exception is 

661 

662 >>> abs(0.12345670e-20/1e-21 - 0.12345672e-20/1e-21) >= 10**-(8-1) 

663 True 

664 

665 """ 

666 __tracebackhide__ = True # Hide traceback for py.test 

667 import numpy as np 

668 

669 (actual, desired) = map(float, (actual, desired)) 

670 if desired == actual: 

671 return 

672 # Normalized the numbers to be in range (-10.0,10.0) 

673 # scale = float(pow(10,math.floor(math.log10(0.5*(abs(desired)+abs(actual)))))) 

674 with np.errstate(invalid='ignore'): 

675 scale = 0.5*(np.abs(desired) + np.abs(actual)) 

676 scale = np.power(10, np.floor(np.log10(scale))) 

677 try: 

678 sc_desired = desired/scale 

679 except ZeroDivisionError: 

680 sc_desired = 0.0 

681 try: 

682 sc_actual = actual/scale 

683 except ZeroDivisionError: 

684 sc_actual = 0.0 

685 msg = build_err_msg( 

686 [actual, desired], err_msg, 

687 header='Items are not equal to %d significant digits:' % significant, 

688 verbose=verbose) 

689 try: 

690 # If one of desired/actual is not finite, handle it specially here: 

691 # check that both are nan if any is a nan, and test for equality 

692 # otherwise 

693 if not (gisfinite(desired) and gisfinite(actual)): 

694 if gisnan(desired) or gisnan(actual): 

695 if not (gisnan(desired) and gisnan(actual)): 

696 raise AssertionError(msg) 

697 else: 

698 if not desired == actual: 

699 raise AssertionError(msg) 

700 return 

701 except (TypeError, NotImplementedError): 

702 pass 

703 if np.abs(sc_desired - sc_actual) >= np.power(10., -(significant-1)): 

704 raise AssertionError(msg) 

705 

706 

707@np._no_nep50_warning() 

708def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', 

709 precision=6, equal_nan=True, equal_inf=True, 

710 *, strict=False): 

711 __tracebackhide__ = True # Hide traceback for py.test 

712 from numpy.core import array, array2string, isnan, inf, bool_, errstate, all, max, object_ 

713 

714 x = np.asanyarray(x) 

715 y = np.asanyarray(y) 

716 

717 # original array for output formatting 

718 ox, oy = x, y 

719 

720 def isnumber(x): 

721 return x.dtype.char in '?bhilqpBHILQPefdgFDG' 

722 

723 def istime(x): 

724 return x.dtype.char in "Mm" 

725 

726 def func_assert_same_pos(x, y, func=isnan, hasval='nan'): 

727 """Handling nan/inf. 

728 

729 Combine results of running func on x and y, checking that they are True 

730 at the same locations. 

731 

732 """ 

733 __tracebackhide__ = True # Hide traceback for py.test 

734 

735 x_id = func(x) 

736 y_id = func(y) 

737 # We include work-arounds here to handle three types of slightly 

738 # pathological ndarray subclasses: 

739 # (1) all() on `masked` array scalars can return masked arrays, so we 

740 # use != True 

741 # (2) __eq__ on some ndarray subclasses returns Python booleans 

742 # instead of element-wise comparisons, so we cast to bool_() and 

743 # use isinstance(..., bool) checks 

744 # (3) subclasses with bare-bones __array_function__ implementations may 

745 # not implement np.all(), so favor using the .all() method 

746 # We are not committed to supporting such subclasses, but it's nice to 

747 # support them if possible. 

748 if bool_(x_id == y_id).all() != True: 

749 msg = build_err_msg([x, y], 

750 err_msg + '\nx and y %s location mismatch:' 

751 % (hasval), verbose=verbose, header=header, 

752 names=('x', 'y'), precision=precision) 

753 raise AssertionError(msg) 

754 # If there is a scalar, then here we know the array has the same 

755 # flag as it everywhere, so we should return the scalar flag. 

756 if isinstance(x_id, bool) or x_id.ndim == 0: 

757 return bool_(x_id) 

758 elif isinstance(y_id, bool) or y_id.ndim == 0: 

759 return bool_(y_id) 

760 else: 

761 return y_id 

762 

763 try: 

764 if strict: 

765 cond = x.shape == y.shape and x.dtype == y.dtype 

766 else: 

767 cond = (x.shape == () or y.shape == ()) or x.shape == y.shape 

768 if not cond: 

769 if x.shape != y.shape: 

770 reason = f'\n(shapes {x.shape}, {y.shape} mismatch)' 

771 else: 

772 reason = f'\n(dtypes {x.dtype}, {y.dtype} mismatch)' 

773 msg = build_err_msg([x, y], 

774 err_msg 

775 + reason, 

776 verbose=verbose, header=header, 

777 names=('x', 'y'), precision=precision) 

778 raise AssertionError(msg) 

779 

780 flagged = bool_(False) 

781 if isnumber(x) and isnumber(y): 

782 if equal_nan: 

783 flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan') 

784 

785 if equal_inf: 

786 flagged |= func_assert_same_pos(x, y, 

787 func=lambda xy: xy == +inf, 

788 hasval='+inf') 

789 flagged |= func_assert_same_pos(x, y, 

790 func=lambda xy: xy == -inf, 

791 hasval='-inf') 

792 

793 elif istime(x) and istime(y): 

794 # If one is datetime64 and the other timedelta64 there is no point 

795 if equal_nan and x.dtype.type == y.dtype.type: 

796 flagged = func_assert_same_pos(x, y, func=isnat, hasval="NaT") 

797 

798 if flagged.ndim > 0: 

799 x, y = x[~flagged], y[~flagged] 

800 # Only do the comparison if actual values are left 

801 if x.size == 0: 

802 return 

803 elif flagged: 

804 # no sense doing comparison if everything is flagged. 

805 return 

806 

807 val = comparison(x, y) 

808 

809 if isinstance(val, bool): 

810 cond = val 

811 reduced = array([val]) 

812 else: 

813 reduced = val.ravel() 

814 cond = reduced.all() 

815 

816 # The below comparison is a hack to ensure that fully masked 

817 # results, for which val.ravel().all() returns np.ma.masked, 

818 # do not trigger a failure (np.ma.masked != True evaluates as 

819 # np.ma.masked, which is falsy). 

820 if cond != True: 

821 n_mismatch = reduced.size - reduced.sum(dtype=intp) 

822 n_elements = flagged.size if flagged.ndim != 0 else reduced.size 

823 percent_mismatch = 100 * n_mismatch / n_elements 

824 remarks = [ 

825 'Mismatched elements: {} / {} ({:.3g}%)'.format( 

826 n_mismatch, n_elements, percent_mismatch)] 

827 

828 with errstate(all='ignore'): 

829 # ignore errors for non-numeric types 

830 with contextlib.suppress(TypeError): 

831 error = abs(x - y) 

832 if np.issubdtype(x.dtype, np.unsignedinteger): 

833 error2 = abs(y - x) 

834 np.minimum(error, error2, out=error) 

835 max_abs_error = max(error) 

836 if getattr(error, 'dtype', object_) == object_: 

837 remarks.append('Max absolute difference: ' 

838 + str(max_abs_error)) 

839 else: 

840 remarks.append('Max absolute difference: ' 

841 + array2string(max_abs_error)) 

842 

843 # note: this definition of relative error matches that one 

844 # used by assert_allclose (found in np.isclose) 

845 # Filter values where the divisor would be zero 

846 nonzero = bool_(y != 0) 

847 if all(~nonzero): 

848 max_rel_error = array(inf) 

849 else: 

850 max_rel_error = max(error[nonzero] / abs(y[nonzero])) 

851 if getattr(error, 'dtype', object_) == object_: 

852 remarks.append('Max relative difference: ' 

853 + str(max_rel_error)) 

854 else: 

855 remarks.append('Max relative difference: ' 

856 + array2string(max_rel_error)) 

857 

858 err_msg += '\n' + '\n'.join(remarks) 

859 msg = build_err_msg([ox, oy], err_msg, 

860 verbose=verbose, header=header, 

861 names=('x', 'y'), precision=precision) 

862 raise AssertionError(msg) 

863 except ValueError: 

864 import traceback 

865 efmt = traceback.format_exc() 

866 header = f'error during assertion:\n\n{efmt}\n\n{header}' 

867 

868 msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header, 

869 names=('x', 'y'), precision=precision) 

870 raise ValueError(msg) 

871 

872 

873def assert_array_equal(x, y, err_msg='', verbose=True, *, strict=False): 

874 """ 

875 Raises an AssertionError if two array_like objects are not equal. 

876 

877 Given two array_like objects, check that the shape is equal and all 

878 elements of these objects are equal (but see the Notes for the special 

879 handling of a scalar). An exception is raised at shape mismatch or 

880 conflicting values. In contrast to the standard usage in numpy, NaNs 

881 are compared like numbers, no assertion is raised if both objects have 

882 NaNs in the same positions. 

883 

884 The usual caution for verifying equality with floating point numbers is 

885 advised. 

886 

887 Parameters 

888 ---------- 

889 x : array_like 

890 The actual object to check. 

891 y : array_like 

892 The desired, expected object. 

893 err_msg : str, optional 

894 The error message to be printed in case of failure. 

895 verbose : bool, optional 

896 If True, the conflicting values are appended to the error message. 

897 strict : bool, optional 

898 If True, raise an AssertionError when either the shape or the data 

899 type of the array_like objects does not match. The special 

900 handling for scalars mentioned in the Notes section is disabled. 

901 

902 .. versionadded:: 1.24.0 

903 

904 Raises 

905 ------ 

906 AssertionError 

907 If actual and desired objects are not equal. 

908 

909 See Also 

910 -------- 

911 assert_allclose: Compare two array_like objects for equality with desired 

912 relative and/or absolute precision. 

913 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal 

914 

915 Notes 

916 ----- 

917 When one of `x` and `y` is a scalar and the other is array_like, the 

918 function checks that each element of the array_like object is equal to 

919 the scalar. This behaviour can be disabled with the `strict` parameter. 

920 

921 Examples 

922 -------- 

923 The first assert does not raise an exception: 

924 

925 >>> np.testing.assert_array_equal([1.0,2.33333,np.nan], 

926 ... [np.exp(0),2.33333, np.nan]) 

927 

928 Assert fails with numerical imprecision with floats: 

929 

930 >>> np.testing.assert_array_equal([1.0,np.pi,np.nan], 

931 ... [1, np.sqrt(np.pi)**2, np.nan]) 

932 Traceback (most recent call last): 

933 ... 

934 AssertionError: 

935 Arrays are not equal 

936 <BLANKLINE> 

937 Mismatched elements: 1 / 3 (33.3%) 

938 Max absolute difference: 4.4408921e-16 

939 Max relative difference: 1.41357986e-16 

940 x: array([1. , 3.141593, nan]) 

941 y: array([1. , 3.141593, nan]) 

942 

943 Use `assert_allclose` or one of the nulp (number of floating point values) 

944 functions for these cases instead: 

945 

946 >>> np.testing.assert_allclose([1.0,np.pi,np.nan], 

947 ... [1, np.sqrt(np.pi)**2, np.nan], 

948 ... rtol=1e-10, atol=0) 

949 

950 As mentioned in the Notes section, `assert_array_equal` has special 

951 handling for scalars. Here the test checks that each value in `x` is 3: 

952 

953 >>> x = np.full((2, 5), fill_value=3) 

954 >>> np.testing.assert_array_equal(x, 3) 

955 

956 Use `strict` to raise an AssertionError when comparing a scalar with an 

957 array: 

958 

959 >>> np.testing.assert_array_equal(x, 3, strict=True) 

960 Traceback (most recent call last): 

961 ... 

962 AssertionError: 

963 Arrays are not equal 

964 <BLANKLINE> 

965 (shapes (2, 5), () mismatch) 

966 x: array([[3, 3, 3, 3, 3], 

967 [3, 3, 3, 3, 3]]) 

968 y: array(3) 

969 

970 The `strict` parameter also ensures that the array data types match: 

971 

972 >>> x = np.array([2, 2, 2]) 

973 >>> y = np.array([2., 2., 2.], dtype=np.float32) 

974 >>> np.testing.assert_array_equal(x, y, strict=True) 

975 Traceback (most recent call last): 

976 ... 

977 AssertionError: 

978 Arrays are not equal 

979 <BLANKLINE> 

980 (dtypes int64, float32 mismatch) 

981 x: array([2, 2, 2]) 

982 y: array([2., 2., 2.], dtype=float32) 

983 """ 

984 __tracebackhide__ = True # Hide traceback for py.test 

985 assert_array_compare(operator.__eq__, x, y, err_msg=err_msg, 

986 verbose=verbose, header='Arrays are not equal', 

987 strict=strict) 

988 

989 

990@np._no_nep50_warning() 

991def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): 

992 """ 

993 Raises an AssertionError if two objects are not equal up to desired 

994 precision. 

995 

996 .. note:: It is recommended to use one of `assert_allclose`, 

997 `assert_array_almost_equal_nulp` or `assert_array_max_ulp` 

998 instead of this function for more consistent floating point 

999 comparisons. 

1000 

1001 The test verifies identical shapes and that the elements of ``actual`` and 

1002 ``desired`` satisfy. 

1003 

1004 ``abs(desired-actual) < 1.5 * 10**(-decimal)`` 

1005 

1006 That is a looser test than originally documented, but agrees with what the 

1007 actual implementation did up to rounding vagaries. An exception is raised 

1008 at shape mismatch or conflicting values. In contrast to the standard usage 

1009 in numpy, NaNs are compared like numbers, no assertion is raised if both 

1010 objects have NaNs in the same positions. 

1011 

1012 Parameters 

1013 ---------- 

1014 x : array_like 

1015 The actual object to check. 

1016 y : array_like 

1017 The desired, expected object. 

1018 decimal : int, optional 

1019 Desired precision, default is 6. 

1020 err_msg : str, optional 

1021 The error message to be printed in case of failure. 

1022 verbose : bool, optional 

1023 If True, the conflicting values are appended to the error message. 

1024 

1025 Raises 

1026 ------ 

1027 AssertionError 

1028 If actual and desired are not equal up to specified precision. 

1029 

1030 See Also 

1031 -------- 

1032 assert_allclose: Compare two array_like objects for equality with desired 

1033 relative and/or absolute precision. 

1034 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal 

1035 

1036 Examples 

1037 -------- 

1038 the first assert does not raise an exception 

1039 

1040 >>> np.testing.assert_array_almost_equal([1.0,2.333,np.nan], 

1041 ... [1.0,2.333,np.nan]) 

1042 

1043 >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan], 

1044 ... [1.0,2.33339,np.nan], decimal=5) 

1045 Traceback (most recent call last): 

1046 ... 

1047 AssertionError: 

1048 Arrays are not almost equal to 5 decimals 

1049 <BLANKLINE> 

1050 Mismatched elements: 1 / 3 (33.3%) 

1051 Max absolute difference: 6.e-05 

1052 Max relative difference: 2.57136612e-05 

1053 x: array([1. , 2.33333, nan]) 

1054 y: array([1. , 2.33339, nan]) 

1055 

1056 >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan], 

1057 ... [1.0,2.33333, 5], decimal=5) 

1058 Traceback (most recent call last): 

1059 ... 

1060 AssertionError: 

1061 Arrays are not almost equal to 5 decimals 

1062 <BLANKLINE> 

1063 x and y nan location mismatch: 

1064 x: array([1. , 2.33333, nan]) 

1065 y: array([1. , 2.33333, 5. ]) 

1066 

1067 """ 

1068 __tracebackhide__ = True # Hide traceback for py.test 

1069 from numpy.core import number, float_, result_type, array 

1070 from numpy.core.numerictypes import issubdtype 

1071 from numpy.core.fromnumeric import any as npany 

1072 

1073 def compare(x, y): 

1074 try: 

1075 if npany(gisinf(x)) or npany( gisinf(y)): 

1076 xinfid = gisinf(x) 

1077 yinfid = gisinf(y) 

1078 if not (xinfid == yinfid).all(): 

1079 return False 

1080 # if one item, x and y is +- inf 

1081 if x.size == y.size == 1: 

1082 return x == y 

1083 x = x[~xinfid] 

1084 y = y[~yinfid] 

1085 except (TypeError, NotImplementedError): 

1086 pass 

1087 

1088 # make sure y is an inexact type to avoid abs(MIN_INT); will cause 

1089 # casting of x later. 

1090 dtype = result_type(y, 1.) 

1091 y = np.asanyarray(y, dtype) 

1092 z = abs(x - y) 

1093 

1094 if not issubdtype(z.dtype, number): 

1095 z = z.astype(float_) # handle object arrays 

1096 

1097 return z < 1.5 * 10.0**(-decimal) 

1098 

1099 assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, 

1100 header=('Arrays are not almost equal to %d decimals' % decimal), 

1101 precision=decimal) 

1102 

1103 

1104def assert_array_less(x, y, err_msg='', verbose=True): 

1105 """ 

1106 Raises an AssertionError if two array_like objects are not ordered by less 

1107 than. 

1108 

1109 Given two array_like objects, check that the shape is equal and all 

1110 elements of the first object are strictly smaller than those of the 

1111 second object. An exception is raised at shape mismatch or incorrectly 

1112 ordered values. Shape mismatch does not raise if an object has zero 

1113 dimension. In contrast to the standard usage in numpy, NaNs are 

1114 compared, no assertion is raised if both objects have NaNs in the same 

1115 positions. 

1116 

1117 

1118 

1119 Parameters 

1120 ---------- 

1121 x : array_like 

1122 The smaller object to check. 

1123 y : array_like 

1124 The larger object to compare. 

1125 err_msg : string 

1126 The error message to be printed in case of failure. 

1127 verbose : bool 

1128 If True, the conflicting values are appended to the error message. 

1129 

1130 Raises 

1131 ------ 

1132 AssertionError 

1133 If actual and desired objects are not equal. 

1134 

1135 See Also 

1136 -------- 

1137 assert_array_equal: tests objects for equality 

1138 assert_array_almost_equal: test objects for equality up to precision 

1139 

1140 

1141 

1142 Examples 

1143 -------- 

1144 >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1.1, 2.0, np.nan]) 

1145 >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1, 2.0, np.nan]) 

1146 Traceback (most recent call last): 

1147 ... 

1148 AssertionError: 

1149 Arrays are not less-ordered 

1150 <BLANKLINE> 

1151 Mismatched elements: 1 / 3 (33.3%) 

1152 Max absolute difference: 1. 

1153 Max relative difference: 0.5 

1154 x: array([ 1., 1., nan]) 

1155 y: array([ 1., 2., nan]) 

1156 

1157 >>> np.testing.assert_array_less([1.0, 4.0], 3) 

1158 Traceback (most recent call last): 

1159 ... 

1160 AssertionError: 

1161 Arrays are not less-ordered 

1162 <BLANKLINE> 

1163 Mismatched elements: 1 / 2 (50%) 

1164 Max absolute difference: 2. 

1165 Max relative difference: 0.66666667 

1166 x: array([1., 4.]) 

1167 y: array(3) 

1168 

1169 >>> np.testing.assert_array_less([1.0, 2.0, 3.0], [4]) 

1170 Traceback (most recent call last): 

1171 ... 

1172 AssertionError: 

1173 Arrays are not less-ordered 

1174 <BLANKLINE> 

1175 (shapes (3,), (1,) mismatch) 

1176 x: array([1., 2., 3.]) 

1177 y: array([4]) 

1178 

1179 """ 

1180 __tracebackhide__ = True # Hide traceback for py.test 

1181 assert_array_compare(operator.__lt__, x, y, err_msg=err_msg, 

1182 verbose=verbose, 

1183 header='Arrays are not less-ordered', 

1184 equal_inf=False) 

1185 

1186 

1187def runstring(astr, dict): 

1188 exec(astr, dict) 

1189 

1190 

1191def assert_string_equal(actual, desired): 

1192 """ 

1193 Test if two strings are equal. 

1194 

1195 If the given strings are equal, `assert_string_equal` does nothing. 

1196 If they are not equal, an AssertionError is raised, and the diff 

1197 between the strings is shown. 

1198 

1199 Parameters 

1200 ---------- 

1201 actual : str 

1202 The string to test for equality against the expected string. 

1203 desired : str 

1204 The expected string. 

1205 

1206 Examples 

1207 -------- 

1208 >>> np.testing.assert_string_equal('abc', 'abc') 

1209 >>> np.testing.assert_string_equal('abc', 'abcd') 

1210 Traceback (most recent call last): 

1211 File "<stdin>", line 1, in <module> 

1212 ... 

1213 AssertionError: Differences in strings: 

1214 - abc+ abcd? + 

1215 

1216 """ 

1217 # delay import of difflib to reduce startup time 

1218 __tracebackhide__ = True # Hide traceback for py.test 

1219 import difflib 

1220 

1221 if not isinstance(actual, str): 

1222 raise AssertionError(repr(type(actual))) 

1223 if not isinstance(desired, str): 

1224 raise AssertionError(repr(type(desired))) 

1225 if desired == actual: 

1226 return 

1227 

1228 diff = list(difflib.Differ().compare(actual.splitlines(True), 

1229 desired.splitlines(True))) 

1230 diff_list = [] 

1231 while diff: 

1232 d1 = diff.pop(0) 

1233 if d1.startswith(' '): 

1234 continue 

1235 if d1.startswith('- '): 

1236 l = [d1] 

1237 d2 = diff.pop(0) 

1238 if d2.startswith('? '): 

1239 l.append(d2) 

1240 d2 = diff.pop(0) 

1241 if not d2.startswith('+ '): 

1242 raise AssertionError(repr(d2)) 

1243 l.append(d2) 

1244 if diff: 

1245 d3 = diff.pop(0) 

1246 if d3.startswith('? '): 

1247 l.append(d3) 

1248 else: 

1249 diff.insert(0, d3) 

1250 if d2[2:] == d1[2:]: 

1251 continue 

1252 diff_list.extend(l) 

1253 continue 

1254 raise AssertionError(repr(d1)) 

1255 if not diff_list: 

1256 return 

1257 msg = f"Differences in strings:\n{''.join(diff_list).rstrip()}" 

1258 if actual != desired: 

1259 raise AssertionError(msg) 

1260 

1261 

1262def rundocs(filename=None, raise_on_error=True): 

1263 """ 

1264 Run doctests found in the given file. 

1265 

1266 By default `rundocs` raises an AssertionError on failure. 

1267 

1268 Parameters 

1269 ---------- 

1270 filename : str 

1271 The path to the file for which the doctests are run. 

1272 raise_on_error : bool 

1273 Whether to raise an AssertionError when a doctest fails. Default is 

1274 True. 

1275 

1276 Notes 

1277 ----- 

1278 The doctests can be run by the user/developer by adding the ``doctests`` 

1279 argument to the ``test()`` call. For example, to run all tests (including 

1280 doctests) for `numpy.lib`: 

1281 

1282 >>> np.lib.test(doctests=True) # doctest: +SKIP 

1283 """ 

1284 from numpy.distutils.misc_util import exec_mod_from_location 

1285 import doctest 

1286 if filename is None: 

1287 f = sys._getframe(1) 

1288 filename = f.f_globals['__file__'] 

1289 name = os.path.splitext(os.path.basename(filename))[0] 

1290 m = exec_mod_from_location(name, filename) 

1291 

1292 tests = doctest.DocTestFinder().find(m) 

1293 runner = doctest.DocTestRunner(verbose=False) 

1294 

1295 msg = [] 

1296 if raise_on_error: 

1297 out = lambda s: msg.append(s) 

1298 else: 

1299 out = None 

1300 

1301 for test in tests: 

1302 runner.run(test, out=out) 

1303 

1304 if runner.failures > 0 and raise_on_error: 

1305 raise AssertionError("Some doctests failed:\n%s" % "\n".join(msg)) 

1306 

1307 

1308def raises(*args): 

1309 """Decorator to check for raised exceptions. 

1310 

1311 The decorated test function must raise one of the passed exceptions to 

1312 pass. If you want to test many assertions about exceptions in a single 

1313 test, you may want to use `assert_raises` instead. 

1314 

1315 .. warning:: 

1316 This decorator is nose specific, do not use it if you are using a 

1317 different test framework. 

1318 

1319 Parameters 

1320 ---------- 

1321 args : exceptions 

1322 The test passes if any of the passed exceptions is raised. 

1323 

1324 Raises 

1325 ------ 

1326 AssertionError 

1327 

1328 Examples 

1329 -------- 

1330 

1331 Usage:: 

1332 

1333 @raises(TypeError, ValueError) 

1334 def test_raises_type_error(): 

1335 raise TypeError("This test passes") 

1336 

1337 @raises(Exception) 

1338 def test_that_fails_by_passing(): 

1339 pass 

1340 

1341 """ 

1342 nose = import_nose() 

1343 return nose.tools.raises(*args) 

1344 

1345# 

1346# assert_raises and assert_raises_regex are taken from unittest. 

1347# 

1348import unittest 

1349 

1350 

1351class _Dummy(unittest.TestCase): 

1352 def nop(self): 

1353 pass 

1354 

1355_d = _Dummy('nop') 

1356 

1357def assert_raises(*args, **kwargs): 

1358 """ 

1359 assert_raises(exception_class, callable, *args, **kwargs) 

1360 assert_raises(exception_class) 

1361 

1362 Fail unless an exception of class exception_class is thrown 

1363 by callable when invoked with arguments args and keyword 

1364 arguments kwargs. If a different type of exception is 

1365 thrown, it will not be caught, and the test case will be 

1366 deemed to have suffered an error, exactly as for an 

1367 unexpected exception. 

1368 

1369 Alternatively, `assert_raises` can be used as a context manager: 

1370 

1371 >>> from numpy.testing import assert_raises 

1372 >>> with assert_raises(ZeroDivisionError): 

1373 ... 1 / 0 

1374 

1375 is equivalent to 

1376 

1377 >>> def div(x, y): 

1378 ... return x / y 

1379 >>> assert_raises(ZeroDivisionError, div, 1, 0) 

1380 

1381 """ 

1382 __tracebackhide__ = True # Hide traceback for py.test 

1383 return _d.assertRaises(*args,**kwargs) 

1384 

1385 

1386def assert_raises_regex(exception_class, expected_regexp, *args, **kwargs): 

1387 """ 

1388 assert_raises_regex(exception_class, expected_regexp, callable, *args, 

1389 **kwargs) 

1390 assert_raises_regex(exception_class, expected_regexp) 

1391 

1392 Fail unless an exception of class exception_class and with message that 

1393 matches expected_regexp is thrown by callable when invoked with arguments 

1394 args and keyword arguments kwargs. 

1395 

1396 Alternatively, can be used as a context manager like `assert_raises`. 

1397 

1398 Notes 

1399 ----- 

1400 .. versionadded:: 1.9.0 

1401 

1402 """ 

1403 __tracebackhide__ = True # Hide traceback for py.test 

1404 return _d.assertRaisesRegex(exception_class, expected_regexp, *args, **kwargs) 

1405 

1406 

1407def decorate_methods(cls, decorator, testmatch=None): 

1408 """ 

1409 Apply a decorator to all methods in a class matching a regular expression. 

1410 

1411 The given decorator is applied to all public methods of `cls` that are 

1412 matched by the regular expression `testmatch` 

1413 (``testmatch.search(methodname)``). Methods that are private, i.e. start 

1414 with an underscore, are ignored. 

1415 

1416 Parameters 

1417 ---------- 

1418 cls : class 

1419 Class whose methods to decorate. 

1420 decorator : function 

1421 Decorator to apply to methods 

1422 testmatch : compiled regexp or str, optional 

1423 The regular expression. Default value is None, in which case the 

1424 nose default (``re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep)``) 

1425 is used. 

1426 If `testmatch` is a string, it is compiled to a regular expression 

1427 first. 

1428 

1429 """ 

1430 if testmatch is None: 

1431 testmatch = re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep) 

1432 else: 

1433 testmatch = re.compile(testmatch) 

1434 cls_attr = cls.__dict__ 

1435 

1436 # delayed import to reduce startup time 

1437 from inspect import isfunction 

1438 

1439 methods = [_m for _m in cls_attr.values() if isfunction(_m)] 

1440 for function in methods: 

1441 try: 

1442 if hasattr(function, 'compat_func_name'): 

1443 funcname = function.compat_func_name 

1444 else: 

1445 funcname = function.__name__ 

1446 except AttributeError: 

1447 # not a function 

1448 continue 

1449 if testmatch.search(funcname) and not funcname.startswith('_'): 

1450 setattr(cls, funcname, decorator(function)) 

1451 return 

1452 

1453 

1454def measure(code_str, times=1, label=None): 

1455 """ 

1456 Return elapsed time for executing code in the namespace of the caller. 

1457 

1458 The supplied code string is compiled with the Python builtin ``compile``. 

1459 The precision of the timing is 10 milli-seconds. If the code will execute 

1460 fast on this timescale, it can be executed many times to get reasonable 

1461 timing accuracy. 

1462 

1463 Parameters 

1464 ---------- 

1465 code_str : str 

1466 The code to be timed. 

1467 times : int, optional 

1468 The number of times the code is executed. Default is 1. The code is 

1469 only compiled once. 

1470 label : str, optional 

1471 A label to identify `code_str` with. This is passed into ``compile`` 

1472 as the second argument (for run-time error messages). 

1473 

1474 Returns 

1475 ------- 

1476 elapsed : float 

1477 Total elapsed time in seconds for executing `code_str` `times` times. 

1478 

1479 Examples 

1480 -------- 

1481 >>> times = 10 

1482 >>> etime = np.testing.measure('for i in range(1000): np.sqrt(i**2)', times=times) 

1483 >>> print("Time for a single execution : ", etime / times, "s") # doctest: +SKIP 

1484 Time for a single execution : 0.005 s 

1485 

1486 """ 

1487 frame = sys._getframe(1) 

1488 locs, globs = frame.f_locals, frame.f_globals 

1489 

1490 code = compile(code_str, f'Test name: {label} ', 'exec') 

1491 i = 0 

1492 elapsed = jiffies() 

1493 while i < times: 

1494 i += 1 

1495 exec(code, globs, locs) 

1496 elapsed = jiffies() - elapsed 

1497 return 0.01*elapsed 

1498 

1499 

1500def _assert_valid_refcount(op): 

1501 """ 

1502 Check that ufuncs don't mishandle refcount of object `1`. 

1503 Used in a few regression tests. 

1504 """ 

1505 if not HAS_REFCOUNT: 

1506 return True 

1507 

1508 import gc 

1509 import numpy as np 

1510 

1511 b = np.arange(100*100).reshape(100, 100) 

1512 c = b 

1513 i = 1 

1514 

1515 gc.disable() 

1516 try: 

1517 rc = sys.getrefcount(i) 

1518 for j in range(15): 

1519 d = op(b, c) 

1520 assert_(sys.getrefcount(i) >= rc) 

1521 finally: 

1522 gc.enable() 

1523 del d # for pyflakes 

1524 

1525 

1526def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True, 

1527 err_msg='', verbose=True): 

1528 """ 

1529 Raises an AssertionError if two objects are not equal up to desired 

1530 tolerance. 

1531 

1532 Given two array_like objects, check that their shapes and all elements 

1533 are equal (but see the Notes for the special handling of a scalar). An 

1534 exception is raised if the shapes mismatch or any values conflict. In  

1535 contrast to the standard usage in numpy, NaNs are compared like numbers, 

1536 no assertion is raised if both objects have NaNs in the same positions. 

1537 

1538 The test is equivalent to ``allclose(actual, desired, rtol, atol)`` (note 

1539 that ``allclose`` has different default values). It compares the difference 

1540 between `actual` and `desired` to ``atol + rtol * abs(desired)``. 

1541 

1542 .. versionadded:: 1.5.0 

1543 

1544 Parameters 

1545 ---------- 

1546 actual : array_like 

1547 Array obtained. 

1548 desired : array_like 

1549 Array desired. 

1550 rtol : float, optional 

1551 Relative tolerance. 

1552 atol : float, optional 

1553 Absolute tolerance. 

1554 equal_nan : bool, optional. 

1555 If True, NaNs will compare equal. 

1556 err_msg : str, optional 

1557 The error message to be printed in case of failure. 

1558 verbose : bool, optional 

1559 If True, the conflicting values are appended to the error message. 

1560 

1561 Raises 

1562 ------ 

1563 AssertionError 

1564 If actual and desired are not equal up to specified precision. 

1565 

1566 See Also 

1567 -------- 

1568 assert_array_almost_equal_nulp, assert_array_max_ulp 

1569 

1570 Notes 

1571 ----- 

1572 When one of `actual` and `desired` is a scalar and the other is 

1573 array_like, the function checks that each element of the array_like 

1574 object is equal to the scalar. 

1575 

1576 Examples 

1577 -------- 

1578 >>> x = [1e-5, 1e-3, 1e-1] 

1579 >>> y = np.arccos(np.cos(x)) 

1580 >>> np.testing.assert_allclose(x, y, rtol=1e-5, atol=0) 

1581 

1582 """ 

1583 __tracebackhide__ = True # Hide traceback for py.test 

1584 import numpy as np 

1585 

1586 def compare(x, y): 

1587 return np.core.numeric.isclose(x, y, rtol=rtol, atol=atol, 

1588 equal_nan=equal_nan) 

1589 

1590 actual, desired = np.asanyarray(actual), np.asanyarray(desired) 

1591 header = f'Not equal to tolerance rtol={rtol:g}, atol={atol:g}' 

1592 assert_array_compare(compare, actual, desired, err_msg=str(err_msg), 

1593 verbose=verbose, header=header, equal_nan=equal_nan) 

1594 

1595 

1596def assert_array_almost_equal_nulp(x, y, nulp=1): 

1597 """ 

1598 Compare two arrays relatively to their spacing. 

1599 

1600 This is a relatively robust method to compare two arrays whose amplitude 

1601 is variable. 

1602 

1603 Parameters 

1604 ---------- 

1605 x, y : array_like 

1606 Input arrays. 

1607 nulp : int, optional 

1608 The maximum number of unit in the last place for tolerance (see Notes). 

1609 Default is 1. 

1610 

1611 Returns 

1612 ------- 

1613 None 

1614 

1615 Raises 

1616 ------ 

1617 AssertionError 

1618 If the spacing between `x` and `y` for one or more elements is larger 

1619 than `nulp`. 

1620 

1621 See Also 

1622 -------- 

1623 assert_array_max_ulp : Check that all items of arrays differ in at most 

1624 N Units in the Last Place. 

1625 spacing : Return the distance between x and the nearest adjacent number. 

1626 

1627 Notes 

1628 ----- 

1629 An assertion is raised if the following condition is not met:: 

1630 

1631 abs(x - y) <= nulp * spacing(maximum(abs(x), abs(y))) 

1632 

1633 Examples 

1634 -------- 

1635 >>> x = np.array([1., 1e-10, 1e-20]) 

1636 >>> eps = np.finfo(x.dtype).eps 

1637 >>> np.testing.assert_array_almost_equal_nulp(x, x*eps/2 + x) 

1638 

1639 >>> np.testing.assert_array_almost_equal_nulp(x, x*eps + x) 

1640 Traceback (most recent call last): 

1641 ... 

1642 AssertionError: X and Y are not equal to 1 ULP (max is 2) 

1643 

1644 """ 

1645 __tracebackhide__ = True # Hide traceback for py.test 

1646 import numpy as np 

1647 ax = np.abs(x) 

1648 ay = np.abs(y) 

1649 ref = nulp * np.spacing(np.where(ax > ay, ax, ay)) 

1650 if not np.all(np.abs(x-y) <= ref): 

1651 if np.iscomplexobj(x) or np.iscomplexobj(y): 

1652 msg = "X and Y are not equal to %d ULP" % nulp 

1653 else: 

1654 max_nulp = np.max(nulp_diff(x, y)) 

1655 msg = "X and Y are not equal to %d ULP (max is %g)" % (nulp, max_nulp) 

1656 raise AssertionError(msg) 

1657 

1658 

1659def assert_array_max_ulp(a, b, maxulp=1, dtype=None): 

1660 """ 

1661 Check that all items of arrays differ in at most N Units in the Last Place. 

1662 

1663 Parameters 

1664 ---------- 

1665 a, b : array_like 

1666 Input arrays to be compared. 

1667 maxulp : int, optional 

1668 The maximum number of units in the last place that elements of `a` and 

1669 `b` can differ. Default is 1. 

1670 dtype : dtype, optional 

1671 Data-type to convert `a` and `b` to if given. Default is None. 

1672 

1673 Returns 

1674 ------- 

1675 ret : ndarray 

1676 Array containing number of representable floating point numbers between 

1677 items in `a` and `b`. 

1678 

1679 Raises 

1680 ------ 

1681 AssertionError 

1682 If one or more elements differ by more than `maxulp`. 

1683 

1684 Notes 

1685 ----- 

1686 For computing the ULP difference, this API does not differentiate between 

1687 various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000 

1688 is zero). 

1689 

1690 See Also 

1691 -------- 

1692 assert_array_almost_equal_nulp : Compare two arrays relatively to their 

1693 spacing. 

1694 

1695 Examples 

1696 -------- 

1697 >>> a = np.linspace(0., 1., 100) 

1698 >>> res = np.testing.assert_array_max_ulp(a, np.arcsin(np.sin(a))) 

1699 

1700 """ 

1701 __tracebackhide__ = True # Hide traceback for py.test 

1702 import numpy as np 

1703 ret = nulp_diff(a, b, dtype) 

1704 if not np.all(ret <= maxulp): 

1705 raise AssertionError("Arrays are not almost equal up to %g " 

1706 "ULP (max difference is %g ULP)" % 

1707 (maxulp, np.max(ret))) 

1708 return ret 

1709 

1710 

1711def nulp_diff(x, y, dtype=None): 

1712 """For each item in x and y, return the number of representable floating 

1713 points between them. 

1714 

1715 Parameters 

1716 ---------- 

1717 x : array_like 

1718 first input array 

1719 y : array_like 

1720 second input array 

1721 dtype : dtype, optional 

1722 Data-type to convert `x` and `y` to if given. Default is None. 

1723 

1724 Returns 

1725 ------- 

1726 nulp : array_like 

1727 number of representable floating point numbers between each item in x 

1728 and y. 

1729 

1730 Notes 

1731 ----- 

1732 For computing the ULP difference, this API does not differentiate between 

1733 various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000 

1734 is zero). 

1735 

1736 Examples 

1737 -------- 

1738 # By definition, epsilon is the smallest number such as 1 + eps != 1, so 

1739 # there should be exactly one ULP between 1 and 1 + eps 

1740 >>> nulp_diff(1, 1 + np.finfo(x.dtype).eps) 

1741 1.0 

1742 """ 

1743 import numpy as np 

1744 if dtype: 

1745 x = np.asarray(x, dtype=dtype) 

1746 y = np.asarray(y, dtype=dtype) 

1747 else: 

1748 x = np.asarray(x) 

1749 y = np.asarray(y) 

1750 

1751 t = np.common_type(x, y) 

1752 if np.iscomplexobj(x) or np.iscomplexobj(y): 

1753 raise NotImplementedError("_nulp not implemented for complex array") 

1754 

1755 x = np.array([x], dtype=t) 

1756 y = np.array([y], dtype=t) 

1757 

1758 x[np.isnan(x)] = np.nan 

1759 y[np.isnan(y)] = np.nan 

1760 

1761 if not x.shape == y.shape: 

1762 raise ValueError("x and y do not have the same shape: %s - %s" % 

1763 (x.shape, y.shape)) 

1764 

1765 def _diff(rx, ry, vdt): 

1766 diff = np.asarray(rx-ry, dtype=vdt) 

1767 return np.abs(diff) 

1768 

1769 rx = integer_repr(x) 

1770 ry = integer_repr(y) 

1771 return _diff(rx, ry, t) 

1772 

1773 

1774def _integer_repr(x, vdt, comp): 

1775 # Reinterpret binary representation of the float as sign-magnitude: 

1776 # take into account two-complement representation 

1777 # See also 

1778 # https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/ 

1779 rx = x.view(vdt) 

1780 if not (rx.size == 1): 

1781 rx[rx < 0] = comp - rx[rx < 0] 

1782 else: 

1783 if rx < 0: 

1784 rx = comp - rx 

1785 

1786 return rx 

1787 

1788 

1789def integer_repr(x): 

1790 """Return the signed-magnitude interpretation of the binary representation 

1791 of x.""" 

1792 import numpy as np 

1793 if x.dtype == np.float16: 

1794 return _integer_repr(x, np.int16, np.int16(-2**15)) 

1795 elif x.dtype == np.float32: 

1796 return _integer_repr(x, np.int32, np.int32(-2**31)) 

1797 elif x.dtype == np.float64: 

1798 return _integer_repr(x, np.int64, np.int64(-2**63)) 

1799 else: 

1800 raise ValueError(f'Unsupported dtype {x.dtype}') 

1801 

1802 

1803@contextlib.contextmanager 

1804def _assert_warns_context(warning_class, name=None): 

1805 __tracebackhide__ = True # Hide traceback for py.test 

1806 with suppress_warnings() as sup: 

1807 l = sup.record(warning_class) 

1808 yield 

1809 if not len(l) > 0: 

1810 name_str = f' when calling {name}' if name is not None else '' 

1811 raise AssertionError("No warning raised" + name_str) 

1812 

1813 

1814def assert_warns(warning_class, *args, **kwargs): 

1815 """ 

1816 Fail unless the given callable throws the specified warning. 

1817 

1818 A warning of class warning_class should be thrown by the callable when 

1819 invoked with arguments args and keyword arguments kwargs. 

1820 If a different type of warning is thrown, it will not be caught. 

1821 

1822 If called with all arguments other than the warning class omitted, may be 

1823 used as a context manager: 

1824 

1825 with assert_warns(SomeWarning): 

1826 do_something() 

1827 

1828 The ability to be used as a context manager is new in NumPy v1.11.0. 

1829 

1830 .. versionadded:: 1.4.0 

1831 

1832 Parameters 

1833 ---------- 

1834 warning_class : class 

1835 The class defining the warning that `func` is expected to throw. 

1836 func : callable, optional 

1837 Callable to test 

1838 *args : Arguments 

1839 Arguments for `func`. 

1840 **kwargs : Kwargs 

1841 Keyword arguments for `func`. 

1842 

1843 Returns 

1844 ------- 

1845 The value returned by `func`. 

1846 

1847 Examples 

1848 -------- 

1849 >>> import warnings 

1850 >>> def deprecated_func(num): 

1851 ... warnings.warn("Please upgrade", DeprecationWarning) 

1852 ... return num*num 

1853 >>> with np.testing.assert_warns(DeprecationWarning): 

1854 ... assert deprecated_func(4) == 16 

1855 >>> # or passing a func 

1856 >>> ret = np.testing.assert_warns(DeprecationWarning, deprecated_func, 4) 

1857 >>> assert ret == 16 

1858 """ 

1859 if not args: 

1860 return _assert_warns_context(warning_class) 

1861 

1862 func = args[0] 

1863 args = args[1:] 

1864 with _assert_warns_context(warning_class, name=func.__name__): 

1865 return func(*args, **kwargs) 

1866 

1867 

1868@contextlib.contextmanager 

1869def _assert_no_warnings_context(name=None): 

1870 __tracebackhide__ = True # Hide traceback for py.test 

1871 with warnings.catch_warnings(record=True) as l: 

1872 warnings.simplefilter('always') 

1873 yield 

1874 if len(l) > 0: 

1875 name_str = f' when calling {name}' if name is not None else '' 

1876 raise AssertionError(f'Got warnings{name_str}: {l}') 

1877 

1878 

1879def assert_no_warnings(*args, **kwargs): 

1880 """ 

1881 Fail if the given callable produces any warnings. 

1882 

1883 If called with all arguments omitted, may be used as a context manager: 

1884 

1885 with assert_no_warnings(): 

1886 do_something() 

1887 

1888 The ability to be used as a context manager is new in NumPy v1.11.0. 

1889 

1890 .. versionadded:: 1.7.0 

1891 

1892 Parameters 

1893 ---------- 

1894 func : callable 

1895 The callable to test. 

1896 \\*args : Arguments 

1897 Arguments passed to `func`. 

1898 \\*\\*kwargs : Kwargs 

1899 Keyword arguments passed to `func`. 

1900 

1901 Returns 

1902 ------- 

1903 The value returned by `func`. 

1904 

1905 """ 

1906 if not args: 

1907 return _assert_no_warnings_context() 

1908 

1909 func = args[0] 

1910 args = args[1:] 

1911 with _assert_no_warnings_context(name=func.__name__): 

1912 return func(*args, **kwargs) 

1913 

1914 

1915def _gen_alignment_data(dtype=float32, type='binary', max_size=24): 

1916 """ 

1917 generator producing data with different alignment and offsets 

1918 to test simd vectorization 

1919 

1920 Parameters 

1921 ---------- 

1922 dtype : dtype 

1923 data type to produce 

1924 type : string 

1925 'unary': create data for unary operations, creates one input 

1926 and output array 

1927 'binary': create data for unary operations, creates two input 

1928 and output array 

1929 max_size : integer 

1930 maximum size of data to produce 

1931 

1932 Returns 

1933 ------- 

1934 if type is 'unary' yields one output, one input array and a message 

1935 containing information on the data 

1936 if type is 'binary' yields one output array, two input array and a message 

1937 containing information on the data 

1938 

1939 """ 

1940 ufmt = 'unary offset=(%d, %d), size=%d, dtype=%r, %s' 

1941 bfmt = 'binary offset=(%d, %d, %d), size=%d, dtype=%r, %s' 

1942 for o in range(3): 

1943 for s in range(o + 2, max(o + 3, max_size)): 

1944 if type == 'unary': 

1945 inp = lambda: arange(s, dtype=dtype)[o:] 

1946 out = empty((s,), dtype=dtype)[o:] 

1947 yield out, inp(), ufmt % (o, o, s, dtype, 'out of place') 

1948 d = inp() 

1949 yield d, d, ufmt % (o, o, s, dtype, 'in place') 

1950 yield out[1:], inp()[:-1], ufmt % \ 

1951 (o + 1, o, s - 1, dtype, 'out of place') 

1952 yield out[:-1], inp()[1:], ufmt % \ 

1953 (o, o + 1, s - 1, dtype, 'out of place') 

1954 yield inp()[:-1], inp()[1:], ufmt % \ 

1955 (o, o + 1, s - 1, dtype, 'aliased') 

1956 yield inp()[1:], inp()[:-1], ufmt % \ 

1957 (o + 1, o, s - 1, dtype, 'aliased') 

1958 if type == 'binary': 

1959 inp1 = lambda: arange(s, dtype=dtype)[o:] 

1960 inp2 = lambda: arange(s, dtype=dtype)[o:] 

1961 out = empty((s,), dtype=dtype)[o:] 

1962 yield out, inp1(), inp2(), bfmt % \ 

1963 (o, o, o, s, dtype, 'out of place') 

1964 d = inp1() 

1965 yield d, d, inp2(), bfmt % \ 

1966 (o, o, o, s, dtype, 'in place1') 

1967 d = inp2() 

1968 yield d, inp1(), d, bfmt % \ 

1969 (o, o, o, s, dtype, 'in place2') 

1970 yield out[1:], inp1()[:-1], inp2()[:-1], bfmt % \ 

1971 (o + 1, o, o, s - 1, dtype, 'out of place') 

1972 yield out[:-1], inp1()[1:], inp2()[:-1], bfmt % \ 

1973 (o, o + 1, o, s - 1, dtype, 'out of place') 

1974 yield out[:-1], inp1()[:-1], inp2()[1:], bfmt % \ 

1975 (o, o, o + 1, s - 1, dtype, 'out of place') 

1976 yield inp1()[1:], inp1()[:-1], inp2()[:-1], bfmt % \ 

1977 (o + 1, o, o, s - 1, dtype, 'aliased') 

1978 yield inp1()[:-1], inp1()[1:], inp2()[:-1], bfmt % \ 

1979 (o, o + 1, o, s - 1, dtype, 'aliased') 

1980 yield inp1()[:-1], inp1()[:-1], inp2()[1:], bfmt % \ 

1981 (o, o, o + 1, s - 1, dtype, 'aliased') 

1982 

1983 

1984class IgnoreException(Exception): 

1985 "Ignoring this exception due to disabled feature" 

1986 pass 

1987 

1988 

1989@contextlib.contextmanager 

1990def tempdir(*args, **kwargs): 

1991 """Context manager to provide a temporary test folder. 

1992 

1993 All arguments are passed as this to the underlying tempfile.mkdtemp 

1994 function. 

1995 

1996 """ 

1997 tmpdir = mkdtemp(*args, **kwargs) 

1998 try: 

1999 yield tmpdir 

2000 finally: 

2001 shutil.rmtree(tmpdir) 

2002 

2003 

2004@contextlib.contextmanager 

2005def temppath(*args, **kwargs): 

2006 """Context manager for temporary files. 

2007 

2008 Context manager that returns the path to a closed temporary file. Its 

2009 parameters are the same as for tempfile.mkstemp and are passed directly 

2010 to that function. The underlying file is removed when the context is 

2011 exited, so it should be closed at that time. 

2012 

2013 Windows does not allow a temporary file to be opened if it is already 

2014 open, so the underlying file must be closed after opening before it 

2015 can be opened again. 

2016 

2017 """ 

2018 fd, path = mkstemp(*args, **kwargs) 

2019 os.close(fd) 

2020 try: 

2021 yield path 

2022 finally: 

2023 os.remove(path) 

2024 

2025 

2026class clear_and_catch_warnings(warnings.catch_warnings): 

2027 """ Context manager that resets warning registry for catching warnings 

2028 

2029 Warnings can be slippery, because, whenever a warning is triggered, Python 

2030 adds a ``__warningregistry__`` member to the *calling* module. This makes 

2031 it impossible to retrigger the warning in this module, whatever you put in 

2032 the warnings filters. This context manager accepts a sequence of `modules` 

2033 as a keyword argument to its constructor and: 

2034 

2035 * stores and removes any ``__warningregistry__`` entries in given `modules` 

2036 on entry; 

2037 * resets ``__warningregistry__`` to its previous state on exit. 

2038 

2039 This makes it possible to trigger any warning afresh inside the context 

2040 manager without disturbing the state of warnings outside. 

2041 

2042 For compatibility with Python 3.0, please consider all arguments to be 

2043 keyword-only. 

2044 

2045 Parameters 

2046 ---------- 

2047 record : bool, optional 

2048 Specifies whether warnings should be captured by a custom 

2049 implementation of ``warnings.showwarning()`` and be appended to a list 

2050 returned by the context manager. Otherwise None is returned by the 

2051 context manager. The objects appended to the list are arguments whose 

2052 attributes mirror the arguments to ``showwarning()``. 

2053 modules : sequence, optional 

2054 Sequence of modules for which to reset warnings registry on entry and 

2055 restore on exit. To work correctly, all 'ignore' filters should 

2056 filter by one of these modules. 

2057 

2058 Examples 

2059 -------- 

2060 >>> import warnings 

2061 >>> with np.testing.clear_and_catch_warnings( 

2062 ... modules=[np.core.fromnumeric]): 

2063 ... warnings.simplefilter('always') 

2064 ... warnings.filterwarnings('ignore', module='np.core.fromnumeric') 

2065 ... # do something that raises a warning but ignore those in 

2066 ... # np.core.fromnumeric 

2067 """ 

2068 class_modules = () 

2069 

2070 def __init__(self, record=False, modules=()): 

2071 self.modules = set(modules).union(self.class_modules) 

2072 self._warnreg_copies = {} 

2073 super().__init__(record=record) 

2074 

2075 def __enter__(self): 

2076 for mod in self.modules: 

2077 if hasattr(mod, '__warningregistry__'): 

2078 mod_reg = mod.__warningregistry__ 

2079 self._warnreg_copies[mod] = mod_reg.copy() 

2080 mod_reg.clear() 

2081 return super().__enter__() 

2082 

2083 def __exit__(self, *exc_info): 

2084 super().__exit__(*exc_info) 

2085 for mod in self.modules: 

2086 if hasattr(mod, '__warningregistry__'): 

2087 mod.__warningregistry__.clear() 

2088 if mod in self._warnreg_copies: 

2089 mod.__warningregistry__.update(self._warnreg_copies[mod]) 

2090 

2091 

2092class suppress_warnings: 

2093 """ 

2094 Context manager and decorator doing much the same as 

2095 ``warnings.catch_warnings``. 

2096 

2097 However, it also provides a filter mechanism to work around 

2098 https://bugs.python.org/issue4180. 

2099 

2100 This bug causes Python before 3.4 to not reliably show warnings again 

2101 after they have been ignored once (even within catch_warnings). It 

2102 means that no "ignore" filter can be used easily, since following 

2103 tests might need to see the warning. Additionally it allows easier 

2104 specificity for testing warnings and can be nested. 

2105 

2106 Parameters 

2107 ---------- 

2108 forwarding_rule : str, optional 

2109 One of "always", "once", "module", or "location". Analogous to 

2110 the usual warnings module filter mode, it is useful to reduce 

2111 noise mostly on the outmost level. Unsuppressed and unrecorded 

2112 warnings will be forwarded based on this rule. Defaults to "always". 

2113 "location" is equivalent to the warnings "default", match by exact 

2114 location the warning warning originated from. 

2115 

2116 Notes 

2117 ----- 

2118 Filters added inside the context manager will be discarded again 

2119 when leaving it. Upon entering all filters defined outside a 

2120 context will be applied automatically. 

2121 

2122 When a recording filter is added, matching warnings are stored in the 

2123 ``log`` attribute as well as in the list returned by ``record``. 

2124 

2125 If filters are added and the ``module`` keyword is given, the 

2126 warning registry of this module will additionally be cleared when 

2127 applying it, entering the context, or exiting it. This could cause 

2128 warnings to appear a second time after leaving the context if they 

2129 were configured to be printed once (default) and were already 

2130 printed before the context was entered. 

2131 

2132 Nesting this context manager will work as expected when the 

2133 forwarding rule is "always" (default). Unfiltered and unrecorded 

2134 warnings will be passed out and be matched by the outer level. 

2135 On the outmost level they will be printed (or caught by another 

2136 warnings context). The forwarding rule argument can modify this 

2137 behaviour. 

2138 

2139 Like ``catch_warnings`` this context manager is not threadsafe. 

2140 

2141 Examples 

2142 -------- 

2143 

2144 With a context manager:: 

2145 

2146 with np.testing.suppress_warnings() as sup: 

2147 sup.filter(DeprecationWarning, "Some text") 

2148 sup.filter(module=np.ma.core) 

2149 log = sup.record(FutureWarning, "Does this occur?") 

2150 command_giving_warnings() 

2151 # The FutureWarning was given once, the filtered warnings were 

2152 # ignored. All other warnings abide outside settings (may be 

2153 # printed/error) 

2154 assert_(len(log) == 1) 

2155 assert_(len(sup.log) == 1) # also stored in log attribute 

2156 

2157 Or as a decorator:: 

2158 

2159 sup = np.testing.suppress_warnings() 

2160 sup.filter(module=np.ma.core) # module must match exactly 

2161 @sup 

2162 def some_function(): 

2163 # do something which causes a warning in np.ma.core 

2164 pass 

2165 """ 

2166 def __init__(self, forwarding_rule="always"): 

2167 self._entered = False 

2168 

2169 # Suppressions are either instance or defined inside one with block: 

2170 self._suppressions = [] 

2171 

2172 if forwarding_rule not in {"always", "module", "once", "location"}: 

2173 raise ValueError("unsupported forwarding rule.") 

2174 self._forwarding_rule = forwarding_rule 

2175 

2176 def _clear_registries(self): 

2177 if hasattr(warnings, "_filters_mutated"): 

2178 # clearing the registry should not be necessary on new pythons, 

2179 # instead the filters should be mutated. 

2180 warnings._filters_mutated() 

2181 return 

2182 # Simply clear the registry, this should normally be harmless, 

2183 # note that on new pythons it would be invalidated anyway. 

2184 for module in self._tmp_modules: 

2185 if hasattr(module, "__warningregistry__"): 

2186 module.__warningregistry__.clear() 

2187 

2188 def _filter(self, category=Warning, message="", module=None, record=False): 

2189 if record: 

2190 record = [] # The log where to store warnings 

2191 else: 

2192 record = None 

2193 if self._entered: 

2194 if module is None: 

2195 warnings.filterwarnings( 

2196 "always", category=category, message=message) 

2197 else: 

2198 module_regex = module.__name__.replace('.', r'\.') + '$' 

2199 warnings.filterwarnings( 

2200 "always", category=category, message=message, 

2201 module=module_regex) 

2202 self._tmp_modules.add(module) 

2203 self._clear_registries() 

2204 

2205 self._tmp_suppressions.append( 

2206 (category, message, re.compile(message, re.I), module, record)) 

2207 else: 

2208 self._suppressions.append( 

2209 (category, message, re.compile(message, re.I), module, record)) 

2210 

2211 return record 

2212 

2213 def filter(self, category=Warning, message="", module=None): 

2214 """ 

2215 Add a new suppressing filter or apply it if the state is entered. 

2216 

2217 Parameters 

2218 ---------- 

2219 category : class, optional 

2220 Warning class to filter 

2221 message : string, optional 

2222 Regular expression matching the warning message. 

2223 module : module, optional 

2224 Module to filter for. Note that the module (and its file) 

2225 must match exactly and cannot be a submodule. This may make 

2226 it unreliable for external modules. 

2227 

2228 Notes 

2229 ----- 

2230 When added within a context, filters are only added inside 

2231 the context and will be forgotten when the context is exited. 

2232 """ 

2233 self._filter(category=category, message=message, module=module, 

2234 record=False) 

2235 

2236 def record(self, category=Warning, message="", module=None): 

2237 """ 

2238 Append a new recording filter or apply it if the state is entered. 

2239 

2240 All warnings matching will be appended to the ``log`` attribute. 

2241 

2242 Parameters 

2243 ---------- 

2244 category : class, optional 

2245 Warning class to filter 

2246 message : string, optional 

2247 Regular expression matching the warning message. 

2248 module : module, optional 

2249 Module to filter for. Note that the module (and its file) 

2250 must match exactly and cannot be a submodule. This may make 

2251 it unreliable for external modules. 

2252 

2253 Returns 

2254 ------- 

2255 log : list 

2256 A list which will be filled with all matched warnings. 

2257 

2258 Notes 

2259 ----- 

2260 When added within a context, filters are only added inside 

2261 the context and will be forgotten when the context is exited. 

2262 """ 

2263 return self._filter(category=category, message=message, module=module, 

2264 record=True) 

2265 

2266 def __enter__(self): 

2267 if self._entered: 

2268 raise RuntimeError("cannot enter suppress_warnings twice.") 

2269 

2270 self._orig_show = warnings.showwarning 

2271 self._filters = warnings.filters 

2272 warnings.filters = self._filters[:] 

2273 

2274 self._entered = True 

2275 self._tmp_suppressions = [] 

2276 self._tmp_modules = set() 

2277 self._forwarded = set() 

2278 

2279 self.log = [] # reset global log (no need to keep same list) 

2280 

2281 for cat, mess, _, mod, log in self._suppressions: 

2282 if log is not None: 

2283 del log[:] # clear the log 

2284 if mod is None: 

2285 warnings.filterwarnings( 

2286 "always", category=cat, message=mess) 

2287 else: 

2288 module_regex = mod.__name__.replace('.', r'\.') + '$' 

2289 warnings.filterwarnings( 

2290 "always", category=cat, message=mess, 

2291 module=module_regex) 

2292 self._tmp_modules.add(mod) 

2293 warnings.showwarning = self._showwarning 

2294 self._clear_registries() 

2295 

2296 return self 

2297 

2298 def __exit__(self, *exc_info): 

2299 warnings.showwarning = self._orig_show 

2300 warnings.filters = self._filters 

2301 self._clear_registries() 

2302 self._entered = False 

2303 del self._orig_show 

2304 del self._filters 

2305 

2306 def _showwarning(self, message, category, filename, lineno, 

2307 *args, use_warnmsg=None, **kwargs): 

2308 for cat, _, pattern, mod, rec in ( 

2309 self._suppressions + self._tmp_suppressions)[::-1]: 

2310 if (issubclass(category, cat) and 

2311 pattern.match(message.args[0]) is not None): 

2312 if mod is None: 

2313 # Message and category match, either recorded or ignored 

2314 if rec is not None: 

2315 msg = WarningMessage(message, category, filename, 

2316 lineno, **kwargs) 

2317 self.log.append(msg) 

2318 rec.append(msg) 

2319 return 

2320 # Use startswith, because warnings strips the c or o from 

2321 # .pyc/.pyo files. 

2322 elif mod.__file__.startswith(filename): 

2323 # The message and module (filename) match 

2324 if rec is not None: 

2325 msg = WarningMessage(message, category, filename, 

2326 lineno, **kwargs) 

2327 self.log.append(msg) 

2328 rec.append(msg) 

2329 return 

2330 

2331 # There is no filter in place, so pass to the outside handler 

2332 # unless we should only pass it once 

2333 if self._forwarding_rule == "always": 

2334 if use_warnmsg is None: 

2335 self._orig_show(message, category, filename, lineno, 

2336 *args, **kwargs) 

2337 else: 

2338 self._orig_showmsg(use_warnmsg) 

2339 return 

2340 

2341 if self._forwarding_rule == "once": 

2342 signature = (message.args, category) 

2343 elif self._forwarding_rule == "module": 

2344 signature = (message.args, category, filename) 

2345 elif self._forwarding_rule == "location": 

2346 signature = (message.args, category, filename, lineno) 

2347 

2348 if signature in self._forwarded: 

2349 return 

2350 self._forwarded.add(signature) 

2351 if use_warnmsg is None: 

2352 self._orig_show(message, category, filename, lineno, *args, 

2353 **kwargs) 

2354 else: 

2355 self._orig_showmsg(use_warnmsg) 

2356 

2357 def __call__(self, func): 

2358 """ 

2359 Function decorator to apply certain suppressions to a whole 

2360 function. 

2361 """ 

2362 @wraps(func) 

2363 def new_func(*args, **kwargs): 

2364 with self: 

2365 return func(*args, **kwargs) 

2366 

2367 return new_func 

2368 

2369 

2370@contextlib.contextmanager 

2371def _assert_no_gc_cycles_context(name=None): 

2372 __tracebackhide__ = True # Hide traceback for py.test 

2373 

2374 # not meaningful to test if there is no refcounting 

2375 if not HAS_REFCOUNT: 

2376 yield 

2377 return 

2378 

2379 assert_(gc.isenabled()) 

2380 gc.disable() 

2381 gc_debug = gc.get_debug() 

2382 try: 

2383 for i in range(100): 

2384 if gc.collect() == 0: 

2385 break 

2386 else: 

2387 raise RuntimeError( 

2388 "Unable to fully collect garbage - perhaps a __del__ method " 

2389 "is creating more reference cycles?") 

2390 

2391 gc.set_debug(gc.DEBUG_SAVEALL) 

2392 yield 

2393 # gc.collect returns the number of unreachable objects in cycles that 

2394 # were found -- we are checking that no cycles were created in the context 

2395 n_objects_in_cycles = gc.collect() 

2396 objects_in_cycles = gc.garbage[:] 

2397 finally: 

2398 del gc.garbage[:] 

2399 gc.set_debug(gc_debug) 

2400 gc.enable() 

2401 

2402 if n_objects_in_cycles: 

2403 name_str = f' when calling {name}' if name is not None else '' 

2404 raise AssertionError( 

2405 "Reference cycles were found{}: {} objects were collected, " 

2406 "of which {} are shown below:{}" 

2407 .format( 

2408 name_str, 

2409 n_objects_in_cycles, 

2410 len(objects_in_cycles), 

2411 ''.join( 

2412 "\n {} object with id={}:\n {}".format( 

2413 type(o).__name__, 

2414 id(o), 

2415 pprint.pformat(o).replace('\n', '\n ') 

2416 ) for o in objects_in_cycles 

2417 ) 

2418 ) 

2419 ) 

2420 

2421 

2422def assert_no_gc_cycles(*args, **kwargs): 

2423 """ 

2424 Fail if the given callable produces any reference cycles. 

2425 

2426 If called with all arguments omitted, may be used as a context manager: 

2427 

2428 with assert_no_gc_cycles(): 

2429 do_something() 

2430 

2431 .. versionadded:: 1.15.0 

2432 

2433 Parameters 

2434 ---------- 

2435 func : callable 

2436 The callable to test. 

2437 \\*args : Arguments 

2438 Arguments passed to `func`. 

2439 \\*\\*kwargs : Kwargs 

2440 Keyword arguments passed to `func`. 

2441 

2442 Returns 

2443 ------- 

2444 Nothing. The result is deliberately discarded to ensure that all cycles 

2445 are found. 

2446 

2447 """ 

2448 if not args: 

2449 return _assert_no_gc_cycles_context() 

2450 

2451 func = args[0] 

2452 args = args[1:] 

2453 with _assert_no_gc_cycles_context(name=func.__name__): 

2454 func(*args, **kwargs) 

2455 

2456def break_cycles(): 

2457 """ 

2458 Break reference cycles by calling gc.collect 

2459 Objects can call other objects' methods (for instance, another object's 

2460 __del__) inside their own __del__. On PyPy, the interpreter only runs 

2461 between calls to gc.collect, so multiple calls are needed to completely 

2462 release all cycles. 

2463 """ 

2464 

2465 gc.collect() 

2466 if IS_PYPY: 

2467 # a few more, just to make sure all the finalizers are called 

2468 gc.collect() 

2469 gc.collect() 

2470 gc.collect() 

2471 gc.collect() 

2472 

2473 

2474def requires_memory(free_bytes): 

2475 """Decorator to skip a test if not enough memory is available""" 

2476 import pytest 

2477 

2478 def decorator(func): 

2479 @wraps(func) 

2480 def wrapper(*a, **kw): 

2481 msg = check_free_memory(free_bytes) 

2482 if msg is not None: 

2483 pytest.skip(msg) 

2484 

2485 try: 

2486 return func(*a, **kw) 

2487 except MemoryError: 

2488 # Probably ran out of memory regardless: don't regard as failure 

2489 pytest.xfail("MemoryError raised") 

2490 

2491 return wrapper 

2492 

2493 return decorator 

2494 

2495 

2496def check_free_memory(free_bytes): 

2497 """ 

2498 Check whether `free_bytes` amount of memory is currently free. 

2499 Returns: None if enough memory available, otherwise error message 

2500 """ 

2501 env_var = 'NPY_AVAILABLE_MEM' 

2502 env_value = os.environ.get(env_var) 

2503 if env_value is not None: 

2504 try: 

2505 mem_free = _parse_size(env_value) 

2506 except ValueError as exc: 

2507 raise ValueError(f'Invalid environment variable {env_var}: {exc}') 

2508 

2509 msg = (f'{free_bytes/1e9} GB memory required, but environment variable ' 

2510 f'NPY_AVAILABLE_MEM={env_value} set') 

2511 else: 

2512 mem_free = _get_mem_available() 

2513 

2514 if mem_free is None: 

2515 msg = ("Could not determine available memory; set NPY_AVAILABLE_MEM " 

2516 "environment variable (e.g. NPY_AVAILABLE_MEM=16GB) to run " 

2517 "the test.") 

2518 mem_free = -1 

2519 else: 

2520 msg = f'{free_bytes/1e9} GB memory required, but {mem_free/1e9} GB available' 

2521 

2522 return msg if mem_free < free_bytes else None 

2523 

2524 

2525def _parse_size(size_str): 

2526 """Convert memory size strings ('12 GB' etc.) to float""" 

2527 suffixes = {'': 1, 'b': 1, 

2528 'k': 1000, 'm': 1000**2, 'g': 1000**3, 't': 1000**4, 

2529 'kb': 1000, 'mb': 1000**2, 'gb': 1000**3, 'tb': 1000**4, 

2530 'kib': 1024, 'mib': 1024**2, 'gib': 1024**3, 'tib': 1024**4} 

2531 

2532 size_re = re.compile(r'^\s*(\d+|\d+\.\d+)\s*({0})\s*$'.format( 

2533 '|'.join(suffixes.keys())), re.I) 

2534 

2535 m = size_re.match(size_str.lower()) 

2536 if not m or m.group(2) not in suffixes: 

2537 raise ValueError(f'value {size_str!r} not a valid size') 

2538 return int(float(m.group(1)) * suffixes[m.group(2)]) 

2539 

2540 

2541def _get_mem_available(): 

2542 """Return available memory in bytes, or None if unknown.""" 

2543 try: 

2544 import psutil 

2545 return psutil.virtual_memory().available 

2546 except (ImportError, AttributeError): 

2547 pass 

2548 

2549 if sys.platform.startswith('linux'): 

2550 info = {} 

2551 with open('/proc/meminfo', 'r') as f: 

2552 for line in f: 

2553 p = line.split() 

2554 info[p[0].strip(':').lower()] = int(p[1]) * 1024 

2555 

2556 if 'memavailable' in info: 

2557 # Linux >= 3.14 

2558 return info['memavailable'] 

2559 else: 

2560 return info['memfree'] + info['cached'] 

2561 

2562 return None 

2563 

2564 

2565def _no_tracing(func): 

2566 """ 

2567 Decorator to temporarily turn off tracing for the duration of a test. 

2568 Needed in tests that check refcounting, otherwise the tracing itself 

2569 influences the refcounts 

2570 """ 

2571 if not hasattr(sys, 'gettrace'): 

2572 return func 

2573 else: 

2574 @wraps(func) 

2575 def wrapper(*args, **kwargs): 

2576 original_trace = sys.gettrace() 

2577 try: 

2578 sys.settrace(None) 

2579 return func(*args, **kwargs) 

2580 finally: 

2581 sys.settrace(original_trace) 

2582 return wrapper 

2583 

2584 

2585def _get_glibc_version(): 

2586 try: 

2587 ver = os.confstr('CS_GNU_LIBC_VERSION').rsplit(' ')[1] 

2588 except Exception as inst: 

2589 ver = '0.0' 

2590 

2591 return ver 

2592 

2593 

2594_glibcver = _get_glibc_version() 

2595_glibc_older_than = lambda x: (_glibcver != '0.0' and _glibcver < x)