Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/cond.py: 22%
133 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Cond function for Control Flow Operations."""
17from tensorflow.python.eager import context
18from tensorflow.python.eager.polymorphic_function import eager_function_run
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import indexed_slices
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import control_flow_util as util
25from tensorflow.python.ops import math_ops
26from tensorflow.python.platform import tf_logging as logging
27from tensorflow.python.types import core
28from tensorflow.python.util import deprecation
29from tensorflow.python.util import dispatch
30from tensorflow.python.util import nest
31from tensorflow.python.util.lazy_loader import LazyLoader
32from tensorflow.python.util.tf_export import tf_export
34# TODO(b/269483538): below lazy loads
35# needed for references while refactors are in progress
36control_flow_ops = LazyLoader(
37 "control_flow_ops", globals(),
38 "tensorflow.python.ops.control_flow_ops")
39# This is to avoid a circular dependency:
40# cond_v2 -> gradients_util -> control_flow_ops
41cond_v2 = LazyLoader("cond_v2", globals(),
42 "tensorflow.python.ops.cond_v2")
45# pylint: disable=redefined-outer-name
46# pylint: disable=g-doc-args
47@tf_export(v1=["cond"])
48@dispatch.add_dispatch_support
49@deprecation.deprecated_args(
50 None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
51 "fn1", "fn2")
52def cond(pred,
53 true_fn=None,
54 false_fn=None,
55 strict=False,
56 name=None,
57 fn1=None,
58 fn2=None):
59 """Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
61 `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
62 `false_fn` must have the same non-zero number and type of outputs.
64 **WARNING**: Any Tensors or Operations created outside of `true_fn` and
65 `false_fn` will be executed regardless of which branch is selected at runtime.
67 Although this behavior is consistent with the dataflow model of TensorFlow,
68 it has frequently surprised users who expected a lazier semantics.
69 Consider the following simple program:
71 ```python
72 z = tf.multiply(a, b)
73 result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
74 ```
76 If `x < y`, the `tf.add` operation will be executed and `tf.square`
77 operation will not be executed. Since `z` is needed for at least one
78 branch of the `cond`, the `tf.multiply` operation is always executed,
79 unconditionally.
81 Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
82 call to `cond`, and not at all during `Session.run()`). `cond`
83 stitches together the graph fragments created during the `true_fn` and
84 `false_fn` calls with some additional graph nodes to ensure that the right
85 branch gets executed depending on the value of `pred`.
87 `tf.cond` supports nested structures as implemented in
88 `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
89 same (possibly nested) value structure of lists, tuples, and/or named tuples.
90 Singleton lists and tuples form the only exceptions to this: when returned by
91 `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
92 This behavior is disabled by passing `strict=True`.
94 Args:
95 pred: A scalar determining whether to return the result of `true_fn` or
96 `false_fn`.
97 true_fn: The callable to be performed if pred is true.
98 false_fn: The callable to be performed if pred is false.
99 strict: A boolean that enables/disables 'strict' mode; see above.
100 name: Optional name prefix for the returned tensors.
102 Returns:
103 Tensors returned by the call to either `true_fn` or `false_fn`. If the
104 callables return a singleton list, the element is extracted from the list.
106 Raises:
107 TypeError: if `true_fn` or `false_fn` is not callable.
108 ValueError: if `true_fn` and `false_fn` do not return the same number of
109 tensors, or return tensors of different types.
111 Example:
113 ```python
114 x = tf.constant(2)
115 y = tf.constant(5)
116 def f1(): return tf.multiply(x, 17)
117 def f2(): return tf.add(y, 23)
118 r = tf.cond(tf.less(x, y), f1, f2)
119 # r is set to f1().
120 # Operations in f2 (e.g., tf.add) are not executed.
121 ```
123 """
124 # We needed to make true_fn/false_fn keyword arguments for
125 # backwards-compatibility. This check exists so that we can convert back to
126 # having them be positional arguments.
127 # TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after
128 # `fn1` and `fn2` are deleted.
129 if fn1 is not None:
130 if true_fn is not None:
131 raise TypeError(
132 "cond(): 'true_fn' and 'fn1' may not be set simultaneously.")
133 true_fn = fn1
134 elif true_fn is None:
135 raise TypeError("cond(): 'true_fn' argument required")
136 if fn2 is not None:
137 if false_fn is not None:
138 raise TypeError(
139 "cond(): 'false_fn' and 'fn2' may not be set simultaneously.")
140 false_fn = fn2
141 elif false_fn is None:
142 raise TypeError("cond(): 'false_fn' argument required")
144 if not callable(true_fn):
145 raise TypeError("'true_fn' must be callable.")
146 if not callable(false_fn):
147 raise TypeError("'false_fn' must be callable.")
149 if context.executing_eagerly():
150 return _eager_cond_implementation(pred, true_fn, false_fn, strict, name)
152 # Always enable control flow v2 if building a function, regardless of toggle.
153 if util.EnableControlFlowV2(ops.get_default_graph()):
154 return cond_v2.cond_v2(pred, true_fn, false_fn, name)
156 with ops.name_scope(name, "cond", [pred]):
157 # Add the Switch to the graph.
158 if isinstance(pred, bool):
159 raise TypeError("'pred' must not be a Python bool.")
160 p_2, p_1 = control_flow_ops.switch(pred, pred)
161 pivot_1 = array_ops.identity(p_1, name="switch_t")
162 pivot_2 = array_ops.identity(p_2, name="switch_f")
163 pred = array_ops.identity(pred, name="pred_id")
164 # Disable the fetching of tensors that are only on one branch of cond.
165 for tensor in [p_1, p_2, pivot_1, pivot_2, pred]:
166 tensor.op.graph.prevent_fetching(tensor.op)
168 # Build the graph for the true branch in a new context.
169 context_t = control_flow_ops.CondContext(pred, pivot_1, branch=1)
170 try:
171 context_t.Enter()
172 orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
173 if orig_res_t is None:
174 raise ValueError("'true_fn' must have a return value.")
175 context_t.ExitResult(res_t)
176 finally:
177 context_t.Exit()
179 # Build the graph for the false branch in a new context.
180 context_f = control_flow_ops.CondContext(pred, pivot_2, branch=0)
181 try:
182 context_f.Enter()
183 orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
184 if orig_res_f is None:
185 raise ValueError("'false_fn' must have a return value.")
186 context_f.ExitResult(res_f)
187 finally:
188 context_f.Exit()
190 if not strict:
191 orig_res_t = _UnpackIfSingleton(orig_res_t)
192 orig_res_f = _UnpackIfSingleton(orig_res_f)
194 # Check that the return values of the two branches have the same structure.
195 try:
196 nest.assert_same_structure(orig_res_t, orig_res_f, expand_composites=True)
197 except (TypeError, ValueError):
198 nest.map_structure(_cast_indexed_slice_indices, orig_res_t, orig_res_f)
199 nest.map_structure(_cast_indexed_slice_indices, res_t, res_f)
200 try:
201 nest.assert_same_structure(orig_res_t, orig_res_f,
202 expand_composites=True)
203 except TypeError as e:
204 raise TypeError(
205 f"Incompatible return types of 'true_fn' and 'false_fn': {e}")
206 except ValueError as e:
207 raise ValueError(
208 f"Incompatible return values of 'true_fn' and 'false_fn': {e}")
210 # Add the final merge to the graph.
211 if not res_t:
212 raise ValueError(
213 "'true_fn' and 'false_fn' must return at least one result.")
215 res_t_flat = nest.flatten(res_t, expand_composites=True)
216 res_f_flat = nest.flatten(res_f, expand_composites=True)
218 for (x, y) in zip(res_t_flat, res_f_flat):
219 assert isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor)
220 if x.dtype.base_dtype != y.dtype.base_dtype:
221 raise ValueError(
222 "Outputs of 'true_fn' and 'false_fn' must have the same type(s). "
223 f"Received {x.dtype.name} from 'true_fn' "
224 f"and {y.dtype.name} from 'false_fn'.")
226 merges = [
227 control_flow_ops.merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)]
228 merges = nest.map_structure(
229 control_flow_ops._convert_flow_to_tensorarray, # pylint: disable=protected-access
230 nest.flatten(orig_res_t, expand_composites=True),
231 merges)
233 # Only add non-nested conds to the collection. Any nested control flow will
234 # be encapsulated in the root context.
235 assert context_t.outer_context == context_f.outer_context
236 if context_t.outer_context is None:
237 ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t)
238 ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f)
240 merges = nest.pack_sequence_as(
241 structure=orig_res_t, flat_sequence=merges, expand_composites=True)
243 # Singleton lists and tuples are automatically unpacked if strict == False.
244 if not strict:
245 merges = _UnpackIfSingleton(merges)
246 return merges
249@tf_export("cond", v1=[])
250@dispatch.add_dispatch_support
251def cond_for_tf_v2(pred, true_fn=None, false_fn=None, name=None):
252 """Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
254 Note: This op is automatically used in a `tf.function` to convert Python
255 if-statements when the predicate is a `tf.Tensor`, unless `autograph=False` is
256 explicitly specified in `tf.function` args. For example, the following are
257 equivalent:
259 >>> @tf.function
260 ... def fun1(x,y):
261 ... if x > 0: # AutoGraph converts if-statement to tf.cond().
262 ... z = y+1
263 ... else:
264 ... z = y-1
265 ... return z
266 >>> fun1(tf.constant(7), tf.constant(3)).numpy()
267 4
269 >>> @tf.function
270 ... def fun2(x,y):
271 ... pred = x > 0
272 ... true_fn = lambda: y+1
273 ... false_fn = lambda: y-1
274 ... return tf.cond(pred, true_fn, false_fn) # Use tf.cond() explicitly.
275 >>> fun1(tf.constant(7), tf.constant(3)).numpy()
276 4
278 For more information, see [tf.function and AutoGraph guide](
279 https://www.tensorflow.org/guide/function#autograph_transformations).
281 `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
282 `false_fn` must have the same non-zero number and type of outputs.
284 **WARNING**: Any Tensors or Operations created outside of `true_fn` and
285 `false_fn` will be executed regardless of which branch is selected at runtime.
287 Although this behavior is consistent with the dataflow model of TensorFlow,
288 it has frequently surprised users who expected a lazier semantics.
289 Consider the following simple program:
291 >>> x, y = tf.constant(2, dtype=tf.int32), tf.constant(4, dtype=tf.int32)
292 >>> z = tf.multiply(x, y)
293 >>> r = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
294 >>> r.numpy()
295 10
297 If `x < y`, the `tf.add` operation will be executed and `tf.square`
298 operation will not be executed. Since `z` is needed for at least one
299 branch of the `cond`, the `tf.multiply` operation is always executed,
300 unconditionally.
302 Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
303 call to `cond`, and not at all during `Session.run()`). `cond`
304 stitches together the graph fragments created during the `true_fn` and
305 `false_fn` calls with some additional graph nodes to ensure that the right
306 branch gets executed depending on the value of `pred`.
308 `tf.cond` supports nested structures as implemented in
309 `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
310 same (possibly nested) value structure of lists, tuples, and/or named tuples.
311 Singleton lists and tuples form the only exceptions to this: when returned by
312 `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
314 Note: It is illegal to "directly" use tensors created inside a cond branch
315 outside it, e.g. by storing a reference to a branch tensor in the python
316 state. If you need to use a tensor created in a branch function you should
317 return it as an output of the branch function and use the output from
318 `tf.cond` instead.
320 Args:
321 pred: A scalar determining whether to return the result of `true_fn` or
322 `false_fn`.
323 true_fn: The callable to be performed if pred is true.
324 false_fn: The callable to be performed if pred is false.
325 name: Optional name prefix for the returned tensors.
327 Returns:
328 Tensors returned by the call to either `true_fn` or `false_fn`. If the
329 callables return a singleton list, the element is extracted from the list.
331 Raises:
332 TypeError: if `true_fn` or `false_fn` is not callable.
333 ValueError: if `true_fn` and `false_fn` do not return the same number of
334 tensors, or return tensors of different types.
336 Example:
338 >>> x = tf.constant(2)
339 >>> y = tf.constant(5)
340 >>> def f1(): return tf.multiply(x, 7)
341 >>> def f2(): return tf.add(y, 3)
342 >>> r = tf.cond(tf.less(x, y), f1, f2)
343 >>> # r is set to f1().
344 >>> # Operations in f2 (e.g., tf.add) are not executed.
345 >>> r.numpy()
346 14
348 """
349 return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name)
352def _UnpackIfSingleton(res):
353 if isinstance(res, (list, tuple)) and len(res) == 1:
354 return res[0]
355 else:
356 return res
359def _eager_cond_implementation(pred, true_fn, false_fn, strict, name):
360 """Special cases for `cond` when executing eagerly."""
361 pred = ops.convert_to_tensor(pred)
362 pred_constant_value = tensor_util.constant_value(pred)
363 if pred_constant_value is None:
364 # Eager tensors from a parallel device may not have a constant
365 # value. Running the cond op itself would work, but we don't have logic to
366 # build cond ops without wrapping in a function first.
367 if (not isinstance(true_fn, core.GenericFunction)
368 or not isinstance(false_fn, core.GenericFunction)):
369 raise TypeError("When running tf.cond on a parallel device, 'true_fn' "
370 "and 'false_fn' must be decorated with `tf.function`.")
371 functions_run_eagerly = eager_function_run.functions_run_eagerly()
372 if functions_run_eagerly:
373 # We need to use tf.function to deal with variable creation inside the
374 # cond, and skipping it because of run_functions_eagerly would just
375 # crash immediately.
376 logging.warning(
377 "It looks like tf.function behavior was disabled, perhaps using "
378 "tf.config.run_functions_eagerly. Parallelized tf.cond requires "
379 "tf.function to work. This primitive will override the disable.")
380 eager_function_run.run_functions_eagerly(False)
381 try:
382 return cond_v2.cond_v2(pred, true_fn, false_fn, name)
383 finally:
384 if functions_run_eagerly is not None:
385 eager_function_run.run_functions_eagerly(functions_run_eagerly)
386 else:
387 # For conditions which are eager tensors with a constant value (most of
388 # them), we only call the relevant branch function and execute it eagerly.
389 with ops.name_scope(name, "cond", [pred]):
390 if pred_constant_value:
391 result = true_fn()
392 else:
393 result = false_fn()
394 if not strict:
395 result = _UnpackIfSingleton(result)
396 return result
399def _cast_indexed_slice_indices(a, b):
400 """Cast IndexedSlice.indices from int32 to int64 where necessary.
402 If `a` and `b` are both IndexedSlices, and their indices have different
403 dtypes, then cast both their dtypes to `int64` (modifies `a` and `b`
404 in-place). Otherwise, does nothing.
406 Args:
407 a: A value, which may be an IndexedSlices.
408 b: A value, which may be an IndexedSlices.
409 """
410 if (isinstance(a, indexed_slices.IndexedSlices) and
411 isinstance(b, indexed_slices.IndexedSlices) and
412 a.indices.dtype != b.indices.dtype):
413 # pylint: disable=protected-access
414 a._indices = math_ops.cast(a.indices, dtypes.int64)
415 b._indices = math_ops.cast(b.indices, dtypes.int64)