Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/scipy/linalg/_decomp_lu.py: 20%

45 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-12 06:31 +0000

1"""LU decomposition functions.""" 

2 

3from warnings import warn 

4 

5from numpy import asarray, asarray_chkfinite 

6 

7# Local imports 

8from ._misc import _datacopied, LinAlgWarning 

9from .lapack import get_lapack_funcs 

10from ._flinalg_py import get_flinalg_funcs 

11 

12__all__ = ['lu', 'lu_solve', 'lu_factor'] 

13 

14 

15def lu_factor(a, overwrite_a=False, check_finite=True): 

16 """ 

17 Compute pivoted LU decomposition of a matrix. 

18 

19 The decomposition is:: 

20 

21 A = P L U 

22 

23 where P is a permutation matrix, L lower triangular with unit 

24 diagonal elements, and U upper triangular. 

25 

26 Parameters 

27 ---------- 

28 a : (M, N) array_like 

29 Matrix to decompose 

30 overwrite_a : bool, optional 

31 Whether to overwrite data in A (may increase performance) 

32 check_finite : bool, optional 

33 Whether to check that the input matrix contains only finite numbers. 

34 Disabling may give a performance gain, but may result in problems 

35 (crashes, non-termination) if the inputs do contain infinities or NaNs. 

36 

37 Returns 

38 ------- 

39 lu : (M, N) ndarray 

40 Matrix containing U in its upper triangle, and L in its lower triangle. 

41 The unit diagonal elements of L are not stored. 

42 piv : (N,) ndarray 

43 Pivot indices representing the permutation matrix P: 

44 row i of matrix was interchanged with row piv[i]. 

45 

46 See Also 

47 -------- 

48 lu : gives lu factorization in more user-friendly format 

49 lu_solve : solve an equation system using the LU factorization of a matrix 

50 

51 Notes 

52 ----- 

53 This is a wrapper to the ``*GETRF`` routines from LAPACK. Unlike 

54 :func:`lu`, it outputs the L and U factors into a single array 

55 and returns pivot indices instead of a permutation matrix. 

56 

57 Examples 

58 -------- 

59 >>> import numpy as np 

60 >>> from scipy.linalg import lu_factor 

61 >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]) 

62 >>> lu, piv = lu_factor(A) 

63 >>> piv 

64 array([2, 2, 3, 3], dtype=int32) 

65 

66 Convert LAPACK's ``piv`` array to NumPy index and test the permutation 

67 

68 >>> piv_py = [2, 0, 3, 1] 

69 >>> L, U = np.tril(lu, k=-1) + np.eye(4), np.triu(lu) 

70 >>> np.allclose(A[piv_py] - L @ U, np.zeros((4, 4))) 

71 True 

72 """ 

73 if check_finite: 

74 a1 = asarray_chkfinite(a) 

75 else: 

76 a1 = asarray(a) 

77 overwrite_a = overwrite_a or (_datacopied(a1, a)) 

78 getrf, = get_lapack_funcs(('getrf',), (a1,)) 

79 lu, piv, info = getrf(a1, overwrite_a=overwrite_a) 

80 if info < 0: 

81 raise ValueError('illegal value in %dth argument of ' 

82 'internal getrf (lu_factor)' % -info) 

83 if info > 0: 

84 warn("Diagonal number %d is exactly zero. Singular matrix." % info, 

85 LinAlgWarning, stacklevel=2) 

86 return lu, piv 

87 

88 

89def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): 

90 """Solve an equation system, a x = b, given the LU factorization of a 

91 

92 Parameters 

93 ---------- 

94 (lu, piv) 

95 Factorization of the coefficient matrix a, as given by lu_factor 

96 b : array 

97 Right-hand side 

98 trans : {0, 1, 2}, optional 

99 Type of system to solve: 

100 

101 ===== ========= 

102 trans system 

103 ===== ========= 

104 0 a x = b 

105 1 a^T x = b 

106 2 a^H x = b 

107 ===== ========= 

108 overwrite_b : bool, optional 

109 Whether to overwrite data in b (may increase performance) 

110 check_finite : bool, optional 

111 Whether to check that the input matrices contain only finite numbers. 

112 Disabling may give a performance gain, but may result in problems 

113 (crashes, non-termination) if the inputs do contain infinities or NaNs. 

114 

115 Returns 

116 ------- 

117 x : array 

118 Solution to the system 

119 

120 See Also 

121 -------- 

122 lu_factor : LU factorize a matrix 

123 

124 Examples 

125 -------- 

126 >>> import numpy as np 

127 >>> from scipy.linalg import lu_factor, lu_solve 

128 >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]) 

129 >>> b = np.array([1, 1, 1, 1]) 

130 >>> lu, piv = lu_factor(A) 

131 >>> x = lu_solve((lu, piv), b) 

132 >>> np.allclose(A @ x - b, np.zeros((4,))) 

133 True 

134 

135 """ 

136 (lu, piv) = lu_and_piv 

137 if check_finite: 

138 b1 = asarray_chkfinite(b) 

139 else: 

140 b1 = asarray(b) 

141 overwrite_b = overwrite_b or _datacopied(b1, b) 

142 if lu.shape[0] != b1.shape[0]: 

143 raise ValueError("Shapes of lu {} and b {} are incompatible" 

144 .format(lu.shape, b1.shape)) 

145 

146 getrs, = get_lapack_funcs(('getrs',), (lu, b1)) 

147 x, info = getrs(lu, piv, b1, trans=trans, overwrite_b=overwrite_b) 

148 if info == 0: 

149 return x 

150 raise ValueError('illegal value in %dth argument of internal gesv|posv' 

151 % -info) 

152 

153 

154def lu(a, permute_l=False, overwrite_a=False, check_finite=True): 

155 """ 

156 Compute pivoted LU decomposition of a matrix. 

157 

158 The decomposition is:: 

159 

160 A = P L U 

161 

162 where P is a permutation matrix, L lower triangular with unit 

163 diagonal elements, and U upper triangular. 

164 

165 Parameters 

166 ---------- 

167 a : (M, N) array_like 

168 Array to decompose 

169 permute_l : bool, optional 

170 Perform the multiplication P*L (Default: do not permute) 

171 overwrite_a : bool, optional 

172 Whether to overwrite data in a (may improve performance) 

173 check_finite : bool, optional 

174 Whether to check that the input matrix contains only finite numbers. 

175 Disabling may give a performance gain, but may result in problems 

176 (crashes, non-termination) if the inputs do contain infinities or NaNs. 

177 

178 Returns 

179 ------- 

180 **(If permute_l == False)** 

181 

182 p : (M, M) ndarray 

183 Permutation matrix 

184 l : (M, K) ndarray 

185 Lower triangular or trapezoidal matrix with unit diagonal. 

186 K = min(M, N) 

187 u : (K, N) ndarray 

188 Upper triangular or trapezoidal matrix 

189 

190 **(If permute_l == True)** 

191 

192 pl : (M, K) ndarray 

193 Permuted L matrix. 

194 K = min(M, N) 

195 u : (K, N) ndarray 

196 Upper triangular or trapezoidal matrix 

197 

198 Notes 

199 ----- 

200 This is a LU factorization routine written for SciPy. 

201 

202 Examples 

203 -------- 

204 >>> import numpy as np 

205 >>> from scipy.linalg import lu 

206 >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]) 

207 >>> p, l, u = lu(A) 

208 >>> np.allclose(A - p @ l @ u, np.zeros((4, 4))) 

209 True 

210 

211 """ 

212 if check_finite: 

213 a1 = asarray_chkfinite(a) 

214 else: 

215 a1 = asarray(a) 

216 if len(a1.shape) != 2: 

217 raise ValueError('expected matrix') 

218 overwrite_a = overwrite_a or (_datacopied(a1, a)) 

219 flu, = get_flinalg_funcs(('lu',), (a1,)) 

220 p, l, u, info = flu(a1, permute_l=permute_l, overwrite_a=overwrite_a) 

221 if info < 0: 

222 raise ValueError('illegal value in %dth argument of ' 

223 'internal lu.getrf' % -info) 

224 if permute_l: 

225 return l, u 

226 return p, l, u