Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/while_loop.py: 26%
81 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""While loop for Control Flow Operations."""
17from tensorflow.python.eager import context
18from tensorflow.python.framework import constant_op
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.framework import type_spec
22from tensorflow.python.ops import control_flow_util as util
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops import tensor_array_ops
25from tensorflow.python.util import deprecation
26from tensorflow.python.util import nest
27from tensorflow.python.util import variable_utils
28from tensorflow.python.util.lazy_loader import LazyLoader
29from tensorflow.python.util.tf_export import tf_export
32# TODO(b/269483538): below lazy loads
33# needed for references while refactors are in progress
34control_flow_ops = LazyLoader(
35 "control_flow_ops", globals(),
36 "tensorflow.python.ops.control_flow_ops")
37# This is to avoid circular dependencies:
38# while_v2 -> control_flow_ops
39# while_v2 -> gradients_util -> control_flow_ops
40while_v2 = LazyLoader("while_v2", globals(),
41 "tensorflow.python.ops.while_v2")
44# @TODO(b/133606651) Replace "shape_invariants" with "loop_vars_signature".
45# pylint: disable=redefined-outer-name
46@tf_export("while_loop", v1=[])
47@deprecation.deprecated_arg_values(
48 None,
49 """back_prop=False is deprecated. Consider using tf.stop_gradient instead.
50Instead of:
51results = tf.while_loop(c, b, vars, back_prop=False)
52Use:
53results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))""",
54 warn_once=True,
55 back_prop=False)
56def while_loop_v2(cond,
57 body,
58 loop_vars,
59 shape_invariants=None,
60 parallel_iterations=10,
61 back_prop=True,
62 swap_memory=False,
63 maximum_iterations=None,
64 name=None):
65 """Repeat `body` while the condition `cond` is true.
67 Note: This op is automatically used in a `tf.function` to convert Python for-
68 and while- loops when the loop variable is a `tf.Tensor`, unless
69 `autograph=False` is explicitly specified in `tf.function` args. For example,
70 the following are equivalent:
72 >>> @tf.function
73 ... def sumSquare(n):
74 ... i, result = tf.constant(0), tf.constant(0)
75 ... while i < n: # AutoGraph converts while-loop to tf.while_loop().
76 ... result += i * i
77 ... i += 1
78 ... return result
79 >>> sumSquare(10).numpy()
80 285
82 >>> @tf.function
83 ... def sumSquare2(n):
84 ... i, result = tf.constant(0), tf.constant(0)
85 ... c = lambda i, _: tf.less(i, n)
86 ... b = lambda i, result: (i + 1, result + i * i)
87 ... return tf.while_loop(c, b, [i, result])[1]
88 >>> sumSquare2(10).numpy()
89 285
91 For more information, see [tf.function and AutoGraph guide
92 ](https://www.tensorflow.org/guide/function#autograph_transformations).
94 `cond` is a callable returning a boolean scalar tensor. `body` is a callable
95 returning a (possibly nested) tuple, namedtuple or list of tensors of the same
96 arity (length and structure) and types as `loop_vars`. `loop_vars` is a
97 (possibly nested) tuple, namedtuple or list of tensors that is passed to both
98 `cond` and `body`. `cond` and `body` both take as many arguments as there are
99 `loop_vars`.
101 In addition to regular Tensors or IndexedSlices, the body may accept and
102 return TensorArray objects. The flows of the TensorArray objects will
103 be appropriately forwarded between loops and during gradient calculations.
105 Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
106 call to `while_loop`, and not at all during `Session.run()`). `while_loop`
107 stitches together the graph fragments created during the `cond` and `body`
108 calls with some additional graph nodes to create the graph flow that
109 repeats `body` until `cond` returns false.
111 For correctness, `tf.while_loop()` strictly enforces shape invariants for
112 the loop variables. A shape invariant is a (possibly partial) shape that
113 is unchanged across the iterations of the loop. An error will be raised
114 if the shape of a loop variable after an iteration is determined to be more
115 general than or incompatible with its shape invariant. For example, a shape
116 of `[11, None]` is more general than a shape of `[11, 17]`, and `[11, 21]` is
117 not compatible with `[11, 17]`. By default (if the argument `shape_invariants`
118 is not specified), it is assumed that the initial shape of each tensor in
119 `loop_vars` is the same in every iteration. The `shape_invariants` argument
120 allows the caller to specify a less specific shape invariant for each loop
121 variable, which is needed if the shape varies between iterations. The
122 `tf.Tensor.set_shape`
123 function may also be used in the `body` function to indicate that
124 the output loop variable has a particular shape. The shape invariant for
125 SparseTensor and IndexedSlices are treated specially as follows:
127 a) If a loop variable is a SparseTensor, the shape invariant must be
128 `TensorShape([r])` where `r` is the rank of the dense tensor represented
129 by the sparse tensor. It means the shapes of the three tensors of the
130 SparseTensor are `([None], [None, r], [r])`. NOTE: The shape invariant here
131 is the shape of the SparseTensor.dense_shape property. It must be the shape of
132 a vector.
134 b) If a loop variable is an IndexedSlices, the shape invariant must be
135 a shape invariant of the values tensor of the IndexedSlices. It means
136 the shapes of the three tensors of the IndexedSlices are `(shape, [shape[0]],
137 [shape.ndims])`.
139 `while_loop` implements non-strict semantics, enabling multiple iterations
140 to run in parallel. The maximum number of parallel iterations can be
141 controlled by `parallel_iterations`, which gives users some control over
142 memory consumption and execution order. For correct programs, `while_loop`
143 should return the same result for any `parallel_iterations > 0`.
145 For training, TensorFlow stores the tensors that are produced in the
146 forward inference and are needed in back propagation. These tensors are a
147 main source of memory consumption and often cause OOM errors when training
148 on GPUs. When the flag swap_memory is true, we swap out these tensors from
149 GPU to CPU. This for example allows us to train RNN models with very long
150 sequences and large batches.
152 Args:
153 cond: A callable that represents the termination condition of the loop.
154 body: A callable that represents the loop body.
155 loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array,
156 `Tensor`, and `TensorArray` objects.
157 shape_invariants: The shape invariants for the loop variables.
158 parallel_iterations: The number of iterations allowed to run in parallel. It
159 must be a positive integer.
160 back_prop: (optional) Deprecated. False disables support for back
161 propagation. Prefer using `tf.stop_gradient` instead.
162 swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
163 maximum_iterations: Optional maximum number of iterations of the while loop
164 to run. If provided, the `cond` output is AND-ed with an additional
165 condition ensuring the number of iterations executed is no greater than
166 `maximum_iterations`.
167 name: Optional name prefix for the returned tensors.
169 Returns:
170 The output tensors for the loop variables after the loop. The return value
171 has the same structure as `loop_vars`.
173 Raises:
174 TypeError: if `cond` or `body` is not callable.
175 ValueError: if `loop_vars` is empty.
177 Example:
179 >>> i = tf.constant(0)
180 >>> c = lambda i: tf.less(i, 10)
181 >>> b = lambda i: (tf.add(i, 1), )
182 >>> r = tf.while_loop(c, b, [i])[0]
183 >>> r.numpy()
184 10
186 Example with nesting and a namedtuple:
188 >>> import collections
189 >>> Pair = collections.namedtuple('Pair', 'j, k')
190 >>> ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
191 >>> c = lambda i, p: i < 10
192 >>> b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
193 >>> ijk_final = tf.while_loop(c, b, ijk_0)[1]
194 >>> ijk_final[0].numpy(), ijk_final[1].numpy()
195 (32, 64)
197 Example using shape_invariants:
199 >>> i0 = tf.constant(0)
200 >>> m0 = tf.ones([2, 2])
201 >>> c = lambda i, m: i < 10
202 >>> b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
203 >>> tf.while_loop(
204 ... c, b, loop_vars=[i0, m0],
205 ... shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])[1]
206 <tf.Tensor: shape=(2048, 2), dtype=float32, numpy=...>
208 Example which demonstrates non-strict semantics: In the following
209 example, the final value of `counter` does not depend on `x`. So
210 the `while_loop` can increment the counter parallel to updates of `x`.
211 However, because the loop counter at one loop iteration depends
212 on the value at the previous iteration, the loop counter itself cannot
213 be incremented in parallel. Hence if we just want the final value of the
214 counter (which we print on the line `print(sess.run(i))`), then
215 `x` will never be incremented, but the counter will be updated on a
216 single thread. Conversely, if we want the value of the output (which we
217 print on the line `print(sess.run(out).shape)`), then the counter may be
218 incremented on its own thread, while `x` can be incremented in
219 parallel on a separate thread. In the extreme case, it is conceivable
220 that the thread incrementing the counter runs until completion before
221 `x` is incremented even a single time. The only thing that can never
222 happen is that the thread updating `x` can never get ahead of the
223 counter thread because the thread incrementing `x` depends on the value
224 of the counter.
226 >>> with tf.compat.v1.Session() as sess:
227 ... n = 10
228 ... c = lambda i, x: i < n
229 ... b = lambda i, x: (
230 ... tf.compat.v1.Print(i + 1, [i], "Updating i based on i == "),
231 ... # Let x depend on i
232 ... tf.compat.v1.Print(x + i, [i], "Updating x based on i == "))
233 ...
234 ... # Make x to be a big matrix so its updating thread would run slowly
235 ... x = tf.zeros([1000, 100], dtype=tf.int32)
236 ... counter = tf.constant(0)
237 ... counter_out, x_out = tf.while_loop(c, b, (counter, x))
238 ...
239 ... # The following line may increment the counter and x in parallel.
240 ... # The counter thread may get ahead of the x thread, but not the
241 ... # other way around. For example, the log may contain these messages:
242 ... # ```
243 ... # Updating i based on i == [9]
244 ... # Updating x based on i == [3]
245 ... # ```
246 ... # meaning that the counter(i) thread is on iteration 9,
247 ... # while the x thread is on iteration 3.
248 ... print(sess.run(x_out).shape)
249 (1000, 100)
251 """
252 return while_loop(
253 cond=cond,
254 body=body,
255 loop_vars=loop_vars,
256 shape_invariants=shape_invariants,
257 parallel_iterations=parallel_iterations,
258 back_prop=back_prop,
259 swap_memory=swap_memory,
260 name=name,
261 maximum_iterations=maximum_iterations,
262 return_same_structure=True)
265# pylint: disable=redefined-outer-name
266@tf_export(v1=["while_loop"])
267def while_loop(cond,
268 body,
269 loop_vars,
270 shape_invariants=None,
271 parallel_iterations=10,
272 back_prop=True,
273 swap_memory=False,
274 name=None,
275 maximum_iterations=None,
276 return_same_structure=False):
277 """Repeat `body` while the condition `cond` is true.
279 `cond` is a callable returning a boolean scalar tensor. `body` is a callable
280 returning a (possibly nested) tuple, namedtuple or list of tensors of the same
281 arity (length and structure) and types as `loop_vars`. `loop_vars` is a
282 (possibly nested) tuple, namedtuple or list of tensors that is passed to both
283 `cond` and `body`. `cond` and `body` both take as many arguments as there are
284 `loop_vars`.
286 In addition to regular Tensors or IndexedSlices, the body may accept and
287 return TensorArray objects. The flows of the TensorArray objects will
288 be appropriately forwarded between loops and during gradient calculations.
290 Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
291 call to `while_loop`, and not at all during `Session.run()`). `while_loop`
292 stitches together the graph fragments created during the `cond` and `body`
293 calls with some additional graph nodes to create the graph flow that
294 repeats `body` until `cond` returns false.
296 For correctness, `tf.while_loop()` strictly enforces shape invariants for
297 the loop variables. A shape invariant is a (possibly partial) shape that
298 is unchanged across the iterations of the loop. An error will be raised
299 if the shape of a loop variable after an iteration is determined to be more
300 general than or incompatible with its shape invariant. For example, a shape
301 of [11, None] is more general than a shape of [11, 17], and [11, 21] is not
302 compatible with [11, 17]. By default (if the argument `shape_invariants` is
303 not specified), it is assumed that the initial shape of each tensor in
304 `loop_vars` is the same in every iteration. The `shape_invariants` argument
305 allows the caller to specify a less specific shape invariant for each loop
306 variable, which is needed if the shape varies between iterations. The
307 `tf.Tensor.set_shape`
308 function may also be used in the `body` function to indicate that
309 the output loop variable has a particular shape. The shape invariant for
310 SparseTensor and IndexedSlices are treated specially as follows:
312 a) If a loop variable is a SparseTensor, the shape invariant must be
313 TensorShape([r]) where r is the rank of the dense tensor represented
314 by the sparse tensor. It means the shapes of the three tensors of the
315 SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here
316 is the shape of the SparseTensor.dense_shape property. It must be the shape of
317 a vector.
319 b) If a loop variable is an IndexedSlices, the shape invariant must be
320 a shape invariant of the values tensor of the IndexedSlices. It means
321 the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]],
322 [shape.ndims]).
324 `while_loop` implements non-strict semantics, enabling multiple iterations
325 to run in parallel. The maximum number of parallel iterations can be
326 controlled by `parallel_iterations`, which gives users some control over
327 memory consumption and execution order. For correct programs, `while_loop`
328 should return the same result for any parallel_iterations > 0.
330 For training, TensorFlow stores the tensors that are produced in the
331 forward inference and are needed in back propagation. These tensors are a
332 main source of memory consumption and often cause OOM errors when training
333 on GPUs. When the flag swap_memory is true, we swap out these tensors from
334 GPU to CPU. This for example allows us to train RNN models with very long
335 sequences and large batches.
337 Args:
338 cond: A callable that represents the termination condition of the loop.
339 body: A callable that represents the loop body.
340 loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array,
341 `Tensor`, and `TensorArray` objects.
342 shape_invariants: The shape invariants for the loop variables.
343 parallel_iterations: The number of iterations allowed to run in parallel. It
344 must be a positive integer.
345 back_prop: Whether backprop is enabled for this while loop.
346 swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
347 name: Optional name prefix for the returned tensors.
348 maximum_iterations: Optional maximum number of iterations of the while loop
349 to run. If provided, the `cond` output is AND-ed with an additional
350 condition ensuring the number of iterations executed is no greater than
351 `maximum_iterations`.
352 return_same_structure: If True, output has same structure as `loop_vars`. If
353 eager execution is enabled, this is ignored (and always treated as True).
355 Returns:
356 The output tensors for the loop variables after the loop.
357 If `return_same_structure` is True, the return value has the same
358 structure as `loop_vars`.
359 If `return_same_structure` is False, the return value is a Tensor,
360 TensorArray or IndexedSlice if the length of `loop_vars` is 1, or a list
361 otherwise.
363 Raises:
364 TypeError: if `cond` or `body` is not callable.
365 ValueError: if `loop_vars` is empty.
367 Example:
369 ```python
370 i = tf.constant(0)
371 c = lambda i: tf.less(i, 10)
372 b = lambda i: tf.add(i, 1)
373 r = tf.while_loop(c, b, [i])
374 ```
376 Example with nesting and a namedtuple:
378 ```python
379 import collections
380 Pair = collections.namedtuple('Pair', 'j, k')
381 ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
382 c = lambda i, p: i < 10
383 b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
384 ijk_final = tf.while_loop(c, b, ijk_0)
385 ```
387 Example using shape_invariants:
389 ```python
390 i0 = tf.constant(0)
391 m0 = tf.ones([2, 2])
392 c = lambda i, m: i < 10
393 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
394 tf.while_loop(
395 c, b, loop_vars=[i0, m0],
396 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
397 ```
399 Example which demonstrates non-strict semantics: In the following
400 example, the final value of the counter `i` does not depend on `x`. So
401 the `while_loop` can increment the counter parallel to updates of `x`.
402 However, because the loop counter at one loop iteration depends
403 on the value at the previous iteration, the loop counter itself cannot
404 be incremented in parallel. Hence if we just want the final value of the
405 counter (which we print on the line `print(sess.run(i))`), then
406 `x` will never be incremented, but the counter will be updated on a
407 single thread. Conversely, if we want the value of the output (which we
408 print on the line `print(sess.run(out).shape)`), then the counter may be
409 incremented on its own thread, while `x` can be incremented in
410 parallel on a separate thread. In the extreme case, it is conceivable
411 that the thread incrementing the counter runs until completion before
412 `x` is incremented even a single time. The only thing that can never
413 happen is that the thread updating `x` can never get ahead of the
414 counter thread because the thread incrementing `x` depends on the value
415 of the counter.
417 ```python
418 import tensorflow as tf
420 n = 10000
421 x = tf.constant(list(range(n)))
422 c = lambda i, x: i < n
423 b = lambda i, x: (tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1,
424 [i], "x:"))
425 i, out = tf.while_loop(c, b, (0, x))
426 with tf.compat.v1.Session() as sess:
427 print(sess.run(i)) # prints [0] ... [9999]
429 # The following line may increment the counter and x in parallel.
430 # The counter thread may get ahead of the other thread, but not the
431 # other way around. So you may see things like
432 # [9996] x:[9987]
433 # meaning that the counter thread is on iteration 9996,
434 # while the other thread is on iteration 9987
435 print(sess.run(out).shape)
436 ```
437 """
438 if not callable(cond):
439 raise TypeError("'cond' must be callable.")
440 if not callable(body):
441 raise TypeError("'body' must be callable.")
442 if parallel_iterations < 1:
443 raise TypeError("'parallel_iterations' must be a positive integer.")
445 loop_vars = variable_utils.convert_variables_to_tensors(loop_vars)
447 # Always enable control flow v2 if building a function, regardless of toggle.
448 executing_eagerly = context.executing_eagerly()
449 if (util.EnableControlFlowV2(ops.get_default_graph()) and
450 not executing_eagerly):
451 return while_v2.while_loop(
452 cond,
453 body,
454 loop_vars,
455 shape_invariants=shape_invariants,
456 parallel_iterations=parallel_iterations,
457 maximum_iterations=maximum_iterations,
458 name=name,
459 return_same_structure=return_same_structure,
460 back_prop=back_prop)
462 with ops.name_scope(name, "while", loop_vars):
463 if not loop_vars:
464 raise ValueError("'loop_vars' must be provided.")
465 try_to_pack = (len(loop_vars) == 1 and not return_same_structure)
466 if maximum_iterations is not None:
467 maximum_iterations = ops.convert_to_tensor(
468 maximum_iterations, name="maximum_iterations")
469 if maximum_iterations.shape.ndims != 0:
470 raise ValueError("'maximum_iterations' must be a scalar. "
471 f"Received shape: {maximum_iterations.shape}")
473 if executing_eagerly:
474 counter = 0
475 maximum_iterations = int(maximum_iterations.numpy())
476 else:
477 counter = constant_op.constant(
478 0, dtype=maximum_iterations.dtype, name="iteration_counter")
479 orig_cond = cond
480 orig_body = body
481 if try_to_pack:
482 loop_vars = (counter, loop_vars[0])
483 cond = lambda i, lv: ( # pylint: disable=g-long-lambda
484 math_ops.logical_and(i < maximum_iterations, orig_cond(lv)))
485 body = lambda i, lv: (i + 1, orig_body(lv))
486 else:
487 loop_vars = (counter, loop_vars)
488 cond = lambda i, lv: ( # pylint: disable=g-long-lambda
489 math_ops.logical_and(i < maximum_iterations, orig_cond(*lv)))
490 body = lambda i, lv: (i + 1, orig_body(*lv))
491 try_to_pack = False
493 if executing_eagerly:
494 packed = False # whether the body result was packed into a 1-item tuple
496 loop_var_structure = nest.map_structure(type_spec.type_spec_from_value,
497 list(loop_vars))
498 while cond(*loop_vars):
499 loop_vars = body(*loop_vars)
500 if try_to_pack and not isinstance(loop_vars, (list, tuple)):
501 packed = True
502 loop_vars = (loop_vars,)
503 nest.assert_same_structure(loop_var_structure, list(loop_vars))
505 def convert(x):
506 if isinstance(x, tensor_array_ops.TensorArray):
507 return x
508 return ops.convert_to_tensor(x)
510 loop_vars = nest.map_structure(convert, loop_vars, expand_composites=True)
511 if maximum_iterations is not None:
512 return loop_vars[1]
513 else:
514 return loop_vars[0] if packed else loop_vars
516 if shape_invariants is not None:
517 if maximum_iterations is not None:
518 shape_invariants = (tensor_shape.TensorShape([]), shape_invariants)
520 loop_context = control_flow_ops.WhileContext(
521 maximum_iterations=maximum_iterations,
522 parallel_iterations=parallel_iterations,
523 back_prop=back_prop,
524 swap_memory=swap_memory)
525 # Only add non-nested loops to the collection. Any nested control flow will
526 # be encapsulated in the root context.
527 if loop_context.outer_context is None:
528 ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
529 result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants,
530 return_same_structure)
531 if maximum_iterations is not None:
532 return result[1]
533 else:
534 return result