Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/_lib/array_api_compat/common/_helpers.py: 16%
168 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-03 06:39 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-03 06:39 +0000
1"""
2Various helper functions which are not part of the spec.
4Functions which start with an underscore are for internal use only but helpers
5that are in __all__ are intended as additional helper functions for use by end
6users of the compat library.
7"""
8from __future__ import annotations
10from typing import TYPE_CHECKING
12if TYPE_CHECKING:
13 from typing import Optional, Union, Any
14 from ._typing import Array, Device
16import sys
17import math
18import inspect
19import warnings
21def is_numpy_array(x):
22 """
23 Return True if `x` is a NumPy array.
25 This function does not import NumPy if it has not already been imported
26 and is therefore cheap to use.
28 This also returns True for `ndarray` subclasses and NumPy scalar objects.
30 See Also
31 --------
33 array_namespace
34 is_array_api_obj
35 is_cupy_array
36 is_torch_array
37 is_dask_array
38 is_jax_array
39 """
40 # Avoid importing NumPy if it isn't already
41 if 'numpy' not in sys.modules:
42 return False
44 import numpy as np
46 # TODO: Should we reject ndarray subclasses?
47 return isinstance(x, (np.ndarray, np.generic))
49def is_cupy_array(x):
50 """
51 Return True if `x` is a CuPy array.
53 This function does not import CuPy if it has not already been imported
54 and is therefore cheap to use.
56 This also returns True for `cupy.ndarray` subclasses and CuPy scalar objects.
58 See Also
59 --------
61 array_namespace
62 is_array_api_obj
63 is_numpy_array
64 is_torch_array
65 is_dask_array
66 is_jax_array
67 """
68 # Avoid importing NumPy if it isn't already
69 if 'cupy' not in sys.modules:
70 return False
72 import cupy as cp
74 # TODO: Should we reject ndarray subclasses?
75 return isinstance(x, (cp.ndarray, cp.generic))
77def is_torch_array(x):
78 """
79 Return True if `x` is a PyTorch tensor.
81 This function does not import PyTorch if it has not already been imported
82 and is therefore cheap to use.
84 See Also
85 --------
87 array_namespace
88 is_array_api_obj
89 is_numpy_array
90 is_cupy_array
91 is_dask_array
92 is_jax_array
93 """
94 # Avoid importing torch if it isn't already
95 if 'torch' not in sys.modules:
96 return False
98 import torch
100 # TODO: Should we reject ndarray subclasses?
101 return isinstance(x, torch.Tensor)
103def is_dask_array(x):
104 """
105 Return True if `x` is a dask.array Array.
107 This function does not import dask if it has not already been imported
108 and is therefore cheap to use.
110 See Also
111 --------
113 array_namespace
114 is_array_api_obj
115 is_numpy_array
116 is_cupy_array
117 is_torch_array
118 is_jax_array
119 """
120 # Avoid importing dask if it isn't already
121 if 'dask.array' not in sys.modules:
122 return False
124 import dask.array
126 return isinstance(x, dask.array.Array)
128def is_jax_array(x):
129 """
130 Return True if `x` is a JAX array.
132 This function does not import JAX if it has not already been imported
133 and is therefore cheap to use.
136 See Also
137 --------
139 array_namespace
140 is_array_api_obj
141 is_numpy_array
142 is_cupy_array
143 is_torch_array
144 is_dask_array
145 """
146 # Avoid importing jax if it isn't already
147 if 'jax' not in sys.modules:
148 return False
150 import jax
152 return isinstance(x, jax.Array)
154def is_array_api_obj(x):
155 """
156 Return True if `x` is an array API compatible array object.
158 See Also
159 --------
161 array_namespace
162 is_numpy_array
163 is_cupy_array
164 is_torch_array
165 is_dask_array
166 is_jax_array
167 """
168 return is_numpy_array(x) \
169 or is_cupy_array(x) \
170 or is_torch_array(x) \
171 or is_dask_array(x) \
172 or is_jax_array(x) \
173 or hasattr(x, '__array_namespace__')
175def _check_api_version(api_version):
176 if api_version == '2021.12':
177 warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
178 elif api_version is not None and api_version != '2022.12':
179 raise ValueError("Only the 2022.12 version of the array API specification is currently supported")
181def array_namespace(*xs, api_version=None, _use_compat=True):
182 """
183 Get the array API compatible namespace for the arrays `xs`.
185 Parameters
186 ----------
187 xs: arrays
188 one or more arrays.
190 api_version: str
191 The newest version of the spec that you need support for (currently
192 the compat library wrapped APIs support v2022.12).
194 Returns
195 -------
197 out: namespace
198 The array API compatible namespace corresponding to the arrays in `xs`.
200 Raises
201 ------
202 TypeError
203 If `xs` contains arrays from different array libraries or contains a
204 non-array.
207 Typical usage is to pass the arguments of a function to
208 `array_namespace()` at the top of a function to get the corresponding
209 array API namespace:
211 .. code:: python
213 def your_function(x, y):
214 xp = array_api_compat.array_namespace(x, y)
215 # Now use xp as the array library namespace
216 return xp.mean(x, axis=0) + 2*xp.std(y, axis=0)
219 Wrapped array namespaces can also be imported directly. For example,
220 `array_namespace(np.array(...))` will return `array_api_compat.numpy`.
221 This function will also work for any array library not wrapped by
222 array-api-compat if it explicitly defines `__array_namespace__
223 <https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__array_namespace__.html>`__
224 (the wrapped namespace is always preferred if it exists).
226 See Also
227 --------
229 is_array_api_obj
230 is_numpy_array
231 is_cupy_array
232 is_torch_array
233 is_dask_array
234 is_jax_array
236 """
237 namespaces = set()
238 for x in xs:
239 if is_numpy_array(x):
240 _check_api_version(api_version)
241 if _use_compat:
242 from .. import numpy as numpy_namespace
243 namespaces.add(numpy_namespace)
244 else:
245 import numpy as np
246 namespaces.add(np)
247 elif is_cupy_array(x):
248 _check_api_version(api_version)
249 if _use_compat:
250 from .. import cupy as cupy_namespace
251 namespaces.add(cupy_namespace)
252 else:
253 import cupy as cp
254 namespaces.add(cp)
255 elif is_torch_array(x):
256 _check_api_version(api_version)
257 if _use_compat:
258 from .. import torch as torch_namespace
259 namespaces.add(torch_namespace)
260 else:
261 import torch
262 namespaces.add(torch)
263 elif is_dask_array(x):
264 _check_api_version(api_version)
265 if _use_compat:
266 from ..dask import array as dask_namespace
267 namespaces.add(dask_namespace)
268 else:
269 raise TypeError("_use_compat cannot be False if input array is a dask array!")
270 elif is_jax_array(x):
271 _check_api_version(api_version)
272 # jax.experimental.array_api is already an array namespace. We do
273 # not have a wrapper submodule for it.
274 import jax.experimental.array_api as jnp
275 namespaces.add(jnp)
276 elif hasattr(x, '__array_namespace__'):
277 namespaces.add(x.__array_namespace__(api_version=api_version))
278 else:
279 # TODO: Support Python scalars?
280 raise TypeError(f"{type(x).__name__} is not a supported array type")
282 if not namespaces:
283 raise TypeError("Unrecognized array input")
285 if len(namespaces) != 1:
286 raise TypeError(f"Multiple namespaces for array inputs: {namespaces}")
288 xp, = namespaces
290 return xp
292# backwards compatibility alias
293get_namespace = array_namespace
295def _check_device(xp, device):
296 if xp == sys.modules.get('numpy'):
297 if device not in ["cpu", None]:
298 raise ValueError(f"Unsupported device for NumPy: {device!r}")
300# Placeholder object to represent the dask device
301# when the array backend is not the CPU.
302# (since it is not easy to tell which device a dask array is on)
303class _dask_device:
304 def __repr__(self):
305 return "DASK_DEVICE"
307_DASK_DEVICE = _dask_device()
309# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray
310# or cupy.ndarray. They are not included in array objects of this library
311# because this library just reuses the respective ndarray classes without
312# wrapping or subclassing them. These helper functions can be used instead of
313# the wrapper functions for libraries that need to support both NumPy/CuPy and
314# other libraries that use devices.
315def device(x: Array, /) -> Device:
316 """
317 Hardware device the array data resides on.
319 This is equivalent to `x.device` according to the `standard
320 <https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.device.html>`__.
321 This helper is included because some array libraries either do not have
322 the `device` attribute or include it with an incompatible API.
324 Parameters
325 ----------
326 x: array
327 array instance from an array API compatible library.
329 Returns
330 -------
331 out: device
332 a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__
333 section of the array API specification).
335 Notes
336 -----
338 For NumPy the device is always `"cpu"`. For Dask, the device is always a
339 special `DASK_DEVICE` object.
341 See Also
342 --------
344 to_device : Move array data to a different device.
346 """
347 if is_numpy_array(x):
348 return "cpu"
349 elif is_dask_array(x):
350 # Peek at the metadata of the jax array to determine type
351 try:
352 import numpy as np
353 if isinstance(x._meta, np.ndarray):
354 # Must be on CPU since backed by numpy
355 return "cpu"
356 except ImportError:
357 pass
358 return _DASK_DEVICE
359 elif is_jax_array(x):
360 # JAX has .device() as a method, but it is being deprecated so that it
361 # can become a property, in accordance with the standard. In order for
362 # this function to not break when JAX makes the flip, we check for
363 # both here.
364 if inspect.ismethod(x.device):
365 return x.device()
366 else:
367 return x.device
368 return x.device
370# Based on cupy.array_api.Array.to_device
371def _cupy_to_device(x, device, /, stream=None):
372 import cupy as cp
373 from cupy.cuda import Device as _Device
374 from cupy.cuda import stream as stream_module
375 from cupy_backends.cuda.api import runtime
377 if device == x.device:
378 return x
379 elif device == "cpu":
380 # allowing us to use `to_device(x, "cpu")`
381 # is useful for portable test swapping between
382 # host and device backends
383 return x.get()
384 elif not isinstance(device, _Device):
385 raise ValueError(f"Unsupported device {device!r}")
386 else:
387 # see cupy/cupy#5985 for the reason how we handle device/stream here
388 prev_device = runtime.getDevice()
389 prev_stream: stream_module.Stream = None
390 if stream is not None:
391 prev_stream = stream_module.get_current_stream()
392 # stream can be an int as specified in __dlpack__, or a CuPy stream
393 if isinstance(stream, int):
394 stream = cp.cuda.ExternalStream(stream)
395 elif isinstance(stream, cp.cuda.Stream):
396 pass
397 else:
398 raise ValueError('the input stream is not recognized')
399 stream.use()
400 try:
401 runtime.setDevice(device.id)
402 arr = x.copy()
403 finally:
404 runtime.setDevice(prev_device)
405 if stream is not None:
406 prev_stream.use()
407 return arr
409def _torch_to_device(x, device, /, stream=None):
410 if stream is not None:
411 raise NotImplementedError
412 return x.to(device)
414def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array:
415 """
416 Copy the array from the device on which it currently resides to the specified ``device``.
418 This is equivalent to `x.to_device(device, stream=stream)` according to
419 the `standard
420 <https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.to_device.html>`__.
421 This helper is included because some array libraries do not have the
422 `to_device` method.
424 Parameters
425 ----------
427 x: array
428 array instance from an array API compatible library.
430 device: device
431 a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__
432 section of the array API specification).
434 stream: Optional[Union[int, Any]]
435 stream object to use during copy. In addition to the types supported
436 in ``array.__dlpack__``, implementations may choose to support any
437 library-specific stream object with the caveat that any code using
438 such an object would not be portable.
440 Returns
441 -------
443 out: array
444 an array with the same data and data type as ``x`` and located on the
445 specified ``device``.
447 Notes
448 -----
450 For NumPy, this function effectively does nothing since the only supported
451 device is the CPU. For CuPy, this method supports CuPy CUDA
452 :external+cupy:class:`Device <cupy.cuda.Device>` and
453 :external+cupy:class:`Stream <cupy.cuda.Stream>` objects. For PyTorch,
454 this is the same as :external+torch:meth:`x.to(device) <torch.Tensor.to>`
455 (the ``stream`` argument is not supported in PyTorch).
457 See Also
458 --------
460 device : Hardware device the array data resides on.
462 """
463 if is_numpy_array(x):
464 if stream is not None:
465 raise ValueError("The stream argument to to_device() is not supported")
466 if device == 'cpu':
467 return x
468 raise ValueError(f"Unsupported device {device!r}")
469 elif is_cupy_array(x):
470 # cupy does not yet have to_device
471 return _cupy_to_device(x, device, stream=stream)
472 elif is_torch_array(x):
473 return _torch_to_device(x, device, stream=stream)
474 elif is_dask_array(x):
475 if stream is not None:
476 raise ValueError("The stream argument to to_device() is not supported")
477 # TODO: What if our array is on the GPU already?
478 if device == 'cpu':
479 return x
480 raise ValueError(f"Unsupported device {device!r}")
481 elif is_jax_array(x):
482 # This import adds to_device to x
483 import jax.experimental.array_api # noqa: F401
484 return x.to_device(device, stream=stream)
485 return x.to_device(device, stream=stream)
487def size(x):
488 """
489 Return the total number of elements of x.
491 This is equivalent to `x.size` according to the `standard
492 <https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html>`__.
493 This helper is included because PyTorch defines `size` in an
494 :external+torch:meth:`incompatible way <torch.Tensor.size>`.
496 """
497 if None in x.shape:
498 return None
499 return math.prod(x.shape)
501__all__ = [
502 "array_namespace",
503 "device",
504 "get_namespace",
505 "is_array_api_obj",
506 "is_cupy_array",
507 "is_dask_array",
508 "is_jax_array",
509 "is_numpy_array",
510 "is_torch_array",
511 "size",
512 "to_device",
513]
515_all_ignore = ['sys', 'math', 'inspect', 'warnings']