Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/_lib/_util.py: 3%

357 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-03 06:39 +0000

1import re 

2from contextlib import contextmanager 

3import functools 

4import operator 

5import warnings 

6import numbers 

7from collections import namedtuple 

8import inspect 

9import math 

10from typing import ( 

11 Optional, 

12 Union, 

13 TYPE_CHECKING, 

14 TypeVar, 

15) 

16 

17import numpy as np 

18from scipy._lib._array_api import array_namespace 

19 

20 

21AxisError: type[Exception] 

22ComplexWarning: type[Warning] 

23VisibleDeprecationWarning: type[Warning] 

24 

25if np.lib.NumpyVersion(np.__version__) >= '1.25.0': 

26 from numpy.exceptions import ( 

27 AxisError, ComplexWarning, VisibleDeprecationWarning, 

28 DTypePromotionError 

29 ) 

30else: 

31 from numpy import ( 

32 AxisError, ComplexWarning, VisibleDeprecationWarning # noqa: F401 

33 ) 

34 DTypePromotionError = TypeError # type: ignore 

35 

36np_long: type 

37np_ulong: type 

38 

39if np.lib.NumpyVersion(np.__version__) >= "2.0.0.dev0": 

40 try: 

41 with warnings.catch_warnings(): 

42 warnings.filterwarnings( 

43 "ignore", 

44 r".*In the future `np\.long` will be defined as.*", 

45 FutureWarning, 

46 ) 

47 np_long = np.long # type: ignore[attr-defined] 

48 np_ulong = np.ulong # type: ignore[attr-defined] 

49 except AttributeError: 

50 np_long = np.int_ 

51 np_ulong = np.uint 

52else: 

53 np_long = np.int_ 

54 np_ulong = np.uint 

55 

56IntNumber = Union[int, np.integer] 

57DecimalNumber = Union[float, np.floating, np.integer] 

58 

59copy_if_needed: Optional[bool] 

60 

61if np.lib.NumpyVersion(np.__version__) >= "2.0.0": 

62 copy_if_needed = None 

63elif np.lib.NumpyVersion(np.__version__) < "1.28.0": 

64 copy_if_needed = False 

65else: 

66 # 2.0.0 dev versions, handle cases where copy may or may not exist 

67 try: 

68 np.array([1]).__array__(copy=None) # type: ignore[call-overload] 

69 copy_if_needed = None 

70 except TypeError: 

71 copy_if_needed = False 

72 

73# Since Generator was introduced in numpy 1.17, the following condition is needed for 

74# backward compatibility 

75if TYPE_CHECKING: 

76 SeedType = Optional[Union[IntNumber, np.random.Generator, 

77 np.random.RandomState]] 

78 GeneratorType = TypeVar("GeneratorType", bound=Union[np.random.Generator, 

79 np.random.RandomState]) 

80 

81try: 

82 from numpy.random import Generator as Generator 

83except ImportError: 

84 class Generator: # type: ignore[no-redef] 

85 pass 

86 

87 

88def _lazywhere(cond, arrays, f, fillvalue=None, f2=None): 

89 """Return elements chosen from two possibilities depending on a condition 

90 

91 Equivalent to ``f(*arrays) if cond else fillvalue`` performed elementwise. 

92 

93 Parameters 

94 ---------- 

95 cond : array 

96 The condition (expressed as a boolean array). 

97 arrays : tuple of array 

98 Arguments to `f` (and `f2`). Must be broadcastable with `cond`. 

99 f : callable 

100 Where `cond` is True, output will be ``f(arr1[cond], arr2[cond], ...)`` 

101 fillvalue : object 

102 If provided, value with which to fill output array where `cond` is 

103 not True. 

104 f2 : callable 

105 If provided, output will be ``f2(arr1[cond], arr2[cond], ...)`` where 

106 `cond` is not True. 

107 

108 Returns 

109 ------- 

110 out : array 

111 An array with elements from the output of `f` where `cond` is True 

112 and `fillvalue` (or elements from the output of `f2`) elsewhere. The 

113 returned array has data type determined by Type Promotion Rules 

114 with the output of `f` and `fillvalue` (or the output of `f2`). 

115 

116 Notes 

117 ----- 

118 ``xp.where(cond, x, fillvalue)`` requires explicitly forming `x` even where 

119 `cond` is False. This function evaluates ``f(arr1[cond], arr2[cond], ...)`` 

120 onle where `cond` ``is True. 

121 

122 Examples 

123 -------- 

124 >>> import numpy as np 

125 >>> a, b = np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]) 

126 >>> def f(a, b): 

127 ... return a*b 

128 >>> _lazywhere(a > 2, (a, b), f, np.nan) 

129 array([ nan, nan, 21., 32.]) 

130 

131 """ 

132 xp = array_namespace(cond, *arrays) 

133 

134 if (f2 is fillvalue is None) or (f2 is not None and fillvalue is not None): 

135 raise ValueError("Exactly one of `fillvalue` or `f2` must be given.") 

136 

137 args = xp.broadcast_arrays(cond, *arrays) 

138 bool_dtype = xp.asarray([True]).dtype # numpy 1.xx doesn't have `bool` 

139 cond, arrays = xp.astype(args[0], bool_dtype, copy=False), args[1:] 

140 

141 temp1 = xp.asarray(f(*(arr[cond] for arr in arrays))) 

142 

143 if f2 is None: 

144 fillvalue = xp.asarray(fillvalue) 

145 dtype = xp.result_type(temp1.dtype, fillvalue.dtype) 

146 out = xp.full(cond.shape, fill_value=fillvalue, dtype=dtype) 

147 else: 

148 ncond = ~cond 

149 temp2 = xp.asarray(f2(*(arr[ncond] for arr in arrays))) 

150 dtype = xp.result_type(temp1, temp2) 

151 out = xp.empty(cond.shape, dtype=dtype) 

152 out[ncond] = temp2 

153 

154 out[cond] = temp1 

155 

156 return out 

157 

158 

159def _lazyselect(condlist, choicelist, arrays, default=0): 

160 """ 

161 Mimic `np.select(condlist, choicelist)`. 

162 

163 Notice, it assumes that all `arrays` are of the same shape or can be 

164 broadcasted together. 

165 

166 All functions in `choicelist` must accept array arguments in the order 

167 given in `arrays` and must return an array of the same shape as broadcasted 

168 `arrays`. 

169 

170 Examples 

171 -------- 

172 >>> import numpy as np 

173 >>> x = np.arange(6) 

174 >>> np.select([x <3, x > 3], [x**2, x**3], default=0) 

175 array([ 0, 1, 4, 0, 64, 125]) 

176 

177 >>> _lazyselect([x < 3, x > 3], [lambda x: x**2, lambda x: x**3], (x,)) 

178 array([ 0., 1., 4., 0., 64., 125.]) 

179 

180 >>> a = -np.ones_like(x) 

181 >>> _lazyselect([x < 3, x > 3], 

182 ... [lambda x, a: x**2, lambda x, a: a * x**3], 

183 ... (x, a), default=np.nan) 

184 array([ 0., 1., 4., nan, -64., -125.]) 

185 

186 """ 

187 arrays = np.broadcast_arrays(*arrays) 

188 tcode = np.mintypecode([a.dtype.char for a in arrays]) 

189 out = np.full(np.shape(arrays[0]), fill_value=default, dtype=tcode) 

190 for func, cond in zip(choicelist, condlist): 

191 if np.all(cond is False): 

192 continue 

193 cond, _ = np.broadcast_arrays(cond, arrays[0]) 

194 temp = tuple(np.extract(cond, arr) for arr in arrays) 

195 np.place(out, cond, func(*temp)) 

196 return out 

197 

198 

199def _aligned_zeros(shape, dtype=float, order="C", align=None): 

200 """Allocate a new ndarray with aligned memory. 

201 

202 Primary use case for this currently is working around a f2py issue 

203 in NumPy 1.9.1, where dtype.alignment is such that np.zeros() does 

204 not necessarily create arrays aligned up to it. 

205 

206 """ 

207 dtype = np.dtype(dtype) 

208 if align is None: 

209 align = dtype.alignment 

210 if not hasattr(shape, '__len__'): 

211 shape = (shape,) 

212 size = functools.reduce(operator.mul, shape) * dtype.itemsize 

213 buf = np.empty(size + align + 1, np.uint8) 

214 offset = buf.__array_interface__['data'][0] % align 

215 if offset != 0: 

216 offset = align - offset 

217 # Note: slices producing 0-size arrays do not necessarily change 

218 # data pointer --- so we use and allocate size+1 

219 buf = buf[offset:offset+size+1][:-1] 

220 data = np.ndarray(shape, dtype, buf, order=order) 

221 data.fill(0) 

222 return data 

223 

224 

225def _prune_array(array): 

226 """Return an array equivalent to the input array. If the input 

227 array is a view of a much larger array, copy its contents to a 

228 newly allocated array. Otherwise, return the input unchanged. 

229 """ 

230 if array.base is not None and array.size < array.base.size // 2: 

231 return array.copy() 

232 return array 

233 

234 

235def float_factorial(n: int) -> float: 

236 """Compute the factorial and return as a float 

237 

238 Returns infinity when result is too large for a double 

239 """ 

240 return float(math.factorial(n)) if n < 171 else np.inf 

241 

242 

243# copy-pasted from scikit-learn utils/validation.py 

244# change this to scipy.stats._qmc.check_random_state once numpy 1.16 is dropped 

245def check_random_state(seed): 

246 """Turn `seed` into a `np.random.RandomState` instance. 

247 

248 Parameters 

249 ---------- 

250 seed : {None, int, `numpy.random.Generator`, `numpy.random.RandomState`}, optional 

251 If `seed` is None (or `np.random`), the `numpy.random.RandomState` 

252 singleton is used. 

253 If `seed` is an int, a new ``RandomState`` instance is used, 

254 seeded with `seed`. 

255 If `seed` is already a ``Generator`` or ``RandomState`` instance then 

256 that instance is used. 

257 

258 Returns 

259 ------- 

260 seed : {`numpy.random.Generator`, `numpy.random.RandomState`} 

261 Random number generator. 

262 

263 """ 

264 if seed is None or seed is np.random: 

265 return np.random.mtrand._rand 

266 if isinstance(seed, (numbers.Integral, np.integer)): 

267 return np.random.RandomState(seed) 

268 if isinstance(seed, (np.random.RandomState, np.random.Generator)): 

269 return seed 

270 

271 raise ValueError('%r cannot be used to seed a numpy.random.RandomState' 

272 ' instance' % seed) 

273 

274 

275def _asarray_validated(a, check_finite=True, 

276 sparse_ok=False, objects_ok=False, mask_ok=False, 

277 as_inexact=False): 

278 """ 

279 Helper function for SciPy argument validation. 

280 

281 Many SciPy linear algebra functions do support arbitrary array-like 

282 input arguments. Examples of commonly unsupported inputs include 

283 matrices containing inf/nan, sparse matrix representations, and 

284 matrices with complicated elements. 

285 

286 Parameters 

287 ---------- 

288 a : array_like 

289 The array-like input. 

290 check_finite : bool, optional 

291 Whether to check that the input matrices contain only finite numbers. 

292 Disabling may give a performance gain, but may result in problems 

293 (crashes, non-termination) if the inputs do contain infinities or NaNs. 

294 Default: True 

295 sparse_ok : bool, optional 

296 True if scipy sparse matrices are allowed. 

297 objects_ok : bool, optional 

298 True if arrays with dype('O') are allowed. 

299 mask_ok : bool, optional 

300 True if masked arrays are allowed. 

301 as_inexact : bool, optional 

302 True to convert the input array to a np.inexact dtype. 

303 

304 Returns 

305 ------- 

306 ret : ndarray 

307 The converted validated array. 

308 

309 """ 

310 if not sparse_ok: 

311 import scipy.sparse 

312 if scipy.sparse.issparse(a): 

313 msg = ('Sparse matrices are not supported by this function. ' 

314 'Perhaps one of the scipy.sparse.linalg functions ' 

315 'would work instead.') 

316 raise ValueError(msg) 

317 if not mask_ok: 

318 if np.ma.isMaskedArray(a): 

319 raise ValueError('masked arrays are not supported') 

320 toarray = np.asarray_chkfinite if check_finite else np.asarray 

321 a = toarray(a) 

322 if not objects_ok: 

323 if a.dtype is np.dtype('O'): 

324 raise ValueError('object arrays are not supported') 

325 if as_inexact: 

326 if not np.issubdtype(a.dtype, np.inexact): 

327 a = toarray(a, dtype=np.float64) 

328 return a 

329 

330 

331def _validate_int(k, name, minimum=None): 

332 """ 

333 Validate a scalar integer. 

334 

335 This function can be used to validate an argument to a function 

336 that expects the value to be an integer. It uses `operator.index` 

337 to validate the value (so, for example, k=2.0 results in a 

338 TypeError). 

339 

340 Parameters 

341 ---------- 

342 k : int 

343 The value to be validated. 

344 name : str 

345 The name of the parameter. 

346 minimum : int, optional 

347 An optional lower bound. 

348 """ 

349 try: 

350 k = operator.index(k) 

351 except TypeError: 

352 raise TypeError(f'{name} must be an integer.') from None 

353 if minimum is not None and k < minimum: 

354 raise ValueError(f'{name} must be an integer not less ' 

355 f'than {minimum}') from None 

356 return k 

357 

358 

359# Add a replacement for inspect.getfullargspec()/ 

360# The version below is borrowed from Django, 

361# https://github.com/django/django/pull/4846. 

362 

363# Note an inconsistency between inspect.getfullargspec(func) and 

364# inspect.signature(func). If `func` is a bound method, the latter does *not* 

365# list `self` as a first argument, while the former *does*. 

366# Hence, cook up a common ground replacement: `getfullargspec_no_self` which 

367# mimics `inspect.getfullargspec` but does not list `self`. 

368# 

369# This way, the caller code does not need to know whether it uses a legacy 

370# .getfullargspec or a bright and shiny .signature. 

371 

372FullArgSpec = namedtuple('FullArgSpec', 

373 ['args', 'varargs', 'varkw', 'defaults', 

374 'kwonlyargs', 'kwonlydefaults', 'annotations']) 

375 

376 

377def getfullargspec_no_self(func): 

378 """inspect.getfullargspec replacement using inspect.signature. 

379 

380 If func is a bound method, do not list the 'self' parameter. 

381 

382 Parameters 

383 ---------- 

384 func : callable 

385 A callable to inspect 

386 

387 Returns 

388 ------- 

389 fullargspec : FullArgSpec(args, varargs, varkw, defaults, kwonlyargs, 

390 kwonlydefaults, annotations) 

391 

392 NOTE: if the first argument of `func` is self, it is *not*, I repeat 

393 *not*, included in fullargspec.args. 

394 This is done for consistency between inspect.getargspec() under 

395 Python 2.x, and inspect.signature() under Python 3.x. 

396 

397 """ 

398 sig = inspect.signature(func) 

399 args = [ 

400 p.name for p in sig.parameters.values() 

401 if p.kind in [inspect.Parameter.POSITIONAL_OR_KEYWORD, 

402 inspect.Parameter.POSITIONAL_ONLY] 

403 ] 

404 varargs = [ 

405 p.name for p in sig.parameters.values() 

406 if p.kind == inspect.Parameter.VAR_POSITIONAL 

407 ] 

408 varargs = varargs[0] if varargs else None 

409 varkw = [ 

410 p.name for p in sig.parameters.values() 

411 if p.kind == inspect.Parameter.VAR_KEYWORD 

412 ] 

413 varkw = varkw[0] if varkw else None 

414 defaults = tuple( 

415 p.default for p in sig.parameters.values() 

416 if (p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and 

417 p.default is not p.empty) 

418 ) or None 

419 kwonlyargs = [ 

420 p.name for p in sig.parameters.values() 

421 if p.kind == inspect.Parameter.KEYWORD_ONLY 

422 ] 

423 kwdefaults = {p.name: p.default for p in sig.parameters.values() 

424 if p.kind == inspect.Parameter.KEYWORD_ONLY and 

425 p.default is not p.empty} 

426 annotations = {p.name: p.annotation for p in sig.parameters.values() 

427 if p.annotation is not p.empty} 

428 return FullArgSpec(args, varargs, varkw, defaults, kwonlyargs, 

429 kwdefaults or None, annotations) 

430 

431 

432class _FunctionWrapper: 

433 """ 

434 Object to wrap user's function, allowing picklability 

435 """ 

436 def __init__(self, f, args): 

437 self.f = f 

438 self.args = [] if args is None else args 

439 

440 def __call__(self, x): 

441 return self.f(x, *self.args) 

442 

443 

444class MapWrapper: 

445 """ 

446 Parallelisation wrapper for working with map-like callables, such as 

447 `multiprocessing.Pool.map`. 

448 

449 Parameters 

450 ---------- 

451 pool : int or map-like callable 

452 If `pool` is an integer, then it specifies the number of threads to 

453 use for parallelization. If ``int(pool) == 1``, then no parallel 

454 processing is used and the map builtin is used. 

455 If ``pool == -1``, then the pool will utilize all available CPUs. 

456 If `pool` is a map-like callable that follows the same 

457 calling sequence as the built-in map function, then this callable is 

458 used for parallelization. 

459 """ 

460 def __init__(self, pool=1): 

461 self.pool = None 

462 self._mapfunc = map 

463 self._own_pool = False 

464 

465 if callable(pool): 

466 self.pool = pool 

467 self._mapfunc = self.pool 

468 else: 

469 from multiprocessing import Pool 

470 # user supplies a number 

471 if int(pool) == -1: 

472 # use as many processors as possible 

473 self.pool = Pool() 

474 self._mapfunc = self.pool.map 

475 self._own_pool = True 

476 elif int(pool) == 1: 

477 pass 

478 elif int(pool) > 1: 

479 # use the number of processors requested 

480 self.pool = Pool(processes=int(pool)) 

481 self._mapfunc = self.pool.map 

482 self._own_pool = True 

483 else: 

484 raise RuntimeError("Number of workers specified must be -1," 

485 " an int >= 1, or an object with a 'map' " 

486 "method") 

487 

488 def __enter__(self): 

489 return self 

490 

491 def terminate(self): 

492 if self._own_pool: 

493 self.pool.terminate() 

494 

495 def join(self): 

496 if self._own_pool: 

497 self.pool.join() 

498 

499 def close(self): 

500 if self._own_pool: 

501 self.pool.close() 

502 

503 def __exit__(self, exc_type, exc_value, traceback): 

504 if self._own_pool: 

505 self.pool.close() 

506 self.pool.terminate() 

507 

508 def __call__(self, func, iterable): 

509 # only accept one iterable because that's all Pool.map accepts 

510 try: 

511 return self._mapfunc(func, iterable) 

512 except TypeError as e: 

513 # wrong number of arguments 

514 raise TypeError("The map-like callable must be of the" 

515 " form f(func, iterable)") from e 

516 

517 

518def rng_integers(gen, low, high=None, size=None, dtype='int64', 

519 endpoint=False): 

520 """ 

521 Return random integers from low (inclusive) to high (exclusive), or if 

522 endpoint=True, low (inclusive) to high (inclusive). Replaces 

523 `RandomState.randint` (with endpoint=False) and 

524 `RandomState.random_integers` (with endpoint=True). 

525 

526 Return random integers from the "discrete uniform" distribution of the 

527 specified dtype. If high is None (the default), then results are from 

528 0 to low. 

529 

530 Parameters 

531 ---------- 

532 gen : {None, np.random.RandomState, np.random.Generator} 

533 Random number generator. If None, then the np.random.RandomState 

534 singleton is used. 

535 low : int or array-like of ints 

536 Lowest (signed) integers to be drawn from the distribution (unless 

537 high=None, in which case this parameter is 0 and this value is used 

538 for high). 

539 high : int or array-like of ints 

540 If provided, one above the largest (signed) integer to be drawn from 

541 the distribution (see above for behavior if high=None). If array-like, 

542 must contain integer values. 

543 size : array-like of ints, optional 

544 Output shape. If the given shape is, e.g., (m, n, k), then m * n * k 

545 samples are drawn. Default is None, in which case a single value is 

546 returned. 

547 dtype : {str, dtype}, optional 

548 Desired dtype of the result. All dtypes are determined by their name, 

549 i.e., 'int64', 'int', etc, so byteorder is not available and a specific 

550 precision may have different C types depending on the platform. 

551 The default value is 'int64'. 

552 endpoint : bool, optional 

553 If True, sample from the interval [low, high] instead of the default 

554 [low, high) Defaults to False. 

555 

556 Returns 

557 ------- 

558 out: int or ndarray of ints 

559 size-shaped array of random integers from the appropriate distribution, 

560 or a single such random int if size not provided. 

561 """ 

562 if isinstance(gen, Generator): 

563 return gen.integers(low, high=high, size=size, dtype=dtype, 

564 endpoint=endpoint) 

565 else: 

566 if gen is None: 

567 # default is RandomState singleton used by np.random. 

568 gen = np.random.mtrand._rand 

569 if endpoint: 

570 # inclusive of endpoint 

571 # remember that low and high can be arrays, so don't modify in 

572 # place 

573 if high is None: 

574 return gen.randint(low + 1, size=size, dtype=dtype) 

575 if high is not None: 

576 return gen.randint(low, high=high + 1, size=size, dtype=dtype) 

577 

578 # exclusive 

579 return gen.randint(low, high=high, size=size, dtype=dtype) 

580 

581 

582@contextmanager 

583def _fixed_default_rng(seed=1638083107694713882823079058616272161): 

584 """Context with a fixed np.random.default_rng seed.""" 

585 orig_fun = np.random.default_rng 

586 np.random.default_rng = lambda seed=seed: orig_fun(seed) 

587 try: 

588 yield 

589 finally: 

590 np.random.default_rng = orig_fun 

591 

592 

593def _rng_html_rewrite(func): 

594 """Rewrite the HTML rendering of ``np.random.default_rng``. 

595 

596 This is intended to decorate 

597 ``numpydoc.docscrape_sphinx.SphinxDocString._str_examples``. 

598 

599 Examples are only run by Sphinx when there are plot involved. Even so, 

600 it does not change the result values getting printed. 

601 """ 

602 # hexadecimal or number seed, case-insensitive 

603 pattern = re.compile(r'np.random.default_rng\((0x[0-9A-F]+|\d+)\)', re.I) 

604 

605 def _wrapped(*args, **kwargs): 

606 res = func(*args, **kwargs) 

607 lines = [ 

608 re.sub(pattern, 'np.random.default_rng()', line) 

609 for line in res 

610 ] 

611 return lines 

612 

613 return _wrapped 

614 

615 

616def _argmin(a, keepdims=False, axis=None): 

617 """ 

618 argmin with a `keepdims` parameter. 

619 

620 See https://github.com/numpy/numpy/issues/8710 

621 

622 If axis is not None, a.shape[axis] must be greater than 0. 

623 """ 

624 res = np.argmin(a, axis=axis) 

625 if keepdims and axis is not None: 

626 res = np.expand_dims(res, axis=axis) 

627 return res 

628 

629 

630def _first_nonnan(a, axis): 

631 """ 

632 Return the first non-nan value along the given axis. 

633 

634 If a slice is all nan, nan is returned for that slice. 

635 

636 The shape of the return value corresponds to ``keepdims=True``. 

637 

638 Examples 

639 -------- 

640 >>> import numpy as np 

641 >>> nan = np.nan 

642 >>> a = np.array([[ 3., 3., nan, 3.], 

643 [ 1., nan, 2., 4.], 

644 [nan, nan, 9., -1.], 

645 [nan, 5., 4., 3.], 

646 [ 2., 2., 2., 2.], 

647 [nan, nan, nan, nan]]) 

648 >>> _first_nonnan(a, axis=0) 

649 array([[3., 3., 2., 3.]]) 

650 >>> _first_nonnan(a, axis=1) 

651 array([[ 3.], 

652 [ 1.], 

653 [ 9.], 

654 [ 5.], 

655 [ 2.], 

656 [nan]]) 

657 """ 

658 k = _argmin(np.isnan(a), axis=axis, keepdims=True) 

659 return np.take_along_axis(a, k, axis=axis) 

660 

661 

662def _nan_allsame(a, axis, keepdims=False): 

663 """ 

664 Determine if the values along an axis are all the same. 

665 

666 nan values are ignored. 

667 

668 `a` must be a numpy array. 

669 

670 `axis` is assumed to be normalized; that is, 0 <= axis < a.ndim. 

671 

672 For an axis of length 0, the result is True. That is, we adopt the 

673 convention that ``allsame([])`` is True. (There are no values in the 

674 input that are different.) 

675 

676 `True` is returned for slices that are all nan--not because all the 

677 values are the same, but because this is equivalent to ``allsame([])``. 

678 

679 Examples 

680 -------- 

681 >>> from numpy import nan, array 

682 >>> a = array([[ 3., 3., nan, 3.], 

683 ... [ 1., nan, 2., 4.], 

684 ... [nan, nan, 9., -1.], 

685 ... [nan, 5., 4., 3.], 

686 ... [ 2., 2., 2., 2.], 

687 ... [nan, nan, nan, nan]]) 

688 >>> _nan_allsame(a, axis=1, keepdims=True) 

689 array([[ True], 

690 [False], 

691 [False], 

692 [False], 

693 [ True], 

694 [ True]]) 

695 """ 

696 if axis is None: 

697 if a.size == 0: 

698 return True 

699 a = a.ravel() 

700 axis = 0 

701 else: 

702 shp = a.shape 

703 if shp[axis] == 0: 

704 shp = shp[:axis] + (1,)*keepdims + shp[axis + 1:] 

705 return np.full(shp, fill_value=True, dtype=bool) 

706 a0 = _first_nonnan(a, axis=axis) 

707 return ((a0 == a) | np.isnan(a)).all(axis=axis, keepdims=keepdims) 

708 

709 

710def _contains_nan(a, nan_policy='propagate', use_summation=True, 

711 policies=None): 

712 if not isinstance(a, np.ndarray): 

713 use_summation = False # some array_likes ignore nans (e.g. pandas) 

714 if policies is None: 

715 policies = ['propagate', 'raise', 'omit'] 

716 if nan_policy not in policies: 

717 raise ValueError("nan_policy must be one of {%s}" % 

718 ', '.join("'%s'" % s for s in policies)) 

719 

720 if np.issubdtype(a.dtype, np.inexact): 

721 # The summation method avoids creating a (potentially huge) array. 

722 if use_summation: 

723 with np.errstate(invalid='ignore', over='ignore'): 

724 contains_nan = np.isnan(np.sum(a)) 

725 else: 

726 contains_nan = np.isnan(a).any() 

727 elif np.issubdtype(a.dtype, object): 

728 contains_nan = False 

729 for el in a.ravel(): 

730 # isnan doesn't work on non-numeric elements 

731 if np.issubdtype(type(el), np.number) and np.isnan(el): 

732 contains_nan = True 

733 break 

734 else: 

735 # Only `object` and `inexact` arrays can have NaNs 

736 contains_nan = False 

737 

738 if contains_nan and nan_policy == 'raise': 

739 raise ValueError("The input contains nan values") 

740 

741 return contains_nan, nan_policy 

742 

743 

744def _rename_parameter(old_name, new_name, dep_version=None): 

745 """ 

746 Generate decorator for backward-compatible keyword renaming. 

747 

748 Apply the decorator generated by `_rename_parameter` to functions with a 

749 recently renamed parameter to maintain backward-compatibility. 

750 

751 After decoration, the function behaves as follows: 

752 If only the new parameter is passed into the function, behave as usual. 

753 If only the old parameter is passed into the function (as a keyword), raise 

754 a DeprecationWarning if `dep_version` is provided, and behave as usual 

755 otherwise. 

756 If both old and new parameters are passed into the function, raise a 

757 DeprecationWarning if `dep_version` is provided, and raise the appropriate 

758 TypeError (function got multiple values for argument). 

759 

760 Parameters 

761 ---------- 

762 old_name : str 

763 Old name of parameter 

764 new_name : str 

765 New name of parameter 

766 dep_version : str, optional 

767 Version of SciPy in which old parameter was deprecated in the format 

768 'X.Y.Z'. If supplied, the deprecation message will indicate that 

769 support for the old parameter will be removed in version 'X.Y+2.Z' 

770 

771 Notes 

772 ----- 

773 Untested with functions that accept *args. Probably won't work as written. 

774 

775 """ 

776 def decorator(fun): 

777 @functools.wraps(fun) 

778 def wrapper(*args, **kwargs): 

779 if old_name in kwargs: 

780 if dep_version: 

781 end_version = dep_version.split('.') 

782 end_version[1] = str(int(end_version[1]) + 2) 

783 end_version = '.'.join(end_version) 

784 message = (f"Use of keyword argument `{old_name}` is " 

785 f"deprecated and replaced by `{new_name}`. " 

786 f"Support for `{old_name}` will be removed " 

787 f"in SciPy {end_version}.") 

788 warnings.warn(message, DeprecationWarning, stacklevel=2) 

789 if new_name in kwargs: 

790 message = (f"{fun.__name__}() got multiple values for " 

791 f"argument now known as `{new_name}`") 

792 raise TypeError(message) 

793 kwargs[new_name] = kwargs.pop(old_name) 

794 return fun(*args, **kwargs) 

795 return wrapper 

796 return decorator 

797 

798 

799def _rng_spawn(rng, n_children): 

800 # spawns independent RNGs from a parent RNG 

801 bg = rng._bit_generator 

802 ss = bg._seed_seq 

803 child_rngs = [np.random.Generator(type(bg)(child_ss)) 

804 for child_ss in ss.spawn(n_children)] 

805 return child_rngs 

806 

807 

808def _get_nan(*data): 

809 # Get NaN of appropriate dtype for data 

810 data = [np.asarray(item) for item in data] 

811 try: 

812 dtype = np.result_type(*data, np.half) # must be a float16 at least 

813 except DTypePromotionError: 

814 # fallback to float64 

815 return np.array(np.nan, dtype=np.float64)[()] 

816 return np.array(np.nan, dtype=dtype)[()] 

817 

818 

819def normalize_axis_index(axis, ndim): 

820 # Check if `axis` is in the correct range and normalize it 

821 if axis < -ndim or axis >= ndim: 

822 msg = f"axis {axis} is out of bounds for array of dimension {ndim}" 

823 raise AxisError(msg) 

824 

825 if axis < 0: 

826 axis = axis + ndim 

827 return axis 

828 

829 

830def _call_callback_maybe_halt(callback, res): 

831 """Call wrapped callback; return True if algorithm should stop. 

832 

833 Parameters 

834 ---------- 

835 callback : callable or None 

836 A user-provided callback wrapped with `_wrap_callback` 

837 res : OptimizeResult 

838 Information about the current iterate 

839 

840 Returns 

841 ------- 

842 halt : bool 

843 True if minimization should stop 

844 

845 """ 

846 if callback is None: 

847 return False 

848 try: 

849 callback(res) 

850 return False 

851 except StopIteration: 

852 callback.stop_iteration = True 

853 return True 

854 

855 

856class _RichResult(dict): 

857 """ Container for multiple outputs with pretty-printing """ 

858 def __getattr__(self, name): 

859 try: 

860 return self[name] 

861 except KeyError as e: 

862 raise AttributeError(name) from e 

863 

864 __setattr__ = dict.__setitem__ 

865 __delattr__ = dict.__delitem__ 

866 

867 def __repr__(self): 

868 order_keys = ['message', 'success', 'status', 'fun', 'funl', 'x', 'xl', 

869 'col_ind', 'nit', 'lower', 'upper', 'eqlin', 'ineqlin', 

870 'converged', 'flag', 'function_calls', 'iterations', 

871 'root'] 

872 order_keys = getattr(self, '_order_keys', order_keys) 

873 # 'slack', 'con' are redundant with residuals 

874 # 'crossover_nit' is probably not interesting to most users 

875 omit_keys = {'slack', 'con', 'crossover_nit', '_order_keys'} 

876 

877 def key(item): 

878 try: 

879 return order_keys.index(item[0].lower()) 

880 except ValueError: # item not in list 

881 return np.inf 

882 

883 def omit_redundant(items): 

884 for item in items: 

885 if item[0] in omit_keys: 

886 continue 

887 yield item 

888 

889 def item_sorter(d): 

890 return sorted(omit_redundant(d.items()), key=key) 

891 

892 if self.keys(): 

893 return _dict_formatter(self, sorter=item_sorter) 

894 else: 

895 return self.__class__.__name__ + "()" 

896 

897 def __dir__(self): 

898 return list(self.keys()) 

899 

900 

901def _indenter(s, n=0): 

902 """ 

903 Ensures that lines after the first are indented by the specified amount 

904 """ 

905 split = s.split("\n") 

906 indent = " "*n 

907 return ("\n" + indent).join(split) 

908 

909 

910def _float_formatter_10(x): 

911 """ 

912 Returns a string representation of a float with exactly ten characters 

913 """ 

914 if np.isposinf(x): 

915 return " inf" 

916 elif np.isneginf(x): 

917 return " -inf" 

918 elif np.isnan(x): 

919 return " nan" 

920 return np.format_float_scientific(x, precision=3, pad_left=2, unique=False) 

921 

922 

923def _dict_formatter(d, n=0, mplus=1, sorter=None): 

924 """ 

925 Pretty printer for dictionaries 

926 

927 `n` keeps track of the starting indentation; 

928 lines are indented by this much after a line break. 

929 `mplus` is additional left padding applied to keys 

930 """ 

931 if isinstance(d, dict): 

932 m = max(map(len, list(d.keys()))) + mplus # width to print keys 

933 s = '\n'.join([k.rjust(m) + ': ' + # right justified, width m 

934 _indenter(_dict_formatter(v, m+n+2, 0, sorter), m+2) 

935 for k, v in sorter(d)]) # +2 for ': ' 

936 else: 

937 # By default, NumPy arrays print with linewidth=76. `n` is 

938 # the indent at which a line begins printing, so it is subtracted 

939 # from the default to avoid exceeding 76 characters total. 

940 # `edgeitems` is the number of elements to include before and after 

941 # ellipses when arrays are not shown in full. 

942 # `threshold` is the maximum number of elements for which an 

943 # array is shown in full. 

944 # These values tend to work well for use with OptimizeResult. 

945 with np.printoptions(linewidth=76-n, edgeitems=2, threshold=12, 

946 formatter={'float_kind': _float_formatter_10}): 

947 s = str(d) 

948 return s