Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_grad.py: 24%
120 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""Gradients for operators defined in control_flow_ops.py."""
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import indexed_slices
20from tensorflow.python.framework import ops
21from tensorflow.python.framework import sparse_tensor
22from tensorflow.python.ops import control_flow_ops
23from tensorflow.python.ops import control_flow_util
24from tensorflow.python.ops import math_ops
25# go/tf-wildcard-import
26# pylint: disable=wildcard-import,undefined-variable,redefined-builtin
27from tensorflow.python.ops.control_flow_ops import *
28# pylint: enable=wildcard-import
31def _SwitchGrad(op, *grad):
32 """Gradients for a Switch op is calculated using a Merge op.
34 If the switch is a loop switch, it will be visited twice. We create
35 the merge on the first visit, and update the other input of the merge
36 on the second visit. A next_iteration is also added on second visit.
37 """
38 graph = ops.get_default_graph()
39 # pylint: disable=protected-access
40 op_ctxt = op._get_control_flow_context()
41 grad_ctxt = graph._get_control_flow_context()
42 # pylint: enable=protected-access
43 if isinstance(op_ctxt, WhileContext):
44 merge_grad = grad_ctxt.grad_state.switch_map.get(op)
45 if merge_grad is not None:
46 # This is the second time this Switch is visited. It comes from
47 # the non-exit branch of the Switch, so update the second input
48 # to the Merge.
49 # TODO(yuanbyu): Perform shape inference with this new input.
50 if grad[1] is not None:
51 # pylint: disable=protected-access
52 control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1],
53 enforce_shape_invariant=False)
54 # pylint: enable=protected-access
55 return None, None
56 elif grad[0] is not None:
57 # This is the first time this Switch is visited. It comes from
58 # the Exit branch, which is grad[0]. grad[1] is empty at this point.
59 # Use grad[0] for both inputs to merge for now, but update the second
60 # input of merge when we see this Switch the second time.
61 merge_grad = merge([grad[0], grad[0]], name="b_switch")[0]
62 grad_ctxt.grad_state.switch_map[op] = merge_grad
63 return merge_grad, None
64 else:
65 # This is the first time this Switch is visited. It comes from the
66 # Identity branch. Such a Switch has `None` gradient for the Exit branch,
67 # meaning the output is not differentiable.
68 return None, None
69 elif isinstance(op_ctxt, CondContext):
70 zero_grad = grad[1 - op_ctxt.branch]
71 # At this point, we have created zero_grad guarded by the right switch.
72 # Unfortunately, we may still get None here for not trainable data types.
73 if zero_grad is None:
74 # For resource variables we get None always on the other branch, so bypass
75 # this.
76 if op.inputs[0].dtype == dtypes.resource:
77 return merge(
78 [grad[op_ctxt.branch]] * 2, name="cond_resource_grad")[0], None
79 return None, None
80 return merge(grad, name="cond_grad")[0], None
81 else:
82 false_grad = switch(grad[0], op.inputs[1])[0]
83 true_grad = switch(grad[1], op.inputs[1])[1]
84 return merge([false_grad, true_grad])[0], None
87ops.RegisterGradient("Switch")(_SwitchGrad)
88ops.RegisterGradient("RefSwitch")(_SwitchGrad)
91@ops.RegisterGradient("Merge")
92def _MergeGrad(op, grad, _):
93 """Gradients for a Merge op are calculated using a Switch op."""
94 input_op = op.inputs[0].op
95 graph = ops.get_default_graph()
96 # pylint: disable=protected-access
97 op_ctxt = control_flow_util.GetOutputContext(input_op)
98 grad_ctxt = graph._get_control_flow_context()
99 # pylint: enable=protected-access
100 if isinstance(op_ctxt, WhileContext):
101 # pylint: disable=protected-access
102 return control_flow_ops._SwitchRefOrTensor(grad, grad_ctxt.pivot)
103 # pylint: enable=protected-access
104 elif isinstance(op_ctxt, CondContext):
105 pred = op_ctxt.pred
106 if grad_ctxt and grad_ctxt.grad_state:
107 # This Merge node is part of a cond within a loop.
108 # The backprop needs to have the value of this predicate for every
109 # iteration. So we must have its values accumulated in the forward, and
110 # use the accumulated values as the predicate for this backprop switch.
111 grad_state = grad_ctxt.grad_state
112 real_pred = grad_state.history_map.get(pred.name)
113 if real_pred is None:
114 # Remember the value of pred for every iteration.
115 grad_ctxt = grad_state.grad_context
116 grad_ctxt.Exit()
117 history_pred = grad_state.AddForwardAccumulator(pred)
118 grad_ctxt.Enter()
120 # Add the stack pop op. If pred.op is in a (outer) CondContext,
121 # the stack pop will be guarded with a switch.
122 real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred)
123 grad_state.history_map[pred.name] = real_pred
124 pred = real_pred
125 # pylint: disable=protected-access
126 return control_flow_ops._SwitchRefOrTensor(grad, pred, name="cond_grad")
127 # pylint: enable=protected-access
128 else:
129 num_inputs = len(op.inputs)
130 cond = [math_ops.equal(op.outputs[1], i) for i in range(num_inputs)]
131 # pylint: disable=protected-access
132 return [
133 control_flow_ops._SwitchRefOrTensor(grad, cond[i])[1]
134 for i in range(num_inputs)
135 ]
136 # pylint: enable=protected-access
139@ops.RegisterGradient("RefMerge")
140def _RefMergeGrad(op, grad, _):
141 return _MergeGrad(op, grad, _)
144@ops.RegisterGradient("Exit")
145def _ExitGrad(op, grad):
146 """Gradients for an exit op are calculated using an Enter op."""
147 graph = ops.get_default_graph()
148 # pylint: disable=protected-access
149 op_ctxt = op._get_control_flow_context()
150 grad_ctxt = graph._get_control_flow_context()
151 # pylint: enable=protected-access
152 if not grad_ctxt.back_prop:
153 # The flag `back_prop` is set by users to suppress gradient
154 # computation for this loop. If the attribute `back_prop` is false,
155 # no gradient computation.
156 return None
158 if op_ctxt.grad_state:
159 raise TypeError("Second-order gradient for while loops not supported.")
161 if isinstance(grad, ops.Tensor):
162 grad_ctxt.AddName(grad.name)
163 else:
164 if not isinstance(
165 grad, (indexed_slices.IndexedSlices, sparse_tensor.SparseTensor)):
166 raise TypeError(f"Type {type(grad)} not supported, must be either"
167 "`indexed_slices.IndexedSlices` or `SparseTensor`.")
168 grad_ctxt.AddName(grad.values.name)
169 grad_ctxt.AddName(grad.indices.name)
170 dense_shape = grad.dense_shape
171 if dense_shape is not None:
172 grad_ctxt.AddName(dense_shape.name)
173 grad_ctxt.Enter()
174 # pylint: disable=protected-access
175 result = control_flow_ops._Enter(
176 grad, grad_ctxt.name, is_constant=False,
177 parallel_iterations=grad_ctxt.parallel_iterations,
178 name="b_exit")
179 # pylint: enable=protected-access
180 grad_ctxt.loop_enters.append(result)
181 grad_ctxt.Exit()
182 return result
185ops.RegisterGradient("RefExit")(_ExitGrad)
188@ops.RegisterGradient("NextIteration")
189def _NextIterationGrad(_, grad):
190 """A forward next_iteration is translated into a backprop identity.
192 Note that the backprop next_iteration is added in switch grad.
193 """
194 return grad
197@ops.RegisterGradient("RefNextIteration")
198def _RefNextIterationGrad(_, grad):
199 return _NextIterationGrad(_, grad)
202@ops.RegisterGradient("Enter")
203def _EnterGrad(op, grad):
204 """Gradients for an Enter are calculated using an Exit op.
206 For loop variables, grad is the gradient so just add an exit.
207 For loop invariants, we need to add an accumulator loop.
208 """
209 graph = ops.get_default_graph()
210 # pylint: disable=protected-access
211 grad_ctxt = graph._get_control_flow_context()
212 # pylint: enable=protected-access
213 if grad_ctxt is None:
214 return grad
215 if not grad_ctxt.back_prop:
216 # Skip gradient computation, if the attribute `back_prop` is false.
217 return grad
218 if grad_ctxt.grad_state is None:
219 # Pass the gradient through if we are not in a gradient while context.
220 return grad
221 if op.get_attr("is_constant"):
222 # Add a gradient accumulator for each loop invariant.
223 if isinstance(grad, ops.Tensor):
224 result = grad_ctxt.AddBackpropAccumulator(op, grad)
225 elif isinstance(grad, indexed_slices.IndexedSlices):
226 result = grad_ctxt.AddBackpropIndexedSlicesAccumulator(op, grad)
227 else:
228 # TODO(yuanbyu, lukasr): Add support for SparseTensor.
229 raise TypeError(f"Type {type(grad)} not supported,"
230 "must be Tensor or Indexed Slices")
231 else:
232 result = exit(grad)
233 grad_ctxt.loop_exits.append(result)
234 grad_ctxt.ExitResult([result])
235 return result
238@ops.RegisterGradient("RefEnter")
239def _RefEnterGrad(op, grad):
240 return _EnterGrad(op, grad)
243@ops.RegisterGradient("LoopCond")
244def _LoopCondGrad(_):
245 """Stop backprop for the predicate of a while loop."""
246 return None