Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/_lib/array_api_compat/common/_helpers.py: 16%

168 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-03 06:39 +0000

1""" 

2Various helper functions which are not part of the spec. 

3 

4Functions which start with an underscore are for internal use only but helpers 

5that are in __all__ are intended as additional helper functions for use by end 

6users of the compat library. 

7""" 

8from __future__ import annotations 

9 

10from typing import TYPE_CHECKING 

11 

12if TYPE_CHECKING: 

13 from typing import Optional, Union, Any 

14 from ._typing import Array, Device 

15 

16import sys 

17import math 

18import inspect 

19import warnings 

20 

21def is_numpy_array(x): 

22 """ 

23 Return True if `x` is a NumPy array. 

24 

25 This function does not import NumPy if it has not already been imported 

26 and is therefore cheap to use. 

27 

28 This also returns True for `ndarray` subclasses and NumPy scalar objects. 

29 

30 See Also 

31 -------- 

32 

33 array_namespace 

34 is_array_api_obj 

35 is_cupy_array 

36 is_torch_array 

37 is_dask_array 

38 is_jax_array 

39 """ 

40 # Avoid importing NumPy if it isn't already 

41 if 'numpy' not in sys.modules: 

42 return False 

43 

44 import numpy as np 

45 

46 # TODO: Should we reject ndarray subclasses? 

47 return isinstance(x, (np.ndarray, np.generic)) 

48 

49def is_cupy_array(x): 

50 """ 

51 Return True if `x` is a CuPy array. 

52 

53 This function does not import CuPy if it has not already been imported 

54 and is therefore cheap to use. 

55 

56 This also returns True for `cupy.ndarray` subclasses and CuPy scalar objects. 

57 

58 See Also 

59 -------- 

60 

61 array_namespace 

62 is_array_api_obj 

63 is_numpy_array 

64 is_torch_array 

65 is_dask_array 

66 is_jax_array 

67 """ 

68 # Avoid importing NumPy if it isn't already 

69 if 'cupy' not in sys.modules: 

70 return False 

71 

72 import cupy as cp 

73 

74 # TODO: Should we reject ndarray subclasses? 

75 return isinstance(x, (cp.ndarray, cp.generic)) 

76 

77def is_torch_array(x): 

78 """ 

79 Return True if `x` is a PyTorch tensor. 

80 

81 This function does not import PyTorch if it has not already been imported 

82 and is therefore cheap to use. 

83 

84 See Also 

85 -------- 

86 

87 array_namespace 

88 is_array_api_obj 

89 is_numpy_array 

90 is_cupy_array 

91 is_dask_array 

92 is_jax_array 

93 """ 

94 # Avoid importing torch if it isn't already 

95 if 'torch' not in sys.modules: 

96 return False 

97 

98 import torch 

99 

100 # TODO: Should we reject ndarray subclasses? 

101 return isinstance(x, torch.Tensor) 

102 

103def is_dask_array(x): 

104 """ 

105 Return True if `x` is a dask.array Array. 

106 

107 This function does not import dask if it has not already been imported 

108 and is therefore cheap to use. 

109 

110 See Also 

111 -------- 

112 

113 array_namespace 

114 is_array_api_obj 

115 is_numpy_array 

116 is_cupy_array 

117 is_torch_array 

118 is_jax_array 

119 """ 

120 # Avoid importing dask if it isn't already 

121 if 'dask.array' not in sys.modules: 

122 return False 

123 

124 import dask.array 

125 

126 return isinstance(x, dask.array.Array) 

127 

128def is_jax_array(x): 

129 """ 

130 Return True if `x` is a JAX array. 

131 

132 This function does not import JAX if it has not already been imported 

133 and is therefore cheap to use. 

134 

135 

136 See Also 

137 -------- 

138 

139 array_namespace 

140 is_array_api_obj 

141 is_numpy_array 

142 is_cupy_array 

143 is_torch_array 

144 is_dask_array 

145 """ 

146 # Avoid importing jax if it isn't already 

147 if 'jax' not in sys.modules: 

148 return False 

149 

150 import jax 

151 

152 return isinstance(x, jax.Array) 

153 

154def is_array_api_obj(x): 

155 """ 

156 Return True if `x` is an array API compatible array object. 

157 

158 See Also 

159 -------- 

160 

161 array_namespace 

162 is_numpy_array 

163 is_cupy_array 

164 is_torch_array 

165 is_dask_array 

166 is_jax_array 

167 """ 

168 return is_numpy_array(x) \ 

169 or is_cupy_array(x) \ 

170 or is_torch_array(x) \ 

171 or is_dask_array(x) \ 

172 or is_jax_array(x) \ 

173 or hasattr(x, '__array_namespace__') 

174 

175def _check_api_version(api_version): 

176 if api_version == '2021.12': 

177 warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12") 

178 elif api_version is not None and api_version != '2022.12': 

179 raise ValueError("Only the 2022.12 version of the array API specification is currently supported") 

180 

181def array_namespace(*xs, api_version=None, _use_compat=True): 

182 """ 

183 Get the array API compatible namespace for the arrays `xs`. 

184 

185 Parameters 

186 ---------- 

187 xs: arrays 

188 one or more arrays. 

189 

190 api_version: str 

191 The newest version of the spec that you need support for (currently 

192 the compat library wrapped APIs support v2022.12). 

193 

194 Returns 

195 ------- 

196 

197 out: namespace 

198 The array API compatible namespace corresponding to the arrays in `xs`. 

199 

200 Raises 

201 ------ 

202 TypeError 

203 If `xs` contains arrays from different array libraries or contains a 

204 non-array. 

205 

206 

207 Typical usage is to pass the arguments of a function to 

208 `array_namespace()` at the top of a function to get the corresponding 

209 array API namespace: 

210 

211 .. code:: python 

212 

213 def your_function(x, y): 

214 xp = array_api_compat.array_namespace(x, y) 

215 # Now use xp as the array library namespace 

216 return xp.mean(x, axis=0) + 2*xp.std(y, axis=0) 

217 

218 

219 Wrapped array namespaces can also be imported directly. For example, 

220 `array_namespace(np.array(...))` will return `array_api_compat.numpy`. 

221 This function will also work for any array library not wrapped by 

222 array-api-compat if it explicitly defines `__array_namespace__ 

223 <https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__array_namespace__.html>`__ 

224 (the wrapped namespace is always preferred if it exists). 

225 

226 See Also 

227 -------- 

228 

229 is_array_api_obj 

230 is_numpy_array 

231 is_cupy_array 

232 is_torch_array 

233 is_dask_array 

234 is_jax_array 

235 

236 """ 

237 namespaces = set() 

238 for x in xs: 

239 if is_numpy_array(x): 

240 _check_api_version(api_version) 

241 if _use_compat: 

242 from .. import numpy as numpy_namespace 

243 namespaces.add(numpy_namespace) 

244 else: 

245 import numpy as np 

246 namespaces.add(np) 

247 elif is_cupy_array(x): 

248 _check_api_version(api_version) 

249 if _use_compat: 

250 from .. import cupy as cupy_namespace 

251 namespaces.add(cupy_namespace) 

252 else: 

253 import cupy as cp 

254 namespaces.add(cp) 

255 elif is_torch_array(x): 

256 _check_api_version(api_version) 

257 if _use_compat: 

258 from .. import torch as torch_namespace 

259 namespaces.add(torch_namespace) 

260 else: 

261 import torch 

262 namespaces.add(torch) 

263 elif is_dask_array(x): 

264 _check_api_version(api_version) 

265 if _use_compat: 

266 from ..dask import array as dask_namespace 

267 namespaces.add(dask_namespace) 

268 else: 

269 raise TypeError("_use_compat cannot be False if input array is a dask array!") 

270 elif is_jax_array(x): 

271 _check_api_version(api_version) 

272 # jax.experimental.array_api is already an array namespace. We do 

273 # not have a wrapper submodule for it. 

274 import jax.experimental.array_api as jnp 

275 namespaces.add(jnp) 

276 elif hasattr(x, '__array_namespace__'): 

277 namespaces.add(x.__array_namespace__(api_version=api_version)) 

278 else: 

279 # TODO: Support Python scalars? 

280 raise TypeError(f"{type(x).__name__} is not a supported array type") 

281 

282 if not namespaces: 

283 raise TypeError("Unrecognized array input") 

284 

285 if len(namespaces) != 1: 

286 raise TypeError(f"Multiple namespaces for array inputs: {namespaces}") 

287 

288 xp, = namespaces 

289 

290 return xp 

291 

292# backwards compatibility alias 

293get_namespace = array_namespace 

294 

295def _check_device(xp, device): 

296 if xp == sys.modules.get('numpy'): 

297 if device not in ["cpu", None]: 

298 raise ValueError(f"Unsupported device for NumPy: {device!r}") 

299 

300# Placeholder object to represent the dask device 

301# when the array backend is not the CPU. 

302# (since it is not easy to tell which device a dask array is on) 

303class _dask_device: 

304 def __repr__(self): 

305 return "DASK_DEVICE" 

306 

307_DASK_DEVICE = _dask_device() 

308 

309# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray 

310# or cupy.ndarray. They are not included in array objects of this library 

311# because this library just reuses the respective ndarray classes without 

312# wrapping or subclassing them. These helper functions can be used instead of 

313# the wrapper functions for libraries that need to support both NumPy/CuPy and 

314# other libraries that use devices. 

315def device(x: Array, /) -> Device: 

316 """ 

317 Hardware device the array data resides on. 

318 

319 This is equivalent to `x.device` according to the `standard 

320 <https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.device.html>`__. 

321 This helper is included because some array libraries either do not have 

322 the `device` attribute or include it with an incompatible API. 

323 

324 Parameters 

325 ---------- 

326 x: array 

327 array instance from an array API compatible library. 

328 

329 Returns 

330 ------- 

331 out: device 

332 a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__ 

333 section of the array API specification). 

334 

335 Notes 

336 ----- 

337 

338 For NumPy the device is always `"cpu"`. For Dask, the device is always a 

339 special `DASK_DEVICE` object. 

340 

341 See Also 

342 -------- 

343 

344 to_device : Move array data to a different device. 

345 

346 """ 

347 if is_numpy_array(x): 

348 return "cpu" 

349 elif is_dask_array(x): 

350 # Peek at the metadata of the jax array to determine type 

351 try: 

352 import numpy as np 

353 if isinstance(x._meta, np.ndarray): 

354 # Must be on CPU since backed by numpy 

355 return "cpu" 

356 except ImportError: 

357 pass 

358 return _DASK_DEVICE 

359 elif is_jax_array(x): 

360 # JAX has .device() as a method, but it is being deprecated so that it 

361 # can become a property, in accordance with the standard. In order for 

362 # this function to not break when JAX makes the flip, we check for 

363 # both here. 

364 if inspect.ismethod(x.device): 

365 return x.device() 

366 else: 

367 return x.device 

368 return x.device 

369 

370# Based on cupy.array_api.Array.to_device 

371def _cupy_to_device(x, device, /, stream=None): 

372 import cupy as cp 

373 from cupy.cuda import Device as _Device 

374 from cupy.cuda import stream as stream_module 

375 from cupy_backends.cuda.api import runtime 

376 

377 if device == x.device: 

378 return x 

379 elif device == "cpu": 

380 # allowing us to use `to_device(x, "cpu")` 

381 # is useful for portable test swapping between 

382 # host and device backends 

383 return x.get() 

384 elif not isinstance(device, _Device): 

385 raise ValueError(f"Unsupported device {device!r}") 

386 else: 

387 # see cupy/cupy#5985 for the reason how we handle device/stream here 

388 prev_device = runtime.getDevice() 

389 prev_stream: stream_module.Stream = None 

390 if stream is not None: 

391 prev_stream = stream_module.get_current_stream() 

392 # stream can be an int as specified in __dlpack__, or a CuPy stream 

393 if isinstance(stream, int): 

394 stream = cp.cuda.ExternalStream(stream) 

395 elif isinstance(stream, cp.cuda.Stream): 

396 pass 

397 else: 

398 raise ValueError('the input stream is not recognized') 

399 stream.use() 

400 try: 

401 runtime.setDevice(device.id) 

402 arr = x.copy() 

403 finally: 

404 runtime.setDevice(prev_device) 

405 if stream is not None: 

406 prev_stream.use() 

407 return arr 

408 

409def _torch_to_device(x, device, /, stream=None): 

410 if stream is not None: 

411 raise NotImplementedError 

412 return x.to(device) 

413 

414def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array: 

415 """ 

416 Copy the array from the device on which it currently resides to the specified ``device``. 

417 

418 This is equivalent to `x.to_device(device, stream=stream)` according to 

419 the `standard 

420 <https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.to_device.html>`__. 

421 This helper is included because some array libraries do not have the 

422 `to_device` method. 

423 

424 Parameters 

425 ---------- 

426 

427 x: array 

428 array instance from an array API compatible library. 

429 

430 device: device 

431 a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__ 

432 section of the array API specification). 

433 

434 stream: Optional[Union[int, Any]] 

435 stream object to use during copy. In addition to the types supported 

436 in ``array.__dlpack__``, implementations may choose to support any 

437 library-specific stream object with the caveat that any code using 

438 such an object would not be portable. 

439 

440 Returns 

441 ------- 

442 

443 out: array 

444 an array with the same data and data type as ``x`` and located on the 

445 specified ``device``. 

446 

447 Notes 

448 ----- 

449 

450 For NumPy, this function effectively does nothing since the only supported 

451 device is the CPU. For CuPy, this method supports CuPy CUDA 

452 :external+cupy:class:`Device <cupy.cuda.Device>` and 

453 :external+cupy:class:`Stream <cupy.cuda.Stream>` objects. For PyTorch, 

454 this is the same as :external+torch:meth:`x.to(device) <torch.Tensor.to>` 

455 (the ``stream`` argument is not supported in PyTorch). 

456 

457 See Also 

458 -------- 

459 

460 device : Hardware device the array data resides on. 

461 

462 """ 

463 if is_numpy_array(x): 

464 if stream is not None: 

465 raise ValueError("The stream argument to to_device() is not supported") 

466 if device == 'cpu': 

467 return x 

468 raise ValueError(f"Unsupported device {device!r}") 

469 elif is_cupy_array(x): 

470 # cupy does not yet have to_device 

471 return _cupy_to_device(x, device, stream=stream) 

472 elif is_torch_array(x): 

473 return _torch_to_device(x, device, stream=stream) 

474 elif is_dask_array(x): 

475 if stream is not None: 

476 raise ValueError("The stream argument to to_device() is not supported") 

477 # TODO: What if our array is on the GPU already? 

478 if device == 'cpu': 

479 return x 

480 raise ValueError(f"Unsupported device {device!r}") 

481 elif is_jax_array(x): 

482 # This import adds to_device to x 

483 import jax.experimental.array_api # noqa: F401 

484 return x.to_device(device, stream=stream) 

485 return x.to_device(device, stream=stream) 

486 

487def size(x): 

488 """ 

489 Return the total number of elements of x. 

490 

491 This is equivalent to `x.size` according to the `standard 

492 <https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html>`__. 

493 This helper is included because PyTorch defines `size` in an 

494 :external+torch:meth:`incompatible way <torch.Tensor.size>`. 

495 

496 """ 

497 if None in x.shape: 

498 return None 

499 return math.prod(x.shape) 

500 

501__all__ = [ 

502 "array_namespace", 

503 "device", 

504 "get_namespace", 

505 "is_array_api_obj", 

506 "is_cupy_array", 

507 "is_dask_array", 

508 "is_jax_array", 

509 "is_numpy_array", 

510 "is_torch_array", 

511 "size", 

512 "to_device", 

513] 

514 

515_all_ignore = ['sys', 'math', 'inspect', 'warnings']