Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/numpy/testing/_private/utils.py: 14%
876 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-23 06:43 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-23 06:43 +0000
1"""
2Utility function to facilitate testing.
4"""
5import os
6import sys
7import platform
8import re
9import gc
10import operator
11import warnings
12from functools import partial, wraps
13import shutil
14import contextlib
15from tempfile import mkdtemp, mkstemp
16from unittest.case import SkipTest
17from warnings import WarningMessage
18import pprint
19import sysconfig
21import numpy as np
22from numpy.core import (
23 intp, float32, empty, arange, array_repr, ndarray, isnat, array)
24from numpy import isfinite, isnan, isinf
25import numpy.linalg._umath_linalg
27from io import StringIO
29__all__ = [
30 'assert_equal', 'assert_almost_equal', 'assert_approx_equal',
31 'assert_array_equal', 'assert_array_less', 'assert_string_equal',
32 'assert_array_almost_equal', 'assert_raises', 'build_err_msg',
33 'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal',
34 'rundocs', 'runstring', 'verbose', 'measure',
35 'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex',
36 'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings',
37 'assert_allclose', 'IgnoreException', 'clear_and_catch_warnings',
38 'SkipTest', 'KnownFailureException', 'temppath', 'tempdir', 'IS_PYPY',
39 'HAS_REFCOUNT', "IS_WASM", 'suppress_warnings', 'assert_array_compare',
40 'assert_no_gc_cycles', 'break_cycles', 'HAS_LAPACK64', 'IS_PYSTON',
41 '_OLD_PROMOTION', 'IS_MUSL', '_SUPPORTS_SVE'
42 ]
45class KnownFailureException(Exception):
46 '''Raise this exception to mark a test as a known failing test.'''
47 pass
50KnownFailureTest = KnownFailureException # backwards compat
51verbose = 0
53IS_WASM = platform.machine() in ["wasm32", "wasm64"]
54IS_PYPY = sys.implementation.name == 'pypy'
55IS_PYSTON = hasattr(sys, "pyston_version_info")
56HAS_REFCOUNT = getattr(sys, 'getrefcount', None) is not None and not IS_PYSTON
57HAS_LAPACK64 = numpy.linalg._umath_linalg._ilp64
59_OLD_PROMOTION = lambda: np._get_promotion_state() == 'legacy'
61IS_MUSL = False
62# alternate way is
63# from packaging.tags import sys_tags
64# _tags = list(sys_tags())
65# if 'musllinux' in _tags[0].platform:
66_v = sysconfig.get_config_var('HOST_GNU_TYPE') or ''
67if 'musl' in _v:
68 IS_MUSL = True
71def assert_(val, msg=''):
72 """
73 Assert that works in release mode.
74 Accepts callable msg to allow deferring evaluation until failure.
76 The Python built-in ``assert`` does not work when executing code in
77 optimized mode (the ``-O`` flag) - no byte-code is generated for it.
79 For documentation on usage, refer to the Python documentation.
81 """
82 __tracebackhide__ = True # Hide traceback for py.test
83 if not val:
84 try:
85 smsg = msg()
86 except TypeError:
87 smsg = msg
88 raise AssertionError(smsg)
91if os.name == 'nt':
92 # Code "stolen" from enthought/debug/memusage.py
93 def GetPerformanceAttributes(object, counter, instance=None,
94 inum=-1, format=None, machine=None):
95 # NOTE: Many counters require 2 samples to give accurate results,
96 # including "% Processor Time" (as by definition, at any instant, a
97 # thread's CPU usage is either 0 or 100). To read counters like this,
98 # you should copy this function, but keep the counter open, and call
99 # CollectQueryData() each time you need to know.
100 # See http://msdn.microsoft.com/library/en-us/dnperfmo/html/perfmonpt2.asp (dead link)
101 # My older explanation for this was that the "AddCounter" process
102 # forced the CPU to 100%, but the above makes more sense :)
103 import win32pdh
104 if format is None:
105 format = win32pdh.PDH_FMT_LONG
106 path = win32pdh.MakeCounterPath( (machine, object, instance, None,
107 inum, counter))
108 hq = win32pdh.OpenQuery()
109 try:
110 hc = win32pdh.AddCounter(hq, path)
111 try:
112 win32pdh.CollectQueryData(hq)
113 type, val = win32pdh.GetFormattedCounterValue(hc, format)
114 return val
115 finally:
116 win32pdh.RemoveCounter(hc)
117 finally:
118 win32pdh.CloseQuery(hq)
120 def memusage(processName="python", instance=0):
121 # from win32pdhutil, part of the win32all package
122 import win32pdh
123 return GetPerformanceAttributes("Process", "Virtual Bytes",
124 processName, instance,
125 win32pdh.PDH_FMT_LONG, None)
126elif sys.platform[:5] == 'linux':
128 def memusage(_proc_pid_stat=f'/proc/{os.getpid()}/stat'):
129 """
130 Return virtual memory size in bytes of the running python.
132 """
133 try:
134 with open(_proc_pid_stat) as f:
135 l = f.readline().split(' ')
136 return int(l[22])
137 except Exception:
138 return
139else:
140 def memusage():
141 """
142 Return memory usage of running python. [Not implemented]
144 """
145 raise NotImplementedError
148if sys.platform[:5] == 'linux':
149 def jiffies(_proc_pid_stat=f'/proc/{os.getpid()}/stat', _load_time=[]):
150 """
151 Return number of jiffies elapsed.
153 Return number of jiffies (1/100ths of a second) that this
154 process has been scheduled in user mode. See man 5 proc.
156 """
157 import time
158 if not _load_time:
159 _load_time.append(time.time())
160 try:
161 with open(_proc_pid_stat) as f:
162 l = f.readline().split(' ')
163 return int(l[13])
164 except Exception:
165 return int(100*(time.time()-_load_time[0]))
166else:
167 # os.getpid is not in all platforms available.
168 # Using time is safe but inaccurate, especially when process
169 # was suspended or sleeping.
170 def jiffies(_load_time=[]):
171 """
172 Return number of jiffies elapsed.
174 Return number of jiffies (1/100ths of a second) that this
175 process has been scheduled in user mode. See man 5 proc.
177 """
178 import time
179 if not _load_time:
180 _load_time.append(time.time())
181 return int(100*(time.time()-_load_time[0]))
184def build_err_msg(arrays, err_msg, header='Items are not equal:',
185 verbose=True, names=('ACTUAL', 'DESIRED'), precision=8):
186 msg = ['\n' + header]
187 if err_msg:
188 if err_msg.find('\n') == -1 and len(err_msg) < 79-len(header):
189 msg = [msg[0] + ' ' + err_msg]
190 else:
191 msg.append(err_msg)
192 if verbose:
193 for i, a in enumerate(arrays):
195 if isinstance(a, ndarray):
196 # precision argument is only needed if the objects are ndarrays
197 r_func = partial(array_repr, precision=precision)
198 else:
199 r_func = repr
201 try:
202 r = r_func(a)
203 except Exception as exc:
204 r = f'[repr failed for <{type(a).__name__}>: {exc}]'
205 if r.count('\n') > 3:
206 r = '\n'.join(r.splitlines()[:3])
207 r += '...'
208 msg.append(f' {names[i]}: {r}')
209 return '\n'.join(msg)
212def assert_equal(actual, desired, err_msg='', verbose=True):
213 """
214 Raises an AssertionError if two objects are not equal.
216 Given two objects (scalars, lists, tuples, dictionaries or numpy arrays),
217 check that all elements of these objects are equal. An exception is raised
218 at the first conflicting values.
220 When one of `actual` and `desired` is a scalar and the other is array_like,
221 the function checks that each element of the array_like object is equal to
222 the scalar.
224 This function handles NaN comparisons as if NaN was a "normal" number.
225 That is, AssertionError is not raised if both objects have NaNs in the same
226 positions. This is in contrast to the IEEE standard on NaNs, which says
227 that NaN compared to anything must return False.
229 Parameters
230 ----------
231 actual : array_like
232 The object to check.
233 desired : array_like
234 The expected object.
235 err_msg : str, optional
236 The error message to be printed in case of failure.
237 verbose : bool, optional
238 If True, the conflicting values are appended to the error message.
240 Raises
241 ------
242 AssertionError
243 If actual and desired are not equal.
245 Examples
246 --------
247 >>> np.testing.assert_equal([4,5], [4,6])
248 Traceback (most recent call last):
249 ...
250 AssertionError:
251 Items are not equal:
252 item=1
253 ACTUAL: 5
254 DESIRED: 6
256 The following comparison does not raise an exception. There are NaNs
257 in the inputs, but they are in the same positions.
259 >>> np.testing.assert_equal(np.array([1.0, 2.0, np.nan]), [1, 2, np.nan])
261 """
262 __tracebackhide__ = True # Hide traceback for py.test
263 if isinstance(desired, dict):
264 if not isinstance(actual, dict):
265 raise AssertionError(repr(type(actual)))
266 assert_equal(len(actual), len(desired), err_msg, verbose)
267 for k, i in desired.items():
268 if k not in actual:
269 raise AssertionError(repr(k))
270 assert_equal(actual[k], desired[k], f'key={k!r}\n{err_msg}',
271 verbose)
272 return
273 if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)):
274 assert_equal(len(actual), len(desired), err_msg, verbose)
275 for k in range(len(desired)):
276 assert_equal(actual[k], desired[k], f'item={k!r}\n{err_msg}',
277 verbose)
278 return
279 from numpy.core import ndarray, isscalar, signbit
280 from numpy.lib import iscomplexobj, real, imag
281 if isinstance(actual, ndarray) or isinstance(desired, ndarray):
282 return assert_array_equal(actual, desired, err_msg, verbose)
283 msg = build_err_msg([actual, desired], err_msg, verbose=verbose)
285 # Handle complex numbers: separate into real/imag to handle
286 # nan/inf/negative zero correctly
287 # XXX: catch ValueError for subclasses of ndarray where iscomplex fail
288 try:
289 usecomplex = iscomplexobj(actual) or iscomplexobj(desired)
290 except (ValueError, TypeError):
291 usecomplex = False
293 if usecomplex:
294 if iscomplexobj(actual):
295 actualr = real(actual)
296 actuali = imag(actual)
297 else:
298 actualr = actual
299 actuali = 0
300 if iscomplexobj(desired):
301 desiredr = real(desired)
302 desiredi = imag(desired)
303 else:
304 desiredr = desired
305 desiredi = 0
306 try:
307 assert_equal(actualr, desiredr)
308 assert_equal(actuali, desiredi)
309 except AssertionError:
310 raise AssertionError(msg)
312 # isscalar test to check cases such as [np.nan] != np.nan
313 if isscalar(desired) != isscalar(actual):
314 raise AssertionError(msg)
316 try:
317 isdesnat = isnat(desired)
318 isactnat = isnat(actual)
319 dtypes_match = (np.asarray(desired).dtype.type ==
320 np.asarray(actual).dtype.type)
321 if isdesnat and isactnat:
322 # If both are NaT (and have the same dtype -- datetime or
323 # timedelta) they are considered equal.
324 if dtypes_match:
325 return
326 else:
327 raise AssertionError(msg)
329 except (TypeError, ValueError, NotImplementedError):
330 pass
332 # Inf/nan/negative zero handling
333 try:
334 isdesnan = isnan(desired)
335 isactnan = isnan(actual)
336 if isdesnan and isactnan:
337 return # both nan, so equal
339 # handle signed zero specially for floats
340 array_actual = np.asarray(actual)
341 array_desired = np.asarray(desired)
342 if (array_actual.dtype.char in 'Mm' or
343 array_desired.dtype.char in 'Mm'):
344 # version 1.18
345 # until this version, isnan failed for datetime64 and timedelta64.
346 # Now it succeeds but comparison to scalar with a different type
347 # emits a DeprecationWarning.
348 # Avoid that by skipping the next check
349 raise NotImplementedError('cannot compare to a scalar '
350 'with a different type')
352 if desired == 0 and actual == 0:
353 if not signbit(desired) == signbit(actual):
354 raise AssertionError(msg)
356 except (TypeError, ValueError, NotImplementedError):
357 pass
359 try:
360 # Explicitly use __eq__ for comparison, gh-2552
361 if not (desired == actual):
362 raise AssertionError(msg)
364 except (DeprecationWarning, FutureWarning) as e:
365 # this handles the case when the two types are not even comparable
366 if 'elementwise == comparison' in e.args[0]:
367 raise AssertionError(msg)
368 else:
369 raise
372def print_assert_equal(test_string, actual, desired):
373 """
374 Test if two objects are equal, and print an error message if test fails.
376 The test is performed with ``actual == desired``.
378 Parameters
379 ----------
380 test_string : str
381 The message supplied to AssertionError.
382 actual : object
383 The object to test for equality against `desired`.
384 desired : object
385 The expected result.
387 Examples
388 --------
389 >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 1])
390 >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 2])
391 Traceback (most recent call last):
392 ...
393 AssertionError: Test XYZ of func xyz failed
394 ACTUAL:
395 [0, 1]
396 DESIRED:
397 [0, 2]
399 """
400 __tracebackhide__ = True # Hide traceback for py.test
401 import pprint
403 if not (actual == desired):
404 msg = StringIO()
405 msg.write(test_string)
406 msg.write(' failed\nACTUAL: \n')
407 pprint.pprint(actual, msg)
408 msg.write('DESIRED: \n')
409 pprint.pprint(desired, msg)
410 raise AssertionError(msg.getvalue())
413@np._no_nep50_warning()
414def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True):
415 """
416 Raises an AssertionError if two items are not equal up to desired
417 precision.
419 .. note:: It is recommended to use one of `assert_allclose`,
420 `assert_array_almost_equal_nulp` or `assert_array_max_ulp`
421 instead of this function for more consistent floating point
422 comparisons.
424 The test verifies that the elements of `actual` and `desired` satisfy.
426 ``abs(desired-actual) < float64(1.5 * 10**(-decimal))``
428 That is a looser test than originally documented, but agrees with what the
429 actual implementation in `assert_array_almost_equal` did up to rounding
430 vagaries. An exception is raised at conflicting values. For ndarrays this
431 delegates to assert_array_almost_equal
433 Parameters
434 ----------
435 actual : array_like
436 The object to check.
437 desired : array_like
438 The expected object.
439 decimal : int, optional
440 Desired precision, default is 7.
441 err_msg : str, optional
442 The error message to be printed in case of failure.
443 verbose : bool, optional
444 If True, the conflicting values are appended to the error message.
446 Raises
447 ------
448 AssertionError
449 If actual and desired are not equal up to specified precision.
451 See Also
452 --------
453 assert_allclose: Compare two array_like objects for equality with desired
454 relative and/or absolute precision.
455 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
457 Examples
458 --------
459 >>> from numpy.testing import assert_almost_equal
460 >>> assert_almost_equal(2.3333333333333, 2.33333334)
461 >>> assert_almost_equal(2.3333333333333, 2.33333334, decimal=10)
462 Traceback (most recent call last):
463 ...
464 AssertionError:
465 Arrays are not almost equal to 10 decimals
466 ACTUAL: 2.3333333333333
467 DESIRED: 2.33333334
469 >>> assert_almost_equal(np.array([1.0,2.3333333333333]),
470 ... np.array([1.0,2.33333334]), decimal=9)
471 Traceback (most recent call last):
472 ...
473 AssertionError:
474 Arrays are not almost equal to 9 decimals
475 <BLANKLINE>
476 Mismatched elements: 1 / 2 (50%)
477 Max absolute difference: 6.66669964e-09
478 Max relative difference: 2.85715698e-09
479 x: array([1. , 2.333333333])
480 y: array([1. , 2.33333334])
482 """
483 __tracebackhide__ = True # Hide traceback for py.test
484 from numpy.core import ndarray
485 from numpy.lib import iscomplexobj, real, imag
487 # Handle complex numbers: separate into real/imag to handle
488 # nan/inf/negative zero correctly
489 # XXX: catch ValueError for subclasses of ndarray where iscomplex fail
490 try:
491 usecomplex = iscomplexobj(actual) or iscomplexobj(desired)
492 except ValueError:
493 usecomplex = False
495 def _build_err_msg():
496 header = ('Arrays are not almost equal to %d decimals' % decimal)
497 return build_err_msg([actual, desired], err_msg, verbose=verbose,
498 header=header)
500 if usecomplex:
501 if iscomplexobj(actual):
502 actualr = real(actual)
503 actuali = imag(actual)
504 else:
505 actualr = actual
506 actuali = 0
507 if iscomplexobj(desired):
508 desiredr = real(desired)
509 desiredi = imag(desired)
510 else:
511 desiredr = desired
512 desiredi = 0
513 try:
514 assert_almost_equal(actualr, desiredr, decimal=decimal)
515 assert_almost_equal(actuali, desiredi, decimal=decimal)
516 except AssertionError:
517 raise AssertionError(_build_err_msg())
519 if isinstance(actual, (ndarray, tuple, list)) \
520 or isinstance(desired, (ndarray, tuple, list)):
521 return assert_array_almost_equal(actual, desired, decimal, err_msg)
522 try:
523 # If one of desired/actual is not finite, handle it specially here:
524 # check that both are nan if any is a nan, and test for equality
525 # otherwise
526 if not (isfinite(desired) and isfinite(actual)):
527 if isnan(desired) or isnan(actual):
528 if not (isnan(desired) and isnan(actual)):
529 raise AssertionError(_build_err_msg())
530 else:
531 if not desired == actual:
532 raise AssertionError(_build_err_msg())
533 return
534 except (NotImplementedError, TypeError):
535 pass
536 if abs(desired - actual) >= np.float64(1.5 * 10.0**(-decimal)):
537 raise AssertionError(_build_err_msg())
540@np._no_nep50_warning()
541def assert_approx_equal(actual, desired, significant=7, err_msg='',
542 verbose=True):
543 """
544 Raises an AssertionError if two items are not equal up to significant
545 digits.
547 .. note:: It is recommended to use one of `assert_allclose`,
548 `assert_array_almost_equal_nulp` or `assert_array_max_ulp`
549 instead of this function for more consistent floating point
550 comparisons.
552 Given two numbers, check that they are approximately equal.
553 Approximately equal is defined as the number of significant digits
554 that agree.
556 Parameters
557 ----------
558 actual : scalar
559 The object to check.
560 desired : scalar
561 The expected object.
562 significant : int, optional
563 Desired precision, default is 7.
564 err_msg : str, optional
565 The error message to be printed in case of failure.
566 verbose : bool, optional
567 If True, the conflicting values are appended to the error message.
569 Raises
570 ------
571 AssertionError
572 If actual and desired are not equal up to specified precision.
574 See Also
575 --------
576 assert_allclose: Compare two array_like objects for equality with desired
577 relative and/or absolute precision.
578 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
580 Examples
581 --------
582 >>> np.testing.assert_approx_equal(0.12345677777777e-20, 0.1234567e-20)
583 >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345671e-20,
584 ... significant=8)
585 >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345672e-20,
586 ... significant=8)
587 Traceback (most recent call last):
588 ...
589 AssertionError:
590 Items are not equal to 8 significant digits:
591 ACTUAL: 1.234567e-21
592 DESIRED: 1.2345672e-21
594 the evaluated condition that raises the exception is
596 >>> abs(0.12345670e-20/1e-21 - 0.12345672e-20/1e-21) >= 10**-(8-1)
597 True
599 """
600 __tracebackhide__ = True # Hide traceback for py.test
601 import numpy as np
603 (actual, desired) = map(float, (actual, desired))
604 if desired == actual:
605 return
606 # Normalized the numbers to be in range (-10.0,10.0)
607 # scale = float(pow(10,math.floor(math.log10(0.5*(abs(desired)+abs(actual))))))
608 with np.errstate(invalid='ignore'):
609 scale = 0.5*(np.abs(desired) + np.abs(actual))
610 scale = np.power(10, np.floor(np.log10(scale)))
611 try:
612 sc_desired = desired/scale
613 except ZeroDivisionError:
614 sc_desired = 0.0
615 try:
616 sc_actual = actual/scale
617 except ZeroDivisionError:
618 sc_actual = 0.0
619 msg = build_err_msg(
620 [actual, desired], err_msg,
621 header='Items are not equal to %d significant digits:' % significant,
622 verbose=verbose)
623 try:
624 # If one of desired/actual is not finite, handle it specially here:
625 # check that both are nan if any is a nan, and test for equality
626 # otherwise
627 if not (isfinite(desired) and isfinite(actual)):
628 if isnan(desired) or isnan(actual):
629 if not (isnan(desired) and isnan(actual)):
630 raise AssertionError(msg)
631 else:
632 if not desired == actual:
633 raise AssertionError(msg)
634 return
635 except (TypeError, NotImplementedError):
636 pass
637 if np.abs(sc_desired - sc_actual) >= np.power(10., -(significant-1)):
638 raise AssertionError(msg)
641@np._no_nep50_warning()
642def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
643 precision=6, equal_nan=True, equal_inf=True,
644 *, strict=False):
645 __tracebackhide__ = True # Hide traceback for py.test
646 from numpy.core import (array2string, isnan, inf, bool_, errstate,
647 all, max, object_)
649 x = np.asanyarray(x)
650 y = np.asanyarray(y)
652 # original array for output formatting
653 ox, oy = x, y
655 def isnumber(x):
656 return x.dtype.char in '?bhilqpBHILQPefdgFDG'
658 def istime(x):
659 return x.dtype.char in "Mm"
661 def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
662 """Handling nan/inf.
664 Combine results of running func on x and y, checking that they are True
665 at the same locations.
667 """
668 __tracebackhide__ = True # Hide traceback for py.test
670 x_id = func(x)
671 y_id = func(y)
672 # We include work-arounds here to handle three types of slightly
673 # pathological ndarray subclasses:
674 # (1) all() on `masked` array scalars can return masked arrays, so we
675 # use != True
676 # (2) __eq__ on some ndarray subclasses returns Python booleans
677 # instead of element-wise comparisons, so we cast to bool_() and
678 # use isinstance(..., bool) checks
679 # (3) subclasses with bare-bones __array_function__ implementations may
680 # not implement np.all(), so favor using the .all() method
681 # We are not committed to supporting such subclasses, but it's nice to
682 # support them if possible.
683 if bool_(x_id == y_id).all() != True:
684 msg = build_err_msg([x, y],
685 err_msg + '\nx and y %s location mismatch:'
686 % (hasval), verbose=verbose, header=header,
687 names=('x', 'y'), precision=precision)
688 raise AssertionError(msg)
689 # If there is a scalar, then here we know the array has the same
690 # flag as it everywhere, so we should return the scalar flag.
691 if isinstance(x_id, bool) or x_id.ndim == 0:
692 return bool_(x_id)
693 elif isinstance(y_id, bool) or y_id.ndim == 0:
694 return bool_(y_id)
695 else:
696 return y_id
698 try:
699 if strict:
700 cond = x.shape == y.shape and x.dtype == y.dtype
701 else:
702 cond = (x.shape == () or y.shape == ()) or x.shape == y.shape
703 if not cond:
704 if x.shape != y.shape:
705 reason = f'\n(shapes {x.shape}, {y.shape} mismatch)'
706 else:
707 reason = f'\n(dtypes {x.dtype}, {y.dtype} mismatch)'
708 msg = build_err_msg([x, y],
709 err_msg
710 + reason,
711 verbose=verbose, header=header,
712 names=('x', 'y'), precision=precision)
713 raise AssertionError(msg)
715 flagged = bool_(False)
716 if isnumber(x) and isnumber(y):
717 if equal_nan:
718 flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan')
720 if equal_inf:
721 flagged |= func_assert_same_pos(x, y,
722 func=lambda xy: xy == +inf,
723 hasval='+inf')
724 flagged |= func_assert_same_pos(x, y,
725 func=lambda xy: xy == -inf,
726 hasval='-inf')
728 elif istime(x) and istime(y):
729 # If one is datetime64 and the other timedelta64 there is no point
730 if equal_nan and x.dtype.type == y.dtype.type:
731 flagged = func_assert_same_pos(x, y, func=isnat, hasval="NaT")
733 if flagged.ndim > 0:
734 x, y = x[~flagged], y[~flagged]
735 # Only do the comparison if actual values are left
736 if x.size == 0:
737 return
738 elif flagged:
739 # no sense doing comparison if everything is flagged.
740 return
742 val = comparison(x, y)
744 if isinstance(val, bool):
745 cond = val
746 reduced = array([val])
747 else:
748 reduced = val.ravel()
749 cond = reduced.all()
751 # The below comparison is a hack to ensure that fully masked
752 # results, for which val.ravel().all() returns np.ma.masked,
753 # do not trigger a failure (np.ma.masked != True evaluates as
754 # np.ma.masked, which is falsy).
755 if cond != True:
756 n_mismatch = reduced.size - reduced.sum(dtype=intp)
757 n_elements = flagged.size if flagged.ndim != 0 else reduced.size
758 percent_mismatch = 100 * n_mismatch / n_elements
759 remarks = [
760 'Mismatched elements: {} / {} ({:.3g}%)'.format(
761 n_mismatch, n_elements, percent_mismatch)]
763 with errstate(all='ignore'):
764 # ignore errors for non-numeric types
765 with contextlib.suppress(TypeError):
766 error = abs(x - y)
767 if np.issubdtype(x.dtype, np.unsignedinteger):
768 error2 = abs(y - x)
769 np.minimum(error, error2, out=error)
770 max_abs_error = max(error)
771 if getattr(error, 'dtype', object_) == object_:
772 remarks.append('Max absolute difference: '
773 + str(max_abs_error))
774 else:
775 remarks.append('Max absolute difference: '
776 + array2string(max_abs_error))
778 # note: this definition of relative error matches that one
779 # used by assert_allclose (found in np.isclose)
780 # Filter values where the divisor would be zero
781 nonzero = bool_(y != 0)
782 if all(~nonzero):
783 max_rel_error = array(inf)
784 else:
785 max_rel_error = max(error[nonzero] / abs(y[nonzero]))
786 if getattr(error, 'dtype', object_) == object_:
787 remarks.append('Max relative difference: '
788 + str(max_rel_error))
789 else:
790 remarks.append('Max relative difference: '
791 + array2string(max_rel_error))
793 err_msg += '\n' + '\n'.join(remarks)
794 msg = build_err_msg([ox, oy], err_msg,
795 verbose=verbose, header=header,
796 names=('x', 'y'), precision=precision)
797 raise AssertionError(msg)
798 except ValueError:
799 import traceback
800 efmt = traceback.format_exc()
801 header = f'error during assertion:\n\n{efmt}\n\n{header}'
803 msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header,
804 names=('x', 'y'), precision=precision)
805 raise ValueError(msg)
808def assert_array_equal(x, y, err_msg='', verbose=True, *, strict=False):
809 """
810 Raises an AssertionError if two array_like objects are not equal.
812 Given two array_like objects, check that the shape is equal and all
813 elements of these objects are equal (but see the Notes for the special
814 handling of a scalar). An exception is raised at shape mismatch or
815 conflicting values. In contrast to the standard usage in numpy, NaNs
816 are compared like numbers, no assertion is raised if both objects have
817 NaNs in the same positions.
819 The usual caution for verifying equality with floating point numbers is
820 advised.
822 Parameters
823 ----------
824 x : array_like
825 The actual object to check.
826 y : array_like
827 The desired, expected object.
828 err_msg : str, optional
829 The error message to be printed in case of failure.
830 verbose : bool, optional
831 If True, the conflicting values are appended to the error message.
832 strict : bool, optional
833 If True, raise an AssertionError when either the shape or the data
834 type of the array_like objects does not match. The special
835 handling for scalars mentioned in the Notes section is disabled.
837 .. versionadded:: 1.24.0
839 Raises
840 ------
841 AssertionError
842 If actual and desired objects are not equal.
844 See Also
845 --------
846 assert_allclose: Compare two array_like objects for equality with desired
847 relative and/or absolute precision.
848 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
850 Notes
851 -----
852 When one of `x` and `y` is a scalar and the other is array_like, the
853 function checks that each element of the array_like object is equal to
854 the scalar. This behaviour can be disabled with the `strict` parameter.
856 Examples
857 --------
858 The first assert does not raise an exception:
860 >>> np.testing.assert_array_equal([1.0,2.33333,np.nan],
861 ... [np.exp(0),2.33333, np.nan])
863 Assert fails with numerical imprecision with floats:
865 >>> np.testing.assert_array_equal([1.0,np.pi,np.nan],
866 ... [1, np.sqrt(np.pi)**2, np.nan])
867 Traceback (most recent call last):
868 ...
869 AssertionError:
870 Arrays are not equal
871 <BLANKLINE>
872 Mismatched elements: 1 / 3 (33.3%)
873 Max absolute difference: 4.4408921e-16
874 Max relative difference: 1.41357986e-16
875 x: array([1. , 3.141593, nan])
876 y: array([1. , 3.141593, nan])
878 Use `assert_allclose` or one of the nulp (number of floating point values)
879 functions for these cases instead:
881 >>> np.testing.assert_allclose([1.0,np.pi,np.nan],
882 ... [1, np.sqrt(np.pi)**2, np.nan],
883 ... rtol=1e-10, atol=0)
885 As mentioned in the Notes section, `assert_array_equal` has special
886 handling for scalars. Here the test checks that each value in `x` is 3:
888 >>> x = np.full((2, 5), fill_value=3)
889 >>> np.testing.assert_array_equal(x, 3)
891 Use `strict` to raise an AssertionError when comparing a scalar with an
892 array:
894 >>> np.testing.assert_array_equal(x, 3, strict=True)
895 Traceback (most recent call last):
896 ...
897 AssertionError:
898 Arrays are not equal
899 <BLANKLINE>
900 (shapes (2, 5), () mismatch)
901 x: array([[3, 3, 3, 3, 3],
902 [3, 3, 3, 3, 3]])
903 y: array(3)
905 The `strict` parameter also ensures that the array data types match:
907 >>> x = np.array([2, 2, 2])
908 >>> y = np.array([2., 2., 2.], dtype=np.float32)
909 >>> np.testing.assert_array_equal(x, y, strict=True)
910 Traceback (most recent call last):
911 ...
912 AssertionError:
913 Arrays are not equal
914 <BLANKLINE>
915 (dtypes int64, float32 mismatch)
916 x: array([2, 2, 2])
917 y: array([2., 2., 2.], dtype=float32)
918 """
919 __tracebackhide__ = True # Hide traceback for py.test
920 assert_array_compare(operator.__eq__, x, y, err_msg=err_msg,
921 verbose=verbose, header='Arrays are not equal',
922 strict=strict)
925@np._no_nep50_warning()
926def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
927 """
928 Raises an AssertionError if two objects are not equal up to desired
929 precision.
931 .. note:: It is recommended to use one of `assert_allclose`,
932 `assert_array_almost_equal_nulp` or `assert_array_max_ulp`
933 instead of this function for more consistent floating point
934 comparisons.
936 The test verifies identical shapes and that the elements of ``actual`` and
937 ``desired`` satisfy.
939 ``abs(desired-actual) < 1.5 * 10**(-decimal)``
941 That is a looser test than originally documented, but agrees with what the
942 actual implementation did up to rounding vagaries. An exception is raised
943 at shape mismatch or conflicting values. In contrast to the standard usage
944 in numpy, NaNs are compared like numbers, no assertion is raised if both
945 objects have NaNs in the same positions.
947 Parameters
948 ----------
949 x : array_like
950 The actual object to check.
951 y : array_like
952 The desired, expected object.
953 decimal : int, optional
954 Desired precision, default is 6.
955 err_msg : str, optional
956 The error message to be printed in case of failure.
957 verbose : bool, optional
958 If True, the conflicting values are appended to the error message.
960 Raises
961 ------
962 AssertionError
963 If actual and desired are not equal up to specified precision.
965 See Also
966 --------
967 assert_allclose: Compare two array_like objects for equality with desired
968 relative and/or absolute precision.
969 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
971 Examples
972 --------
973 the first assert does not raise an exception
975 >>> np.testing.assert_array_almost_equal([1.0,2.333,np.nan],
976 ... [1.0,2.333,np.nan])
978 >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan],
979 ... [1.0,2.33339,np.nan], decimal=5)
980 Traceback (most recent call last):
981 ...
982 AssertionError:
983 Arrays are not almost equal to 5 decimals
984 <BLANKLINE>
985 Mismatched elements: 1 / 3 (33.3%)
986 Max absolute difference: 6.e-05
987 Max relative difference: 2.57136612e-05
988 x: array([1. , 2.33333, nan])
989 y: array([1. , 2.33339, nan])
991 >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan],
992 ... [1.0,2.33333, 5], decimal=5)
993 Traceback (most recent call last):
994 ...
995 AssertionError:
996 Arrays are not almost equal to 5 decimals
997 <BLANKLINE>
998 x and y nan location mismatch:
999 x: array([1. , 2.33333, nan])
1000 y: array([1. , 2.33333, 5. ])
1002 """
1003 __tracebackhide__ = True # Hide traceback for py.test
1004 from numpy.core import number, float_, result_type
1005 from numpy.core.numerictypes import issubdtype
1006 from numpy.core.fromnumeric import any as npany
1008 def compare(x, y):
1009 try:
1010 if npany(isinf(x)) or npany(isinf(y)):
1011 xinfid = isinf(x)
1012 yinfid = isinf(y)
1013 if not (xinfid == yinfid).all():
1014 return False
1015 # if one item, x and y is +- inf
1016 if x.size == y.size == 1:
1017 return x == y
1018 x = x[~xinfid]
1019 y = y[~yinfid]
1020 except (TypeError, NotImplementedError):
1021 pass
1023 # make sure y is an inexact type to avoid abs(MIN_INT); will cause
1024 # casting of x later.
1025 dtype = result_type(y, 1.)
1026 y = np.asanyarray(y, dtype)
1027 z = abs(x - y)
1029 if not issubdtype(z.dtype, number):
1030 z = z.astype(float_) # handle object arrays
1032 return z < 1.5 * 10.0**(-decimal)
1034 assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
1035 header=('Arrays are not almost equal to %d decimals' % decimal),
1036 precision=decimal)
1039def assert_array_less(x, y, err_msg='', verbose=True):
1040 """
1041 Raises an AssertionError if two array_like objects are not ordered by less
1042 than.
1044 Given two array_like objects, check that the shape is equal and all
1045 elements of the first object are strictly smaller than those of the
1046 second object. An exception is raised at shape mismatch or incorrectly
1047 ordered values. Shape mismatch does not raise if an object has zero
1048 dimension. In contrast to the standard usage in numpy, NaNs are
1049 compared, no assertion is raised if both objects have NaNs in the same
1050 positions.
1052 Parameters
1053 ----------
1054 x : array_like
1055 The smaller object to check.
1056 y : array_like
1057 The larger object to compare.
1058 err_msg : string
1059 The error message to be printed in case of failure.
1060 verbose : bool
1061 If True, the conflicting values are appended to the error message.
1063 Raises
1064 ------
1065 AssertionError
1066 If x is not strictly smaller than y, element-wise.
1068 See Also
1069 --------
1070 assert_array_equal: tests objects for equality
1071 assert_array_almost_equal: test objects for equality up to precision
1073 Examples
1074 --------
1075 >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1.1, 2.0, np.nan])
1076 >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1, 2.0, np.nan])
1077 Traceback (most recent call last):
1078 ...
1079 AssertionError:
1080 Arrays are not less-ordered
1081 <BLANKLINE>
1082 Mismatched elements: 1 / 3 (33.3%)
1083 Max absolute difference: 1.
1084 Max relative difference: 0.5
1085 x: array([ 1., 1., nan])
1086 y: array([ 1., 2., nan])
1088 >>> np.testing.assert_array_less([1.0, 4.0], 3)
1089 Traceback (most recent call last):
1090 ...
1091 AssertionError:
1092 Arrays are not less-ordered
1093 <BLANKLINE>
1094 Mismatched elements: 1 / 2 (50%)
1095 Max absolute difference: 2.
1096 Max relative difference: 0.66666667
1097 x: array([1., 4.])
1098 y: array(3)
1100 >>> np.testing.assert_array_less([1.0, 2.0, 3.0], [4])
1101 Traceback (most recent call last):
1102 ...
1103 AssertionError:
1104 Arrays are not less-ordered
1105 <BLANKLINE>
1106 (shapes (3,), (1,) mismatch)
1107 x: array([1., 2., 3.])
1108 y: array([4])
1110 """
1111 __tracebackhide__ = True # Hide traceback for py.test
1112 assert_array_compare(operator.__lt__, x, y, err_msg=err_msg,
1113 verbose=verbose,
1114 header='Arrays are not less-ordered',
1115 equal_inf=False)
1118def runstring(astr, dict):
1119 exec(astr, dict)
1122def assert_string_equal(actual, desired):
1123 """
1124 Test if two strings are equal.
1126 If the given strings are equal, `assert_string_equal` does nothing.
1127 If they are not equal, an AssertionError is raised, and the diff
1128 between the strings is shown.
1130 Parameters
1131 ----------
1132 actual : str
1133 The string to test for equality against the expected string.
1134 desired : str
1135 The expected string.
1137 Examples
1138 --------
1139 >>> np.testing.assert_string_equal('abc', 'abc')
1140 >>> np.testing.assert_string_equal('abc', 'abcd')
1141 Traceback (most recent call last):
1142 File "<stdin>", line 1, in <module>
1143 ...
1144 AssertionError: Differences in strings:
1145 - abc+ abcd? +
1147 """
1148 # delay import of difflib to reduce startup time
1149 __tracebackhide__ = True # Hide traceback for py.test
1150 import difflib
1152 if not isinstance(actual, str):
1153 raise AssertionError(repr(type(actual)))
1154 if not isinstance(desired, str):
1155 raise AssertionError(repr(type(desired)))
1156 if desired == actual:
1157 return
1159 diff = list(difflib.Differ().compare(actual.splitlines(True),
1160 desired.splitlines(True)))
1161 diff_list = []
1162 while diff:
1163 d1 = diff.pop(0)
1164 if d1.startswith(' '):
1165 continue
1166 if d1.startswith('- '):
1167 l = [d1]
1168 d2 = diff.pop(0)
1169 if d2.startswith('? '):
1170 l.append(d2)
1171 d2 = diff.pop(0)
1172 if not d2.startswith('+ '):
1173 raise AssertionError(repr(d2))
1174 l.append(d2)
1175 if diff:
1176 d3 = diff.pop(0)
1177 if d3.startswith('? '):
1178 l.append(d3)
1179 else:
1180 diff.insert(0, d3)
1181 if d2[2:] == d1[2:]:
1182 continue
1183 diff_list.extend(l)
1184 continue
1185 raise AssertionError(repr(d1))
1186 if not diff_list:
1187 return
1188 msg = f"Differences in strings:\n{''.join(diff_list).rstrip()}"
1189 if actual != desired:
1190 raise AssertionError(msg)
1193def rundocs(filename=None, raise_on_error=True):
1194 """
1195 Run doctests found in the given file.
1197 By default `rundocs` raises an AssertionError on failure.
1199 Parameters
1200 ----------
1201 filename : str
1202 The path to the file for which the doctests are run.
1203 raise_on_error : bool
1204 Whether to raise an AssertionError when a doctest fails. Default is
1205 True.
1207 Notes
1208 -----
1209 The doctests can be run by the user/developer by adding the ``doctests``
1210 argument to the ``test()`` call. For example, to run all tests (including
1211 doctests) for `numpy.lib`:
1213 >>> np.lib.test(doctests=True) # doctest: +SKIP
1214 """
1215 from numpy.distutils.misc_util import exec_mod_from_location
1216 import doctest
1217 if filename is None:
1218 f = sys._getframe(1)
1219 filename = f.f_globals['__file__']
1220 name = os.path.splitext(os.path.basename(filename))[0]
1221 m = exec_mod_from_location(name, filename)
1223 tests = doctest.DocTestFinder().find(m)
1224 runner = doctest.DocTestRunner(verbose=False)
1226 msg = []
1227 if raise_on_error:
1228 out = lambda s: msg.append(s)
1229 else:
1230 out = None
1232 for test in tests:
1233 runner.run(test, out=out)
1235 if runner.failures > 0 and raise_on_error:
1236 raise AssertionError("Some doctests failed:\n%s" % "\n".join(msg))
1239def check_support_sve():
1240 """
1241 gh-22982
1242 """
1244 import subprocess
1245 cmd = 'lscpu'
1246 try:
1247 output = subprocess.run(cmd, capture_output=True, text=True)
1248 return 'sve' in output.stdout
1249 except OSError:
1250 return False
1253_SUPPORTS_SVE = check_support_sve()
1255#
1256# assert_raises and assert_raises_regex are taken from unittest.
1257#
1258import unittest
1261class _Dummy(unittest.TestCase):
1262 def nop(self):
1263 pass
1266_d = _Dummy('nop')
1269def assert_raises(*args, **kwargs):
1270 """
1271 assert_raises(exception_class, callable, *args, **kwargs)
1272 assert_raises(exception_class)
1274 Fail unless an exception of class exception_class is thrown
1275 by callable when invoked with arguments args and keyword
1276 arguments kwargs. If a different type of exception is
1277 thrown, it will not be caught, and the test case will be
1278 deemed to have suffered an error, exactly as for an
1279 unexpected exception.
1281 Alternatively, `assert_raises` can be used as a context manager:
1283 >>> from numpy.testing import assert_raises
1284 >>> with assert_raises(ZeroDivisionError):
1285 ... 1 / 0
1287 is equivalent to
1289 >>> def div(x, y):
1290 ... return x / y
1291 >>> assert_raises(ZeroDivisionError, div, 1, 0)
1293 """
1294 __tracebackhide__ = True # Hide traceback for py.test
1295 return _d.assertRaises(*args, **kwargs)
1298def assert_raises_regex(exception_class, expected_regexp, *args, **kwargs):
1299 """
1300 assert_raises_regex(exception_class, expected_regexp, callable, *args,
1301 **kwargs)
1302 assert_raises_regex(exception_class, expected_regexp)
1304 Fail unless an exception of class exception_class and with message that
1305 matches expected_regexp is thrown by callable when invoked with arguments
1306 args and keyword arguments kwargs.
1308 Alternatively, can be used as a context manager like `assert_raises`.
1310 Notes
1311 -----
1312 .. versionadded:: 1.9.0
1314 """
1315 __tracebackhide__ = True # Hide traceback for py.test
1316 return _d.assertRaisesRegex(exception_class, expected_regexp, *args, **kwargs)
1319def decorate_methods(cls, decorator, testmatch=None):
1320 """
1321 Apply a decorator to all methods in a class matching a regular expression.
1323 The given decorator is applied to all public methods of `cls` that are
1324 matched by the regular expression `testmatch`
1325 (``testmatch.search(methodname)``). Methods that are private, i.e. start
1326 with an underscore, are ignored.
1328 Parameters
1329 ----------
1330 cls : class
1331 Class whose methods to decorate.
1332 decorator : function
1333 Decorator to apply to methods
1334 testmatch : compiled regexp or str, optional
1335 The regular expression. Default value is None, in which case the
1336 nose default (``re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep)``)
1337 is used.
1338 If `testmatch` is a string, it is compiled to a regular expression
1339 first.
1341 """
1342 if testmatch is None:
1343 testmatch = re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep)
1344 else:
1345 testmatch = re.compile(testmatch)
1346 cls_attr = cls.__dict__
1348 # delayed import to reduce startup time
1349 from inspect import isfunction
1351 methods = [_m for _m in cls_attr.values() if isfunction(_m)]
1352 for function in methods:
1353 try:
1354 if hasattr(function, 'compat_func_name'):
1355 funcname = function.compat_func_name
1356 else:
1357 funcname = function.__name__
1358 except AttributeError:
1359 # not a function
1360 continue
1361 if testmatch.search(funcname) and not funcname.startswith('_'):
1362 setattr(cls, funcname, decorator(function))
1363 return
1366def measure(code_str, times=1, label=None):
1367 """
1368 Return elapsed time for executing code in the namespace of the caller.
1370 The supplied code string is compiled with the Python builtin ``compile``.
1371 The precision of the timing is 10 milli-seconds. If the code will execute
1372 fast on this timescale, it can be executed many times to get reasonable
1373 timing accuracy.
1375 Parameters
1376 ----------
1377 code_str : str
1378 The code to be timed.
1379 times : int, optional
1380 The number of times the code is executed. Default is 1. The code is
1381 only compiled once.
1382 label : str, optional
1383 A label to identify `code_str` with. This is passed into ``compile``
1384 as the second argument (for run-time error messages).
1386 Returns
1387 -------
1388 elapsed : float
1389 Total elapsed time in seconds for executing `code_str` `times` times.
1391 Examples
1392 --------
1393 >>> times = 10
1394 >>> etime = np.testing.measure('for i in range(1000): np.sqrt(i**2)', times=times)
1395 >>> print("Time for a single execution : ", etime / times, "s") # doctest: +SKIP
1396 Time for a single execution : 0.005 s
1398 """
1399 frame = sys._getframe(1)
1400 locs, globs = frame.f_locals, frame.f_globals
1402 code = compile(code_str, f'Test name: {label} ', 'exec')
1403 i = 0
1404 elapsed = jiffies()
1405 while i < times:
1406 i += 1
1407 exec(code, globs, locs)
1408 elapsed = jiffies() - elapsed
1409 return 0.01*elapsed
1412def _assert_valid_refcount(op):
1413 """
1414 Check that ufuncs don't mishandle refcount of object `1`.
1415 Used in a few regression tests.
1416 """
1417 if not HAS_REFCOUNT:
1418 return True
1420 import gc
1421 import numpy as np
1423 b = np.arange(100*100).reshape(100, 100)
1424 c = b
1425 i = 1
1427 gc.disable()
1428 try:
1429 rc = sys.getrefcount(i)
1430 for j in range(15):
1431 d = op(b, c)
1432 assert_(sys.getrefcount(i) >= rc)
1433 finally:
1434 gc.enable()
1435 del d # for pyflakes
1438def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
1439 err_msg='', verbose=True):
1440 """
1441 Raises an AssertionError if two objects are not equal up to desired
1442 tolerance.
1444 Given two array_like objects, check that their shapes and all elements
1445 are equal (but see the Notes for the special handling of a scalar). An
1446 exception is raised if the shapes mismatch or any values conflict. In
1447 contrast to the standard usage in numpy, NaNs are compared like numbers,
1448 no assertion is raised if both objects have NaNs in the same positions.
1450 The test is equivalent to ``allclose(actual, desired, rtol, atol)`` (note
1451 that ``allclose`` has different default values). It compares the difference
1452 between `actual` and `desired` to ``atol + rtol * abs(desired)``.
1454 .. versionadded:: 1.5.0
1456 Parameters
1457 ----------
1458 actual : array_like
1459 Array obtained.
1460 desired : array_like
1461 Array desired.
1462 rtol : float, optional
1463 Relative tolerance.
1464 atol : float, optional
1465 Absolute tolerance.
1466 equal_nan : bool, optional.
1467 If True, NaNs will compare equal.
1468 err_msg : str, optional
1469 The error message to be printed in case of failure.
1470 verbose : bool, optional
1471 If True, the conflicting values are appended to the error message.
1473 Raises
1474 ------
1475 AssertionError
1476 If actual and desired are not equal up to specified precision.
1478 See Also
1479 --------
1480 assert_array_almost_equal_nulp, assert_array_max_ulp
1482 Notes
1483 -----
1484 When one of `actual` and `desired` is a scalar and the other is
1485 array_like, the function checks that each element of the array_like
1486 object is equal to the scalar.
1488 Examples
1489 --------
1490 >>> x = [1e-5, 1e-3, 1e-1]
1491 >>> y = np.arccos(np.cos(x))
1492 >>> np.testing.assert_allclose(x, y, rtol=1e-5, atol=0)
1494 """
1495 __tracebackhide__ = True # Hide traceback for py.test
1496 import numpy as np
1498 def compare(x, y):
1499 return np.core.numeric.isclose(x, y, rtol=rtol, atol=atol,
1500 equal_nan=equal_nan)
1502 actual, desired = np.asanyarray(actual), np.asanyarray(desired)
1503 header = f'Not equal to tolerance rtol={rtol:g}, atol={atol:g}'
1504 assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
1505 verbose=verbose, header=header, equal_nan=equal_nan)
1508def assert_array_almost_equal_nulp(x, y, nulp=1):
1509 """
1510 Compare two arrays relatively to their spacing.
1512 This is a relatively robust method to compare two arrays whose amplitude
1513 is variable.
1515 Parameters
1516 ----------
1517 x, y : array_like
1518 Input arrays.
1519 nulp : int, optional
1520 The maximum number of unit in the last place for tolerance (see Notes).
1521 Default is 1.
1523 Returns
1524 -------
1525 None
1527 Raises
1528 ------
1529 AssertionError
1530 If the spacing between `x` and `y` for one or more elements is larger
1531 than `nulp`.
1533 See Also
1534 --------
1535 assert_array_max_ulp : Check that all items of arrays differ in at most
1536 N Units in the Last Place.
1537 spacing : Return the distance between x and the nearest adjacent number.
1539 Notes
1540 -----
1541 An assertion is raised if the following condition is not met::
1543 abs(x - y) <= nulp * spacing(maximum(abs(x), abs(y)))
1545 Examples
1546 --------
1547 >>> x = np.array([1., 1e-10, 1e-20])
1548 >>> eps = np.finfo(x.dtype).eps
1549 >>> np.testing.assert_array_almost_equal_nulp(x, x*eps/2 + x)
1551 >>> np.testing.assert_array_almost_equal_nulp(x, x*eps + x)
1552 Traceback (most recent call last):
1553 ...
1554 AssertionError: X and Y are not equal to 1 ULP (max is 2)
1556 """
1557 __tracebackhide__ = True # Hide traceback for py.test
1558 import numpy as np
1559 ax = np.abs(x)
1560 ay = np.abs(y)
1561 ref = nulp * np.spacing(np.where(ax > ay, ax, ay))
1562 if not np.all(np.abs(x-y) <= ref):
1563 if np.iscomplexobj(x) or np.iscomplexobj(y):
1564 msg = "X and Y are not equal to %d ULP" % nulp
1565 else:
1566 max_nulp = np.max(nulp_diff(x, y))
1567 msg = "X and Y are not equal to %d ULP (max is %g)" % (nulp, max_nulp)
1568 raise AssertionError(msg)
1571def assert_array_max_ulp(a, b, maxulp=1, dtype=None):
1572 """
1573 Check that all items of arrays differ in at most N Units in the Last Place.
1575 Parameters
1576 ----------
1577 a, b : array_like
1578 Input arrays to be compared.
1579 maxulp : int, optional
1580 The maximum number of units in the last place that elements of `a` and
1581 `b` can differ. Default is 1.
1582 dtype : dtype, optional
1583 Data-type to convert `a` and `b` to if given. Default is None.
1585 Returns
1586 -------
1587 ret : ndarray
1588 Array containing number of representable floating point numbers between
1589 items in `a` and `b`.
1591 Raises
1592 ------
1593 AssertionError
1594 If one or more elements differ by more than `maxulp`.
1596 Notes
1597 -----
1598 For computing the ULP difference, this API does not differentiate between
1599 various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000
1600 is zero).
1602 See Also
1603 --------
1604 assert_array_almost_equal_nulp : Compare two arrays relatively to their
1605 spacing.
1607 Examples
1608 --------
1609 >>> a = np.linspace(0., 1., 100)
1610 >>> res = np.testing.assert_array_max_ulp(a, np.arcsin(np.sin(a)))
1612 """
1613 __tracebackhide__ = True # Hide traceback for py.test
1614 import numpy as np
1615 ret = nulp_diff(a, b, dtype)
1616 if not np.all(ret <= maxulp):
1617 raise AssertionError("Arrays are not almost equal up to %g "
1618 "ULP (max difference is %g ULP)" %
1619 (maxulp, np.max(ret)))
1620 return ret
1623def nulp_diff(x, y, dtype=None):
1624 """For each item in x and y, return the number of representable floating
1625 points between them.
1627 Parameters
1628 ----------
1629 x : array_like
1630 first input array
1631 y : array_like
1632 second input array
1633 dtype : dtype, optional
1634 Data-type to convert `x` and `y` to if given. Default is None.
1636 Returns
1637 -------
1638 nulp : array_like
1639 number of representable floating point numbers between each item in x
1640 and y.
1642 Notes
1643 -----
1644 For computing the ULP difference, this API does not differentiate between
1645 various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000
1646 is zero).
1648 Examples
1649 --------
1650 # By definition, epsilon is the smallest number such as 1 + eps != 1, so
1651 # there should be exactly one ULP between 1 and 1 + eps
1652 >>> nulp_diff(1, 1 + np.finfo(x.dtype).eps)
1653 1.0
1654 """
1655 import numpy as np
1656 if dtype:
1657 x = np.asarray(x, dtype=dtype)
1658 y = np.asarray(y, dtype=dtype)
1659 else:
1660 x = np.asarray(x)
1661 y = np.asarray(y)
1663 t = np.common_type(x, y)
1664 if np.iscomplexobj(x) or np.iscomplexobj(y):
1665 raise NotImplementedError("_nulp not implemented for complex array")
1667 x = np.array([x], dtype=t)
1668 y = np.array([y], dtype=t)
1670 x[np.isnan(x)] = np.nan
1671 y[np.isnan(y)] = np.nan
1673 if not x.shape == y.shape:
1674 raise ValueError("x and y do not have the same shape: %s - %s" %
1675 (x.shape, y.shape))
1677 def _diff(rx, ry, vdt):
1678 diff = np.asarray(rx-ry, dtype=vdt)
1679 return np.abs(diff)
1681 rx = integer_repr(x)
1682 ry = integer_repr(y)
1683 return _diff(rx, ry, t)
1686def _integer_repr(x, vdt, comp):
1687 # Reinterpret binary representation of the float as sign-magnitude:
1688 # take into account two-complement representation
1689 # See also
1690 # https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/
1691 rx = x.view(vdt)
1692 if not (rx.size == 1):
1693 rx[rx < 0] = comp - rx[rx < 0]
1694 else:
1695 if rx < 0:
1696 rx = comp - rx
1698 return rx
1701def integer_repr(x):
1702 """Return the signed-magnitude interpretation of the binary representation
1703 of x."""
1704 import numpy as np
1705 if x.dtype == np.float16:
1706 return _integer_repr(x, np.int16, np.int16(-2**15))
1707 elif x.dtype == np.float32:
1708 return _integer_repr(x, np.int32, np.int32(-2**31))
1709 elif x.dtype == np.float64:
1710 return _integer_repr(x, np.int64, np.int64(-2**63))
1711 else:
1712 raise ValueError(f'Unsupported dtype {x.dtype}')
1715@contextlib.contextmanager
1716def _assert_warns_context(warning_class, name=None):
1717 __tracebackhide__ = True # Hide traceback for py.test
1718 with suppress_warnings() as sup:
1719 l = sup.record(warning_class)
1720 yield
1721 if not len(l) > 0:
1722 name_str = f' when calling {name}' if name is not None else ''
1723 raise AssertionError("No warning raised" + name_str)
1726def assert_warns(warning_class, *args, **kwargs):
1727 """
1728 Fail unless the given callable throws the specified warning.
1730 A warning of class warning_class should be thrown by the callable when
1731 invoked with arguments args and keyword arguments kwargs.
1732 If a different type of warning is thrown, it will not be caught.
1734 If called with all arguments other than the warning class omitted, may be
1735 used as a context manager:
1737 with assert_warns(SomeWarning):
1738 do_something()
1740 The ability to be used as a context manager is new in NumPy v1.11.0.
1742 .. versionadded:: 1.4.0
1744 Parameters
1745 ----------
1746 warning_class : class
1747 The class defining the warning that `func` is expected to throw.
1748 func : callable, optional
1749 Callable to test
1750 *args : Arguments
1751 Arguments for `func`.
1752 **kwargs : Kwargs
1753 Keyword arguments for `func`.
1755 Returns
1756 -------
1757 The value returned by `func`.
1759 Examples
1760 --------
1761 >>> import warnings
1762 >>> def deprecated_func(num):
1763 ... warnings.warn("Please upgrade", DeprecationWarning)
1764 ... return num*num
1765 >>> with np.testing.assert_warns(DeprecationWarning):
1766 ... assert deprecated_func(4) == 16
1767 >>> # or passing a func
1768 >>> ret = np.testing.assert_warns(DeprecationWarning, deprecated_func, 4)
1769 >>> assert ret == 16
1770 """
1771 if not args:
1772 return _assert_warns_context(warning_class)
1774 func = args[0]
1775 args = args[1:]
1776 with _assert_warns_context(warning_class, name=func.__name__):
1777 return func(*args, **kwargs)
1780@contextlib.contextmanager
1781def _assert_no_warnings_context(name=None):
1782 __tracebackhide__ = True # Hide traceback for py.test
1783 with warnings.catch_warnings(record=True) as l:
1784 warnings.simplefilter('always')
1785 yield
1786 if len(l) > 0:
1787 name_str = f' when calling {name}' if name is not None else ''
1788 raise AssertionError(f'Got warnings{name_str}: {l}')
1791def assert_no_warnings(*args, **kwargs):
1792 """
1793 Fail if the given callable produces any warnings.
1795 If called with all arguments omitted, may be used as a context manager:
1797 with assert_no_warnings():
1798 do_something()
1800 The ability to be used as a context manager is new in NumPy v1.11.0.
1802 .. versionadded:: 1.7.0
1804 Parameters
1805 ----------
1806 func : callable
1807 The callable to test.
1808 \\*args : Arguments
1809 Arguments passed to `func`.
1810 \\*\\*kwargs : Kwargs
1811 Keyword arguments passed to `func`.
1813 Returns
1814 -------
1815 The value returned by `func`.
1817 """
1818 if not args:
1819 return _assert_no_warnings_context()
1821 func = args[0]
1822 args = args[1:]
1823 with _assert_no_warnings_context(name=func.__name__):
1824 return func(*args, **kwargs)
1827def _gen_alignment_data(dtype=float32, type='binary', max_size=24):
1828 """
1829 generator producing data with different alignment and offsets
1830 to test simd vectorization
1832 Parameters
1833 ----------
1834 dtype : dtype
1835 data type to produce
1836 type : string
1837 'unary': create data for unary operations, creates one input
1838 and output array
1839 'binary': create data for unary operations, creates two input
1840 and output array
1841 max_size : integer
1842 maximum size of data to produce
1844 Returns
1845 -------
1846 if type is 'unary' yields one output, one input array and a message
1847 containing information on the data
1848 if type is 'binary' yields one output array, two input array and a message
1849 containing information on the data
1851 """
1852 ufmt = 'unary offset=(%d, %d), size=%d, dtype=%r, %s'
1853 bfmt = 'binary offset=(%d, %d, %d), size=%d, dtype=%r, %s'
1854 for o in range(3):
1855 for s in range(o + 2, max(o + 3, max_size)):
1856 if type == 'unary':
1857 inp = lambda: arange(s, dtype=dtype)[o:]
1858 out = empty((s,), dtype=dtype)[o:]
1859 yield out, inp(), ufmt % (o, o, s, dtype, 'out of place')
1860 d = inp()
1861 yield d, d, ufmt % (o, o, s, dtype, 'in place')
1862 yield out[1:], inp()[:-1], ufmt % \
1863 (o + 1, o, s - 1, dtype, 'out of place')
1864 yield out[:-1], inp()[1:], ufmt % \
1865 (o, o + 1, s - 1, dtype, 'out of place')
1866 yield inp()[:-1], inp()[1:], ufmt % \
1867 (o, o + 1, s - 1, dtype, 'aliased')
1868 yield inp()[1:], inp()[:-1], ufmt % \
1869 (o + 1, o, s - 1, dtype, 'aliased')
1870 if type == 'binary':
1871 inp1 = lambda: arange(s, dtype=dtype)[o:]
1872 inp2 = lambda: arange(s, dtype=dtype)[o:]
1873 out = empty((s,), dtype=dtype)[o:]
1874 yield out, inp1(), inp2(), bfmt % \
1875 (o, o, o, s, dtype, 'out of place')
1876 d = inp1()
1877 yield d, d, inp2(), bfmt % \
1878 (o, o, o, s, dtype, 'in place1')
1879 d = inp2()
1880 yield d, inp1(), d, bfmt % \
1881 (o, o, o, s, dtype, 'in place2')
1882 yield out[1:], inp1()[:-1], inp2()[:-1], bfmt % \
1883 (o + 1, o, o, s - 1, dtype, 'out of place')
1884 yield out[:-1], inp1()[1:], inp2()[:-1], bfmt % \
1885 (o, o + 1, o, s - 1, dtype, 'out of place')
1886 yield out[:-1], inp1()[:-1], inp2()[1:], bfmt % \
1887 (o, o, o + 1, s - 1, dtype, 'out of place')
1888 yield inp1()[1:], inp1()[:-1], inp2()[:-1], bfmt % \
1889 (o + 1, o, o, s - 1, dtype, 'aliased')
1890 yield inp1()[:-1], inp1()[1:], inp2()[:-1], bfmt % \
1891 (o, o + 1, o, s - 1, dtype, 'aliased')
1892 yield inp1()[:-1], inp1()[:-1], inp2()[1:], bfmt % \
1893 (o, o, o + 1, s - 1, dtype, 'aliased')
1896class IgnoreException(Exception):
1897 "Ignoring this exception due to disabled feature"
1898 pass
1901@contextlib.contextmanager
1902def tempdir(*args, **kwargs):
1903 """Context manager to provide a temporary test folder.
1905 All arguments are passed as this to the underlying tempfile.mkdtemp
1906 function.
1908 """
1909 tmpdir = mkdtemp(*args, **kwargs)
1910 try:
1911 yield tmpdir
1912 finally:
1913 shutil.rmtree(tmpdir)
1916@contextlib.contextmanager
1917def temppath(*args, **kwargs):
1918 """Context manager for temporary files.
1920 Context manager that returns the path to a closed temporary file. Its
1921 parameters are the same as for tempfile.mkstemp and are passed directly
1922 to that function. The underlying file is removed when the context is
1923 exited, so it should be closed at that time.
1925 Windows does not allow a temporary file to be opened if it is already
1926 open, so the underlying file must be closed after opening before it
1927 can be opened again.
1929 """
1930 fd, path = mkstemp(*args, **kwargs)
1931 os.close(fd)
1932 try:
1933 yield path
1934 finally:
1935 os.remove(path)
1938class clear_and_catch_warnings(warnings.catch_warnings):
1939 """ Context manager that resets warning registry for catching warnings
1941 Warnings can be slippery, because, whenever a warning is triggered, Python
1942 adds a ``__warningregistry__`` member to the *calling* module. This makes
1943 it impossible to retrigger the warning in this module, whatever you put in
1944 the warnings filters. This context manager accepts a sequence of `modules`
1945 as a keyword argument to its constructor and:
1947 * stores and removes any ``__warningregistry__`` entries in given `modules`
1948 on entry;
1949 * resets ``__warningregistry__`` to its previous state on exit.
1951 This makes it possible to trigger any warning afresh inside the context
1952 manager without disturbing the state of warnings outside.
1954 For compatibility with Python 3.0, please consider all arguments to be
1955 keyword-only.
1957 Parameters
1958 ----------
1959 record : bool, optional
1960 Specifies whether warnings should be captured by a custom
1961 implementation of ``warnings.showwarning()`` and be appended to a list
1962 returned by the context manager. Otherwise None is returned by the
1963 context manager. The objects appended to the list are arguments whose
1964 attributes mirror the arguments to ``showwarning()``.
1965 modules : sequence, optional
1966 Sequence of modules for which to reset warnings registry on entry and
1967 restore on exit. To work correctly, all 'ignore' filters should
1968 filter by one of these modules.
1970 Examples
1971 --------
1972 >>> import warnings
1973 >>> with np.testing.clear_and_catch_warnings(
1974 ... modules=[np.core.fromnumeric]):
1975 ... warnings.simplefilter('always')
1976 ... warnings.filterwarnings('ignore', module='np.core.fromnumeric')
1977 ... # do something that raises a warning but ignore those in
1978 ... # np.core.fromnumeric
1979 """
1980 class_modules = ()
1982 def __init__(self, record=False, modules=()):
1983 self.modules = set(modules).union(self.class_modules)
1984 self._warnreg_copies = {}
1985 super().__init__(record=record)
1987 def __enter__(self):
1988 for mod in self.modules:
1989 if hasattr(mod, '__warningregistry__'):
1990 mod_reg = mod.__warningregistry__
1991 self._warnreg_copies[mod] = mod_reg.copy()
1992 mod_reg.clear()
1993 return super().__enter__()
1995 def __exit__(self, *exc_info):
1996 super().__exit__(*exc_info)
1997 for mod in self.modules:
1998 if hasattr(mod, '__warningregistry__'):
1999 mod.__warningregistry__.clear()
2000 if mod in self._warnreg_copies:
2001 mod.__warningregistry__.update(self._warnreg_copies[mod])
2004class suppress_warnings:
2005 """
2006 Context manager and decorator doing much the same as
2007 ``warnings.catch_warnings``.
2009 However, it also provides a filter mechanism to work around
2010 https://bugs.python.org/issue4180.
2012 This bug causes Python before 3.4 to not reliably show warnings again
2013 after they have been ignored once (even within catch_warnings). It
2014 means that no "ignore" filter can be used easily, since following
2015 tests might need to see the warning. Additionally it allows easier
2016 specificity for testing warnings and can be nested.
2018 Parameters
2019 ----------
2020 forwarding_rule : str, optional
2021 One of "always", "once", "module", or "location". Analogous to
2022 the usual warnings module filter mode, it is useful to reduce
2023 noise mostly on the outmost level. Unsuppressed and unrecorded
2024 warnings will be forwarded based on this rule. Defaults to "always".
2025 "location" is equivalent to the warnings "default", match by exact
2026 location the warning warning originated from.
2028 Notes
2029 -----
2030 Filters added inside the context manager will be discarded again
2031 when leaving it. Upon entering all filters defined outside a
2032 context will be applied automatically.
2034 When a recording filter is added, matching warnings are stored in the
2035 ``log`` attribute as well as in the list returned by ``record``.
2037 If filters are added and the ``module`` keyword is given, the
2038 warning registry of this module will additionally be cleared when
2039 applying it, entering the context, or exiting it. This could cause
2040 warnings to appear a second time after leaving the context if they
2041 were configured to be printed once (default) and were already
2042 printed before the context was entered.
2044 Nesting this context manager will work as expected when the
2045 forwarding rule is "always" (default). Unfiltered and unrecorded
2046 warnings will be passed out and be matched by the outer level.
2047 On the outmost level they will be printed (or caught by another
2048 warnings context). The forwarding rule argument can modify this
2049 behaviour.
2051 Like ``catch_warnings`` this context manager is not threadsafe.
2053 Examples
2054 --------
2056 With a context manager::
2058 with np.testing.suppress_warnings() as sup:
2059 sup.filter(DeprecationWarning, "Some text")
2060 sup.filter(module=np.ma.core)
2061 log = sup.record(FutureWarning, "Does this occur?")
2062 command_giving_warnings()
2063 # The FutureWarning was given once, the filtered warnings were
2064 # ignored. All other warnings abide outside settings (may be
2065 # printed/error)
2066 assert_(len(log) == 1)
2067 assert_(len(sup.log) == 1) # also stored in log attribute
2069 Or as a decorator::
2071 sup = np.testing.suppress_warnings()
2072 sup.filter(module=np.ma.core) # module must match exactly
2073 @sup
2074 def some_function():
2075 # do something which causes a warning in np.ma.core
2076 pass
2077 """
2078 def __init__(self, forwarding_rule="always"):
2079 self._entered = False
2081 # Suppressions are either instance or defined inside one with block:
2082 self._suppressions = []
2084 if forwarding_rule not in {"always", "module", "once", "location"}:
2085 raise ValueError("unsupported forwarding rule.")
2086 self._forwarding_rule = forwarding_rule
2088 def _clear_registries(self):
2089 if hasattr(warnings, "_filters_mutated"):
2090 # clearing the registry should not be necessary on new pythons,
2091 # instead the filters should be mutated.
2092 warnings._filters_mutated()
2093 return
2094 # Simply clear the registry, this should normally be harmless,
2095 # note that on new pythons it would be invalidated anyway.
2096 for module in self._tmp_modules:
2097 if hasattr(module, "__warningregistry__"):
2098 module.__warningregistry__.clear()
2100 def _filter(self, category=Warning, message="", module=None, record=False):
2101 if record:
2102 record = [] # The log where to store warnings
2103 else:
2104 record = None
2105 if self._entered:
2106 if module is None:
2107 warnings.filterwarnings(
2108 "always", category=category, message=message)
2109 else:
2110 module_regex = module.__name__.replace('.', r'\.') + '$'
2111 warnings.filterwarnings(
2112 "always", category=category, message=message,
2113 module=module_regex)
2114 self._tmp_modules.add(module)
2115 self._clear_registries()
2117 self._tmp_suppressions.append(
2118 (category, message, re.compile(message, re.I), module, record))
2119 else:
2120 self._suppressions.append(
2121 (category, message, re.compile(message, re.I), module, record))
2123 return record
2125 def filter(self, category=Warning, message="", module=None):
2126 """
2127 Add a new suppressing filter or apply it if the state is entered.
2129 Parameters
2130 ----------
2131 category : class, optional
2132 Warning class to filter
2133 message : string, optional
2134 Regular expression matching the warning message.
2135 module : module, optional
2136 Module to filter for. Note that the module (and its file)
2137 must match exactly and cannot be a submodule. This may make
2138 it unreliable for external modules.
2140 Notes
2141 -----
2142 When added within a context, filters are only added inside
2143 the context and will be forgotten when the context is exited.
2144 """
2145 self._filter(category=category, message=message, module=module,
2146 record=False)
2148 def record(self, category=Warning, message="", module=None):
2149 """
2150 Append a new recording filter or apply it if the state is entered.
2152 All warnings matching will be appended to the ``log`` attribute.
2154 Parameters
2155 ----------
2156 category : class, optional
2157 Warning class to filter
2158 message : string, optional
2159 Regular expression matching the warning message.
2160 module : module, optional
2161 Module to filter for. Note that the module (and its file)
2162 must match exactly and cannot be a submodule. This may make
2163 it unreliable for external modules.
2165 Returns
2166 -------
2167 log : list
2168 A list which will be filled with all matched warnings.
2170 Notes
2171 -----
2172 When added within a context, filters are only added inside
2173 the context and will be forgotten when the context is exited.
2174 """
2175 return self._filter(category=category, message=message, module=module,
2176 record=True)
2178 def __enter__(self):
2179 if self._entered:
2180 raise RuntimeError("cannot enter suppress_warnings twice.")
2182 self._orig_show = warnings.showwarning
2183 self._filters = warnings.filters
2184 warnings.filters = self._filters[:]
2186 self._entered = True
2187 self._tmp_suppressions = []
2188 self._tmp_modules = set()
2189 self._forwarded = set()
2191 self.log = [] # reset global log (no need to keep same list)
2193 for cat, mess, _, mod, log in self._suppressions:
2194 if log is not None:
2195 del log[:] # clear the log
2196 if mod is None:
2197 warnings.filterwarnings(
2198 "always", category=cat, message=mess)
2199 else:
2200 module_regex = mod.__name__.replace('.', r'\.') + '$'
2201 warnings.filterwarnings(
2202 "always", category=cat, message=mess,
2203 module=module_regex)
2204 self._tmp_modules.add(mod)
2205 warnings.showwarning = self._showwarning
2206 self._clear_registries()
2208 return self
2210 def __exit__(self, *exc_info):
2211 warnings.showwarning = self._orig_show
2212 warnings.filters = self._filters
2213 self._clear_registries()
2214 self._entered = False
2215 del self._orig_show
2216 del self._filters
2218 def _showwarning(self, message, category, filename, lineno,
2219 *args, use_warnmsg=None, **kwargs):
2220 for cat, _, pattern, mod, rec in (
2221 self._suppressions + self._tmp_suppressions)[::-1]:
2222 if (issubclass(category, cat) and
2223 pattern.match(message.args[0]) is not None):
2224 if mod is None:
2225 # Message and category match, either recorded or ignored
2226 if rec is not None:
2227 msg = WarningMessage(message, category, filename,
2228 lineno, **kwargs)
2229 self.log.append(msg)
2230 rec.append(msg)
2231 return
2232 # Use startswith, because warnings strips the c or o from
2233 # .pyc/.pyo files.
2234 elif mod.__file__.startswith(filename):
2235 # The message and module (filename) match
2236 if rec is not None:
2237 msg = WarningMessage(message, category, filename,
2238 lineno, **kwargs)
2239 self.log.append(msg)
2240 rec.append(msg)
2241 return
2243 # There is no filter in place, so pass to the outside handler
2244 # unless we should only pass it once
2245 if self._forwarding_rule == "always":
2246 if use_warnmsg is None:
2247 self._orig_show(message, category, filename, lineno,
2248 *args, **kwargs)
2249 else:
2250 self._orig_showmsg(use_warnmsg)
2251 return
2253 if self._forwarding_rule == "once":
2254 signature = (message.args, category)
2255 elif self._forwarding_rule == "module":
2256 signature = (message.args, category, filename)
2257 elif self._forwarding_rule == "location":
2258 signature = (message.args, category, filename, lineno)
2260 if signature in self._forwarded:
2261 return
2262 self._forwarded.add(signature)
2263 if use_warnmsg is None:
2264 self._orig_show(message, category, filename, lineno, *args,
2265 **kwargs)
2266 else:
2267 self._orig_showmsg(use_warnmsg)
2269 def __call__(self, func):
2270 """
2271 Function decorator to apply certain suppressions to a whole
2272 function.
2273 """
2274 @wraps(func)
2275 def new_func(*args, **kwargs):
2276 with self:
2277 return func(*args, **kwargs)
2279 return new_func
2282@contextlib.contextmanager
2283def _assert_no_gc_cycles_context(name=None):
2284 __tracebackhide__ = True # Hide traceback for py.test
2286 # not meaningful to test if there is no refcounting
2287 if not HAS_REFCOUNT:
2288 yield
2289 return
2291 assert_(gc.isenabled())
2292 gc.disable()
2293 gc_debug = gc.get_debug()
2294 try:
2295 for i in range(100):
2296 if gc.collect() == 0:
2297 break
2298 else:
2299 raise RuntimeError(
2300 "Unable to fully collect garbage - perhaps a __del__ method "
2301 "is creating more reference cycles?")
2303 gc.set_debug(gc.DEBUG_SAVEALL)
2304 yield
2305 # gc.collect returns the number of unreachable objects in cycles that
2306 # were found -- we are checking that no cycles were created in the context
2307 n_objects_in_cycles = gc.collect()
2308 objects_in_cycles = gc.garbage[:]
2309 finally:
2310 del gc.garbage[:]
2311 gc.set_debug(gc_debug)
2312 gc.enable()
2314 if n_objects_in_cycles:
2315 name_str = f' when calling {name}' if name is not None else ''
2316 raise AssertionError(
2317 "Reference cycles were found{}: {} objects were collected, "
2318 "of which {} are shown below:{}"
2319 .format(
2320 name_str,
2321 n_objects_in_cycles,
2322 len(objects_in_cycles),
2323 ''.join(
2324 "\n {} object with id={}:\n {}".format(
2325 type(o).__name__,
2326 id(o),
2327 pprint.pformat(o).replace('\n', '\n ')
2328 ) for o in objects_in_cycles
2329 )
2330 )
2331 )
2334def assert_no_gc_cycles(*args, **kwargs):
2335 """
2336 Fail if the given callable produces any reference cycles.
2338 If called with all arguments omitted, may be used as a context manager:
2340 with assert_no_gc_cycles():
2341 do_something()
2343 .. versionadded:: 1.15.0
2345 Parameters
2346 ----------
2347 func : callable
2348 The callable to test.
2349 \\*args : Arguments
2350 Arguments passed to `func`.
2351 \\*\\*kwargs : Kwargs
2352 Keyword arguments passed to `func`.
2354 Returns
2355 -------
2356 Nothing. The result is deliberately discarded to ensure that all cycles
2357 are found.
2359 """
2360 if not args:
2361 return _assert_no_gc_cycles_context()
2363 func = args[0]
2364 args = args[1:]
2365 with _assert_no_gc_cycles_context(name=func.__name__):
2366 func(*args, **kwargs)
2369def break_cycles():
2370 """
2371 Break reference cycles by calling gc.collect
2372 Objects can call other objects' methods (for instance, another object's
2373 __del__) inside their own __del__. On PyPy, the interpreter only runs
2374 between calls to gc.collect, so multiple calls are needed to completely
2375 release all cycles.
2376 """
2378 gc.collect()
2379 if IS_PYPY:
2380 # a few more, just to make sure all the finalizers are called
2381 gc.collect()
2382 gc.collect()
2383 gc.collect()
2384 gc.collect()
2387def requires_memory(free_bytes):
2388 """Decorator to skip a test if not enough memory is available"""
2389 import pytest
2391 def decorator(func):
2392 @wraps(func)
2393 def wrapper(*a, **kw):
2394 msg = check_free_memory(free_bytes)
2395 if msg is not None:
2396 pytest.skip(msg)
2398 try:
2399 return func(*a, **kw)
2400 except MemoryError:
2401 # Probably ran out of memory regardless: don't regard as failure
2402 pytest.xfail("MemoryError raised")
2404 return wrapper
2406 return decorator
2409def check_free_memory(free_bytes):
2410 """
2411 Check whether `free_bytes` amount of memory is currently free.
2412 Returns: None if enough memory available, otherwise error message
2413 """
2414 env_var = 'NPY_AVAILABLE_MEM'
2415 env_value = os.environ.get(env_var)
2416 if env_value is not None:
2417 try:
2418 mem_free = _parse_size(env_value)
2419 except ValueError as exc:
2420 raise ValueError(f'Invalid environment variable {env_var}: {exc}')
2422 msg = (f'{free_bytes/1e9} GB memory required, but environment variable '
2423 f'NPY_AVAILABLE_MEM={env_value} set')
2424 else:
2425 mem_free = _get_mem_available()
2427 if mem_free is None:
2428 msg = ("Could not determine available memory; set NPY_AVAILABLE_MEM "
2429 "environment variable (e.g. NPY_AVAILABLE_MEM=16GB) to run "
2430 "the test.")
2431 mem_free = -1
2432 else:
2433 msg = f'{free_bytes/1e9} GB memory required, but {mem_free/1e9} GB available'
2435 return msg if mem_free < free_bytes else None
2438def _parse_size(size_str):
2439 """Convert memory size strings ('12 GB' etc.) to float"""
2440 suffixes = {'': 1, 'b': 1,
2441 'k': 1000, 'm': 1000**2, 'g': 1000**3, 't': 1000**4,
2442 'kb': 1000, 'mb': 1000**2, 'gb': 1000**3, 'tb': 1000**4,
2443 'kib': 1024, 'mib': 1024**2, 'gib': 1024**3, 'tib': 1024**4}
2445 size_re = re.compile(r'^\s*(\d+|\d+\.\d+)\s*({0})\s*$'.format(
2446 '|'.join(suffixes.keys())), re.I)
2448 m = size_re.match(size_str.lower())
2449 if not m or m.group(2) not in suffixes:
2450 raise ValueError(f'value {size_str!r} not a valid size')
2451 return int(float(m.group(1)) * suffixes[m.group(2)])
2454def _get_mem_available():
2455 """Return available memory in bytes, or None if unknown."""
2456 try:
2457 import psutil
2458 return psutil.virtual_memory().available
2459 except (ImportError, AttributeError):
2460 pass
2462 if sys.platform.startswith('linux'):
2463 info = {}
2464 with open('/proc/meminfo') as f:
2465 for line in f:
2466 p = line.split()
2467 info[p[0].strip(':').lower()] = int(p[1]) * 1024
2469 if 'memavailable' in info:
2470 # Linux >= 3.14
2471 return info['memavailable']
2472 else:
2473 return info['memfree'] + info['cached']
2475 return None
2478def _no_tracing(func):
2479 """
2480 Decorator to temporarily turn off tracing for the duration of a test.
2481 Needed in tests that check refcounting, otherwise the tracing itself
2482 influences the refcounts
2483 """
2484 if not hasattr(sys, 'gettrace'):
2485 return func
2486 else:
2487 @wraps(func)
2488 def wrapper(*args, **kwargs):
2489 original_trace = sys.gettrace()
2490 try:
2491 sys.settrace(None)
2492 return func(*args, **kwargs)
2493 finally:
2494 sys.settrace(original_trace)
2495 return wrapper
2498def _get_glibc_version():
2499 try:
2500 ver = os.confstr('CS_GNU_LIBC_VERSION').rsplit(' ')[1]
2501 except Exception:
2502 ver = '0.0'
2504 return ver
2507_glibcver = _get_glibc_version()
2508_glibc_older_than = lambda x: (_glibcver != '0.0' and _glibcver < x)