Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/linalg/_decomp_lu.py: 14%
92 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-23 06:43 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-23 06:43 +0000
1"""LU decomposition functions."""
3from warnings import warn
5from numpy import asarray, asarray_chkfinite
6import numpy as np
7from itertools import product
9# Local imports
10from ._misc import _datacopied, LinAlgWarning
11from .lapack import get_lapack_funcs
12from ._decomp_lu_cython import lu_dispatcher
14# deprecated imports to be removed in SciPy 1.13.0
15from scipy.linalg._flinalg_py import get_flinalg_funcs # noqa
18lapack_cast_dict = {x: ''.join([y for y in 'fdFD' if np.can_cast(x, y)])
19 for x in np.typecodes['All']}
21__all__ = ['lu', 'lu_solve', 'lu_factor']
24def lu_factor(a, overwrite_a=False, check_finite=True):
25 """
26 Compute pivoted LU decomposition of a matrix.
28 The decomposition is::
30 A = P L U
32 where P is a permutation matrix, L lower triangular with unit
33 diagonal elements, and U upper triangular.
35 Parameters
36 ----------
37 a : (M, N) array_like
38 Matrix to decompose
39 overwrite_a : bool, optional
40 Whether to overwrite data in A (may increase performance)
41 check_finite : bool, optional
42 Whether to check that the input matrix contains only finite numbers.
43 Disabling may give a performance gain, but may result in problems
44 (crashes, non-termination) if the inputs do contain infinities or NaNs.
46 Returns
47 -------
48 lu : (M, N) ndarray
49 Matrix containing U in its upper triangle, and L in its lower triangle.
50 The unit diagonal elements of L are not stored.
51 piv : (K,) ndarray
52 Pivot indices representing the permutation matrix P:
53 row i of matrix was interchanged with row piv[i].
54 Of shape ``(K,)``, with ``K = min(M, N)``.
56 See Also
57 --------
58 lu : gives lu factorization in more user-friendly format
59 lu_solve : solve an equation system using the LU factorization of a matrix
61 Notes
62 -----
63 This is a wrapper to the ``*GETRF`` routines from LAPACK. Unlike
64 :func:`lu`, it outputs the L and U factors into a single array
65 and returns pivot indices instead of a permutation matrix.
67 While the underlying ``*GETRF`` routines return 1-based pivot indices, the
68 ``piv`` array returned by ``lu_factor`` contains 0-based indices.
70 Examples
71 --------
72 >>> import numpy as np
73 >>> from scipy.linalg import lu_factor
74 >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
75 >>> lu, piv = lu_factor(A)
76 >>> piv
77 array([2, 2, 3, 3], dtype=int32)
79 Convert LAPACK's ``piv`` array to NumPy index and test the permutation
81 >>> def pivot_to_permutation(piv):
82 ... perm = np.arange(len(piv))
83 ... for i in range(len(piv)):
84 ... perm[i], perm[piv[i]] = perm[piv[i]], perm[i]
85 ... return perm
86 ...
87 >>> p_inv = pivot_to_permutation(piv)
88 >>> p_inv
89 array([2, 0, 3, 1])
90 >>> L, U = np.tril(lu, k=-1) + np.eye(4), np.triu(lu)
91 >>> np.allclose(A[p_inv] - L @ U, np.zeros((4, 4)))
92 True
94 The P matrix in P L U is defined by the inverse permutation and
95 can be recovered using argsort:
97 >>> p = np.argsort(p_inv)
98 >>> p
99 array([1, 3, 0, 2])
100 >>> np.allclose(A - L[p] @ U, np.zeros((4, 4)))
101 True
103 or alternatively:
105 >>> P = np.eye(4)[p]
106 >>> np.allclose(A - P @ L @ U, np.zeros((4, 4)))
107 True
108 """
109 if check_finite:
110 a1 = asarray_chkfinite(a)
111 else:
112 a1 = asarray(a)
113 overwrite_a = overwrite_a or (_datacopied(a1, a))
114 getrf, = get_lapack_funcs(('getrf',), (a1,))
115 lu, piv, info = getrf(a1, overwrite_a=overwrite_a)
116 if info < 0:
117 raise ValueError('illegal value in %dth argument of '
118 'internal getrf (lu_factor)' % -info)
119 if info > 0:
120 warn("Diagonal number %d is exactly zero. Singular matrix." % info,
121 LinAlgWarning, stacklevel=2)
122 return lu, piv
125def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
126 """Solve an equation system, a x = b, given the LU factorization of a
128 Parameters
129 ----------
130 (lu, piv)
131 Factorization of the coefficient matrix a, as given by lu_factor.
132 In particular piv are 0-indexed pivot indices.
133 b : array
134 Right-hand side
135 trans : {0, 1, 2}, optional
136 Type of system to solve:
138 ===== =========
139 trans system
140 ===== =========
141 0 a x = b
142 1 a^T x = b
143 2 a^H x = b
144 ===== =========
145 overwrite_b : bool, optional
146 Whether to overwrite data in b (may increase performance)
147 check_finite : bool, optional
148 Whether to check that the input matrices contain only finite numbers.
149 Disabling may give a performance gain, but may result in problems
150 (crashes, non-termination) if the inputs do contain infinities or NaNs.
152 Returns
153 -------
154 x : array
155 Solution to the system
157 See Also
158 --------
159 lu_factor : LU factorize a matrix
161 Examples
162 --------
163 >>> import numpy as np
164 >>> from scipy.linalg import lu_factor, lu_solve
165 >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
166 >>> b = np.array([1, 1, 1, 1])
167 >>> lu, piv = lu_factor(A)
168 >>> x = lu_solve((lu, piv), b)
169 >>> np.allclose(A @ x - b, np.zeros((4,)))
170 True
172 """
173 (lu, piv) = lu_and_piv
174 if check_finite:
175 b1 = asarray_chkfinite(b)
176 else:
177 b1 = asarray(b)
178 overwrite_b = overwrite_b or _datacopied(b1, b)
179 if lu.shape[0] != b1.shape[0]:
180 raise ValueError("Shapes of lu {} and b {} are incompatible"
181 .format(lu.shape, b1.shape))
183 getrs, = get_lapack_funcs(('getrs',), (lu, b1))
184 x, info = getrs(lu, piv, b1, trans=trans, overwrite_b=overwrite_b)
185 if info == 0:
186 return x
187 raise ValueError('illegal value in %dth argument of internal gesv|posv'
188 % -info)
191def lu(a, permute_l=False, overwrite_a=False, check_finite=True,
192 p_indices=False):
193 """
194 Compute LU decomposition of a matrix with partial pivoting.
196 The decomposition satisfies::
198 A = P @ L @ U
200 where ``P`` is a permutation matrix, ``L`` lower triangular with unit
201 diagonal elements, and ``U`` upper triangular. If `permute_l` is set to
202 ``True`` then ``L`` is returned already permuted and hence satisfying
203 ``A = L @ U``.
205 Parameters
206 ----------
207 a : (M, N) array_like
208 Array to decompose
209 permute_l : bool, optional
210 Perform the multiplication P*L (Default: do not permute)
211 overwrite_a : bool, optional
212 Whether to overwrite data in a (may improve performance)
213 check_finite : bool, optional
214 Whether to check that the input matrix contains only finite numbers.
215 Disabling may give a performance gain, but may result in problems
216 (crashes, non-termination) if the inputs do contain infinities or NaNs.
217 p_indices : bool, optional
218 If ``True`` the permutation information is returned as row indices.
219 The default is ``False`` for backwards-compatibility reasons.
221 Returns
222 -------
223 **(If `permute_l` is ``False``)**
225 p : (..., M, M) ndarray
226 Permutation arrays or vectors depending on `p_indices`
227 l : (..., M, K) ndarray
228 Lower triangular or trapezoidal array with unit diagonal.
229 ``K = min(M, N)``
230 u : (..., K, N) ndarray
231 Upper triangular or trapezoidal array
233 **(If `permute_l` is ``True``)**
235 pl : (..., M, K) ndarray
236 Permuted L matrix.
237 ``K = min(M, N)``
238 u : (..., K, N) ndarray
239 Upper triangular or trapezoidal array
241 Notes
242 -----
243 Permutation matrices are costly since they are nothing but row reorder of
244 ``L`` and hence indices are strongly recommended to be used instead if the
245 permutation is required. The relation in the 2D case then becomes simply
246 ``A = L[P, :] @ U``. In higher dimensions, it is better to use `permute_l`
247 to avoid complicated indexing tricks.
249 In 2D case, if one has the indices however, for some reason, the
250 permutation matrix is still needed then it can be constructed by
251 ``np.eye(M)[P, :]``.
253 Examples
254 --------
256 >>> import numpy as np
257 >>> from scipy.linalg import lu
258 >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
259 >>> p, l, u = lu(A)
260 >>> np.allclose(A, p @ l @ u)
261 True
262 >>> p # Permutation matrix
263 array([[0., 1., 0., 0.], # Row index 1
264 [0., 0., 0., 1.], # Row index 3
265 [1., 0., 0., 0.], # Row index 0
266 [0., 0., 1., 0.]]) # Row index 2
267 >>> p, _, _ = lu(A, p_indices=True)
268 >>> p
269 array([1, 3, 0, 2]) # as given by row indices above
270 >>> np.allclose(A, l[p, :] @ u)
271 True
273 We can also use nd-arrays, for example, a demonstration with 4D array:
275 >>> rng = np.random.default_rng()
276 >>> A = rng.uniform(low=-4, high=4, size=[3, 2, 4, 8])
277 >>> p, l, u = lu(A)
278 >>> p.shape, l.shape, u.shape
279 ((3, 2, 4, 4), (3, 2, 4, 4), (3, 2, 4, 8))
280 >>> np.allclose(A, p @ l @ u)
281 True
282 >>> PL, U = lu(A, permute_l=True)
283 >>> np.allclose(A, PL @ U)
284 True
286 """
287 a1 = np.asarray_chkfinite(a) if check_finite else np.asarray(a)
288 if a1.ndim < 2:
289 raise ValueError('The input array must be at least two-dimensional.')
291 # Also check if dtype is LAPACK compatible
292 if a1.dtype.char not in 'fdFD':
293 dtype_char = lapack_cast_dict[a1.dtype.char]
294 if not dtype_char: # No casting possible
295 raise TypeError(f'The dtype {a1.dtype} cannot be cast '
296 'to float(32, 64) or complex(64, 128).')
298 a1 = a1.astype(dtype_char[0]) # makes a copy, free to scratch
299 overwrite_a = True
301 *nd, m, n = a1.shape
302 k = min(m, n)
303 real_dchar = 'f' if a1.dtype.char in 'fF' else 'd'
305 # Empty input
306 if min(*a1.shape) == 0:
307 if permute_l:
308 PL = np.empty(shape=[*nd, m, k], dtype=a1.dtype)
309 U = np.empty(shape=[*nd, k, n], dtype=a1.dtype)
310 return PL, U
311 else:
312 P = (np.empty([*nd, 0], dtype=np.int32) if p_indices else
313 np.empty([*nd, 0, 0], dtype=real_dchar))
314 L = np.empty(shape=[*nd, m, k], dtype=a1.dtype)
315 U = np.empty(shape=[*nd, k, n], dtype=a1.dtype)
316 return P, L, U
318 # Scalar case
319 if a1.shape[-2:] == (1, 1):
320 if permute_l:
321 return np.ones_like(a1), (a1 if overwrite_a else a1.copy())
322 else:
323 P = (np.zeros(shape=[*nd, m], dtype=int) if p_indices
324 else np.ones_like(a1))
325 return P, np.ones_like(a1), (a1 if overwrite_a else a1.copy())
327 # Then check overwrite permission
328 if not _datacopied(a1, a): # "a" still alive through "a1"
329 if not overwrite_a:
330 # Data belongs to "a" so make a copy
331 a1 = a1.copy(order='C')
332 # else: Do nothing we'll use "a" if possible
333 # else: a1 has its own data thus free to scratch
335 # Then layout checks, might happen that overwrite is allowed but original
336 # array was read-only or non-contiguous.
338 if not (a1.flags['C_CONTIGUOUS'] and a1.flags['WRITEABLE']):
339 a1 = a1.copy(order='C')
341 if not nd: # 2D array
343 p = np.empty(m, dtype=np.int32)
344 u = np.zeros([k, k], dtype=a1.dtype)
345 lu_dispatcher(a1, u, p, permute_l)
346 P, L, U = (p, a1, u) if m > n else (p, u, a1)
348 else: # Stacked array
350 # Prepare the contiguous data holders
351 P = np.empty([*nd, m], dtype=np.int32) # perm vecs
353 if m > n: # Tall arrays, U will be created
354 U = np.zeros([*nd, k, k], dtype=a1.dtype)
355 for ind in product(*[range(x) for x in a1.shape[:-2]]):
356 lu_dispatcher(a1[ind], U[ind], P[ind], permute_l)
357 L = a1
359 else: # Fat arrays, L will be created
360 L = np.zeros([*nd, k, k], dtype=a1.dtype)
361 for ind in product(*[range(x) for x in a1.shape[:-2]]):
362 lu_dispatcher(a1[ind], L[ind], P[ind], permute_l)
363 U = a1
365 # Convert permutation vecs to permutation arrays
366 # permute_l=False needed to enter here to avoid wasted efforts
367 if (not p_indices) and (not permute_l):
368 if nd:
369 Pa = np.zeros([*nd, m, m], dtype=real_dchar)
370 # An unreadable index hack - One-hot encoding for perm matrices
371 nd_ix = np.ix_(*([np.arange(x) for x in nd]+[np.arange(m)]))
372 Pa[(*nd_ix, P)] = 1
373 P = Pa
374 else: # 2D case
375 Pa = np.zeros([m, m], dtype=real_dchar)
376 Pa[np.arange(m), P] = 1
377 P = Pa
379 return (L, U) if permute_l else (P, L, U)