Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/functional_ops.py: 19%
257 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Functional operations."""
17from tensorflow.core.framework import attr_value_pb2
18from tensorflow.python.eager import context
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import function
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import gen_functional_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import tensor_array_ops
28from tensorflow.python.ops import variable_scope as vs
29from tensorflow.python.ops import while_loop
30# pylint: disable=unused-import
31from tensorflow.python.ops.gen_functional_ops import remote_call
32# pylint: enable=unused-import
33from tensorflow.python.ops.gen_functional_ops import symbolic_gradient
34from tensorflow.python.util import deprecation
35from tensorflow.python.util import dispatch
36from tensorflow.python.util import nest
37from tensorflow.python.util.tf_export import tf_export
40# TODO(yuanbyu, mrry): Handle stride to support sliding windows.
41@tf_export(v1=["foldl"])
42@dispatch.add_dispatch_support
43def foldl(fn,
44 elems,
45 initializer=None,
46 parallel_iterations=10,
47 back_prop=True,
48 swap_memory=False,
49 name=None):
50 """foldl on the list of tensors unpacked from `elems` on dimension 0.
52 This foldl operator repeatedly applies the callable `fn` to a sequence
53 of elements from first to last. The elements are made of the tensors
54 unpacked from `elems` on dimension 0. The callable fn takes two tensors as
55 arguments. The first argument is the accumulated value computed from the
56 preceding invocation of fn, and the second is the value at the current
57 position of `elems`. If `initializer` is None, `elems` must contain at least
58 one element, and its first element is used as the initializer.
60 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
61 of the result tensor is fn(initializer, values[0]).shape`.
63 This method also allows multi-arity `elems` and output of `fn`. If `elems`
64 is a (possibly nested) list or tuple of tensors, then each of these tensors
65 must have a matching first (unpack) dimension. The signature of `fn` may
66 match the structure of `elems`. That is, if `elems` is
67 `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
68 `fn = lambda (t1, [t2, t3, [t4, t5]]):`.
70 Args:
71 fn: The callable to be performed.
72 elems: A tensor or (possibly nested) sequence of tensors, each of which will
73 be unpacked along their first dimension. The nested sequence of the
74 resulting slices will be the first argument to `fn`.
75 initializer: (optional) A tensor or (possibly nested) sequence of tensors,
76 as the initial value for the accumulator.
77 parallel_iterations: (optional) The number of iterations allowed to run in
78 parallel.
79 back_prop: (optional) True enables support for back propagation.
80 swap_memory: (optional) True enables GPU-CPU memory swapping.
81 name: (optional) Name prefix for the returned tensors.
83 Returns:
84 A tensor or (possibly nested) sequence of tensors, resulting from applying
85 `fn` consecutively to the list of tensors unpacked from `elems`, from first
86 to last.
88 Raises:
89 TypeError: if `fn` is not callable.
91 Example:
92 ```python
93 elems = tf.constant([1, 2, 3, 4, 5, 6])
94 sum = foldl(lambda a, x: a + x, elems)
95 # sum == 21
96 ```
97 """
98 if not callable(fn):
99 raise TypeError(
100 f"{fn.__name__} is not callable. Please provide a callable function.")
102 def create_ta(elem):
103 return tensor_array_ops.TensorArray(
104 dtype=elem.dtype, size=n, dynamic_size=False,
105 infer_shape=True).unstack(elem)
107 in_graph_mode = not context.executing_eagerly()
108 with ops.name_scope(name, "foldl", [elems]):
109 # TODO(akshayka): Remove the in_graph_mode check once caching devices are
110 # supported in Eager
111 if in_graph_mode:
112 # Any get_variable calls in fn will cache the first call locally
113 # and not issue repeated network I/O requests for each iteration.
114 varscope = vs.get_variable_scope()
115 varscope_caching_device_was_none = False
116 if varscope.caching_device is None:
117 # TODO(ebrevdo): Change to using colocate_with here and in other
118 # methods.
119 varscope.set_caching_device(lambda op: op.device)
120 varscope_caching_device_was_none = True
122 # Convert elems to tensor array. n may be known statically.
123 elems_flat = [
124 ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems)
125 ]
126 n = (
127 tensor_shape.dimension_value(elems_flat[0].shape[0]) or
128 array_ops.shape(elems_flat[0])[0])
130 elems_ta = nest.map_structure(create_ta, elems)
132 if initializer is None:
133 a = nest.map_structure(lambda elem: elem.read(0), elems_ta)
134 i = constant_op.constant(1)
135 else:
136 a = initializer
137 i = constant_op.constant(0)
139 def compute(i, a):
140 elem_i = nest.map_structure(lambda elem: elem.read(i), elems_ta)
141 a = fn(a, elem_i)
142 return [i + 1, a]
144 _, r_a = while_loop.while_loop(
145 lambda i, a: i < n,
146 compute, [i, a],
147 parallel_iterations=parallel_iterations,
148 back_prop=back_prop,
149 swap_memory=swap_memory,
150 maximum_iterations=n)
152 # TODO(akshayka): Remove the in_graph_mode check once caching devices are
153 # supported in Eager
154 if in_graph_mode and varscope_caching_device_was_none:
155 varscope.set_caching_device(None)
157 return r_a
160@tf_export("foldl", v1=[])
161@dispatch.add_dispatch_support
162@deprecation.deprecated_arg_values(
163 None,
164 """back_prop=False is deprecated. Consider using tf.stop_gradient instead.
165Instead of:
166results = tf.foldl(fn, elems, back_prop=False)
167Use:
168results = tf.nest.map_structure(tf.stop_gradient, tf.foldl(fn, elems))""",
169 warn_once=True,
170 back_prop=False)
171def foldl_v2(fn,
172 elems,
173 initializer=None,
174 parallel_iterations=10,
175 back_prop=True,
176 swap_memory=False,
177 name=None):
178 """foldl on the list of tensors unpacked from `elems` on dimension 0.
180 This foldl operator repeatedly applies the callable `fn` to a sequence
181 of elements from first to last. The elements are made of the tensors
182 unpacked from `elems` on dimension 0. The callable fn takes two tensors as
183 arguments. The first argument is the accumulated value computed from the
184 preceding invocation of fn, and the second is the value at the current
185 position of `elems`. If `initializer` is None, `elems` must contain at least
186 one element, and its first element is used as the initializer.
188 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
189 of the result tensor is fn(initializer, values[0]).shape`.
191 This method also allows multi-arity `elems` and output of `fn`. If `elems`
192 is a (possibly nested) list or tuple of tensors, then each of these tensors
193 must have a matching first (unpack) dimension. The signature of `fn` may
194 match the structure of `elems`. That is, if `elems` is
195 `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
196 `fn = lambda (t1, [t2, t3, [t4, t5]]):`.
198 Args:
199 fn: The callable to be performed.
200 elems: A tensor or (possibly nested) sequence of tensors, each of which will
201 be unpacked along their first dimension. The nested sequence of the
202 resulting slices will be the first argument to `fn`.
203 initializer: (optional) A tensor or (possibly nested) sequence of tensors,
204 as the initial value for the accumulator.
205 parallel_iterations: (optional) The number of iterations allowed to run in
206 parallel.
207 back_prop: (optional) Deprecated. False disables support for back
208 propagation. Prefer using `tf.stop_gradient` instead.
209 swap_memory: (optional) True enables GPU-CPU memory swapping.
210 name: (optional) Name prefix for the returned tensors.
212 Returns:
213 A tensor or (possibly nested) sequence of tensors, resulting from applying
214 `fn` consecutively to the list of tensors unpacked from `elems`, from first
215 to last.
217 Raises:
218 TypeError: if `fn` is not callable.
220 Example:
221 ```python
222 elems = tf.constant([1, 2, 3, 4, 5, 6])
223 sum = tf.foldl(lambda a, x: a + x, elems)
224 # sum == 21
225 ```
226 """
227 return foldl(
228 fn=fn,
229 elems=elems,
230 initializer=initializer,
231 parallel_iterations=parallel_iterations,
232 back_prop=back_prop,
233 swap_memory=swap_memory,
234 name=name)
237@tf_export(v1=["foldr"])
238@dispatch.add_dispatch_support
239def foldr(fn,
240 elems,
241 initializer=None,
242 parallel_iterations=10,
243 back_prop=True,
244 swap_memory=False,
245 name=None):
246 """foldr on the list of tensors unpacked from `elems` on dimension 0.
248 This foldr operator repeatedly applies the callable `fn` to a sequence
249 of elements from last to first. The elements are made of the tensors
250 unpacked from `elems`. The callable fn takes two tensors as arguments.
251 The first argument is the accumulated value computed from the preceding
252 invocation of fn, and the second is the value at the current position of
253 `elems`. If `initializer` is None, `elems` must contain at least one element,
254 and its first element is used as the initializer.
256 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
257 of the result tensor is `fn(initializer, values[0]).shape`.
259 This method also allows multi-arity `elems` and output of `fn`. If `elems`
260 is a (possibly nested) list or tuple of tensors, then each of these tensors
261 must have a matching first (unpack) dimension. The signature of `fn` may
262 match the structure of `elems`. That is, if `elems` is
263 `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
264 `fn = lambda (t1, [t2, t3, [t4, t5]]):`.
266 Args:
267 fn: The callable to be performed.
268 elems: A tensor or (possibly nested) sequence of tensors, each of which will
269 be unpacked along their first dimension. The nested sequence of the
270 resulting slices will be the first argument to `fn`.
271 initializer: (optional) A tensor or (possibly nested) sequence of tensors,
272 as the initial value for the accumulator.
273 parallel_iterations: (optional) The number of iterations allowed to run in
274 parallel.
275 back_prop: (optional) True enables support for back propagation.
276 swap_memory: (optional) True enables GPU-CPU memory swapping.
277 name: (optional) Name prefix for the returned tensors.
279 Returns:
280 A tensor or (possibly nested) sequence of tensors, resulting from applying
281 `fn` consecutively to the list of tensors unpacked from `elems`, from last
282 to first.
284 Raises:
285 TypeError: if `fn` is not callable.
287 Example:
288 ```python
289 elems = [1, 2, 3, 4, 5, 6]
290 sum = foldr(lambda a, x: a + x, elems)
291 # sum == 21
292 ```
293 """
294 if not callable(fn):
295 raise TypeError(
296 f"{fn.__name__} is not callable. Please provide a callable function.")
298 def create_ta(elem):
299 return tensor_array_ops.TensorArray(
300 dtype=elem.dtype, size=n, dynamic_size=False,
301 infer_shape=True).unstack(elem)
303 in_graph_mode = not context.executing_eagerly()
304 with ops.name_scope(name, "foldr", [elems]):
305 # TODO(akshayka): Remove the in_graph_mode check once caching devices are
306 # supported in Eager
307 if in_graph_mode:
308 # Any get_variable calls in fn will cache the first call locally and not
309 # issue repeated network I/O requests for each iteration.
310 varscope = vs.get_variable_scope()
311 varscope_caching_device_was_none = False
312 if varscope.caching_device is None:
313 # TODO(ebrevdo): Change to using colocate_with here and in other
314 # methods.
315 varscope.set_caching_device(lambda op: op.device)
316 varscope_caching_device_was_none = True
318 # Convert elems to tensor array. n may be known statically.
319 elems_flat = [
320 ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems)
321 ]
322 n = (
323 tensor_shape.dimension_value(elems_flat[0].shape[0]) or
324 array_ops.shape(elems_flat[0])[0])
326 elems_ta = nest.map_structure(create_ta, elems)
328 if initializer is None:
329 i = n - 1
330 a = nest.map_structure(lambda elem: elem.read(i), elems_ta)
331 else:
332 i = n
333 a = initializer
335 def compute(i, a):
336 i -= 1
337 elem = nest.map_structure(lambda elem: elem.read(i), elems_ta)
338 a_out = fn(a, elem)
339 return [i, a_out]
341 _, r_a = while_loop.while_loop(
342 lambda i, a: i > 0,
343 compute, [i, a],
344 parallel_iterations=parallel_iterations,
345 back_prop=back_prop,
346 swap_memory=swap_memory,
347 maximum_iterations=n)
349 # TODO(akshayka): Remove the in_graph_mode check once caching devices are
350 # supported in Eager
351 if in_graph_mode and varscope_caching_device_was_none:
352 varscope.set_caching_device(None)
354 return r_a
357@tf_export("foldr", v1=[])
358@dispatch.add_dispatch_support
359@deprecation.deprecated_arg_values(
360 None,
361 """back_prop=False is deprecated. Consider using tf.stop_gradient instead.
362Instead of:
363results = tf.foldr(fn, elems, back_prop=False)
364Use:
365results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))""",
366 warn_once=True,
367 back_prop=False)
368def foldr_v2(fn,
369 elems,
370 initializer=None,
371 parallel_iterations=10,
372 back_prop=True,
373 swap_memory=False,
374 name=None):
375 """foldr on the list of tensors unpacked from `elems` on dimension 0.
377 This foldr operator repeatedly applies the callable `fn` to a sequence
378 of elements from last to first. The elements are made of the tensors
379 unpacked from `elems`. The callable fn takes two tensors as arguments.
380 The first argument is the accumulated value computed from the preceding
381 invocation of fn, and the second is the value at the current position of
382 `elems`. If `initializer` is None, `elems` must contain at least one element,
383 and its first element is used as the initializer.
385 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
386 of the result tensor is `fn(initializer, values[0]).shape`.
388 This method also allows multi-arity `elems` and output of `fn`. If `elems`
389 is a (possibly nested) list or tuple of tensors, then each of these tensors
390 must have a matching first (unpack) dimension. The signature of `fn` may
391 match the structure of `elems`. That is, if `elems` is
392 `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
393 `fn = lambda (t1, [t2, t3, [t4, t5]]):`.
395 Args:
396 fn: The callable to be performed.
397 elems: A tensor or (possibly nested) sequence of tensors, each of which will
398 be unpacked along their first dimension. The nested sequence of the
399 resulting slices will be the first argument to `fn`.
400 initializer: (optional) A tensor or (possibly nested) sequence of tensors,
401 as the initial value for the accumulator.
402 parallel_iterations: (optional) The number of iterations allowed to run in
403 parallel.
404 back_prop: (optional) Deprecated. False disables support for back
405 propagation. Prefer using `tf.stop_gradient` instead.
406 swap_memory: (optional) True enables GPU-CPU memory swapping.
407 name: (optional) Name prefix for the returned tensors.
409 Returns:
410 A tensor or (possibly nested) sequence of tensors, resulting from applying
411 `fn` consecutively to the list of tensors unpacked from `elems`, from last
412 to first.
414 Raises:
415 TypeError: if `fn` is not callable.
417 Example:
418 ```python
419 elems = [1, 2, 3, 4, 5, 6]
420 sum = tf.foldr(lambda a, x: a + x, elems)
421 # sum == 21
422 ```
423 """
424 return foldr(
425 fn=fn,
426 elems=elems,
427 initializer=initializer,
428 parallel_iterations=parallel_iterations,
429 back_prop=back_prop,
430 swap_memory=swap_memory,
431 name=name)
434@tf_export(v1=["scan"])
435@dispatch.add_dispatch_support
436def scan(fn,
437 elems,
438 initializer=None,
439 parallel_iterations=10,
440 back_prop=True,
441 swap_memory=False,
442 infer_shape=True,
443 reverse=False,
444 name=None):
445 """scan on the list of tensors unpacked from `elems` on dimension 0.
447 See also `tf.map_fn`.
449 The simplest version of `scan` repeatedly applies the callable `fn` to a
450 sequence of elements from first to last. The elements are made of the tensors
451 unpacked from `elems` on dimension 0. The callable fn takes two tensors as
452 arguments. The first argument is the accumulated value computed from the
453 preceding invocation of fn, and the second is the value at the current
454 position of `elems`. If `initializer` is None, `elems` must contain at least
455 one element, and its first element is used as the initializer.
457 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
458 of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
459 If reverse=True, it's fn(initializer, values[-1]).shape.
461 This method also allows multi-arity `elems` and accumulator. If `elems`
462 is a (possibly nested) list or tuple of tensors, then each of these tensors
463 must have a matching first (unpack) dimension. The second argument of
464 `fn` must match the structure of `elems`.
466 If no `initializer` is provided, the output structure and dtypes of `fn`
467 are assumed to be the same as its input; and in this case, the first
468 argument of `fn` must match the structure of `elems`.
470 If an `initializer` is provided, then the output of `fn` must have the same
471 structure as `initializer`; and the first argument of `fn` must match
472 this structure.
474 For example, if `elems` is `(t1, [t2, t3])` and `initializer` is
475 `[i1, i2]` then an appropriate signature for `fn` in `python2` is:
476 `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list,
477 `[acc_n1, acc_n2]`. An alternative correct signature for `fn`, and the
478 one that works in `python3`, is:
479 `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples.
481 Args:
482 fn: The callable to be performed. It accepts two arguments. The first will
483 have the same structure as `initializer` if one is provided, otherwise it
484 will have the same structure as `elems`. The second will have the same
485 (possibly nested) structure as `elems`. Its output must have the same
486 structure as `initializer` if one is provided, otherwise it must have the
487 same structure as `elems`.
488 elems: A tensor or (possibly nested) sequence of tensors, each of which will
489 be unpacked along their first dimension. The nested sequence of the
490 resulting slices will be the first argument to `fn`.
491 initializer: (optional) A tensor or (possibly nested) sequence of tensors,
492 initial value for the accumulator, and the expected output type of `fn`.
493 parallel_iterations: (optional) The number of iterations allowed to run in
494 parallel.
495 back_prop: (optional) True enables support for back propagation.
496 swap_memory: (optional) True enables GPU-CPU memory swapping.
497 infer_shape: (optional) False disables tests for consistent output shapes.
498 reverse: (optional) True scans the tensor last to first (instead of first to
499 last).
500 name: (optional) Name prefix for the returned tensors.
502 Returns:
503 A tensor or (possibly nested) sequence of tensors. Each tensor packs the
504 results of applying `fn` to tensors unpacked from `elems` along the first
505 dimension, and the previous accumulator value(s), from first to last (or
506 last to first, if `reverse=True`).
508 Raises:
509 TypeError: if `fn` is not callable or the structure of the output of
510 `fn` and `initializer` do not match.
511 ValueError: if the lengths of the output of `fn` and `initializer`
512 do not match.
514 Examples:
515 ```python
516 elems = np.array([1, 2, 3, 4, 5, 6])
517 sum = scan(lambda a, x: a + x, elems)
518 # sum == [1, 3, 6, 10, 15, 21]
519 sum = scan(lambda a, x: a + x, elems, reverse=True)
520 # sum == [21, 20, 18, 15, 11, 6]
521 ```
523 ```python
524 elems = np.array([1, 2, 3, 4, 5, 6])
525 initializer = np.array(0)
526 sum_one = scan(
527 lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer)
528 # sum_one == [1, 2, 3, 4, 5, 6]
529 ```
531 ```python
532 elems = np.array([1, 0, 0, 0, 0, 0])
533 initializer = (np.array(0), np.array(1))
534 fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer)
535 # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13])
536 ```
537 """
538 if not callable(fn):
539 raise TypeError(
540 f"{fn.__name__} is not callable. Please provide a callable function.")
542 input_is_sequence = nest.is_nested(elems)
543 input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]
545 def input_pack(x):
546 return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]
548 if initializer is None:
549 output_is_sequence = input_is_sequence
550 output_flatten = input_flatten
551 output_pack = input_pack
552 else:
553 output_is_sequence = nest.is_nested(initializer)
554 output_flatten = lambda x: nest.flatten(x) if output_is_sequence else [x]
556 def output_pack(x):
557 return (nest.pack_sequence_as(initializer, x)
558 if output_is_sequence else x[0])
560 elems_flat = input_flatten(elems)
562 in_graph_mode = not context.executing_eagerly()
563 with ops.name_scope(name, "scan", elems_flat):
564 # TODO(akshayka): Remove the in_graph_mode check once caching devices are
565 # supported in Eager
566 if in_graph_mode:
567 # Any get_variable calls in fn will cache the first call locally
568 # and not issue repeated network I/O requests for each iteration.
569 varscope = vs.get_variable_scope()
570 varscope_caching_device_was_none = False
571 if varscope.caching_device is None:
572 # TODO(ebrevdo): Change to using colocate_with here and in other
573 # methods.
574 varscope.set_caching_device(lambda op: op.device)
575 varscope_caching_device_was_none = True
577 # Convert elems to tensor array.
578 elems_flat = [
579 ops.convert_to_tensor(elem, name="elem") for elem in elems_flat
580 ]
582 # Convert elems to tensor array. n may be known statically.
583 n = tensor_shape.dimension_value(elems_flat[0].shape[0])
584 if n is None:
585 n = array_ops.shape(elems_flat[0])[0]
587 # TensorArrays are always flat
588 elems_ta = [
589 tensor_array_ops.TensorArray(
590 dtype=elem.dtype,
591 size=n,
592 dynamic_size=False,
593 element_shape=elem.shape[1:],
594 infer_shape=True) for elem in elems_flat
595 ]
596 # Unpack elements
597 elems_ta = [
598 elem_ta.unstack(elem) for elem_ta, elem in zip(elems_ta, elems_flat)
599 ]
601 if initializer is None:
602 a_flat = [elem.read(n - 1 if reverse else 0) for elem in elems_ta]
603 i = 1
604 else:
605 initializer_flat = output_flatten(initializer)
606 a_flat = [ops.convert_to_tensor(init) for init in initializer_flat]
607 i = 0
609 # Create a tensor array to store the intermediate values.
610 accs_ta = [
611 tensor_array_ops.TensorArray(
612 dtype=init.dtype,
613 size=n,
614 element_shape=init.shape if infer_shape else None,
615 dynamic_size=False,
616 infer_shape=infer_shape) for init in a_flat
617 ]
619 if initializer is None:
620 accs_ta = [
621 acc_ta.write(n - 1 if reverse else 0, a)
622 for (acc_ta, a) in zip(accs_ta, a_flat)
623 ]
625 def compute(i, a_flat, tas):
626 """The loop body of scan.
628 Args:
629 i: the loop counter.
630 a_flat: the accumulator value(s), flattened.
631 tas: the output accumulator TensorArray(s), flattened.
633 Returns:
634 [i + 1, a_flat, tas]: the updated counter + new accumulator values +
635 updated TensorArrays
637 Raises:
638 TypeError: if initializer and fn() output structure do not match
639 ValueType: if initializer and fn() output lengths do not match
640 """
641 packed_elems = input_pack([elem_ta.read(i) for elem_ta in elems_ta])
642 packed_a = output_pack(a_flat)
643 a_out = fn(packed_a, packed_elems)
644 nest.assert_same_structure(elems if initializer is None else initializer,
645 a_out)
646 flat_a_out = output_flatten(a_out)
647 tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_a_out)]
648 if reverse:
649 next_i = i - 1
650 else:
651 next_i = i + 1
652 return (next_i, flat_a_out, tas)
654 if reverse:
655 initial_i = n - 1 - i
656 condition = lambda i, _1, _2: i >= 0
657 else:
658 initial_i = i
659 condition = lambda i, _1, _2: i < n
660 _, _, r_a = while_loop.while_loop(
661 condition,
662 compute, (initial_i, a_flat, accs_ta),
663 parallel_iterations=parallel_iterations,
664 back_prop=back_prop,
665 swap_memory=swap_memory,
666 maximum_iterations=n)
668 results_flat = [r.stack() for r in r_a]
670 n_static = tensor_shape.Dimension(
671 tensor_shape.dimension_value(
672 elems_flat[0].get_shape().with_rank_at_least(1)[0]))
673 for elem in elems_flat[1:]:
674 n_static.assert_is_compatible_with(
675 tensor_shape.Dimension(
676 tensor_shape.dimension_value(
677 elem.get_shape().with_rank_at_least(1)[0])))
678 for r in results_flat:
679 r.set_shape(
680 tensor_shape.TensorShape(n_static).concatenate(r.get_shape()[1:]))
682 # TODO(akshayka): Remove the in_graph_mode check once caching devices are
683 # supported in Eager
684 if in_graph_mode and varscope_caching_device_was_none:
685 varscope.set_caching_device(None)
687 return output_pack(results_flat)
690@tf_export("scan", v1=[])
691@dispatch.add_dispatch_support
692@deprecation.deprecated_arg_values(
693 None,
694 """back_prop=False is deprecated. Consider using tf.stop_gradient instead.
695Instead of:
696results = tf.scan(fn, elems, back_prop=False)
697Use:
698results = tf.nest.map_structure(tf.stop_gradient, tf.scan(fn, elems))""",
699 warn_once=True,
700 back_prop=False)
701def scan_v2(fn,
702 elems,
703 initializer=None,
704 parallel_iterations=10,
705 back_prop=True,
706 swap_memory=False,
707 infer_shape=True,
708 reverse=False,
709 name=None):
710 """scan on the list of tensors unpacked from `elems` on dimension 0.
712 The simplest version of `scan` repeatedly applies the callable `fn` to a
713 sequence of elements from first to last. The elements are made of the tensors
714 unpacked from `elems` on dimension 0. The callable fn takes two tensors as
715 arguments. The first argument is the accumulated value computed from the
716 preceding invocation of fn, and the second is the value at the current
717 position of `elems`. If `initializer` is None, `elems` must contain at least
718 one element, and its first element is used as the initializer.
720 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
721 of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
722 If reverse=True, it's fn(initializer, values[-1]).shape.
724 This method also allows multi-arity `elems` and accumulator. If `elems`
725 is a (possibly nested) list or tuple of tensors, then each of these tensors
726 must have a matching first (unpack) dimension. The second argument of
727 `fn` must match the structure of `elems`.
729 If no `initializer` is provided, the output structure and dtypes of `fn`
730 are assumed to be the same as its input; and in this case, the first
731 argument of `fn` must match the structure of `elems`.
733 If an `initializer` is provided, then the output of `fn` must have the same
734 structure as `initializer`; and the first argument of `fn` must match
735 this structure.
737 For example, if `elems` is `(t1, [t2, t3])` and `initializer` is
738 `[i1, i2]` then an appropriate signature for `fn` in `python2` is:
739 `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list,
740 `[acc_n1, acc_n2]`. An alternative correct signature for `fn`, and the
741 one that works in `python3`, is:
742 `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples.
744 Args:
745 fn: The callable to be performed. It accepts two arguments. The first will
746 have the same structure as `initializer` if one is provided, otherwise it
747 will have the same structure as `elems`. The second will have the same
748 (possibly nested) structure as `elems`. Its output must have the same
749 structure as `initializer` if one is provided, otherwise it must have the
750 same structure as `elems`.
751 elems: A tensor or (possibly nested) sequence of tensors, each of which will
752 be unpacked along their first dimension. The nested sequence of the
753 resulting slices will be the first argument to `fn`.
754 initializer: (optional) A tensor or (possibly nested) sequence of tensors,
755 initial value for the accumulator, and the expected output type of `fn`.
756 parallel_iterations: (optional) The number of iterations allowed to run in
757 parallel.
758 back_prop: (optional) Deprecated. False disables support for back
759 propagation. Prefer using `tf.stop_gradient` instead.
760 swap_memory: (optional) True enables GPU-CPU memory swapping.
761 infer_shape: (optional) False disables tests for consistent output shapes.
762 reverse: (optional) True scans the tensor last to first (instead of first to
763 last).
764 name: (optional) Name prefix for the returned tensors.
766 Returns:
767 A tensor or (possibly nested) sequence of tensors. Each tensor packs the
768 results of applying `fn` to tensors unpacked from `elems` along the first
769 dimension, and the previous accumulator value(s), from first to last (or
770 last to first, if `reverse=True`).
772 Raises:
773 TypeError: if `fn` is not callable or the structure of the output of
774 `fn` and `initializer` do not match.
775 ValueError: if the lengths of the output of `fn` and `initializer`
776 do not match.
778 Examples:
779 ```python
780 elems = np.array([1, 2, 3, 4, 5, 6])
781 sum = scan(lambda a, x: a + x, elems)
782 # sum == [1, 3, 6, 10, 15, 21]
783 sum = scan(lambda a, x: a + x, elems, reverse=True)
784 # sum == [21, 20, 18, 15, 11, 6]
785 ```
787 ```python
788 elems = np.array([1, 2, 3, 4, 5, 6])
789 initializer = np.array(0)
790 sum_one = scan(
791 lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer)
792 # sum_one == [1, 2, 3, 4, 5, 6]
793 ```
795 ```python
796 elems = np.array([1, 0, 0, 0, 0, 0])
797 initializer = (np.array(0), np.array(1))
798 fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer)
799 # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13])
800 ```
801 """
802 return scan(
803 fn=fn,
804 elems=elems,
805 initializer=initializer,
806 parallel_iterations=parallel_iterations,
807 back_prop=back_prop,
808 swap_memory=swap_memory,
809 infer_shape=infer_shape,
810 reverse=reverse,
811 name=name)
814# pylint: disable=invalid-name
815def If(cond, inputs, then_branch, else_branch, name=None):
816 r"""output = Cond(inputs) ?
818 then_branch(inputs) : else_branch(inputs).
820 Args:
821 cond: A `Tensor`. A scalar. If the scalar is not a boolean, the scalar is
822 converted to a boolean according to the following rule: if the scalar is a
823 numerical value, non-zero means True and zero means False; if the scalar
824 is a string, non-empty means True and empty means False.
825 inputs: A list of input tensors.
826 then_branch: A function takes 'inputs' and returns a list of tensors, whose
827 types are the same as what else_branch returns.
828 else_branch: A function takes 'inputs' and returns a list of tensors. whose
829 types are the same as what then_branch returns.
830 name: A name for the operation (optional).
832 Returns:
833 A list of tensors returned by either then_branch(inputs)
834 or else_branch(inputs).
835 """
836 # pylint: disable=protected-access
837 # Handle the Defun case until users have transitioned to tf.function. Note
838 # that composites may need to be re-packed by the caller.
839 if isinstance(then_branch, function._DefinedFunction):
840 tlist = [_.type for _ in then_branch.definition.signature.output_arg]
841 return gen_functional_ops._if(
842 cond, inputs, tlist, then_branch, else_branch, name=name)
844 # We assume that `then_branch` is a ConcreteFunction here.
845 then_out = then_branch.structured_outputs
846 else_out = else_branch.structured_outputs
848 # Ensure then/else are the same type of composites to avoid an invalid call
849 # to pack_sequence_as later on.
850 nest.assert_same_structure(then_out, else_out, expand_composites=True)
852 tlist = nest.flatten(then_branch.output_dtypes)
853 ret = gen_functional_ops._if(
854 cond, inputs, tlist, then_branch, else_branch, name=name)
856 # Re-pack the outputs to restore any CompositeTensors
857 return nest.pack_sequence_as(then_out, ret, expand_composites=True)
860def Gradient(inputs, f, name=None):
861 r"""Computes the gradient function for function f via backpropagation.
863 Args:
864 inputs: A list of tensors of size N + M.
865 f: The function we want to compute the gradient for. The function 'f' must
866 be a numerical function which takes N inputs and produces M outputs. Its
867 gradient function 'g', which is a function taking N + M inputs and
868 produces N outputs. I.e. if we have (y1, y2, ..., yM) = f(x1, x2, ...,
869 xN), then, g is (dL/dx1, dL/dx2, ..., dL/dxN) = g(x1, x2, ..., xN, dL/dy1,
870 dL/dy2, ..., dL/dyM), where L is a scalar-value function of (x1, x2, ...,
871 xN) (e.g., the loss function). dL/dxi is the partial derivative of L with
872 respect to xi.
873 name: A name for the operation (optional).
875 Returns:
876 A list of tensors of size N.
877 """
878 # TODO(zhifengc): Pretty-print the above spec in latex.
879 # TODO(zhfiengc): Needs some math expert to say the comment above better.
880 tlist = [_.type for _ in f.definition.signature.input_arg]
881 return symbolic_gradient(input=inputs, Tout=tlist, f=f, name=name)
884def _GetInputDtypes(func):
885 """Returns the input dtypes of func, excluding dtypes for captured inputs."""
886 if isinstance(func, function._DefinedFunction): # pylint: disable=protected-access
887 return func.declared_input_types
889 # We assume that `func` is a ConcreteFunction here, but we are not able to
890 # verify since importing eager function library will cause cyclic dependence.
891 #
892 # ConcreteFunction.inputs includes captured inputs.
893 num_non_captured_inputs = len(func.inputs) - len(func.captured_inputs)
894 inputs_without_captured = func.inputs[:num_non_captured_inputs]
895 return [t.dtype for t in inputs_without_captured]
898def _LoopBodyCaptureWrapper(func):
899 """Returns a wrapper for `func` that handles loop-carried captured inputs."""
901 @function.Defun(*_GetInputDtypes(func), func_name="%s_Wrapper" % func.name)
902 def Wrapper(*args):
903 """A wrapper that handles loop-carried captured inputs."""
904 result = func(*args)
905 extra_args = tuple(function.get_extra_args())
906 # Nullary functions return an Operation. Normal functions can't do this
907 # because their return values are converted to Tensors.
908 if isinstance(result, ops.Operation):
909 return extra_args
910 # Unary functions return a single Tensor value.
911 elif not isinstance(result, (list, tuple)):
912 return (result,) + extra_args
913 # N-ary functions return a tuple of Tensors.
914 else:
915 return result + type(result)(extra_args)
917 return Wrapper
920# pylint: disable=invalid-name,protected-access
921def While(input_, cond, body, name=None, hostmem=None):
922 r"""output = input; While (Cond(output)) { output = Body(output) }.
924 Args:
925 input_: A list of `Tensor` objects. A list of input tensors whose types are
926 T.
927 cond: . A function takes 'input' and returns a tensor. If the tensor is a
928 scalar of non-boolean, the scalar is converted to a boolean
929 according to the following rule: if the scalar is a numerical value,
930 non-zero means True and zero means False; if the scalar is a string,
931 non-empty means True and empty means False. If the tensor is not a
932 scalar, non-emptiness means True and False otherwise.
933 body: . A function takes a list of tensors and returns another list tensors.
934 Both lists have the same types as specified by T.
935 name: A name for the operation (optional).
936 hostmem: A list of integer. If i is in the list, input[i] is a host memory
937 tensor.
939 Raises:
940 ValueError: if `cond` has implicitly captured inputs or if `cond` and `body`
941 have different signatures.
943 Returns:
944 A list of `Tensor` objects. Has the same type as `input`.
945 A list of output tensors whose types are T.
946 """
947 if cond.captured_inputs:
948 raise ValueError(
949 "The 'cond' argument can not have implicitly captured inputs. Received "
950 f"captured_inputs: {cond.captured_inputs}")
952 cond_input_types = _GetInputDtypes(cond)
953 body_input_types = _GetInputDtypes(body)
955 if cond_input_types != body_input_types:
956 raise ValueError(
957 "The 'cond' and 'body' signatures do not match. Received: "
958 f"cond_input_types={cond_input_types}, body_input_types="
959 f"{body_input_types}")
961 if body.captured_inputs:
962 cond_dtypes = list(body_input_types) + [
963 t.dtype for t in body.captured_inputs
964 ]
966 @function.Defun(*cond_dtypes, func_name="%s_Wrapper" % cond.name)
967 def CondWrapper(*args):
968 """A wrapper that handles loop-carried captured inputs."""
969 return cond(*args[:len(body_input_types)])
971 ret = gen_functional_ops._while(
972 input_ + body.captured_inputs,
973 CondWrapper,
974 _LoopBodyCaptureWrapper(body),
975 name=name)
976 # Slice off the loop-carried captured inputs.
977 ret = ret[:-len(body.captured_inputs)]
978 else:
979 ret = gen_functional_ops._while(input_, cond, body, name=name)
980 if hostmem:
981 input_attr = attr_value_pb2.AttrValue()
982 input_attr.list.i.extend(hostmem)
983 ret[0].op._set_attr("_input_hostmem", input_attr) # pylint: disable=protected-access
985 output_attr = attr_value_pb2.AttrValue()
986 output_attr.list.i.extend(hostmem)
987 ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access
988 return ret
991# b/36459430
992#
993# Ideally, we do not need this rewrite For loop into a While loop.
994# However, today, if a While runs on GPU and the condition returns a
995# boolean, the While kernel crashes. Even if we fix the crash, the
996# bool needs to be copied between GPU and CPU. So, a for loop is much
997# preferred when running on GPU.
998#
999# On the other hand, For op has no directly XLA kernel. So, when we run
1000# a for loop, we need to rewrite it using a While op.
1001#
1002# It should be possible and probably better to write a XLA C++ kernel
1003# implementing the logic in _ForUsingWhile.
1004def _ForUsingWhile(start,
1005 limit,
1006 delta,
1007 inputs,
1008 forbody,
1009 name=None,
1010 hostmem=None):
1011 """Helper to implement a For loop using a While."""
1012 # To support negative delta (e.g., range(100, 0, -3)), we iterate
1013 # over the range(n) and use iter * delta + start as the real
1014 # iteration index. (e.g., for i in range(34): iter = i * (-3) +
1015 # 100).
1016 d = math_ops.abs(delta)
1017 # XLA on TPUs doesn't support integer division
1018 n = math_ops.cast(
1019 math_ops.cast((math_ops.abs(limit - start) + d - 1), dtypes.float32) /
1020 math_ops.cast(d, dtypes.float32), dtypes.int32)
1022 # Carried loop variables ("extra_args") are implicitly added to the input list
1023 # of the WhileBody function. WhileCond does not call forbody, and so does not
1024 # depend on any of forbody's extra_args. Since WhileCond and WhileBody
1025 # must have identical inputs, we have to augment the cond signature to take
1026 # the same types as the carried loop variables.
1027 body_sig = [dtypes.int32] * 4 + list(forbody.declared_input_types)[1:]
1029 cond_name = "%s_Cond" % forbody.name
1031 @function.Defun(*body_sig, func_name=cond_name)
1032 def WhileCond(i, n, *args):
1033 del args
1034 return i < n
1036 body_name = "%s_Body" % forbody.name
1038 @function.Defun(*body_sig, func_name=body_name)
1039 def WhileBody(i, n, start, delta, *args):
1040 """A While wrapper for forbody that handles loop-carried captured inputs."""
1041 for_result = forbody(start + i * delta, *args)
1042 # Nullary functions return an Operation. Normal functions can't do this
1043 # because their return values are converted to Tensors.
1044 if isinstance(for_result, ops.Operation):
1045 for_result = ()
1046 # Unary functions return a single Tensor value.
1047 elif isinstance(for_result, ops.Tensor):
1048 for_result = (for_result,)
1049 return (i + 1, n, start, delta) + tuple(for_result)
1051 if hostmem is not None:
1052 hostmem = [0, 1, 2, 3] + [(4 + _) for _ in hostmem]
1053 else:
1054 hostmem = [0, 1, 2, 3]
1056 results = While(
1057 input_=[0, n, start, delta] + inputs,
1058 cond=WhileCond,
1059 body=WhileBody,
1060 name=name,
1061 hostmem=hostmem)
1062 # Slice off the loop-carried captured inputs.
1063 return list(results[4:len(results)])
1066def For(start,
1067 limit,
1068 delta,
1069 inputs,
1070 body,
1071 name=None,
1072 hostmem=None,
1073 rewrite_with_while=None):
1074 r"""out = input; for i in range(start, limit, delta) out = body(i, out).
1076 Args:
1077 start: A `Tensor` of type `int32`.
1078 limit: A `Tensor` of type `int32`.
1079 delta: A `Tensor` of type `int32`.
1080 inputs: A list of `Tensor` objects. A list of input tensors whose types are
1081 T.
1082 body: A function takes a list of tensors and returns another list of
1083 tensors. Both lists have the same types as (int32, T...).
1084 name: A name for the operation (optional).
1085 hostmem: A list of integer. If i is in the list, inputs[i] is a host memory
1086 tensor. In other words, (i+1)-th argument of the body function is
1087 expecting a host memory.
1088 rewrite_with_while: If True, using While op to implement the For.
1090 Returns:
1091 A list of `Tensor` objects. Has the same type as `input`.
1092 A list of output tensors whose types are T.
1093 """
1094 if rewrite_with_while:
1095 return _ForUsingWhile(start, limit, delta, inputs, body, name, hostmem)
1096 if body.captured_inputs:
1097 ret = gen_functional_ops._for(
1098 start,
1099 limit,
1100 delta,
1101 inputs + body.captured_inputs,
1102 _LoopBodyCaptureWrapper(body),
1103 name=name)
1104 # Slice off the loop-carried captured inputs.
1105 ret = ret[:-len(body.captured_inputs)]
1106 else:
1107 ret = gen_functional_ops._for(start, limit, delta, inputs, body, name=name)
1108 if hostmem:
1109 num_for_params = 3 # start/limit/delta
1111 input_attr = attr_value_pb2.AttrValue()
1112 input_attr.list.i.extend([num_for_params + i for i in hostmem])
1113 ret[0].op._set_attr("_input_hostmem", input_attr) # pylint: disable=protected-access
1115 output_attr = attr_value_pb2.AttrValue()
1116 output_attr.list.i.extend(hostmem)
1117 ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access
1118 return ret