Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/_lib/_array_api.py: 4%
167 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-03 06:39 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-03 06:39 +0000
1"""Utility functions to use Python Array API compatible libraries.
3For the context about the Array API see:
4https://data-apis.org/array-api/latest/purpose_and_scope.html
6The SciPy use case of the Array API is described on the following page:
7https://data-apis.org/array-api/latest/use_cases.html#use-case-scipy
8"""
9from __future__ import annotations
11import os
12import warnings
14import numpy as np
16from scipy._lib import array_api_compat
17from scipy._lib.array_api_compat import (
18 is_array_api_obj,
19 size,
20 numpy as np_compat,
21)
23__all__ = ['array_namespace', '_asarray', 'size']
26# To enable array API and strict array-like input validation
27SCIPY_ARRAY_API: str | bool = os.environ.get("SCIPY_ARRAY_API", False)
28# To control the default device - for use in the test suite only
29SCIPY_DEVICE = os.environ.get("SCIPY_DEVICE", "cpu")
31_GLOBAL_CONFIG = {
32 "SCIPY_ARRAY_API": SCIPY_ARRAY_API,
33 "SCIPY_DEVICE": SCIPY_DEVICE,
34}
37def compliance_scipy(arrays):
38 """Raise exceptions on known-bad subclasses.
40 The following subclasses are not supported and raise and error:
41 - `numpy.ma.MaskedArray`
42 - `numpy.matrix`
43 - NumPy arrays which do not have a boolean or numerical dtype
44 - Any array-like which is neither array API compatible nor coercible by NumPy
45 - Any array-like which is coerced by NumPy to an unsupported dtype
46 """
47 for i in range(len(arrays)):
48 array = arrays[i]
49 if isinstance(array, np.ma.MaskedArray):
50 raise TypeError("Inputs of type `numpy.ma.MaskedArray` are not supported.")
51 elif isinstance(array, np.matrix):
52 raise TypeError("Inputs of type `numpy.matrix` are not supported.")
53 if isinstance(array, (np.ndarray, np.generic)):
54 dtype = array.dtype
55 if not (np.issubdtype(dtype, np.number) or np.issubdtype(dtype, np.bool_)):
56 raise TypeError(f"An argument has dtype `{dtype!r}`; "
57 f"only boolean and numerical dtypes are supported.")
58 elif not is_array_api_obj(array):
59 try:
60 array = np.asanyarray(array)
61 except TypeError:
62 raise TypeError("An argument is neither array API compatible nor "
63 "coercible by NumPy.")
64 dtype = array.dtype
65 if not (np.issubdtype(dtype, np.number) or np.issubdtype(dtype, np.bool_)):
66 message = (
67 f"An argument was coerced to an unsupported dtype `{dtype!r}`; "
68 f"only boolean and numerical dtypes are supported."
69 )
70 raise TypeError(message)
71 arrays[i] = array
72 return arrays
75def _check_finite(array, xp):
76 """Check for NaNs or Infs."""
77 msg = "array must not contain infs or NaNs"
78 try:
79 if not xp.all(xp.isfinite(array)):
80 raise ValueError(msg)
81 except TypeError:
82 raise ValueError(msg)
85def array_namespace(*arrays):
86 """Get the array API compatible namespace for the arrays xs.
88 Parameters
89 ----------
90 *arrays : sequence of array_like
91 Arrays used to infer the common namespace.
93 Returns
94 -------
95 namespace : module
96 Common namespace.
98 Notes
99 -----
100 Thin wrapper around `array_api_compat.array_namespace`.
102 1. Check for the global switch: SCIPY_ARRAY_API. This can also be accessed
103 dynamically through ``_GLOBAL_CONFIG['SCIPY_ARRAY_API']``.
104 2. `compliance_scipy` raise exceptions on known-bad subclasses. See
105 its definition for more details.
107 When the global switch is False, it defaults to the `numpy` namespace.
108 In that case, there is no compliance check. This is a convenience to
109 ease the adoption. Otherwise, arrays must comply with the new rules.
110 """
111 if not _GLOBAL_CONFIG["SCIPY_ARRAY_API"]:
112 # here we could wrap the namespace if needed
113 return np_compat
115 arrays = [array for array in arrays if array is not None]
117 arrays = compliance_scipy(arrays)
119 return array_api_compat.array_namespace(*arrays)
122def _asarray(
123 array, dtype=None, order=None, copy=None, *, xp=None, check_finite=False
124):
125 """SciPy-specific replacement for `np.asarray` with `order` and `check_finite`.
127 Memory layout parameter `order` is not exposed in the Array API standard.
128 `order` is only enforced if the input array implementation
129 is NumPy based, otherwise `order` is just silently ignored.
131 `check_finite` is also not a keyword in the array API standard; included
132 here for convenience rather than that having to be a separate function
133 call inside SciPy functions.
134 """
135 if xp is None:
136 xp = array_namespace(array)
137 if xp.__name__ in {"numpy", "scipy._lib.array_api_compat.numpy"}:
138 # Use NumPy API to support order
139 if copy is True:
140 array = np.array(array, order=order, dtype=dtype)
141 else:
142 array = np.asarray(array, order=order, dtype=dtype)
144 # At this point array is a NumPy ndarray. We convert it to an array
145 # container that is consistent with the input's namespace.
146 array = xp.asarray(array)
147 else:
148 try:
149 array = xp.asarray(array, dtype=dtype, copy=copy)
150 except TypeError:
151 coerced_xp = array_namespace(xp.asarray(3))
152 array = coerced_xp.asarray(array, dtype=dtype, copy=copy)
154 if check_finite:
155 _check_finite(array, xp)
157 return array
160def atleast_nd(x, *, ndim, xp=None):
161 """Recursively expand the dimension to have at least `ndim`."""
162 if xp is None:
163 xp = array_namespace(x)
164 x = xp.asarray(x)
165 if x.ndim < ndim:
166 x = xp.expand_dims(x, axis=0)
167 x = atleast_nd(x, ndim=ndim, xp=xp)
168 return x
171def copy(x, *, xp=None):
172 """
173 Copies an array.
175 Parameters
176 ----------
177 x : array
179 xp : array_namespace
181 Returns
182 -------
183 copy : array
184 Copied array
186 Notes
187 -----
188 This copy function does not offer all the semantics of `np.copy`, i.e. the
189 `subok` and `order` keywords are not used.
190 """
191 # Note: xp.asarray fails if xp is numpy.
192 if xp is None:
193 xp = array_namespace(x)
195 return _asarray(x, copy=True, xp=xp)
198def is_numpy(xp):
199 return xp.__name__ in ('numpy', 'scipy._lib.array_api_compat.numpy')
202def is_cupy(xp):
203 return xp.__name__ in ('cupy', 'scipy._lib.array_api_compat.cupy')
206def is_torch(xp):
207 return xp.__name__ in ('torch', 'scipy._lib.array_api_compat.torch')
210def _strict_check(actual, desired, xp,
211 check_namespace=True, check_dtype=True, check_shape=True):
212 __tracebackhide__ = True # Hide traceback for py.test
213 if check_namespace:
214 _assert_matching_namespace(actual, desired)
216 desired = xp.asarray(desired)
218 if check_dtype:
219 _msg = f"dtypes do not match.\nActual: {actual.dtype}\nDesired: {desired.dtype}"
220 assert actual.dtype == desired.dtype, _msg
222 if check_shape:
223 _msg = f"Shapes do not match.\nActual: {actual.shape}\nDesired: {desired.shape}"
224 assert actual.shape == desired.shape, _msg
225 _check_scalar(actual, desired, xp)
227 desired = xp.broadcast_to(desired, actual.shape)
228 return desired
231def _assert_matching_namespace(actual, desired):
232 __tracebackhide__ = True # Hide traceback for py.test
233 actual = actual if isinstance(actual, tuple) else (actual,)
234 desired_space = array_namespace(desired)
235 for arr in actual:
236 arr_space = array_namespace(arr)
237 _msg = (f"Namespaces do not match.\n"
238 f"Actual: {arr_space.__name__}\n"
239 f"Desired: {desired_space.__name__}")
240 assert arr_space == desired_space, _msg
243def _check_scalar(actual, desired, xp):
244 __tracebackhide__ = True # Hide traceback for py.test
245 # Shape check alone is sufficient unless desired.shape == (). Also,
246 # only NumPy distinguishes between scalars and arrays.
247 if desired.shape != () or not is_numpy(xp):
248 return
249 # We want to follow the conventions of the `xp` library. Libraries like
250 # NumPy, for which `np.asarray(0)[()]` returns a scalar, tend to return
251 # a scalar even when a 0D array might be more appropriate:
252 # import numpy as np
253 # np.mean([1, 2, 3]) # scalar, not 0d array
254 # np.asarray(0)*2 # scalar, not 0d array
255 # np.sin(np.asarray(0)) # scalar, not 0d array
256 # Libraries like CuPy, for which `cp.asarray(0)[()]` returns a 0D array,
257 # tend to return a 0D array in scenarios like those above.
258 # Therefore, regardless of whether the developer provides a scalar or 0D
259 # array for `desired`, we would typically want the type of `actual` to be
260 # the type of `desired[()]`. If the developer wants to override this
261 # behavior, they can set `check_shape=False`.
262 desired = desired[()]
263 _msg = f"Types do not match:\n Actual: {type(actual)}\n Desired: {type(desired)}"
264 assert (xp.isscalar(actual) and xp.isscalar(desired)
265 or (not xp.isscalar(actual) and not xp.isscalar(desired))), _msg
268def xp_assert_equal(actual, desired, check_namespace=True, check_dtype=True,
269 check_shape=True, err_msg='', xp=None):
270 __tracebackhide__ = True # Hide traceback for py.test
271 if xp is None:
272 xp = array_namespace(actual)
273 desired = _strict_check(actual, desired, xp, check_namespace=check_namespace,
274 check_dtype=check_dtype, check_shape=check_shape)
275 if is_cupy(xp):
276 return xp.testing.assert_array_equal(actual, desired, err_msg=err_msg)
277 elif is_torch(xp):
278 # PyTorch recommends using `rtol=0, atol=0` like this
279 # to test for exact equality
280 err_msg = None if err_msg == '' else err_msg
281 return xp.testing.assert_close(actual, desired, rtol=0, atol=0, equal_nan=True,
282 check_dtype=False, msg=err_msg)
283 return np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
286def xp_assert_close(actual, desired, rtol=1e-07, atol=0, check_namespace=True,
287 check_dtype=True, check_shape=True, err_msg='', xp=None):
288 __tracebackhide__ = True # Hide traceback for py.test
289 if xp is None:
290 xp = array_namespace(actual)
291 desired = _strict_check(actual, desired, xp, check_namespace=check_namespace,
292 check_dtype=check_dtype, check_shape=check_shape)
293 if is_cupy(xp):
294 return xp.testing.assert_allclose(actual, desired, rtol=rtol,
295 atol=atol, err_msg=err_msg)
296 elif is_torch(xp):
297 err_msg = None if err_msg == '' else err_msg
298 return xp.testing.assert_close(actual, desired, rtol=rtol, atol=atol,
299 equal_nan=True, check_dtype=False, msg=err_msg)
300 return np.testing.assert_allclose(actual, desired, rtol=rtol,
301 atol=atol, err_msg=err_msg)
304def xp_assert_less(actual, desired, check_namespace=True, check_dtype=True,
305 check_shape=True, err_msg='', verbose=True, xp=None):
306 __tracebackhide__ = True # Hide traceback for py.test
307 if xp is None:
308 xp = array_namespace(actual)
309 desired = _strict_check(actual, desired, xp, check_namespace=check_namespace,
310 check_dtype=check_dtype, check_shape=check_shape)
311 if is_cupy(xp):
312 return xp.testing.assert_array_less(actual, desired,
313 err_msg=err_msg, verbose=verbose)
314 elif is_torch(xp):
315 if actual.device.type != 'cpu':
316 actual = actual.cpu()
317 if desired.device.type != 'cpu':
318 desired = desired.cpu()
319 return np.testing.assert_array_less(actual, desired,
320 err_msg=err_msg, verbose=verbose)
323def cov(x, *, xp=None):
324 if xp is None:
325 xp = array_namespace(x)
327 X = copy(x, xp=xp)
328 dtype = xp.result_type(X, xp.float64)
330 X = atleast_nd(X, ndim=2, xp=xp)
331 X = xp.asarray(X, dtype=dtype)
333 avg = xp.mean(X, axis=1)
334 fact = X.shape[1] - 1
336 if fact <= 0:
337 warnings.warn("Degrees of freedom <= 0 for slice",
338 RuntimeWarning, stacklevel=2)
339 fact = 0.0
341 X -= avg[:, None]
342 X_T = X.T
343 if xp.isdtype(X_T.dtype, 'complex floating'):
344 X_T = xp.conj(X_T)
345 c = X @ X_T
346 c /= fact
347 axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1)
348 return xp.squeeze(c, axis=axes)
351def xp_unsupported_param_msg(param):
352 return f'Providing {param!r} is only supported for numpy arrays.'
355def is_complex(x, xp):
356 return xp.isdtype(x.dtype, 'complex floating')