Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/sparse/conjugate_gradient.py: 27%
51 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Preconditioned Conjugate Gradient."""
17import collections
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops import while_loop
25from tensorflow.python.ops.linalg import linalg_impl as linalg
26from tensorflow.python.util import dispatch
27from tensorflow.python.util.tf_export import tf_export
30@tf_export('linalg.experimental.conjugate_gradient')
31@dispatch.add_dispatch_support
32def conjugate_gradient(operator,
33 rhs,
34 preconditioner=None,
35 x=None,
36 tol=1e-5,
37 max_iter=20,
38 name='conjugate_gradient'):
39 r"""Conjugate gradient solver.
41 Solves a linear system of equations `A*x = rhs` for self-adjoint, positive
42 definite matrix `A` and right-hand side vector `rhs`, using an iterative,
43 matrix-free algorithm where the action of the matrix A is represented by
44 `operator`. The iteration terminates when either the number of iterations
45 exceeds `max_iter` or when the residual norm has been reduced to `tol`
46 times its initial value, i.e. \\(||rhs - A x_k|| <= tol ||rhs||\\).
48 Args:
49 operator: A `LinearOperator` that is self-adjoint and positive definite.
50 rhs: A possibly batched vector of shape `[..., N]` containing the right-hand
51 size vector.
52 preconditioner: A `LinearOperator` that approximates the inverse of `A`.
53 An efficient preconditioner could dramatically improve the rate of
54 convergence. If `preconditioner` represents matrix `M`(`M` approximates
55 `A^{-1}`), the algorithm uses `preconditioner.apply(x)` to estimate
56 `A^{-1}x`. For this to be useful, the cost of applying `M` should be
57 much lower than computing `A^{-1}` directly.
58 x: A possibly batched vector of shape `[..., N]` containing the initial
59 guess for the solution.
60 tol: A float scalar convergence tolerance.
61 max_iter: An integer giving the maximum number of iterations.
62 name: A name scope for the operation.
64 Returns:
65 output: A namedtuple representing the final state with fields:
66 - i: A scalar `int32` `Tensor`. Number of iterations executed.
67 - x: A rank-1 `Tensor` of shape `[..., N]` containing the computed
68 solution.
69 - r: A rank-1 `Tensor` of shape `[.., M]` containing the residual vector.
70 - p: A rank-1 `Tensor` of shape `[..., N]`. `A`-conjugate basis vector.
71 - gamma: \\(r \dot M \dot r\\), equivalent to \\(||r||_2^2\\) when
72 `preconditioner=None`.
73 """
74 if not (operator.is_self_adjoint and operator.is_positive_definite):
75 raise ValueError('Expected a self-adjoint, positive definite operator.')
77 cg_state = collections.namedtuple('CGState', ['i', 'x', 'r', 'p', 'gamma'])
79 def stopping_criterion(i, state):
80 return math_ops.logical_and(
81 i < max_iter,
82 math_ops.reduce_any(linalg.norm(state.r, axis=-1) > tol))
84 def dot(x, y):
85 return array_ops.squeeze(
86 math_ops.matvec(
87 x[..., array_ops.newaxis],
88 y, adjoint_a=True), axis=-1)
90 def cg_step(i, state): # pylint: disable=missing-docstring
91 z = math_ops.matvec(operator, state.p)
92 alpha = state.gamma / dot(state.p, z)
93 x = state.x + alpha[..., array_ops.newaxis] * state.p
94 r = state.r - alpha[..., array_ops.newaxis] * z
95 if preconditioner is None:
96 q = r
97 else:
98 q = preconditioner.matvec(r)
99 gamma = dot(r, q)
100 beta = gamma / state.gamma
101 p = q + beta[..., array_ops.newaxis] * state.p
102 return i + 1, cg_state(i + 1, x, r, p, gamma)
104 # We now broadcast initial shapes so that we have fixed shapes per iteration.
106 with ops.name_scope(name):
107 broadcast_shape = array_ops.broadcast_dynamic_shape(
108 array_ops.shape(rhs)[:-1],
109 operator.batch_shape_tensor())
110 if preconditioner is not None:
111 broadcast_shape = array_ops.broadcast_dynamic_shape(
112 broadcast_shape,
113 preconditioner.batch_shape_tensor()
114 )
115 broadcast_rhs_shape = array_ops.concat([
116 broadcast_shape, [array_ops.shape(rhs)[-1]]], axis=-1)
117 r0 = array_ops.broadcast_to(rhs, broadcast_rhs_shape)
118 tol *= linalg.norm(r0, axis=-1)
120 if x is None:
121 x = array_ops.zeros(
122 broadcast_rhs_shape, dtype=rhs.dtype.base_dtype)
123 else:
124 r0 = rhs - math_ops.matvec(operator, x)
125 if preconditioner is None:
126 p0 = r0
127 else:
128 p0 = math_ops.matvec(preconditioner, r0)
129 gamma0 = dot(r0, p0)
130 i = constant_op.constant(0, dtype=dtypes.int32)
131 state = cg_state(i=i, x=x, r=r0, p=p0, gamma=gamma0)
132 _, state = while_loop.while_loop(stopping_criterion, cg_step, [i, state])
133 return cg_state(
134 state.i,
135 x=state.x,
136 r=state.r,
137 p=state.p,
138 gamma=state.gamma)