Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/parallel_for/control_flow_ops.py: 18%

218 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""for_loop and pfor ops.""" 

16# pylint: disable=g-direct-tensorflow-import 

17 

18import functools 

19 

20from tensorflow.python.eager import context 

21from tensorflow.python.eager import def_function 

22from tensorflow.python.autograph.core import ag_ctx as autograph_ctx 

23from tensorflow.python.autograph.impl import api as autograph 

24from tensorflow.python.framework import composite_tensor 

25from tensorflow.python.framework import indexed_slices 

26from tensorflow.python.framework import ops 

27from tensorflow.python.framework import sparse_tensor 

28from tensorflow.python.framework import tensor_shape 

29from tensorflow.python.framework import tensor_util 

30from tensorflow.python.framework import type_spec 

31from tensorflow.python.ops import array_ops 

32from tensorflow.python.ops import cond 

33from tensorflow.python.ops import math_ops 

34from tensorflow.python.ops import tensor_array_ops 

35from tensorflow.python.ops import while_loop 

36from tensorflow.python.ops.parallel_for.pfor import PFor 

37from tensorflow.python.ops.parallel_for.pfor import PForConfig 

38from tensorflow.python.platform import tf_logging as logging 

39from tensorflow.python.util import nest 

40from tensorflow.python.util import tf_decorator 

41from tensorflow.python.util import tf_inspect 

42from tensorflow.python.util import variable_utils 

43from tensorflow.python.util.tf_export import tf_export 

44 

45 

46def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None): 

47 """Runs `loop_fn` `iters` times and stacks the outputs. 

48 

49 

50 Runs `loop_fn` `iters` times, with input values from 0 to `iters - 1`, and 

51 stacks corresponding outputs of the different runs. 

52 

53 Args: 

54 loop_fn: A function that takes an int32 scalar tf.Tensor object representing 

55 the iteration number, and returns a possibly nested structure of tensor 

56 objects. The shape of these outputs should not depend on the input. 

57 loop_fn_dtypes: dtypes for the outputs of `loop_fn`. 

58 iters: Number of iterations for which to run `loop_fn`. 

59 parallel_iterations: The number of iterations that can be dispatched in 

60 parallel. This knob can be used to control the total memory usage. 

61 

62 Returns: 

63 Returns a nested structure of stacked output tensor objects with the same 

64 nested structure as the output of `loop_fn`. 

65 """ 

66 

67 flat_loop_fn_dtypes = nest.flatten(loop_fn_dtypes) 

68 is_none_list = [] 

69 

70 def while_body(i, *ta_list): 

71 """Body of while loop.""" 

72 fn_conv = autograph.tf_convert(loop_fn, autograph_ctx.control_status_ctx()) 

73 fn_output = nest.flatten(fn_conv(i)) 

74 if len(fn_output) != len(flat_loop_fn_dtypes): 

75 raise ValueError( 

76 f"Number of expected outputs {len(flat_loop_fn_dtypes)}, does not " 

77 f"match the number of actual outputs {len(fn_output)} from loop_fn: " 

78 f"{loop_fn} with output {fn_output}.") 

79 outputs = [] 

80 del is_none_list[:] 

81 is_none_list.extend(x is None for x in fn_output) 

82 for out, ta in zip(fn_output, ta_list): 

83 # TODO(agarwal): support returning Operation objects from loop_fn. 

84 if out is not None: 

85 # out may be a ref tensor, wrap it in identity to get a non-ref tensor. 

86 ta = ta.write(i, out) 

87 outputs.append(ta) 

88 return tuple([i + 1] + outputs) 

89 

90 if parallel_iterations is not None: 

91 extra_args = {"parallel_iterations": parallel_iterations} 

92 else: 

93 extra_args = {} 

94 ta_list = while_loop.while_loop(lambda i, *ta: i < iters, while_body, [0] + [ 

95 tensor_array_ops.TensorArray(dtype.base_dtype, iters) 

96 for dtype in flat_loop_fn_dtypes 

97 ], **extra_args)[1:] 

98 

99 # TODO(rachelim): enable this for sparse tensors 

100 

101 output = [ 

102 None if is_none else ta.stack() 

103 for ta, is_none in zip(ta_list, is_none_list) 

104 ] 

105 assert len(output) in (0, len(flat_loop_fn_dtypes)) 

106 if not output: 

107 # This may happen for the case where iters == 0. 

108 # Pack a list of empty tensors with the proper ranks to match pfor output on 0 iters 

109 loop_var = array_ops.placeholder_with_default(0, shape=[]) 

110 try: 

111 loop_fn_out = loop_fn(loop_var) 

112 out_shapes = [ 

113 [0] + ops.convert_to_tensor(x).shape 

114 for x in nest.flatten(loop_fn_out) 

115 ] 

116 output = [ 

117 array_ops.zeros(out_shapes[i], dt) 

118 for i, dt in enumerate(flat_loop_fn_dtypes) 

119 ] 

120 except Exception: 

121 output = [array_ops.zeros([0])] 

122 return nest.pack_sequence_as(loop_fn_dtypes, output) 

123 

124 

125def _flatten_first_two_dims(x): 

126 """Flattens the first two dimensions of x into a single dimension.""" 

127 old_shape = array_ops.shape(x) 

128 new_shape = array_ops.concat([[old_shape[0] * old_shape[1]], old_shape[2:]], 

129 axis=0) 

130 return array_ops.reshape(x, new_shape) 

131 

132 

133PFOR_CONFIG_ARG = "pfor_config" 

134 

135 

136def _is_under_xla_context(): 

137 """Check if we are currently inside an XLA compile context.""" 

138 g = ops.get_default_graph() 

139 while g is not None: 

140 control_flow_context = g._get_control_flow_context() # pylint: disable=protected-access 

141 while control_flow_context is not None: 

142 if control_flow_context.IsXLAContext(): 

143 return True 

144 else: 

145 control_flow_context = control_flow_context.outer_context 

146 # If g is a FuncGraph, get its outer_graph. 

147 g = getattr(g, "outer_graph", None) 

148 return False 

149 

150 

151def pfor(loop_fn, 

152 iters, 

153 fallback_to_while_loop=True, 

154 parallel_iterations=None, 

155 warn=False): 

156 """Equivalent to running `loop_fn` `iters` times and stacking the outputs. 

157 

158 `pfor` has functionality similar to `for_loop`, i.e. running `loop_fn` `iters` 

159 times, with input from 0 to `iters - 1`, and stacking corresponding output of 

160 each iteration. However the implementation does not use a `tf.while_loop`. 

161 Instead it adds new operations to the graph that collectively compute the same 

162 value as what running `loop_fn` in a loop would compute. 

163 

164 

165 This is an experimental feature and currently has a lot of limitations: 

166 - There should be no data dependency between the different iterations. For 

167 example, a future iteration should not depend on a value or side-effect of 

168 a previous iteration. 

169 - Stateful kernels may mostly not be supported since these often imply a 

170 data dependency or ordering of the iterations. We do support a limited set 

171 of such stateful kernels though (like RandomFoo, Variable operations like 

172 reads, etc). 

173 - Conversion works only on a limited set of kernels for which a converter 

174 has been registered. 

175 - `loop_fn` has limited support for control flow operations. `tf.cond` in 

176 particular is not supported. 

177 - `loop_fn` should return nested structure of Tensors or Operations. However 

178 if an Operation is returned, it should have zero outputs. 

179 - The shape and dtype of `loop_fn` outputs should not depend on the input 

180 to loop_fn. 

181 

182 Args: 

183 loop_fn: A function that takes an int32 scalar tf.Tensor object representing 

184 the iteration number, and optionally a keyword argument `pfor_config` set 

185 to a PForConfig object. It returns a possibly nested structure of Tensor 

186 or Operation objects. Note that if setting `parallel_iterations` argument 

187 to something other than None, `loop_fn` may be called more than once 

188 during graph construction. So it may need to avoid mutating global state. 

189 iters: Number of iterations for which to run `loop_fn`. 

190 fallback_to_while_loop: If true, on failing to vectorize an operation, pfor 

191 fallbacks to using a `tf.while_loop` to dispatch the iterations. 

192 parallel_iterations: A knob to control how many iterations are vectorized 

193 and dispatched in parallel. The default value of None corresponds to 

194 vectorizing all the iterations. If `parallel_iterations` is smaller than 

195 `iters`, then chunks of at most that many iterations are dispatched in 

196 sequence. This knob can be used to control the total memory usage. 

197 warn: Whether or not to warn when falling back to while loops. 

198 

199 Returns: 

200 Returns a nested structure of stacked tensor objects with the same nested 

201 structure as the output of `loop_fn`. 

202 Raises: 

203 ValueError: If parallel_iterations is not None and not an integer > 1. 

204 """ 

205 def f(): 

206 return _pfor_impl( 

207 loop_fn, 

208 iters, 

209 fallback_to_while_loop=fallback_to_while_loop, 

210 parallel_iterations=parallel_iterations, 

211 warn=warn) 

212 # Note that we wrap into a tf.function if in eager execution mode or under 

213 # XLA compilation. The latter is so that we don't compile operations like 

214 # tf.placeholder that are created by the loop body. 

215 functions_run_eagerly = None 

216 if context.executing_eagerly() or _is_under_xla_context(): 

217 functions_run_eagerly = def_function.functions_run_eagerly() 

218 if functions_run_eagerly: 

219 logging.warning( 

220 "It looks like tf.function behavior was disabled, perhaps using " 

221 "tf.config.run_functions_eagerly. Vectorization " 

222 "primitives (e.g. tf.vectorized_map) require tf.function to work. " 

223 "These primitives will override the disable.") 

224 def_function.run_functions_eagerly(False) 

225 f = def_function.function(f) 

226 

227 outputs = f() 

228 if functions_run_eagerly is not None: 

229 def_function.run_functions_eagerly(functions_run_eagerly) 

230 return outputs 

231 

232 

233def _should_expand_composite(value): 

234 return (isinstance(value, composite_tensor.CompositeTensor) 

235 # Leave sparse tensors to be converted by `PFor._convert_sparse`. 

236 and not isinstance(value, sparse_tensor.SparseTensor) 

237 and not isinstance(value, indexed_slices.IndexedSlices)) 

238 

239 

240# pylint: disable=protected-access 

241def _composite_to_tensors(value, is_batched=False): 

242 """Converts a CompositeTensor into a list of stackable tensors.""" 

243 if _should_expand_composite(value): 

244 spec = value._type_spec 

245 if not isinstance(spec, type_spec.BatchableTypeSpec): 

246 raise ValueError(f"CompositeTensor instance {value} returned from " 

247 "parallel_for or vectorized_map loop body must provide " 

248 f"a `BatchableTypeSpec` (saw: {spec}).") 

249 if is_batched: 

250 return spec._to_batched_tensor_list(value) 

251 return spec._to_tensor_list(value) 

252 return value 

253# pylint: enable=protected-access 

254 

255 

256# pylint: disable=protected-access 

257def _composite_from_tensors(stacked_tensors, 

258 preconverted_value, 

259 batch_size): 

260 """Converts a list of stacked tensors to a batch CompositeTensor.""" 

261 if _should_expand_composite(preconverted_value): 

262 batch_type_spec = preconverted_value._type_spec._batch(batch_size) 

263 return batch_type_spec._from_compatible_tensor_list(stacked_tensors) 

264 return stacked_tensors 

265# pylint: enable=protected-access 

266 

267 

268def _loop_fn_has_config(loop_fn): 

269 """Test if `loop_fn` has a `pfor_config` argument.""" 

270 if tf_inspect.isfunction(loop_fn): 

271 argspec = tf_inspect.getargspec(loop_fn) 

272 return PFOR_CONFIG_ARG in argspec.args 

273 elif isinstance(loop_fn, functools.partial): 

274 fn = loop_fn.func 

275 argspec = tf_inspect.getargspec(fn) 

276 return (PFOR_CONFIG_ARG in argspec.args and 

277 PFOR_CONFIG_ARG not in loop_fn.keywords) 

278 else: 

279 loop_class = tf_decorator.unwrap(loop_fn)[1] 

280 if not hasattr(loop_class, "__call__"): 

281 raise ValueError("`loop_fn` object did not have a __call__ method") 

282 argspec = tf_inspect.getargspec(loop_class.__call__) 

283 return PFOR_CONFIG_ARG in argspec.args 

284 

285 

286def _pfor_impl(loop_fn, 

287 iters, 

288 fallback_to_while_loop, 

289 parallel_iterations=None, 

290 pfor_config=None, 

291 warn=False): 

292 """Implementation of pfor.""" 

293 assert not context.executing_eagerly() 

294 loop_fn_has_config = _loop_fn_has_config(loop_fn) 

295 existing_ops = set(ops.get_default_graph().get_operations()) 

296 iters_value = tensor_util.constant_value(iters) 

297 # Run the loop body 

298 with ops.name_scope("loop_body"): 

299 loop_var = array_ops.placeholder_with_default(0, shape=[]) 

300 if loop_fn_has_config: 

301 if pfor_config is None: 

302 pfor_config = PForConfig() 

303 pfor_config._set_iters(iters) # pylint: disable=protected-access 

304 loop_fn_outputs = loop_fn(loop_var, **{PFOR_CONFIG_ARG: pfor_config}) 

305 else: 

306 assert pfor_config is None 

307 f = autograph.tf_convert(loop_fn, autograph_ctx.control_status_ctx()) 

308 loop_fn_outputs = f(loop_var) 

309 loop_fn_output_tensors = nest.map_structure(_composite_to_tensors, 

310 loop_fn_outputs) 

311 

312 # Convert outputs to Tensor if needed. 

313 tmp_loop_fn_outputs = [] 

314 for loop_fn_output in nest.flatten(loop_fn_output_tensors): 

315 if (loop_fn_output is not None and not isinstance( 

316 loop_fn_output, 

317 (ops.Operation, ops.Tensor, sparse_tensor.SparseTensor))): 

318 if isinstance(loop_fn_output, indexed_slices.IndexedSlices): 

319 logging.warn("Converting %s to a dense representation may make it slow." 

320 " Alternatively, output the indices and values of the" 

321 " IndexedSlices separately, and handle the vectorized" 

322 " outputs directly." % loop_fn_output) 

323 loop_fn_output = ops.convert_to_tensor(loop_fn_output) 

324 else: 

325 loop_fn_output = ops.convert_to_tensor(loop_fn_output) 

326 tmp_loop_fn_outputs.append(loop_fn_output) 

327 loop_fn_output_tensors = nest.pack_sequence_as(loop_fn_output_tensors, 

328 tmp_loop_fn_outputs) 

329 

330 new_ops = set(ops.get_default_graph().get_operations()) - existing_ops 

331 iters = ops.convert_to_tensor(iters) 

332 if parallel_iterations is not None: 

333 if parallel_iterations < 1: 

334 raise ValueError( 

335 "Argument `parallel_iterations` must be None or a positive integer. " 

336 f"Received: {parallel_iterations}.") 

337 if parallel_iterations == 1: 

338 raise ValueError( 

339 "Found `parallel_iterations == 1`. Use `for_loop` instead.") 

340 if iters_value is not None and iters_value < parallel_iterations: 

341 parallel_iterations = None 

342 if parallel_iterations is None: 

343 with ops.name_scope("pfor"): 

344 converter = PFor( 

345 loop_var, 

346 iters, 

347 new_ops, 

348 fallback_to_while_loop=fallback_to_while_loop, 

349 pfor_config=pfor_config, 

350 warn=warn) 

351 flattened_output_tensors = [] 

352 for loop_fn_output in nest.flatten(loop_fn_output_tensors): 

353 output = converter.convert(loop_fn_output) 

354 flattened_output_tensors.append(output) 

355 else: 

356 if pfor_config is not None and pfor_config._has_reductions(): # pylint: disable=protected-access 

357 raise ValueError("Setting `parallel_iterations` currently unsupported if " 

358 "reductions across iterations are performed.") 

359 num_tiled_iterations = iters // parallel_iterations 

360 num_remaining_iterations = iters % parallel_iterations 

361 # TODO(agarwal): Avoid calling loop_fn twice. Generate the loop body inside 

362 # a tf.function and extract the graph from there to vectorize it. 

363 with ops.name_scope("pfor_untiled"): 

364 converter = PFor(loop_var, num_remaining_iterations, new_ops, 

365 fallback_to_while_loop=fallback_to_while_loop, 

366 pfor_config=pfor_config) 

367 remaining_output_tensors = [] 

368 flattened_output_tensors = nest.flatten(loop_fn_output_tensors) 

369 for loop_fn_output in flattened_output_tensors: 

370 output = converter.convert(loop_fn_output) 

371 remaining_output_tensors.append(output) 

372 

373 with ops.name_scope("pfor_tiled"): 

374 loop_fn_dtypes = [ops.convert_to_tensor(x).dtype 

375 for x in flattened_output_tensors] 

376 

377 def tiled_loop_body(j): 

378 offset = j * parallel_iterations + num_remaining_iterations 

379 

380 def tiled_loop_fn(i, pfor_config=None): 

381 if loop_fn_has_config: 

382 loop_fn_outputs = loop_fn(i + offset, pfor_config=pfor_config) 

383 else: 

384 loop_fn_outputs = loop_fn(i + offset) 

385 return nest.flatten( 

386 # Stacking across iterations requires explicit Tensors. 

387 nest.map_structure(_composite_to_tensors, loop_fn_outputs)) 

388 

389 return _pfor_impl( 

390 tiled_loop_fn, 

391 parallel_iterations, 

392 fallback_to_while_loop=fallback_to_while_loop, 

393 pfor_config=pfor_config) 

394 

395 tiled_output_tensors = for_loop( 

396 tiled_loop_body, loop_fn_dtypes, 

397 num_tiled_iterations, parallel_iterations=1) 

398 tiled_output_tensors = [ 

399 _flatten_first_two_dims(y) for y in tiled_output_tensors] 

400 

401 with ops.name_scope("pfor"): 

402 if iters_value is None or iters_value % parallel_iterations: 

403 output_tensors = cond.cond( 

404 math_ops.equal(num_remaining_iterations, 0), 

405 lambda: tiled_output_tensors, 

406 lambda: [array_ops.concat([x, y], axis=0) # pylint: disable=g-long-lambda 

407 for x, y in zip(remaining_output_tensors, 

408 tiled_output_tensors)]) 

409 else: 

410 output_tensors = tiled_output_tensors 

411 flattened_output_tensors = nest.flatten(output_tensors) 

412 

413 for output, original_output in zip(flattened_output_tensors, 

414 nest.flatten(loop_fn_output_tensors)): 

415 # Restore any shape information lost from tiling. 

416 # TODO(b/174254748): this may not be correct for stacked `variant`s. 

417 output.set_shape( 

418 tensor_shape.TensorShape([iters_value]).concatenate( 

419 original_output.shape)) 

420 return nest.map_structure_up_to( 

421 loop_fn_outputs, 

422 functools.partial(_composite_from_tensors, batch_size=iters_value), 

423 nest.pack_sequence_as(loop_fn_output_tensors, 

424 flattened_output_tensors), 

425 loop_fn_outputs) 

426 

427 

428def _broadcasting_gather(x, i): 

429 """Wrapper for gather that implicitly broadcasts unit dimensions.""" 

430 static_first_dim = tensor_shape.dimension_value(x.shape[0]) 

431 if static_first_dim == 1: 

432 i = 0 

433 elif static_first_dim is None: 

434 i = array_ops.where_v2(array_ops.shape(x)[0] > 1, i, 0) 

435 result = array_ops.gather(x, i) 

436 return result 

437 

438 

439# pylint: disable=protected-access 

440def _gather_from_tensor_or_composite(x, i): 

441 """Wrapper for gather that handles CompositeTensors.""" 

442 if _should_expand_composite(x): 

443 spec = x._type_spec 

444 gathered_tensors = [_broadcasting_gather(t, i) 

445 for t in spec._to_batched_tensor_list(x)] 

446 return spec._unbatch()._from_compatible_tensor_list(gathered_tensors) 

447 return _broadcasting_gather(x, i) 

448# pylint: enable=protected-access 

449 

450 

451@tf_export("vectorized_map") 

452def vectorized_map(fn, elems, fallback_to_while_loop=True, warn=True): 

453 """Parallel map on the list of tensors unpacked from `elems` on dimension 0. 

454 

455 This method works similar to `tf.map_fn` but is optimized to run much faster, 

456 possibly with a much larger memory footprint. The speedups are obtained by 

457 vectorization (see [Auto-Vectorizing TensorFlow Graphs: Jacobians, 

458 Auto-Batching and Beyond](https://arxiv.org/pdf/1903.04243.pdf)). The idea 

459 behind vectorization is to semantically launch all the invocations of `fn` in 

460 parallel and fuse corresponding operations across all these invocations. This 

461 fusion is done statically at graph generation time and the generated code is 

462 often similar in performance to a manually fused version. 

463 

464 Because `tf.vectorized_map` fully parallelizes the batch, this method will 

465 generally be significantly faster than using `tf.map_fn`, especially in eager 

466 mode. However this is an experimental feature and currently has a lot of 

467 limitations: 

468 - There should be no data dependency between the different semantic 

469 invocations of `fn`, i.e. it should be safe to map the elements of the 

470 inputs in any order. 

471 - Stateful kernels may mostly not be supported since these often imply a 

472 data dependency. We do support a limited set of such stateful kernels 

473 though (like RandomFoo, Variable operations like reads, etc). 

474 - `fn` has limited support for control flow operations. 

475 - `fn` should return nested structure of Tensors or Operations. However 

476 if an Operation is returned, it should have zero outputs. 

477 - The shape and dtype of any intermediate or output tensors in the 

478 computation of `fn` should not depend on the input to `fn`. 

479 

480 Examples: 

481 ```python 

482 def outer_product(a): 

483 return tf.tensordot(a, a, 0) 

484 

485 batch_size = 100 

486 a = tf.ones((batch_size, 32, 32)) 

487 c = tf.vectorized_map(outer_product, a) 

488 assert c.shape == (batch_size, 32, 32, 32, 32) 

489 ``` 

490 

491 ```python 

492 # Computing per-example gradients 

493 

494 batch_size = 10 

495 num_features = 32 

496 layer = tf.keras.layers.Dense(1) 

497 

498 def model_fn(arg): 

499 with tf.GradientTape() as g: 

500 inp, label = arg 

501 inp = tf.expand_dims(inp, 0) 

502 label = tf.expand_dims(label, 0) 

503 prediction = layer(inp) 

504 loss = tf.nn.l2_loss(label - prediction) 

505 return g.gradient(loss, (layer.kernel, layer.bias)) 

506 

507 inputs = tf.random.uniform([batch_size, num_features]) 

508 labels = tf.random.uniform([batch_size, 1]) 

509 per_example_gradients = tf.vectorized_map(model_fn, (inputs, labels)) 

510 assert per_example_gradients[0].shape == (batch_size, num_features, 1) 

511 assert per_example_gradients[1].shape == (batch_size, 1) 

512 ``` 

513 

514 Args: 

515 fn: The callable to be performed. It accepts one argument, which will have 

516 the same (possibly nested) structure as `elems`, and returns a possibly 

517 nested structure of Tensors and Operations, which may be different than 

518 the structure of `elems`. 

519 elems: A tensor or (possibly nested) sequence of tensors, each of which will 

520 be unpacked along their first dimension. The nested sequence of the 

521 resulting slices will be mapped over by `fn`. The first dimensions of all 

522 elements must broadcast to a consistent value; equivalently, each 

523 element tensor must have first dimension of either `B` or `1`, for some 

524 common batch size `B >= 1`. 

525 fallback_to_while_loop: If true, on failing to vectorize an operation, 

526 the unsupported op is wrapped in a tf.while_loop to execute the map 

527 iterations. Note that this fallback only happens for unsupported ops and 

528 other parts of `fn` are still vectorized. If false, on encountering an 

529 unsupported op, a ValueError is thrown. Note that the fallbacks can result 

530 in slowdowns since vectorization often yields speedup of one to two orders 

531 of magnitude. 

532 warn: If set to `false`, this will supress any warnings due to operation 

533 conversions in the provided `fn` falling back to while loops. 

534 

535 Returns: 

536 A tensor or (possibly nested) sequence of tensors. Each tensor packs the 

537 results of applying fn to tensors unpacked from elems along the first 

538 dimension, from first to last. 

539 

540 Although they are less common as user-visible inputs and outputs, note that 

541 tensors of type `tf.variant` which represent tensor lists (for example from 

542 `tf.raw_ops.TensorListFromTensor`) are vectorized by stacking the list 

543 contents rather than the variant itself, and so the container tensor will 

544 have a scalar shape when returned rather than the usual stacked shape. This 

545 improves the performance of control flow gradient vectorization. 

546 

547 Raises: 

548 ValueError: If vectorization fails and fallback_to_while_loop is False. 

549 """ 

550 elems = variable_utils.convert_variables_to_tensors(elems) 

551 elems = nest.map_structure(ops.convert_to_tensor, 

552 elems, 

553 expand_composites=True) 

554 

555 def loop_fn(i): 

556 gathered_elems = nest.map_structure( 

557 lambda x: _gather_from_tensor_or_composite(x, i), elems) 

558 return fn(gathered_elems) 

559 

560 # Extract batch size from the maximum first dimension of any element. 

561 flat_elems = nest.flatten( 

562 nest.map_structure( 

563 functools.partial(_composite_to_tensors, 

564 is_batched=True), 

565 elems)) 

566 def _get_shape(x): 

567 if x.shape.rank is None: 

568 return None 

569 return x.shape.as_list()[0] 

570 static_first_dims = [_get_shape(elem) for elem in flat_elems] 

571 if any(s is None for s in static_first_dims): 

572 batch_size = math_ops.reduce_max( 

573 [array_ops.shape(elem)[0] for elem in flat_elems]) 

574 else: 

575 batch_size = max(static_first_dims) 

576 

577 return pfor( 

578 loop_fn, 

579 batch_size, 

580 fallback_to_while_loop=fallback_to_while_loop, 

581 warn=warn)