Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/linalg/_decomp_lu.py: 13%
91 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-22 06:44 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-22 06:44 +0000
1"""LU decomposition functions."""
3from warnings import warn
5from numpy import asarray, asarray_chkfinite
6import numpy as np
7from itertools import product
9# Local imports
10from ._misc import _datacopied, LinAlgWarning
11from .lapack import get_lapack_funcs
12from ._decomp_lu_cython import lu_dispatcher
14lapack_cast_dict = {x: ''.join([y for y in 'fdFD' if np.can_cast(x, y)])
15 for x in np.typecodes['All']}
17__all__ = ['lu', 'lu_solve', 'lu_factor']
20def lu_factor(a, overwrite_a=False, check_finite=True):
21 """
22 Compute pivoted LU decomposition of a matrix.
24 The decomposition is::
26 A = P L U
28 where P is a permutation matrix, L lower triangular with unit
29 diagonal elements, and U upper triangular.
31 Parameters
32 ----------
33 a : (M, N) array_like
34 Matrix to decompose
35 overwrite_a : bool, optional
36 Whether to overwrite data in A (may increase performance)
37 check_finite : bool, optional
38 Whether to check that the input matrix contains only finite numbers.
39 Disabling may give a performance gain, but may result in problems
40 (crashes, non-termination) if the inputs do contain infinities or NaNs.
42 Returns
43 -------
44 lu : (M, N) ndarray
45 Matrix containing U in its upper triangle, and L in its lower triangle.
46 The unit diagonal elements of L are not stored.
47 piv : (K,) ndarray
48 Pivot indices representing the permutation matrix P:
49 row i of matrix was interchanged with row piv[i].
50 Of shape ``(K,)``, with ``K = min(M, N)``.
52 See Also
53 --------
54 lu : gives lu factorization in more user-friendly format
55 lu_solve : solve an equation system using the LU factorization of a matrix
57 Notes
58 -----
59 This is a wrapper to the ``*GETRF`` routines from LAPACK. Unlike
60 :func:`lu`, it outputs the L and U factors into a single array
61 and returns pivot indices instead of a permutation matrix.
63 While the underlying ``*GETRF`` routines return 1-based pivot indices, the
64 ``piv`` array returned by ``lu_factor`` contains 0-based indices.
66 Examples
67 --------
68 >>> import numpy as np
69 >>> from scipy.linalg import lu_factor
70 >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
71 >>> lu, piv = lu_factor(A)
72 >>> piv
73 array([2, 2, 3, 3], dtype=int32)
75 Convert LAPACK's ``piv`` array to NumPy index and test the permutation
77 >>> def pivot_to_permutation(piv):
78 ... perm = np.arange(len(piv))
79 ... for i in range(len(piv)):
80 ... perm[i], perm[piv[i]] = perm[piv[i]], perm[i]
81 ... return perm
82 ...
83 >>> p_inv = pivot_to_permutation(piv)
84 >>> p_inv
85 array([2, 0, 3, 1])
86 >>> L, U = np.tril(lu, k=-1) + np.eye(4), np.triu(lu)
87 >>> np.allclose(A[p_inv] - L @ U, np.zeros((4, 4)))
88 True
90 The P matrix in P L U is defined by the inverse permutation and
91 can be recovered using argsort:
93 >>> p = np.argsort(p_inv)
94 >>> p
95 array([1, 3, 0, 2])
96 >>> np.allclose(A - L[p] @ U, np.zeros((4, 4)))
97 True
99 or alternatively:
101 >>> P = np.eye(4)[p]
102 >>> np.allclose(A - P @ L @ U, np.zeros((4, 4)))
103 True
104 """
105 if check_finite:
106 a1 = asarray_chkfinite(a)
107 else:
108 a1 = asarray(a)
109 overwrite_a = overwrite_a or (_datacopied(a1, a))
110 getrf, = get_lapack_funcs(('getrf',), (a1,))
111 lu, piv, info = getrf(a1, overwrite_a=overwrite_a)
112 if info < 0:
113 raise ValueError('illegal value in %dth argument of '
114 'internal getrf (lu_factor)' % -info)
115 if info > 0:
116 warn("Diagonal number %d is exactly zero. Singular matrix." % info,
117 LinAlgWarning, stacklevel=2)
118 return lu, piv
121def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
122 """Solve an equation system, a x = b, given the LU factorization of a
124 Parameters
125 ----------
126 (lu, piv)
127 Factorization of the coefficient matrix a, as given by lu_factor.
128 In particular piv are 0-indexed pivot indices.
129 b : array
130 Right-hand side
131 trans : {0, 1, 2}, optional
132 Type of system to solve:
134 ===== =========
135 trans system
136 ===== =========
137 0 a x = b
138 1 a^T x = b
139 2 a^H x = b
140 ===== =========
141 overwrite_b : bool, optional
142 Whether to overwrite data in b (may increase performance)
143 check_finite : bool, optional
144 Whether to check that the input matrices contain only finite numbers.
145 Disabling may give a performance gain, but may result in problems
146 (crashes, non-termination) if the inputs do contain infinities or NaNs.
148 Returns
149 -------
150 x : array
151 Solution to the system
153 See Also
154 --------
155 lu_factor : LU factorize a matrix
157 Examples
158 --------
159 >>> import numpy as np
160 >>> from scipy.linalg import lu_factor, lu_solve
161 >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
162 >>> b = np.array([1, 1, 1, 1])
163 >>> lu, piv = lu_factor(A)
164 >>> x = lu_solve((lu, piv), b)
165 >>> np.allclose(A @ x - b, np.zeros((4,)))
166 True
168 """
169 (lu, piv) = lu_and_piv
170 if check_finite:
171 b1 = asarray_chkfinite(b)
172 else:
173 b1 = asarray(b)
174 overwrite_b = overwrite_b or _datacopied(b1, b)
175 if lu.shape[0] != b1.shape[0]:
176 raise ValueError(f"Shapes of lu {lu.shape} and b {b1.shape} are incompatible")
178 getrs, = get_lapack_funcs(('getrs',), (lu, b1))
179 x, info = getrs(lu, piv, b1, trans=trans, overwrite_b=overwrite_b)
180 if info == 0:
181 return x
182 raise ValueError('illegal value in %dth argument of internal gesv|posv'
183 % -info)
186def lu(a, permute_l=False, overwrite_a=False, check_finite=True,
187 p_indices=False):
188 """
189 Compute LU decomposition of a matrix with partial pivoting.
191 The decomposition satisfies::
193 A = P @ L @ U
195 where ``P`` is a permutation matrix, ``L`` lower triangular with unit
196 diagonal elements, and ``U`` upper triangular. If `permute_l` is set to
197 ``True`` then ``L`` is returned already permuted and hence satisfying
198 ``A = L @ U``.
200 Parameters
201 ----------
202 a : (M, N) array_like
203 Array to decompose
204 permute_l : bool, optional
205 Perform the multiplication P*L (Default: do not permute)
206 overwrite_a : bool, optional
207 Whether to overwrite data in a (may improve performance)
208 check_finite : bool, optional
209 Whether to check that the input matrix contains only finite numbers.
210 Disabling may give a performance gain, but may result in problems
211 (crashes, non-termination) if the inputs do contain infinities or NaNs.
212 p_indices : bool, optional
213 If ``True`` the permutation information is returned as row indices.
214 The default is ``False`` for backwards-compatibility reasons.
216 Returns
217 -------
218 **(If `permute_l` is ``False``)**
220 p : (..., M, M) ndarray
221 Permutation arrays or vectors depending on `p_indices`
222 l : (..., M, K) ndarray
223 Lower triangular or trapezoidal array with unit diagonal.
224 ``K = min(M, N)``
225 u : (..., K, N) ndarray
226 Upper triangular or trapezoidal array
228 **(If `permute_l` is ``True``)**
230 pl : (..., M, K) ndarray
231 Permuted L matrix.
232 ``K = min(M, N)``
233 u : (..., K, N) ndarray
234 Upper triangular or trapezoidal array
236 Notes
237 -----
238 Permutation matrices are costly since they are nothing but row reorder of
239 ``L`` and hence indices are strongly recommended to be used instead if the
240 permutation is required. The relation in the 2D case then becomes simply
241 ``A = L[P, :] @ U``. In higher dimensions, it is better to use `permute_l`
242 to avoid complicated indexing tricks.
244 In 2D case, if one has the indices however, for some reason, the
245 permutation matrix is still needed then it can be constructed by
246 ``np.eye(M)[P, :]``.
248 Examples
249 --------
251 >>> import numpy as np
252 >>> from scipy.linalg import lu
253 >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
254 >>> p, l, u = lu(A)
255 >>> np.allclose(A, p @ l @ u)
256 True
257 >>> p # Permutation matrix
258 array([[0., 1., 0., 0.], # Row index 1
259 [0., 0., 0., 1.], # Row index 3
260 [1., 0., 0., 0.], # Row index 0
261 [0., 0., 1., 0.]]) # Row index 2
262 >>> p, _, _ = lu(A, p_indices=True)
263 >>> p
264 array([1, 3, 0, 2]) # as given by row indices above
265 >>> np.allclose(A, l[p, :] @ u)
266 True
268 We can also use nd-arrays, for example, a demonstration with 4D array:
270 >>> rng = np.random.default_rng()
271 >>> A = rng.uniform(low=-4, high=4, size=[3, 2, 4, 8])
272 >>> p, l, u = lu(A)
273 >>> p.shape, l.shape, u.shape
274 ((3, 2, 4, 4), (3, 2, 4, 4), (3, 2, 4, 8))
275 >>> np.allclose(A, p @ l @ u)
276 True
277 >>> PL, U = lu(A, permute_l=True)
278 >>> np.allclose(A, PL @ U)
279 True
281 """
282 a1 = np.asarray_chkfinite(a) if check_finite else np.asarray(a)
283 if a1.ndim < 2:
284 raise ValueError('The input array must be at least two-dimensional.')
286 # Also check if dtype is LAPACK compatible
287 if a1.dtype.char not in 'fdFD':
288 dtype_char = lapack_cast_dict[a1.dtype.char]
289 if not dtype_char: # No casting possible
290 raise TypeError(f'The dtype {a1.dtype} cannot be cast '
291 'to float(32, 64) or complex(64, 128).')
293 a1 = a1.astype(dtype_char[0]) # makes a copy, free to scratch
294 overwrite_a = True
296 *nd, m, n = a1.shape
297 k = min(m, n)
298 real_dchar = 'f' if a1.dtype.char in 'fF' else 'd'
300 # Empty input
301 if min(*a1.shape) == 0:
302 if permute_l:
303 PL = np.empty(shape=[*nd, m, k], dtype=a1.dtype)
304 U = np.empty(shape=[*nd, k, n], dtype=a1.dtype)
305 return PL, U
306 else:
307 P = (np.empty([*nd, 0], dtype=np.int32) if p_indices else
308 np.empty([*nd, 0, 0], dtype=real_dchar))
309 L = np.empty(shape=[*nd, m, k], dtype=a1.dtype)
310 U = np.empty(shape=[*nd, k, n], dtype=a1.dtype)
311 return P, L, U
313 # Scalar case
314 if a1.shape[-2:] == (1, 1):
315 if permute_l:
316 return np.ones_like(a1), (a1 if overwrite_a else a1.copy())
317 else:
318 P = (np.zeros(shape=[*nd, m], dtype=int) if p_indices
319 else np.ones_like(a1))
320 return P, np.ones_like(a1), (a1 if overwrite_a else a1.copy())
322 # Then check overwrite permission
323 if not _datacopied(a1, a): # "a" still alive through "a1"
324 if not overwrite_a:
325 # Data belongs to "a" so make a copy
326 a1 = a1.copy(order='C')
327 # else: Do nothing we'll use "a" if possible
328 # else: a1 has its own data thus free to scratch
330 # Then layout checks, might happen that overwrite is allowed but original
331 # array was read-only or non-contiguous.
333 if not (a1.flags['C_CONTIGUOUS'] and a1.flags['WRITEABLE']):
334 a1 = a1.copy(order='C')
336 if not nd: # 2D array
338 p = np.empty(m, dtype=np.int32)
339 u = np.zeros([k, k], dtype=a1.dtype)
340 lu_dispatcher(a1, u, p, permute_l)
341 P, L, U = (p, a1, u) if m > n else (p, u, a1)
343 else: # Stacked array
345 # Prepare the contiguous data holders
346 P = np.empty([*nd, m], dtype=np.int32) # perm vecs
348 if m > n: # Tall arrays, U will be created
349 U = np.zeros([*nd, k, k], dtype=a1.dtype)
350 for ind in product(*[range(x) for x in a1.shape[:-2]]):
351 lu_dispatcher(a1[ind], U[ind], P[ind], permute_l)
352 L = a1
354 else: # Fat arrays, L will be created
355 L = np.zeros([*nd, k, k], dtype=a1.dtype)
356 for ind in product(*[range(x) for x in a1.shape[:-2]]):
357 lu_dispatcher(a1[ind], L[ind], P[ind], permute_l)
358 U = a1
360 # Convert permutation vecs to permutation arrays
361 # permute_l=False needed to enter here to avoid wasted efforts
362 if (not p_indices) and (not permute_l):
363 if nd:
364 Pa = np.zeros([*nd, m, m], dtype=real_dchar)
365 # An unreadable index hack - One-hot encoding for perm matrices
366 nd_ix = np.ix_(*([np.arange(x) for x in nd]+[np.arange(m)]))
367 Pa[(*nd_ix, P)] = 1
368 P = Pa
369 else: # 2D case
370 Pa = np.zeros([m, m], dtype=real_dchar)
371 Pa[np.arange(m), P] = 1
372 P = Pa
374 return (L, U) if permute_l else (P, L, U)