Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/map_fn.py: 24%

157 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================= 

15 

16"""Functional operations.""" 

17 

18 

19import re 

20 

21from tensorflow.python.autograph.core import ag_ctx as autograph_ctx 

22from tensorflow.python.autograph.impl import api as autograph 

23from tensorflow.python.eager import context 

24from tensorflow.python.framework import constant_op 

25from tensorflow.python.framework import ops 

26from tensorflow.python.framework import sparse_tensor 

27from tensorflow.python.framework import tensor_shape 

28from tensorflow.python.framework import tensor_spec 

29from tensorflow.python.framework import type_spec 

30from tensorflow.python.ops import array_ops 

31from tensorflow.python.ops import tensor_array_ops 

32from tensorflow.python.ops import variable_scope as vs 

33from tensorflow.python.ops import while_loop 

34from tensorflow.python.ops.ragged import ragged_tensor 

35from tensorflow.python.platform import tf_logging as logging 

36from tensorflow.python.util import deprecation 

37from tensorflow.python.util import nest 

38from tensorflow.python.util import variable_utils 

39from tensorflow.python.util.tf_export import tf_export 

40 

41 

42@tf_export(v1=["map_fn"]) 

43@deprecation.deprecated_args(None, "Use fn_output_signature instead", "dtype") 

44def map_fn(fn, 

45 elems, 

46 dtype=None, 

47 parallel_iterations=None, 

48 back_prop=True, 

49 swap_memory=False, 

50 infer_shape=True, 

51 name=None, 

52 fn_output_signature=None): 

53 """Transforms `elems` by applying `fn` to each element unstacked on axis 0. 

54 

55 See also `tf.scan`. 

56 

57 `map_fn` unstacks `elems` on axis 0 to obtain a sequence of elements; 

58 calls `fn` to transform each element; and then stacks the transformed 

59 values back together. 

60 

61 #### Mapping functions with single-Tensor inputs and outputs 

62 

63 If `elems` is a single tensor and `fn`'s signature is `tf.Tensor->tf.Tensor`, 

64 then `map_fn(fn, elems)` is equivalent to 

65 `tf.stack([fn(elem) for elem in tf.unstack(elems)])`. E.g.: 

66 

67 >>> tf.map_fn(fn=lambda t: tf.range(t, t + 3), elems=tf.constant([3, 5, 2])) 

68 <tf.Tensor: shape=(3, 3), dtype=int32, numpy= 

69 array([[3, 4, 5], 

70 [5, 6, 7], 

71 [2, 3, 4]], dtype=int32)> 

72 

73 `map_fn(fn, elems).shape = [elems.shape[0]] + fn(elems[0]).shape`. 

74 

75 #### Mapping functions with multi-arity inputs and outputs 

76 

77 `map_fn` also supports functions with multi-arity inputs and outputs: 

78 

79 * If `elems` is a tuple (or nested structure) of tensors, then those tensors 

80 must all have the same outer-dimension size (`num_elems`); and `fn` is 

81 used to transform each tuple (or structure) of corresponding slices from 

82 `elems`. E.g., if `elems` is a tuple `(t1, t2, t3)`, then `fn` is used to 

83 transform each tuple of slices `(t1[i], t2[i], t3[i])` 

84 (where `0 <= i < num_elems`). 

85 

86 * If `fn` returns a tuple (or nested structure) of tensors, then the 

87 result is formed by stacking corresponding elements from those structures. 

88 

89 #### Specifying `fn`'s output signature 

90 

91 If `fn`'s input and output signatures are different, then the output 

92 signature must be specified using `fn_output_signature`. (The input and 

93 output signatures are differ if their structures, dtypes, or tensor types do 

94 not match). E.g.: 

95 

96 >>> tf.map_fn(fn=tf.strings.length, # input & output have different dtypes 

97 ... elems=tf.constant(["hello", "moon"]), 

98 ... fn_output_signature=tf.int32) 

99 <tf.Tensor: shape=(2,), dtype=int32, numpy=array([5, 4], dtype=int32)> 

100 >>> tf.map_fn(fn=tf.strings.join, # input & output have different structures 

101 ... elems=[tf.constant(['The', 'A']), tf.constant(['Dog', 'Cat'])], 

102 ... fn_output_signature=tf.string) 

103 <tf.Tensor: shape=(2,), dtype=string, 

104 numpy=array([b'TheDog', b'ACat'], dtype=object)> 

105 

106 `fn_output_signature` can be specified using any of the following: 

107 

108 * A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`) 

109 * A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`) 

110 * A `tf.SparseTensorSpec` (to describe a `tf.sparse.SparseTensor`) 

111 * A (possibly nested) tuple, list, or dict containing the above types. 

112 

113 #### RaggedTensors 

114 

115 `map_fn` supports `tf.RaggedTensor` inputs and outputs. In particular: 

116 

117 * If `elems` is a `RaggedTensor`, then `fn` will be called with each 

118 row of that ragged tensor. 

119 * If `elems` has only one ragged dimension, then the values passed to 

120 `fn` will be `tf.Tensor`s. 

121 * If `elems` has multiple ragged dimensions, then the values passed to 

122 `fn` will be `tf.RaggedTensor`s with one fewer ragged dimension. 

123 

124 * If the result of `map_fn` should be a `RaggedTensor`, then use a 

125 `tf.RaggedTensorSpec` to specify `fn_output_signature`. 

126 * If `fn` returns `tf.Tensor`s with varying sizes, then use a 

127 `tf.RaggedTensorSpec` with `ragged_rank=0` to combine them into a 

128 single ragged tensor (which will have ragged_rank=1). 

129 * If `fn` returns `tf.RaggedTensor`s, then use a `tf.RaggedTensorSpec` 

130 with the same `ragged_rank`. 

131 

132 >>> # Example: RaggedTensor input 

133 >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]]) 

134 >>> tf.map_fn(tf.reduce_sum, rt, fn_output_signature=tf.int32) 

135 <tf.Tensor: shape=(4,), dtype=int32, numpy=array([6, 0, 9, 6], dtype=int32)> 

136 

137 >>> # Example: RaggedTensor output 

138 >>> elems = tf.constant([3, 5, 0, 2]) 

139 >>> tf.map_fn(tf.range, elems, 

140 ... fn_output_signature=tf.RaggedTensorSpec(shape=[None], 

141 ... dtype=tf.int32)) 

142 <tf.RaggedTensor [[0, 1, 2], [0, 1, 2, 3, 4], [], [0, 1]]> 

143 

144 Note: `map_fn` should only be used if you need to map a function over the 

145 *rows* of a `RaggedTensor`. If you wish to map a function over the 

146 individual values, then you should use: 

147 

148 * `tf.ragged.map_flat_values(fn, rt)` 

149 (if fn is expressible as TensorFlow ops) 

150 * `rt.with_flat_values(map_fn(fn, rt.flat_values))` 

151 (otherwise) 

152 

153 E.g.: 

154 

155 >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]]) 

156 >>> tf.ragged.map_flat_values(lambda x: x + 2, rt) 

157 <tf.RaggedTensor [[3, 4, 5], [], [6, 7], [8]]> 

158 

159 #### SparseTensors 

160 

161 `map_fn` supports `tf.sparse.SparseTensor` inputs and outputs. In particular: 

162 

163 * If `elems` is a `SparseTensor`, then `fn` will be called with each row 

164 of that sparse tensor. In particular, the value passed to `fn` will be a 

165 `tf.sparse.SparseTensor` with one fewer dimension than `elems`. 

166 

167 * If the result of `map_fn` should be a `SparseTensor`, then use a 

168 `tf.SparseTensorSpec` to specify `fn_output_signature`. The individual 

169 `SparseTensor`s returned by `fn` will be stacked into a single 

170 `SparseTensor` with one more dimension. 

171 

172 >>> # Example: SparseTensor input 

173 >>> st = tf.sparse.SparseTensor([[0, 0], [2, 0], [2, 1]], [2, 3, 4], [4, 4]) 

174 >>> tf.map_fn(tf.sparse.reduce_sum, st, fn_output_signature=tf.int32) 

175 <tf.Tensor: shape=(4,), dtype=int32, numpy=array([2, 0, 7, 0], dtype=int32)> 

176 

177 >>> # Example: SparseTensor output 

178 >>> tf.sparse.to_dense( 

179 ... tf.map_fn(tf.sparse.eye, tf.constant([2, 3]), 

180 ... fn_output_signature=tf.SparseTensorSpec(None, tf.float32))) 

181 <tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy= 

182 array([[[1., 0., 0.], 

183 [0., 1., 0.], 

184 [0., 0., 0.]], 

185 [[1., 0., 0.], 

186 [0., 1., 0.], 

187 [0., 0., 1.]]], dtype=float32)> 

188 

189 Note: `map_fn` should only be used if you need to map a function over the 

190 *rows* of a `SparseTensor`. If you wish to map a function over the nonzero 

191 values, then you should use: 

192 

193 * If the function is expressible as TensorFlow ops, use: 

194 ```python 

195 tf.sparse.SparseTensor(st.indices, fn(st.values), st.dense_shape) 

196 ``` 

197 * Otherwise, use: 

198 ```python 

199 tf.sparse.SparseTensor(st.indices, tf.map_fn(fn, st.values), 

200 st.dense_shape) 

201 ``` 

202 

203 #### `map_fn` vs. vectorized operations 

204 

205 `map_fn` will apply the operations used by `fn` to each element of `elems`, 

206 resulting in `O(elems.shape[0])` total operations. This is somewhat 

207 mitigated by the fact that `map_fn` can process elements in parallel. 

208 However, a transform expressed using `map_fn` is still typically less 

209 efficient than an equivalent transform expressed using vectorized operations. 

210 

211 `map_fn` should typically only be used if one of the following is true: 

212 

213 * It is difficult or expensive to express the desired transform with 

214 vectorized operations. 

215 * `fn` creates large intermediate values, so an equivalent vectorized 

216 transform would take too much memory. 

217 * Processing elements in parallel is more efficient than an equivalent 

218 vectorized transform. 

219 * Efficiency of the transform is not critical, and using `map_fn` is 

220 more readable. 

221 

222 E.g., the example given above that maps `fn=lambda t: tf.range(t, t + 3)` 

223 across `elems` could be rewritten more efficiently using vectorized ops: 

224 

225 >>> elems = tf.constant([3, 5, 2]) 

226 >>> tf.range(3) + tf.expand_dims(elems, 1) 

227 <tf.Tensor: shape=(3, 3), dtype=int32, numpy= 

228 array([[3, 4, 5], 

229 [5, 6, 7], 

230 [2, 3, 4]], dtype=int32)> 

231 

232 In some cases, `tf.vectorized_map` can be used to automatically convert a 

233 function to a vectorized equivalent. 

234 

235 #### Eager execution 

236 

237 When executing eagerly, `map_fn` does not execute in parallel even if 

238 `parallel_iterations` is set to a value > 1. You can still get the 

239 performance benefits of running a function in parallel by using the 

240 `tf.function` decorator: 

241 

242 >>> fn=lambda t: tf.range(t, t + 3) 

243 >>> @tf.function 

244 ... def func(elems): 

245 ... return tf.map_fn(fn, elems, parallel_iterations=3) 

246 >>> func(tf.constant([3, 5, 2])) 

247 <tf.Tensor: shape=(3, 3), dtype=int32, numpy= 

248 array([[3, 4, 5], 

249 [5, 6, 7], 

250 [2, 3, 4]], dtype=int32)> 

251 

252 

253 Note: if you use the `tf.function` decorator, any non-TensorFlow Python 

254 code that you may have written in your function won't get executed. See 

255 `tf.function` for more details. The recommendation would be to debug without 

256 `tf.function` but switch to it to get performance benefits of running `map_fn` 

257 in parallel. 

258 

259 Args: 

260 fn: The callable to be performed. It accepts one argument, which will have 

261 the same (possibly nested) structure as `elems`. Its output must have the 

262 same structure as `fn_output_signature` if one is provided; otherwise it 

263 must have the same structure as `elems`. 

264 elems: A tensor or (possibly nested) sequence of tensors, each of which will 

265 be unstacked along their first dimension. `fn` will be applied to the 

266 nested sequence of the resulting slices. `elems` may include ragged and 

267 sparse tensors. `elems` must consist of at least one tensor. 

268 dtype: Deprecated: Equivalent to `fn_output_signature`. 

269 parallel_iterations: (optional) The number of iterations allowed to run in 

270 parallel. When graph building, the default value is 10. While executing 

271 eagerly, the default value is set to 1. 

272 back_prop: (optional) False disables support for back propagation. 

273 swap_memory: (optional) True enables GPU-CPU memory swapping. 

274 infer_shape: (optional) False disables tests for consistent output shapes. 

275 name: (optional) Name prefix for the returned tensors. 

276 fn_output_signature: The output signature of `fn`. Must be specified if 

277 `fn`'s input and output signatures are different (i.e., if their 

278 structures, dtypes, or tensor types do not match). 

279 `fn_output_signature` can be specified using any of the following: 

280 

281 * A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`) 

282 * A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`) 

283 * A `tf.SparseTensorSpec` (to describe a `tf.sparse.SparseTensor`) 

284 * A (possibly nested) tuple, list, or dict containing the above types. 

285 

286 Returns: 

287 A tensor or (possibly nested) sequence of tensors. Each tensor stacks the 

288 results of applying `fn` to tensors unstacked from `elems` along the first 

289 dimension, from first to last. The result may include ragged and sparse 

290 tensors. 

291 

292 Raises: 

293 TypeError: if `fn` is not callable or the structure of the output of 

294 `fn` and `fn_output_signature` do not match. 

295 ValueError: if the lengths of the output of `fn` and `fn_output_signature` 

296 do not match, or if the `elems` does not contain any tensor. 

297 

298 Examples: 

299 

300 >>> elems = np.array([1, 2, 3, 4, 5, 6]) 

301 >>> tf.map_fn(lambda x: x * x, elems) 

302 <tf.Tensor: shape=(6,), dtype=int64, numpy=array([ 1, 4, 9, 16, 25, 36])> 

303 

304 >>> elems = (np.array([1, 2, 3]), np.array([-1, 1, -1])) 

305 >>> tf.map_fn(lambda x: x[0] * x[1], elems, fn_output_signature=tf.int64) 

306 <tf.Tensor: shape=(3,), dtype=int64, numpy=array([-1, 2, -3])> 

307 

308 >>> elems = np.array([1, 2, 3]) 

309 >>> tf.map_fn(lambda x: (x, -x), elems, 

310 ... fn_output_signature=(tf.int64, tf.int64)) 

311 (<tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 2, 3])>, 

312 <tf.Tensor: shape=(3,), dtype=int64, numpy=array([-1, -2, -3])>) 

313 """ 

314 # This function uses a `while_loop` to call `fn` on each value of the input 

315 # tensor(s) (unstacked on dimension 0). The following sequence of variables 

316 # are used to transform the input tensor(s) (`elems`) into the output 

317 # tensor(s) (`result`): 

318 # 

319 # - Preparing and unstacking input values for the while_loop: 

320 # - elems: The input tensor(s) to map_fn. May include composite tensors. 

321 # - elems_flat: Flattened list of tensors from elems (using nest.flatten) 

322 # May include composite tensors. 

323 # - elems_batchable: Concatenation of "batchable tensor lists" for each 

324 # tensor in elems_flat. This "boxes" composite tensors 

325 # into sliceable tf.Tensor objects. For more info see: 

326 # TensorSpec._to_batched_tensor_list 

327 # - elems_batchable_ta: List of TensorArrays used to unstack each Tensor 

328 # in elems_batchable into elems_value_batchable. 

329 # 

330 # - Calling `fn` on each unstacked value in the body of the while_loop: 

331 # - elems_value_batchable: Single unstacked value from elems_batchable. 

332 # - elems_value_flat: Single unstacked value from elems_flat, 

333 # constructed from elems_value_batchable (using 

334 # TensorSpec._from_tensor_list). 

335 # - elems_value: Single unstacked value from elems (the input to fn). 

336 # - result_value: Result of calling `fn(elems_value)`. May contain 

337 # composite tensors. 

338 # - result_value_flat: Flattened list of tensors from result_value. 

339 # May contain composite tensors. 

340 # - result_value_batchable: Concatenation of batchable tensor lists for 

341 # each tensor in result_value_flat 

342 # (using TensorSpec._to_tensor_list). 

343 # 

344 # - Collecting and stacking output values from the while_loop: 

345 # - result_batchable_ta: List of TensorArrays used to stack each tensor 

346 # ta result_value_batchable into result_batchable. 

347 # - result_batchable: Stacked tensors from result_batchable_ta. 

348 # - result_flat: Flat list of tensors for the result, constructed from 

349 # results bactchable (using TensorSpec._from_tensor_list). 

350 # - result: Structured result value packed from results flat 

351 # (using nest.pack_sequence_as). 

352 

353 if fn_output_signature is None: 

354 fn_output_signature = dtype 

355 

356 if not callable(fn): 

357 raise TypeError(f"The provided function {fn.__name__} is not callable." 

358 "fn must be callable.") 

359 

360 in_graph_mode = not context.executing_eagerly() 

361 # Set the default number of parallel_iterations depending on graph/eager mode. 

362 if in_graph_mode and not parallel_iterations: 

363 parallel_iterations = 10 

364 elif not in_graph_mode and not parallel_iterations: 

365 parallel_iterations = 1 

366 elif not in_graph_mode and parallel_iterations > 1: 

367 logging.log_first_n( 

368 logging.WARN, "Setting parallel_iterations > 1 has no " 

369 "effect when executing eagerly. Consider calling map_fn" 

370 " with tf.function to execute fn in " 

371 "parallel.", 1) 

372 parallel_iterations = 1 

373 

374 # Explicitly read values of ResourceVariables. 

375 elems = variable_utils.convert_variables_to_tensors(elems) 

376 # Flatten the input tensors, and get the TypeSpec for each one. 

377 elems_flat = nest.flatten(elems) 

378 

379 # Check in case this is an empty list 

380 if len(elems_flat) == 0: 

381 raise ValueError( 

382 "elems must be a Tensor or (possibly nested) sequence of Tensors. " 

383 "Got {}, which does not contain any Tensors.".format(elems)) 

384 

385 elems_flat_signature = [type_spec.type_spec_from_value(e) for e in elems_flat] 

386 elems_unflatten = lambda x: nest.pack_sequence_as(elems, x) 

387 

388 # Flatten fn's output signature. 

389 if fn_output_signature is None: 

390 # If fn_output_signature was not specified, then assume that it matches the 

391 # input signature. 

392 result_flat_signature = [ 

393 _most_general_compatible_type(s)._unbatch() # pylint: disable=protected-access 

394 for s in elems_flat_signature 

395 ] 

396 result_unflatten = elems_unflatten 

397 else: 

398 result_flat_signature = [ 

399 _dtype_to_spec(d) for d in nest.flatten(fn_output_signature) 

400 ] 

401 result_unflatten = lambda x: nest.pack_sequence_as(fn_output_signature, x) 

402 

403 with ops.name_scope(name, "map", elems_flat): 

404 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 

405 # supported in Eager 

406 if in_graph_mode: 

407 # Any get_variable calls in fn will cache the first call locally 

408 # and not issue repeated network I/O requests for each iteration. 

409 varscope = vs.get_variable_scope() 

410 varscope_caching_device_was_none = False 

411 if varscope.caching_device is None: 

412 # TODO(ebrevdo): Change to using colocate_with here and in other 

413 # methods. 

414 varscope.set_caching_device(lambda op: op.device) 

415 varscope_caching_device_was_none = True 

416 

417 elems_flat = [ 

418 ops.convert_to_tensor_or_composite(t, name="elem") for t in elems_flat 

419 ] 

420 

421 # Check that inputs are not scalars. 

422 first_elem = elems_flat[0] 

423 if hasattr(first_elem, "shape"): 

424 elems_static_shape = first_elem.shape 

425 if elems_static_shape.ndims is not None and elems_static_shape.ndims < 1: 

426 raise ValueError( 

427 "Elements in elems must be 1+ dimensional Tensors, not scalars") 

428 

429 # Box any composite tensors into tensor lists. 

430 elems_batchable = _elems_flat_to_batchable(elems_flat) 

431 

432 # Find the number of iterations, n. (may be known statically.) 

433 n_static = tensor_shape.Dimension( 

434 tensor_shape.dimension_value( 

435 elems_batchable[0].get_shape().with_rank_at_least(1)[0])) 

436 for tensor in elems_batchable[1:]: 

437 n_static.assert_is_compatible_with( 

438 tensor_shape.Dimension( 

439 tensor_shape.dimension_value( 

440 tensor.get_shape().with_rank_at_least(1)[0]))) 

441 n = n_static.value or array_ops.shape(elems_batchable[0])[0] 

442 

443 # Convert elems to tensor array. 

444 # TODO(edloper): Should we set infer_shape=False for composite tensors? 

445 elems_batchable_ta = [ 

446 tensor_array_ops.TensorArray( 

447 dtype=t.dtype, size=n, dynamic_size=False, infer_shape=True) 

448 for t in elems_batchable 

449 ] 

450 # Unpack elements 

451 elems_batchable_ta = [ 

452 ta.unstack(t) for (ta, t) in zip(elems_batchable_ta, elems_batchable) 

453 ] 

454 

455 i = constant_op.constant(0) 

456 

457 # Prepare result tensor array. 

458 # TODO(edloper): Should we set infer_shape=False for composite tensors? 

459 result_batchable_tensor_spec = ( 

460 _result_flat_signature_to_batchable_tensor_spec(result_flat_signature)) 

461 result_batchable_ta = [] 

462 for spec in result_batchable_tensor_spec: 

463 result_batchable_ta.append( 

464 tensor_array_ops.TensorArray( 

465 dtype=spec.dtype, size=n, dynamic_size=False, 

466 infer_shape=infer_shape, element_shape=spec.shape)) 

467 

468 def compute(i, tas): 

469 """The loop body of map_fn. 

470 

471 Args: 

472 i: the loop counter 

473 tas: the flat TensorArray accumulator list 

474 

475 Returns: 

476 (i + 1, tas): the updated counter + updated TensorArrays 

477 

478 Raises: 

479 TypeError: if fn_output_signature and result_value structure don't match 

480 ValueType: if fn_output_signature and result_value lengths don't match 

481 """ 

482 elems_value_batchable = [ta.read(i) for ta in elems_batchable_ta] 

483 elems_value_flat = _elems_value_batchable_to_flat(elems_value_batchable, 

484 elems_flat_signature) 

485 elems_value = elems_unflatten(elems_value_flat) 

486 ag_ctx = autograph_ctx.control_status_ctx() 

487 autographed_fn = autograph.tf_convert(fn, ag_ctx) 

488 result_value = autographed_fn(elems_value) 

489 nest.assert_same_structure(fn_output_signature or elems, result_value) 

490 result_value_flat = nest.flatten(result_value) 

491 result_value_batchable = _result_value_flat_to_batchable( 

492 result_value_flat, result_flat_signature) 

493 tas = [ 

494 ta.write(i, value) for (ta, value) in zip(tas, result_value_batchable) 

495 ] 

496 return (i + 1, tas) 

497 

498 _, r_a = while_loop.while_loop( 

499 lambda i, _: i < n, 

500 compute, (i, result_batchable_ta), 

501 parallel_iterations=parallel_iterations, 

502 back_prop=back_prop, 

503 swap_memory=swap_memory, 

504 maximum_iterations=n) 

505 result_batchable = [r.stack() for r in r_a] 

506 

507 # Update each output tensor w/ static shape info about the outer dimension. 

508 for r in result_batchable: 

509 r.set_shape(tensor_shape.TensorShape(n_static).concatenate( 

510 r.get_shape()[1:])) 

511 

512 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 

513 # supported in Eager 

514 if in_graph_mode and varscope_caching_device_was_none: 

515 varscope.set_caching_device(None) 

516 

517 result_flat = _result_batchable_to_flat(result_batchable, 

518 result_flat_signature, 

519 n_static) 

520 result = result_unflatten(result_flat) 

521 return result 

522 

523 

524def _dtype_to_spec(d): 

525 if not isinstance(d, type_spec.TypeSpec): 

526 d = tensor_spec.TensorSpec(None, d) 

527 return d 

528 

529 

530def _most_general_compatible_type(spec): 

531 """Returns the most general TypeSpec compatible with `spec`.""" 

532 # TODO(edloper): Consider adding most_general_compatible_type to TypeSpec API 

533 if isinstance(spec, tensor_spec.TensorSpec): 

534 return tensor_spec.TensorSpec(None, spec.dtype) 

535 elif isinstance(spec, ragged_tensor.RaggedTensorSpec): 

536 # pylint: disable=protected-access 

537 return ragged_tensor.RaggedTensorSpec(None, spec._dtype, spec._ragged_rank, 

538 spec._row_splits_dtype) 

539 elif isinstance(spec, sparse_tensor.SparseTensorSpec): 

540 # pylint: disable=protected-access 

541 return sparse_tensor.SparseTensorSpec(None, spec.dtype) 

542 else: 

543 return spec 

544 

545 

546def _result_flat_signature_to_batchable_tensor_spec(result_flat_signature): 

547 """Converts result_flat_signature -> result_batchable_tensor_specs.""" 

548 tensor_specs = [] 

549 for spec in result_flat_signature: 

550 if not isinstance(spec, type_spec.BatchableTypeSpec): 

551 raise TypeError("map_fn can not generate %s outputs" % (spec,)) 

552 tensor_specs.extend(spec._flat_tensor_specs) # pylint: disable=protected-access 

553 return tensor_specs 

554 

555 

556def _elems_flat_to_batchable(elems_flat): 

557 """Converts elems_flat -> elems_batchable.""" 

558 elems_batchable = [] 

559 for elems_tensor in elems_flat: 

560 spec = type_spec.type_spec_from_value(elems_tensor) 

561 if not isinstance(spec, type_spec.BatchableTypeSpec): 

562 raise TypeError("map_fn can not consume %s inputs: got %r" % 

563 (spec, elems_tensor)) 

564 # pylint: disable=protected-access 

565 elems_batchable.extend(spec._to_batched_tensor_list(elems_tensor)) 

566 return elems_batchable 

567 

568 

569def _elems_value_batchable_to_flat(elems_value_batchable, elems_flat_signature): 

570 """Converts elems_value_batchable -> elems_value_flat.""" 

571 elems_value_flat = [] 

572 i = 0 

573 for spec in elems_flat_signature: 

574 # pylint: disable=protected-access 

575 spec = spec._unbatch() 

576 tensor_list = elems_value_batchable[i:i + len(spec._flat_tensor_specs)] 

577 elems_value_flat.append(spec._from_compatible_tensor_list(tensor_list)) 

578 i += len(tensor_list) 

579 assert i == len(elems_value_batchable) 

580 return elems_value_flat 

581 

582 

583def _result_value_flat_to_batchable(result_value_flat, result_flat_signature): 

584 """Converts result_value_flat -> result_value_batchable.""" 

585 result_value_batchable = [] 

586 for (r_value, r_spec) in zip(result_value_flat, result_flat_signature): 

587 if isinstance(r_spec, tensor_spec.TensorSpec): 

588 result_value_batchable.append(r_value) 

589 else: 

590 if not r_spec.is_compatible_with(r_value): 

591 raise ValueError( 

592 "Error in map_fn:\n Expected `fn` to return a:\n %s\n" 

593 " But it returned a:\n %s\n (value=%s)\n" 

594 " To fix, update the `fn_output_signature` (or `dtype`) " 

595 "argument to `map_fn`." % 

596 (r_spec, type_spec.type_spec_from_value(r_value), r_value)) 

597 result_value_batchable.extend(r_spec._to_tensor_list(r_value)) # pylint: disable=protected-access 

598 return result_value_batchable 

599 

600 

601def _result_batchable_to_flat(result_batchable, result_flat_signature, 

602 batch_size): 

603 """Converts result_batchable -> result_flat.""" 

604 result_flat = [] 

605 i = 0 

606 for spec in result_flat_signature: 

607 # pylint: disable=protected-access 

608 num_tensors = len(spec._flat_tensor_specs) 

609 result_flat.append( 

610 spec._batch(batch_size)._from_compatible_tensor_list( 

611 result_batchable[i:i + num_tensors])) 

612 i += num_tensors 

613 assert i == len(result_batchable) 

614 return result_flat 

615 

616 

617@tf_export("map_fn", v1=[]) 

618@deprecation.deprecated_arg_values( 

619 None, 

620 """back_prop=False is deprecated. Consider using tf.stop_gradient instead. 

621Instead of: 

622results = tf.map_fn(fn, elems, back_prop=False) 

623Use: 

624results = tf.nest.map_structure(tf.stop_gradient, tf.map_fn(fn, elems))""", 

625 warn_once=True, 

626 back_prop=False) 

627@deprecation.deprecated_args(None, "Use fn_output_signature instead", "dtype") 

628def map_fn_v2(fn, 

629 elems, 

630 dtype=None, 

631 parallel_iterations=None, 

632 back_prop=True, 

633 swap_memory=False, 

634 infer_shape=True, 

635 name=None, 

636 fn_output_signature=None): 

637 """Transform `elems` by applying `fn` to each element unstacked on axis 0.""" 

638 if fn_output_signature is None: 

639 fn_output_signature = dtype 

640 return map_fn( 

641 fn=fn, 

642 elems=elems, 

643 fn_output_signature=fn_output_signature, 

644 parallel_iterations=parallel_iterations, 

645 back_prop=back_prop, 

646 swap_memory=swap_memory, 

647 infer_shape=infer_shape, 

648 name=name) 

649 

650 

651# Docstring for v2 is the same as v1, except that back_prop is deprecated. 

652map_fn_v2.__doc__ = re.sub( 

653 r"( back_prop: \(optional\) )(.*)", 

654 r"\1Deprecated: prefer using `tf.stop_gradient` instead. \2", 

655 map_fn.__doc__) 

656assert "prefer using `tf.stop_gradient` instead" in map_fn_v2.__doc__