Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_grad.py: 24%

120 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""Gradients for operators defined in control_flow_ops.py.""" 

17 

18from tensorflow.python.framework import dtypes 

19from tensorflow.python.framework import indexed_slices 

20from tensorflow.python.framework import ops 

21from tensorflow.python.framework import sparse_tensor 

22from tensorflow.python.ops import control_flow_ops 

23from tensorflow.python.ops import control_flow_util 

24from tensorflow.python.ops import math_ops 

25# go/tf-wildcard-import 

26# pylint: disable=wildcard-import,undefined-variable,redefined-builtin 

27from tensorflow.python.ops.control_flow_ops import * 

28# pylint: enable=wildcard-import 

29 

30 

31def _SwitchGrad(op, *grad): 

32 """Gradients for a Switch op is calculated using a Merge op. 

33 

34 If the switch is a loop switch, it will be visited twice. We create 

35 the merge on the first visit, and update the other input of the merge 

36 on the second visit. A next_iteration is also added on second visit. 

37 """ 

38 graph = ops.get_default_graph() 

39 # pylint: disable=protected-access 

40 op_ctxt = op._get_control_flow_context() 

41 grad_ctxt = graph._get_control_flow_context() 

42 # pylint: enable=protected-access 

43 if isinstance(op_ctxt, WhileContext): 

44 merge_grad = grad_ctxt.grad_state.switch_map.get(op) 

45 if merge_grad is not None: 

46 # This is the second time this Switch is visited. It comes from 

47 # the non-exit branch of the Switch, so update the second input 

48 # to the Merge. 

49 # TODO(yuanbyu): Perform shape inference with this new input. 

50 if grad[1] is not None: 

51 # pylint: disable=protected-access 

52 control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1], 

53 enforce_shape_invariant=False) 

54 # pylint: enable=protected-access 

55 return None, None 

56 elif grad[0] is not None: 

57 # This is the first time this Switch is visited. It comes from 

58 # the Exit branch, which is grad[0]. grad[1] is empty at this point. 

59 # Use grad[0] for both inputs to merge for now, but update the second 

60 # input of merge when we see this Switch the second time. 

61 merge_grad = merge([grad[0], grad[0]], name="b_switch")[0] 

62 grad_ctxt.grad_state.switch_map[op] = merge_grad 

63 return merge_grad, None 

64 else: 

65 # This is the first time this Switch is visited. It comes from the 

66 # Identity branch. Such a Switch has `None` gradient for the Exit branch, 

67 # meaning the output is not differentiable. 

68 return None, None 

69 elif isinstance(op_ctxt, CondContext): 

70 zero_grad = grad[1 - op_ctxt.branch] 

71 # At this point, we have created zero_grad guarded by the right switch. 

72 # Unfortunately, we may still get None here for not trainable data types. 

73 if zero_grad is None: 

74 # For resource variables we get None always on the other branch, so bypass 

75 # this. 

76 if op.inputs[0].dtype == dtypes.resource: 

77 return merge( 

78 [grad[op_ctxt.branch]] * 2, name="cond_resource_grad")[0], None 

79 return None, None 

80 return merge(grad, name="cond_grad")[0], None 

81 else: 

82 false_grad = switch(grad[0], op.inputs[1])[0] 

83 true_grad = switch(grad[1], op.inputs[1])[1] 

84 return merge([false_grad, true_grad])[0], None 

85 

86 

87ops.RegisterGradient("Switch")(_SwitchGrad) 

88ops.RegisterGradient("RefSwitch")(_SwitchGrad) 

89 

90 

91@ops.RegisterGradient("Merge") 

92def _MergeGrad(op, grad, _): 

93 """Gradients for a Merge op are calculated using a Switch op.""" 

94 input_op = op.inputs[0].op 

95 graph = ops.get_default_graph() 

96 # pylint: disable=protected-access 

97 op_ctxt = control_flow_util.GetOutputContext(input_op) 

98 grad_ctxt = graph._get_control_flow_context() 

99 # pylint: enable=protected-access 

100 if isinstance(op_ctxt, WhileContext): 

101 # pylint: disable=protected-access 

102 return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot) 

103 # pylint: enable=protected-access 

104 elif isinstance(op_ctxt, CondContext): 

105 pred = op_ctxt.pred 

106 if grad_ctxt and grad_ctxt.grad_state: 

107 # This Merge node is part of a cond within a loop. 

108 # The backprop needs to have the value of this predicate for every 

109 # iteration. So we must have its values accumulated in the forward, and 

110 # use the accumulated values as the predicate for this backprop switch. 

111 grad_state = grad_ctxt.grad_state 

112 real_pred = grad_state.history_map.get(pred.name) 

113 if real_pred is None: 

114 # Remember the value of pred for every iteration. 

115 grad_ctxt = grad_state.grad_context 

116 grad_ctxt.Exit() 

117 history_pred = grad_state.AddForwardAccumulator(pred) 

118 grad_ctxt.Enter() 

119 

120 # Add the stack pop op. If pred.op is in a (outer) CondContext, 

121 # the stack pop will be guarded with a switch. 

122 real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred) 

123 grad_state.history_map[pred.name] = real_pred 

124 pred = real_pred 

125 # pylint: disable=protected-access 

126 return control_flow_ops._SwitchRefOrTensor(grad, pred, name="cond_grad") 

127 # pylint: enable=protected-access 

128 else: 

129 num_inputs = len(op.inputs) 

130 cond = [math_ops.equal(op.outputs[1], i) for i in range(num_inputs)] 

131 # pylint: disable=protected-access 

132 return [ 

133 control_flow_ops._SwitchRefOrTensor(grad, cond[i])[1] 

134 for i in range(num_inputs) 

135 ] 

136 # pylint: enable=protected-access 

137 

138 

139@ops.RegisterGradient("RefMerge") 

140def _RefMergeGrad(op, grad, _): 

141 return _MergeGrad(op, grad, _) 

142 

143 

144@ops.RegisterGradient("Exit") 

145def _ExitGrad(op, grad): 

146 """Gradients for an exit op are calculated using an Enter op.""" 

147 graph = ops.get_default_graph() 

148 # pylint: disable=protected-access 

149 op_ctxt = op._get_control_flow_context() 

150 grad_ctxt = graph._get_control_flow_context() 

151 # pylint: enable=protected-access 

152 if not grad_ctxt.back_prop: 

153 # The flag `back_prop` is set by users to suppress gradient 

154 # computation for this loop. If the attribute `back_prop` is false, 

155 # no gradient computation. 

156 return None 

157 

158 if op_ctxt.grad_state: 

159 raise TypeError("Second-order gradient for while loops not supported.") 

160 

161 if isinstance(grad, ops.Tensor): 

162 grad_ctxt.AddName(grad.name) 

163 else: 

164 if not isinstance( 

165 grad, (indexed_slices.IndexedSlices, sparse_tensor.SparseTensor)): 

166 raise TypeError(f"Type {type(grad)} not supported, must be either" 

167 "`indexed_slices.IndexedSlices` or `SparseTensor`.") 

168 grad_ctxt.AddName(grad.values.name) 

169 grad_ctxt.AddName(grad.indices.name) 

170 dense_shape = grad.dense_shape 

171 if dense_shape is not None: 

172 grad_ctxt.AddName(dense_shape.name) 

173 grad_ctxt.Enter() 

174 # pylint: disable=protected-access 

175 result = control_flow_ops._Enter( 

176 grad, grad_ctxt.name, is_constant=False, 

177 parallel_iterations=grad_ctxt.parallel_iterations, 

178 name="b_exit") 

179 # pylint: enable=protected-access 

180 grad_ctxt.loop_enters.append(result) 

181 grad_ctxt.Exit() 

182 return result 

183 

184 

185ops.RegisterGradient("RefExit")(_ExitGrad) 

186 

187 

188@ops.RegisterGradient("NextIteration") 

189def _NextIterationGrad(_, grad): 

190 """A forward next_iteration is translated into a backprop identity. 

191 

192 Note that the backprop next_iteration is added in switch grad. 

193 """ 

194 return grad 

195 

196 

197@ops.RegisterGradient("RefNextIteration") 

198def _RefNextIterationGrad(_, grad): 

199 return _NextIterationGrad(_, grad) 

200 

201 

202@ops.RegisterGradient("Enter") 

203def _EnterGrad(op, grad): 

204 """Gradients for an Enter are calculated using an Exit op. 

205 

206 For loop variables, grad is the gradient so just add an exit. 

207 For loop invariants, we need to add an accumulator loop. 

208 """ 

209 graph = ops.get_default_graph() 

210 # pylint: disable=protected-access 

211 grad_ctxt = graph._get_control_flow_context() 

212 # pylint: enable=protected-access 

213 if grad_ctxt is None: 

214 return grad 

215 if not grad_ctxt.back_prop: 

216 # Skip gradient computation, if the attribute `back_prop` is false. 

217 return grad 

218 if grad_ctxt.grad_state is None: 

219 # Pass the gradient through if we are not in a gradient while context. 

220 return grad 

221 if op.get_attr("is_constant"): 

222 # Add a gradient accumulator for each loop invariant. 

223 if isinstance(grad, ops.Tensor): 

224 result = grad_ctxt.AddBackpropAccumulator(op, grad) 

225 elif isinstance(grad, indexed_slices.IndexedSlices): 

226 result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad) 

227 else: 

228 # TODO(yuanbyu, lukasr): Add support for SparseTensor. 

229 raise TypeError(f"Type {type(grad)} not supported," 

230 "must be Tensor or Indexed Slices") 

231 else: 

232 result = exit(grad) 

233 grad_ctxt.loop_exits.append(result) 

234 grad_ctxt.ExitResult([result]) 

235 return result 

236 

237 

238@ops.RegisterGradient("RefEnter") 

239def _RefEnterGrad(op, grad): 

240 return _EnterGrad(op, grad) 

241 

242 

243@ops.RegisterGradient("LoopCond") 

244def _LoopCondGrad(_): 

245 """Stop backprop for the predicate of a while loop.""" 

246 return None