Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/_lib/array_api_compat/array_api_compat/common/_helpers.py: 15%
116 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-23 06:43 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-23 06:43 +0000
1"""
2Various helper functions which are not part of the spec.
4Functions which start with an underscore are for internal use only but helpers
5that are in __all__ are intended as additional helper functions for use by end
6users of the compat library.
7"""
8from __future__ import annotations
10import sys
11import math
13def _is_numpy_array(x):
14 # Avoid importing NumPy if it isn't already
15 if 'numpy' not in sys.modules:
16 return False
18 import numpy as np
20 # TODO: Should we reject ndarray subclasses?
21 return isinstance(x, (np.ndarray, np.generic))
23def _is_cupy_array(x):
24 # Avoid importing NumPy if it isn't already
25 if 'cupy' not in sys.modules:
26 return False
28 import cupy as cp
30 # TODO: Should we reject ndarray subclasses?
31 return isinstance(x, (cp.ndarray, cp.generic))
33def _is_torch_array(x):
34 # Avoid importing torch if it isn't already
35 if 'torch' not in sys.modules:
36 return False
38 import torch
40 # TODO: Should we reject ndarray subclasses?
41 return isinstance(x, torch.Tensor)
43def is_array_api_obj(x):
44 """
45 Check if x is an array API compatible array object.
46 """
47 return _is_numpy_array(x) \
48 or _is_cupy_array(x) \
49 or _is_torch_array(x) \
50 or hasattr(x, '__array_namespace__')
52def _check_api_version(api_version):
53 if api_version is not None and api_version != '2021.12':
54 raise ValueError("Only the 2021.12 version of the array API specification is currently supported")
56def array_namespace(*xs, api_version=None, _use_compat=True):
57 """
58 Get the array API compatible namespace for the arrays `xs`.
60 `xs` should contain one or more arrays.
62 Typical usage is
64 def your_function(x, y):
65 xp = array_api_compat.array_namespace(x, y)
66 # Now use xp as the array library namespace
67 return xp.mean(x, axis=0) + 2*xp.std(y, axis=0)
69 api_version should be the newest version of the spec that you need support
70 for (currently the compat library wrapped APIs only support v2021.12).
71 """
72 namespaces = set()
73 for x in xs:
74 if isinstance(x, (tuple, list)):
75 namespaces.add(array_namespace(*x, _use_compat=_use_compat))
76 elif hasattr(x, '__array_namespace__'):
77 namespaces.add(x.__array_namespace__(api_version=api_version))
78 elif _is_numpy_array(x):
79 _check_api_version(api_version)
80 if _use_compat:
81 from .. import numpy as numpy_namespace
82 namespaces.add(numpy_namespace)
83 else:
84 import numpy as np
85 namespaces.add(np)
86 elif _is_cupy_array(x):
87 _check_api_version(api_version)
88 if _use_compat:
89 from .. import cupy as cupy_namespace
90 namespaces.add(cupy_namespace)
91 else:
92 import cupy as cp
93 namespaces.add(cp)
94 elif _is_torch_array(x):
95 _check_api_version(api_version)
96 if _use_compat:
97 from .. import torch as torch_namespace
98 namespaces.add(torch_namespace)
99 else:
100 import torch
101 namespaces.add(torch)
102 else:
103 # TODO: Support Python scalars?
104 raise TypeError("The input is not a supported array type")
106 if not namespaces:
107 raise TypeError("Unrecognized array input")
109 if len(namespaces) != 1:
110 raise TypeError(f"Multiple namespaces for array inputs: {namespaces}")
112 xp, = namespaces
114 return xp
116# backwards compatibility alias
117get_namespace = array_namespace
119def _check_device(xp, device):
120 if xp == sys.modules.get('numpy'):
121 if device not in ["cpu", None]:
122 raise ValueError(f"Unsupported device for NumPy: {device!r}")
124# device() is not on numpy.ndarray and and to_device() is not on numpy.ndarray
125# or cupy.ndarray. They are not included in array objects of this library
126# because this library just reuses the respective ndarray classes without
127# wrapping or subclassing them. These helper functions can be used instead of
128# the wrapper functions for libraries that need to support both NumPy/CuPy and
129# other libraries that use devices.
130def device(x: "Array", /) -> "Device":
131 """
132 Hardware device the array data resides on.
134 Parameters
135 ----------
136 x: array
137 array instance from NumPy or an array API compatible library.
139 Returns
140 -------
141 out: device
142 a ``device`` object (see the "Device Support" section of the array API specification).
143 """
144 if _is_numpy_array(x):
145 return "cpu"
146 return x.device
148# Based on cupy.array_api.Array.to_device
149def _cupy_to_device(x, device, /, stream=None):
150 import cupy as cp
151 from cupy.cuda import Device as _Device
152 from cupy.cuda import stream as stream_module
153 from cupy_backends.cuda.api import runtime
155 if device == x.device:
156 return x
157 elif device == "cpu":
158 # allowing us to use `to_device(x, "cpu")`
159 # is useful for portable test swapping between
160 # host and device backends
161 return x.get()
162 elif not isinstance(device, _Device):
163 raise ValueError(f"Unsupported device {device!r}")
164 else:
165 # see cupy/cupy#5985 for the reason how we handle device/stream here
166 prev_device = runtime.getDevice()
167 prev_stream: stream_module.Stream = None
168 if stream is not None:
169 prev_stream = stream_module.get_current_stream()
170 # stream can be an int as specified in __dlpack__, or a CuPy stream
171 if isinstance(stream, int):
172 stream = cp.cuda.ExternalStream(stream)
173 elif isinstance(stream, cp.cuda.Stream):
174 pass
175 else:
176 raise ValueError('the input stream is not recognized')
177 stream.use()
178 try:
179 runtime.setDevice(device.id)
180 arr = x.copy()
181 finally:
182 runtime.setDevice(prev_device)
183 if stream is not None:
184 prev_stream.use()
185 return arr
187def _torch_to_device(x, device, /, stream=None):
188 if stream is not None:
189 raise NotImplementedError
190 return x.to(device)
192def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, Any]]" = None) -> "Array":
193 """
194 Copy the array from the device on which it currently resides to the specified ``device``.
196 Parameters
197 ----------
198 x: array
199 array instance from NumPy or an array API compatible library.
200 device: device
201 a ``device`` object (see the "Device Support" section of the array API specification).
202 stream: Optional[Union[int, Any]]
203 stream object to use during copy. In addition to the types supported in ``array.__dlpack__``, implementations may choose to support any library-specific stream object with the caveat that any code using such an object would not be portable.
205 Returns
206 -------
207 out: array
208 an array with the same data and data type as ``x`` and located on the specified ``device``.
210 .. note::
211 If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation.
212 """
213 if _is_numpy_array(x):
214 if stream is not None:
215 raise ValueError("The stream argument to to_device() is not supported")
216 if device == 'cpu':
217 return x
218 raise ValueError(f"Unsupported device {device!r}")
219 elif _is_cupy_array(x):
220 # cupy does not yet have to_device
221 return _cupy_to_device(x, device, stream=stream)
222 elif _is_torch_array(x):
223 return _torch_to_device(x, device, stream=stream)
224 return x.to_device(device, stream=stream)
226def size(x):
227 """
228 Return the total number of elements of x
229 """
230 if None in x.shape:
231 return None
232 return math.prod(x.shape)
234__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device', 'to_device', 'size']