Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/_lib/array_api_compat/common/_aliases.py: 24%
228 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-03 06:39 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-03 06:39 +0000
1"""
2These are functions that are just aliases of existing functions in NumPy.
3"""
5from __future__ import annotations
7from typing import TYPE_CHECKING
8if TYPE_CHECKING:
9 import numpy as np
10 from typing import Optional, Sequence, Tuple, Union
11 from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
13from typing import NamedTuple
14from types import ModuleType
15import inspect
17from ._helpers import _check_device, is_numpy_array, array_namespace
19# These functions are modified from the NumPy versions.
21def arange(
22 start: Union[int, float],
23 /,
24 stop: Optional[Union[int, float]] = None,
25 step: Union[int, float] = 1,
26 *,
27 xp,
28 dtype: Optional[Dtype] = None,
29 device: Optional[Device] = None,
30 **kwargs
31) -> ndarray:
32 _check_device(xp, device)
33 return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
35def empty(
36 shape: Union[int, Tuple[int, ...]],
37 xp,
38 *,
39 dtype: Optional[Dtype] = None,
40 device: Optional[Device] = None,
41 **kwargs
42) -> ndarray:
43 _check_device(xp, device)
44 return xp.empty(shape, dtype=dtype, **kwargs)
46def empty_like(
47 x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
48 **kwargs
49) -> ndarray:
50 _check_device(xp, device)
51 return xp.empty_like(x, dtype=dtype, **kwargs)
53def eye(
54 n_rows: int,
55 n_cols: Optional[int] = None,
56 /,
57 *,
58 xp,
59 k: int = 0,
60 dtype: Optional[Dtype] = None,
61 device: Optional[Device] = None,
62 **kwargs,
63) -> ndarray:
64 _check_device(xp, device)
65 return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
67def full(
68 shape: Union[int, Tuple[int, ...]],
69 fill_value: Union[int, float],
70 xp,
71 *,
72 dtype: Optional[Dtype] = None,
73 device: Optional[Device] = None,
74 **kwargs,
75) -> ndarray:
76 _check_device(xp, device)
77 return xp.full(shape, fill_value, dtype=dtype, **kwargs)
79def full_like(
80 x: ndarray,
81 /,
82 fill_value: Union[int, float],
83 *,
84 xp,
85 dtype: Optional[Dtype] = None,
86 device: Optional[Device] = None,
87 **kwargs,
88) -> ndarray:
89 _check_device(xp, device)
90 return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
92def linspace(
93 start: Union[int, float],
94 stop: Union[int, float],
95 /,
96 num: int,
97 *,
98 xp,
99 dtype: Optional[Dtype] = None,
100 device: Optional[Device] = None,
101 endpoint: bool = True,
102 **kwargs,
103) -> ndarray:
104 _check_device(xp, device)
105 return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
107def ones(
108 shape: Union[int, Tuple[int, ...]],
109 xp,
110 *,
111 dtype: Optional[Dtype] = None,
112 device: Optional[Device] = None,
113 **kwargs,
114) -> ndarray:
115 _check_device(xp, device)
116 return xp.ones(shape, dtype=dtype, **kwargs)
118def ones_like(
119 x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
120 **kwargs,
121) -> ndarray:
122 _check_device(xp, device)
123 return xp.ones_like(x, dtype=dtype, **kwargs)
125def zeros(
126 shape: Union[int, Tuple[int, ...]],
127 xp,
128 *,
129 dtype: Optional[Dtype] = None,
130 device: Optional[Device] = None,
131 **kwargs,
132) -> ndarray:
133 _check_device(xp, device)
134 return xp.zeros(shape, dtype=dtype, **kwargs)
136def zeros_like(
137 x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
138 **kwargs,
139) -> ndarray:
140 _check_device(xp, device)
141 return xp.zeros_like(x, dtype=dtype, **kwargs)
143# np.unique() is split into four functions in the array API:
144# unique_all, unique_counts, unique_inverse, and unique_values (this is done
145# to remove polymorphic return types).
147# The functions here return namedtuples (np.unique() returns a normal
148# tuple).
150# Note that these named tuples aren't actually part of the standard namespace,
151# but I don't see any issue with exporting the names here regardless.
152class UniqueAllResult(NamedTuple):
153 values: ndarray
154 indices: ndarray
155 inverse_indices: ndarray
156 counts: ndarray
159class UniqueCountsResult(NamedTuple):
160 values: ndarray
161 counts: ndarray
164class UniqueInverseResult(NamedTuple):
165 values: ndarray
166 inverse_indices: ndarray
169def _unique_kwargs(xp):
170 # Older versions of NumPy and CuPy do not have equal_nan. Rather than
171 # trying to parse version numbers, just check if equal_nan is in the
172 # signature.
173 s = inspect.signature(xp.unique)
174 if 'equal_nan' in s.parameters:
175 return {'equal_nan': False}
176 return {}
178def unique_all(x: ndarray, /, xp) -> UniqueAllResult:
179 kwargs = _unique_kwargs(xp)
180 values, indices, inverse_indices, counts = xp.unique(
181 x,
182 return_counts=True,
183 return_index=True,
184 return_inverse=True,
185 **kwargs,
186 )
187 # np.unique() flattens inverse indices, but they need to share x's shape
188 # See https://github.com/numpy/numpy/issues/20638
189 inverse_indices = inverse_indices.reshape(x.shape)
190 return UniqueAllResult(
191 values,
192 indices,
193 inverse_indices,
194 counts,
195 )
198def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult:
199 kwargs = _unique_kwargs(xp)
200 res = xp.unique(
201 x,
202 return_counts=True,
203 return_index=False,
204 return_inverse=False,
205 **kwargs
206 )
208 return UniqueCountsResult(*res)
211def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult:
212 kwargs = _unique_kwargs(xp)
213 values, inverse_indices = xp.unique(
214 x,
215 return_counts=False,
216 return_index=False,
217 return_inverse=True,
218 **kwargs,
219 )
220 # xp.unique() flattens inverse indices, but they need to share x's shape
221 # See https://github.com/numpy/numpy/issues/20638
222 inverse_indices = inverse_indices.reshape(x.shape)
223 return UniqueInverseResult(values, inverse_indices)
226def unique_values(x: ndarray, /, xp) -> ndarray:
227 kwargs = _unique_kwargs(xp)
228 return xp.unique(
229 x,
230 return_counts=False,
231 return_index=False,
232 return_inverse=False,
233 **kwargs,
234 )
236def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray:
237 if not copy and dtype == x.dtype:
238 return x
239 return x.astype(dtype=dtype, copy=copy)
241# These functions have different keyword argument names
243def std(
244 x: ndarray,
245 /,
246 xp,
247 *,
248 axis: Optional[Union[int, Tuple[int, ...]]] = None,
249 correction: Union[int, float] = 0.0, # correction instead of ddof
250 keepdims: bool = False,
251 **kwargs,
252) -> ndarray:
253 return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
255def var(
256 x: ndarray,
257 /,
258 xp,
259 *,
260 axis: Optional[Union[int, Tuple[int, ...]]] = None,
261 correction: Union[int, float] = 0.0, # correction instead of ddof
262 keepdims: bool = False,
263 **kwargs,
264) -> ndarray:
265 return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
267# Unlike transpose(), the axes argument to permute_dims() is required.
268def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
269 return xp.transpose(x, axes)
271# Creation functions add the device keyword (which does nothing for NumPy)
273# asarray also adds the copy keyword
274def _asarray(
275 obj: Union[
276 ndarray,
277 bool,
278 int,
279 float,
280 NestedSequence[bool | int | float],
281 SupportsBufferProtocol,
282 ],
283 /,
284 *,
285 dtype: Optional[Dtype] = None,
286 device: Optional[Device] = None,
287 copy: "Optional[Union[bool, np._CopyMode]]" = None,
288 namespace = None,
289 **kwargs,
290) -> ndarray:
291 """
292 Array API compatibility wrapper for asarray().
294 See the corresponding documentation in NumPy/CuPy and/or the array API
295 specification for more details.
297 """
298 if namespace is None:
299 try:
300 xp = array_namespace(obj, _use_compat=False)
301 except ValueError:
302 # TODO: What about lists of arrays?
303 raise ValueError("A namespace must be specified for asarray() with non-array input")
304 elif isinstance(namespace, ModuleType):
305 xp = namespace
306 elif namespace == 'numpy':
307 import numpy as xp
308 elif namespace == 'cupy':
309 import cupy as xp
310 elif namespace == 'dask.array':
311 import dask.array as xp
312 else:
313 raise ValueError("Unrecognized namespace argument to asarray()")
315 _check_device(xp, device)
316 if is_numpy_array(obj):
317 import numpy as np
318 if hasattr(np, '_CopyMode'):
319 # Not present in older NumPys
320 COPY_FALSE = (False, np._CopyMode.IF_NEEDED)
321 COPY_TRUE = (True, np._CopyMode.ALWAYS)
322 else:
323 COPY_FALSE = (False,)
324 COPY_TRUE = (True,)
325 else:
326 COPY_FALSE = (False,)
327 COPY_TRUE = (True,)
328 if copy in COPY_FALSE and namespace != "dask.array":
329 # copy=False is not yet implemented in xp.asarray
330 raise NotImplementedError("copy=False is not yet implemented")
331 if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)):
332 if dtype is not None and obj.dtype != dtype:
333 copy = True
334 if copy in COPY_TRUE:
335 return xp.array(obj, copy=True, dtype=dtype)
336 return obj
337 elif namespace == "dask.array":
338 if copy in COPY_TRUE:
339 if dtype is None:
340 return obj.copy()
341 # Go through numpy, since dask copy is no-op by default
342 import numpy as np
343 obj = np.array(obj, dtype=dtype, copy=True)
344 return xp.array(obj, dtype=dtype)
345 else:
346 import dask.array as da
347 import numpy as np
348 if not isinstance(obj, da.Array):
349 obj = np.asarray(obj, dtype=dtype)
350 return da.from_array(obj)
351 return obj
353 return xp.asarray(obj, dtype=dtype, **kwargs)
355# np.reshape calls the keyword argument 'newshape' instead of 'shape'
356def reshape(x: ndarray,
357 /,
358 shape: Tuple[int, ...],
359 xp, copy: Optional[bool] = None,
360 **kwargs) -> ndarray:
361 if copy is True:
362 x = x.copy()
363 elif copy is False:
364 y = x.view()
365 y.shape = shape
366 return y
367 return xp.reshape(x, shape, **kwargs)
369# The descending keyword is new in sort and argsort, and 'kind' replaced with
370# 'stable'
371def argsort(
372 x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True,
373 **kwargs,
374) -> ndarray:
375 # Note: this keyword argument is different, and the default is different.
376 # We set it in kwargs like this because numpy.sort uses kind='quicksort'
377 # as the default whereas cupy.sort uses kind=None.
378 if stable:
379 kwargs['kind'] = "stable"
380 if not descending:
381 res = xp.argsort(x, axis=axis, **kwargs)
382 else:
383 # As NumPy has no native descending sort, we imitate it here. Note that
384 # simply flipping the results of xp.argsort(x, ...) would not
385 # respect the relative order like it would in native descending sorts.
386 res = xp.flip(
387 xp.argsort(xp.flip(x, axis=axis), axis=axis, **kwargs),
388 axis=axis,
389 )
390 # Rely on flip()/argsort() to validate axis
391 normalised_axis = axis if axis >= 0 else x.ndim + axis
392 max_i = x.shape[normalised_axis] - 1
393 res = max_i - res
394 return res
396def sort(
397 x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True,
398 **kwargs,
399) -> ndarray:
400 # Note: this keyword argument is different, and the default is different.
401 # We set it in kwargs like this because numpy.sort uses kind='quicksort'
402 # as the default whereas cupy.sort uses kind=None.
403 if stable:
404 kwargs['kind'] = "stable"
405 res = xp.sort(x, axis=axis, **kwargs)
406 if descending:
407 res = xp.flip(res, axis=axis)
408 return res
410# nonzero should error for zero-dimensional arrays
411def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]:
412 if x.ndim == 0:
413 raise ValueError("nonzero() does not support zero-dimensional arrays")
414 return xp.nonzero(x, **kwargs)
416# sum() and prod() should always upcast when dtype=None
417def sum(
418 x: ndarray,
419 /,
420 xp,
421 *,
422 axis: Optional[Union[int, Tuple[int, ...]]] = None,
423 dtype: Optional[Dtype] = None,
424 keepdims: bool = False,
425 **kwargs,
426) -> ndarray:
427 # `xp.sum` already upcasts integers, but not floats or complexes
428 if dtype is None:
429 if x.dtype == xp.float32:
430 dtype = xp.float64
431 elif x.dtype == xp.complex64:
432 dtype = xp.complex128
433 return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs)
435def prod(
436 x: ndarray,
437 /,
438 xp,
439 *,
440 axis: Optional[Union[int, Tuple[int, ...]]] = None,
441 dtype: Optional[Dtype] = None,
442 keepdims: bool = False,
443 **kwargs,
444) -> ndarray:
445 if dtype is None:
446 if x.dtype == xp.float32:
447 dtype = xp.float64
448 elif x.dtype == xp.complex64:
449 dtype = xp.complex128
450 return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs)
452# ceil, floor, and trunc return integers for integer inputs
454def ceil(x: ndarray, /, xp, **kwargs) -> ndarray:
455 if xp.issubdtype(x.dtype, xp.integer):
456 return x
457 return xp.ceil(x, **kwargs)
459def floor(x: ndarray, /, xp, **kwargs) -> ndarray:
460 if xp.issubdtype(x.dtype, xp.integer):
461 return x
462 return xp.floor(x, **kwargs)
464def trunc(x: ndarray, /, xp, **kwargs) -> ndarray:
465 if xp.issubdtype(x.dtype, xp.integer):
466 return x
467 return xp.trunc(x, **kwargs)
469# linear algebra functions
471def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray:
472 return xp.matmul(x1, x2, **kwargs)
474# Unlike transpose, matrix_transpose only transposes the last two axes.
475def matrix_transpose(x: ndarray, /, xp) -> ndarray:
476 if x.ndim < 2:
477 raise ValueError("x must be at least 2-dimensional for matrix_transpose")
478 return xp.swapaxes(x, -1, -2)
480def tensordot(x1: ndarray,
481 x2: ndarray,
482 /,
483 xp,
484 *,
485 axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2,
486 **kwargs,
487) -> ndarray:
488 return xp.tensordot(x1, x2, axes=axes, **kwargs)
490def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray:
491 if x1.shape[axis] != x2.shape[axis]:
492 raise ValueError("x1 and x2 must have the same size along the given axis")
494 if hasattr(xp, 'broadcast_tensors'):
495 _broadcast = xp.broadcast_tensors
496 else:
497 _broadcast = xp.broadcast_arrays
499 x1_ = xp.moveaxis(x1, axis, -1)
500 x2_ = xp.moveaxis(x2, axis, -1)
501 x1_, x2_ = _broadcast(x1_, x2_)
503 res = x1_[..., None, :] @ x2_[..., None]
504 return res[..., 0, 0]
506# isdtype is a new function in the 2022.12 array API specification.
508def isdtype(
509 dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], xp,
510 *, _tuple=True, # Disallow nested tuples
511) -> bool:
512 """
513 Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
515 Note that outside of this function, this compat library does not yet fully
516 support complex numbers.
518 See
519 https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
520 for more details
521 """
522 if isinstance(kind, tuple) and _tuple:
523 return any(isdtype(dtype, k, xp, _tuple=False) for k in kind)
524 elif isinstance(kind, str):
525 if kind == 'bool':
526 return dtype == xp.bool_
527 elif kind == 'signed integer':
528 return xp.issubdtype(dtype, xp.signedinteger)
529 elif kind == 'unsigned integer':
530 return xp.issubdtype(dtype, xp.unsignedinteger)
531 elif kind == 'integral':
532 return xp.issubdtype(dtype, xp.integer)
533 elif kind == 'real floating':
534 return xp.issubdtype(dtype, xp.floating)
535 elif kind == 'complex floating':
536 return xp.issubdtype(dtype, xp.complexfloating)
537 elif kind == 'numeric':
538 return xp.issubdtype(dtype, xp.number)
539 else:
540 raise ValueError(f"Unrecognized data type kind: {kind!r}")
541 else:
542 # This will allow things that aren't required by the spec, like
543 # isdtype(np.float64, float) or isdtype(np.int64, 'l'). Should we be
544 # more strict here to match the type annotation? Note that the
545 # array_api_strict implementation will be very strict.
546 return dtype == kind
548__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
549 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
550 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
551 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
552 'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
553 'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
554 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']