Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/_lib/_util.py: 3%
357 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-03 06:39 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-03 06:39 +0000
1import re
2from contextlib import contextmanager
3import functools
4import operator
5import warnings
6import numbers
7from collections import namedtuple
8import inspect
9import math
10from typing import (
11 Optional,
12 Union,
13 TYPE_CHECKING,
14 TypeVar,
15)
17import numpy as np
18from scipy._lib._array_api import array_namespace
21AxisError: type[Exception]
22ComplexWarning: type[Warning]
23VisibleDeprecationWarning: type[Warning]
25if np.lib.NumpyVersion(np.__version__) >= '1.25.0':
26 from numpy.exceptions import (
27 AxisError, ComplexWarning, VisibleDeprecationWarning,
28 DTypePromotionError
29 )
30else:
31 from numpy import (
32 AxisError, ComplexWarning, VisibleDeprecationWarning # noqa: F401
33 )
34 DTypePromotionError = TypeError # type: ignore
36np_long: type
37np_ulong: type
39if np.lib.NumpyVersion(np.__version__) >= "2.0.0.dev0":
40 try:
41 with warnings.catch_warnings():
42 warnings.filterwarnings(
43 "ignore",
44 r".*In the future `np\.long` will be defined as.*",
45 FutureWarning,
46 )
47 np_long = np.long # type: ignore[attr-defined]
48 np_ulong = np.ulong # type: ignore[attr-defined]
49 except AttributeError:
50 np_long = np.int_
51 np_ulong = np.uint
52else:
53 np_long = np.int_
54 np_ulong = np.uint
56IntNumber = Union[int, np.integer]
57DecimalNumber = Union[float, np.floating, np.integer]
59copy_if_needed: Optional[bool]
61if np.lib.NumpyVersion(np.__version__) >= "2.0.0":
62 copy_if_needed = None
63elif np.lib.NumpyVersion(np.__version__) < "1.28.0":
64 copy_if_needed = False
65else:
66 # 2.0.0 dev versions, handle cases where copy may or may not exist
67 try:
68 np.array([1]).__array__(copy=None) # type: ignore[call-overload]
69 copy_if_needed = None
70 except TypeError:
71 copy_if_needed = False
73# Since Generator was introduced in numpy 1.17, the following condition is needed for
74# backward compatibility
75if TYPE_CHECKING:
76 SeedType = Optional[Union[IntNumber, np.random.Generator,
77 np.random.RandomState]]
78 GeneratorType = TypeVar("GeneratorType", bound=Union[np.random.Generator,
79 np.random.RandomState])
81try:
82 from numpy.random import Generator as Generator
83except ImportError:
84 class Generator: # type: ignore[no-redef]
85 pass
88def _lazywhere(cond, arrays, f, fillvalue=None, f2=None):
89 """Return elements chosen from two possibilities depending on a condition
91 Equivalent to ``f(*arrays) if cond else fillvalue`` performed elementwise.
93 Parameters
94 ----------
95 cond : array
96 The condition (expressed as a boolean array).
97 arrays : tuple of array
98 Arguments to `f` (and `f2`). Must be broadcastable with `cond`.
99 f : callable
100 Where `cond` is True, output will be ``f(arr1[cond], arr2[cond], ...)``
101 fillvalue : object
102 If provided, value with which to fill output array where `cond` is
103 not True.
104 f2 : callable
105 If provided, output will be ``f2(arr1[cond], arr2[cond], ...)`` where
106 `cond` is not True.
108 Returns
109 -------
110 out : array
111 An array with elements from the output of `f` where `cond` is True
112 and `fillvalue` (or elements from the output of `f2`) elsewhere. The
113 returned array has data type determined by Type Promotion Rules
114 with the output of `f` and `fillvalue` (or the output of `f2`).
116 Notes
117 -----
118 ``xp.where(cond, x, fillvalue)`` requires explicitly forming `x` even where
119 `cond` is False. This function evaluates ``f(arr1[cond], arr2[cond], ...)``
120 onle where `cond` ``is True.
122 Examples
123 --------
124 >>> import numpy as np
125 >>> a, b = np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8])
126 >>> def f(a, b):
127 ... return a*b
128 >>> _lazywhere(a > 2, (a, b), f, np.nan)
129 array([ nan, nan, 21., 32.])
131 """
132 xp = array_namespace(cond, *arrays)
134 if (f2 is fillvalue is None) or (f2 is not None and fillvalue is not None):
135 raise ValueError("Exactly one of `fillvalue` or `f2` must be given.")
137 args = xp.broadcast_arrays(cond, *arrays)
138 bool_dtype = xp.asarray([True]).dtype # numpy 1.xx doesn't have `bool`
139 cond, arrays = xp.astype(args[0], bool_dtype, copy=False), args[1:]
141 temp1 = xp.asarray(f(*(arr[cond] for arr in arrays)))
143 if f2 is None:
144 fillvalue = xp.asarray(fillvalue)
145 dtype = xp.result_type(temp1.dtype, fillvalue.dtype)
146 out = xp.full(cond.shape, fill_value=fillvalue, dtype=dtype)
147 else:
148 ncond = ~cond
149 temp2 = xp.asarray(f2(*(arr[ncond] for arr in arrays)))
150 dtype = xp.result_type(temp1, temp2)
151 out = xp.empty(cond.shape, dtype=dtype)
152 out[ncond] = temp2
154 out[cond] = temp1
156 return out
159def _lazyselect(condlist, choicelist, arrays, default=0):
160 """
161 Mimic `np.select(condlist, choicelist)`.
163 Notice, it assumes that all `arrays` are of the same shape or can be
164 broadcasted together.
166 All functions in `choicelist` must accept array arguments in the order
167 given in `arrays` and must return an array of the same shape as broadcasted
168 `arrays`.
170 Examples
171 --------
172 >>> import numpy as np
173 >>> x = np.arange(6)
174 >>> np.select([x <3, x > 3], [x**2, x**3], default=0)
175 array([ 0, 1, 4, 0, 64, 125])
177 >>> _lazyselect([x < 3, x > 3], [lambda x: x**2, lambda x: x**3], (x,))
178 array([ 0., 1., 4., 0., 64., 125.])
180 >>> a = -np.ones_like(x)
181 >>> _lazyselect([x < 3, x > 3],
182 ... [lambda x, a: x**2, lambda x, a: a * x**3],
183 ... (x, a), default=np.nan)
184 array([ 0., 1., 4., nan, -64., -125.])
186 """
187 arrays = np.broadcast_arrays(*arrays)
188 tcode = np.mintypecode([a.dtype.char for a in arrays])
189 out = np.full(np.shape(arrays[0]), fill_value=default, dtype=tcode)
190 for func, cond in zip(choicelist, condlist):
191 if np.all(cond is False):
192 continue
193 cond, _ = np.broadcast_arrays(cond, arrays[0])
194 temp = tuple(np.extract(cond, arr) for arr in arrays)
195 np.place(out, cond, func(*temp))
196 return out
199def _aligned_zeros(shape, dtype=float, order="C", align=None):
200 """Allocate a new ndarray with aligned memory.
202 Primary use case for this currently is working around a f2py issue
203 in NumPy 1.9.1, where dtype.alignment is such that np.zeros() does
204 not necessarily create arrays aligned up to it.
206 """
207 dtype = np.dtype(dtype)
208 if align is None:
209 align = dtype.alignment
210 if not hasattr(shape, '__len__'):
211 shape = (shape,)
212 size = functools.reduce(operator.mul, shape) * dtype.itemsize
213 buf = np.empty(size + align + 1, np.uint8)
214 offset = buf.__array_interface__['data'][0] % align
215 if offset != 0:
216 offset = align - offset
217 # Note: slices producing 0-size arrays do not necessarily change
218 # data pointer --- so we use and allocate size+1
219 buf = buf[offset:offset+size+1][:-1]
220 data = np.ndarray(shape, dtype, buf, order=order)
221 data.fill(0)
222 return data
225def _prune_array(array):
226 """Return an array equivalent to the input array. If the input
227 array is a view of a much larger array, copy its contents to a
228 newly allocated array. Otherwise, return the input unchanged.
229 """
230 if array.base is not None and array.size < array.base.size // 2:
231 return array.copy()
232 return array
235def float_factorial(n: int) -> float:
236 """Compute the factorial and return as a float
238 Returns infinity when result is too large for a double
239 """
240 return float(math.factorial(n)) if n < 171 else np.inf
243# copy-pasted from scikit-learn utils/validation.py
244# change this to scipy.stats._qmc.check_random_state once numpy 1.16 is dropped
245def check_random_state(seed):
246 """Turn `seed` into a `np.random.RandomState` instance.
248 Parameters
249 ----------
250 seed : {None, int, `numpy.random.Generator`, `numpy.random.RandomState`}, optional
251 If `seed` is None (or `np.random`), the `numpy.random.RandomState`
252 singleton is used.
253 If `seed` is an int, a new ``RandomState`` instance is used,
254 seeded with `seed`.
255 If `seed` is already a ``Generator`` or ``RandomState`` instance then
256 that instance is used.
258 Returns
259 -------
260 seed : {`numpy.random.Generator`, `numpy.random.RandomState`}
261 Random number generator.
263 """
264 if seed is None or seed is np.random:
265 return np.random.mtrand._rand
266 if isinstance(seed, (numbers.Integral, np.integer)):
267 return np.random.RandomState(seed)
268 if isinstance(seed, (np.random.RandomState, np.random.Generator)):
269 return seed
271 raise ValueError('%r cannot be used to seed a numpy.random.RandomState'
272 ' instance' % seed)
275def _asarray_validated(a, check_finite=True,
276 sparse_ok=False, objects_ok=False, mask_ok=False,
277 as_inexact=False):
278 """
279 Helper function for SciPy argument validation.
281 Many SciPy linear algebra functions do support arbitrary array-like
282 input arguments. Examples of commonly unsupported inputs include
283 matrices containing inf/nan, sparse matrix representations, and
284 matrices with complicated elements.
286 Parameters
287 ----------
288 a : array_like
289 The array-like input.
290 check_finite : bool, optional
291 Whether to check that the input matrices contain only finite numbers.
292 Disabling may give a performance gain, but may result in problems
293 (crashes, non-termination) if the inputs do contain infinities or NaNs.
294 Default: True
295 sparse_ok : bool, optional
296 True if scipy sparse matrices are allowed.
297 objects_ok : bool, optional
298 True if arrays with dype('O') are allowed.
299 mask_ok : bool, optional
300 True if masked arrays are allowed.
301 as_inexact : bool, optional
302 True to convert the input array to a np.inexact dtype.
304 Returns
305 -------
306 ret : ndarray
307 The converted validated array.
309 """
310 if not sparse_ok:
311 import scipy.sparse
312 if scipy.sparse.issparse(a):
313 msg = ('Sparse matrices are not supported by this function. '
314 'Perhaps one of the scipy.sparse.linalg functions '
315 'would work instead.')
316 raise ValueError(msg)
317 if not mask_ok:
318 if np.ma.isMaskedArray(a):
319 raise ValueError('masked arrays are not supported')
320 toarray = np.asarray_chkfinite if check_finite else np.asarray
321 a = toarray(a)
322 if not objects_ok:
323 if a.dtype is np.dtype('O'):
324 raise ValueError('object arrays are not supported')
325 if as_inexact:
326 if not np.issubdtype(a.dtype, np.inexact):
327 a = toarray(a, dtype=np.float64)
328 return a
331def _validate_int(k, name, minimum=None):
332 """
333 Validate a scalar integer.
335 This function can be used to validate an argument to a function
336 that expects the value to be an integer. It uses `operator.index`
337 to validate the value (so, for example, k=2.0 results in a
338 TypeError).
340 Parameters
341 ----------
342 k : int
343 The value to be validated.
344 name : str
345 The name of the parameter.
346 minimum : int, optional
347 An optional lower bound.
348 """
349 try:
350 k = operator.index(k)
351 except TypeError:
352 raise TypeError(f'{name} must be an integer.') from None
353 if minimum is not None and k < minimum:
354 raise ValueError(f'{name} must be an integer not less '
355 f'than {minimum}') from None
356 return k
359# Add a replacement for inspect.getfullargspec()/
360# The version below is borrowed from Django,
361# https://github.com/django/django/pull/4846.
363# Note an inconsistency between inspect.getfullargspec(func) and
364# inspect.signature(func). If `func` is a bound method, the latter does *not*
365# list `self` as a first argument, while the former *does*.
366# Hence, cook up a common ground replacement: `getfullargspec_no_self` which
367# mimics `inspect.getfullargspec` but does not list `self`.
368#
369# This way, the caller code does not need to know whether it uses a legacy
370# .getfullargspec or a bright and shiny .signature.
372FullArgSpec = namedtuple('FullArgSpec',
373 ['args', 'varargs', 'varkw', 'defaults',
374 'kwonlyargs', 'kwonlydefaults', 'annotations'])
377def getfullargspec_no_self(func):
378 """inspect.getfullargspec replacement using inspect.signature.
380 If func is a bound method, do not list the 'self' parameter.
382 Parameters
383 ----------
384 func : callable
385 A callable to inspect
387 Returns
388 -------
389 fullargspec : FullArgSpec(args, varargs, varkw, defaults, kwonlyargs,
390 kwonlydefaults, annotations)
392 NOTE: if the first argument of `func` is self, it is *not*, I repeat
393 *not*, included in fullargspec.args.
394 This is done for consistency between inspect.getargspec() under
395 Python 2.x, and inspect.signature() under Python 3.x.
397 """
398 sig = inspect.signature(func)
399 args = [
400 p.name for p in sig.parameters.values()
401 if p.kind in [inspect.Parameter.POSITIONAL_OR_KEYWORD,
402 inspect.Parameter.POSITIONAL_ONLY]
403 ]
404 varargs = [
405 p.name for p in sig.parameters.values()
406 if p.kind == inspect.Parameter.VAR_POSITIONAL
407 ]
408 varargs = varargs[0] if varargs else None
409 varkw = [
410 p.name for p in sig.parameters.values()
411 if p.kind == inspect.Parameter.VAR_KEYWORD
412 ]
413 varkw = varkw[0] if varkw else None
414 defaults = tuple(
415 p.default for p in sig.parameters.values()
416 if (p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and
417 p.default is not p.empty)
418 ) or None
419 kwonlyargs = [
420 p.name for p in sig.parameters.values()
421 if p.kind == inspect.Parameter.KEYWORD_ONLY
422 ]
423 kwdefaults = {p.name: p.default for p in sig.parameters.values()
424 if p.kind == inspect.Parameter.KEYWORD_ONLY and
425 p.default is not p.empty}
426 annotations = {p.name: p.annotation for p in sig.parameters.values()
427 if p.annotation is not p.empty}
428 return FullArgSpec(args, varargs, varkw, defaults, kwonlyargs,
429 kwdefaults or None, annotations)
432class _FunctionWrapper:
433 """
434 Object to wrap user's function, allowing picklability
435 """
436 def __init__(self, f, args):
437 self.f = f
438 self.args = [] if args is None else args
440 def __call__(self, x):
441 return self.f(x, *self.args)
444class MapWrapper:
445 """
446 Parallelisation wrapper for working with map-like callables, such as
447 `multiprocessing.Pool.map`.
449 Parameters
450 ----------
451 pool : int or map-like callable
452 If `pool` is an integer, then it specifies the number of threads to
453 use for parallelization. If ``int(pool) == 1``, then no parallel
454 processing is used and the map builtin is used.
455 If ``pool == -1``, then the pool will utilize all available CPUs.
456 If `pool` is a map-like callable that follows the same
457 calling sequence as the built-in map function, then this callable is
458 used for parallelization.
459 """
460 def __init__(self, pool=1):
461 self.pool = None
462 self._mapfunc = map
463 self._own_pool = False
465 if callable(pool):
466 self.pool = pool
467 self._mapfunc = self.pool
468 else:
469 from multiprocessing import Pool
470 # user supplies a number
471 if int(pool) == -1:
472 # use as many processors as possible
473 self.pool = Pool()
474 self._mapfunc = self.pool.map
475 self._own_pool = True
476 elif int(pool) == 1:
477 pass
478 elif int(pool) > 1:
479 # use the number of processors requested
480 self.pool = Pool(processes=int(pool))
481 self._mapfunc = self.pool.map
482 self._own_pool = True
483 else:
484 raise RuntimeError("Number of workers specified must be -1,"
485 " an int >= 1, or an object with a 'map' "
486 "method")
488 def __enter__(self):
489 return self
491 def terminate(self):
492 if self._own_pool:
493 self.pool.terminate()
495 def join(self):
496 if self._own_pool:
497 self.pool.join()
499 def close(self):
500 if self._own_pool:
501 self.pool.close()
503 def __exit__(self, exc_type, exc_value, traceback):
504 if self._own_pool:
505 self.pool.close()
506 self.pool.terminate()
508 def __call__(self, func, iterable):
509 # only accept one iterable because that's all Pool.map accepts
510 try:
511 return self._mapfunc(func, iterable)
512 except TypeError as e:
513 # wrong number of arguments
514 raise TypeError("The map-like callable must be of the"
515 " form f(func, iterable)") from e
518def rng_integers(gen, low, high=None, size=None, dtype='int64',
519 endpoint=False):
520 """
521 Return random integers from low (inclusive) to high (exclusive), or if
522 endpoint=True, low (inclusive) to high (inclusive). Replaces
523 `RandomState.randint` (with endpoint=False) and
524 `RandomState.random_integers` (with endpoint=True).
526 Return random integers from the "discrete uniform" distribution of the
527 specified dtype. If high is None (the default), then results are from
528 0 to low.
530 Parameters
531 ----------
532 gen : {None, np.random.RandomState, np.random.Generator}
533 Random number generator. If None, then the np.random.RandomState
534 singleton is used.
535 low : int or array-like of ints
536 Lowest (signed) integers to be drawn from the distribution (unless
537 high=None, in which case this parameter is 0 and this value is used
538 for high).
539 high : int or array-like of ints
540 If provided, one above the largest (signed) integer to be drawn from
541 the distribution (see above for behavior if high=None). If array-like,
542 must contain integer values.
543 size : array-like of ints, optional
544 Output shape. If the given shape is, e.g., (m, n, k), then m * n * k
545 samples are drawn. Default is None, in which case a single value is
546 returned.
547 dtype : {str, dtype}, optional
548 Desired dtype of the result. All dtypes are determined by their name,
549 i.e., 'int64', 'int', etc, so byteorder is not available and a specific
550 precision may have different C types depending on the platform.
551 The default value is 'int64'.
552 endpoint : bool, optional
553 If True, sample from the interval [low, high] instead of the default
554 [low, high) Defaults to False.
556 Returns
557 -------
558 out: int or ndarray of ints
559 size-shaped array of random integers from the appropriate distribution,
560 or a single such random int if size not provided.
561 """
562 if isinstance(gen, Generator):
563 return gen.integers(low, high=high, size=size, dtype=dtype,
564 endpoint=endpoint)
565 else:
566 if gen is None:
567 # default is RandomState singleton used by np.random.
568 gen = np.random.mtrand._rand
569 if endpoint:
570 # inclusive of endpoint
571 # remember that low and high can be arrays, so don't modify in
572 # place
573 if high is None:
574 return gen.randint(low + 1, size=size, dtype=dtype)
575 if high is not None:
576 return gen.randint(low, high=high + 1, size=size, dtype=dtype)
578 # exclusive
579 return gen.randint(low, high=high, size=size, dtype=dtype)
582@contextmanager
583def _fixed_default_rng(seed=1638083107694713882823079058616272161):
584 """Context with a fixed np.random.default_rng seed."""
585 orig_fun = np.random.default_rng
586 np.random.default_rng = lambda seed=seed: orig_fun(seed)
587 try:
588 yield
589 finally:
590 np.random.default_rng = orig_fun
593def _rng_html_rewrite(func):
594 """Rewrite the HTML rendering of ``np.random.default_rng``.
596 This is intended to decorate
597 ``numpydoc.docscrape_sphinx.SphinxDocString._str_examples``.
599 Examples are only run by Sphinx when there are plot involved. Even so,
600 it does not change the result values getting printed.
601 """
602 # hexadecimal or number seed, case-insensitive
603 pattern = re.compile(r'np.random.default_rng\((0x[0-9A-F]+|\d+)\)', re.I)
605 def _wrapped(*args, **kwargs):
606 res = func(*args, **kwargs)
607 lines = [
608 re.sub(pattern, 'np.random.default_rng()', line)
609 for line in res
610 ]
611 return lines
613 return _wrapped
616def _argmin(a, keepdims=False, axis=None):
617 """
618 argmin with a `keepdims` parameter.
620 See https://github.com/numpy/numpy/issues/8710
622 If axis is not None, a.shape[axis] must be greater than 0.
623 """
624 res = np.argmin(a, axis=axis)
625 if keepdims and axis is not None:
626 res = np.expand_dims(res, axis=axis)
627 return res
630def _first_nonnan(a, axis):
631 """
632 Return the first non-nan value along the given axis.
634 If a slice is all nan, nan is returned for that slice.
636 The shape of the return value corresponds to ``keepdims=True``.
638 Examples
639 --------
640 >>> import numpy as np
641 >>> nan = np.nan
642 >>> a = np.array([[ 3., 3., nan, 3.],
643 [ 1., nan, 2., 4.],
644 [nan, nan, 9., -1.],
645 [nan, 5., 4., 3.],
646 [ 2., 2., 2., 2.],
647 [nan, nan, nan, nan]])
648 >>> _first_nonnan(a, axis=0)
649 array([[3., 3., 2., 3.]])
650 >>> _first_nonnan(a, axis=1)
651 array([[ 3.],
652 [ 1.],
653 [ 9.],
654 [ 5.],
655 [ 2.],
656 [nan]])
657 """
658 k = _argmin(np.isnan(a), axis=axis, keepdims=True)
659 return np.take_along_axis(a, k, axis=axis)
662def _nan_allsame(a, axis, keepdims=False):
663 """
664 Determine if the values along an axis are all the same.
666 nan values are ignored.
668 `a` must be a numpy array.
670 `axis` is assumed to be normalized; that is, 0 <= axis < a.ndim.
672 For an axis of length 0, the result is True. That is, we adopt the
673 convention that ``allsame([])`` is True. (There are no values in the
674 input that are different.)
676 `True` is returned for slices that are all nan--not because all the
677 values are the same, but because this is equivalent to ``allsame([])``.
679 Examples
680 --------
681 >>> from numpy import nan, array
682 >>> a = array([[ 3., 3., nan, 3.],
683 ... [ 1., nan, 2., 4.],
684 ... [nan, nan, 9., -1.],
685 ... [nan, 5., 4., 3.],
686 ... [ 2., 2., 2., 2.],
687 ... [nan, nan, nan, nan]])
688 >>> _nan_allsame(a, axis=1, keepdims=True)
689 array([[ True],
690 [False],
691 [False],
692 [False],
693 [ True],
694 [ True]])
695 """
696 if axis is None:
697 if a.size == 0:
698 return True
699 a = a.ravel()
700 axis = 0
701 else:
702 shp = a.shape
703 if shp[axis] == 0:
704 shp = shp[:axis] + (1,)*keepdims + shp[axis + 1:]
705 return np.full(shp, fill_value=True, dtype=bool)
706 a0 = _first_nonnan(a, axis=axis)
707 return ((a0 == a) | np.isnan(a)).all(axis=axis, keepdims=keepdims)
710def _contains_nan(a, nan_policy='propagate', use_summation=True,
711 policies=None):
712 if not isinstance(a, np.ndarray):
713 use_summation = False # some array_likes ignore nans (e.g. pandas)
714 if policies is None:
715 policies = ['propagate', 'raise', 'omit']
716 if nan_policy not in policies:
717 raise ValueError("nan_policy must be one of {%s}" %
718 ', '.join("'%s'" % s for s in policies))
720 if np.issubdtype(a.dtype, np.inexact):
721 # The summation method avoids creating a (potentially huge) array.
722 if use_summation:
723 with np.errstate(invalid='ignore', over='ignore'):
724 contains_nan = np.isnan(np.sum(a))
725 else:
726 contains_nan = np.isnan(a).any()
727 elif np.issubdtype(a.dtype, object):
728 contains_nan = False
729 for el in a.ravel():
730 # isnan doesn't work on non-numeric elements
731 if np.issubdtype(type(el), np.number) and np.isnan(el):
732 contains_nan = True
733 break
734 else:
735 # Only `object` and `inexact` arrays can have NaNs
736 contains_nan = False
738 if contains_nan and nan_policy == 'raise':
739 raise ValueError("The input contains nan values")
741 return contains_nan, nan_policy
744def _rename_parameter(old_name, new_name, dep_version=None):
745 """
746 Generate decorator for backward-compatible keyword renaming.
748 Apply the decorator generated by `_rename_parameter` to functions with a
749 recently renamed parameter to maintain backward-compatibility.
751 After decoration, the function behaves as follows:
752 If only the new parameter is passed into the function, behave as usual.
753 If only the old parameter is passed into the function (as a keyword), raise
754 a DeprecationWarning if `dep_version` is provided, and behave as usual
755 otherwise.
756 If both old and new parameters are passed into the function, raise a
757 DeprecationWarning if `dep_version` is provided, and raise the appropriate
758 TypeError (function got multiple values for argument).
760 Parameters
761 ----------
762 old_name : str
763 Old name of parameter
764 new_name : str
765 New name of parameter
766 dep_version : str, optional
767 Version of SciPy in which old parameter was deprecated in the format
768 'X.Y.Z'. If supplied, the deprecation message will indicate that
769 support for the old parameter will be removed in version 'X.Y+2.Z'
771 Notes
772 -----
773 Untested with functions that accept *args. Probably won't work as written.
775 """
776 def decorator(fun):
777 @functools.wraps(fun)
778 def wrapper(*args, **kwargs):
779 if old_name in kwargs:
780 if dep_version:
781 end_version = dep_version.split('.')
782 end_version[1] = str(int(end_version[1]) + 2)
783 end_version = '.'.join(end_version)
784 message = (f"Use of keyword argument `{old_name}` is "
785 f"deprecated and replaced by `{new_name}`. "
786 f"Support for `{old_name}` will be removed "
787 f"in SciPy {end_version}.")
788 warnings.warn(message, DeprecationWarning, stacklevel=2)
789 if new_name in kwargs:
790 message = (f"{fun.__name__}() got multiple values for "
791 f"argument now known as `{new_name}`")
792 raise TypeError(message)
793 kwargs[new_name] = kwargs.pop(old_name)
794 return fun(*args, **kwargs)
795 return wrapper
796 return decorator
799def _rng_spawn(rng, n_children):
800 # spawns independent RNGs from a parent RNG
801 bg = rng._bit_generator
802 ss = bg._seed_seq
803 child_rngs = [np.random.Generator(type(bg)(child_ss))
804 for child_ss in ss.spawn(n_children)]
805 return child_rngs
808def _get_nan(*data):
809 # Get NaN of appropriate dtype for data
810 data = [np.asarray(item) for item in data]
811 try:
812 dtype = np.result_type(*data, np.half) # must be a float16 at least
813 except DTypePromotionError:
814 # fallback to float64
815 return np.array(np.nan, dtype=np.float64)[()]
816 return np.array(np.nan, dtype=dtype)[()]
819def normalize_axis_index(axis, ndim):
820 # Check if `axis` is in the correct range and normalize it
821 if axis < -ndim or axis >= ndim:
822 msg = f"axis {axis} is out of bounds for array of dimension {ndim}"
823 raise AxisError(msg)
825 if axis < 0:
826 axis = axis + ndim
827 return axis
830def _call_callback_maybe_halt(callback, res):
831 """Call wrapped callback; return True if algorithm should stop.
833 Parameters
834 ----------
835 callback : callable or None
836 A user-provided callback wrapped with `_wrap_callback`
837 res : OptimizeResult
838 Information about the current iterate
840 Returns
841 -------
842 halt : bool
843 True if minimization should stop
845 """
846 if callback is None:
847 return False
848 try:
849 callback(res)
850 return False
851 except StopIteration:
852 callback.stop_iteration = True
853 return True
856class _RichResult(dict):
857 """ Container for multiple outputs with pretty-printing """
858 def __getattr__(self, name):
859 try:
860 return self[name]
861 except KeyError as e:
862 raise AttributeError(name) from e
864 __setattr__ = dict.__setitem__
865 __delattr__ = dict.__delitem__
867 def __repr__(self):
868 order_keys = ['message', 'success', 'status', 'fun', 'funl', 'x', 'xl',
869 'col_ind', 'nit', 'lower', 'upper', 'eqlin', 'ineqlin',
870 'converged', 'flag', 'function_calls', 'iterations',
871 'root']
872 order_keys = getattr(self, '_order_keys', order_keys)
873 # 'slack', 'con' are redundant with residuals
874 # 'crossover_nit' is probably not interesting to most users
875 omit_keys = {'slack', 'con', 'crossover_nit', '_order_keys'}
877 def key(item):
878 try:
879 return order_keys.index(item[0].lower())
880 except ValueError: # item not in list
881 return np.inf
883 def omit_redundant(items):
884 for item in items:
885 if item[0] in omit_keys:
886 continue
887 yield item
889 def item_sorter(d):
890 return sorted(omit_redundant(d.items()), key=key)
892 if self.keys():
893 return _dict_formatter(self, sorter=item_sorter)
894 else:
895 return self.__class__.__name__ + "()"
897 def __dir__(self):
898 return list(self.keys())
901def _indenter(s, n=0):
902 """
903 Ensures that lines after the first are indented by the specified amount
904 """
905 split = s.split("\n")
906 indent = " "*n
907 return ("\n" + indent).join(split)
910def _float_formatter_10(x):
911 """
912 Returns a string representation of a float with exactly ten characters
913 """
914 if np.isposinf(x):
915 return " inf"
916 elif np.isneginf(x):
917 return " -inf"
918 elif np.isnan(x):
919 return " nan"
920 return np.format_float_scientific(x, precision=3, pad_left=2, unique=False)
923def _dict_formatter(d, n=0, mplus=1, sorter=None):
924 """
925 Pretty printer for dictionaries
927 `n` keeps track of the starting indentation;
928 lines are indented by this much after a line break.
929 `mplus` is additional left padding applied to keys
930 """
931 if isinstance(d, dict):
932 m = max(map(len, list(d.keys()))) + mplus # width to print keys
933 s = '\n'.join([k.rjust(m) + ': ' + # right justified, width m
934 _indenter(_dict_formatter(v, m+n+2, 0, sorter), m+2)
935 for k, v in sorter(d)]) # +2 for ': '
936 else:
937 # By default, NumPy arrays print with linewidth=76. `n` is
938 # the indent at which a line begins printing, so it is subtracted
939 # from the default to avoid exceeding 76 characters total.
940 # `edgeitems` is the number of elements to include before and after
941 # ellipses when arrays are not shown in full.
942 # `threshold` is the maximum number of elements for which an
943 # array is shown in full.
944 # These values tend to work well for use with OptimizeResult.
945 with np.printoptions(linewidth=76-n, edgeitems=2, threshold=12,
946 formatter={'float_kind': _float_formatter_10}):
947 s = str(d)
948 return s