Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/linalg/_decomp_lu.py: 13%

91 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-02-14 06:37 +0000

1"""LU decomposition functions.""" 

2 

3from warnings import warn 

4 

5from numpy import asarray, asarray_chkfinite 

6import numpy as np 

7from itertools import product 

8 

9# Local imports 

10from ._misc import _datacopied, LinAlgWarning 

11from .lapack import get_lapack_funcs 

12from ._decomp_lu_cython import lu_dispatcher 

13 

14lapack_cast_dict = {x: ''.join([y for y in 'fdFD' if np.can_cast(x, y)]) 

15 for x in np.typecodes['All']} 

16 

17__all__ = ['lu', 'lu_solve', 'lu_factor'] 

18 

19 

20def lu_factor(a, overwrite_a=False, check_finite=True): 

21 """ 

22 Compute pivoted LU decomposition of a matrix. 

23 

24 The decomposition is:: 

25 

26 A = P L U 

27 

28 where P is a permutation matrix, L lower triangular with unit 

29 diagonal elements, and U upper triangular. 

30 

31 Parameters 

32 ---------- 

33 a : (M, N) array_like 

34 Matrix to decompose 

35 overwrite_a : bool, optional 

36 Whether to overwrite data in A (may increase performance) 

37 check_finite : bool, optional 

38 Whether to check that the input matrix contains only finite numbers. 

39 Disabling may give a performance gain, but may result in problems 

40 (crashes, non-termination) if the inputs do contain infinities or NaNs. 

41 

42 Returns 

43 ------- 

44 lu : (M, N) ndarray 

45 Matrix containing U in its upper triangle, and L in its lower triangle. 

46 The unit diagonal elements of L are not stored. 

47 piv : (K,) ndarray 

48 Pivot indices representing the permutation matrix P: 

49 row i of matrix was interchanged with row piv[i]. 

50 Of shape ``(K,)``, with ``K = min(M, N)``. 

51 

52 See Also 

53 -------- 

54 lu : gives lu factorization in more user-friendly format 

55 lu_solve : solve an equation system using the LU factorization of a matrix 

56 

57 Notes 

58 ----- 

59 This is a wrapper to the ``*GETRF`` routines from LAPACK. Unlike 

60 :func:`lu`, it outputs the L and U factors into a single array 

61 and returns pivot indices instead of a permutation matrix. 

62 

63 While the underlying ``*GETRF`` routines return 1-based pivot indices, the 

64 ``piv`` array returned by ``lu_factor`` contains 0-based indices. 

65 

66 Examples 

67 -------- 

68 >>> import numpy as np 

69 >>> from scipy.linalg import lu_factor 

70 >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]) 

71 >>> lu, piv = lu_factor(A) 

72 >>> piv 

73 array([2, 2, 3, 3], dtype=int32) 

74 

75 Convert LAPACK's ``piv`` array to NumPy index and test the permutation 

76 

77 >>> def pivot_to_permutation(piv): 

78 ... perm = np.arange(len(piv)) 

79 ... for i in range(len(piv)): 

80 ... perm[i], perm[piv[i]] = perm[piv[i]], perm[i] 

81 ... return perm 

82 ... 

83 >>> p_inv = pivot_to_permutation(piv) 

84 >>> p_inv 

85 array([2, 0, 3, 1]) 

86 >>> L, U = np.tril(lu, k=-1) + np.eye(4), np.triu(lu) 

87 >>> np.allclose(A[p_inv] - L @ U, np.zeros((4, 4))) 

88 True 

89 

90 The P matrix in P L U is defined by the inverse permutation and 

91 can be recovered using argsort: 

92 

93 >>> p = np.argsort(p_inv) 

94 >>> p 

95 array([1, 3, 0, 2]) 

96 >>> np.allclose(A - L[p] @ U, np.zeros((4, 4))) 

97 True 

98 

99 or alternatively: 

100 

101 >>> P = np.eye(4)[p] 

102 >>> np.allclose(A - P @ L @ U, np.zeros((4, 4))) 

103 True 

104 """ 

105 if check_finite: 

106 a1 = asarray_chkfinite(a) 

107 else: 

108 a1 = asarray(a) 

109 overwrite_a = overwrite_a or (_datacopied(a1, a)) 

110 getrf, = get_lapack_funcs(('getrf',), (a1,)) 

111 lu, piv, info = getrf(a1, overwrite_a=overwrite_a) 

112 if info < 0: 

113 raise ValueError('illegal value in %dth argument of ' 

114 'internal getrf (lu_factor)' % -info) 

115 if info > 0: 

116 warn("Diagonal number %d is exactly zero. Singular matrix." % info, 

117 LinAlgWarning, stacklevel=2) 

118 return lu, piv 

119 

120 

121def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): 

122 """Solve an equation system, a x = b, given the LU factorization of a 

123 

124 Parameters 

125 ---------- 

126 (lu, piv) 

127 Factorization of the coefficient matrix a, as given by lu_factor. 

128 In particular piv are 0-indexed pivot indices. 

129 b : array 

130 Right-hand side 

131 trans : {0, 1, 2}, optional 

132 Type of system to solve: 

133 

134 ===== ========= 

135 trans system 

136 ===== ========= 

137 0 a x = b 

138 1 a^T x = b 

139 2 a^H x = b 

140 ===== ========= 

141 overwrite_b : bool, optional 

142 Whether to overwrite data in b (may increase performance) 

143 check_finite : bool, optional 

144 Whether to check that the input matrices contain only finite numbers. 

145 Disabling may give a performance gain, but may result in problems 

146 (crashes, non-termination) if the inputs do contain infinities or NaNs. 

147 

148 Returns 

149 ------- 

150 x : array 

151 Solution to the system 

152 

153 See Also 

154 -------- 

155 lu_factor : LU factorize a matrix 

156 

157 Examples 

158 -------- 

159 >>> import numpy as np 

160 >>> from scipy.linalg import lu_factor, lu_solve 

161 >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]) 

162 >>> b = np.array([1, 1, 1, 1]) 

163 >>> lu, piv = lu_factor(A) 

164 >>> x = lu_solve((lu, piv), b) 

165 >>> np.allclose(A @ x - b, np.zeros((4,))) 

166 True 

167 

168 """ 

169 (lu, piv) = lu_and_piv 

170 if check_finite: 

171 b1 = asarray_chkfinite(b) 

172 else: 

173 b1 = asarray(b) 

174 overwrite_b = overwrite_b or _datacopied(b1, b) 

175 if lu.shape[0] != b1.shape[0]: 

176 raise ValueError(f"Shapes of lu {lu.shape} and b {b1.shape} are incompatible") 

177 

178 getrs, = get_lapack_funcs(('getrs',), (lu, b1)) 

179 x, info = getrs(lu, piv, b1, trans=trans, overwrite_b=overwrite_b) 

180 if info == 0: 

181 return x 

182 raise ValueError('illegal value in %dth argument of internal gesv|posv' 

183 % -info) 

184 

185 

186def lu(a, permute_l=False, overwrite_a=False, check_finite=True, 

187 p_indices=False): 

188 """ 

189 Compute LU decomposition of a matrix with partial pivoting. 

190 

191 The decomposition satisfies:: 

192 

193 A = P @ L @ U 

194 

195 where ``P`` is a permutation matrix, ``L`` lower triangular with unit 

196 diagonal elements, and ``U`` upper triangular. If `permute_l` is set to 

197 ``True`` then ``L`` is returned already permuted and hence satisfying 

198 ``A = L @ U``. 

199 

200 Parameters 

201 ---------- 

202 a : (M, N) array_like 

203 Array to decompose 

204 permute_l : bool, optional 

205 Perform the multiplication P*L (Default: do not permute) 

206 overwrite_a : bool, optional 

207 Whether to overwrite data in a (may improve performance) 

208 check_finite : bool, optional 

209 Whether to check that the input matrix contains only finite numbers. 

210 Disabling may give a performance gain, but may result in problems 

211 (crashes, non-termination) if the inputs do contain infinities or NaNs. 

212 p_indices : bool, optional 

213 If ``True`` the permutation information is returned as row indices. 

214 The default is ``False`` for backwards-compatibility reasons. 

215 

216 Returns 

217 ------- 

218 **(If `permute_l` is ``False``)** 

219 

220 p : (..., M, M) ndarray 

221 Permutation arrays or vectors depending on `p_indices` 

222 l : (..., M, K) ndarray 

223 Lower triangular or trapezoidal array with unit diagonal. 

224 ``K = min(M, N)`` 

225 u : (..., K, N) ndarray 

226 Upper triangular or trapezoidal array 

227 

228 **(If `permute_l` is ``True``)** 

229 

230 pl : (..., M, K) ndarray 

231 Permuted L matrix. 

232 ``K = min(M, N)`` 

233 u : (..., K, N) ndarray 

234 Upper triangular or trapezoidal array 

235 

236 Notes 

237 ----- 

238 Permutation matrices are costly since they are nothing but row reorder of 

239 ``L`` and hence indices are strongly recommended to be used instead if the 

240 permutation is required. The relation in the 2D case then becomes simply 

241 ``A = L[P, :] @ U``. In higher dimensions, it is better to use `permute_l` 

242 to avoid complicated indexing tricks. 

243 

244 In 2D case, if one has the indices however, for some reason, the 

245 permutation matrix is still needed then it can be constructed by 

246 ``np.eye(M)[P, :]``. 

247 

248 Examples 

249 -------- 

250 

251 >>> import numpy as np 

252 >>> from scipy.linalg import lu 

253 >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]) 

254 >>> p, l, u = lu(A) 

255 >>> np.allclose(A, p @ l @ u) 

256 True 

257 >>> p # Permutation matrix 

258 array([[0., 1., 0., 0.], # Row index 1 

259 [0., 0., 0., 1.], # Row index 3 

260 [1., 0., 0., 0.], # Row index 0 

261 [0., 0., 1., 0.]]) # Row index 2 

262 >>> p, _, _ = lu(A, p_indices=True) 

263 >>> p 

264 array([1, 3, 0, 2]) # as given by row indices above 

265 >>> np.allclose(A, l[p, :] @ u) 

266 True 

267 

268 We can also use nd-arrays, for example, a demonstration with 4D array: 

269 

270 >>> rng = np.random.default_rng() 

271 >>> A = rng.uniform(low=-4, high=4, size=[3, 2, 4, 8]) 

272 >>> p, l, u = lu(A) 

273 >>> p.shape, l.shape, u.shape 

274 ((3, 2, 4, 4), (3, 2, 4, 4), (3, 2, 4, 8)) 

275 >>> np.allclose(A, p @ l @ u) 

276 True 

277 >>> PL, U = lu(A, permute_l=True) 

278 >>> np.allclose(A, PL @ U) 

279 True 

280 

281 """ 

282 a1 = np.asarray_chkfinite(a) if check_finite else np.asarray(a) 

283 if a1.ndim < 2: 

284 raise ValueError('The input array must be at least two-dimensional.') 

285 

286 # Also check if dtype is LAPACK compatible 

287 if a1.dtype.char not in 'fdFD': 

288 dtype_char = lapack_cast_dict[a1.dtype.char] 

289 if not dtype_char: # No casting possible 

290 raise TypeError(f'The dtype {a1.dtype} cannot be cast ' 

291 'to float(32, 64) or complex(64, 128).') 

292 

293 a1 = a1.astype(dtype_char[0]) # makes a copy, free to scratch 

294 overwrite_a = True 

295 

296 *nd, m, n = a1.shape 

297 k = min(m, n) 

298 real_dchar = 'f' if a1.dtype.char in 'fF' else 'd' 

299 

300 # Empty input 

301 if min(*a1.shape) == 0: 

302 if permute_l: 

303 PL = np.empty(shape=[*nd, m, k], dtype=a1.dtype) 

304 U = np.empty(shape=[*nd, k, n], dtype=a1.dtype) 

305 return PL, U 

306 else: 

307 P = (np.empty([*nd, 0], dtype=np.int32) if p_indices else 

308 np.empty([*nd, 0, 0], dtype=real_dchar)) 

309 L = np.empty(shape=[*nd, m, k], dtype=a1.dtype) 

310 U = np.empty(shape=[*nd, k, n], dtype=a1.dtype) 

311 return P, L, U 

312 

313 # Scalar case 

314 if a1.shape[-2:] == (1, 1): 

315 if permute_l: 

316 return np.ones_like(a1), (a1 if overwrite_a else a1.copy()) 

317 else: 

318 P = (np.zeros(shape=[*nd, m], dtype=int) if p_indices 

319 else np.ones_like(a1)) 

320 return P, np.ones_like(a1), (a1 if overwrite_a else a1.copy()) 

321 

322 # Then check overwrite permission 

323 if not _datacopied(a1, a): # "a" still alive through "a1" 

324 if not overwrite_a: 

325 # Data belongs to "a" so make a copy 

326 a1 = a1.copy(order='C') 

327 # else: Do nothing we'll use "a" if possible 

328 # else: a1 has its own data thus free to scratch 

329 

330 # Then layout checks, might happen that overwrite is allowed but original 

331 # array was read-only or non-contiguous. 

332 

333 if not (a1.flags['C_CONTIGUOUS'] and a1.flags['WRITEABLE']): 

334 a1 = a1.copy(order='C') 

335 

336 if not nd: # 2D array 

337 

338 p = np.empty(m, dtype=np.int32) 

339 u = np.zeros([k, k], dtype=a1.dtype) 

340 lu_dispatcher(a1, u, p, permute_l) 

341 P, L, U = (p, a1, u) if m > n else (p, u, a1) 

342 

343 else: # Stacked array 

344 

345 # Prepare the contiguous data holders 

346 P = np.empty([*nd, m], dtype=np.int32) # perm vecs 

347 

348 if m > n: # Tall arrays, U will be created 

349 U = np.zeros([*nd, k, k], dtype=a1.dtype) 

350 for ind in product(*[range(x) for x in a1.shape[:-2]]): 

351 lu_dispatcher(a1[ind], U[ind], P[ind], permute_l) 

352 L = a1 

353 

354 else: # Fat arrays, L will be created 

355 L = np.zeros([*nd, k, k], dtype=a1.dtype) 

356 for ind in product(*[range(x) for x in a1.shape[:-2]]): 

357 lu_dispatcher(a1[ind], L[ind], P[ind], permute_l) 

358 U = a1 

359 

360 # Convert permutation vecs to permutation arrays 

361 # permute_l=False needed to enter here to avoid wasted efforts 

362 if (not p_indices) and (not permute_l): 

363 if nd: 

364 Pa = np.zeros([*nd, m, m], dtype=real_dchar) 

365 # An unreadable index hack - One-hot encoding for perm matrices 

366 nd_ix = np.ix_(*([np.arange(x) for x in nd]+[np.arange(m)])) 

367 Pa[(*nd_ix, P)] = 1 

368 P = Pa 

369 else: # 2D case 

370 Pa = np.zeros([m, m], dtype=real_dchar) 

371 Pa[np.arange(m), P] = 1 

372 P = Pa 

373 

374 return (L, U) if permute_l else (P, L, U)