Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/sparse/conjugate_gradient.py: 27%

51 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Preconditioned Conjugate Gradient.""" 

16 

17import collections 

18 

19from tensorflow.python.framework import constant_op 

20from tensorflow.python.framework import dtypes 

21from tensorflow.python.framework import ops 

22from tensorflow.python.ops import array_ops 

23from tensorflow.python.ops import math_ops 

24from tensorflow.python.ops import while_loop 

25from tensorflow.python.ops.linalg import linalg_impl as linalg 

26from tensorflow.python.util import dispatch 

27from tensorflow.python.util.tf_export import tf_export 

28 

29 

30@tf_export('linalg.experimental.conjugate_gradient') 

31@dispatch.add_dispatch_support 

32def conjugate_gradient(operator, 

33 rhs, 

34 preconditioner=None, 

35 x=None, 

36 tol=1e-5, 

37 max_iter=20, 

38 name='conjugate_gradient'): 

39 r"""Conjugate gradient solver. 

40 

41 Solves a linear system of equations `A*x = rhs` for self-adjoint, positive 

42 definite matrix `A` and right-hand side vector `rhs`, using an iterative, 

43 matrix-free algorithm where the action of the matrix A is represented by 

44 `operator`. The iteration terminates when either the number of iterations 

45 exceeds `max_iter` or when the residual norm has been reduced to `tol` 

46 times its initial value, i.e. \\(||rhs - A x_k|| <= tol ||rhs||\\). 

47 

48 Args: 

49 operator: A `LinearOperator` that is self-adjoint and positive definite. 

50 rhs: A possibly batched vector of shape `[..., N]` containing the right-hand 

51 size vector. 

52 preconditioner: A `LinearOperator` that approximates the inverse of `A`. 

53 An efficient preconditioner could dramatically improve the rate of 

54 convergence. If `preconditioner` represents matrix `M`(`M` approximates 

55 `A^{-1}`), the algorithm uses `preconditioner.apply(x)` to estimate 

56 `A^{-1}x`. For this to be useful, the cost of applying `M` should be 

57 much lower than computing `A^{-1}` directly. 

58 x: A possibly batched vector of shape `[..., N]` containing the initial 

59 guess for the solution. 

60 tol: A float scalar convergence tolerance. 

61 max_iter: An integer giving the maximum number of iterations. 

62 name: A name scope for the operation. 

63 

64 Returns: 

65 output: A namedtuple representing the final state with fields: 

66 - i: A scalar `int32` `Tensor`. Number of iterations executed. 

67 - x: A rank-1 `Tensor` of shape `[..., N]` containing the computed 

68 solution. 

69 - r: A rank-1 `Tensor` of shape `[.., M]` containing the residual vector. 

70 - p: A rank-1 `Tensor` of shape `[..., N]`. `A`-conjugate basis vector. 

71 - gamma: \\(r \dot M \dot r\\), equivalent to \\(||r||_2^2\\) when 

72 `preconditioner=None`. 

73 """ 

74 if not (operator.is_self_adjoint and operator.is_positive_definite): 

75 raise ValueError('Expected a self-adjoint, positive definite operator.') 

76 

77 cg_state = collections.namedtuple('CGState', ['i', 'x', 'r', 'p', 'gamma']) 

78 

79 def stopping_criterion(i, state): 

80 return math_ops.logical_and( 

81 i < max_iter, 

82 math_ops.reduce_any(linalg.norm(state.r, axis=-1) > tol)) 

83 

84 def dot(x, y): 

85 return array_ops.squeeze( 

86 math_ops.matvec( 

87 x[..., array_ops.newaxis], 

88 y, adjoint_a=True), axis=-1) 

89 

90 def cg_step(i, state): # pylint: disable=missing-docstring 

91 z = math_ops.matvec(operator, state.p) 

92 alpha = state.gamma / dot(state.p, z) 

93 x = state.x + alpha[..., array_ops.newaxis] * state.p 

94 r = state.r - alpha[..., array_ops.newaxis] * z 

95 if preconditioner is None: 

96 q = r 

97 else: 

98 q = preconditioner.matvec(r) 

99 gamma = dot(r, q) 

100 beta = gamma / state.gamma 

101 p = q + beta[..., array_ops.newaxis] * state.p 

102 return i + 1, cg_state(i + 1, x, r, p, gamma) 

103 

104 # We now broadcast initial shapes so that we have fixed shapes per iteration. 

105 

106 with ops.name_scope(name): 

107 broadcast_shape = array_ops.broadcast_dynamic_shape( 

108 array_ops.shape(rhs)[:-1], 

109 operator.batch_shape_tensor()) 

110 if preconditioner is not None: 

111 broadcast_shape = array_ops.broadcast_dynamic_shape( 

112 broadcast_shape, 

113 preconditioner.batch_shape_tensor() 

114 ) 

115 broadcast_rhs_shape = array_ops.concat([ 

116 broadcast_shape, [array_ops.shape(rhs)[-1]]], axis=-1) 

117 r0 = array_ops.broadcast_to(rhs, broadcast_rhs_shape) 

118 tol *= linalg.norm(r0, axis=-1) 

119 

120 if x is None: 

121 x = array_ops.zeros( 

122 broadcast_rhs_shape, dtype=rhs.dtype.base_dtype) 

123 else: 

124 r0 = rhs - math_ops.matvec(operator, x) 

125 if preconditioner is None: 

126 p0 = r0 

127 else: 

128 p0 = math_ops.matvec(preconditioner, r0) 

129 gamma0 = dot(r0, p0) 

130 i = constant_op.constant(0, dtype=dtypes.int32) 

131 state = cg_state(i=i, x=x, r=r0, p=p0, gamma=gamma0) 

132 _, state = while_loop.while_loop(stopping_criterion, cg_step, [i, state]) 

133 return cg_state( 

134 state.i, 

135 x=state.x, 

136 r=state.r, 

137 p=state.p, 

138 gamma=state.gamma)