Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/map_fn.py: 24%
157 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
16"""Functional operations."""
19import re
21from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
22from tensorflow.python.autograph.impl import api as autograph
23from tensorflow.python.eager import context
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import sparse_tensor
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.framework import tensor_spec
29from tensorflow.python.framework import type_spec
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import tensor_array_ops
32from tensorflow.python.ops import variable_scope as vs
33from tensorflow.python.ops import while_loop
34from tensorflow.python.ops.ragged import ragged_tensor
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.util import deprecation
37from tensorflow.python.util import nest
38from tensorflow.python.util import variable_utils
39from tensorflow.python.util.tf_export import tf_export
42@tf_export(v1=["map_fn"])
43@deprecation.deprecated_args(None, "Use fn_output_signature instead", "dtype")
44def map_fn(fn,
45 elems,
46 dtype=None,
47 parallel_iterations=None,
48 back_prop=True,
49 swap_memory=False,
50 infer_shape=True,
51 name=None,
52 fn_output_signature=None):
53 """Transforms `elems` by applying `fn` to each element unstacked on axis 0.
55 See also `tf.scan`.
57 `map_fn` unstacks `elems` on axis 0 to obtain a sequence of elements;
58 calls `fn` to transform each element; and then stacks the transformed
59 values back together.
61 #### Mapping functions with single-Tensor inputs and outputs
63 If `elems` is a single tensor and `fn`'s signature is `tf.Tensor->tf.Tensor`,
64 then `map_fn(fn, elems)` is equivalent to
65 `tf.stack([fn(elem) for elem in tf.unstack(elems)])`. E.g.:
67 >>> tf.map_fn(fn=lambda t: tf.range(t, t + 3), elems=tf.constant([3, 5, 2]))
68 <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
69 array([[3, 4, 5],
70 [5, 6, 7],
71 [2, 3, 4]], dtype=int32)>
73 `map_fn(fn, elems).shape = [elems.shape[0]] + fn(elems[0]).shape`.
75 #### Mapping functions with multi-arity inputs and outputs
77 `map_fn` also supports functions with multi-arity inputs and outputs:
79 * If `elems` is a tuple (or nested structure) of tensors, then those tensors
80 must all have the same outer-dimension size (`num_elems`); and `fn` is
81 used to transform each tuple (or structure) of corresponding slices from
82 `elems`. E.g., if `elems` is a tuple `(t1, t2, t3)`, then `fn` is used to
83 transform each tuple of slices `(t1[i], t2[i], t3[i])`
84 (where `0 <= i < num_elems`).
86 * If `fn` returns a tuple (or nested structure) of tensors, then the
87 result is formed by stacking corresponding elements from those structures.
89 #### Specifying `fn`'s output signature
91 If `fn`'s input and output signatures are different, then the output
92 signature must be specified using `fn_output_signature`. (The input and
93 output signatures are differ if their structures, dtypes, or tensor types do
94 not match). E.g.:
96 >>> tf.map_fn(fn=tf.strings.length, # input & output have different dtypes
97 ... elems=tf.constant(["hello", "moon"]),
98 ... fn_output_signature=tf.int32)
99 <tf.Tensor: shape=(2,), dtype=int32, numpy=array([5, 4], dtype=int32)>
100 >>> tf.map_fn(fn=tf.strings.join, # input & output have different structures
101 ... elems=[tf.constant(['The', 'A']), tf.constant(['Dog', 'Cat'])],
102 ... fn_output_signature=tf.string)
103 <tf.Tensor: shape=(2,), dtype=string,
104 numpy=array([b'TheDog', b'ACat'], dtype=object)>
106 `fn_output_signature` can be specified using any of the following:
108 * A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`)
109 * A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`)
110 * A `tf.SparseTensorSpec` (to describe a `tf.sparse.SparseTensor`)
111 * A (possibly nested) tuple, list, or dict containing the above types.
113 #### RaggedTensors
115 `map_fn` supports `tf.RaggedTensor` inputs and outputs. In particular:
117 * If `elems` is a `RaggedTensor`, then `fn` will be called with each
118 row of that ragged tensor.
119 * If `elems` has only one ragged dimension, then the values passed to
120 `fn` will be `tf.Tensor`s.
121 * If `elems` has multiple ragged dimensions, then the values passed to
122 `fn` will be `tf.RaggedTensor`s with one fewer ragged dimension.
124 * If the result of `map_fn` should be a `RaggedTensor`, then use a
125 `tf.RaggedTensorSpec` to specify `fn_output_signature`.
126 * If `fn` returns `tf.Tensor`s with varying sizes, then use a
127 `tf.RaggedTensorSpec` with `ragged_rank=0` to combine them into a
128 single ragged tensor (which will have ragged_rank=1).
129 * If `fn` returns `tf.RaggedTensor`s, then use a `tf.RaggedTensorSpec`
130 with the same `ragged_rank`.
132 >>> # Example: RaggedTensor input
133 >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
134 >>> tf.map_fn(tf.reduce_sum, rt, fn_output_signature=tf.int32)
135 <tf.Tensor: shape=(4,), dtype=int32, numpy=array([6, 0, 9, 6], dtype=int32)>
137 >>> # Example: RaggedTensor output
138 >>> elems = tf.constant([3, 5, 0, 2])
139 >>> tf.map_fn(tf.range, elems,
140 ... fn_output_signature=tf.RaggedTensorSpec(shape=[None],
141 ... dtype=tf.int32))
142 <tf.RaggedTensor [[0, 1, 2], [0, 1, 2, 3, 4], [], [0, 1]]>
144 Note: `map_fn` should only be used if you need to map a function over the
145 *rows* of a `RaggedTensor`. If you wish to map a function over the
146 individual values, then you should use:
148 * `tf.ragged.map_flat_values(fn, rt)`
149 (if fn is expressible as TensorFlow ops)
150 * `rt.with_flat_values(map_fn(fn, rt.flat_values))`
151 (otherwise)
153 E.g.:
155 >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
156 >>> tf.ragged.map_flat_values(lambda x: x + 2, rt)
157 <tf.RaggedTensor [[3, 4, 5], [], [6, 7], [8]]>
159 #### SparseTensors
161 `map_fn` supports `tf.sparse.SparseTensor` inputs and outputs. In particular:
163 * If `elems` is a `SparseTensor`, then `fn` will be called with each row
164 of that sparse tensor. In particular, the value passed to `fn` will be a
165 `tf.sparse.SparseTensor` with one fewer dimension than `elems`.
167 * If the result of `map_fn` should be a `SparseTensor`, then use a
168 `tf.SparseTensorSpec` to specify `fn_output_signature`. The individual
169 `SparseTensor`s returned by `fn` will be stacked into a single
170 `SparseTensor` with one more dimension.
172 >>> # Example: SparseTensor input
173 >>> st = tf.sparse.SparseTensor([[0, 0], [2, 0], [2, 1]], [2, 3, 4], [4, 4])
174 >>> tf.map_fn(tf.sparse.reduce_sum, st, fn_output_signature=tf.int32)
175 <tf.Tensor: shape=(4,), dtype=int32, numpy=array([2, 0, 7, 0], dtype=int32)>
177 >>> # Example: SparseTensor output
178 >>> tf.sparse.to_dense(
179 ... tf.map_fn(tf.sparse.eye, tf.constant([2, 3]),
180 ... fn_output_signature=tf.SparseTensorSpec(None, tf.float32)))
181 <tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy=
182 array([[[1., 0., 0.],
183 [0., 1., 0.],
184 [0., 0., 0.]],
185 [[1., 0., 0.],
186 [0., 1., 0.],
187 [0., 0., 1.]]], dtype=float32)>
189 Note: `map_fn` should only be used if you need to map a function over the
190 *rows* of a `SparseTensor`. If you wish to map a function over the nonzero
191 values, then you should use:
193 * If the function is expressible as TensorFlow ops, use:
194 ```python
195 tf.sparse.SparseTensor(st.indices, fn(st.values), st.dense_shape)
196 ```
197 * Otherwise, use:
198 ```python
199 tf.sparse.SparseTensor(st.indices, tf.map_fn(fn, st.values),
200 st.dense_shape)
201 ```
203 #### `map_fn` vs. vectorized operations
205 `map_fn` will apply the operations used by `fn` to each element of `elems`,
206 resulting in `O(elems.shape[0])` total operations. This is somewhat
207 mitigated by the fact that `map_fn` can process elements in parallel.
208 However, a transform expressed using `map_fn` is still typically less
209 efficient than an equivalent transform expressed using vectorized operations.
211 `map_fn` should typically only be used if one of the following is true:
213 * It is difficult or expensive to express the desired transform with
214 vectorized operations.
215 * `fn` creates large intermediate values, so an equivalent vectorized
216 transform would take too much memory.
217 * Processing elements in parallel is more efficient than an equivalent
218 vectorized transform.
219 * Efficiency of the transform is not critical, and using `map_fn` is
220 more readable.
222 E.g., the example given above that maps `fn=lambda t: tf.range(t, t + 3)`
223 across `elems` could be rewritten more efficiently using vectorized ops:
225 >>> elems = tf.constant([3, 5, 2])
226 >>> tf.range(3) + tf.expand_dims(elems, 1)
227 <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
228 array([[3, 4, 5],
229 [5, 6, 7],
230 [2, 3, 4]], dtype=int32)>
232 In some cases, `tf.vectorized_map` can be used to automatically convert a
233 function to a vectorized equivalent.
235 #### Eager execution
237 When executing eagerly, `map_fn` does not execute in parallel even if
238 `parallel_iterations` is set to a value > 1. You can still get the
239 performance benefits of running a function in parallel by using the
240 `tf.function` decorator:
242 >>> fn=lambda t: tf.range(t, t + 3)
243 >>> @tf.function
244 ... def func(elems):
245 ... return tf.map_fn(fn, elems, parallel_iterations=3)
246 >>> func(tf.constant([3, 5, 2]))
247 <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
248 array([[3, 4, 5],
249 [5, 6, 7],
250 [2, 3, 4]], dtype=int32)>
253 Note: if you use the `tf.function` decorator, any non-TensorFlow Python
254 code that you may have written in your function won't get executed. See
255 `tf.function` for more details. The recommendation would be to debug without
256 `tf.function` but switch to it to get performance benefits of running `map_fn`
257 in parallel.
259 Args:
260 fn: The callable to be performed. It accepts one argument, which will have
261 the same (possibly nested) structure as `elems`. Its output must have the
262 same structure as `fn_output_signature` if one is provided; otherwise it
263 must have the same structure as `elems`.
264 elems: A tensor or (possibly nested) sequence of tensors, each of which will
265 be unstacked along their first dimension. `fn` will be applied to the
266 nested sequence of the resulting slices. `elems` may include ragged and
267 sparse tensors. `elems` must consist of at least one tensor.
268 dtype: Deprecated: Equivalent to `fn_output_signature`.
269 parallel_iterations: (optional) The number of iterations allowed to run in
270 parallel. When graph building, the default value is 10. While executing
271 eagerly, the default value is set to 1.
272 back_prop: (optional) False disables support for back propagation.
273 swap_memory: (optional) True enables GPU-CPU memory swapping.
274 infer_shape: (optional) False disables tests for consistent output shapes.
275 name: (optional) Name prefix for the returned tensors.
276 fn_output_signature: The output signature of `fn`. Must be specified if
277 `fn`'s input and output signatures are different (i.e., if their
278 structures, dtypes, or tensor types do not match).
279 `fn_output_signature` can be specified using any of the following:
281 * A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`)
282 * A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`)
283 * A `tf.SparseTensorSpec` (to describe a `tf.sparse.SparseTensor`)
284 * A (possibly nested) tuple, list, or dict containing the above types.
286 Returns:
287 A tensor or (possibly nested) sequence of tensors. Each tensor stacks the
288 results of applying `fn` to tensors unstacked from `elems` along the first
289 dimension, from first to last. The result may include ragged and sparse
290 tensors.
292 Raises:
293 TypeError: if `fn` is not callable or the structure of the output of
294 `fn` and `fn_output_signature` do not match.
295 ValueError: if the lengths of the output of `fn` and `fn_output_signature`
296 do not match, or if the `elems` does not contain any tensor.
298 Examples:
300 >>> elems = np.array([1, 2, 3, 4, 5, 6])
301 >>> tf.map_fn(lambda x: x * x, elems)
302 <tf.Tensor: shape=(6,), dtype=int64, numpy=array([ 1, 4, 9, 16, 25, 36])>
304 >>> elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
305 >>> tf.map_fn(lambda x: x[0] * x[1], elems, fn_output_signature=tf.int64)
306 <tf.Tensor: shape=(3,), dtype=int64, numpy=array([-1, 2, -3])>
308 >>> elems = np.array([1, 2, 3])
309 >>> tf.map_fn(lambda x: (x, -x), elems,
310 ... fn_output_signature=(tf.int64, tf.int64))
311 (<tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 2, 3])>,
312 <tf.Tensor: shape=(3,), dtype=int64, numpy=array([-1, -2, -3])>)
313 """
314 # This function uses a `while_loop` to call `fn` on each value of the input
315 # tensor(s) (unstacked on dimension 0). The following sequence of variables
316 # are used to transform the input tensor(s) (`elems`) into the output
317 # tensor(s) (`result`):
318 #
319 # - Preparing and unstacking input values for the while_loop:
320 # - elems: The input tensor(s) to map_fn. May include composite tensors.
321 # - elems_flat: Flattened list of tensors from elems (using nest.flatten)
322 # May include composite tensors.
323 # - elems_batchable: Concatenation of "batchable tensor lists" for each
324 # tensor in elems_flat. This "boxes" composite tensors
325 # into sliceable tf.Tensor objects. For more info see:
326 # TensorSpec._to_batched_tensor_list
327 # - elems_batchable_ta: List of TensorArrays used to unstack each Tensor
328 # in elems_batchable into elems_value_batchable.
329 #
330 # - Calling `fn` on each unstacked value in the body of the while_loop:
331 # - elems_value_batchable: Single unstacked value from elems_batchable.
332 # - elems_value_flat: Single unstacked value from elems_flat,
333 # constructed from elems_value_batchable (using
334 # TensorSpec._from_tensor_list).
335 # - elems_value: Single unstacked value from elems (the input to fn).
336 # - result_value: Result of calling `fn(elems_value)`. May contain
337 # composite tensors.
338 # - result_value_flat: Flattened list of tensors from result_value.
339 # May contain composite tensors.
340 # - result_value_batchable: Concatenation of batchable tensor lists for
341 # each tensor in result_value_flat
342 # (using TensorSpec._to_tensor_list).
343 #
344 # - Collecting and stacking output values from the while_loop:
345 # - result_batchable_ta: List of TensorArrays used to stack each tensor
346 # ta result_value_batchable into result_batchable.
347 # - result_batchable: Stacked tensors from result_batchable_ta.
348 # - result_flat: Flat list of tensors for the result, constructed from
349 # results bactchable (using TensorSpec._from_tensor_list).
350 # - result: Structured result value packed from results flat
351 # (using nest.pack_sequence_as).
353 if fn_output_signature is None:
354 fn_output_signature = dtype
356 if not callable(fn):
357 raise TypeError(f"The provided function {fn.__name__} is not callable."
358 "fn must be callable.")
360 in_graph_mode = not context.executing_eagerly()
361 # Set the default number of parallel_iterations depending on graph/eager mode.
362 if in_graph_mode and not parallel_iterations:
363 parallel_iterations = 10
364 elif not in_graph_mode and not parallel_iterations:
365 parallel_iterations = 1
366 elif not in_graph_mode and parallel_iterations > 1:
367 logging.log_first_n(
368 logging.WARN, "Setting parallel_iterations > 1 has no "
369 "effect when executing eagerly. Consider calling map_fn"
370 " with tf.function to execute fn in "
371 "parallel.", 1)
372 parallel_iterations = 1
374 # Explicitly read values of ResourceVariables.
375 elems = variable_utils.convert_variables_to_tensors(elems)
376 # Flatten the input tensors, and get the TypeSpec for each one.
377 elems_flat = nest.flatten(elems)
379 # Check in case this is an empty list
380 if len(elems_flat) == 0:
381 raise ValueError(
382 "elems must be a Tensor or (possibly nested) sequence of Tensors. "
383 "Got {}, which does not contain any Tensors.".format(elems))
385 elems_flat_signature = [type_spec.type_spec_from_value(e) for e in elems_flat]
386 elems_unflatten = lambda x: nest.pack_sequence_as(elems, x)
388 # Flatten fn's output signature.
389 if fn_output_signature is None:
390 # If fn_output_signature was not specified, then assume that it matches the
391 # input signature.
392 result_flat_signature = [
393 _most_general_compatible_type(s)._unbatch() # pylint: disable=protected-access
394 for s in elems_flat_signature
395 ]
396 result_unflatten = elems_unflatten
397 else:
398 result_flat_signature = [
399 _dtype_to_spec(d) for d in nest.flatten(fn_output_signature)
400 ]
401 result_unflatten = lambda x: nest.pack_sequence_as(fn_output_signature, x)
403 with ops.name_scope(name, "map", elems_flat):
404 # TODO(akshayka): Remove the in_graph_mode check once caching devices are
405 # supported in Eager
406 if in_graph_mode:
407 # Any get_variable calls in fn will cache the first call locally
408 # and not issue repeated network I/O requests for each iteration.
409 varscope = vs.get_variable_scope()
410 varscope_caching_device_was_none = False
411 if varscope.caching_device is None:
412 # TODO(ebrevdo): Change to using colocate_with here and in other
413 # methods.
414 varscope.set_caching_device(lambda op: op.device)
415 varscope_caching_device_was_none = True
417 elems_flat = [
418 ops.convert_to_tensor_or_composite(t, name="elem") for t in elems_flat
419 ]
421 # Check that inputs are not scalars.
422 first_elem = elems_flat[0]
423 if hasattr(first_elem, "shape"):
424 elems_static_shape = first_elem.shape
425 if elems_static_shape.ndims is not None and elems_static_shape.ndims < 1:
426 raise ValueError(
427 "Elements in elems must be 1+ dimensional Tensors, not scalars")
429 # Box any composite tensors into tensor lists.
430 elems_batchable = _elems_flat_to_batchable(elems_flat)
432 # Find the number of iterations, n. (may be known statically.)
433 n_static = tensor_shape.Dimension(
434 tensor_shape.dimension_value(
435 elems_batchable[0].get_shape().with_rank_at_least(1)[0]))
436 for tensor in elems_batchable[1:]:
437 n_static.assert_is_compatible_with(
438 tensor_shape.Dimension(
439 tensor_shape.dimension_value(
440 tensor.get_shape().with_rank_at_least(1)[0])))
441 n = n_static.value or array_ops.shape(elems_batchable[0])[0]
443 # Convert elems to tensor array.
444 # TODO(edloper): Should we set infer_shape=False for composite tensors?
445 elems_batchable_ta = [
446 tensor_array_ops.TensorArray(
447 dtype=t.dtype, size=n, dynamic_size=False, infer_shape=True)
448 for t in elems_batchable
449 ]
450 # Unpack elements
451 elems_batchable_ta = [
452 ta.unstack(t) for (ta, t) in zip(elems_batchable_ta, elems_batchable)
453 ]
455 i = constant_op.constant(0)
457 # Prepare result tensor array.
458 # TODO(edloper): Should we set infer_shape=False for composite tensors?
459 result_batchable_tensor_spec = (
460 _result_flat_signature_to_batchable_tensor_spec(result_flat_signature))
461 result_batchable_ta = []
462 for spec in result_batchable_tensor_spec:
463 result_batchable_ta.append(
464 tensor_array_ops.TensorArray(
465 dtype=spec.dtype, size=n, dynamic_size=False,
466 infer_shape=infer_shape, element_shape=spec.shape))
468 def compute(i, tas):
469 """The loop body of map_fn.
471 Args:
472 i: the loop counter
473 tas: the flat TensorArray accumulator list
475 Returns:
476 (i + 1, tas): the updated counter + updated TensorArrays
478 Raises:
479 TypeError: if fn_output_signature and result_value structure don't match
480 ValueType: if fn_output_signature and result_value lengths don't match
481 """
482 elems_value_batchable = [ta.read(i) for ta in elems_batchable_ta]
483 elems_value_flat = _elems_value_batchable_to_flat(elems_value_batchable,
484 elems_flat_signature)
485 elems_value = elems_unflatten(elems_value_flat)
486 ag_ctx = autograph_ctx.control_status_ctx()
487 autographed_fn = autograph.tf_convert(fn, ag_ctx)
488 result_value = autographed_fn(elems_value)
489 nest.assert_same_structure(fn_output_signature or elems, result_value)
490 result_value_flat = nest.flatten(result_value)
491 result_value_batchable = _result_value_flat_to_batchable(
492 result_value_flat, result_flat_signature)
493 tas = [
494 ta.write(i, value) for (ta, value) in zip(tas, result_value_batchable)
495 ]
496 return (i + 1, tas)
498 _, r_a = while_loop.while_loop(
499 lambda i, _: i < n,
500 compute, (i, result_batchable_ta),
501 parallel_iterations=parallel_iterations,
502 back_prop=back_prop,
503 swap_memory=swap_memory,
504 maximum_iterations=n)
505 result_batchable = [r.stack() for r in r_a]
507 # Update each output tensor w/ static shape info about the outer dimension.
508 for r in result_batchable:
509 r.set_shape(tensor_shape.TensorShape(n_static).concatenate(
510 r.get_shape()[1:]))
512 # TODO(akshayka): Remove the in_graph_mode check once caching devices are
513 # supported in Eager
514 if in_graph_mode and varscope_caching_device_was_none:
515 varscope.set_caching_device(None)
517 result_flat = _result_batchable_to_flat(result_batchable,
518 result_flat_signature,
519 n_static)
520 result = result_unflatten(result_flat)
521 return result
524def _dtype_to_spec(d):
525 if not isinstance(d, type_spec.TypeSpec):
526 d = tensor_spec.TensorSpec(None, d)
527 return d
530def _most_general_compatible_type(spec):
531 """Returns the most general TypeSpec compatible with `spec`."""
532 # TODO(edloper): Consider adding most_general_compatible_type to TypeSpec API
533 if isinstance(spec, tensor_spec.TensorSpec):
534 return tensor_spec.TensorSpec(None, spec.dtype)
535 elif isinstance(spec, ragged_tensor.RaggedTensorSpec):
536 # pylint: disable=protected-access
537 return ragged_tensor.RaggedTensorSpec(None, spec._dtype, spec._ragged_rank,
538 spec._row_splits_dtype)
539 elif isinstance(spec, sparse_tensor.SparseTensorSpec):
540 # pylint: disable=protected-access
541 return sparse_tensor.SparseTensorSpec(None, spec.dtype)
542 else:
543 return spec
546def _result_flat_signature_to_batchable_tensor_spec(result_flat_signature):
547 """Converts result_flat_signature -> result_batchable_tensor_specs."""
548 tensor_specs = []
549 for spec in result_flat_signature:
550 if not isinstance(spec, type_spec.BatchableTypeSpec):
551 raise TypeError("map_fn can not generate %s outputs" % (spec,))
552 tensor_specs.extend(spec._flat_tensor_specs) # pylint: disable=protected-access
553 return tensor_specs
556def _elems_flat_to_batchable(elems_flat):
557 """Converts elems_flat -> elems_batchable."""
558 elems_batchable = []
559 for elems_tensor in elems_flat:
560 spec = type_spec.type_spec_from_value(elems_tensor)
561 if not isinstance(spec, type_spec.BatchableTypeSpec):
562 raise TypeError("map_fn can not consume %s inputs: got %r" %
563 (spec, elems_tensor))
564 # pylint: disable=protected-access
565 elems_batchable.extend(spec._to_batched_tensor_list(elems_tensor))
566 return elems_batchable
569def _elems_value_batchable_to_flat(elems_value_batchable, elems_flat_signature):
570 """Converts elems_value_batchable -> elems_value_flat."""
571 elems_value_flat = []
572 i = 0
573 for spec in elems_flat_signature:
574 # pylint: disable=protected-access
575 spec = spec._unbatch()
576 tensor_list = elems_value_batchable[i:i + len(spec._flat_tensor_specs)]
577 elems_value_flat.append(spec._from_compatible_tensor_list(tensor_list))
578 i += len(tensor_list)
579 assert i == len(elems_value_batchable)
580 return elems_value_flat
583def _result_value_flat_to_batchable(result_value_flat, result_flat_signature):
584 """Converts result_value_flat -> result_value_batchable."""
585 result_value_batchable = []
586 for (r_value, r_spec) in zip(result_value_flat, result_flat_signature):
587 if isinstance(r_spec, tensor_spec.TensorSpec):
588 result_value_batchable.append(r_value)
589 else:
590 if not r_spec.is_compatible_with(r_value):
591 raise ValueError(
592 "Error in map_fn:\n Expected `fn` to return a:\n %s\n"
593 " But it returned a:\n %s\n (value=%s)\n"
594 " To fix, update the `fn_output_signature` (or `dtype`) "
595 "argument to `map_fn`." %
596 (r_spec, type_spec.type_spec_from_value(r_value), r_value))
597 result_value_batchable.extend(r_spec._to_tensor_list(r_value)) # pylint: disable=protected-access
598 return result_value_batchable
601def _result_batchable_to_flat(result_batchable, result_flat_signature,
602 batch_size):
603 """Converts result_batchable -> result_flat."""
604 result_flat = []
605 i = 0
606 for spec in result_flat_signature:
607 # pylint: disable=protected-access
608 num_tensors = len(spec._flat_tensor_specs)
609 result_flat.append(
610 spec._batch(batch_size)._from_compatible_tensor_list(
611 result_batchable[i:i + num_tensors]))
612 i += num_tensors
613 assert i == len(result_batchable)
614 return result_flat
617@tf_export("map_fn", v1=[])
618@deprecation.deprecated_arg_values(
619 None,
620 """back_prop=False is deprecated. Consider using tf.stop_gradient instead.
621Instead of:
622results = tf.map_fn(fn, elems, back_prop=False)
623Use:
624results = tf.nest.map_structure(tf.stop_gradient, tf.map_fn(fn, elems))""",
625 warn_once=True,
626 back_prop=False)
627@deprecation.deprecated_args(None, "Use fn_output_signature instead", "dtype")
628def map_fn_v2(fn,
629 elems,
630 dtype=None,
631 parallel_iterations=None,
632 back_prop=True,
633 swap_memory=False,
634 infer_shape=True,
635 name=None,
636 fn_output_signature=None):
637 """Transform `elems` by applying `fn` to each element unstacked on axis 0."""
638 if fn_output_signature is None:
639 fn_output_signature = dtype
640 return map_fn(
641 fn=fn,
642 elems=elems,
643 fn_output_signature=fn_output_signature,
644 parallel_iterations=parallel_iterations,
645 back_prop=back_prop,
646 swap_memory=swap_memory,
647 infer_shape=infer_shape,
648 name=name)
651# Docstring for v2 is the same as v1, except that back_prop is deprecated.
652map_fn_v2.__doc__ = re.sub(
653 r"( back_prop: \(optional\) )(.*)",
654 r"\1Deprecated: prefer using `tf.stop_gradient` instead. \2",
655 map_fn.__doc__)
656assert "prefer using `tf.stop_gradient` instead" in map_fn_v2.__doc__