Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/_lib/array_api_compat/array_api_compat/common/_linalg.py: 41%
83 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-23 06:43 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-23 06:43 +0000
1from __future__ import annotations
3from typing import TYPE_CHECKING, NamedTuple
4if TYPE_CHECKING:
5 from typing import Literal, Optional, Sequence, Tuple, Union
6 from ._typing import ndarray
8from numpy.core.numeric import normalize_axis_tuple
10from ._aliases import matmul, matrix_transpose, tensordot, vecdot
11from .._internal import get_xp
13# These are in the main NumPy namespace but not in numpy.linalg
14def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1, **kwargs) -> ndarray:
15 return xp.cross(x1, x2, axis=axis, **kwargs)
17def outer(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray:
18 return xp.outer(x1, x2, **kwargs)
20class EighResult(NamedTuple):
21 eigenvalues: ndarray
22 eigenvectors: ndarray
24class QRResult(NamedTuple):
25 Q: ndarray
26 R: ndarray
28class SlogdetResult(NamedTuple):
29 sign: ndarray
30 logabsdet: ndarray
32class SVDResult(NamedTuple):
33 U: ndarray
34 S: ndarray
35 Vh: ndarray
37# These functions are the same as their NumPy counterparts except they return
38# a namedtuple.
39def eigh(x: ndarray, /, xp, **kwargs) -> EighResult:
40 return EighResult(*xp.linalg.eigh(x, **kwargs))
42def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced',
43 **kwargs) -> QRResult:
44 return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs))
46def slogdet(x: ndarray, /, xp, **kwargs) -> SlogdetResult:
47 return SlogdetResult(*xp.linalg.slogdet(x, **kwargs))
49def svd(x: ndarray, /, xp, *, full_matrices: bool = True, **kwargs) -> SVDResult:
50 return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs))
52# These functions have additional keyword arguments
54# The upper keyword argument is new from NumPy
55def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray:
56 L = xp.linalg.cholesky(x, **kwargs)
57 if upper:
58 return get_xp(xp)(matrix_transpose)(L)
59 return L
61# The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.
62# Note that it has a different semantic meaning from tol and rcond.
63def matrix_rank(x: ndarray,
64 /,
65 xp,
66 *,
67 rtol: Optional[Union[float, ndarray]] = None,
68 **kwargs) -> ndarray:
69 # this is different from xp.linalg.matrix_rank, which supports 1
70 # dimensional arrays.
71 if x.ndim < 2:
72 raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
73 S = xp.linalg.svd(x, compute_uv=False, **kwargs)
74 if rtol is None:
75 tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps
76 else:
77 # this is different from xp.linalg.matrix_rank, which does not
78 # multiply the tolerance by the largest singular value.
79 tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis]
80 return xp.count_nonzero(S > tol, axis=-1)
82def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **kwargs) -> ndarray:
83 # this is different from xp.linalg.pinv, which does not multiply the
84 # default tolerance by max(M, N).
85 if rtol is None:
86 rtol = max(x.shape[-2:]) * xp.finfo(x.dtype).eps
87 return xp.linalg.pinv(x, rcond=rtol, **kwargs)
89# These functions are new in the array API spec
91def matrix_norm(x: ndarray, /, xp, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray:
92 return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)
94# svdvals is not in NumPy (but it is in SciPy). It is equivalent to
95# xp.linalg.svd(compute_uv=False).
96def svdvals(x: ndarray, /, xp) -> Union[ndarray, Tuple[ndarray, ...]]:
97 return xp.linalg.svd(x, compute_uv=False)
99def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray:
100 # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
101 # when axis=None and the input is 2-D, so to force a vector norm, we make
102 # it so the input is 1-D (for axis=None), or reshape so that norm is done
103 # on a single dimension.
104 if axis is None:
105 # Note: xp.linalg.norm() doesn't handle 0-D arrays
106 x = x.ravel()
107 _axis = 0
108 elif isinstance(axis, tuple):
109 # Note: The axis argument supports any number of axes, whereas
110 # xp.linalg.norm() only supports a single axis for vector norm.
111 normalized_axis = normalize_axis_tuple(axis, x.ndim)
112 rest = tuple(i for i in range(x.ndim) if i not in normalized_axis)
113 newshape = axis + rest
114 x = xp.transpose(x, newshape).reshape(
115 (xp.prod([x.shape[i] for i in axis], dtype=int), *[x.shape[i] for i in rest]))
116 _axis = 0
117 else:
118 _axis = axis
120 res = xp.linalg.norm(x, axis=_axis, ord=ord)
122 if keepdims:
123 # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks
124 # above to avoid matrix norm logic.
125 shape = list(x.shape)
126 _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim)
127 for i in _axis:
128 shape[i] = 1
129 res = xp.reshape(res, tuple(shape))
131 return res
133# xp.diagonal and xp.trace operate on the first two axes whereas these
134# operates on the last two
136def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray:
137 return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
139def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray:
140 if dtype is None:
141 if x.dtype == xp.float32:
142 dtype = xp.float64
143 elif x.dtype == xp.complex64:
144 dtype = xp.complex128
145 return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
147__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
148 'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet',
149 'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm',
150 'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
151 'trace']