Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/_lib/_array_api.py: 19%
153 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-23 06:43 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-23 06:43 +0000
1"""Utility functions to use Python Array API compatible libraries.
3For the context about the Array API see:
4https://data-apis.org/array-api/latest/purpose_and_scope.html
6The SciPy use case of the Array API is described on the following page:
7https://data-apis.org/array-api/latest/use_cases.html#use-case-scipy
8"""
9from __future__ import annotations
11import os
12import warnings
14import numpy as np
15from numpy.testing import assert_
16import scipy._lib.array_api_compat.array_api_compat as array_api_compat
17from scipy._lib.array_api_compat.array_api_compat import size
18import scipy._lib.array_api_compat.array_api_compat.numpy as array_api_compat_numpy
20__all__ = ['array_namespace', 'as_xparray', 'size']
23# To enable array API and strict array-like input validation
24SCIPY_ARRAY_API: str | bool = os.environ.get("SCIPY_ARRAY_API", False)
25# To control the default device - for use in the test suite only
26SCIPY_DEVICE = os.environ.get("SCIPY_DEVICE", "cpu")
28_GLOBAL_CONFIG = {
29 "SCIPY_ARRAY_API": SCIPY_ARRAY_API,
30 "SCIPY_DEVICE": SCIPY_DEVICE,
31}
34def compliance_scipy(arrays):
35 """Raise exceptions on known-bad subclasses.
37 The following subclasses are not supported and raise and error:
38 - `np.ma.MaskedArray`
39 - `numpy.matrix`
40 - Any array-like which is not Array API compatible or coercible by numpy
41 - object arrays
42 """
43 for i in range(len(arrays)):
44 array = arrays[i]
45 if isinstance(array, np.ma.MaskedArray):
46 raise TypeError("'numpy.ma.MaskedArray' are not supported")
47 elif isinstance(array, np.matrix):
48 raise TypeError("'numpy.matrix' are not supported")
49 elif not array_api_compat.is_array_api_obj(array):
50 try:
51 array = np.asanyarray(array)
52 except TypeError:
53 raise TypeError("Array is not Array API compatible or "
54 "coercible by numpy")
55 if array.dtype is np.dtype('O'):
56 raise TypeError("An argument was coerced to an object array, "
57 "but object arrays are not supported.")
58 arrays[i] = array
59 elif array.dtype is np.dtype('O'):
60 raise TypeError('object arrays are not supported')
61 return arrays
64def _check_finite(array, xp):
65 """Check for NaNs or Infs."""
66 msg = "array must not contain infs or NaNs"
67 try:
68 if not xp.all(xp.isfinite(array)):
69 raise ValueError(msg)
70 except TypeError:
71 raise ValueError(msg)
74def array_namespace(*arrays):
75 """Get the array API compatible namespace for the arrays xs.
77 Parameters
78 ----------
79 *arrays : sequence of array_like
80 Arrays used to infer the common namespace.
82 Returns
83 -------
84 namespace : module
85 Common namespace.
87 Notes
88 -----
89 Thin wrapper around `array_api_compat.array_namespace`.
91 1. Check for the global switch: SCIPY_ARRAY_API. This can also be accessed
92 dynamically through ``_GLOBAL_CONFIG['SCIPY_ARRAY_API']``.
93 2. `compliance_scipy` raise exceptions on known-bad subclasses. See
94 it's definition for more details.
96 When the global switch is False, it defaults to the `numpy` namespace.
97 In that case, there is no compliance check. This is a convenience to
98 ease the adoption. Otherwise, arrays must comply with the new rules.
99 """
100 if not _GLOBAL_CONFIG["SCIPY_ARRAY_API"]:
101 # here we could wrap the namespace if needed
102 return array_api_compat_numpy
104 arrays = [array for array in arrays if array is not None]
106 arrays = compliance_scipy(arrays)
108 return array_api_compat.array_namespace(*arrays)
111def as_xparray(
112 array, dtype=None, order=None, copy=None, *, xp=None, check_finite=False
113):
114 """SciPy-specific replacement for `np.asarray` with `order` and `check_finite`.
116 Memory layout parameter `order` is not exposed in the Array API standard.
117 `order` is only enforced if the input array implementation
118 is NumPy based, otherwise `order` is just silently ignored.
120 `check_finite` is also not a keyword in the array API standard; included
121 here for convenience rather than that having to be a separate function
122 call inside SciPy functions.
123 """
124 if xp is None:
125 xp = array_namespace(array)
126 if xp.__name__ in {"numpy", "scipy._lib.array_api_compat.array_api_compat.numpy"}:
127 # Use NumPy API to support order
128 if copy is True:
129 array = np.array(array, order=order, dtype=dtype)
130 else:
131 array = np.asarray(array, order=order, dtype=dtype)
133 # At this point array is a NumPy ndarray. We convert it to an array
134 # container that is consistent with the input's namespace.
135 array = xp.asarray(array)
136 else:
137 try:
138 array = xp.asarray(array, dtype=dtype, copy=copy)
139 except TypeError:
140 coerced_xp = array_namespace(xp.asarray(3))
141 array = coerced_xp.asarray(array, dtype=dtype, copy=copy)
143 if check_finite:
144 _check_finite(array, xp)
146 return array
149def atleast_nd(x, *, ndim, xp=None):
150 """Recursively expand the dimension to have at least `ndim`."""
151 if xp is None:
152 xp = array_namespace(x)
153 x = xp.asarray(x)
154 if x.ndim < ndim:
155 x = xp.expand_dims(x, axis=0)
156 x = atleast_nd(x, ndim=ndim, xp=xp)
157 return x
160def copy(x, *, xp=None):
161 """
162 Copies an array.
164 Parameters
165 ----------
166 x : array
168 xp : array_namespace
170 Returns
171 -------
172 copy : array
173 Copied array
175 Notes
176 -----
177 This copy function does not offer all the semantics of `np.copy`, i.e. the
178 `subok` and `order` keywords are not used.
179 """
180 # Note: xp.asarray fails if xp is numpy.
181 if xp is None:
182 xp = array_namespace(x)
184 return as_xparray(x, copy=True, xp=xp)
187def is_numpy(xp):
188 return xp.__name__ == 'scipy._lib.array_api_compat.array_api_compat.numpy'
191def is_cupy(xp):
192 return xp.__name__ == 'scipy._lib.array_api_compat.array_api_compat.cupy'
195def is_torch(xp):
196 return xp.__name__ == 'scipy._lib.array_api_compat.array_api_compat.torch'
199def _strict_check(actual, desired, xp,
200 check_namespace=True, check_dtype=True, check_shape=True):
201 if check_namespace:
202 _assert_matching_namespace(actual, desired)
204 desired = xp.asarray(desired)
206 if check_dtype:
207 assert_(actual.dtype == desired.dtype,
208 "dtypes do not match.\n"
209 f"Actual: {actual.dtype}\n"
210 f"Desired: {desired.dtype}")
212 if check_shape:
213 assert_(actual.shape == desired.shape,
214 "Shapes do not match.\n"
215 f"Actual: {actual.shape}\n"
216 f"Desired: {desired.shape}")
217 _check_scalar(actual, desired, xp)
219 desired = xp.broadcast_to(desired, actual.shape)
220 return desired
223def _assert_matching_namespace(actual, desired):
224 actual = actual if isinstance(actual, tuple) else (actual,)
225 desired_space = array_namespace(desired)
226 for arr in actual:
227 arr_space = array_namespace(arr)
228 assert_(arr_space == desired_space,
229 "Namespaces do not match.\n"
230 f"Actual: {arr_space.__name__}\n"
231 f"Desired: {desired_space.__name__}")
234def _check_scalar(actual, desired, xp):
235 # Shape check alone is sufficient unless desired.shape == (). Also,
236 # only NumPy distinguishes between scalars and arrays.
237 if desired.shape != () or not is_numpy(xp):
238 return
239 # We want to follow the conventions of the `xp` library. Libraries like
240 # NumPy, for which `np.asarray(0)[()]` returns a scalar, tend to return
241 # a scalar even when a 0D array might be more appropriate:
242 # import numpy as np
243 # np.mean([1, 2, 3]) # scalar, not 0d array
244 # np.asarray(0)*2 # scalar, not 0d array
245 # np.sin(np.asarray(0)) # scalar, not 0d array
246 # Libraries like CuPy, for which `cp.asarray(0)[()]` returns a 0D array,
247 # tend to return a 0D array in scenarios like those above.
248 # Therefore, regardless of whether the developer provides a scalar or 0D
249 # array for `desired`, we would typically want the type of `actual` to be
250 # the type of `desired[()]`. If the developer wants to override this
251 # behavior, they can set `check_shape=False`.
252 desired = desired[()]
253 assert_((xp.isscalar(actual) and xp.isscalar(desired)
254 or (not xp.isscalar(actual) and not xp.isscalar(desired))),
255 "Types do not match:\n"
256 f"Actual: {type(actual)}\n"
257 f"Desired: {type(desired)}")
260def xp_assert_equal(actual, desired, check_namespace=True, check_dtype=True,
261 check_shape=True, err_msg='', xp=None):
262 if xp is None:
263 xp = array_namespace(actual)
264 desired = _strict_check(actual, desired, xp, check_namespace=check_namespace,
265 check_dtype=check_dtype, check_shape=check_shape)
266 if is_cupy(xp):
267 return xp.testing.assert_array_equal(actual, desired, err_msg=err_msg)
268 elif is_torch(xp):
269 # PyTorch recommends using `rtol=0, atol=0` like this
270 # to test for exact equality
271 err_msg = None if err_msg == '' else err_msg
272 return xp.testing.assert_close(actual, desired, rtol=0, atol=0, equal_nan=True,
273 check_dtype=False, msg=err_msg)
274 return np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
277def xp_assert_close(actual, desired, rtol=1e-07, atol=0, check_namespace=True,
278 check_dtype=True, check_shape=True, err_msg='', xp=None):
279 if xp is None:
280 xp = array_namespace(actual)
281 desired = _strict_check(actual, desired, xp, check_namespace=check_namespace,
282 check_dtype=check_dtype, check_shape=check_shape)
283 if is_cupy(xp):
284 return xp.testing.assert_allclose(actual, desired, rtol=rtol,
285 atol=atol, err_msg=err_msg)
286 elif is_torch(xp):
287 err_msg = None if err_msg == '' else err_msg
288 return xp.testing.assert_close(actual, desired, rtol=rtol, atol=atol,
289 equal_nan=True, check_dtype=False, msg=err_msg)
290 return np.testing.assert_allclose(actual, desired, rtol=rtol,
291 atol=atol, err_msg=err_msg)
294def xp_assert_less(actual, desired, check_namespace=True, check_dtype=True,
295 check_shape=True, err_msg='', verbose=True, xp=None):
296 if xp is None:
297 xp = array_namespace(actual)
298 desired = _strict_check(actual, desired, xp, check_namespace=check_namespace,
299 check_dtype=check_dtype, check_shape=check_shape)
300 if is_cupy(xp):
301 return xp.testing.assert_array_less(actual, desired,
302 err_msg=err_msg, verbose=verbose)
303 elif is_torch(xp):
304 if actual.device.type != 'cpu':
305 actual = actual.cpu()
306 if desired.device.type != 'cpu':
307 desired = desired.cpu()
308 return np.testing.assert_array_less(actual, desired,
309 err_msg=err_msg, verbose=verbose)
312def cov(x, *, xp=None):
313 if xp is None:
314 xp = array_namespace(x)
316 X = copy(x, xp=xp)
317 dtype = xp.result_type(X, xp.float64)
319 X = atleast_nd(X, ndim=2, xp=xp)
320 X = xp.asarray(X, dtype=dtype)
322 avg = xp.mean(X, axis=1)
323 fact = X.shape[1] - 1
325 if fact <= 0:
326 warnings.warn("Degrees of freedom <= 0 for slice",
327 RuntimeWarning, stacklevel=2)
328 fact = 0.0
330 X -= avg[:, None]
331 X_T = X.T
332 if xp.isdtype(X_T.dtype, 'complex floating'):
333 X_T = xp.conj(X_T)
334 c = X @ X_T
335 c /= fact
336 axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1)
337 return xp.squeeze(c, axis=axes)
340def xp_unsupported_param_msg(param):
341 return f'Providing {param!r} is only supported for numpy arrays.'