Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/_lib/array_api_compat/common/_helpers.py: 15%
114 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-14 06:37 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-14 06:37 +0000
1"""
2Various helper functions which are not part of the spec.
4Functions which start with an underscore are for internal use only but helpers
5that are in __all__ are intended as additional helper functions for use by end
6users of the compat library.
7"""
8from __future__ import annotations
10import sys
11import math
13def _is_numpy_array(x):
14 # Avoid importing NumPy if it isn't already
15 if 'numpy' not in sys.modules:
16 return False
18 import numpy as np
20 # TODO: Should we reject ndarray subclasses?
21 return isinstance(x, (np.ndarray, np.generic))
23def _is_cupy_array(x):
24 # Avoid importing NumPy if it isn't already
25 if 'cupy' not in sys.modules:
26 return False
28 import cupy as cp
30 # TODO: Should we reject ndarray subclasses?
31 return isinstance(x, (cp.ndarray, cp.generic))
33def _is_torch_array(x):
34 # Avoid importing torch if it isn't already
35 if 'torch' not in sys.modules:
36 return False
38 import torch
40 # TODO: Should we reject ndarray subclasses?
41 return isinstance(x, torch.Tensor)
43def is_array_api_obj(x):
44 """
45 Check if x is an array API compatible array object.
46 """
47 return _is_numpy_array(x) \
48 or _is_cupy_array(x) \
49 or _is_torch_array(x) \
50 or hasattr(x, '__array_namespace__')
52def _check_api_version(api_version):
53 if api_version is not None and api_version != '2021.12':
54 raise ValueError("Only the 2021.12 version of the array API specification is currently supported")
56def array_namespace(*xs, api_version=None, _use_compat=True):
57 """
58 Get the array API compatible namespace for the arrays `xs`.
60 `xs` should contain one or more arrays.
62 Typical usage is
64 def your_function(x, y):
65 xp = array_api_compat.array_namespace(x, y)
66 # Now use xp as the array library namespace
67 return xp.mean(x, axis=0) + 2*xp.std(y, axis=0)
69 api_version should be the newest version of the spec that you need support
70 for (currently the compat library wrapped APIs only support v2021.12).
71 """
72 namespaces = set()
73 for x in xs:
74 if _is_numpy_array(x):
75 _check_api_version(api_version)
76 if _use_compat:
77 from .. import numpy as numpy_namespace
78 namespaces.add(numpy_namespace)
79 else:
80 import numpy as np
81 namespaces.add(np)
82 elif _is_cupy_array(x):
83 _check_api_version(api_version)
84 if _use_compat:
85 from .. import cupy as cupy_namespace
86 namespaces.add(cupy_namespace)
87 else:
88 import cupy as cp
89 namespaces.add(cp)
90 elif _is_torch_array(x):
91 _check_api_version(api_version)
92 if _use_compat:
93 from .. import torch as torch_namespace
94 namespaces.add(torch_namespace)
95 else:
96 import torch
97 namespaces.add(torch)
98 elif hasattr(x, '__array_namespace__'):
99 namespaces.add(x.__array_namespace__(api_version=api_version))
100 else:
101 # TODO: Support Python scalars?
102 raise TypeError(f"{type(x).__name__} is not a supported array type")
104 if not namespaces:
105 raise TypeError("Unrecognized array input")
107 if len(namespaces) != 1:
108 raise TypeError(f"Multiple namespaces for array inputs: {namespaces}")
110 xp, = namespaces
112 return xp
114# backwards compatibility alias
115get_namespace = array_namespace
117def _check_device(xp, device):
118 if xp == sys.modules.get('numpy'):
119 if device not in ["cpu", None]:
120 raise ValueError(f"Unsupported device for NumPy: {device!r}")
122# device() is not on numpy.ndarray and and to_device() is not on numpy.ndarray
123# or cupy.ndarray. They are not included in array objects of this library
124# because this library just reuses the respective ndarray classes without
125# wrapping or subclassing them. These helper functions can be used instead of
126# the wrapper functions for libraries that need to support both NumPy/CuPy and
127# other libraries that use devices.
128def device(x: "Array", /) -> "Device":
129 """
130 Hardware device the array data resides on.
132 Parameters
133 ----------
134 x: array
135 array instance from NumPy or an array API compatible library.
137 Returns
138 -------
139 out: device
140 a ``device`` object (see the "Device Support" section of the array API specification).
141 """
142 if _is_numpy_array(x):
143 return "cpu"
144 return x.device
146# Based on cupy.array_api.Array.to_device
147def _cupy_to_device(x, device, /, stream=None):
148 import cupy as cp
149 from cupy.cuda import Device as _Device
150 from cupy.cuda import stream as stream_module
151 from cupy_backends.cuda.api import runtime
153 if device == x.device:
154 return x
155 elif device == "cpu":
156 # allowing us to use `to_device(x, "cpu")`
157 # is useful for portable test swapping between
158 # host and device backends
159 return x.get()
160 elif not isinstance(device, _Device):
161 raise ValueError(f"Unsupported device {device!r}")
162 else:
163 # see cupy/cupy#5985 for the reason how we handle device/stream here
164 prev_device = runtime.getDevice()
165 prev_stream: stream_module.Stream = None
166 if stream is not None:
167 prev_stream = stream_module.get_current_stream()
168 # stream can be an int as specified in __dlpack__, or a CuPy stream
169 if isinstance(stream, int):
170 stream = cp.cuda.ExternalStream(stream)
171 elif isinstance(stream, cp.cuda.Stream):
172 pass
173 else:
174 raise ValueError('the input stream is not recognized')
175 stream.use()
176 try:
177 runtime.setDevice(device.id)
178 arr = x.copy()
179 finally:
180 runtime.setDevice(prev_device)
181 if stream is not None:
182 prev_stream.use()
183 return arr
185def _torch_to_device(x, device, /, stream=None):
186 if stream is not None:
187 raise NotImplementedError
188 return x.to(device)
190def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, Any]]" = None) -> "Array":
191 """
192 Copy the array from the device on which it currently resides to the specified ``device``.
194 Parameters
195 ----------
196 x: array
197 array instance from NumPy or an array API compatible library.
198 device: device
199 a ``device`` object (see the "Device Support" section of the array API specification).
200 stream: Optional[Union[int, Any]]
201 stream object to use during copy. In addition to the types supported in ``array.__dlpack__``, implementations may choose to support any library-specific stream object with the caveat that any code using such an object would not be portable.
203 Returns
204 -------
205 out: array
206 an array with the same data and data type as ``x`` and located on the specified ``device``.
208 .. note::
209 If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation.
210 """
211 if _is_numpy_array(x):
212 if stream is not None:
213 raise ValueError("The stream argument to to_device() is not supported")
214 if device == 'cpu':
215 return x
216 raise ValueError(f"Unsupported device {device!r}")
217 elif _is_cupy_array(x):
218 # cupy does not yet have to_device
219 return _cupy_to_device(x, device, stream=stream)
220 elif _is_torch_array(x):
221 return _torch_to_device(x, device, stream=stream)
222 return x.to_device(device, stream=stream)
224def size(x):
225 """
226 Return the total number of elements of x
227 """
228 if None in x.shape:
229 return None
230 return math.prod(x.shape)
232__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device', 'to_device', 'size']