Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/scipy/linalg/_matfuncs_sqrtm.py: 13%
85 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-12 06:31 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-12 06:31 +0000
1"""
2Matrix square root for general matrices and for upper triangular matrices.
4This module exists to avoid cyclic imports.
6"""
7__all__ = ['sqrtm']
9import numpy as np
11from scipy._lib._util import _asarray_validated
14# Local imports
15from ._misc import norm
16from .lapack import ztrsyl, dtrsyl
17from ._decomp_schur import schur, rsf2csf
20class SqrtmError(np.linalg.LinAlgError):
21 pass
24from ._matfuncs_sqrtm_triu import within_block_loop
27def _sqrtm_triu(T, blocksize=64):
28 """
29 Matrix square root of an upper triangular matrix.
31 This is a helper function for `sqrtm` and `logm`.
33 Parameters
34 ----------
35 T : (N, N) array_like upper triangular
36 Matrix whose square root to evaluate
37 blocksize : int, optional
38 If the blocksize is not degenerate with respect to the
39 size of the input array, then use a blocked algorithm. (Default: 64)
41 Returns
42 -------
43 sqrtm : (N, N) ndarray
44 Value of the sqrt function at `T`
46 References
47 ----------
48 .. [1] Edvin Deadman, Nicholas J. Higham, Rui Ralha (2013)
49 "Blocked Schur Algorithms for Computing the Matrix Square Root,
50 Lecture Notes in Computer Science, 7782. pp. 171-182.
52 """
53 T_diag = np.diag(T)
54 keep_it_real = np.isrealobj(T) and np.min(T_diag) >= 0
56 # Cast to complex as necessary + ensure double precision
57 if not keep_it_real:
58 T = np.asarray(T, dtype=np.complex128, order="C")
59 T_diag = np.asarray(T_diag, dtype=np.complex128)
60 else:
61 T = np.asarray(T, dtype=np.float64, order="C")
62 T_diag = np.asarray(T_diag, dtype=np.float64)
64 R = np.diag(np.sqrt(T_diag))
66 # Compute the number of blocks to use; use at least one block.
67 n, n = T.shape
68 nblocks = max(n // blocksize, 1)
70 # Compute the smaller of the two sizes of blocks that
71 # we will actually use, and compute the number of large blocks.
72 bsmall, nlarge = divmod(n, nblocks)
73 blarge = bsmall + 1
74 nsmall = nblocks - nlarge
75 if nsmall * bsmall + nlarge * blarge != n:
76 raise Exception('internal inconsistency')
78 # Define the index range covered by each block.
79 start_stop_pairs = []
80 start = 0
81 for count, size in ((nsmall, bsmall), (nlarge, blarge)):
82 for i in range(count):
83 start_stop_pairs.append((start, start + size))
84 start += size
86 # Within-block interactions (Cythonized)
87 try:
88 within_block_loop(R, T, start_stop_pairs, nblocks)
89 except RuntimeError as e:
90 raise SqrtmError(*e.args) from e
92 # Between-block interactions (Cython would give no significant speedup)
93 for j in range(nblocks):
94 jstart, jstop = start_stop_pairs[j]
95 for i in range(j-1, -1, -1):
96 istart, istop = start_stop_pairs[i]
97 S = T[istart:istop, jstart:jstop]
98 if j - i > 1:
99 S = S - R[istart:istop, istop:jstart].dot(R[istop:jstart,
100 jstart:jstop])
102 # Invoke LAPACK.
103 # For more details, see the solve_sylvester implemention
104 # and the fortran dtrsyl and ztrsyl docs.
105 Rii = R[istart:istop, istart:istop]
106 Rjj = R[jstart:jstop, jstart:jstop]
107 if keep_it_real:
108 x, scale, info = dtrsyl(Rii, Rjj, S)
109 else:
110 x, scale, info = ztrsyl(Rii, Rjj, S)
111 R[istart:istop, jstart:jstop] = x * scale
113 # Return the matrix square root.
114 return R
117def sqrtm(A, disp=True, blocksize=64):
118 """
119 Matrix square root.
121 Parameters
122 ----------
123 A : (N, N) array_like
124 Matrix whose square root to evaluate
125 disp : bool, optional
126 Print warning if error in the result is estimated large
127 instead of returning estimated error. (Default: True)
128 blocksize : integer, optional
129 If the blocksize is not degenerate with respect to the
130 size of the input array, then use a blocked algorithm. (Default: 64)
132 Returns
133 -------
134 sqrtm : (N, N) ndarray
135 Value of the sqrt function at `A`. The dtype is float or complex.
136 The precision (data size) is determined based on the precision of
137 input `A`. When the dtype is float, the precision is same as `A`.
138 When the dtype is complex, the precition is double as `A`. The
139 precision might be cliped by each dtype precision range.
141 errest : float
142 (if disp == False)
144 Frobenius norm of the estimated error, ||err||_F / ||A||_F
146 References
147 ----------
148 .. [1] Edvin Deadman, Nicholas J. Higham, Rui Ralha (2013)
149 "Blocked Schur Algorithms for Computing the Matrix Square Root,
150 Lecture Notes in Computer Science, 7782. pp. 171-182.
152 Examples
153 --------
154 >>> import numpy as np
155 >>> from scipy.linalg import sqrtm
156 >>> a = np.array([[1.0, 3.0], [1.0, 4.0]])
157 >>> r = sqrtm(a)
158 >>> r
159 array([[ 0.75592895, 1.13389342],
160 [ 0.37796447, 1.88982237]])
161 >>> r.dot(r)
162 array([[ 1., 3.],
163 [ 1., 4.]])
165 """
166 byte_size = np.asarray(A).dtype.itemsize
167 A = _asarray_validated(A, check_finite=True, as_inexact=True)
168 if len(A.shape) != 2:
169 raise ValueError("Non-matrix input to matrix function.")
170 if blocksize < 1:
171 raise ValueError("The blocksize should be at least 1.")
172 keep_it_real = np.isrealobj(A)
173 if keep_it_real:
174 T, Z = schur(A)
175 if not np.array_equal(T, np.triu(T)):
176 T, Z = rsf2csf(T, Z)
177 else:
178 T, Z = schur(A, output='complex')
179 failflag = False
180 try:
181 R = _sqrtm_triu(T, blocksize=blocksize)
182 ZH = np.conjugate(Z).T
183 X = Z.dot(R).dot(ZH)
184 if not np.iscomplexobj(X):
185 # float byte size range: f2 ~ f16
186 X = X.astype(f"f{np.clip(byte_size, 2, 16)}", copy=False)
187 else:
188 # complex byte size range: c8 ~ c32.
189 # c32(complex256) might not be supported in some environments.
190 if hasattr(np, 'complex256'):
191 X = X.astype(f"c{np.clip(byte_size*2, 8, 32)}", copy=False)
192 else:
193 X = X.astype(f"c{np.clip(byte_size*2, 8, 16)}", copy=False)
194 except SqrtmError:
195 failflag = True
196 X = np.empty_like(A)
197 X.fill(np.nan)
199 if disp:
200 if failflag:
201 print("Failed to find a square root.")
202 return X
203 else:
204 try:
205 arg2 = norm(X.dot(X) - A, 'fro')**2 / norm(A, 'fro')
206 except ValueError:
207 # NaNs in matrix
208 arg2 = np.inf
210 return X, arg2