Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/linalg/_decomp_lu.py: 14%

92 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-23 06:43 +0000

1"""LU decomposition functions.""" 

2 

3from warnings import warn 

4 

5from numpy import asarray, asarray_chkfinite 

6import numpy as np 

7from itertools import product 

8 

9# Local imports 

10from ._misc import _datacopied, LinAlgWarning 

11from .lapack import get_lapack_funcs 

12from ._decomp_lu_cython import lu_dispatcher 

13 

14# deprecated imports to be removed in SciPy 1.13.0 

15from scipy.linalg._flinalg_py import get_flinalg_funcs # noqa 

16 

17 

18lapack_cast_dict = {x: ''.join([y for y in 'fdFD' if np.can_cast(x, y)]) 

19 for x in np.typecodes['All']} 

20 

21__all__ = ['lu', 'lu_solve', 'lu_factor'] 

22 

23 

24def lu_factor(a, overwrite_a=False, check_finite=True): 

25 """ 

26 Compute pivoted LU decomposition of a matrix. 

27 

28 The decomposition is:: 

29 

30 A = P L U 

31 

32 where P is a permutation matrix, L lower triangular with unit 

33 diagonal elements, and U upper triangular. 

34 

35 Parameters 

36 ---------- 

37 a : (M, N) array_like 

38 Matrix to decompose 

39 overwrite_a : bool, optional 

40 Whether to overwrite data in A (may increase performance) 

41 check_finite : bool, optional 

42 Whether to check that the input matrix contains only finite numbers. 

43 Disabling may give a performance gain, but may result in problems 

44 (crashes, non-termination) if the inputs do contain infinities or NaNs. 

45 

46 Returns 

47 ------- 

48 lu : (M, N) ndarray 

49 Matrix containing U in its upper triangle, and L in its lower triangle. 

50 The unit diagonal elements of L are not stored. 

51 piv : (K,) ndarray 

52 Pivot indices representing the permutation matrix P: 

53 row i of matrix was interchanged with row piv[i]. 

54 Of shape ``(K,)``, with ``K = min(M, N)``. 

55 

56 See Also 

57 -------- 

58 lu : gives lu factorization in more user-friendly format 

59 lu_solve : solve an equation system using the LU factorization of a matrix 

60 

61 Notes 

62 ----- 

63 This is a wrapper to the ``*GETRF`` routines from LAPACK. Unlike 

64 :func:`lu`, it outputs the L and U factors into a single array 

65 and returns pivot indices instead of a permutation matrix. 

66 

67 While the underlying ``*GETRF`` routines return 1-based pivot indices, the 

68 ``piv`` array returned by ``lu_factor`` contains 0-based indices. 

69 

70 Examples 

71 -------- 

72 >>> import numpy as np 

73 >>> from scipy.linalg import lu_factor 

74 >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]) 

75 >>> lu, piv = lu_factor(A) 

76 >>> piv 

77 array([2, 2, 3, 3], dtype=int32) 

78 

79 Convert LAPACK's ``piv`` array to NumPy index and test the permutation 

80 

81 >>> def pivot_to_permutation(piv): 

82 ... perm = np.arange(len(piv)) 

83 ... for i in range(len(piv)): 

84 ... perm[i], perm[piv[i]] = perm[piv[i]], perm[i] 

85 ... return perm 

86 ... 

87 >>> p_inv = pivot_to_permutation(piv) 

88 >>> p_inv 

89 array([2, 0, 3, 1]) 

90 >>> L, U = np.tril(lu, k=-1) + np.eye(4), np.triu(lu) 

91 >>> np.allclose(A[p_inv] - L @ U, np.zeros((4, 4))) 

92 True 

93 

94 The P matrix in P L U is defined by the inverse permutation and 

95 can be recovered using argsort: 

96 

97 >>> p = np.argsort(p_inv) 

98 >>> p 

99 array([1, 3, 0, 2]) 

100 >>> np.allclose(A - L[p] @ U, np.zeros((4, 4))) 

101 True 

102 

103 or alternatively: 

104 

105 >>> P = np.eye(4)[p] 

106 >>> np.allclose(A - P @ L @ U, np.zeros((4, 4))) 

107 True 

108 """ 

109 if check_finite: 

110 a1 = asarray_chkfinite(a) 

111 else: 

112 a1 = asarray(a) 

113 overwrite_a = overwrite_a or (_datacopied(a1, a)) 

114 getrf, = get_lapack_funcs(('getrf',), (a1,)) 

115 lu, piv, info = getrf(a1, overwrite_a=overwrite_a) 

116 if info < 0: 

117 raise ValueError('illegal value in %dth argument of ' 

118 'internal getrf (lu_factor)' % -info) 

119 if info > 0: 

120 warn("Diagonal number %d is exactly zero. Singular matrix." % info, 

121 LinAlgWarning, stacklevel=2) 

122 return lu, piv 

123 

124 

125def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): 

126 """Solve an equation system, a x = b, given the LU factorization of a 

127 

128 Parameters 

129 ---------- 

130 (lu, piv) 

131 Factorization of the coefficient matrix a, as given by lu_factor. 

132 In particular piv are 0-indexed pivot indices. 

133 b : array 

134 Right-hand side 

135 trans : {0, 1, 2}, optional 

136 Type of system to solve: 

137 

138 ===== ========= 

139 trans system 

140 ===== ========= 

141 0 a x = b 

142 1 a^T x = b 

143 2 a^H x = b 

144 ===== ========= 

145 overwrite_b : bool, optional 

146 Whether to overwrite data in b (may increase performance) 

147 check_finite : bool, optional 

148 Whether to check that the input matrices contain only finite numbers. 

149 Disabling may give a performance gain, but may result in problems 

150 (crashes, non-termination) if the inputs do contain infinities or NaNs. 

151 

152 Returns 

153 ------- 

154 x : array 

155 Solution to the system 

156 

157 See Also 

158 -------- 

159 lu_factor : LU factorize a matrix 

160 

161 Examples 

162 -------- 

163 >>> import numpy as np 

164 >>> from scipy.linalg import lu_factor, lu_solve 

165 >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]) 

166 >>> b = np.array([1, 1, 1, 1]) 

167 >>> lu, piv = lu_factor(A) 

168 >>> x = lu_solve((lu, piv), b) 

169 >>> np.allclose(A @ x - b, np.zeros((4,))) 

170 True 

171 

172 """ 

173 (lu, piv) = lu_and_piv 

174 if check_finite: 

175 b1 = asarray_chkfinite(b) 

176 else: 

177 b1 = asarray(b) 

178 overwrite_b = overwrite_b or _datacopied(b1, b) 

179 if lu.shape[0] != b1.shape[0]: 

180 raise ValueError("Shapes of lu {} and b {} are incompatible" 

181 .format(lu.shape, b1.shape)) 

182 

183 getrs, = get_lapack_funcs(('getrs',), (lu, b1)) 

184 x, info = getrs(lu, piv, b1, trans=trans, overwrite_b=overwrite_b) 

185 if info == 0: 

186 return x 

187 raise ValueError('illegal value in %dth argument of internal gesv|posv' 

188 % -info) 

189 

190 

191def lu(a, permute_l=False, overwrite_a=False, check_finite=True, 

192 p_indices=False): 

193 """ 

194 Compute LU decomposition of a matrix with partial pivoting. 

195 

196 The decomposition satisfies:: 

197 

198 A = P @ L @ U 

199 

200 where ``P`` is a permutation matrix, ``L`` lower triangular with unit 

201 diagonal elements, and ``U`` upper triangular. If `permute_l` is set to 

202 ``True`` then ``L`` is returned already permuted and hence satisfying 

203 ``A = L @ U``. 

204 

205 Parameters 

206 ---------- 

207 a : (M, N) array_like 

208 Array to decompose 

209 permute_l : bool, optional 

210 Perform the multiplication P*L (Default: do not permute) 

211 overwrite_a : bool, optional 

212 Whether to overwrite data in a (may improve performance) 

213 check_finite : bool, optional 

214 Whether to check that the input matrix contains only finite numbers. 

215 Disabling may give a performance gain, but may result in problems 

216 (crashes, non-termination) if the inputs do contain infinities or NaNs. 

217 p_indices : bool, optional 

218 If ``True`` the permutation information is returned as row indices. 

219 The default is ``False`` for backwards-compatibility reasons. 

220 

221 Returns 

222 ------- 

223 **(If `permute_l` is ``False``)** 

224 

225 p : (..., M, M) ndarray 

226 Permutation arrays or vectors depending on `p_indices` 

227 l : (..., M, K) ndarray 

228 Lower triangular or trapezoidal array with unit diagonal. 

229 ``K = min(M, N)`` 

230 u : (..., K, N) ndarray 

231 Upper triangular or trapezoidal array 

232 

233 **(If `permute_l` is ``True``)** 

234 

235 pl : (..., M, K) ndarray 

236 Permuted L matrix. 

237 ``K = min(M, N)`` 

238 u : (..., K, N) ndarray 

239 Upper triangular or trapezoidal array 

240 

241 Notes 

242 ----- 

243 Permutation matrices are costly since they are nothing but row reorder of 

244 ``L`` and hence indices are strongly recommended to be used instead if the 

245 permutation is required. The relation in the 2D case then becomes simply 

246 ``A = L[P, :] @ U``. In higher dimensions, it is better to use `permute_l` 

247 to avoid complicated indexing tricks. 

248 

249 In 2D case, if one has the indices however, for some reason, the 

250 permutation matrix is still needed then it can be constructed by 

251 ``np.eye(M)[P, :]``. 

252 

253 Examples 

254 -------- 

255 

256 >>> import numpy as np 

257 >>> from scipy.linalg import lu 

258 >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]) 

259 >>> p, l, u = lu(A) 

260 >>> np.allclose(A, p @ l @ u) 

261 True 

262 >>> p # Permutation matrix 

263 array([[0., 1., 0., 0.], # Row index 1 

264 [0., 0., 0., 1.], # Row index 3 

265 [1., 0., 0., 0.], # Row index 0 

266 [0., 0., 1., 0.]]) # Row index 2 

267 >>> p, _, _ = lu(A, p_indices=True) 

268 >>> p 

269 array([1, 3, 0, 2]) # as given by row indices above 

270 >>> np.allclose(A, l[p, :] @ u) 

271 True 

272 

273 We can also use nd-arrays, for example, a demonstration with 4D array: 

274 

275 >>> rng = np.random.default_rng() 

276 >>> A = rng.uniform(low=-4, high=4, size=[3, 2, 4, 8]) 

277 >>> p, l, u = lu(A) 

278 >>> p.shape, l.shape, u.shape 

279 ((3, 2, 4, 4), (3, 2, 4, 4), (3, 2, 4, 8)) 

280 >>> np.allclose(A, p @ l @ u) 

281 True 

282 >>> PL, U = lu(A, permute_l=True) 

283 >>> np.allclose(A, PL @ U) 

284 True 

285 

286 """ 

287 a1 = np.asarray_chkfinite(a) if check_finite else np.asarray(a) 

288 if a1.ndim < 2: 

289 raise ValueError('The input array must be at least two-dimensional.') 

290 

291 # Also check if dtype is LAPACK compatible 

292 if a1.dtype.char not in 'fdFD': 

293 dtype_char = lapack_cast_dict[a1.dtype.char] 

294 if not dtype_char: # No casting possible 

295 raise TypeError(f'The dtype {a1.dtype} cannot be cast ' 

296 'to float(32, 64) or complex(64, 128).') 

297 

298 a1 = a1.astype(dtype_char[0]) # makes a copy, free to scratch 

299 overwrite_a = True 

300 

301 *nd, m, n = a1.shape 

302 k = min(m, n) 

303 real_dchar = 'f' if a1.dtype.char in 'fF' else 'd' 

304 

305 # Empty input 

306 if min(*a1.shape) == 0: 

307 if permute_l: 

308 PL = np.empty(shape=[*nd, m, k], dtype=a1.dtype) 

309 U = np.empty(shape=[*nd, k, n], dtype=a1.dtype) 

310 return PL, U 

311 else: 

312 P = (np.empty([*nd, 0], dtype=np.int32) if p_indices else 

313 np.empty([*nd, 0, 0], dtype=real_dchar)) 

314 L = np.empty(shape=[*nd, m, k], dtype=a1.dtype) 

315 U = np.empty(shape=[*nd, k, n], dtype=a1.dtype) 

316 return P, L, U 

317 

318 # Scalar case 

319 if a1.shape[-2:] == (1, 1): 

320 if permute_l: 

321 return np.ones_like(a1), (a1 if overwrite_a else a1.copy()) 

322 else: 

323 P = (np.zeros(shape=[*nd, m], dtype=int) if p_indices 

324 else np.ones_like(a1)) 

325 return P, np.ones_like(a1), (a1 if overwrite_a else a1.copy()) 

326 

327 # Then check overwrite permission 

328 if not _datacopied(a1, a): # "a" still alive through "a1" 

329 if not overwrite_a: 

330 # Data belongs to "a" so make a copy 

331 a1 = a1.copy(order='C') 

332 # else: Do nothing we'll use "a" if possible 

333 # else: a1 has its own data thus free to scratch 

334 

335 # Then layout checks, might happen that overwrite is allowed but original 

336 # array was read-only or non-contiguous. 

337 

338 if not (a1.flags['C_CONTIGUOUS'] and a1.flags['WRITEABLE']): 

339 a1 = a1.copy(order='C') 

340 

341 if not nd: # 2D array 

342 

343 p = np.empty(m, dtype=np.int32) 

344 u = np.zeros([k, k], dtype=a1.dtype) 

345 lu_dispatcher(a1, u, p, permute_l) 

346 P, L, U = (p, a1, u) if m > n else (p, u, a1) 

347 

348 else: # Stacked array 

349 

350 # Prepare the contiguous data holders 

351 P = np.empty([*nd, m], dtype=np.int32) # perm vecs 

352 

353 if m > n: # Tall arrays, U will be created 

354 U = np.zeros([*nd, k, k], dtype=a1.dtype) 

355 for ind in product(*[range(x) for x in a1.shape[:-2]]): 

356 lu_dispatcher(a1[ind], U[ind], P[ind], permute_l) 

357 L = a1 

358 

359 else: # Fat arrays, L will be created 

360 L = np.zeros([*nd, k, k], dtype=a1.dtype) 

361 for ind in product(*[range(x) for x in a1.shape[:-2]]): 

362 lu_dispatcher(a1[ind], L[ind], P[ind], permute_l) 

363 U = a1 

364 

365 # Convert permutation vecs to permutation arrays 

366 # permute_l=False needed to enter here to avoid wasted efforts 

367 if (not p_indices) and (not permute_l): 

368 if nd: 

369 Pa = np.zeros([*nd, m, m], dtype=real_dchar) 

370 # An unreadable index hack - One-hot encoding for perm matrices 

371 nd_ix = np.ix_(*([np.arange(x) for x in nd]+[np.arange(m)])) 

372 Pa[(*nd_ix, P)] = 1 

373 P = Pa 

374 else: # 2D case 

375 Pa = np.zeros([m, m], dtype=real_dchar) 

376 Pa[np.arange(m), P] = 1 

377 P = Pa 

378 

379 return (L, U) if permute_l else (P, L, U)