Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/sparse/linalg/_isolve/iterative.py: 5%
429 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-22 06:44 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-22 06:44 +0000
1import warnings
2import numpy as np
3from scipy.sparse.linalg._interface import LinearOperator
4from .utils import make_system
5from scipy.linalg import get_lapack_funcs
6from scipy._lib.deprecation import _NoValue, _deprecate_positional_args
8__all__ = ['bicg', 'bicgstab', 'cg', 'cgs', 'gmres', 'qmr']
11def _get_atol_rtol(name, b_norm, tol=_NoValue, atol=0., rtol=1e-5):
12 """
13 A helper function to handle tolerance deprecations and normalization
14 """
15 if tol is not _NoValue:
16 msg = (f"'scipy.sparse.linalg.{name}' keyword argument `tol` is "
17 "deprecated in favor of `rtol` and will be removed in SciPy "
18 "v1.14.0. Until then, if set, it will override `rtol`.")
19 warnings.warn(msg, category=DeprecationWarning, stacklevel=4)
20 rtol = float(tol) if tol is not None else rtol
22 if atol == 'legacy':
23 msg = (f"'scipy.sparse.linalg.{name}' called with `atol='legacy'`. "
24 "This behavior is deprecated and will result in an error in "
25 "SciPy v1.14.0. To preserve current behaviour, set `atol=0.0`.")
26 warnings.warn(msg, category=DeprecationWarning, stacklevel=4)
27 atol = 0
29 # this branch is only hit from gcrotmk/lgmres/tfqmr
30 if atol is None:
31 msg = (f"'scipy.sparse.linalg.{name}' called without specifying "
32 "`atol`. This behavior is deprecated and will result in an "
33 "error in SciPy v1.14.0. To preserve current behaviour, set "
34 "`atol=rtol`, or, to adopt the future default, set `atol=0.0`.")
35 warnings.warn(msg, category=DeprecationWarning, stacklevel=4)
36 atol = rtol
38 atol = max(float(atol), float(rtol) * float(b_norm))
40 return atol, rtol
43@_deprecate_positional_args(version="1.14")
44def bicg(A, b, x0=None, *, tol=_NoValue, maxiter=None, M=None, callback=None,
45 atol=0., rtol=1e-5):
46 """Use BIConjugate Gradient iteration to solve ``Ax = b``.
48 Parameters
49 ----------
50 A : {sparse matrix, ndarray, LinearOperator}
51 The real or complex N-by-N matrix of the linear system.
52 Alternatively, ``A`` can be a linear operator which can
53 produce ``Ax`` and ``A^T x`` using, e.g.,
54 ``scipy.sparse.linalg.LinearOperator``.
55 b : ndarray
56 Right hand side of the linear system. Has shape (N,) or (N,1).
57 x0 : ndarray
58 Starting guess for the solution.
59 rtol, atol : float, optional
60 Parameters for the convergence test. For convergence,
61 ``norm(b - A @ x) <= max(rtol*norm(b), atol)`` should be satisfied.
62 The default is ``atol=0.`` and ``rtol=1e-5``.
63 maxiter : integer
64 Maximum number of iterations. Iteration will stop after maxiter
65 steps even if the specified tolerance has not been achieved.
66 M : {sparse matrix, ndarray, LinearOperator}
67 Preconditioner for A. The preconditioner should approximate the
68 inverse of A. Effective preconditioning dramatically improves the
69 rate of convergence, which implies that fewer iterations are needed
70 to reach a given error tolerance.
71 callback : function
72 User-supplied function to call after each iteration. It is called
73 as callback(xk), where xk is the current solution vector.
74 tol : float, optional, deprecated
76 .. deprecated:: 1.12.0
77 `bicg` keyword argument ``tol`` is deprecated in favor of ``rtol``
78 and will be removed in SciPy 1.14.0.
80 Returns
81 -------
82 x : ndarray
83 The converged solution.
84 info : integer
85 Provides convergence information:
86 0 : successful exit
87 >0 : convergence to tolerance not achieved, number of iterations
88 <0 : parameter breakdown
90 Examples
91 --------
92 >>> import numpy as np
93 >>> from scipy.sparse import csc_matrix
94 >>> from scipy.sparse.linalg import bicg
95 >>> A = csc_matrix([[3, 2, 0], [1, -1, 0], [0, 5, 1.]])
96 >>> b = np.array([2., 4., -1.])
97 >>> x, exitCode = bicg(A, b, atol=1e-5)
98 >>> print(exitCode) # 0 indicates successful convergence
99 0
100 >>> np.allclose(A.dot(x), b)
101 True
103 """
104 A, M, x, b, postprocess = make_system(A, M, x0, b)
105 bnrm2 = np.linalg.norm(b)
107 atol, _ = _get_atol_rtol('bicg', bnrm2, tol, atol, rtol)
109 if bnrm2 == 0:
110 return postprocess(b), 0
112 n = len(b)
113 dotprod = np.vdot if np.iscomplexobj(x) else np.dot
115 if maxiter is None:
116 maxiter = n*10
118 matvec, rmatvec = A.matvec, A.rmatvec
119 psolve, rpsolve = M.matvec, M.rmatvec
121 rhotol = np.finfo(x.dtype.char).eps**2
123 # Dummy values to initialize vars, silence linter warnings
124 rho_prev, p, ptilde = None, None, None
126 r = b - matvec(x) if x.any() else b.copy()
127 rtilde = r.copy()
129 for iteration in range(maxiter):
130 if np.linalg.norm(r) < atol: # Are we done?
131 return postprocess(x), 0
133 z = psolve(r)
134 ztilde = rpsolve(rtilde)
135 # order matters in this dot product
136 rho_cur = dotprod(rtilde, z)
138 if np.abs(rho_cur) < rhotol: # Breakdown case
139 return postprocess, -10
141 if iteration > 0:
142 beta = rho_cur / rho_prev
143 p *= beta
144 p += z
145 ptilde *= beta.conj()
146 ptilde += ztilde
147 else: # First spin
148 p = z.copy()
149 ptilde = ztilde.copy()
151 q = matvec(p)
152 qtilde = rmatvec(ptilde)
153 rv = dotprod(ptilde, q)
155 if rv == 0:
156 return postprocess(x), -11
158 alpha = rho_cur / rv
159 x += alpha*p
160 r -= alpha*q
161 rtilde -= alpha.conj()*qtilde
162 rho_prev = rho_cur
164 if callback:
165 callback(x)
167 else: # for loop exhausted
168 # Return incomplete progress
169 return postprocess(x), maxiter
172@_deprecate_positional_args(version="1.14")
173def bicgstab(A, b, *, x0=None, tol=_NoValue, maxiter=None, M=None,
174 callback=None, atol=0., rtol=1e-5):
175 """Use BIConjugate Gradient STABilized iteration to solve ``Ax = b``.
177 Parameters
178 ----------
179 A : {sparse matrix, ndarray, LinearOperator}
180 The real or complex N-by-N matrix of the linear system.
181 Alternatively, ``A`` can be a linear operator which can
182 produce ``Ax`` and ``A^T x`` using, e.g.,
183 ``scipy.sparse.linalg.LinearOperator``.
184 b : ndarray
185 Right hand side of the linear system. Has shape (N,) or (N,1).
186 x0 : ndarray
187 Starting guess for the solution.
188 rtol, atol : float, optional
189 Parameters for the convergence test. For convergence,
190 ``norm(b - A @ x) <= max(rtol*norm(b), atol)`` should be satisfied.
191 The default is ``atol=0.`` and ``rtol=1e-5``.
192 maxiter : integer
193 Maximum number of iterations. Iteration will stop after maxiter
194 steps even if the specified tolerance has not been achieved.
195 M : {sparse matrix, ndarray, LinearOperator}
196 Preconditioner for A. The preconditioner should approximate the
197 inverse of A. Effective preconditioning dramatically improves the
198 rate of convergence, which implies that fewer iterations are needed
199 to reach a given error tolerance.
200 callback : function
201 User-supplied function to call after each iteration. It is called
202 as callback(xk), where xk is the current solution vector.
203 tol : float, optional, deprecated
205 .. deprecated:: 1.12.0
206 `bicgstab` keyword argument ``tol`` is deprecated in favor of
207 ``rtol`` and will be removed in SciPy 1.14.0.
209 Returns
210 -------
211 x : ndarray
212 The converged solution.
213 info : integer
214 Provides convergence information:
215 0 : successful exit
216 >0 : convergence to tolerance not achieved, number of iterations
217 <0 : parameter breakdown
219 Examples
220 --------
221 >>> import numpy as np
222 >>> from scipy.sparse import csc_matrix
223 >>> from scipy.sparse.linalg import bicgstab
224 >>> R = np.array([[4, 2, 0, 1],
225 ... [3, 0, 0, 2],
226 ... [0, 1, 1, 1],
227 ... [0, 2, 1, 0]])
228 >>> A = csc_matrix(R)
229 >>> b = np.array([-1, -0.5, -1, 2])
230 >>> x, exit_code = bicgstab(A, b, atol=1e-5)
231 >>> print(exit_code) # 0 indicates successful convergence
232 0
233 >>> np.allclose(A.dot(x), b)
234 True
236 """
237 A, M, x, b, postprocess = make_system(A, M, x0, b)
238 bnrm2 = np.linalg.norm(b)
240 atol, _ = _get_atol_rtol('bicgstab', bnrm2, tol, atol, rtol)
242 if bnrm2 == 0:
243 return postprocess(b), 0
245 n = len(b)
247 dotprod = np.vdot if np.iscomplexobj(x) else np.dot
249 if maxiter is None:
250 maxiter = n*10
252 matvec = A.matvec
253 psolve = M.matvec
255 # These values make no sense but coming from original Fortran code
256 # sqrt might have been meant instead.
257 rhotol = np.finfo(x.dtype.char).eps**2
258 omegatol = rhotol
260 # Dummy values to initialize vars, silence linter warnings
261 rho_prev, omega, alpha, p, v = None, None, None, None, None
263 r = b - matvec(x) if x.any() else b.copy()
264 rtilde = r.copy()
266 for iteration in range(maxiter):
267 if np.linalg.norm(r) < atol: # Are we done?
268 return postprocess(x), 0
270 rho = dotprod(rtilde, r)
271 if np.abs(rho) < rhotol: # rho breakdown
272 return postprocess(x), -10
274 if iteration > 0:
275 if np.abs(omega) < omegatol: # omega breakdown
276 return postprocess(x), -11
278 beta = (rho / rho_prev) * (alpha / omega)
279 p -= omega*v
280 p *= beta
281 p += r
282 else: # First spin
283 s = np.empty_like(r)
284 p = r.copy()
286 phat = psolve(p)
287 v = matvec(phat)
288 rv = dotprod(rtilde, v)
289 if rv == 0:
290 return postprocess(x), -11
291 alpha = rho / rv
292 r -= alpha*v
293 s[:] = r[:]
295 if np.linalg.norm(s) < atol:
296 x += alpha*phat
297 return postprocess(x), 0
299 shat = psolve(s)
300 t = matvec(shat)
301 omega = dotprod(t, s) / dotprod(t, t)
302 x += alpha*phat
303 x += omega*shat
304 r -= omega*t
305 rho_prev = rho
307 if callback:
308 callback(x)
310 else: # for loop exhausted
311 # Return incomplete progress
312 return postprocess(x), maxiter
315@_deprecate_positional_args(version="1.14")
316def cg(A, b, x0=None, *, tol=_NoValue, maxiter=None, M=None, callback=None,
317 atol=0., rtol=1e-5):
318 """Use Conjugate Gradient iteration to solve ``Ax = b``.
320 Parameters
321 ----------
322 A : {sparse matrix, ndarray, LinearOperator}
323 The real or complex N-by-N matrix of the linear system.
324 ``A`` must represent a hermitian, positive definite matrix.
325 Alternatively, ``A`` can be a linear operator which can
326 produce ``Ax`` using, e.g.,
327 ``scipy.sparse.linalg.LinearOperator``.
328 b : ndarray
329 Right hand side of the linear system. Has shape (N,) or (N,1).
330 x0 : ndarray
331 Starting guess for the solution.
332 rtol, atol : float, optional
333 Parameters for the convergence test. For convergence,
334 ``norm(b - A @ x) <= max(rtol*norm(b), atol)`` should be satisfied.
335 The default is ``atol=0.`` and ``rtol=1e-5``.
336 maxiter : integer
337 Maximum number of iterations. Iteration will stop after maxiter
338 steps even if the specified tolerance has not been achieved.
339 M : {sparse matrix, ndarray, LinearOperator}
340 Preconditioner for A. The preconditioner should approximate the
341 inverse of A. Effective preconditioning dramatically improves the
342 rate of convergence, which implies that fewer iterations are needed
343 to reach a given error tolerance.
344 callback : function
345 User-supplied function to call after each iteration. It is called
346 as callback(xk), where xk is the current solution vector.
347 tol : float, optional, deprecated
349 .. deprecated:: 1.12.0
350 `cg` keyword argument ``tol`` is deprecated in favor of ``rtol`` and
351 will be removed in SciPy 1.14.0.
353 Returns
354 -------
355 x : ndarray
356 The converged solution.
357 info : integer
358 Provides convergence information:
359 0 : successful exit
360 >0 : convergence to tolerance not achieved, number of iterations
362 Examples
363 --------
364 >>> import numpy as np
365 >>> from scipy.sparse import csc_matrix
366 >>> from scipy.sparse.linalg import cg
367 >>> P = np.array([[4, 0, 1, 0],
368 ... [0, 5, 0, 0],
369 ... [1, 0, 3, 2],
370 ... [0, 0, 2, 4]])
371 >>> A = csc_matrix(P)
372 >>> b = np.array([-1, -0.5, -1, 2])
373 >>> x, exit_code = cg(A, b, atol=1e-5)
374 >>> print(exit_code) # 0 indicates successful convergence
375 0
376 >>> np.allclose(A.dot(x), b)
377 True
379 """
380 A, M, x, b, postprocess = make_system(A, M, x0, b)
381 bnrm2 = np.linalg.norm(b)
383 atol, _ = _get_atol_rtol('cg', bnrm2, tol, atol, rtol)
385 if bnrm2 == 0:
386 return postprocess(b), 0
388 n = len(b)
390 if maxiter is None:
391 maxiter = n*10
393 dotprod = np.vdot if np.iscomplexobj(x) else np.dot
395 matvec = A.matvec
396 psolve = M.matvec
397 r = b - matvec(x) if x.any() else b.copy()
399 # Dummy value to initialize var, silences warnings
400 rho_prev, p = None, None
402 for iteration in range(maxiter):
403 if np.linalg.norm(r) < atol: # Are we done?
404 return postprocess(x), 0
406 z = psolve(r)
407 rho_cur = dotprod(r, z)
408 if iteration > 0:
409 beta = rho_cur / rho_prev
410 p *= beta
411 p += z
412 else: # First spin
413 p = np.empty_like(r)
414 p[:] = z[:]
416 q = matvec(p)
417 alpha = rho_cur / dotprod(p, q)
418 x += alpha*p
419 r -= alpha*q
420 rho_prev = rho_cur
422 if callback:
423 callback(x)
425 else: # for loop exhausted
426 # Return incomplete progress
427 return postprocess(x), maxiter
430@_deprecate_positional_args(version="1.14")
431def cgs(A, b, x0=None, *, tol=_NoValue, maxiter=None, M=None, callback=None,
432 atol=0., rtol=1e-5):
433 """Use Conjugate Gradient Squared iteration to solve ``Ax = b``.
435 Parameters
436 ----------
437 A : {sparse matrix, ndarray, LinearOperator}
438 The real-valued N-by-N matrix of the linear system.
439 Alternatively, ``A`` can be a linear operator which can
440 produce ``Ax`` using, e.g.,
441 ``scipy.sparse.linalg.LinearOperator``.
442 b : ndarray
443 Right hand side of the linear system. Has shape (N,) or (N,1).
444 x0 : ndarray
445 Starting guess for the solution.
446 rtol, atol : float, optional
447 Parameters for the convergence test. For convergence,
448 ``norm(b - A @ x) <= max(rtol*norm(b), atol)`` should be satisfied.
449 The default is ``atol=0.`` and ``rtol=1e-5``.
450 maxiter : integer
451 Maximum number of iterations. Iteration will stop after maxiter
452 steps even if the specified tolerance has not been achieved.
453 M : {sparse matrix, ndarray, LinearOperator}
454 Preconditioner for A. The preconditioner should approximate the
455 inverse of A. Effective preconditioning dramatically improves the
456 rate of convergence, which implies that fewer iterations are needed
457 to reach a given error tolerance.
458 callback : function
459 User-supplied function to call after each iteration. It is called
460 as callback(xk), where xk is the current solution vector.
461 tol : float, optional, deprecated
463 .. deprecated:: 1.12.0
464 `cgs` keyword argument ``tol`` is deprecated in favor of ``rtol``
465 and will be removed in SciPy 1.14.0.
467 Returns
468 -------
469 x : ndarray
470 The converged solution.
471 info : integer
472 Provides convergence information:
473 0 : successful exit
474 >0 : convergence to tolerance not achieved, number of iterations
475 <0 : parameter breakdown
477 Examples
478 --------
479 >>> import numpy as np
480 >>> from scipy.sparse import csc_matrix
481 >>> from scipy.sparse.linalg import cgs
482 >>> R = np.array([[4, 2, 0, 1],
483 ... [3, 0, 0, 2],
484 ... [0, 1, 1, 1],
485 ... [0, 2, 1, 0]])
486 >>> A = csc_matrix(R)
487 >>> b = np.array([-1, -0.5, -1, 2])
488 >>> x, exit_code = cgs(A, b)
489 >>> print(exit_code) # 0 indicates successful convergence
490 0
491 >>> np.allclose(A.dot(x), b)
492 True
494 """
495 A, M, x, b, postprocess = make_system(A, M, x0, b)
496 bnrm2 = np.linalg.norm(b)
498 atol, _ = _get_atol_rtol('cgs', bnrm2, tol, atol, rtol)
500 if bnrm2 == 0:
501 return postprocess(b), 0
503 n = len(b)
505 dotprod = np.vdot if np.iscomplexobj(x) else np.dot
507 if maxiter is None:
508 maxiter = n*10
510 matvec = A.matvec
511 psolve = M.matvec
513 rhotol = np.finfo(x.dtype.char).eps**2
515 r = b - matvec(x) if x.any() else b.copy()
517 rtilde = r.copy()
518 bnorm = np.linalg.norm(b)
519 if bnorm == 0:
520 bnorm = 1
522 # Dummy values to initialize vars, silence linter warnings
523 rho_prev, p, u, q = None, None, None, None
525 for iteration in range(maxiter):
526 rnorm = np.linalg.norm(r)
527 if rnorm < atol: # Are we done?
528 return postprocess(x), 0
530 rho_cur = dotprod(rtilde, r)
531 if np.abs(rho_cur) < rhotol: # Breakdown case
532 return postprocess, -10
534 if iteration > 0:
535 beta = rho_cur / rho_prev
537 # u = r + beta * q
538 # p = u + beta * (q + beta * p);
539 u[:] = r[:]
540 u += beta*q
542 p *= beta
543 p += q
544 p *= beta
545 p += u
547 else: # First spin
548 p = r.copy()
549 u = r.copy()
550 q = np.empty_like(r)
552 phat = psolve(p)
553 vhat = matvec(phat)
554 rv = dotprod(rtilde, vhat)
556 if rv == 0: # Dot product breakdown
557 return postprocess(x), -11
559 alpha = rho_cur / rv
560 q[:] = u[:]
561 q -= alpha*vhat
562 uhat = psolve(u + q)
563 x += alpha*uhat
565 # Due to numerical error build-up the actual residual is computed
566 # instead of the following two lines that were in the original
567 # FORTRAN templates, still using a single matvec.
569 # qhat = matvec(uhat)
570 # r -= alpha*qhat
571 r = b - matvec(x)
573 rho_prev = rho_cur
575 if callback:
576 callback(x)
578 else: # for loop exhausted
579 # Return incomplete progress
580 return postprocess(x), maxiter
583@_deprecate_positional_args(version="1.14")
584def gmres(A, b, x0=None, *, tol=_NoValue, restart=None, maxiter=None, M=None,
585 callback=None, restrt=_NoValue, atol=0., callback_type=None,
586 rtol=1e-5):
587 """
588 Use Generalized Minimal RESidual iteration to solve ``Ax = b``.
590 Parameters
591 ----------
592 A : {sparse matrix, ndarray, LinearOperator}
593 The real or complex N-by-N matrix of the linear system.
594 Alternatively, ``A`` can be a linear operator which can
595 produce ``Ax`` using, e.g.,
596 ``scipy.sparse.linalg.LinearOperator``.
597 b : ndarray
598 Right hand side of the linear system. Has shape (N,) or (N,1).
599 x0 : ndarray
600 Starting guess for the solution (a vector of zeros by default).
601 atol, rtol : float
602 Parameters for the convergence test. For convergence,
603 ``norm(b - A @ x) <= max(rtol*norm(b), atol)`` should be satisfied.
604 The default is ``atol=0.`` and ``rtol=1e-5``.
605 restart : int, optional
606 Number of iterations between restarts. Larger values increase
607 iteration cost, but may be necessary for convergence.
608 If omitted, ``min(20, n)`` is used.
609 maxiter : int, optional
610 Maximum number of iterations (restart cycles). Iteration will stop
611 after maxiter steps even if the specified tolerance has not been
612 achieved. See `callback_type`.
613 M : {sparse matrix, ndarray, LinearOperator}
614 Inverse of the preconditioner of A. M should approximate the
615 inverse of A and be easy to solve for (see Notes). Effective
616 preconditioning dramatically improves the rate of convergence,
617 which implies that fewer iterations are needed to reach a given
618 error tolerance. By default, no preconditioner is used.
619 In this implementation, left preconditioning is used,
620 and the preconditioned residual is minimized. However, the final
621 convergence is tested with respect to the ``b - A @ x`` residual.
622 callback : function
623 User-supplied function to call after each iteration. It is called
624 as `callback(args)`, where `args` are selected by `callback_type`.
625 callback_type : {'x', 'pr_norm', 'legacy'}, optional
626 Callback function argument requested:
627 - ``x``: current iterate (ndarray), called on every restart
628 - ``pr_norm``: relative (preconditioned) residual norm (float),
629 called on every inner iteration
630 - ``legacy`` (default): same as ``pr_norm``, but also changes the
631 meaning of `maxiter` to count inner iterations instead of restart
632 cycles.
634 This keyword has no effect if `callback` is not set.
635 restrt : int, optional, deprecated
637 .. deprecated:: 0.11.0
638 `gmres` keyword argument ``restrt`` is deprecated in favor of
639 ``restart`` and will be removed in SciPy 1.14.0.
640 tol : float, optional, deprecated
642 .. deprecated:: 1.12.0
643 `gmres` keyword argument ``tol`` is deprecated in favor of ``rtol``
644 and will be removed in SciPy 1.14.0
646 Returns
647 -------
648 x : ndarray
649 The converged solution.
650 info : int
651 Provides convergence information:
652 0 : successful exit
653 >0 : convergence to tolerance not achieved, number of iterations
655 See Also
656 --------
657 LinearOperator
659 Notes
660 -----
661 A preconditioner, P, is chosen such that P is close to A but easy to solve
662 for. The preconditioner parameter required by this routine is
663 ``M = P^-1``. The inverse should preferably not be calculated
664 explicitly. Rather, use the following template to produce M::
666 # Construct a linear operator that computes P^-1 @ x.
667 import scipy.sparse.linalg as spla
668 M_x = lambda x: spla.spsolve(P, x)
669 M = spla.LinearOperator((n, n), M_x)
671 Examples
672 --------
673 >>> import numpy as np
674 >>> from scipy.sparse import csc_matrix
675 >>> from scipy.sparse.linalg import gmres
676 >>> A = csc_matrix([[3, 2, 0], [1, -1, 0], [0, 5, 1]], dtype=float)
677 >>> b = np.array([2, 4, -1], dtype=float)
678 >>> x, exitCode = gmres(A, b, atol=1e-5)
679 >>> print(exitCode) # 0 indicates successful convergence
680 0
681 >>> np.allclose(A.dot(x), b)
682 True
683 """
685 # Handle the deprecation frenzy
686 if restrt not in (None, _NoValue) and restart:
687 raise ValueError("Cannot specify both 'restart' and 'restrt'"
688 " keywords. Also 'rstrt' is deprecated."
689 " and will be removed in SciPy 1.14.0. Use "
690 "'restart' instead.")
691 if restrt is not _NoValue:
692 msg = ("'gmres' keyword argument 'restrt' is deprecated "
693 "in favor of 'restart' and will be removed in SciPy"
694 " 1.14.0. Until then, if set, 'rstrt' will override 'restart'."
695 )
696 warnings.warn(msg, DeprecationWarning, stacklevel=3)
697 restart = restrt
699 if callback is not None and callback_type is None:
700 # Warn about 'callback_type' semantic changes.
701 # Probably should be removed only in far future, Scipy 2.0 or so.
702 msg = ("scipy.sparse.linalg.gmres called without specifying "
703 "`callback_type`. The default value will be changed in"
704 " a future release. For compatibility, specify a value "
705 "for `callback_type` explicitly, e.g., "
706 "``gmres(..., callback_type='pr_norm')``, or to retain the "
707 "old behavior ``gmres(..., callback_type='legacy')``"
708 )
709 warnings.warn(msg, category=DeprecationWarning, stacklevel=3)
711 if callback_type is None:
712 callback_type = 'legacy'
714 if callback_type not in ('x', 'pr_norm', 'legacy'):
715 raise ValueError(f"Unknown callback_type: {callback_type!r}")
717 if callback is None:
718 callback_type = None
720 A, M, x, b, postprocess = make_system(A, M, x0, b)
721 matvec = A.matvec
722 psolve = M.matvec
723 n = len(b)
724 bnrm2 = np.linalg.norm(b)
726 atol, _ = _get_atol_rtol('gmres', bnrm2, tol, atol, rtol)
728 if bnrm2 == 0:
729 return postprocess(b), 0
731 eps = np.finfo(x.dtype.char).eps
733 dotprod = np.vdot if np.iscomplexobj(x) else np.dot
735 if maxiter is None:
736 maxiter = n*10
738 if restart is None:
739 restart = 20
740 restart = min(restart, n)
742 Mb_nrm2 = np.linalg.norm(psolve(b))
744 # ====================================================
745 # =========== Tolerance control from gh-8400 =========
746 # ====================================================
747 # Tolerance passed to GMRESREVCOM applies to the inner
748 # iteration and deals with the left-preconditioned
749 # residual.
750 ptol_max_factor = 1.
751 ptol = Mb_nrm2 * min(ptol_max_factor, atol / bnrm2)
752 presid = 0.
753 # ====================================================
754 lartg = get_lapack_funcs('lartg', dtype=x.dtype)
756 # allocate internal variables
757 v = np.empty([restart+1, n], dtype=x.dtype)
758 h = np.zeros([restart, restart+1], dtype=x.dtype)
759 givens = np.zeros([restart, 2], dtype=x.dtype)
761 # legacy iteration count
762 inner_iter = 0
764 for iteration in range(maxiter):
765 if iteration == 0:
766 r = b - matvec(x) if x.any() else b.copy()
767 if np.linalg.norm(r) < atol: # Are we done?
768 return postprocess(x), 0
770 v[0, :] = psolve(r)
771 tmp = np.linalg.norm(v[0, :])
772 v[0, :] *= (1 / tmp)
773 # RHS of the Hessenberg problem
774 S = np.zeros(restart+1, dtype=x.dtype)
775 S[0] = tmp
777 breakdown = False
778 for col in range(restart):
779 av = matvec(v[col, :])
780 w = psolve(av)
782 # Modified Gram-Schmidt
783 h0 = np.linalg.norm(w)
784 for k in range(col+1):
785 tmp = dotprod(v[k, :], w)
786 h[col, k] = tmp
787 w -= tmp*v[k, :]
789 h1 = np.linalg.norm(w)
790 h[col, col + 1] = h1
791 v[col + 1, :] = w[:]
793 # Exact solution indicator
794 if h1 <= eps*h0:
795 h[col, col + 1] = 0
796 breakdown = True
797 else:
798 v[col + 1, :] *= (1 / h1)
800 # apply past Givens rotations to current h column
801 for k in range(col):
802 c, s = givens[k, 0], givens[k, 1]
803 n0, n1 = h[col, [k, k+1]]
804 h[col, [k, k + 1]] = [c*n0 + s*n1, -s.conj()*n0 + c*n1]
806 # get and apply current rotation to h and S
807 c, s, mag = lartg(h[col, col], h[col, col+1])
808 givens[col, :] = [c, s]
809 h[col, [col, col+1]] = mag, 0
811 # S[col+1] component is always 0
812 tmp = -np.conjugate(s)*S[col]
813 S[[col, col + 1]] = [c*S[col], tmp]
814 presid = np.abs(tmp)
815 inner_iter += 1
817 if callback_type in ('legacy', 'pr_norm'):
818 callback(presid / bnrm2)
819 # Legacy behavior
820 if callback_type == 'legacy' and inner_iter == maxiter:
821 break
822 if presid <= ptol or breakdown:
823 break
825 # Solve h(col, col) upper triangular system and allow pseudo-solve
826 # singular cases as in (but without the f2py copies):
827 # y = trsv(h[:col+1, :col+1].T, S[:col+1])
829 if h[col, col] == 0:
830 S[col] = 0
832 y = np.zeros([col+1], dtype=x.dtype)
833 y[:] = S[:col+1]
834 for k in range(col, 0, -1):
835 if y[k] != 0:
836 y[k] /= h[k, k]
837 tmp = y[k]
838 y[:k] -= tmp*h[k, :k]
839 if y[0] != 0:
840 y[0] /= h[0, 0]
842 x += y @ v[:col+1, :]
844 r = b - matvec(x)
845 rnorm = np.linalg.norm(r)
847 # Legacy exit
848 if callback_type == 'legacy' and inner_iter == maxiter:
849 return postprocess(x), 0 if rnorm <= atol else maxiter
851 if callback_type == 'x':
852 callback(x)
854 if rnorm <= atol:
855 break
856 elif breakdown:
857 # Reached breakdown (= exact solution), but the external
858 # tolerance check failed. Bail out with failure.
859 break
860 elif presid <= ptol:
861 # Inner loop passed but outer didn't
862 ptol_max_factor = max(eps, 0.25 * ptol_max_factor)
863 else:
864 ptol_max_factor = min(1.0, 1.5 * ptol_max_factor)
866 ptol = presid * min(ptol_max_factor, atol / rnorm)
868 info = 0 if (rnorm <= atol) else maxiter
869 return postprocess(x), info
872@_deprecate_positional_args(version="1.14")
873def qmr(A, b, x0=None, *, tol=_NoValue, maxiter=None, M1=None, M2=None,
874 callback=None, atol=0., rtol=1e-5):
875 """Use Quasi-Minimal Residual iteration to solve ``Ax = b``.
877 Parameters
878 ----------
879 A : {sparse matrix, ndarray, LinearOperator}
880 The real-valued N-by-N matrix of the linear system.
881 Alternatively, ``A`` can be a linear operator which can
882 produce ``Ax`` and ``A^T x`` using, e.g.,
883 ``scipy.sparse.linalg.LinearOperator``.
884 b : ndarray
885 Right hand side of the linear system. Has shape (N,) or (N,1).
886 x0 : ndarray
887 Starting guess for the solution.
888 atol, rtol : float, optional
889 Parameters for the convergence test. For convergence,
890 ``norm(b - A @ x) <= max(rtol*norm(b), atol)`` should be satisfied.
891 The default is ``atol=0.`` and ``rtol=1e-5``.
892 maxiter : integer
893 Maximum number of iterations. Iteration will stop after maxiter
894 steps even if the specified tolerance has not been achieved.
895 M1 : {sparse matrix, ndarray, LinearOperator}
896 Left preconditioner for A.
897 M2 : {sparse matrix, ndarray, LinearOperator}
898 Right preconditioner for A. Used together with the left
899 preconditioner M1. The matrix M1@A@M2 should have better
900 conditioned than A alone.
901 callback : function
902 User-supplied function to call after each iteration. It is called
903 as callback(xk), where xk is the current solution vector.
904 tol : float, optional, deprecated
906 .. deprecated:: 1.12.0
907 `qmr` keyword argument ``tol`` is deprecated in favor of ``rtol``
908 and will be removed in SciPy 1.14.0.
910 Returns
911 -------
912 x : ndarray
913 The converged solution.
914 info : integer
915 Provides convergence information:
916 0 : successful exit
917 >0 : convergence to tolerance not achieved, number of iterations
918 <0 : parameter breakdown
920 See Also
921 --------
922 LinearOperator
924 Examples
925 --------
926 >>> import numpy as np
927 >>> from scipy.sparse import csc_matrix
928 >>> from scipy.sparse.linalg import qmr
929 >>> A = csc_matrix([[3., 2., 0.], [1., -1., 0.], [0., 5., 1.]])
930 >>> b = np.array([2., 4., -1.])
931 >>> x, exitCode = qmr(A, b, atol=1e-5)
932 >>> print(exitCode) # 0 indicates successful convergence
933 0
934 >>> np.allclose(A.dot(x), b)
935 True
936 """
937 A_ = A
938 A, M, x, b, postprocess = make_system(A, None, x0, b)
939 bnrm2 = np.linalg.norm(b)
941 atol, _ = _get_atol_rtol('qmr', bnrm2, tol, atol, rtol)
943 if bnrm2 == 0:
944 return postprocess(b), 0
946 if M1 is None and M2 is None:
947 if hasattr(A_, 'psolve'):
948 def left_psolve(b):
949 return A_.psolve(b, 'left')
951 def right_psolve(b):
952 return A_.psolve(b, 'right')
954 def left_rpsolve(b):
955 return A_.rpsolve(b, 'left')
957 def right_rpsolve(b):
958 return A_.rpsolve(b, 'right')
959 M1 = LinearOperator(A.shape,
960 matvec=left_psolve,
961 rmatvec=left_rpsolve)
962 M2 = LinearOperator(A.shape,
963 matvec=right_psolve,
964 rmatvec=right_rpsolve)
965 else:
966 def id(b):
967 return b
968 M1 = LinearOperator(A.shape, matvec=id, rmatvec=id)
969 M2 = LinearOperator(A.shape, matvec=id, rmatvec=id)
971 n = len(b)
972 if maxiter is None:
973 maxiter = n*10
975 dotprod = np.vdot if np.iscomplexobj(x) else np.dot
977 rhotol = np.finfo(x.dtype.char).eps
978 betatol = rhotol
979 gammatol = rhotol
980 deltatol = rhotol
981 epsilontol = rhotol
982 xitol = rhotol
984 r = b - A.matvec(x) if x.any() else b.copy()
986 vtilde = r.copy()
987 y = M1.matvec(vtilde)
988 rho = np.linalg.norm(y)
989 wtilde = r.copy()
990 z = M2.rmatvec(wtilde)
991 xi = np.linalg.norm(z)
992 gamma, eta, theta = 1, -1, 0
993 v = np.empty_like(vtilde)
994 w = np.empty_like(wtilde)
996 # Dummy values to initialize vars, silence linter warnings
997 epsilon, q, d, p, s = None, None, None, None, None
999 for iteration in range(maxiter):
1000 if np.linalg.norm(r) < atol: # Are we done?
1001 return postprocess(x), 0
1002 if np.abs(rho) < rhotol: # rho breakdown
1003 return postprocess(x), -10
1004 if np.abs(xi) < xitol: # xi breakdown
1005 return postprocess(x), -15
1007 v[:] = vtilde[:]
1008 v *= (1 / rho)
1009 y *= (1 / rho)
1010 w[:] = wtilde[:]
1011 w *= (1 / xi)
1012 z *= (1 / xi)
1013 delta = dotprod(z, y)
1015 if np.abs(delta) < deltatol: # delta breakdown
1016 return postprocess(x), -13
1018 ytilde = M2.matvec(y)
1019 ztilde = M1.rmatvec(z)
1021 if iteration > 0:
1022 ytilde -= (xi * delta / epsilon) * p
1023 p[:] = ytilde[:]
1024 ztilde -= (rho * (delta / epsilon).conj()) * q
1025 q[:] = ztilde[:]
1026 else: # First spin
1027 p = ytilde.copy()
1028 q = ztilde.copy()
1030 ptilde = A.matvec(p)
1031 epsilon = dotprod(q, ptilde)
1032 if np.abs(epsilon) < epsilontol: # epsilon breakdown
1033 return postprocess(x), -14
1035 beta = epsilon / delta
1036 if np.abs(beta) < betatol: # beta breakdown
1037 return postprocess(x), -11
1039 vtilde[:] = ptilde[:]
1040 vtilde -= beta*v
1041 y = M1.matvec(vtilde)
1043 rho_prev = rho
1044 rho = np.linalg.norm(y)
1045 wtilde[:] = w[:]
1046 wtilde *= - beta.conj()
1047 wtilde += A.rmatvec(q)
1048 z = M2.rmatvec(wtilde)
1049 xi = np.linalg.norm(z)
1050 gamma_prev = gamma
1051 theta_prev = theta
1052 theta = rho / (gamma_prev * np.abs(beta))
1053 gamma = 1 / np.sqrt(1 + theta**2)
1055 if np.abs(gamma) < gammatol: # gamma breakdown
1056 return postprocess(x), -12
1058 eta *= -(rho_prev / beta) * (gamma / gamma_prev)**2
1060 if iteration > 0:
1061 d *= (theta_prev * gamma) ** 2
1062 d += eta*p
1063 s *= (theta_prev * gamma) ** 2
1064 s += eta*ptilde
1065 else:
1066 d = p.copy()
1067 d *= eta
1068 s = ptilde.copy()
1069 s *= eta
1071 x += d
1072 r -= s
1074 if callback:
1075 callback(x)
1077 else: # for loop exhausted
1078 # Return incomplete progress
1079 return postprocess(x), maxiter