Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/_lib/array_api_compat/common/_linalg.py: 41%

91 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-03 06:39 +0000

1from __future__ import annotations 

2 

3from typing import TYPE_CHECKING, NamedTuple 

4if TYPE_CHECKING: 

5 from typing import Literal, Optional, Tuple, Union 

6 from ._typing import ndarray 

7 

8import math 

9 

10import numpy as np 

11if np.__version__[0] == "2": 

12 from numpy.lib.array_utils import normalize_axis_tuple 

13else: 

14 from numpy.core.numeric import normalize_axis_tuple 

15 

16from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype 

17from .._internal import get_xp 

18 

19# These are in the main NumPy namespace but not in numpy.linalg 

20def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1, **kwargs) -> ndarray: 

21 return xp.cross(x1, x2, axis=axis, **kwargs) 

22 

23def outer(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: 

24 return xp.outer(x1, x2, **kwargs) 

25 

26class EighResult(NamedTuple): 

27 eigenvalues: ndarray 

28 eigenvectors: ndarray 

29 

30class QRResult(NamedTuple): 

31 Q: ndarray 

32 R: ndarray 

33 

34class SlogdetResult(NamedTuple): 

35 sign: ndarray 

36 logabsdet: ndarray 

37 

38class SVDResult(NamedTuple): 

39 U: ndarray 

40 S: ndarray 

41 Vh: ndarray 

42 

43# These functions are the same as their NumPy counterparts except they return 

44# a namedtuple. 

45def eigh(x: ndarray, /, xp, **kwargs) -> EighResult: 

46 return EighResult(*xp.linalg.eigh(x, **kwargs)) 

47 

48def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced', 

49 **kwargs) -> QRResult: 

50 return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs)) 

51 

52def slogdet(x: ndarray, /, xp, **kwargs) -> SlogdetResult: 

53 return SlogdetResult(*xp.linalg.slogdet(x, **kwargs)) 

54 

55def svd(x: ndarray, /, xp, *, full_matrices: bool = True, **kwargs) -> SVDResult: 

56 return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs)) 

57 

58# These functions have additional keyword arguments 

59 

60# The upper keyword argument is new from NumPy 

61def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray: 

62 L = xp.linalg.cholesky(x, **kwargs) 

63 if upper: 

64 U = get_xp(xp)(matrix_transpose)(L) 

65 if get_xp(xp)(isdtype)(U.dtype, 'complex floating'): 

66 U = xp.conj(U) 

67 return U 

68 return L 

69 

70# The rtol keyword argument of matrix_rank() and pinv() is new from NumPy. 

71# Note that it has a different semantic meaning from tol and rcond. 

72def matrix_rank(x: ndarray, 

73 /, 

74 xp, 

75 *, 

76 rtol: Optional[Union[float, ndarray]] = None, 

77 **kwargs) -> ndarray: 

78 # this is different from xp.linalg.matrix_rank, which supports 1 

79 # dimensional arrays. 

80 if x.ndim < 2: 

81 raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") 

82 S = get_xp(xp)(svdvals)(x, **kwargs) 

83 if rtol is None: 

84 tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps 

85 else: 

86 # this is different from xp.linalg.matrix_rank, which does not 

87 # multiply the tolerance by the largest singular value. 

88 tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis] 

89 return xp.count_nonzero(S > tol, axis=-1) 

90 

91def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **kwargs) -> ndarray: 

92 # this is different from xp.linalg.pinv, which does not multiply the 

93 # default tolerance by max(M, N). 

94 if rtol is None: 

95 rtol = max(x.shape[-2:]) * xp.finfo(x.dtype).eps 

96 return xp.linalg.pinv(x, rcond=rtol, **kwargs) 

97 

98# These functions are new in the array API spec 

99 

100def matrix_norm(x: ndarray, /, xp, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray: 

101 return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) 

102 

103# svdvals is not in NumPy (but it is in SciPy). It is equivalent to 

104# xp.linalg.svd(compute_uv=False). 

105def svdvals(x: ndarray, /, xp) -> Union[ndarray, Tuple[ndarray, ...]]: 

106 return xp.linalg.svd(x, compute_uv=False) 

107 

108def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray: 

109 # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or 

110 # when axis=None and the input is 2-D, so to force a vector norm, we make 

111 # it so the input is 1-D (for axis=None), or reshape so that norm is done 

112 # on a single dimension. 

113 if axis is None: 

114 # Note: xp.linalg.norm() doesn't handle 0-D arrays 

115 _x = x.ravel() 

116 _axis = 0 

117 elif isinstance(axis, tuple): 

118 # Note: The axis argument supports any number of axes, whereas 

119 # xp.linalg.norm() only supports a single axis for vector norm. 

120 normalized_axis = normalize_axis_tuple(axis, x.ndim) 

121 rest = tuple(i for i in range(x.ndim) if i not in normalized_axis) 

122 newshape = axis + rest 

123 _x = xp.transpose(x, newshape).reshape( 

124 (math.prod([x.shape[i] for i in axis]), *[x.shape[i] for i in rest])) 

125 _axis = 0 

126 else: 

127 _x = x 

128 _axis = axis 

129 

130 res = xp.linalg.norm(_x, axis=_axis, ord=ord) 

131 

132 if keepdims: 

133 # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks 

134 # above to avoid matrix norm logic. 

135 shape = list(x.shape) 

136 _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim) 

137 for i in _axis: 

138 shape[i] = 1 

139 res = xp.reshape(res, tuple(shape)) 

140 

141 return res 

142 

143# xp.diagonal and xp.trace operate on the first two axes whereas these 

144# operates on the last two 

145 

146def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray: 

147 return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs) 

148 

149def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray: 

150 if dtype is None: 

151 if x.dtype == xp.float32: 

152 dtype = xp.float64 

153 elif x.dtype == xp.complex64: 

154 dtype = xp.complex128 

155 return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)) 

156 

157__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult', 

158 'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet', 

159 'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm', 

160 'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal', 

161 'trace']