Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/_lib/array_api_compat/numpy/linalg.py: 64%

50 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-03 06:39 +0000

1from numpy.linalg import * # noqa: F403 

2from numpy.linalg import __all__ as linalg_all 

3import numpy as _np 

4 

5from ..common import _linalg 

6from .._internal import get_xp 

7 

8# These functions are in both the main and linalg namespaces 

9from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 

10 

11import numpy as np 

12 

13cross = get_xp(np)(_linalg.cross) 

14outer = get_xp(np)(_linalg.outer) 

15EighResult = _linalg.EighResult 

16QRResult = _linalg.QRResult 

17SlogdetResult = _linalg.SlogdetResult 

18SVDResult = _linalg.SVDResult 

19eigh = get_xp(np)(_linalg.eigh) 

20qr = get_xp(np)(_linalg.qr) 

21slogdet = get_xp(np)(_linalg.slogdet) 

22svd = get_xp(np)(_linalg.svd) 

23cholesky = get_xp(np)(_linalg.cholesky) 

24matrix_rank = get_xp(np)(_linalg.matrix_rank) 

25pinv = get_xp(np)(_linalg.pinv) 

26matrix_norm = get_xp(np)(_linalg.matrix_norm) 

27svdvals = get_xp(np)(_linalg.svdvals) 

28diagonal = get_xp(np)(_linalg.diagonal) 

29trace = get_xp(np)(_linalg.trace) 

30 

31# Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a 

32# vector when it is exactly 1-dimensional. All other cases treat x2 as a stack 

33# of matrices. The np.linalg.solve behavior of allowing stacks of both 

34# matrices and vectors is ambiguous c.f. 

35# https://github.com/numpy/numpy/issues/15349 and 

36# https://github.com/data-apis/array-api/issues/285. 

37 

38# To workaround this, the below is the code from np.linalg.solve except 

39# only calling solve1 in the exactly 1D case. 

40 

41# This code is here instead of in common because it is numpy specific. Also 

42# note that CuPy's solve() does not currently support broadcasting (see 

43# https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43). 

44def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray: 

45 try: 

46 from numpy.linalg._linalg import ( 

47 _makearray, _assert_stacked_2d, _assert_stacked_square, 

48 _commonType, isComplexType, _raise_linalgerror_singular 

49 ) 

50 except ImportError: 

51 from numpy.linalg.linalg import ( 

52 _makearray, _assert_stacked_2d, _assert_stacked_square, 

53 _commonType, isComplexType, _raise_linalgerror_singular 

54 ) 

55 from numpy.linalg import _umath_linalg 

56 

57 x1, _ = _makearray(x1) 

58 _assert_stacked_2d(x1) 

59 _assert_stacked_square(x1) 

60 x2, wrap = _makearray(x2) 

61 t, result_t = _commonType(x1, x2) 

62 

63 # This part is different from np.linalg.solve 

64 if x2.ndim == 1: 

65 gufunc = _umath_linalg.solve1 

66 else: 

67 gufunc = _umath_linalg.solve 

68 

69 # This does nothing currently but is left in because it will be relevant 

70 # when complex dtype support is added to the spec in 2022. 

71 signature = 'DD->D' if isComplexType(t) else 'dd->d' 

72 with _np.errstate(call=_raise_linalgerror_singular, invalid='call', 

73 over='ignore', divide='ignore', under='ignore'): 

74 r = gufunc(x1, x2, signature=signature) 

75 

76 return wrap(r.astype(result_t, copy=False)) 

77 

78# These functions are completely new here. If the library already has them 

79# (i.e., numpy 2.0), use the library version instead of our wrapper. 

80if hasattr(np.linalg, 'vector_norm'): 

81 vector_norm = np.linalg.vector_norm 

82else: 

83 vector_norm = get_xp(np)(_linalg.vector_norm) 

84 

85__all__ = linalg_all + _linalg.__all__ + ['solve'] 

86 

87del get_xp 

88del np 

89del linalg_all 

90del _linalg