Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_util.py: 21%
153 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""Utility functions for control flow.
18This file is necessary to avoid cyclic dependencies between ops.py and
19control_flow_ops.py.
20"""
22import os
23import traceback
25from tensorflow.python import tf2
26from tensorflow.python.platform import tf_logging as logging
28ENABLE_CONTROL_FLOW_V2 = ((tf2.enabled() and
29 os.getenv("TF_ENABLE_CONTROL_FLOW_V2") != "0") or
30 os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or
31 os.getenv("TF_ENABLE_COND_V2", "0") != "0" or
32 os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" or
33 os.getenv("TF_ENABLE_TENSOR_ARRAY_V2", "0") != "0")
36# TODO(b/137793122): Remove this.
37def enable_control_flow_v2(): # pylint: disable=invalid-name
38 """Use control flow v2.
40 Do not use this symbol. This will be removed.
41 """
42 global ENABLE_CONTROL_FLOW_V2
43 ENABLE_CONTROL_FLOW_V2 = True
46def EnableControlFlowV2(graph):
47 """Returns whether control flow v2 should be used in `graph`."""
48 # Enable new control flow in FuncGraphs (but not legacy _FuncGraphs).
49 # TODO(skyewm): do something better than hasattr without messing up imports.
50 return ENABLE_CONTROL_FLOW_V2 or (
51 graph.building_function and not hasattr(graph, "_captured"))
54def IsInXLAContext(op):
55 try:
56 xla_compile = op.get_attr("_XlaCompile")
57 if xla_compile: return True
58 except ValueError:
59 pass
60 ctxt = op._get_control_flow_context() # pylint: disable=protected-access
61 return GetContainingXLAContext(ctxt) is not None
64def InXlaContext(graph):
65 ctxt = graph._get_control_flow_context() # pylint: disable=protected-access
66 return GetContainingXLAContext(ctxt) is not None
69def GraphOrParentsInXlaContext(graph):
70 while True:
71 if InXlaContext(graph): return True
72 try:
73 graph = graph.outer_graph
74 except AttributeError:
75 return False
78def IsInWhileLoop(op):
79 ctxt = op._get_control_flow_context() # pylint: disable=protected-access
80 return GetContainingWhileContext(ctxt) is not None
83def IsInCond(op):
84 ctxt = op._get_control_flow_context() # pylint: disable=protected-access
85 return GetContainingCondContext(ctxt) is not None
88def IsSwitch(op):
89 """Return true if `op` is a Switch."""
90 return op.type == "Switch" or op.type == "RefSwitch"
93def IsMerge(op):
94 """Return true if `op` is a Merge."""
95 return op.type == "Merge" or op.type == "RefMerge"
98def IsLoopEnter(op):
99 """Returns true if `op` is an Enter."""
100 return op.type == "Enter" or op.type == "RefEnter"
103def IsLoopExit(op):
104 """Return true if `op` is an Exit."""
105 return op.type == "Exit" or op.type == "RefExit"
108def IsCondSwitch(op):
109 """Return true if `op` is the Switch for a conditional."""
110 if not IsSwitch(op):
111 return False
112 if not op.outputs:
113 return False
114 # Switch nodes are not part of the cond control flow context that they
115 # represent, so consider the consumers of its outputs to determine if it is
116 # cond switch or not. A switch is a cond switch iff all its consumers are in
117 # cond contexts.
118 is_cond_switch = True
119 for o in op.outputs:
120 for c in o.consumers():
121 ctxt = c._get_control_flow_context() # pylint: disable=protected-access
122 if IsLoopEnter(c):
123 ctxt = ctxt.outer_context
124 is_cond_switch = is_cond_switch and (ctxt is not None and
125 ctxt.IsCondContext())
126 return is_cond_switch
129def IsCondMerge(op):
130 """Return true if `op` is the Merge for a conditional."""
131 if not IsMerge(op):
132 return False
133 if not op.inputs:
134 return False
135 # Merge nodes are not part of the cond control flow context that they
136 # represent, so consider the inputs to the merge of to determine if it is
137 # cond merge or not: A merge is a cond merge iff all its inputs are in
138 # cond contexts.
139 is_cond_merge = True
140 for i in op.inputs:
141 ctxt = GetOutputContext(i.op)
142 is_cond_merge = is_cond_merge and ctxt is not None and ctxt.IsCondContext()
143 return is_cond_merge
146def IsLoopSwitch(op):
147 """Return true if `op` is the Switch for a while loop."""
148 if IsSwitch(op):
149 ctxt = op._get_control_flow_context() # pylint: disable=protected-access
150 return ctxt is not None and ctxt.IsWhileContext() and not IsCondSwitch(op)
151 return False
154def IsLoopMerge(op):
155 """Return true if `op` is the Merge for a while loop."""
156 if IsMerge(op):
157 ctxt = op._get_control_flow_context() # pylint: disable=protected-access
158 return ctxt is not None and ctxt.IsWhileContext() and not IsCondMerge(op)
159 return False
162def IsLoopConstantEnter(op):
163 """Return true iff op is a loop invariant."""
164 return IsLoopEnter(op) and op.get_attr("is_constant")
167def GetLoopConstantEnter(value):
168 """Return the enter op if we can infer `value` to be a loop invariant."""
169 id_ops = {"Switch", "RefSwitch", "Identity", "RefIdentity"}
170 op = value.op
171 while op.type in id_ops:
172 op = op.inputs[0].op
173 return op if IsLoopConstantEnter(op) else None
176def GetOutputContext(op):
177 """Return the control flow context for the output of an op."""
178 ctxt = op._get_control_flow_context() # pylint: disable=protected-access
179 # Exit nodes usually have a control flow context, except in the case where the
180 # exit node was imported via import_graph_def (in which case no nodes have
181 # control flow contexts).
182 if ctxt is not None and IsLoopExit(op):
183 ctxt = ctxt.outer_context
184 return ctxt
187def GetContainingWhileContext(ctxt, stop_ctxt=None):
188 """Returns the first ancestor WhileContext of `ctxt`.
190 Returns `ctxt` if `ctxt` is a WhileContext, or None if `ctxt` is not in a
191 while loop.
193 Args:
194 ctxt: ControlFlowContext
195 stop_ctxt: ControlFlowContext, optional. If provided, the search will end
196 if it sees stop_ctxt.
198 Returns:
199 `ctxt` if `ctxt` is a WhileContext, the most nested WhileContext containing
200 `ctxt`, or None if `ctxt` is not in a while loop. If `stop_ctxt` is not
201 `None`, this returns `ctxt` if it matches `stop_ctxt` in its traversal.
202 """
203 while ctxt:
204 if ctxt.IsWhileContext() or ctxt == stop_ctxt: return ctxt
205 ctxt = ctxt.outer_context
206 return None
209def GetContainingXLAContext(ctxt):
210 """Returns the first ancestor XLAContext of `ctxt`.
212 Returns `ctxt` if `ctxt` is a XLAContext, or None if `ctxt` is not in a
213 while loop.
215 Args:
216 ctxt: ControlFlowContext
218 Returns:
219 `ctxt` if `ctxt` is a XLAContext, the most nested XLAContext containing
220 `ctxt`, or None if `ctxt` is not in a while loop.
221 """
222 while ctxt:
223 if ctxt.IsXLAContext(): return ctxt
224 ctxt = ctxt.outer_context
225 return None
228def GetContainingCondContext(ctxt):
229 """Returns the first ancestor CondContext of `ctxt`.
231 Returns `ctxt` if `ctxt` is a CondContext, or None if `ctxt` is not in a cond.
233 Args:
234 ctxt: ControlFlowContext
236 Returns:
237 `ctxt` if `ctxt` is a CondContext, the most nested CondContext containing
238 `ctxt`, or None if `ctxt` is not in a cond.
239 """
240 while ctxt:
241 if ctxt.IsCondContext(): return ctxt
242 ctxt = ctxt.outer_context
243 return None
246def IsContainingContext(ctxt, maybe_containing_ctxt):
247 """Returns true if `maybe_containing_ctxt` is or contains `ctxt`."""
248 while ctxt is not maybe_containing_ctxt:
249 if ctxt is None: return False
250 ctxt = ctxt.outer_context
251 return True
254def OpInContext(op, ctxt):
255 return IsContainingContext(op._get_control_flow_context(), ctxt) # pylint: disable=protected-access
258def TensorInContext(tensor, ctxt):
259 return OpInContext(tensor.op, ctxt)
262def CheckInputFromValidContext(op, input_op):
263 """Returns whether `input_op` can be used from `op`s context.
265 Conceptually, only inputs from op's while context or any ancestor while
266 context (including outside of any context) are valid. In practice, there are
267 many other edge cases as well.
269 Args:
270 op: Operation
271 input_op: Operation
273 Raises:
274 ValueError: if input_op is from an invalid context.
275 """
276 op_ctxt = op._get_control_flow_context() # pylint: disable=protected-access
277 input_ctxt = GetOutputContext(input_op)
278 valid = False
280 if not input_ctxt:
281 # input_op isn't in a control flow context.
282 valid = True
283 elif op_ctxt is input_ctxt:
284 # input_op is in the same context as op.
285 valid = True
286 else:
287 while_ctxt = GetContainingWhileContext(op_ctxt)
288 input_while_ctxt = GetContainingWhileContext(input_ctxt)
290 if while_ctxt is None:
291 if input_while_ctxt is None:
292 # Neither op nor input_op is in a while loop, but one or both are in
293 # conds. We allow this, although execution will fail if the branch
294 # corresponding to input_op's cond context isn't taken.
295 valid = True
296 # Invalid if op isn't in a while loop and input_op is. Unless...
297 if IsLoopEnter(op):
298 # WhileContext._BuildLoop clears context for Enter nodes.
299 valid = True
300 if IsSwitch(op):
301 # CondContext.AddValue clears context for Switch nodes.
302 valid = True
303 elif IsContainingContext(while_ctxt, input_while_ctxt):
304 # input_op is in a while loop which contains op's while loop (or not in a
305 # while loop at all).
306 valid = True
307 elif (while_ctxt.grad_state and
308 IsContainingContext(while_ctxt.grad_state.forward_context,
309 input_while_ctxt)):
310 # op is in a gradient context and input_op is in the associated forward
311 # pass context or an ancestor thereof. This case is need to build while
312 # loop gradients.
313 # NOTE(skyewm): we theoretically also need this case for custom gradient
314 # functions that close over tensors from ancestor contexts, but I haven't
315 # verified this.
316 valid = True
317 elif (while_ctxt.grad_state and
318 while_ctxt.grad_state.forward_context is
319 input_while_ctxt._outer_context): # pylint: disable=protected-access
320 # op is in a gradient context and input_op is in a child of the associated
321 # forward pass context. This case is needed for the gradients of while
322 # loops with conds.
323 valid = True
324 elif (input_while_ctxt.grad_state and
325 input_while_ctxt.grad_state.forward_context is while_ctxt):
326 # input_op is in the gradient context of op's context. This case is needed
327 # when the gradient of a while loop gradient is requested (this will
328 # eventually fail unless there is a stop_gradient() or similar).
329 valid = True
330 elif (input_while_ctxt.grad_state and
331 input_ctxt.grad_state.forward_context.grad_state and
332 input_ctxt.grad_state.forward_context.grad_state.forward_context is
333 while_ctxt):
334 # input_op is in the grad grad context of op's context. This case is
335 # needed when the gradient of a while loop gradient is requested (this
336 # will eventually fail unless there is a stop_gradient() or similar).
337 valid = True
339 if not valid:
340 if while_ctxt:
341 error_msg = (
342 f"Cannot use '{input_op.name}' as input to '{op.name}' because they "
343 "are in different while loops.")
344 else:
345 error_msg = (
346 f"Cannot use '{input_op.name}' as input to '{op.name}' because "
347 f"'{input_op.name}' is in a while loop.")
349 # Log the error message plus the relevant stack traces. The stacks may be
350 # useful for debugging this error, but we don't want to raise an
351 # unreadable exception.
352 log_msg = error_msg
353 log_msg += "\n\n%s while context: %s" % (op.name, while_ctxt)
354 log_msg += "\n%s while context: %s" % (input_op.name, input_while_ctxt)
355 log_msg += "\n\nTraceback for %s:\n%s\nTraceback for %s:\n%s\n" % (
356 op.name, "".join(traceback.format_list(op.traceback)),
357 input_op.name, "".join(traceback.format_list(input_op.traceback)))
358 logging.info(log_msg)
359 raise ValueError(error_msg + " See info log for more details.")
362def GetWhileContext(op):
363 """Get the WhileContext to which this op belongs."""
364 ctxt = op._get_control_flow_context() # pylint: disable=protected-access
365 if ctxt:
366 ctxt = ctxt.GetWhileContext()
367 return ctxt