1"""
2Utility function to facilitate testing.
3
4"""
5import os
6import sys
7import platform
8import re
9import gc
10import operator
11import warnings
12from functools import partial, wraps
13import shutil
14import contextlib
15from tempfile import mkdtemp, mkstemp
16from unittest.case import SkipTest
17from warnings import WarningMessage
18import pprint
19
20import numpy as np
21from numpy.core import(
22 intp, float32, empty, arange, array_repr, ndarray, isnat, array)
23import numpy.linalg.lapack_lite
24
25from io import StringIO
26
27__all__ = [
28 'assert_equal', 'assert_almost_equal', 'assert_approx_equal',
29 'assert_array_equal', 'assert_array_less', 'assert_string_equal',
30 'assert_array_almost_equal', 'assert_raises', 'build_err_msg',
31 'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal',
32 'raises', 'rundocs', 'runstring', 'verbose', 'measure',
33 'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex',
34 'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings',
35 'assert_allclose', 'IgnoreException', 'clear_and_catch_warnings',
36 'SkipTest', 'KnownFailureException', 'temppath', 'tempdir', 'IS_PYPY',
37 'HAS_REFCOUNT', "IS_WASM", 'suppress_warnings', 'assert_array_compare',
38 'assert_no_gc_cycles', 'break_cycles', 'HAS_LAPACK64', 'IS_PYSTON',
39 '_OLD_PROMOTION'
40 ]
41
42
43class KnownFailureException(Exception):
44 '''Raise this exception to mark a test as a known failing test.'''
45 pass
46
47
48KnownFailureTest = KnownFailureException # backwards compat
49verbose = 0
50
51IS_WASM = platform.machine() in ["wasm32", "wasm64"]
52IS_PYPY = sys.implementation.name == 'pypy'
53IS_PYSTON = hasattr(sys, "pyston_version_info")
54HAS_REFCOUNT = getattr(sys, 'getrefcount', None) is not None and not IS_PYSTON
55HAS_LAPACK64 = numpy.linalg.lapack_lite._ilp64
56
57_OLD_PROMOTION = lambda: np._get_promotion_state() == 'legacy'
58
59
60def import_nose():
61 """ Import nose only when needed.
62 """
63 nose_is_good = True
64 minimum_nose_version = (1, 0, 0)
65 try:
66 import nose
67 except ImportError:
68 nose_is_good = False
69 else:
70 if nose.__versioninfo__ < minimum_nose_version:
71 nose_is_good = False
72
73 if not nose_is_good:
74 msg = ('Need nose >= %d.%d.%d for tests - see '
75 'https://nose.readthedocs.io' %
76 minimum_nose_version)
77 raise ImportError(msg)
78
79 return nose
80
81
82def assert_(val, msg=''):
83 """
84 Assert that works in release mode.
85 Accepts callable msg to allow deferring evaluation until failure.
86
87 The Python built-in ``assert`` does not work when executing code in
88 optimized mode (the ``-O`` flag) - no byte-code is generated for it.
89
90 For documentation on usage, refer to the Python documentation.
91
92 """
93 __tracebackhide__ = True # Hide traceback for py.test
94 if not val:
95 try:
96 smsg = msg()
97 except TypeError:
98 smsg = msg
99 raise AssertionError(smsg)
100
101
102def gisnan(x):
103 """like isnan, but always raise an error if type not supported instead of
104 returning a TypeError object.
105
106 Notes
107 -----
108 isnan and other ufunc sometimes return a NotImplementedType object instead
109 of raising any exception. This function is a wrapper to make sure an
110 exception is always raised.
111
112 This should be removed once this problem is solved at the Ufunc level."""
113 from numpy.core import isnan
114 st = isnan(x)
115 if isinstance(st, type(NotImplemented)):
116 raise TypeError("isnan not supported for this type")
117 return st
118
119
120def gisfinite(x):
121 """like isfinite, but always raise an error if type not supported instead
122 of returning a TypeError object.
123
124 Notes
125 -----
126 isfinite and other ufunc sometimes return a NotImplementedType object
127 instead of raising any exception. This function is a wrapper to make sure
128 an exception is always raised.
129
130 This should be removed once this problem is solved at the Ufunc level."""
131 from numpy.core import isfinite, errstate
132 with errstate(invalid='ignore'):
133 st = isfinite(x)
134 if isinstance(st, type(NotImplemented)):
135 raise TypeError("isfinite not supported for this type")
136 return st
137
138
139def gisinf(x):
140 """like isinf, but always raise an error if type not supported instead of
141 returning a TypeError object.
142
143 Notes
144 -----
145 isinf and other ufunc sometimes return a NotImplementedType object instead
146 of raising any exception. This function is a wrapper to make sure an
147 exception is always raised.
148
149 This should be removed once this problem is solved at the Ufunc level."""
150 from numpy.core import isinf, errstate
151 with errstate(invalid='ignore'):
152 st = isinf(x)
153 if isinstance(st, type(NotImplemented)):
154 raise TypeError("isinf not supported for this type")
155 return st
156
157
158if os.name == 'nt':
159 # Code "stolen" from enthought/debug/memusage.py
160 def GetPerformanceAttributes(object, counter, instance=None,
161 inum=-1, format=None, machine=None):
162 # NOTE: Many counters require 2 samples to give accurate results,
163 # including "% Processor Time" (as by definition, at any instant, a
164 # thread's CPU usage is either 0 or 100). To read counters like this,
165 # you should copy this function, but keep the counter open, and call
166 # CollectQueryData() each time you need to know.
167 # See http://msdn.microsoft.com/library/en-us/dnperfmo/html/perfmonpt2.asp (dead link)
168 # My older explanation for this was that the "AddCounter" process
169 # forced the CPU to 100%, but the above makes more sense :)
170 import win32pdh
171 if format is None:
172 format = win32pdh.PDH_FMT_LONG
173 path = win32pdh.MakeCounterPath( (machine, object, instance, None,
174 inum, counter))
175 hq = win32pdh.OpenQuery()
176 try:
177 hc = win32pdh.AddCounter(hq, path)
178 try:
179 win32pdh.CollectQueryData(hq)
180 type, val = win32pdh.GetFormattedCounterValue(hc, format)
181 return val
182 finally:
183 win32pdh.RemoveCounter(hc)
184 finally:
185 win32pdh.CloseQuery(hq)
186
187 def memusage(processName="python", instance=0):
188 # from win32pdhutil, part of the win32all package
189 import win32pdh
190 return GetPerformanceAttributes("Process", "Virtual Bytes",
191 processName, instance,
192 win32pdh.PDH_FMT_LONG, None)
193elif sys.platform[:5] == 'linux':
194
195 def memusage(_proc_pid_stat=f'/proc/{os.getpid()}/stat'):
196 """
197 Return virtual memory size in bytes of the running python.
198
199 """
200 try:
201 with open(_proc_pid_stat, 'r') as f:
202 l = f.readline().split(' ')
203 return int(l[22])
204 except Exception:
205 return
206else:
207 def memusage():
208 """
209 Return memory usage of running python. [Not implemented]
210
211 """
212 raise NotImplementedError
213
214
215if sys.platform[:5] == 'linux':
216 def jiffies(_proc_pid_stat=f'/proc/{os.getpid()}/stat', _load_time=[]):
217 """
218 Return number of jiffies elapsed.
219
220 Return number of jiffies (1/100ths of a second) that this
221 process has been scheduled in user mode. See man 5 proc.
222
223 """
224 import time
225 if not _load_time:
226 _load_time.append(time.time())
227 try:
228 with open(_proc_pid_stat, 'r') as f:
229 l = f.readline().split(' ')
230 return int(l[13])
231 except Exception:
232 return int(100*(time.time()-_load_time[0]))
233else:
234 # os.getpid is not in all platforms available.
235 # Using time is safe but inaccurate, especially when process
236 # was suspended or sleeping.
237 def jiffies(_load_time=[]):
238 """
239 Return number of jiffies elapsed.
240
241 Return number of jiffies (1/100ths of a second) that this
242 process has been scheduled in user mode. See man 5 proc.
243
244 """
245 import time
246 if not _load_time:
247 _load_time.append(time.time())
248 return int(100*(time.time()-_load_time[0]))
249
250
251def build_err_msg(arrays, err_msg, header='Items are not equal:',
252 verbose=True, names=('ACTUAL', 'DESIRED'), precision=8):
253 msg = ['\n' + header]
254 if err_msg:
255 if err_msg.find('\n') == -1 and len(err_msg) < 79-len(header):
256 msg = [msg[0] + ' ' + err_msg]
257 else:
258 msg.append(err_msg)
259 if verbose:
260 for i, a in enumerate(arrays):
261
262 if isinstance(a, ndarray):
263 # precision argument is only needed if the objects are ndarrays
264 r_func = partial(array_repr, precision=precision)
265 else:
266 r_func = repr
267
268 try:
269 r = r_func(a)
270 except Exception as exc:
271 r = f'[repr failed for <{type(a).__name__}>: {exc}]'
272 if r.count('\n') > 3:
273 r = '\n'.join(r.splitlines()[:3])
274 r += '...'
275 msg.append(f' {names[i]}: {r}')
276 return '\n'.join(msg)
277
278
279def assert_equal(actual, desired, err_msg='', verbose=True):
280 """
281 Raises an AssertionError if two objects are not equal.
282
283 Given two objects (scalars, lists, tuples, dictionaries or numpy arrays),
284 check that all elements of these objects are equal. An exception is raised
285 at the first conflicting values.
286
287 When one of `actual` and `desired` is a scalar and the other is array_like,
288 the function checks that each element of the array_like object is equal to
289 the scalar.
290
291 This function handles NaN comparisons as if NaN was a "normal" number.
292 That is, AssertionError is not raised if both objects have NaNs in the same
293 positions. This is in contrast to the IEEE standard on NaNs, which says
294 that NaN compared to anything must return False.
295
296 Parameters
297 ----------
298 actual : array_like
299 The object to check.
300 desired : array_like
301 The expected object.
302 err_msg : str, optional
303 The error message to be printed in case of failure.
304 verbose : bool, optional
305 If True, the conflicting values are appended to the error message.
306
307 Raises
308 ------
309 AssertionError
310 If actual and desired are not equal.
311
312 Examples
313 --------
314 >>> np.testing.assert_equal([4,5], [4,6])
315 Traceback (most recent call last):
316 ...
317 AssertionError:
318 Items are not equal:
319 item=1
320 ACTUAL: 5
321 DESIRED: 6
322
323 The following comparison does not raise an exception. There are NaNs
324 in the inputs, but they are in the same positions.
325
326 >>> np.testing.assert_equal(np.array([1.0, 2.0, np.nan]), [1, 2, np.nan])
327
328 """
329 __tracebackhide__ = True # Hide traceback for py.test
330 if isinstance(desired, dict):
331 if not isinstance(actual, dict):
332 raise AssertionError(repr(type(actual)))
333 assert_equal(len(actual), len(desired), err_msg, verbose)
334 for k, i in desired.items():
335 if k not in actual:
336 raise AssertionError(repr(k))
337 assert_equal(actual[k], desired[k], f'key={k!r}\n{err_msg}',
338 verbose)
339 return
340 if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)):
341 assert_equal(len(actual), len(desired), err_msg, verbose)
342 for k in range(len(desired)):
343 assert_equal(actual[k], desired[k], f'item={k!r}\n{err_msg}',
344 verbose)
345 return
346 from numpy.core import ndarray, isscalar, signbit
347 from numpy.lib import iscomplexobj, real, imag
348 if isinstance(actual, ndarray) or isinstance(desired, ndarray):
349 return assert_array_equal(actual, desired, err_msg, verbose)
350 msg = build_err_msg([actual, desired], err_msg, verbose=verbose)
351
352 # Handle complex numbers: separate into real/imag to handle
353 # nan/inf/negative zero correctly
354 # XXX: catch ValueError for subclasses of ndarray where iscomplex fail
355 try:
356 usecomplex = iscomplexobj(actual) or iscomplexobj(desired)
357 except (ValueError, TypeError):
358 usecomplex = False
359
360 if usecomplex:
361 if iscomplexobj(actual):
362 actualr = real(actual)
363 actuali = imag(actual)
364 else:
365 actualr = actual
366 actuali = 0
367 if iscomplexobj(desired):
368 desiredr = real(desired)
369 desiredi = imag(desired)
370 else:
371 desiredr = desired
372 desiredi = 0
373 try:
374 assert_equal(actualr, desiredr)
375 assert_equal(actuali, desiredi)
376 except AssertionError:
377 raise AssertionError(msg)
378
379 # isscalar test to check cases such as [np.nan] != np.nan
380 if isscalar(desired) != isscalar(actual):
381 raise AssertionError(msg)
382
383 try:
384 isdesnat = isnat(desired)
385 isactnat = isnat(actual)
386 dtypes_match = (np.asarray(desired).dtype.type ==
387 np.asarray(actual).dtype.type)
388 if isdesnat and isactnat:
389 # If both are NaT (and have the same dtype -- datetime or
390 # timedelta) they are considered equal.
391 if dtypes_match:
392 return
393 else:
394 raise AssertionError(msg)
395
396 except (TypeError, ValueError, NotImplementedError):
397 pass
398
399 # Inf/nan/negative zero handling
400 try:
401 isdesnan = gisnan(desired)
402 isactnan = gisnan(actual)
403 if isdesnan and isactnan:
404 return # both nan, so equal
405
406 # handle signed zero specially for floats
407 array_actual = np.asarray(actual)
408 array_desired = np.asarray(desired)
409 if (array_actual.dtype.char in 'Mm' or
410 array_desired.dtype.char in 'Mm'):
411 # version 1.18
412 # until this version, gisnan failed for datetime64 and timedelta64.
413 # Now it succeeds but comparison to scalar with a different type
414 # emits a DeprecationWarning.
415 # Avoid that by skipping the next check
416 raise NotImplementedError('cannot compare to a scalar '
417 'with a different type')
418
419 if desired == 0 and actual == 0:
420 if not signbit(desired) == signbit(actual):
421 raise AssertionError(msg)
422
423 except (TypeError, ValueError, NotImplementedError):
424 pass
425
426 try:
427 # Explicitly use __eq__ for comparison, gh-2552
428 if not (desired == actual):
429 raise AssertionError(msg)
430
431 except (DeprecationWarning, FutureWarning) as e:
432 # this handles the case when the two types are not even comparable
433 if 'elementwise == comparison' in e.args[0]:
434 raise AssertionError(msg)
435 else:
436 raise
437
438
439def print_assert_equal(test_string, actual, desired):
440 """
441 Test if two objects are equal, and print an error message if test fails.
442
443 The test is performed with ``actual == desired``.
444
445 Parameters
446 ----------
447 test_string : str
448 The message supplied to AssertionError.
449 actual : object
450 The object to test for equality against `desired`.
451 desired : object
452 The expected result.
453
454 Examples
455 --------
456 >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 1])
457 >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 2])
458 Traceback (most recent call last):
459 ...
460 AssertionError: Test XYZ of func xyz failed
461 ACTUAL:
462 [0, 1]
463 DESIRED:
464 [0, 2]
465
466 """
467 __tracebackhide__ = True # Hide traceback for py.test
468 import pprint
469
470 if not (actual == desired):
471 msg = StringIO()
472 msg.write(test_string)
473 msg.write(' failed\nACTUAL: \n')
474 pprint.pprint(actual, msg)
475 msg.write('DESIRED: \n')
476 pprint.pprint(desired, msg)
477 raise AssertionError(msg.getvalue())
478
479
480@np._no_nep50_warning()
481def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
482 """
483 Raises an AssertionError if two items are not equal up to desired
484 precision.
485
486 .. note:: It is recommended to use one of `assert_allclose`,
487 `assert_array_almost_equal_nulp` or `assert_array_max_ulp`
488 instead of this function for more consistent floating point
489 comparisons.
490
491 The test verifies that the elements of `actual` and `desired` satisfy.
492
493 ``abs(desired-actual) < float64(1.5 * 10**(-decimal))``
494
495 That is a looser test than originally documented, but agrees with what the
496 actual implementation in `assert_array_almost_equal` did up to rounding
497 vagaries. An exception is raised at conflicting values. For ndarrays this
498 delegates to assert_array_almost_equal
499
500 Parameters
501 ----------
502 actual : array_like
503 The object to check.
504 desired : array_like
505 The expected object.
506 decimal : int, optional
507 Desired precision, default is 7.
508 err_msg : str, optional
509 The error message to be printed in case of failure.
510 verbose : bool, optional
511 If True, the conflicting values are appended to the error message.
512
513 Raises
514 ------
515 AssertionError
516 If actual and desired are not equal up to specified precision.
517
518 See Also
519 --------
520 assert_allclose: Compare two array_like objects for equality with desired
521 relative and/or absolute precision.
522 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
523
524 Examples
525 --------
526 >>> from numpy.testing import assert_almost_equal
527 >>> assert_almost_equal(2.3333333333333, 2.33333334)
528 >>> assert_almost_equal(2.3333333333333, 2.33333334, decimal=10)
529 Traceback (most recent call last):
530 ...
531 AssertionError:
532 Arrays are not almost equal to 10 decimals
533 ACTUAL: 2.3333333333333
534 DESIRED: 2.33333334
535
536 >>> assert_almost_equal(np.array([1.0,2.3333333333333]),
537 ... np.array([1.0,2.33333334]), decimal=9)
538 Traceback (most recent call last):
539 ...
540 AssertionError:
541 Arrays are not almost equal to 9 decimals
542 <BLANKLINE>
543 Mismatched elements: 1 / 2 (50%)
544 Max absolute difference: 6.66669964e-09
545 Max relative difference: 2.85715698e-09
546 x: array([1. , 2.333333333])
547 y: array([1. , 2.33333334])
548
549 """
550 __tracebackhide__ = True # Hide traceback for py.test
551 from numpy.core import ndarray
552 from numpy.lib import iscomplexobj, real, imag
553
554 # Handle complex numbers: separate into real/imag to handle
555 # nan/inf/negative zero correctly
556 # XXX: catch ValueError for subclasses of ndarray where iscomplex fail
557 try:
558 usecomplex = iscomplexobj(actual) or iscomplexobj(desired)
559 except ValueError:
560 usecomplex = False
561
562 def _build_err_msg():
563 header = ('Arrays are not almost equal to %d decimals' % decimal)
564 return build_err_msg([actual, desired], err_msg, verbose=verbose,
565 header=header)
566
567 if usecomplex:
568 if iscomplexobj(actual):
569 actualr = real(actual)
570 actuali = imag(actual)
571 else:
572 actualr = actual
573 actuali = 0
574 if iscomplexobj(desired):
575 desiredr = real(desired)
576 desiredi = imag(desired)
577 else:
578 desiredr = desired
579 desiredi = 0
580 try:
581 assert_almost_equal(actualr, desiredr, decimal=decimal)
582 assert_almost_equal(actuali, desiredi, decimal=decimal)
583 except AssertionError:
584 raise AssertionError(_build_err_msg())
585
586 if isinstance(actual, (ndarray, tuple, list)) \
587 or isinstance(desired, (ndarray, tuple, list)):
588 return assert_array_almost_equal(actual, desired, decimal, err_msg)
589 try:
590 # If one of desired/actual is not finite, handle it specially here:
591 # check that both are nan if any is a nan, and test for equality
592 # otherwise
593 if not (gisfinite(desired) and gisfinite(actual)):
594 if gisnan(desired) or gisnan(actual):
595 if not (gisnan(desired) and gisnan(actual)):
596 raise AssertionError(_build_err_msg())
597 else:
598 if not desired == actual:
599 raise AssertionError(_build_err_msg())
600 return
601 except (NotImplementedError, TypeError):
602 pass
603 if abs(desired - actual) >= np.float64(1.5 * 10.0**(-decimal)):
604 raise AssertionError(_build_err_msg())
605
606
607@np._no_nep50_warning()
608def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
609 """
610 Raises an AssertionError if two items are not equal up to significant
611 digits.
612
613 .. note:: It is recommended to use one of `assert_allclose`,
614 `assert_array_almost_equal_nulp` or `assert_array_max_ulp`
615 instead of this function for more consistent floating point
616 comparisons.
617
618 Given two numbers, check that they are approximately equal.
619 Approximately equal is defined as the number of significant digits
620 that agree.
621
622 Parameters
623 ----------
624 actual : scalar
625 The object to check.
626 desired : scalar
627 The expected object.
628 significant : int, optional
629 Desired precision, default is 7.
630 err_msg : str, optional
631 The error message to be printed in case of failure.
632 verbose : bool, optional
633 If True, the conflicting values are appended to the error message.
634
635 Raises
636 ------
637 AssertionError
638 If actual and desired are not equal up to specified precision.
639
640 See Also
641 --------
642 assert_allclose: Compare two array_like objects for equality with desired
643 relative and/or absolute precision.
644 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
645
646 Examples
647 --------
648 >>> np.testing.assert_approx_equal(0.12345677777777e-20, 0.1234567e-20)
649 >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345671e-20,
650 ... significant=8)
651 >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345672e-20,
652 ... significant=8)
653 Traceback (most recent call last):
654 ...
655 AssertionError:
656 Items are not equal to 8 significant digits:
657 ACTUAL: 1.234567e-21
658 DESIRED: 1.2345672e-21
659
660 the evaluated condition that raises the exception is
661
662 >>> abs(0.12345670e-20/1e-21 - 0.12345672e-20/1e-21) >= 10**-(8-1)
663 True
664
665 """
666 __tracebackhide__ = True # Hide traceback for py.test
667 import numpy as np
668
669 (actual, desired) = map(float, (actual, desired))
670 if desired == actual:
671 return
672 # Normalized the numbers to be in range (-10.0,10.0)
673 # scale = float(pow(10,math.floor(math.log10(0.5*(abs(desired)+abs(actual))))))
674 with np.errstate(invalid='ignore'):
675 scale = 0.5*(np.abs(desired) + np.abs(actual))
676 scale = np.power(10, np.floor(np.log10(scale)))
677 try:
678 sc_desired = desired/scale
679 except ZeroDivisionError:
680 sc_desired = 0.0
681 try:
682 sc_actual = actual/scale
683 except ZeroDivisionError:
684 sc_actual = 0.0
685 msg = build_err_msg(
686 [actual, desired], err_msg,
687 header='Items are not equal to %d significant digits:' % significant,
688 verbose=verbose)
689 try:
690 # If one of desired/actual is not finite, handle it specially here:
691 # check that both are nan if any is a nan, and test for equality
692 # otherwise
693 if not (gisfinite(desired) and gisfinite(actual)):
694 if gisnan(desired) or gisnan(actual):
695 if not (gisnan(desired) and gisnan(actual)):
696 raise AssertionError(msg)
697 else:
698 if not desired == actual:
699 raise AssertionError(msg)
700 return
701 except (TypeError, NotImplementedError):
702 pass
703 if np.abs(sc_desired - sc_actual) >= np.power(10., -(significant-1)):
704 raise AssertionError(msg)
705
706
707@np._no_nep50_warning()
708def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
709 precision=6, equal_nan=True, equal_inf=True,
710 *, strict=False):
711 __tracebackhide__ = True # Hide traceback for py.test
712 from numpy.core import array, array2string, isnan, inf, bool_, errstate, all, max, object_
713
714 x = np.asanyarray(x)
715 y = np.asanyarray(y)
716
717 # original array for output formatting
718 ox, oy = x, y
719
720 def isnumber(x):
721 return x.dtype.char in '?bhilqpBHILQPefdgFDG'
722
723 def istime(x):
724 return x.dtype.char in "Mm"
725
726 def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
727 """Handling nan/inf.
728
729 Combine results of running func on x and y, checking that they are True
730 at the same locations.
731
732 """
733 __tracebackhide__ = True # Hide traceback for py.test
734
735 x_id = func(x)
736 y_id = func(y)
737 # We include work-arounds here to handle three types of slightly
738 # pathological ndarray subclasses:
739 # (1) all() on `masked` array scalars can return masked arrays, so we
740 # use != True
741 # (2) __eq__ on some ndarray subclasses returns Python booleans
742 # instead of element-wise comparisons, so we cast to bool_() and
743 # use isinstance(..., bool) checks
744 # (3) subclasses with bare-bones __array_function__ implementations may
745 # not implement np.all(), so favor using the .all() method
746 # We are not committed to supporting such subclasses, but it's nice to
747 # support them if possible.
748 if bool_(x_id == y_id).all() != True:
749 msg = build_err_msg([x, y],
750 err_msg + '\nx and y %s location mismatch:'
751 % (hasval), verbose=verbose, header=header,
752 names=('x', 'y'), precision=precision)
753 raise AssertionError(msg)
754 # If there is a scalar, then here we know the array has the same
755 # flag as it everywhere, so we should return the scalar flag.
756 if isinstance(x_id, bool) or x_id.ndim == 0:
757 return bool_(x_id)
758 elif isinstance(y_id, bool) or y_id.ndim == 0:
759 return bool_(y_id)
760 else:
761 return y_id
762
763 try:
764 if strict:
765 cond = x.shape == y.shape and x.dtype == y.dtype
766 else:
767 cond = (x.shape == () or y.shape == ()) or x.shape == y.shape
768 if not cond:
769 if x.shape != y.shape:
770 reason = f'\n(shapes {x.shape}, {y.shape} mismatch)'
771 else:
772 reason = f'\n(dtypes {x.dtype}, {y.dtype} mismatch)'
773 msg = build_err_msg([x, y],
774 err_msg
775 + reason,
776 verbose=verbose, header=header,
777 names=('x', 'y'), precision=precision)
778 raise AssertionError(msg)
779
780 flagged = bool_(False)
781 if isnumber(x) and isnumber(y):
782 if equal_nan:
783 flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan')
784
785 if equal_inf:
786 flagged |= func_assert_same_pos(x, y,
787 func=lambda xy: xy == +inf,
788 hasval='+inf')
789 flagged |= func_assert_same_pos(x, y,
790 func=lambda xy: xy == -inf,
791 hasval='-inf')
792
793 elif istime(x) and istime(y):
794 # If one is datetime64 and the other timedelta64 there is no point
795 if equal_nan and x.dtype.type == y.dtype.type:
796 flagged = func_assert_same_pos(x, y, func=isnat, hasval="NaT")
797
798 if flagged.ndim > 0:
799 x, y = x[~flagged], y[~flagged]
800 # Only do the comparison if actual values are left
801 if x.size == 0:
802 return
803 elif flagged:
804 # no sense doing comparison if everything is flagged.
805 return
806
807 val = comparison(x, y)
808
809 if isinstance(val, bool):
810 cond = val
811 reduced = array([val])
812 else:
813 reduced = val.ravel()
814 cond = reduced.all()
815
816 # The below comparison is a hack to ensure that fully masked
817 # results, for which val.ravel().all() returns np.ma.masked,
818 # do not trigger a failure (np.ma.masked != True evaluates as
819 # np.ma.masked, which is falsy).
820 if cond != True:
821 n_mismatch = reduced.size - reduced.sum(dtype=intp)
822 n_elements = flagged.size if flagged.ndim != 0 else reduced.size
823 percent_mismatch = 100 * n_mismatch / n_elements
824 remarks = [
825 'Mismatched elements: {} / {} ({:.3g}%)'.format(
826 n_mismatch, n_elements, percent_mismatch)]
827
828 with errstate(all='ignore'):
829 # ignore errors for non-numeric types
830 with contextlib.suppress(TypeError):
831 error = abs(x - y)
832 if np.issubdtype(x.dtype, np.unsignedinteger):
833 error2 = abs(y - x)
834 np.minimum(error, error2, out=error)
835 max_abs_error = max(error)
836 if getattr(error, 'dtype', object_) == object_:
837 remarks.append('Max absolute difference: '
838 + str(max_abs_error))
839 else:
840 remarks.append('Max absolute difference: '
841 + array2string(max_abs_error))
842
843 # note: this definition of relative error matches that one
844 # used by assert_allclose (found in np.isclose)
845 # Filter values where the divisor would be zero
846 nonzero = bool_(y != 0)
847 if all(~nonzero):
848 max_rel_error = array(inf)
849 else:
850 max_rel_error = max(error[nonzero] / abs(y[nonzero]))
851 if getattr(error, 'dtype', object_) == object_:
852 remarks.append('Max relative difference: '
853 + str(max_rel_error))
854 else:
855 remarks.append('Max relative difference: '
856 + array2string(max_rel_error))
857
858 err_msg += '\n' + '\n'.join(remarks)
859 msg = build_err_msg([ox, oy], err_msg,
860 verbose=verbose, header=header,
861 names=('x', 'y'), precision=precision)
862 raise AssertionError(msg)
863 except ValueError:
864 import traceback
865 efmt = traceback.format_exc()
866 header = f'error during assertion:\n\n{efmt}\n\n{header}'
867
868 msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header,
869 names=('x', 'y'), precision=precision)
870 raise ValueError(msg)
871
872
873def assert_array_equal(x, y, err_msg='', verbose=True, *, strict=False):
874 """
875 Raises an AssertionError if two array_like objects are not equal.
876
877 Given two array_like objects, check that the shape is equal and all
878 elements of these objects are equal (but see the Notes for the special
879 handling of a scalar). An exception is raised at shape mismatch or
880 conflicting values. In contrast to the standard usage in numpy, NaNs
881 are compared like numbers, no assertion is raised if both objects have
882 NaNs in the same positions.
883
884 The usual caution for verifying equality with floating point numbers is
885 advised.
886
887 Parameters
888 ----------
889 x : array_like
890 The actual object to check.
891 y : array_like
892 The desired, expected object.
893 err_msg : str, optional
894 The error message to be printed in case of failure.
895 verbose : bool, optional
896 If True, the conflicting values are appended to the error message.
897 strict : bool, optional
898 If True, raise an AssertionError when either the shape or the data
899 type of the array_like objects does not match. The special
900 handling for scalars mentioned in the Notes section is disabled.
901
902 .. versionadded:: 1.24.0
903
904 Raises
905 ------
906 AssertionError
907 If actual and desired objects are not equal.
908
909 See Also
910 --------
911 assert_allclose: Compare two array_like objects for equality with desired
912 relative and/or absolute precision.
913 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
914
915 Notes
916 -----
917 When one of `x` and `y` is a scalar and the other is array_like, the
918 function checks that each element of the array_like object is equal to
919 the scalar. This behaviour can be disabled with the `strict` parameter.
920
921 Examples
922 --------
923 The first assert does not raise an exception:
924
925 >>> np.testing.assert_array_equal([1.0,2.33333,np.nan],
926 ... [np.exp(0),2.33333, np.nan])
927
928 Assert fails with numerical imprecision with floats:
929
930 >>> np.testing.assert_array_equal([1.0,np.pi,np.nan],
931 ... [1, np.sqrt(np.pi)**2, np.nan])
932 Traceback (most recent call last):
933 ...
934 AssertionError:
935 Arrays are not equal
936 <BLANKLINE>
937 Mismatched elements: 1 / 3 (33.3%)
938 Max absolute difference: 4.4408921e-16
939 Max relative difference: 1.41357986e-16
940 x: array([1. , 3.141593, nan])
941 y: array([1. , 3.141593, nan])
942
943 Use `assert_allclose` or one of the nulp (number of floating point values)
944 functions for these cases instead:
945
946 >>> np.testing.assert_allclose([1.0,np.pi,np.nan],
947 ... [1, np.sqrt(np.pi)**2, np.nan],
948 ... rtol=1e-10, atol=0)
949
950 As mentioned in the Notes section, `assert_array_equal` has special
951 handling for scalars. Here the test checks that each value in `x` is 3:
952
953 >>> x = np.full((2, 5), fill_value=3)
954 >>> np.testing.assert_array_equal(x, 3)
955
956 Use `strict` to raise an AssertionError when comparing a scalar with an
957 array:
958
959 >>> np.testing.assert_array_equal(x, 3, strict=True)
960 Traceback (most recent call last):
961 ...
962 AssertionError:
963 Arrays are not equal
964 <BLANKLINE>
965 (shapes (2, 5), () mismatch)
966 x: array([[3, 3, 3, 3, 3],
967 [3, 3, 3, 3, 3]])
968 y: array(3)
969
970 The `strict` parameter also ensures that the array data types match:
971
972 >>> x = np.array([2, 2, 2])
973 >>> y = np.array([2., 2., 2.], dtype=np.float32)
974 >>> np.testing.assert_array_equal(x, y, strict=True)
975 Traceback (most recent call last):
976 ...
977 AssertionError:
978 Arrays are not equal
979 <BLANKLINE>
980 (dtypes int64, float32 mismatch)
981 x: array([2, 2, 2])
982 y: array([2., 2., 2.], dtype=float32)
983 """
984 __tracebackhide__ = True # Hide traceback for py.test
985 assert_array_compare(operator.__eq__, x, y, err_msg=err_msg,
986 verbose=verbose, header='Arrays are not equal',
987 strict=strict)
988
989
990@np._no_nep50_warning()
991def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
992 """
993 Raises an AssertionError if two objects are not equal up to desired
994 precision.
995
996 .. note:: It is recommended to use one of `assert_allclose`,
997 `assert_array_almost_equal_nulp` or `assert_array_max_ulp`
998 instead of this function for more consistent floating point
999 comparisons.
1000
1001 The test verifies identical shapes and that the elements of ``actual`` and
1002 ``desired`` satisfy.
1003
1004 ``abs(desired-actual) < 1.5 * 10**(-decimal)``
1005
1006 That is a looser test than originally documented, but agrees with what the
1007 actual implementation did up to rounding vagaries. An exception is raised
1008 at shape mismatch or conflicting values. In contrast to the standard usage
1009 in numpy, NaNs are compared like numbers, no assertion is raised if both
1010 objects have NaNs in the same positions.
1011
1012 Parameters
1013 ----------
1014 x : array_like
1015 The actual object to check.
1016 y : array_like
1017 The desired, expected object.
1018 decimal : int, optional
1019 Desired precision, default is 6.
1020 err_msg : str, optional
1021 The error message to be printed in case of failure.
1022 verbose : bool, optional
1023 If True, the conflicting values are appended to the error message.
1024
1025 Raises
1026 ------
1027 AssertionError
1028 If actual and desired are not equal up to specified precision.
1029
1030 See Also
1031 --------
1032 assert_allclose: Compare two array_like objects for equality with desired
1033 relative and/or absolute precision.
1034 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
1035
1036 Examples
1037 --------
1038 the first assert does not raise an exception
1039
1040 >>> np.testing.assert_array_almost_equal([1.0,2.333,np.nan],
1041 ... [1.0,2.333,np.nan])
1042
1043 >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan],
1044 ... [1.0,2.33339,np.nan], decimal=5)
1045 Traceback (most recent call last):
1046 ...
1047 AssertionError:
1048 Arrays are not almost equal to 5 decimals
1049 <BLANKLINE>
1050 Mismatched elements: 1 / 3 (33.3%)
1051 Max absolute difference: 6.e-05
1052 Max relative difference: 2.57136612e-05
1053 x: array([1. , 2.33333, nan])
1054 y: array([1. , 2.33339, nan])
1055
1056 >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan],
1057 ... [1.0,2.33333, 5], decimal=5)
1058 Traceback (most recent call last):
1059 ...
1060 AssertionError:
1061 Arrays are not almost equal to 5 decimals
1062 <BLANKLINE>
1063 x and y nan location mismatch:
1064 x: array([1. , 2.33333, nan])
1065 y: array([1. , 2.33333, 5. ])
1066
1067 """
1068 __tracebackhide__ = True # Hide traceback for py.test
1069 from numpy.core import number, float_, result_type, array
1070 from numpy.core.numerictypes import issubdtype
1071 from numpy.core.fromnumeric import any as npany
1072
1073 def compare(x, y):
1074 try:
1075 if npany(gisinf(x)) or npany( gisinf(y)):
1076 xinfid = gisinf(x)
1077 yinfid = gisinf(y)
1078 if not (xinfid == yinfid).all():
1079 return False
1080 # if one item, x and y is +- inf
1081 if x.size == y.size == 1:
1082 return x == y
1083 x = x[~xinfid]
1084 y = y[~yinfid]
1085 except (TypeError, NotImplementedError):
1086 pass
1087
1088 # make sure y is an inexact type to avoid abs(MIN_INT); will cause
1089 # casting of x later.
1090 dtype = result_type(y, 1.)
1091 y = np.asanyarray(y, dtype)
1092 z = abs(x - y)
1093
1094 if not issubdtype(z.dtype, number):
1095 z = z.astype(float_) # handle object arrays
1096
1097 return z < 1.5 * 10.0**(-decimal)
1098
1099 assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
1100 header=('Arrays are not almost equal to %d decimals' % decimal),
1101 precision=decimal)
1102
1103
1104def assert_array_less(x, y, err_msg='', verbose=True):
1105 """
1106 Raises an AssertionError if two array_like objects are not ordered by less
1107 than.
1108
1109 Given two array_like objects, check that the shape is equal and all
1110 elements of the first object are strictly smaller than those of the
1111 second object. An exception is raised at shape mismatch or incorrectly
1112 ordered values. Shape mismatch does not raise if an object has zero
1113 dimension. In contrast to the standard usage in numpy, NaNs are
1114 compared, no assertion is raised if both objects have NaNs in the same
1115 positions.
1116
1117
1118
1119 Parameters
1120 ----------
1121 x : array_like
1122 The smaller object to check.
1123 y : array_like
1124 The larger object to compare.
1125 err_msg : string
1126 The error message to be printed in case of failure.
1127 verbose : bool
1128 If True, the conflicting values are appended to the error message.
1129
1130 Raises
1131 ------
1132 AssertionError
1133 If actual and desired objects are not equal.
1134
1135 See Also
1136 --------
1137 assert_array_equal: tests objects for equality
1138 assert_array_almost_equal: test objects for equality up to precision
1139
1140
1141
1142 Examples
1143 --------
1144 >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1.1, 2.0, np.nan])
1145 >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1, 2.0, np.nan])
1146 Traceback (most recent call last):
1147 ...
1148 AssertionError:
1149 Arrays are not less-ordered
1150 <BLANKLINE>
1151 Mismatched elements: 1 / 3 (33.3%)
1152 Max absolute difference: 1.
1153 Max relative difference: 0.5
1154 x: array([ 1., 1., nan])
1155 y: array([ 1., 2., nan])
1156
1157 >>> np.testing.assert_array_less([1.0, 4.0], 3)
1158 Traceback (most recent call last):
1159 ...
1160 AssertionError:
1161 Arrays are not less-ordered
1162 <BLANKLINE>
1163 Mismatched elements: 1 / 2 (50%)
1164 Max absolute difference: 2.
1165 Max relative difference: 0.66666667
1166 x: array([1., 4.])
1167 y: array(3)
1168
1169 >>> np.testing.assert_array_less([1.0, 2.0, 3.0], [4])
1170 Traceback (most recent call last):
1171 ...
1172 AssertionError:
1173 Arrays are not less-ordered
1174 <BLANKLINE>
1175 (shapes (3,), (1,) mismatch)
1176 x: array([1., 2., 3.])
1177 y: array([4])
1178
1179 """
1180 __tracebackhide__ = True # Hide traceback for py.test
1181 assert_array_compare(operator.__lt__, x, y, err_msg=err_msg,
1182 verbose=verbose,
1183 header='Arrays are not less-ordered',
1184 equal_inf=False)
1185
1186
1187def runstring(astr, dict):
1188 exec(astr, dict)
1189
1190
1191def assert_string_equal(actual, desired):
1192 """
1193 Test if two strings are equal.
1194
1195 If the given strings are equal, `assert_string_equal` does nothing.
1196 If they are not equal, an AssertionError is raised, and the diff
1197 between the strings is shown.
1198
1199 Parameters
1200 ----------
1201 actual : str
1202 The string to test for equality against the expected string.
1203 desired : str
1204 The expected string.
1205
1206 Examples
1207 --------
1208 >>> np.testing.assert_string_equal('abc', 'abc')
1209 >>> np.testing.assert_string_equal('abc', 'abcd')
1210 Traceback (most recent call last):
1211 File "<stdin>", line 1, in <module>
1212 ...
1213 AssertionError: Differences in strings:
1214 - abc+ abcd? +
1215
1216 """
1217 # delay import of difflib to reduce startup time
1218 __tracebackhide__ = True # Hide traceback for py.test
1219 import difflib
1220
1221 if not isinstance(actual, str):
1222 raise AssertionError(repr(type(actual)))
1223 if not isinstance(desired, str):
1224 raise AssertionError(repr(type(desired)))
1225 if desired == actual:
1226 return
1227
1228 diff = list(difflib.Differ().compare(actual.splitlines(True),
1229 desired.splitlines(True)))
1230 diff_list = []
1231 while diff:
1232 d1 = diff.pop(0)
1233 if d1.startswith(' '):
1234 continue
1235 if d1.startswith('- '):
1236 l = [d1]
1237 d2 = diff.pop(0)
1238 if d2.startswith('? '):
1239 l.append(d2)
1240 d2 = diff.pop(0)
1241 if not d2.startswith('+ '):
1242 raise AssertionError(repr(d2))
1243 l.append(d2)
1244 if diff:
1245 d3 = diff.pop(0)
1246 if d3.startswith('? '):
1247 l.append(d3)
1248 else:
1249 diff.insert(0, d3)
1250 if d2[2:] == d1[2:]:
1251 continue
1252 diff_list.extend(l)
1253 continue
1254 raise AssertionError(repr(d1))
1255 if not diff_list:
1256 return
1257 msg = f"Differences in strings:\n{''.join(diff_list).rstrip()}"
1258 if actual != desired:
1259 raise AssertionError(msg)
1260
1261
1262def rundocs(filename=None, raise_on_error=True):
1263 """
1264 Run doctests found in the given file.
1265
1266 By default `rundocs` raises an AssertionError on failure.
1267
1268 Parameters
1269 ----------
1270 filename : str
1271 The path to the file for which the doctests are run.
1272 raise_on_error : bool
1273 Whether to raise an AssertionError when a doctest fails. Default is
1274 True.
1275
1276 Notes
1277 -----
1278 The doctests can be run by the user/developer by adding the ``doctests``
1279 argument to the ``test()`` call. For example, to run all tests (including
1280 doctests) for `numpy.lib`:
1281
1282 >>> np.lib.test(doctests=True) # doctest: +SKIP
1283 """
1284 from numpy.distutils.misc_util import exec_mod_from_location
1285 import doctest
1286 if filename is None:
1287 f = sys._getframe(1)
1288 filename = f.f_globals['__file__']
1289 name = os.path.splitext(os.path.basename(filename))[0]
1290 m = exec_mod_from_location(name, filename)
1291
1292 tests = doctest.DocTestFinder().find(m)
1293 runner = doctest.DocTestRunner(verbose=False)
1294
1295 msg = []
1296 if raise_on_error:
1297 out = lambda s: msg.append(s)
1298 else:
1299 out = None
1300
1301 for test in tests:
1302 runner.run(test, out=out)
1303
1304 if runner.failures > 0 and raise_on_error:
1305 raise AssertionError("Some doctests failed:\n%s" % "\n".join(msg))
1306
1307
1308def raises(*args):
1309 """Decorator to check for raised exceptions.
1310
1311 The decorated test function must raise one of the passed exceptions to
1312 pass. If you want to test many assertions about exceptions in a single
1313 test, you may want to use `assert_raises` instead.
1314
1315 .. warning::
1316 This decorator is nose specific, do not use it if you are using a
1317 different test framework.
1318
1319 Parameters
1320 ----------
1321 args : exceptions
1322 The test passes if any of the passed exceptions is raised.
1323
1324 Raises
1325 ------
1326 AssertionError
1327
1328 Examples
1329 --------
1330
1331 Usage::
1332
1333 @raises(TypeError, ValueError)
1334 def test_raises_type_error():
1335 raise TypeError("This test passes")
1336
1337 @raises(Exception)
1338 def test_that_fails_by_passing():
1339 pass
1340
1341 """
1342 nose = import_nose()
1343 return nose.tools.raises(*args)
1344
1345#
1346# assert_raises and assert_raises_regex are taken from unittest.
1347#
1348import unittest
1349
1350
1351class _Dummy(unittest.TestCase):
1352 def nop(self):
1353 pass
1354
1355_d = _Dummy('nop')
1356
1357def assert_raises(*args, **kwargs):
1358 """
1359 assert_raises(exception_class, callable, *args, **kwargs)
1360 assert_raises(exception_class)
1361
1362 Fail unless an exception of class exception_class is thrown
1363 by callable when invoked with arguments args and keyword
1364 arguments kwargs. If a different type of exception is
1365 thrown, it will not be caught, and the test case will be
1366 deemed to have suffered an error, exactly as for an
1367 unexpected exception.
1368
1369 Alternatively, `assert_raises` can be used as a context manager:
1370
1371 >>> from numpy.testing import assert_raises
1372 >>> with assert_raises(ZeroDivisionError):
1373 ... 1 / 0
1374
1375 is equivalent to
1376
1377 >>> def div(x, y):
1378 ... return x / y
1379 >>> assert_raises(ZeroDivisionError, div, 1, 0)
1380
1381 """
1382 __tracebackhide__ = True # Hide traceback for py.test
1383 return _d.assertRaises(*args,**kwargs)
1384
1385
1386def assert_raises_regex(exception_class, expected_regexp, *args, **kwargs):
1387 """
1388 assert_raises_regex(exception_class, expected_regexp, callable, *args,
1389 **kwargs)
1390 assert_raises_regex(exception_class, expected_regexp)
1391
1392 Fail unless an exception of class exception_class and with message that
1393 matches expected_regexp is thrown by callable when invoked with arguments
1394 args and keyword arguments kwargs.
1395
1396 Alternatively, can be used as a context manager like `assert_raises`.
1397
1398 Notes
1399 -----
1400 .. versionadded:: 1.9.0
1401
1402 """
1403 __tracebackhide__ = True # Hide traceback for py.test
1404 return _d.assertRaisesRegex(exception_class, expected_regexp, *args, **kwargs)
1405
1406
1407def decorate_methods(cls, decorator, testmatch=None):
1408 """
1409 Apply a decorator to all methods in a class matching a regular expression.
1410
1411 The given decorator is applied to all public methods of `cls` that are
1412 matched by the regular expression `testmatch`
1413 (``testmatch.search(methodname)``). Methods that are private, i.e. start
1414 with an underscore, are ignored.
1415
1416 Parameters
1417 ----------
1418 cls : class
1419 Class whose methods to decorate.
1420 decorator : function
1421 Decorator to apply to methods
1422 testmatch : compiled regexp or str, optional
1423 The regular expression. Default value is None, in which case the
1424 nose default (``re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep)``)
1425 is used.
1426 If `testmatch` is a string, it is compiled to a regular expression
1427 first.
1428
1429 """
1430 if testmatch is None:
1431 testmatch = re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep)
1432 else:
1433 testmatch = re.compile(testmatch)
1434 cls_attr = cls.__dict__
1435
1436 # delayed import to reduce startup time
1437 from inspect import isfunction
1438
1439 methods = [_m for _m in cls_attr.values() if isfunction(_m)]
1440 for function in methods:
1441 try:
1442 if hasattr(function, 'compat_func_name'):
1443 funcname = function.compat_func_name
1444 else:
1445 funcname = function.__name__
1446 except AttributeError:
1447 # not a function
1448 continue
1449 if testmatch.search(funcname) and not funcname.startswith('_'):
1450 setattr(cls, funcname, decorator(function))
1451 return
1452
1453
1454def measure(code_str, times=1, label=None):
1455 """
1456 Return elapsed time for executing code in the namespace of the caller.
1457
1458 The supplied code string is compiled with the Python builtin ``compile``.
1459 The precision of the timing is 10 milli-seconds. If the code will execute
1460 fast on this timescale, it can be executed many times to get reasonable
1461 timing accuracy.
1462
1463 Parameters
1464 ----------
1465 code_str : str
1466 The code to be timed.
1467 times : int, optional
1468 The number of times the code is executed. Default is 1. The code is
1469 only compiled once.
1470 label : str, optional
1471 A label to identify `code_str` with. This is passed into ``compile``
1472 as the second argument (for run-time error messages).
1473
1474 Returns
1475 -------
1476 elapsed : float
1477 Total elapsed time in seconds for executing `code_str` `times` times.
1478
1479 Examples
1480 --------
1481 >>> times = 10
1482 >>> etime = np.testing.measure('for i in range(1000): np.sqrt(i**2)', times=times)
1483 >>> print("Time for a single execution : ", etime / times, "s") # doctest: +SKIP
1484 Time for a single execution : 0.005 s
1485
1486 """
1487 frame = sys._getframe(1)
1488 locs, globs = frame.f_locals, frame.f_globals
1489
1490 code = compile(code_str, f'Test name: {label} ', 'exec')
1491 i = 0
1492 elapsed = jiffies()
1493 while i < times:
1494 i += 1
1495 exec(code, globs, locs)
1496 elapsed = jiffies() - elapsed
1497 return 0.01*elapsed
1498
1499
1500def _assert_valid_refcount(op):
1501 """
1502 Check that ufuncs don't mishandle refcount of object `1`.
1503 Used in a few regression tests.
1504 """
1505 if not HAS_REFCOUNT:
1506 return True
1507
1508 import gc
1509 import numpy as np
1510
1511 b = np.arange(100*100).reshape(100, 100)
1512 c = b
1513 i = 1
1514
1515 gc.disable()
1516 try:
1517 rc = sys.getrefcount(i)
1518 for j in range(15):
1519 d = op(b, c)
1520 assert_(sys.getrefcount(i) >= rc)
1521 finally:
1522 gc.enable()
1523 del d # for pyflakes
1524
1525
1526def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
1527 err_msg='', verbose=True):
1528 """
1529 Raises an AssertionError if two objects are not equal up to desired
1530 tolerance.
1531
1532 Given two array_like objects, check that their shapes and all elements
1533 are equal (but see the Notes for the special handling of a scalar). An
1534 exception is raised if the shapes mismatch or any values conflict. In
1535 contrast to the standard usage in numpy, NaNs are compared like numbers,
1536 no assertion is raised if both objects have NaNs in the same positions.
1537
1538 The test is equivalent to ``allclose(actual, desired, rtol, atol)`` (note
1539 that ``allclose`` has different default values). It compares the difference
1540 between `actual` and `desired` to ``atol + rtol * abs(desired)``.
1541
1542 .. versionadded:: 1.5.0
1543
1544 Parameters
1545 ----------
1546 actual : array_like
1547 Array obtained.
1548 desired : array_like
1549 Array desired.
1550 rtol : float, optional
1551 Relative tolerance.
1552 atol : float, optional
1553 Absolute tolerance.
1554 equal_nan : bool, optional.
1555 If True, NaNs will compare equal.
1556 err_msg : str, optional
1557 The error message to be printed in case of failure.
1558 verbose : bool, optional
1559 If True, the conflicting values are appended to the error message.
1560
1561 Raises
1562 ------
1563 AssertionError
1564 If actual and desired are not equal up to specified precision.
1565
1566 See Also
1567 --------
1568 assert_array_almost_equal_nulp, assert_array_max_ulp
1569
1570 Notes
1571 -----
1572 When one of `actual` and `desired` is a scalar and the other is
1573 array_like, the function checks that each element of the array_like
1574 object is equal to the scalar.
1575
1576 Examples
1577 --------
1578 >>> x = [1e-5, 1e-3, 1e-1]
1579 >>> y = np.arccos(np.cos(x))
1580 >>> np.testing.assert_allclose(x, y, rtol=1e-5, atol=0)
1581
1582 """
1583 __tracebackhide__ = True # Hide traceback for py.test
1584 import numpy as np
1585
1586 def compare(x, y):
1587 return np.core.numeric.isclose(x, y, rtol=rtol, atol=atol,
1588 equal_nan=equal_nan)
1589
1590 actual, desired = np.asanyarray(actual), np.asanyarray(desired)
1591 header = f'Not equal to tolerance rtol={rtol:g}, atol={atol:g}'
1592 assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
1593 verbose=verbose, header=header, equal_nan=equal_nan)
1594
1595
1596def assert_array_almost_equal_nulp(x, y, nulp=1):
1597 """
1598 Compare two arrays relatively to their spacing.
1599
1600 This is a relatively robust method to compare two arrays whose amplitude
1601 is variable.
1602
1603 Parameters
1604 ----------
1605 x, y : array_like
1606 Input arrays.
1607 nulp : int, optional
1608 The maximum number of unit in the last place for tolerance (see Notes).
1609 Default is 1.
1610
1611 Returns
1612 -------
1613 None
1614
1615 Raises
1616 ------
1617 AssertionError
1618 If the spacing between `x` and `y` for one or more elements is larger
1619 than `nulp`.
1620
1621 See Also
1622 --------
1623 assert_array_max_ulp : Check that all items of arrays differ in at most
1624 N Units in the Last Place.
1625 spacing : Return the distance between x and the nearest adjacent number.
1626
1627 Notes
1628 -----
1629 An assertion is raised if the following condition is not met::
1630
1631 abs(x - y) <= nulp * spacing(maximum(abs(x), abs(y)))
1632
1633 Examples
1634 --------
1635 >>> x = np.array([1., 1e-10, 1e-20])
1636 >>> eps = np.finfo(x.dtype).eps
1637 >>> np.testing.assert_array_almost_equal_nulp(x, x*eps/2 + x)
1638
1639 >>> np.testing.assert_array_almost_equal_nulp(x, x*eps + x)
1640 Traceback (most recent call last):
1641 ...
1642 AssertionError: X and Y are not equal to 1 ULP (max is 2)
1643
1644 """
1645 __tracebackhide__ = True # Hide traceback for py.test
1646 import numpy as np
1647 ax = np.abs(x)
1648 ay = np.abs(y)
1649 ref = nulp * np.spacing(np.where(ax > ay, ax, ay))
1650 if not np.all(np.abs(x-y) <= ref):
1651 if np.iscomplexobj(x) or np.iscomplexobj(y):
1652 msg = "X and Y are not equal to %d ULP" % nulp
1653 else:
1654 max_nulp = np.max(nulp_diff(x, y))
1655 msg = "X and Y are not equal to %d ULP (max is %g)" % (nulp, max_nulp)
1656 raise AssertionError(msg)
1657
1658
1659def assert_array_max_ulp(a, b, maxulp=1, dtype=None):
1660 """
1661 Check that all items of arrays differ in at most N Units in the Last Place.
1662
1663 Parameters
1664 ----------
1665 a, b : array_like
1666 Input arrays to be compared.
1667 maxulp : int, optional
1668 The maximum number of units in the last place that elements of `a` and
1669 `b` can differ. Default is 1.
1670 dtype : dtype, optional
1671 Data-type to convert `a` and `b` to if given. Default is None.
1672
1673 Returns
1674 -------
1675 ret : ndarray
1676 Array containing number of representable floating point numbers between
1677 items in `a` and `b`.
1678
1679 Raises
1680 ------
1681 AssertionError
1682 If one or more elements differ by more than `maxulp`.
1683
1684 Notes
1685 -----
1686 For computing the ULP difference, this API does not differentiate between
1687 various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000
1688 is zero).
1689
1690 See Also
1691 --------
1692 assert_array_almost_equal_nulp : Compare two arrays relatively to their
1693 spacing.
1694
1695 Examples
1696 --------
1697 >>> a = np.linspace(0., 1., 100)
1698 >>> res = np.testing.assert_array_max_ulp(a, np.arcsin(np.sin(a)))
1699
1700 """
1701 __tracebackhide__ = True # Hide traceback for py.test
1702 import numpy as np
1703 ret = nulp_diff(a, b, dtype)
1704 if not np.all(ret <= maxulp):
1705 raise AssertionError("Arrays are not almost equal up to %g "
1706 "ULP (max difference is %g ULP)" %
1707 (maxulp, np.max(ret)))
1708 return ret
1709
1710
1711def nulp_diff(x, y, dtype=None):
1712 """For each item in x and y, return the number of representable floating
1713 points between them.
1714
1715 Parameters
1716 ----------
1717 x : array_like
1718 first input array
1719 y : array_like
1720 second input array
1721 dtype : dtype, optional
1722 Data-type to convert `x` and `y` to if given. Default is None.
1723
1724 Returns
1725 -------
1726 nulp : array_like
1727 number of representable floating point numbers between each item in x
1728 and y.
1729
1730 Notes
1731 -----
1732 For computing the ULP difference, this API does not differentiate between
1733 various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000
1734 is zero).
1735
1736 Examples
1737 --------
1738 # By definition, epsilon is the smallest number such as 1 + eps != 1, so
1739 # there should be exactly one ULP between 1 and 1 + eps
1740 >>> nulp_diff(1, 1 + np.finfo(x.dtype).eps)
1741 1.0
1742 """
1743 import numpy as np
1744 if dtype:
1745 x = np.asarray(x, dtype=dtype)
1746 y = np.asarray(y, dtype=dtype)
1747 else:
1748 x = np.asarray(x)
1749 y = np.asarray(y)
1750
1751 t = np.common_type(x, y)
1752 if np.iscomplexobj(x) or np.iscomplexobj(y):
1753 raise NotImplementedError("_nulp not implemented for complex array")
1754
1755 x = np.array([x], dtype=t)
1756 y = np.array([y], dtype=t)
1757
1758 x[np.isnan(x)] = np.nan
1759 y[np.isnan(y)] = np.nan
1760
1761 if not x.shape == y.shape:
1762 raise ValueError("x and y do not have the same shape: %s - %s" %
1763 (x.shape, y.shape))
1764
1765 def _diff(rx, ry, vdt):
1766 diff = np.asarray(rx-ry, dtype=vdt)
1767 return np.abs(diff)
1768
1769 rx = integer_repr(x)
1770 ry = integer_repr(y)
1771 return _diff(rx, ry, t)
1772
1773
1774def _integer_repr(x, vdt, comp):
1775 # Reinterpret binary representation of the float as sign-magnitude:
1776 # take into account two-complement representation
1777 # See also
1778 # https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/
1779 rx = x.view(vdt)
1780 if not (rx.size == 1):
1781 rx[rx < 0] = comp - rx[rx < 0]
1782 else:
1783 if rx < 0:
1784 rx = comp - rx
1785
1786 return rx
1787
1788
1789def integer_repr(x):
1790 """Return the signed-magnitude interpretation of the binary representation
1791 of x."""
1792 import numpy as np
1793 if x.dtype == np.float16:
1794 return _integer_repr(x, np.int16, np.int16(-2**15))
1795 elif x.dtype == np.float32:
1796 return _integer_repr(x, np.int32, np.int32(-2**31))
1797 elif x.dtype == np.float64:
1798 return _integer_repr(x, np.int64, np.int64(-2**63))
1799 else:
1800 raise ValueError(f'Unsupported dtype {x.dtype}')
1801
1802
1803@contextlib.contextmanager
1804def _assert_warns_context(warning_class, name=None):
1805 __tracebackhide__ = True # Hide traceback for py.test
1806 with suppress_warnings() as sup:
1807 l = sup.record(warning_class)
1808 yield
1809 if not len(l) > 0:
1810 name_str = f' when calling {name}' if name is not None else ''
1811 raise AssertionError("No warning raised" + name_str)
1812
1813
1814def assert_warns(warning_class, *args, **kwargs):
1815 """
1816 Fail unless the given callable throws the specified warning.
1817
1818 A warning of class warning_class should be thrown by the callable when
1819 invoked with arguments args and keyword arguments kwargs.
1820 If a different type of warning is thrown, it will not be caught.
1821
1822 If called with all arguments other than the warning class omitted, may be
1823 used as a context manager:
1824
1825 with assert_warns(SomeWarning):
1826 do_something()
1827
1828 The ability to be used as a context manager is new in NumPy v1.11.0.
1829
1830 .. versionadded:: 1.4.0
1831
1832 Parameters
1833 ----------
1834 warning_class : class
1835 The class defining the warning that `func` is expected to throw.
1836 func : callable, optional
1837 Callable to test
1838 *args : Arguments
1839 Arguments for `func`.
1840 **kwargs : Kwargs
1841 Keyword arguments for `func`.
1842
1843 Returns
1844 -------
1845 The value returned by `func`.
1846
1847 Examples
1848 --------
1849 >>> import warnings
1850 >>> def deprecated_func(num):
1851 ... warnings.warn("Please upgrade", DeprecationWarning)
1852 ... return num*num
1853 >>> with np.testing.assert_warns(DeprecationWarning):
1854 ... assert deprecated_func(4) == 16
1855 >>> # or passing a func
1856 >>> ret = np.testing.assert_warns(DeprecationWarning, deprecated_func, 4)
1857 >>> assert ret == 16
1858 """
1859 if not args:
1860 return _assert_warns_context(warning_class)
1861
1862 func = args[0]
1863 args = args[1:]
1864 with _assert_warns_context(warning_class, name=func.__name__):
1865 return func(*args, **kwargs)
1866
1867
1868@contextlib.contextmanager
1869def _assert_no_warnings_context(name=None):
1870 __tracebackhide__ = True # Hide traceback for py.test
1871 with warnings.catch_warnings(record=True) as l:
1872 warnings.simplefilter('always')
1873 yield
1874 if len(l) > 0:
1875 name_str = f' when calling {name}' if name is not None else ''
1876 raise AssertionError(f'Got warnings{name_str}: {l}')
1877
1878
1879def assert_no_warnings(*args, **kwargs):
1880 """
1881 Fail if the given callable produces any warnings.
1882
1883 If called with all arguments omitted, may be used as a context manager:
1884
1885 with assert_no_warnings():
1886 do_something()
1887
1888 The ability to be used as a context manager is new in NumPy v1.11.0.
1889
1890 .. versionadded:: 1.7.0
1891
1892 Parameters
1893 ----------
1894 func : callable
1895 The callable to test.
1896 \\*args : Arguments
1897 Arguments passed to `func`.
1898 \\*\\*kwargs : Kwargs
1899 Keyword arguments passed to `func`.
1900
1901 Returns
1902 -------
1903 The value returned by `func`.
1904
1905 """
1906 if not args:
1907 return _assert_no_warnings_context()
1908
1909 func = args[0]
1910 args = args[1:]
1911 with _assert_no_warnings_context(name=func.__name__):
1912 return func(*args, **kwargs)
1913
1914
1915def _gen_alignment_data(dtype=float32, type='binary', max_size=24):
1916 """
1917 generator producing data with different alignment and offsets
1918 to test simd vectorization
1919
1920 Parameters
1921 ----------
1922 dtype : dtype
1923 data type to produce
1924 type : string
1925 'unary': create data for unary operations, creates one input
1926 and output array
1927 'binary': create data for unary operations, creates two input
1928 and output array
1929 max_size : integer
1930 maximum size of data to produce
1931
1932 Returns
1933 -------
1934 if type is 'unary' yields one output, one input array and a message
1935 containing information on the data
1936 if type is 'binary' yields one output array, two input array and a message
1937 containing information on the data
1938
1939 """
1940 ufmt = 'unary offset=(%d, %d), size=%d, dtype=%r, %s'
1941 bfmt = 'binary offset=(%d, %d, %d), size=%d, dtype=%r, %s'
1942 for o in range(3):
1943 for s in range(o + 2, max(o + 3, max_size)):
1944 if type == 'unary':
1945 inp = lambda: arange(s, dtype=dtype)[o:]
1946 out = empty((s,), dtype=dtype)[o:]
1947 yield out, inp(), ufmt % (o, o, s, dtype, 'out of place')
1948 d = inp()
1949 yield d, d, ufmt % (o, o, s, dtype, 'in place')
1950 yield out[1:], inp()[:-1], ufmt % \
1951 (o + 1, o, s - 1, dtype, 'out of place')
1952 yield out[:-1], inp()[1:], ufmt % \
1953 (o, o + 1, s - 1, dtype, 'out of place')
1954 yield inp()[:-1], inp()[1:], ufmt % \
1955 (o, o + 1, s - 1, dtype, 'aliased')
1956 yield inp()[1:], inp()[:-1], ufmt % \
1957 (o + 1, o, s - 1, dtype, 'aliased')
1958 if type == 'binary':
1959 inp1 = lambda: arange(s, dtype=dtype)[o:]
1960 inp2 = lambda: arange(s, dtype=dtype)[o:]
1961 out = empty((s,), dtype=dtype)[o:]
1962 yield out, inp1(), inp2(), bfmt % \
1963 (o, o, o, s, dtype, 'out of place')
1964 d = inp1()
1965 yield d, d, inp2(), bfmt % \
1966 (o, o, o, s, dtype, 'in place1')
1967 d = inp2()
1968 yield d, inp1(), d, bfmt % \
1969 (o, o, o, s, dtype, 'in place2')
1970 yield out[1:], inp1()[:-1], inp2()[:-1], bfmt % \
1971 (o + 1, o, o, s - 1, dtype, 'out of place')
1972 yield out[:-1], inp1()[1:], inp2()[:-1], bfmt % \
1973 (o, o + 1, o, s - 1, dtype, 'out of place')
1974 yield out[:-1], inp1()[:-1], inp2()[1:], bfmt % \
1975 (o, o, o + 1, s - 1, dtype, 'out of place')
1976 yield inp1()[1:], inp1()[:-1], inp2()[:-1], bfmt % \
1977 (o + 1, o, o, s - 1, dtype, 'aliased')
1978 yield inp1()[:-1], inp1()[1:], inp2()[:-1], bfmt % \
1979 (o, o + 1, o, s - 1, dtype, 'aliased')
1980 yield inp1()[:-1], inp1()[:-1], inp2()[1:], bfmt % \
1981 (o, o, o + 1, s - 1, dtype, 'aliased')
1982
1983
1984class IgnoreException(Exception):
1985 "Ignoring this exception due to disabled feature"
1986 pass
1987
1988
1989@contextlib.contextmanager
1990def tempdir(*args, **kwargs):
1991 """Context manager to provide a temporary test folder.
1992
1993 All arguments are passed as this to the underlying tempfile.mkdtemp
1994 function.
1995
1996 """
1997 tmpdir = mkdtemp(*args, **kwargs)
1998 try:
1999 yield tmpdir
2000 finally:
2001 shutil.rmtree(tmpdir)
2002
2003
2004@contextlib.contextmanager
2005def temppath(*args, **kwargs):
2006 """Context manager for temporary files.
2007
2008 Context manager that returns the path to a closed temporary file. Its
2009 parameters are the same as for tempfile.mkstemp and are passed directly
2010 to that function. The underlying file is removed when the context is
2011 exited, so it should be closed at that time.
2012
2013 Windows does not allow a temporary file to be opened if it is already
2014 open, so the underlying file must be closed after opening before it
2015 can be opened again.
2016
2017 """
2018 fd, path = mkstemp(*args, **kwargs)
2019 os.close(fd)
2020 try:
2021 yield path
2022 finally:
2023 os.remove(path)
2024
2025
2026class clear_and_catch_warnings(warnings.catch_warnings):
2027 """ Context manager that resets warning registry for catching warnings
2028
2029 Warnings can be slippery, because, whenever a warning is triggered, Python
2030 adds a ``__warningregistry__`` member to the *calling* module. This makes
2031 it impossible to retrigger the warning in this module, whatever you put in
2032 the warnings filters. This context manager accepts a sequence of `modules`
2033 as a keyword argument to its constructor and:
2034
2035 * stores and removes any ``__warningregistry__`` entries in given `modules`
2036 on entry;
2037 * resets ``__warningregistry__`` to its previous state on exit.
2038
2039 This makes it possible to trigger any warning afresh inside the context
2040 manager without disturbing the state of warnings outside.
2041
2042 For compatibility with Python 3.0, please consider all arguments to be
2043 keyword-only.
2044
2045 Parameters
2046 ----------
2047 record : bool, optional
2048 Specifies whether warnings should be captured by a custom
2049 implementation of ``warnings.showwarning()`` and be appended to a list
2050 returned by the context manager. Otherwise None is returned by the
2051 context manager. The objects appended to the list are arguments whose
2052 attributes mirror the arguments to ``showwarning()``.
2053 modules : sequence, optional
2054 Sequence of modules for which to reset warnings registry on entry and
2055 restore on exit. To work correctly, all 'ignore' filters should
2056 filter by one of these modules.
2057
2058 Examples
2059 --------
2060 >>> import warnings
2061 >>> with np.testing.clear_and_catch_warnings(
2062 ... modules=[np.core.fromnumeric]):
2063 ... warnings.simplefilter('always')
2064 ... warnings.filterwarnings('ignore', module='np.core.fromnumeric')
2065 ... # do something that raises a warning but ignore those in
2066 ... # np.core.fromnumeric
2067 """
2068 class_modules = ()
2069
2070 def __init__(self, record=False, modules=()):
2071 self.modules = set(modules).union(self.class_modules)
2072 self._warnreg_copies = {}
2073 super().__init__(record=record)
2074
2075 def __enter__(self):
2076 for mod in self.modules:
2077 if hasattr(mod, '__warningregistry__'):
2078 mod_reg = mod.__warningregistry__
2079 self._warnreg_copies[mod] = mod_reg.copy()
2080 mod_reg.clear()
2081 return super().__enter__()
2082
2083 def __exit__(self, *exc_info):
2084 super().__exit__(*exc_info)
2085 for mod in self.modules:
2086 if hasattr(mod, '__warningregistry__'):
2087 mod.__warningregistry__.clear()
2088 if mod in self._warnreg_copies:
2089 mod.__warningregistry__.update(self._warnreg_copies[mod])
2090
2091
2092class suppress_warnings:
2093 """
2094 Context manager and decorator doing much the same as
2095 ``warnings.catch_warnings``.
2096
2097 However, it also provides a filter mechanism to work around
2098 https://bugs.python.org/issue4180.
2099
2100 This bug causes Python before 3.4 to not reliably show warnings again
2101 after they have been ignored once (even within catch_warnings). It
2102 means that no "ignore" filter can be used easily, since following
2103 tests might need to see the warning. Additionally it allows easier
2104 specificity for testing warnings and can be nested.
2105
2106 Parameters
2107 ----------
2108 forwarding_rule : str, optional
2109 One of "always", "once", "module", or "location". Analogous to
2110 the usual warnings module filter mode, it is useful to reduce
2111 noise mostly on the outmost level. Unsuppressed and unrecorded
2112 warnings will be forwarded based on this rule. Defaults to "always".
2113 "location" is equivalent to the warnings "default", match by exact
2114 location the warning warning originated from.
2115
2116 Notes
2117 -----
2118 Filters added inside the context manager will be discarded again
2119 when leaving it. Upon entering all filters defined outside a
2120 context will be applied automatically.
2121
2122 When a recording filter is added, matching warnings are stored in the
2123 ``log`` attribute as well as in the list returned by ``record``.
2124
2125 If filters are added and the ``module`` keyword is given, the
2126 warning registry of this module will additionally be cleared when
2127 applying it, entering the context, or exiting it. This could cause
2128 warnings to appear a second time after leaving the context if they
2129 were configured to be printed once (default) and were already
2130 printed before the context was entered.
2131
2132 Nesting this context manager will work as expected when the
2133 forwarding rule is "always" (default). Unfiltered and unrecorded
2134 warnings will be passed out and be matched by the outer level.
2135 On the outmost level they will be printed (or caught by another
2136 warnings context). The forwarding rule argument can modify this
2137 behaviour.
2138
2139 Like ``catch_warnings`` this context manager is not threadsafe.
2140
2141 Examples
2142 --------
2143
2144 With a context manager::
2145
2146 with np.testing.suppress_warnings() as sup:
2147 sup.filter(DeprecationWarning, "Some text")
2148 sup.filter(module=np.ma.core)
2149 log = sup.record(FutureWarning, "Does this occur?")
2150 command_giving_warnings()
2151 # The FutureWarning was given once, the filtered warnings were
2152 # ignored. All other warnings abide outside settings (may be
2153 # printed/error)
2154 assert_(len(log) == 1)
2155 assert_(len(sup.log) == 1) # also stored in log attribute
2156
2157 Or as a decorator::
2158
2159 sup = np.testing.suppress_warnings()
2160 sup.filter(module=np.ma.core) # module must match exactly
2161 @sup
2162 def some_function():
2163 # do something which causes a warning in np.ma.core
2164 pass
2165 """
2166 def __init__(self, forwarding_rule="always"):
2167 self._entered = False
2168
2169 # Suppressions are either instance or defined inside one with block:
2170 self._suppressions = []
2171
2172 if forwarding_rule not in {"always", "module", "once", "location"}:
2173 raise ValueError("unsupported forwarding rule.")
2174 self._forwarding_rule = forwarding_rule
2175
2176 def _clear_registries(self):
2177 if hasattr(warnings, "_filters_mutated"):
2178 # clearing the registry should not be necessary on new pythons,
2179 # instead the filters should be mutated.
2180 warnings._filters_mutated()
2181 return
2182 # Simply clear the registry, this should normally be harmless,
2183 # note that on new pythons it would be invalidated anyway.
2184 for module in self._tmp_modules:
2185 if hasattr(module, "__warningregistry__"):
2186 module.__warningregistry__.clear()
2187
2188 def _filter(self, category=Warning, message="", module=None, record=False):
2189 if record:
2190 record = [] # The log where to store warnings
2191 else:
2192 record = None
2193 if self._entered:
2194 if module is None:
2195 warnings.filterwarnings(
2196 "always", category=category, message=message)
2197 else:
2198 module_regex = module.__name__.replace('.', r'\.') + '$'
2199 warnings.filterwarnings(
2200 "always", category=category, message=message,
2201 module=module_regex)
2202 self._tmp_modules.add(module)
2203 self._clear_registries()
2204
2205 self._tmp_suppressions.append(
2206 (category, message, re.compile(message, re.I), module, record))
2207 else:
2208 self._suppressions.append(
2209 (category, message, re.compile(message, re.I), module, record))
2210
2211 return record
2212
2213 def filter(self, category=Warning, message="", module=None):
2214 """
2215 Add a new suppressing filter or apply it if the state is entered.
2216
2217 Parameters
2218 ----------
2219 category : class, optional
2220 Warning class to filter
2221 message : string, optional
2222 Regular expression matching the warning message.
2223 module : module, optional
2224 Module to filter for. Note that the module (and its file)
2225 must match exactly and cannot be a submodule. This may make
2226 it unreliable for external modules.
2227
2228 Notes
2229 -----
2230 When added within a context, filters are only added inside
2231 the context and will be forgotten when the context is exited.
2232 """
2233 self._filter(category=category, message=message, module=module,
2234 record=False)
2235
2236 def record(self, category=Warning, message="", module=None):
2237 """
2238 Append a new recording filter or apply it if the state is entered.
2239
2240 All warnings matching will be appended to the ``log`` attribute.
2241
2242 Parameters
2243 ----------
2244 category : class, optional
2245 Warning class to filter
2246 message : string, optional
2247 Regular expression matching the warning message.
2248 module : module, optional
2249 Module to filter for. Note that the module (and its file)
2250 must match exactly and cannot be a submodule. This may make
2251 it unreliable for external modules.
2252
2253 Returns
2254 -------
2255 log : list
2256 A list which will be filled with all matched warnings.
2257
2258 Notes
2259 -----
2260 When added within a context, filters are only added inside
2261 the context and will be forgotten when the context is exited.
2262 """
2263 return self._filter(category=category, message=message, module=module,
2264 record=True)
2265
2266 def __enter__(self):
2267 if self._entered:
2268 raise RuntimeError("cannot enter suppress_warnings twice.")
2269
2270 self._orig_show = warnings.showwarning
2271 self._filters = warnings.filters
2272 warnings.filters = self._filters[:]
2273
2274 self._entered = True
2275 self._tmp_suppressions = []
2276 self._tmp_modules = set()
2277 self._forwarded = set()
2278
2279 self.log = [] # reset global log (no need to keep same list)
2280
2281 for cat, mess, _, mod, log in self._suppressions:
2282 if log is not None:
2283 del log[:] # clear the log
2284 if mod is None:
2285 warnings.filterwarnings(
2286 "always", category=cat, message=mess)
2287 else:
2288 module_regex = mod.__name__.replace('.', r'\.') + '$'
2289 warnings.filterwarnings(
2290 "always", category=cat, message=mess,
2291 module=module_regex)
2292 self._tmp_modules.add(mod)
2293 warnings.showwarning = self._showwarning
2294 self._clear_registries()
2295
2296 return self
2297
2298 def __exit__(self, *exc_info):
2299 warnings.showwarning = self._orig_show
2300 warnings.filters = self._filters
2301 self._clear_registries()
2302 self._entered = False
2303 del self._orig_show
2304 del self._filters
2305
2306 def _showwarning(self, message, category, filename, lineno,
2307 *args, use_warnmsg=None, **kwargs):
2308 for cat, _, pattern, mod, rec in (
2309 self._suppressions + self._tmp_suppressions)[::-1]:
2310 if (issubclass(category, cat) and
2311 pattern.match(message.args[0]) is not None):
2312 if mod is None:
2313 # Message and category match, either recorded or ignored
2314 if rec is not None:
2315 msg = WarningMessage(message, category, filename,
2316 lineno, **kwargs)
2317 self.log.append(msg)
2318 rec.append(msg)
2319 return
2320 # Use startswith, because warnings strips the c or o from
2321 # .pyc/.pyo files.
2322 elif mod.__file__.startswith(filename):
2323 # The message and module (filename) match
2324 if rec is not None:
2325 msg = WarningMessage(message, category, filename,
2326 lineno, **kwargs)
2327 self.log.append(msg)
2328 rec.append(msg)
2329 return
2330
2331 # There is no filter in place, so pass to the outside handler
2332 # unless we should only pass it once
2333 if self._forwarding_rule == "always":
2334 if use_warnmsg is None:
2335 self._orig_show(message, category, filename, lineno,
2336 *args, **kwargs)
2337 else:
2338 self._orig_showmsg(use_warnmsg)
2339 return
2340
2341 if self._forwarding_rule == "once":
2342 signature = (message.args, category)
2343 elif self._forwarding_rule == "module":
2344 signature = (message.args, category, filename)
2345 elif self._forwarding_rule == "location":
2346 signature = (message.args, category, filename, lineno)
2347
2348 if signature in self._forwarded:
2349 return
2350 self._forwarded.add(signature)
2351 if use_warnmsg is None:
2352 self._orig_show(message, category, filename, lineno, *args,
2353 **kwargs)
2354 else:
2355 self._orig_showmsg(use_warnmsg)
2356
2357 def __call__(self, func):
2358 """
2359 Function decorator to apply certain suppressions to a whole
2360 function.
2361 """
2362 @wraps(func)
2363 def new_func(*args, **kwargs):
2364 with self:
2365 return func(*args, **kwargs)
2366
2367 return new_func
2368
2369
2370@contextlib.contextmanager
2371def _assert_no_gc_cycles_context(name=None):
2372 __tracebackhide__ = True # Hide traceback for py.test
2373
2374 # not meaningful to test if there is no refcounting
2375 if not HAS_REFCOUNT:
2376 yield
2377 return
2378
2379 assert_(gc.isenabled())
2380 gc.disable()
2381 gc_debug = gc.get_debug()
2382 try:
2383 for i in range(100):
2384 if gc.collect() == 0:
2385 break
2386 else:
2387 raise RuntimeError(
2388 "Unable to fully collect garbage - perhaps a __del__ method "
2389 "is creating more reference cycles?")
2390
2391 gc.set_debug(gc.DEBUG_SAVEALL)
2392 yield
2393 # gc.collect returns the number of unreachable objects in cycles that
2394 # were found -- we are checking that no cycles were created in the context
2395 n_objects_in_cycles = gc.collect()
2396 objects_in_cycles = gc.garbage[:]
2397 finally:
2398 del gc.garbage[:]
2399 gc.set_debug(gc_debug)
2400 gc.enable()
2401
2402 if n_objects_in_cycles:
2403 name_str = f' when calling {name}' if name is not None else ''
2404 raise AssertionError(
2405 "Reference cycles were found{}: {} objects were collected, "
2406 "of which {} are shown below:{}"
2407 .format(
2408 name_str,
2409 n_objects_in_cycles,
2410 len(objects_in_cycles),
2411 ''.join(
2412 "\n {} object with id={}:\n {}".format(
2413 type(o).__name__,
2414 id(o),
2415 pprint.pformat(o).replace('\n', '\n ')
2416 ) for o in objects_in_cycles
2417 )
2418 )
2419 )
2420
2421
2422def assert_no_gc_cycles(*args, **kwargs):
2423 """
2424 Fail if the given callable produces any reference cycles.
2425
2426 If called with all arguments omitted, may be used as a context manager:
2427
2428 with assert_no_gc_cycles():
2429 do_something()
2430
2431 .. versionadded:: 1.15.0
2432
2433 Parameters
2434 ----------
2435 func : callable
2436 The callable to test.
2437 \\*args : Arguments
2438 Arguments passed to `func`.
2439 \\*\\*kwargs : Kwargs
2440 Keyword arguments passed to `func`.
2441
2442 Returns
2443 -------
2444 Nothing. The result is deliberately discarded to ensure that all cycles
2445 are found.
2446
2447 """
2448 if not args:
2449 return _assert_no_gc_cycles_context()
2450
2451 func = args[0]
2452 args = args[1:]
2453 with _assert_no_gc_cycles_context(name=func.__name__):
2454 func(*args, **kwargs)
2455
2456def break_cycles():
2457 """
2458 Break reference cycles by calling gc.collect
2459 Objects can call other objects' methods (for instance, another object's
2460 __del__) inside their own __del__. On PyPy, the interpreter only runs
2461 between calls to gc.collect, so multiple calls are needed to completely
2462 release all cycles.
2463 """
2464
2465 gc.collect()
2466 if IS_PYPY:
2467 # a few more, just to make sure all the finalizers are called
2468 gc.collect()
2469 gc.collect()
2470 gc.collect()
2471 gc.collect()
2472
2473
2474def requires_memory(free_bytes):
2475 """Decorator to skip a test if not enough memory is available"""
2476 import pytest
2477
2478 def decorator(func):
2479 @wraps(func)
2480 def wrapper(*a, **kw):
2481 msg = check_free_memory(free_bytes)
2482 if msg is not None:
2483 pytest.skip(msg)
2484
2485 try:
2486 return func(*a, **kw)
2487 except MemoryError:
2488 # Probably ran out of memory regardless: don't regard as failure
2489 pytest.xfail("MemoryError raised")
2490
2491 return wrapper
2492
2493 return decorator
2494
2495
2496def check_free_memory(free_bytes):
2497 """
2498 Check whether `free_bytes` amount of memory is currently free.
2499 Returns: None if enough memory available, otherwise error message
2500 """
2501 env_var = 'NPY_AVAILABLE_MEM'
2502 env_value = os.environ.get(env_var)
2503 if env_value is not None:
2504 try:
2505 mem_free = _parse_size(env_value)
2506 except ValueError as exc:
2507 raise ValueError(f'Invalid environment variable {env_var}: {exc}')
2508
2509 msg = (f'{free_bytes/1e9} GB memory required, but environment variable '
2510 f'NPY_AVAILABLE_MEM={env_value} set')
2511 else:
2512 mem_free = _get_mem_available()
2513
2514 if mem_free is None:
2515 msg = ("Could not determine available memory; set NPY_AVAILABLE_MEM "
2516 "environment variable (e.g. NPY_AVAILABLE_MEM=16GB) to run "
2517 "the test.")
2518 mem_free = -1
2519 else:
2520 msg = f'{free_bytes/1e9} GB memory required, but {mem_free/1e9} GB available'
2521
2522 return msg if mem_free < free_bytes else None
2523
2524
2525def _parse_size(size_str):
2526 """Convert memory size strings ('12 GB' etc.) to float"""
2527 suffixes = {'': 1, 'b': 1,
2528 'k': 1000, 'm': 1000**2, 'g': 1000**3, 't': 1000**4,
2529 'kb': 1000, 'mb': 1000**2, 'gb': 1000**3, 'tb': 1000**4,
2530 'kib': 1024, 'mib': 1024**2, 'gib': 1024**3, 'tib': 1024**4}
2531
2532 size_re = re.compile(r'^\s*(\d+|\d+\.\d+)\s*({0})\s*$'.format(
2533 '|'.join(suffixes.keys())), re.I)
2534
2535 m = size_re.match(size_str.lower())
2536 if not m or m.group(2) not in suffixes:
2537 raise ValueError(f'value {size_str!r} not a valid size')
2538 return int(float(m.group(1)) * suffixes[m.group(2)])
2539
2540
2541def _get_mem_available():
2542 """Return available memory in bytes, or None if unknown."""
2543 try:
2544 import psutil
2545 return psutil.virtual_memory().available
2546 except (ImportError, AttributeError):
2547 pass
2548
2549 if sys.platform.startswith('linux'):
2550 info = {}
2551 with open('/proc/meminfo', 'r') as f:
2552 for line in f:
2553 p = line.split()
2554 info[p[0].strip(':').lower()] = int(p[1]) * 1024
2555
2556 if 'memavailable' in info:
2557 # Linux >= 3.14
2558 return info['memavailable']
2559 else:
2560 return info['memfree'] + info['cached']
2561
2562 return None
2563
2564
2565def _no_tracing(func):
2566 """
2567 Decorator to temporarily turn off tracing for the duration of a test.
2568 Needed in tests that check refcounting, otherwise the tracing itself
2569 influences the refcounts
2570 """
2571 if not hasattr(sys, 'gettrace'):
2572 return func
2573 else:
2574 @wraps(func)
2575 def wrapper(*args, **kwargs):
2576 original_trace = sys.gettrace()
2577 try:
2578 sys.settrace(None)
2579 return func(*args, **kwargs)
2580 finally:
2581 sys.settrace(original_trace)
2582 return wrapper
2583
2584
2585def _get_glibc_version():
2586 try:
2587 ver = os.confstr('CS_GNU_LIBC_VERSION').rsplit(' ')[1]
2588 except Exception as inst:
2589 ver = '0.0'
2590
2591 return ver
2592
2593
2594_glibcver = _get_glibc_version()
2595_glibc_older_than = lambda x: (_glibcver != '0.0' and _glibcver < x)