Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/_lib/array_api_compat/array_api_compat/common/_aliases.py: 25%
211 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-23 06:43 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-23 06:43 +0000
1"""
2These are functions that are just aliases of existing functions in NumPy.
3"""
5from __future__ import annotations
7from typing import TYPE_CHECKING
8if TYPE_CHECKING:
9 from typing import Optional, Sequence, Tuple, Union, List
10 from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
12from typing import NamedTuple
13from types import ModuleType
14import inspect
16from ._helpers import _check_device, _is_numpy_array, array_namespace
18# These functions are modified from the NumPy versions.
20def arange(
21 start: Union[int, float],
22 /,
23 stop: Optional[Union[int, float]] = None,
24 step: Union[int, float] = 1,
25 *,
26 xp,
27 dtype: Optional[Dtype] = None,
28 device: Optional[Device] = None,
29 **kwargs
30) -> ndarray:
31 _check_device(xp, device)
32 return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
34def empty(
35 shape: Union[int, Tuple[int, ...]],
36 xp,
37 *,
38 dtype: Optional[Dtype] = None,
39 device: Optional[Device] = None,
40 **kwargs
41) -> ndarray:
42 _check_device(xp, device)
43 return xp.empty(shape, dtype=dtype, **kwargs)
45def empty_like(
46 x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
47 **kwargs
48) -> ndarray:
49 _check_device(xp, device)
50 return xp.empty_like(x, dtype=dtype, **kwargs)
52def eye(
53 n_rows: int,
54 n_cols: Optional[int] = None,
55 /,
56 *,
57 xp,
58 k: int = 0,
59 dtype: Optional[Dtype] = None,
60 device: Optional[Device] = None,
61 **kwargs,
62) -> ndarray:
63 _check_device(xp, device)
64 return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
66def full(
67 shape: Union[int, Tuple[int, ...]],
68 fill_value: Union[int, float],
69 xp,
70 *,
71 dtype: Optional[Dtype] = None,
72 device: Optional[Device] = None,
73 **kwargs,
74) -> ndarray:
75 _check_device(xp, device)
76 return xp.full(shape, fill_value, dtype=dtype, **kwargs)
78def full_like(
79 x: ndarray,
80 /,
81 fill_value: Union[int, float],
82 *,
83 xp,
84 dtype: Optional[Dtype] = None,
85 device: Optional[Device] = None,
86 **kwargs,
87) -> ndarray:
88 _check_device(xp, device)
89 return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
91def linspace(
92 start: Union[int, float],
93 stop: Union[int, float],
94 /,
95 num: int,
96 *,
97 xp,
98 dtype: Optional[Dtype] = None,
99 device: Optional[Device] = None,
100 endpoint: bool = True,
101 **kwargs,
102) -> ndarray:
103 _check_device(xp, device)
104 return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
106def ones(
107 shape: Union[int, Tuple[int, ...]],
108 xp,
109 *,
110 dtype: Optional[Dtype] = None,
111 device: Optional[Device] = None,
112 **kwargs,
113) -> ndarray:
114 _check_device(xp, device)
115 return xp.ones(shape, dtype=dtype, **kwargs)
117def ones_like(
118 x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
119 **kwargs,
120) -> ndarray:
121 _check_device(xp, device)
122 return xp.ones_like(x, dtype=dtype, **kwargs)
124def zeros(
125 shape: Union[int, Tuple[int, ...]],
126 xp,
127 *,
128 dtype: Optional[Dtype] = None,
129 device: Optional[Device] = None,
130 **kwargs,
131) -> ndarray:
132 _check_device(xp, device)
133 return xp.zeros(shape, dtype=dtype, **kwargs)
135def zeros_like(
136 x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
137 **kwargs,
138) -> ndarray:
139 _check_device(xp, device)
140 return xp.zeros_like(x, dtype=dtype, **kwargs)
142# np.unique() is split into four functions in the array API:
143# unique_all, unique_counts, unique_inverse, and unique_values (this is done
144# to remove polymorphic return types).
146# The functions here return namedtuples (np.unique() returns a normal
147# tuple).
148class UniqueAllResult(NamedTuple):
149 values: ndarray
150 indices: ndarray
151 inverse_indices: ndarray
152 counts: ndarray
155class UniqueCountsResult(NamedTuple):
156 values: ndarray
157 counts: ndarray
160class UniqueInverseResult(NamedTuple):
161 values: ndarray
162 inverse_indices: ndarray
165def _unique_kwargs(xp):
166 # Older versions of NumPy and CuPy do not have equal_nan. Rather than
167 # trying to parse version numbers, just check if equal_nan is in the
168 # signature.
169 s = inspect.signature(xp.unique)
170 if 'equal_nan' in s.parameters:
171 return {'equal_nan': False}
172 return {}
174def unique_all(x: ndarray, /, xp) -> UniqueAllResult:
175 kwargs = _unique_kwargs(xp)
176 values, indices, inverse_indices, counts = xp.unique(
177 x,
178 return_counts=True,
179 return_index=True,
180 return_inverse=True,
181 **kwargs,
182 )
183 # np.unique() flattens inverse indices, but they need to share x's shape
184 # See https://github.com/numpy/numpy/issues/20638
185 inverse_indices = inverse_indices.reshape(x.shape)
186 return UniqueAllResult(
187 values,
188 indices,
189 inverse_indices,
190 counts,
191 )
194def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult:
195 kwargs = _unique_kwargs(xp)
196 res = xp.unique(
197 x,
198 return_counts=True,
199 return_index=False,
200 return_inverse=False,
201 **kwargs
202 )
204 return UniqueCountsResult(*res)
207def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult:
208 kwargs = _unique_kwargs(xp)
209 values, inverse_indices = xp.unique(
210 x,
211 return_counts=False,
212 return_index=False,
213 return_inverse=True,
214 **kwargs,
215 )
216 # xp.unique() flattens inverse indices, but they need to share x's shape
217 # See https://github.com/numpy/numpy/issues/20638
218 inverse_indices = inverse_indices.reshape(x.shape)
219 return UniqueInverseResult(values, inverse_indices)
222def unique_values(x: ndarray, /, xp) -> ndarray:
223 kwargs = _unique_kwargs(xp)
224 return xp.unique(
225 x,
226 return_counts=False,
227 return_index=False,
228 return_inverse=False,
229 **kwargs,
230 )
232def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray:
233 if not copy and dtype == x.dtype:
234 return x
235 return x.astype(dtype=dtype, copy=copy)
237# These functions have different keyword argument names
239def std(
240 x: ndarray,
241 /,
242 xp,
243 *,
244 axis: Optional[Union[int, Tuple[int, ...]]] = None,
245 correction: Union[int, float] = 0.0, # correction instead of ddof
246 keepdims: bool = False,
247 **kwargs,
248) -> ndarray:
249 return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
251def var(
252 x: ndarray,
253 /,
254 xp,
255 *,
256 axis: Optional[Union[int, Tuple[int, ...]]] = None,
257 correction: Union[int, float] = 0.0, # correction instead of ddof
258 keepdims: bool = False,
259 **kwargs,
260) -> ndarray:
261 return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
263# Unlike transpose(), the axes argument to permute_dims() is required.
264def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
265 return xp.transpose(x, axes)
267# Creation functions add the device keyword (which does nothing for NumPy)
269# asarray also adds the copy keyword
270def _asarray(
271 obj: Union[
272 ndarray,
273 bool,
274 int,
275 float,
276 NestedSequence[bool | int | float],
277 SupportsBufferProtocol,
278 ],
279 /,
280 *,
281 dtype: Optional[Dtype] = None,
282 device: Optional[Device] = None,
283 copy: "Optional[Union[bool, np._CopyMode]]" = None,
284 namespace = None,
285 **kwargs,
286) -> ndarray:
287 """
288 Array API compatibility wrapper for asarray().
290 See the corresponding documentation in NumPy/CuPy and/or the array API
291 specification for more details.
293 """
294 if namespace is None:
295 try:
296 xp = array_namespace(obj, _use_compat=False)
297 except ValueError:
298 # TODO: What about lists of arrays?
299 raise ValueError("A namespace must be specified for asarray() with non-array input")
300 elif isinstance(namespace, ModuleType):
301 xp = namespace
302 elif namespace == 'numpy':
303 import numpy as xp
304 elif namespace == 'cupy':
305 import cupy as xp
306 else:
307 raise ValueError("Unrecognized namespace argument to asarray()")
309 _check_device(xp, device)
310 if _is_numpy_array(obj):
311 import numpy as np
312 if hasattr(np, '_CopyMode'):
313 # Not present in older NumPys
314 COPY_FALSE = (False, np._CopyMode.IF_NEEDED)
315 COPY_TRUE = (True, np._CopyMode.ALWAYS)
316 else:
317 COPY_FALSE = (False,)
318 COPY_TRUE = (True,)
319 else:
320 COPY_FALSE = (False,)
321 COPY_TRUE = (True,)
322 if copy in COPY_FALSE:
323 # copy=False is not yet implemented in xp.asarray
324 raise NotImplementedError("copy=False is not yet implemented")
325 if isinstance(obj, xp.ndarray):
326 if dtype is not None and obj.dtype != dtype:
327 copy = True
328 if copy in COPY_TRUE:
329 return xp.array(obj, copy=True, dtype=dtype)
330 return obj
332 return xp.asarray(obj, dtype=dtype, **kwargs)
334# np.reshape calls the keyword argument 'newshape' instead of 'shape'
335def reshape(x: ndarray,
336 /,
337 shape: Tuple[int, ...],
338 xp, copy: Optional[bool] = None,
339 **kwargs) -> ndarray:
340 if copy is True:
341 x = x.copy()
342 elif copy is False:
343 y = x.view()
344 y.shape = shape
345 return y
346 return xp.reshape(x, shape, **kwargs)
348# The descending keyword is new in sort and argsort, and 'kind' replaced with
349# 'stable'
350def argsort(
351 x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True,
352 **kwargs,
353) -> ndarray:
354 # Note: this keyword argument is different, and the default is different.
355 # We set it in kwargs like this because numpy.sort uses kind='quicksort'
356 # as the default whereas cupy.sort uses kind=None.
357 if stable:
358 kwargs['kind'] = "stable"
359 if not descending:
360 res = xp.argsort(x, axis=axis, **kwargs)
361 else:
362 # As NumPy has no native descending sort, we imitate it here. Note that
363 # simply flipping the results of xp.argsort(x, ...) would not
364 # respect the relative order like it would in native descending sorts.
365 res = xp.flip(
366 xp.argsort(xp.flip(x, axis=axis), axis=axis, **kwargs),
367 axis=axis,
368 )
369 # Rely on flip()/argsort() to validate axis
370 normalised_axis = axis if axis >= 0 else x.ndim + axis
371 max_i = x.shape[normalised_axis] - 1
372 res = max_i - res
373 return res
375def sort(
376 x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True,
377 **kwargs,
378) -> ndarray:
379 # Note: this keyword argument is different, and the default is different.
380 # We set it in kwargs like this because numpy.sort uses kind='quicksort'
381 # as the default whereas cupy.sort uses kind=None.
382 if stable:
383 kwargs['kind'] = "stable"
384 res = xp.sort(x, axis=axis, **kwargs)
385 if descending:
386 res = xp.flip(res, axis=axis)
387 return res
389# sum() and prod() should always upcast when dtype=None
390def sum(
391 x: ndarray,
392 /,
393 xp,
394 *,
395 axis: Optional[Union[int, Tuple[int, ...]]] = None,
396 dtype: Optional[Dtype] = None,
397 keepdims: bool = False,
398 **kwargs,
399) -> ndarray:
400 # `xp.sum` already upcasts integers, but not floats or complexes
401 if dtype is None:
402 if x.dtype == xp.float32:
403 dtype = xp.float64
404 elif x.dtype == xp.complex64:
405 dtype = xp.complex128
406 return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs)
408def prod(
409 x: ndarray,
410 /,
411 xp,
412 *,
413 axis: Optional[Union[int, Tuple[int, ...]]] = None,
414 dtype: Optional[Dtype] = None,
415 keepdims: bool = False,
416 **kwargs,
417) -> ndarray:
418 if dtype is None:
419 if x.dtype == xp.float32:
420 dtype = xp.float64
421 elif x.dtype == xp.complex64:
422 dtype = xp.complex128
423 return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs)
425# ceil, floor, and trunc return integers for integer inputs
427def ceil(x: ndarray, /, xp, **kwargs) -> ndarray:
428 if xp.issubdtype(x.dtype, xp.integer):
429 return x
430 return xp.ceil(x, **kwargs)
432def floor(x: ndarray, /, xp, **kwargs) -> ndarray:
433 if xp.issubdtype(x.dtype, xp.integer):
434 return x
435 return xp.floor(x, **kwargs)
437def trunc(x: ndarray, /, xp, **kwargs) -> ndarray:
438 if xp.issubdtype(x.dtype, xp.integer):
439 return x
440 return xp.trunc(x, **kwargs)
442# linear algebra functions
444def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray:
445 return xp.matmul(x1, x2, **kwargs)
447# Unlike transpose, matrix_transpose only transposes the last two axes.
448def matrix_transpose(x: ndarray, /, xp) -> ndarray:
449 if x.ndim < 2:
450 raise ValueError("x must be at least 2-dimensional for matrix_transpose")
451 return xp.swapaxes(x, -1, -2)
453def tensordot(x1: ndarray,
454 x2: ndarray,
455 /,
456 xp,
457 *,
458 axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2,
459 **kwargs,
460) -> ndarray:
461 return xp.tensordot(x1, x2, axes=axes, **kwargs)
463def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray:
464 ndim = max(x1.ndim, x2.ndim)
465 x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
466 x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
467 if x1_shape[axis] != x2_shape[axis]:
468 raise ValueError("x1 and x2 must have the same size along the given axis")
470 if hasattr(xp, 'broadcast_tensors'):
471 _broadcast = xp.broadcast_tensors
472 else:
473 _broadcast = xp.broadcast_arrays
475 x1_, x2_ = _broadcast(x1, x2)
476 x1_ = xp.moveaxis(x1_, axis, -1)
477 x2_ = xp.moveaxis(x2_, axis, -1)
479 res = x1_[..., None, :] @ x2_[..., None]
480 return res[..., 0, 0]
482# isdtype is a new function in the 2022.12 array API specification.
484def isdtype(
485 dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], xp,
486 *, _tuple=True, # Disallow nested tuples
487) -> bool:
488 """
489 Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
491 Note that outside of this function, this compat library does not yet fully
492 support complex numbers.
494 See
495 https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
496 for more details
497 """
498 if isinstance(kind, tuple) and _tuple:
499 return any(isdtype(dtype, k, xp, _tuple=False) for k in kind)
500 elif isinstance(kind, str):
501 if kind == 'bool':
502 return dtype == xp.bool_
503 elif kind == 'signed integer':
504 return xp.issubdtype(dtype, xp.signedinteger)
505 elif kind == 'unsigned integer':
506 return xp.issubdtype(dtype, xp.unsignedinteger)
507 elif kind == 'integral':
508 return xp.issubdtype(dtype, xp.integer)
509 elif kind == 'real floating':
510 return xp.issubdtype(dtype, xp.floating)
511 elif kind == 'complex floating':
512 return xp.issubdtype(dtype, xp.complexfloating)
513 elif kind == 'numeric':
514 return xp.issubdtype(dtype, xp.number)
515 else:
516 raise ValueError(f"Unrecognized data type kind: {kind!r}")
517 else:
518 # This will allow things that aren't required by the spec, like
519 # isdtype(np.float64, float) or isdtype(np.int64, 'l'). Should we be
520 # more strict here to match the type annotation? Note that the
521 # numpy.array_api implementation will be very strict.
522 return dtype == kind
524__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
525 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
526 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
527 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
528 'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
529 'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc', 'matmul',
530 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']