Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/_lib/array_api_compat/numpy/linalg.py: 64%
50 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-03 06:39 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-03 06:39 +0000
1from numpy.linalg import * # noqa: F403
2from numpy.linalg import __all__ as linalg_all
3import numpy as _np
5from ..common import _linalg
6from .._internal import get_xp
8# These functions are in both the main and linalg namespaces
9from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401
11import numpy as np
13cross = get_xp(np)(_linalg.cross)
14outer = get_xp(np)(_linalg.outer)
15EighResult = _linalg.EighResult
16QRResult = _linalg.QRResult
17SlogdetResult = _linalg.SlogdetResult
18SVDResult = _linalg.SVDResult
19eigh = get_xp(np)(_linalg.eigh)
20qr = get_xp(np)(_linalg.qr)
21slogdet = get_xp(np)(_linalg.slogdet)
22svd = get_xp(np)(_linalg.svd)
23cholesky = get_xp(np)(_linalg.cholesky)
24matrix_rank = get_xp(np)(_linalg.matrix_rank)
25pinv = get_xp(np)(_linalg.pinv)
26matrix_norm = get_xp(np)(_linalg.matrix_norm)
27svdvals = get_xp(np)(_linalg.svdvals)
28diagonal = get_xp(np)(_linalg.diagonal)
29trace = get_xp(np)(_linalg.trace)
31# Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a
32# vector when it is exactly 1-dimensional. All other cases treat x2 as a stack
33# of matrices. The np.linalg.solve behavior of allowing stacks of both
34# matrices and vectors is ambiguous c.f.
35# https://github.com/numpy/numpy/issues/15349 and
36# https://github.com/data-apis/array-api/issues/285.
38# To workaround this, the below is the code from np.linalg.solve except
39# only calling solve1 in the exactly 1D case.
41# This code is here instead of in common because it is numpy specific. Also
42# note that CuPy's solve() does not currently support broadcasting (see
43# https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43).
44def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray:
45 try:
46 from numpy.linalg._linalg import (
47 _makearray, _assert_stacked_2d, _assert_stacked_square,
48 _commonType, isComplexType, _raise_linalgerror_singular
49 )
50 except ImportError:
51 from numpy.linalg.linalg import (
52 _makearray, _assert_stacked_2d, _assert_stacked_square,
53 _commonType, isComplexType, _raise_linalgerror_singular
54 )
55 from numpy.linalg import _umath_linalg
57 x1, _ = _makearray(x1)
58 _assert_stacked_2d(x1)
59 _assert_stacked_square(x1)
60 x2, wrap = _makearray(x2)
61 t, result_t = _commonType(x1, x2)
63 # This part is different from np.linalg.solve
64 if x2.ndim == 1:
65 gufunc = _umath_linalg.solve1
66 else:
67 gufunc = _umath_linalg.solve
69 # This does nothing currently but is left in because it will be relevant
70 # when complex dtype support is added to the spec in 2022.
71 signature = 'DD->D' if isComplexType(t) else 'dd->d'
72 with _np.errstate(call=_raise_linalgerror_singular, invalid='call',
73 over='ignore', divide='ignore', under='ignore'):
74 r = gufunc(x1, x2, signature=signature)
76 return wrap(r.astype(result_t, copy=False))
78# These functions are completely new here. If the library already has them
79# (i.e., numpy 2.0), use the library version instead of our wrapper.
80if hasattr(np.linalg, 'vector_norm'):
81 vector_norm = np.linalg.vector_norm
82else:
83 vector_norm = get_xp(np)(_linalg.vector_norm)
85__all__ = linalg_all + _linalg.__all__ + ['solve']
87del get_xp
88del np
89del linalg_all
90del _linalg