Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/scipy/linalg/_matfuncs_sqrtm.py: 13%

85 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-12 06:31 +0000

1""" 

2Matrix square root for general matrices and for upper triangular matrices. 

3 

4This module exists to avoid cyclic imports. 

5 

6""" 

7__all__ = ['sqrtm'] 

8 

9import numpy as np 

10 

11from scipy._lib._util import _asarray_validated 

12 

13 

14# Local imports 

15from ._misc import norm 

16from .lapack import ztrsyl, dtrsyl 

17from ._decomp_schur import schur, rsf2csf 

18 

19 

20class SqrtmError(np.linalg.LinAlgError): 

21 pass 

22 

23 

24from ._matfuncs_sqrtm_triu import within_block_loop 

25 

26 

27def _sqrtm_triu(T, blocksize=64): 

28 """ 

29 Matrix square root of an upper triangular matrix. 

30 

31 This is a helper function for `sqrtm` and `logm`. 

32 

33 Parameters 

34 ---------- 

35 T : (N, N) array_like upper triangular 

36 Matrix whose square root to evaluate 

37 blocksize : int, optional 

38 If the blocksize is not degenerate with respect to the 

39 size of the input array, then use a blocked algorithm. (Default: 64) 

40 

41 Returns 

42 ------- 

43 sqrtm : (N, N) ndarray 

44 Value of the sqrt function at `T` 

45 

46 References 

47 ---------- 

48 .. [1] Edvin Deadman, Nicholas J. Higham, Rui Ralha (2013) 

49 "Blocked Schur Algorithms for Computing the Matrix Square Root, 

50 Lecture Notes in Computer Science, 7782. pp. 171-182. 

51 

52 """ 

53 T_diag = np.diag(T) 

54 keep_it_real = np.isrealobj(T) and np.min(T_diag) >= 0 

55 

56 # Cast to complex as necessary + ensure double precision 

57 if not keep_it_real: 

58 T = np.asarray(T, dtype=np.complex128, order="C") 

59 T_diag = np.asarray(T_diag, dtype=np.complex128) 

60 else: 

61 T = np.asarray(T, dtype=np.float64, order="C") 

62 T_diag = np.asarray(T_diag, dtype=np.float64) 

63 

64 R = np.diag(np.sqrt(T_diag)) 

65 

66 # Compute the number of blocks to use; use at least one block. 

67 n, n = T.shape 

68 nblocks = max(n // blocksize, 1) 

69 

70 # Compute the smaller of the two sizes of blocks that 

71 # we will actually use, and compute the number of large blocks. 

72 bsmall, nlarge = divmod(n, nblocks) 

73 blarge = bsmall + 1 

74 nsmall = nblocks - nlarge 

75 if nsmall * bsmall + nlarge * blarge != n: 

76 raise Exception('internal inconsistency') 

77 

78 # Define the index range covered by each block. 

79 start_stop_pairs = [] 

80 start = 0 

81 for count, size in ((nsmall, bsmall), (nlarge, blarge)): 

82 for i in range(count): 

83 start_stop_pairs.append((start, start + size)) 

84 start += size 

85 

86 # Within-block interactions (Cythonized) 

87 try: 

88 within_block_loop(R, T, start_stop_pairs, nblocks) 

89 except RuntimeError as e: 

90 raise SqrtmError(*e.args) from e 

91 

92 # Between-block interactions (Cython would give no significant speedup) 

93 for j in range(nblocks): 

94 jstart, jstop = start_stop_pairs[j] 

95 for i in range(j-1, -1, -1): 

96 istart, istop = start_stop_pairs[i] 

97 S = T[istart:istop, jstart:jstop] 

98 if j - i > 1: 

99 S = S - R[istart:istop, istop:jstart].dot(R[istop:jstart, 

100 jstart:jstop]) 

101 

102 # Invoke LAPACK. 

103 # For more details, see the solve_sylvester implemention 

104 # and the fortran dtrsyl and ztrsyl docs. 

105 Rii = R[istart:istop, istart:istop] 

106 Rjj = R[jstart:jstop, jstart:jstop] 

107 if keep_it_real: 

108 x, scale, info = dtrsyl(Rii, Rjj, S) 

109 else: 

110 x, scale, info = ztrsyl(Rii, Rjj, S) 

111 R[istart:istop, jstart:jstop] = x * scale 

112 

113 # Return the matrix square root. 

114 return R 

115 

116 

117def sqrtm(A, disp=True, blocksize=64): 

118 """ 

119 Matrix square root. 

120 

121 Parameters 

122 ---------- 

123 A : (N, N) array_like 

124 Matrix whose square root to evaluate 

125 disp : bool, optional 

126 Print warning if error in the result is estimated large 

127 instead of returning estimated error. (Default: True) 

128 blocksize : integer, optional 

129 If the blocksize is not degenerate with respect to the 

130 size of the input array, then use a blocked algorithm. (Default: 64) 

131 

132 Returns 

133 ------- 

134 sqrtm : (N, N) ndarray 

135 Value of the sqrt function at `A`. The dtype is float or complex. 

136 The precision (data size) is determined based on the precision of 

137 input `A`. When the dtype is float, the precision is same as `A`. 

138 When the dtype is complex, the precition is double as `A`. The 

139 precision might be cliped by each dtype precision range. 

140 

141 errest : float 

142 (if disp == False) 

143 

144 Frobenius norm of the estimated error, ||err||_F / ||A||_F 

145 

146 References 

147 ---------- 

148 .. [1] Edvin Deadman, Nicholas J. Higham, Rui Ralha (2013) 

149 "Blocked Schur Algorithms for Computing the Matrix Square Root, 

150 Lecture Notes in Computer Science, 7782. pp. 171-182. 

151 

152 Examples 

153 -------- 

154 >>> import numpy as np 

155 >>> from scipy.linalg import sqrtm 

156 >>> a = np.array([[1.0, 3.0], [1.0, 4.0]]) 

157 >>> r = sqrtm(a) 

158 >>> r 

159 array([[ 0.75592895, 1.13389342], 

160 [ 0.37796447, 1.88982237]]) 

161 >>> r.dot(r) 

162 array([[ 1., 3.], 

163 [ 1., 4.]]) 

164 

165 """ 

166 byte_size = np.asarray(A).dtype.itemsize 

167 A = _asarray_validated(A, check_finite=True, as_inexact=True) 

168 if len(A.shape) != 2: 

169 raise ValueError("Non-matrix input to matrix function.") 

170 if blocksize < 1: 

171 raise ValueError("The blocksize should be at least 1.") 

172 keep_it_real = np.isrealobj(A) 

173 if keep_it_real: 

174 T, Z = schur(A) 

175 if not np.array_equal(T, np.triu(T)): 

176 T, Z = rsf2csf(T, Z) 

177 else: 

178 T, Z = schur(A, output='complex') 

179 failflag = False 

180 try: 

181 R = _sqrtm_triu(T, blocksize=blocksize) 

182 ZH = np.conjugate(Z).T 

183 X = Z.dot(R).dot(ZH) 

184 if not np.iscomplexobj(X): 

185 # float byte size range: f2 ~ f16 

186 X = X.astype(f"f{np.clip(byte_size, 2, 16)}", copy=False) 

187 else: 

188 # complex byte size range: c8 ~ c32. 

189 # c32(complex256) might not be supported in some environments. 

190 if hasattr(np, 'complex256'): 

191 X = X.astype(f"c{np.clip(byte_size*2, 8, 32)}", copy=False) 

192 else: 

193 X = X.astype(f"c{np.clip(byte_size*2, 8, 16)}", copy=False) 

194 except SqrtmError: 

195 failflag = True 

196 X = np.empty_like(A) 

197 X.fill(np.nan) 

198 

199 if disp: 

200 if failflag: 

201 print("Failed to find a square root.") 

202 return X 

203 else: 

204 try: 

205 arg2 = norm(X.dot(X) - A, 'fro')**2 / norm(A, 'fro') 

206 except ValueError: 

207 # NaNs in matrix 

208 arg2 = np.inf 

209 

210 return X, arg2