Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/parallel_for/pfor.py: 24%

2693 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Compiled parallel-for loop.""" 

16# pylint: disable=missing-docstring,g-direct-tensorflow-import 

17 

18import collections 

19from functools import partial 

20import string 

21import sys 

22import traceback 

23 

24import numpy as np 

25 

26from tensorflow.compiler.tf2xla.python import xla 

27from tensorflow.core.framework import full_type_pb2 

28from tensorflow.python.eager import context 

29from tensorflow.python.eager import def_function 

30from tensorflow.python.eager import execute 

31from tensorflow.python.framework import constant_op 

32from tensorflow.python.framework import dtypes 

33from tensorflow.python.framework import func_graph 

34from tensorflow.python.framework import ops 

35from tensorflow.python.framework import smart_cond 

36from tensorflow.python.framework import sparse_tensor 

37from tensorflow.python.framework import tensor_shape 

38from tensorflow.python.framework import tensor_spec 

39from tensorflow.python.framework import tensor_util 

40from tensorflow.python.ops import array_ops 

41from tensorflow.python.ops import array_ops_stack 

42from tensorflow.python.ops import cond as tf_cond 

43from tensorflow.python.ops import control_flow_assert 

44from tensorflow.python.ops import control_flow_ops 

45from tensorflow.python.ops import control_flow_switch_case 

46from tensorflow.python.ops import data_flow_ops 

47from tensorflow.python.ops import gen_array_ops 

48from tensorflow.python.ops import gen_image_ops 

49from tensorflow.python.ops import gen_linalg_ops 

50from tensorflow.python.ops import gen_list_ops 

51from tensorflow.python.ops import gen_math_ops 

52from tensorflow.python.ops import gen_nn_ops 

53from tensorflow.python.ops import gen_optional_ops 

54from tensorflow.python.ops import gen_parsing_ops 

55from tensorflow.python.ops import gen_random_ops 

56from tensorflow.python.ops import gen_sparse_ops 

57from tensorflow.python.ops import gen_spectral_ops 

58from tensorflow.python.ops import handle_data_util 

59from tensorflow.python.ops import linalg_ops 

60from tensorflow.python.ops import list_ops 

61from tensorflow.python.ops import manip_ops 

62from tensorflow.python.ops import map_fn 

63from tensorflow.python.ops import math_ops 

64from tensorflow.python.ops import nn_ops 

65from tensorflow.python.ops import parsing_ops 

66from tensorflow.python.ops import resource_variable_ops 

67from tensorflow.python.ops import sparse_ops 

68from tensorflow.python.ops import special_math_ops 

69from tensorflow.python.ops import tensor_array_ops 

70from tensorflow.python.ops import while_loop 

71from tensorflow.python.platform import flags 

72from tensorflow.python.platform import tf_logging as logging 

73from tensorflow.python.util import compat 

74from tensorflow.python.util import nest 

75from tensorflow.python.util import object_identity 

76 

77 

78# TODO(agarwal): remove flag. 

79flags.DEFINE_bool( 

80 "op_conversion_fallback_to_while_loop", True, 

81 "DEPRECATED: Flag is ignored.") 

82 

83 

84def _variant_handle_data(t): 

85 """Fetches handle data for a variant tensor `t`, or None if unavailable.""" 

86 handle_data = resource_variable_ops.get_eager_safe_handle_data(t) 

87 if not handle_data.is_set: 

88 return None 

89 return handle_data.shape_and_type 

90 

91 

92def _variant_type_id(t): 

93 """Returns the full_type_pb2 type of `t`, or None if it is not available.""" 

94 if t.dtype != dtypes.variant: 

95 return None 

96 shapes_and_types = _variant_handle_data(t) 

97 if shapes_and_types is None or not shapes_and_types: 

98 # TODO(b/169968286): Identify all variant tensors (e.g. maps) and we can 

99 # make this an error instead of assuming TensorLists have handle data. 

100 return None # Presumed not a TensorList/Optional 

101 return shapes_and_types[0].type.type_id 

102 

103 

104_INTERNAL_STACKING_TYPE_IDS = ( 

105 full_type_pb2.TFT_ARRAY, 

106 full_type_pb2.TFT_OPTIONAL) 

107 

108 

109def _is_variant_with_internal_stacking(t): 

110 """Identifies variant tensors which pfor always maintains as scalars. 

111 

112 For these, the pfor tensor is recorded as "stacked" if the content of the 

113 variant tensor (e.g. the elements of a TensorList) are all stacked. 

114 

115 Args: 

116 t: A tensor to identify. 

117 Returns: 

118 True if `t` is a TensorList/Optional, False not, None if unknown. 

119 """ 

120 type_id = _variant_type_id(t) 

121 return type_id in _INTERNAL_STACKING_TYPE_IDS 

122 

123 

124def _parse_variant_shapes_and_types(t): 

125 """Extracts shape and dtype information from a variant tensor `t`.""" 

126 shapes_and_types = _variant_handle_data(t) 

127 if shapes_and_types is None or not shapes_and_types: 

128 raise ValueError("Required handle data not set for {!r}".format(t)) 

129 if shapes_and_types[0].type.type_id == full_type_pb2.TFT_ARRAY: 

130 return shapes_and_types 

131 else: 

132 if shapes_and_types[0].type.type_id == full_type_pb2.TFT_UNSET: 

133 return shapes_and_types 

134 else: 

135 raise ValueError( 

136 "Attempted to stack a variant-dtype tensor with no type set ({!r})" 

137 .format(t)) 

138 

139 

140def _stack(t, length): 

141 """stacks `t` `length` times.""" 

142 # Note that this stacking may currently be triggered, for example, when a 

143 # loop invariant tensor with dtype variant is input to a while_loop which then 

144 # produces a loop dependent output. Simply stacking the variants may not be 

145 # suitable since operations on stacked handles may expect a vectorized version 

146 # of the variant. 

147 if t.dtype == dtypes.variant: 

148 shapes_and_types = _parse_variant_shapes_and_types(t) 

149 if shapes_and_types[0].type.type_id == full_type_pb2.TFT_ARRAY: 

150 if len(shapes_and_types) != 1: 

151 raise ValueError( 

152 f"Expected handle data of length 1, got {shapes_and_types!r} of " 

153 f"length {len(shapes_and_types)}.") 

154 return wrap( 

155 _stack_tensor_list(t, shapes_and_types[0].dtype, length), 

156 True) 

157 else: 

158 raise ValueError( 

159 "Attempted to stack an unhandled variant-dtype tensor of " 

160 f"type {shapes_and_types[0].type!r} ({t!r}).") 

161 ones = array_ops.ones_like(array_ops.shape(t)) 

162 ones = array_ops.reshape(ones, [-1]) 

163 length = array_ops.reshape(length, [-1]) 

164 multiples = array_ops.concat([length, ones], 0) 

165 t = array_ops.tile(array_ops.expand_dims(t, 0), multiples) 

166 return wrap(t, True) 

167 

168 

169# The following stateful ops can be safely called once, and with the same 

170# signature as the unconverted version, if their inputs are loop invariant. 

171# TODO(agarwal): implement a strategy for converting Variable reads/writes. The 

172# plan is to map each read/write in the loop_fn to a corresponding merged 

173# read/write in the converted graph. Writes need to be mergeable (e.g. 

174# AssignAdd) to be used in `pfor`. Given a certain read/write order in the 

175# loop_fn, doing a one-to-one conversion will simulate executing such 

176# instructions in lock-step across all iterations. 

177passthrough_stateful_ops = set([ 

178 "VariableV2", 

179 "VarHandleOp", 

180 "VariableShape", 

181 "ReadVariableOp", 

182 "StackV2", 

183 "TensorArrayWriteV3", 

184 "TensorArrayReadV3", 

185 "TensorArraySizeV3", 

186]) 

187 

188 

189# Ops which we will treat like stateful for the purpose of vectorization. 

190# Typically this is used to force pfor converters to run for these ops. 

191force_stateful_ops = set([ 

192 # We vectorize this since we need to change the element shape set on the 

193 # list. 

194 "TensorListReserve", 

195]) 

196 

197 

198def _is_stateful_pfor_op(op): 

199 if isinstance(op, WhileOp): 

200 return op.is_stateful 

201 if op.type == "Const": 

202 # Const didn't have an op_def. 

203 return False 

204 if op.type in passthrough_stateful_ops: 

205 return False 

206 if op.type in force_stateful_ops: 

207 return True 

208 assert hasattr(op, "op_def") and op.op_def is not None, op 

209 return op.op_def.is_stateful 

210 

211 

212# pylint: disable=protected-access 

213class WhileOp: 

214 """Object for storing state for converting the outputs of a while_loop.""" 

215 

216 def __init__(self, exit_node, pfor_ops, fallback_to_while_loop, pfor_config): 

217 """Initializer. 

218 

219 Args: 

220 exit_node: A tensor output from the while_loop. 

221 pfor_ops: list of ops inside the current pfor loop. 

222 fallback_to_while_loop: If True, fallback to while loop when conversion of 

223 an op is not supported 

224 pfor_config: PForConfig object used while constructing loop body. 

225 """ 

226 self._fallback_to_while_loop = fallback_to_while_loop 

227 self._pfor_config = pfor_config 

228 self._pfor_ops = set(pfor_ops) 

229 self._pfor_op_ids = set(x._id for x in pfor_ops) 

230 assert isinstance(exit_node, ops.Tensor) 

231 self._while_context = exit_node.op._get_control_flow_context() 

232 assert isinstance(self._while_context, control_flow_ops.WhileContext) 

233 self._context_name = self._while_context.name 

234 self._condition = self._while_context.pivot.op.inputs[0] 

235 # Parts of an external while_loop could be created inside a pfor loop. 

236 # However for the purpose here, we declare such loops to be external. Also 

237 # note that we check if the condition was created inside or outside to 

238 # determine if the while_loop was first created inside or outside. 

239 # TODO(agarwal): check that the Enter and Exit of this loop are unstacked. 

240 self._is_inside_loop = self.op_is_inside_loop(self._condition.op) 

241 if self._is_inside_loop: 

242 for e in self._while_context.loop_exits: 

243 assert self.op_is_inside_loop(e.op) 

244 

245 # Note the code below tries to reverse engineer an existing while_loop graph 

246 # by assuming the following pattern of nodes. 

247 # 

248 # NextIteration <---- Body <--- Enter 

249 # | ^ 

250 # V ___| Y 

251 # Enter -> Merge -> Switch___ 

252 # ^ | N 

253 # | V 

254 # LoopCond Exit 

255 

256 # Node that elements in the list below correspond one-to-one with each 

257 # other. i.e. these lists are the same size, and the i_th entry corresponds 

258 # to different Operations/Tensors of a single cycle as illustrated above. 

259 # List of Switch ops (ops.Operation) that feed into an Exit Node. 

260 self._exit_switches = [] 

261 # List of inputs (ops.Tensor) to NextIteration. 

262 self._body_outputs = [] 

263 # List of list of control inputs of the NextIteration nodes. 

264 self._next_iter_control_inputs = [] 

265 # List of Merge ops (ops.Operation). 

266 self._enter_merges = [] 

267 # List of output (ops.Tensor) of Exit nodes. 

268 self._outputs = [] 

269 

270 # List of Enter Tensors. 

271 # There are two types of Enter nodes: 

272 # - The Enter nodes that are used in the `loop_vars` argument to 

273 # `while_loop` (see 

274 # https://www.tensorflow.org/api_docs/python/tf/while_loop). We collect 

275 # these Enter nodes immediately below by tracing backwards from the Exit 

276 # nodes via Exit <- Switch <- Merge <- Enter. You can see this chain in the 

277 # diagram above. This allows us to have a 1:1 correspondence between the 

278 # self._outputs and the first elements in self._enters. 

279 # - The Enter nodes that are used only by the body. They don't appear in the 

280 # `loop_vars` and are not returned from the `while_loop`. In Python code, 

281 # they are usually captured by the body lambda. We collect them below by 

282 # iterating over all the ops in the graph. They are appended to the end of 

283 # self._enters or self._direct_enters, and don't correspond to any outputs 

284 # in self._outputs. Note that we keep the resource/variant Enter nodes in 

285 # self._direct_enters and the constructed while_loop's body uses them 

286 # directly as opposed to passing them as loop variables. This is done 

287 # because the while_body cannot partition the resource/variant Tensors, so 

288 # it has to leave them unchanged. 

289 self._enters = [] 

290 self._direct_enters = [] 

291 

292 for e in self._while_context.loop_exits: 

293 self._outputs.append(e.op.outputs[0]) 

294 switch = e.op.inputs[0].op 

295 assert switch.type == "Switch", switch 

296 self._exit_switches.append(switch) 

297 merge = switch.inputs[0].op 

298 assert merge.type == "Merge", merge 

299 self._enter_merges.append(merge) 

300 enter = merge.inputs[0].op 

301 assert enter.type == "Enter", enter 

302 self._enters.append(enter.outputs[0]) 

303 next_iter = merge.inputs[1].op 

304 assert next_iter.type == "NextIteration", next_iter 

305 self._body_outputs.append(next_iter.inputs[0]) 

306 self._next_iter_control_inputs.append(next_iter.control_inputs) 

307 

308 # Collect all the Enter nodes that are not part of `loop_vars`, the second 

309 # category described above. 

310 # Also track whether the loop body has any stateful ops. 

311 self._is_stateful = False 

312 for op in ops.get_default_graph().get_operations(): 

313 # TODO(agarwal): make sure this works with nested case. 

314 control_flow_context = op._get_control_flow_context() 

315 if control_flow_context is None: 

316 continue 

317 if control_flow_context.name == self._context_name: 

318 self._is_stateful |= _is_stateful_pfor_op(op) 

319 if op.type == "Enter": 

320 output = op.outputs[0] 

321 if output not in self._enters: 

322 if output.dtype in (dtypes.resource, dtypes.variant): 

323 if output not in self._direct_enters: 

324 self._direct_enters.append(output) 

325 else: 

326 self._enters.append(output) 

327 

328 def __str__(self): 

329 """String representation.""" 

330 return "while_loop(%s)" % self.name 

331 

332 @property 

333 def inputs(self): 

334 """Input to all the Enter nodes.""" 

335 return [x.op.inputs[0] for x in self._enters + self._direct_enters] 

336 

337 @property 

338 def control_inputs(self): 

339 """Control input to all the Enter nodes.""" 

340 control_inputs = [] 

341 for x in self._enters + self._direct_enters: 

342 control_inputs.extend(x.op.control_inputs) 

343 return control_inputs 

344 

345 @property 

346 def outputs(self): 

347 """Outputs of all the Exit nodes.""" 

348 return self._outputs 

349 

350 @property 

351 def name(self): 

352 """Context name for the while loop.""" 

353 return self._context_name 

354 

355 @property 

356 def is_inside_loop(self): 

357 """Returns true if the while_loop was created inside the pfor.""" 

358 return self._is_inside_loop 

359 

360 def op_is_inside_loop(self, op): 

361 """True if op was created inside the pfor loop body.""" 

362 assert isinstance(op, ops.Operation) 

363 # Note that we use self._pfor_op_ids for the check and not self._pfor_ops 

364 # since it appears there tensorflow API could return different python 

365 # objects representing the same Operation node. 

366 return op._id in self._pfor_op_ids 

367 

368 @property 

369 def is_stateful(self): 

370 return self._is_stateful 

371 

372 @property 

373 def pfor_converter(self): 

374 """Return a converter for the while loop.""" 

375 return self 

376 

377 def _init_pfor(self, parent_pfor, indices, cond_stacked, inputs, 

378 inputs_stacked): 

379 """Create a PFor object for converting parts of the while_loop. 

380 

381 Args: 

382 parent_pfor: PFor object being used for converting the while_loop. 

383 indices: int32 Tensor of ids for the iterations that are still active 

384 (i.e. did not exit the while_loop). 

385 cond_stacked: True if the while_loop condition is stacked. 

386 inputs: list of input Tensors corresponding 1-to-1 with self._enters. Note 

387 that these Tensors are a subset of the loop variables for the generated 

388 while_loop. 

389 inputs_stacked: List of booleans corresponding 1-to-1 with `inputs`, 

390 indicating if the value is stacked or not. 

391 

392 Returns: 

393 A PFor instance. The instance is initialized by adding conversion mappings 

394 of nodes that will be external to the conversion that the returned 

395 instance will be used for. e.g. Enter nodes as well as Merge and Switch 

396 outputs are mapped to converted values. 

397 """ 

398 num_outputs = len(self._outputs) 

399 assert len(inputs) == len(self._enters) 

400 assert len(inputs_stacked) == len(self._enters) 

401 loop_var = parent_pfor.loop_var 

402 loop_len = array_ops.size(indices) 

403 pfor = PFor( 

404 loop_var, 

405 loop_len, 

406 pfor_ops=self._pfor_ops, 

407 all_indices=indices, 

408 all_indices_partitioned=cond_stacked, 

409 fallback_to_while_loop=self._fallback_to_while_loop, 

410 pfor_config=self._pfor_config) 

411 # Map all inputs of Enter nodes in self._direct_enters to their converted 

412 # values. 

413 for enter in self._direct_enters: 

414 enter_input = enter.op.inputs[0] 

415 converted_enter, stacked, is_sparse_stacked = parent_pfor._convert_helper( 

416 enter_input) 

417 # Since these are resources / variants, they should be unstacked. 

418 assert not stacked and not is_sparse_stacked, (enter, converted_enter) 

419 pfor._add_conversion(enter, wrap(converted_enter, False)) 

420 

421 # Map all Enter nodes to the inputs. 

422 for enter, inp, stacked in zip(self._enters, inputs, inputs_stacked): 

423 pfor._add_conversion(enter, wrap(inp, stacked)) 

424 # Map outputs of Switch and Merge. 

425 for i in range(num_outputs): 

426 wrapped_inp = wrap(inputs[i], inputs_stacked[i]) 

427 merge = self._enter_merges[i] 

428 pfor._add_conversion(merge.outputs[0], wrapped_inp) 

429 # Note that second output of Merge is typically not used, except possibly 

430 # as a control dependency. To avoid trying to output the correct value, we 

431 # employ a hack here. We output a dummy invalid value with an incorrect 

432 # dtype. This will allow control dependency to work but if using it as an 

433 # input, it should typically lead to errors during graph construction due 

434 # to dtype mismatch. 

435 # TODO(agarwal): Check in the original graph to see if there are any 

436 # consumers of this Tensor that use it as an input. 

437 pfor._add_conversion(merge.outputs[1], 

438 wrap(constant_op.constant(-1.0), False)) 

439 switch = self._exit_switches[i] 

440 # Don't need to worry about switch.output[0] which will feed to Exit node. 

441 pfor._add_conversion(switch.outputs[1], wrapped_inp) 

442 return pfor 

443 

444 def _convert_enter(self, parent_pfor, enter): 

445 """Converts an Enter node.""" 

446 inp, stacked, _ = parent_pfor._convert_helper(enter.op.inputs[0]) 

447 control_inputs = [] 

448 for x in enter.op.control_inputs: 

449 converted = parent_pfor._convert_helper(x) 

450 if not isinstance(converted, ops.Operation): 

451 converted = converted.t 

452 control_inputs.append(converted) 

453 if control_inputs: 

454 with ops.control_dependencies(control_inputs): 

455 inp = array_ops.identity(inp) 

456 return inp, stacked 

457 

458 def _maybe_stacked(self, cache, inp): 

459 """Heuristic to figure out if the converting inp leads to a stacked value. 

460 

461 

462 Args: 

463 cache: map from Tensor to boolean indicating stacked/unstacked. 

464 inp: input Tensor. 

465 

466 Returns: 

467 True if `inp` could get stacked. If the function returns False, the 

468 converted value should be guaranteed to be unstacked. If returning True, 

469 it may or may not be stacked. 

470 """ 

471 if inp in cache: 

472 return cache[inp] 

473 if not self.op_is_inside_loop(inp.op): 

474 return False 

475 op = inp.op 

476 output = False 

477 if op.type in [ 

478 "Shape", 

479 "Rank", 

480 "ShapeN", 

481 "ZerosLike", 

482 "TensorArrayV3", 

483 "TensorArraySizeV3", 

484 ]: 

485 output = False 

486 elif _is_stateful_pfor_op(op): 

487 # This may be fairly aggressive. 

488 output = True 

489 elif op.type == "Exit": 

490 # This may be fairly aggressive. 

491 output = True 

492 else: 

493 for t in op.inputs: 

494 if self._maybe_stacked(cache, t): 

495 output = True 

496 break 

497 cache[inp] = output 

498 return output 

499 

500 def _create_init_values(self, pfor_input): 

501 """Create arguments passed to converted while_loop.""" 

502 with ops.name_scope("while_init"): 

503 loop_len_vector = pfor_input.pfor.loop_len_vector 

504 loop_len = loop_len_vector[0] 

505 num_outputs = len(self._outputs) 

506 

507 inputs = [] 

508 maybe_stacked_cache = {} 

509 # Convert all the Enters. Need to do this before checking for stacking 

510 # below. 

511 for i, enter in enumerate(self._enters): 

512 inp, stacked = self._convert_enter(pfor_input.pfor, enter) 

513 inputs.append(inp) 

514 maybe_stacked_cache[enter] = stacked 

515 # Since this enter node is part of the `loop_vars`, it corresponds to an 

516 # output and its preceding switch. We mark this switch's output the same 

517 # stackness, to act at the base case for the logic below. Below, we will 

518 # be going through the body figuring out which inputs might need to be 

519 # stacked and which inputs can safely remain unstacked. 

520 if i < num_outputs: 

521 maybe_stacked_cache[self._exit_switches[i].outputs[1]] = stacked 

522 

523 # Shape invariants for init_values corresponding to self._enters. 

524 input_shape_invariants = [] 

525 # TensorArrays for outputs of converted while loop 

526 output_tas = [] 

527 # Shape invariants for output TensorArrays. 

528 ta_shape_invariants = [] 

529 # List of booleans indicating stackness of inputs, i.e. tensors 

530 # corresponding to self._enters. 

531 inputs_stacked = [] 

532 for i, inp in enumerate(inputs): 

533 enter = self._enters[i] 

534 inp_stacked = self._maybe_stacked(maybe_stacked_cache, enter) 

535 # Note that even when an input is unstacked, the body could make it 

536 # stacked. we use a heuristic below to figure out if body may be making 

537 # it stacked. 

538 if i < num_outputs: 

539 body_output = self._body_outputs[i] 

540 if enter.op in self._pfor_ops: 

541 body_output_stacked = self._maybe_stacked(maybe_stacked_cache, 

542 body_output) 

543 else: 

544 # If constructed outside of pfor loop, then the output would not be 

545 # stacked. 

546 body_output_stacked = False 

547 if body_output_stacked and not inp_stacked: 

548 inp = _stack(inp, loop_len_vector).t 

549 inputs[i] = inp 

550 inp_stacked = True 

551 # TODO(agarwal): other attributes for the TensorArray ? 

552 output_tas.append(tensor_array_ops.TensorArray(inp.dtype, loop_len)) 

553 ta_shape_invariants.append(tensor_shape.TensorShape(None)) 

554 

555 inputs_stacked.append(inp_stacked) 

556 input_shape_invariants.append(tensor_shape.TensorShape(None)) 

557 

558 # See documentation for __call__ for the structure of init_values. 

559 init_values = [True, pfor_input.pfor.all_indices] + inputs + output_tas 

560 # TODO(agarwal): try stricter shape invariants 

561 shape_invariants = ( 

562 [tensor_shape.TensorShape(None), 

563 tensor_shape.TensorShape(None)] + input_shape_invariants + 

564 ta_shape_invariants) 

565 

566 return init_values, inputs_stacked, shape_invariants 

567 

568 def _process_cond_unstacked(self, conditions, indices, inputs, output_tas): 

569 """Handles case when condition is unstacked. 

570 

571 Note that all iterations end together. So we don't need to partition the 

572 inputs. When all iterations are done, we write the inputs to the 

573 TensorArrays. Note that we only write to index 0 of output_tas. Since all 

574 iterations end together, they can all be output together. 

575 """ 

576 not_all_done = array_ops.reshape(conditions, []) 

577 new_output_tas = [] 

578 # pylint: disable=cell-var-from-loop 

579 for i, out_ta in enumerate(output_tas): 

580 inp = inputs[i] 

581 new_output_tas.append( 

582 tf_cond.cond(not_all_done, lambda: out_ta, 

583 lambda: out_ta.write(0, inp))) 

584 # pylint: enable=cell-var-from-loop 

585 return not_all_done, indices, inputs, new_output_tas 

586 

587 def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked, 

588 output_tas): 

589 num_outputs = len(self._outputs) 

590 # Compute if all iterations are done. 

591 not_all_done = math_ops.reduce_any(conditions) 

592 conditions_int = math_ops.cast(conditions, dtypes.int32) 

593 # Partition the indices. 

594 done_indices, new_indices = data_flow_ops.dynamic_partition( 

595 indices, conditions_int, 2) 

596 

597 new_inputs = [] 

598 new_output_tas = [] 

599 for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)): 

600 # Partition the inputs. 

601 if stacked: 

602 done_inp, new_inp = data_flow_ops.dynamic_partition( 

603 inp, conditions_int, 2) 

604 else: 

605 # TODO(agarwal): avoid this stacking. See TODO earlier in 

606 # _process_cond_unstacked. 

607 done_inp = _stack(inp, [array_ops.size(done_indices)]).t 

608 new_inp = inp 

609 new_inputs.append(new_inp) 

610 # For iterations that are done, write them to TensorArrays. 

611 if i < num_outputs: 

612 out_ta = output_tas[i] 

613 # Note that done_indices can be empty. done_inp should also be empty in 

614 # that case. 

615 new_output_tas.append(out_ta.scatter(done_indices, done_inp)) 

616 return not_all_done, new_indices, new_inputs, new_output_tas 

617 

618 def _process_body(self, pfor_input, inputs_stacked, new_indices, cond_stacked, 

619 new_inputs, not_all_done): 

620 """Convert the body function.""" 

621 

622 def true_fn(control_inputs, body_pfor, body_output, stacked): 

623 """Converts the body function for all but last iteration. 

624 

625 This essentially converts body_output. Additionally, it needs to handle 

626 any control dependencies on the NextIteration node. So it creates another 

627 Identity node with the converted dependencies. 

628 """ 

629 converted_control_inp = [] 

630 for x in control_inputs: 

631 for t in x.outputs: 

632 converted_control_inp.append(body_pfor._convert_helper(t).t) 

633 if stacked: 

634 # Note convert always does the stacking. 

635 output = body_pfor.convert(body_output) 

636 else: 

637 output, convert_stacked, _ = body_pfor._convert_helper(body_output) 

638 assert convert_stacked == stacked, body_output 

639 with ops.control_dependencies(converted_control_inp): 

640 return array_ops.identity(output) 

641 

642 body_pfor = self._init_pfor(pfor_input.pfor, new_indices, cond_stacked, 

643 new_inputs, inputs_stacked) 

644 new_outputs = [] 

645 

646 for i, (body_output, 

647 stacked) in enumerate(zip(self._body_outputs, inputs_stacked)): 

648 control_inp = self._next_iter_control_inputs[i] 

649 out_dtype = body_output.dtype 

650 # Note that we want to run the body only if not all pfor iterations are 

651 # done. If all are done, we return empty tensors since these values will 

652 # not be used. Notice that the value returned by the loop is based on 

653 # TensorArrays and not directly on these returned values. 

654 # pylint: disable=cell-var-from-loop 

655 new_output = tf_cond.cond( 

656 not_all_done, 

657 lambda: true_fn(control_inp, body_pfor, body_output, stacked), 

658 lambda: constant_op.constant([], dtype=out_dtype)) 

659 # pylint: enable=cell-var-from-loop 

660 new_outputs.append(new_output) 

661 return new_outputs 

662 

663 def __call__(self, pfor_input): 

664 """Converter for the while_loop. 

665 

666 The conversion of a while_loop is another while_loop. 

667 

668 The arguments to this converted while_loop are as follows: 

669 not_all_done: Boolean scalar Tensor indicating if all the pfor iterations 

670 are done. 

671 indices: int32 1-D Tensor storing the id of the iterations that are not 

672 done. 

673 args: Remaining arguments. These can be divided into 3 categories: 

674 - First set of arguments are the tensors that correspond to the initial 

675 elements of self._enters. The elements that appear in original while 

676 loop's `loop_vars`. 

677 - The second set of arguments are the tensors that correspond to the 

678 remaining elements of self._enters. These are the tensors that directly 

679 enter the original while loop body. 

680 - Finally, the last set of arguments are TensorArrays. These TensorArrays 

681 correspond to the outputs of the original while_loop, i.e. to the 

682 elements in self._outputs. Each TensorArray has `PFor.loop_len` 

683 elements, i.e. the number of pfor iterations. At the end, the i'th 

684 element of each TensorArray will contain the output computed by the 

685 i'th iteration of pfor. Note that elements can be written into these 

686 tensors arrays in any order, depending on when the corresponding pfor 

687 iteration is done. 

688 If the original while_loop had `k` tensors in its `loop_vars` and its body 

689 directly captured `m` tensors, the `args` will contain `2 * k + m` values. 

690 

691 In each iteration, the while_loop body recomputes the condition for all 

692 active pfor iterations to see which of them are now done. It then partitions 

693 all the inputs and passes them along to the converted body. Values for all 

694 the iterations that are done are written to TensorArrays indexed by the pfor 

695 iteration number. When all iterations are done, the TensorArrays are stacked 

696 to get the final value. 

697 

698 Args: 

699 pfor_input: A PForInput object corresponding to the output of any Exit 

700 node from this while loop. 

701 

702 Returns: 

703 List of converted outputs. 

704 """ 

705 # Create init_values that will be passed to the while_loop. 

706 init_values, inputs_stacked, shape_invariants = self._create_init_values( 

707 pfor_input) 

708 # Note that we use a list as a hack since we need the nested function body 

709 # to set the value of cond_is_stacked. python2.x doesn't support nonlocal 

710 # variables. 

711 cond_is_stacked = [None] 

712 

713 def cond(not_all_done, *_): 

714 return not_all_done 

715 

716 def body(not_all_done, indices, *args): 

717 # See documentation for __call__ for the structure of *args. 

718 num_enters = len(self._enters) 

719 inputs = args[:num_enters] 

720 output_tas = args[num_enters:] 

721 # TODO(agarwal): see which outputs have consumers and only populate the 

722 # TensorArrays corresponding to those. Or do those paths get trimmed out 

723 # from inside the while_loop body? 

724 assert len(inputs) >= len(output_tas) 

725 assert len(inputs) == len(inputs_stacked) 

726 

727 # Convert condition 

728 with ops.name_scope("while_cond"): 

729 # Note that we set cond_stacked to True here. At this point we don't 

730 # know if it could be loop invariant, hence the conservative value is 

731 # to assume stacked. 

732 cond_pfor = self._init_pfor( 

733 pfor_input.pfor, 

734 indices, 

735 cond_stacked=True, 

736 inputs=inputs, 

737 inputs_stacked=inputs_stacked) 

738 conditions, cond_stacked, _ = cond_pfor._convert_helper(self._condition) 

739 cond_is_stacked[0] = cond_stacked 

740 

741 # Recompute the new condition, write outputs of done iterations, and 

742 # partition the inputs if needed. 

743 if not cond_stacked: 

744 (not_all_done, new_indices, new_inputs, 

745 new_output_tas) = self._process_cond_unstacked(conditions, indices, 

746 inputs, output_tas) 

747 else: 

748 (not_all_done, new_indices, new_inputs, 

749 new_output_tas) = self._process_cond_stacked(conditions, indices, 

750 inputs, inputs_stacked, 

751 output_tas) 

752 

753 # Convert body 

754 with ops.name_scope("while_body"): 

755 # Compute the outputs from the body. 

756 new_outputs = self._process_body(pfor_input, inputs_stacked, 

757 new_indices, cond_stacked, new_inputs, 

758 not_all_done) 

759 

760 # Note that the first num_outputs new values of inputs are computed using 

761 # the body. Rest of them were direct Enters into the condition/body and 

762 # the partitioning done earlier is sufficient to give the new value. 

763 num_outputs = len(self._outputs) 

764 new_args = ([not_all_done, new_indices] + new_outputs + 

765 list(new_inputs[num_outputs:]) + new_output_tas) 

766 return tuple(new_args) 

767 

768 while_outputs = while_loop.while_loop( 

769 cond, body, init_values, shape_invariants=shape_invariants) 

770 output_tas = while_outputs[-len(self._outputs):] 

771 outputs = [] 

772 assert cond_is_stacked[0] is not None 

773 for inp_stacked, ta in zip(inputs_stacked, output_tas): 

774 if cond_is_stacked[0]: 

775 outputs.append(wrap(ta.stack(), True)) 

776 else: 

777 # Note that if while_loop condition is unstacked, all iterations exit at 

778 # the same time and we wrote those outputs in index 0 of the tensor 

779 # array. 

780 outputs.append(wrap(ta.read(0), inp_stacked)) 

781 return outputs 

782 

783 

784class ConversionNotImplementedError(Exception): 

785 pass 

786 

787 

788class _PforInput: 

789 """Input object passed to registered pfor converters.""" 

790 

791 __slots__ = ["pfor", "_op", "_inputs"] 

792 

793 def __init__(self, pfor, op, inputs): 

794 """Creates a _PforInput object. 

795 

796 Args: 

797 pfor: PFor converter object. 

798 op: the Operation object that is being converted. 

799 inputs: list of WrappedTensor objects representing converted values of the 

800 inputs of `op`. 

801 """ 

802 self.pfor = pfor 

803 self._op = op 

804 self._inputs = inputs 

805 

806 def stack_inputs(self, stack_indices=None, tile_variants=False): 

807 """Stacks unstacked inputs at `stack_indices`. 

808 

809 Args: 

810 stack_indices: indices of inputs at which stacking is done. If None, 

811 stacking is done at all indices. 

812 tile_variants: If True, affected indices which have a variant dtype will 

813 be tiled after this operation to match the expected shape of a 

814 vectorized tensor. Variants generally need to be un-tiled when they are 

815 inputs to operations and tiled when returned. 

816 """ 

817 if stack_indices is None: 

818 stack_indices = range(len(self._inputs)) 

819 length = self.pfor.loop_len_vector 

820 for i in stack_indices: 

821 inp = self._inputs[i] 

822 is_variant = inp.t.dtype == dtypes.variant 

823 if not inp.is_stacked: 

824 self._inputs[i] = _stack(inp.t, length) 

825 if tile_variants and is_variant: 

826 self._inputs[i] = wrap( 

827 _tile_variant_with_length(self._inputs[i].t, length), True) 

828 elif not tile_variants and is_variant: 

829 self._inputs[i] = wrap(_untile_variant(self._inputs[i].t), True) 

830 

831 def expanddim_inputs_for_broadcast(self): 

832 """Reshapes stacked inputs to prepare them for broadcast. 

833 

834 Since stacked inputs have an extra leading dimension, automatic broadcasting 

835 rules could incorrectly try to expand dimensions before that leading 

836 dimension. To avoid that, we reshape these stacked inputs to the maximum 

837 rank they will need to be broadcasted to. 

838 """ 

839 if not self._inputs: 

840 return 

841 

842 # Find max rank 

843 def _get_rank(x): 

844 rank = array_ops.rank(x.t) 

845 if not x.is_stacked: 

846 rank += 1 

847 return rank 

848 

849 ranks = [_get_rank(x) for x in self._inputs] 

850 max_rank = ranks[0] 

851 for rank in ranks[1:]: 

852 max_rank = math_ops.maximum(rank, max_rank) 

853 

854 for i, inp in enumerate(self._inputs): 

855 if inp.is_stacked: 

856 shape = array_ops.shape(inp.t) 

857 rank_diff = array_ops.reshape(max_rank - ranks[i], [1]) 

858 ones = array_ops.tile([1], rank_diff) 

859 new_shape = array_ops.concat([shape[:1], ones, shape[1:]], axis=0) 

860 self._inputs[i] = wrap(array_ops.reshape(inp.t, new_shape), True) 

861 

862 @property 

863 def inputs(self): 

864 return self._inputs 

865 

866 @property 

867 def num_inputs(self): 

868 return len(self._inputs) 

869 

870 def input(self, index): 

871 assert len(self._inputs) > index, (index, self._inputs) 

872 return self._inputs[index] 

873 

874 def stacked_input(self, index): 

875 t, is_stacked, _ = self.input(index) 

876 if not is_stacked: 

877 op_type = self.op_type 

878 op_def = getattr(self._op, "op_def", None) 

879 if op_def is None: 

880 input_name = "at index %d" % index 

881 else: 

882 input_name = "\"%s\"" % op_def.input_arg[index].name 

883 raise ConversionNotImplementedError( 

884 f"Input {input_name} of op '{op_type}' expected to be not loop " 

885 "invariant.") 

886 return t 

887 

888 def unstacked_input(self, index): 

889 t, is_stacked, _ = self.input(index) 

890 if is_stacked: 

891 op_type = self.op_type 

892 op_def = getattr(self._op, "op_def", None) 

893 if op_def is None: 

894 input_name = "at index %d" % index 

895 else: 

896 input_name = "\"%s\"" % op_def.input_arg[index].name 

897 raise ConversionNotImplementedError( 

898 f"Input {input_name} of op '{op_type}' expected to be loop " 

899 "invariant.") 

900 return t 

901 

902 @property 

903 def op(self): 

904 return self._op 

905 

906 @property 

907 def op_type(self): 

908 return self._op.type 

909 

910 def get_attr(self, attr): 

911 return self._op.get_attr(attr) 

912 

913 @property 

914 def outputs(self): 

915 return self._op.outputs 

916 

917 def output(self, index): 

918 assert index < len(self._op.outputs) 

919 return self._op.outputs[index] 

920 

921 

922_pfor_converter_registry = {} 

923 

924 

925class RegisterPFor: 

926 """Utility to register converters for pfor. 

927 

928 Usage: 

929 @RegisterPFor(foo_op_type) 

930 def _foo_converter(pfor_input): 

931 ... 

932 

933 The above will register conversion function `_foo_converter` for handling 

934 conversion of `foo_op_type`. These converters are called during vectorization 

935 of a `pfor` loop body. For each operation node in this loop body, 

936 the vectorization process will call the converter corresponding to the 

937 operation type of the node. 

938 

939 During conversion, the registered function will be called with a single 

940 argument `pfor_input`, of type `PForInput`, which will contain state needed 

941 for the conversion. When the converter is called for a node, all its inputs 

942 should already have been converted and these converted values are stored in 

943 `pfor_input.inputs`. This registered function should output a list of 

944 WrappedTensor objects with the same length as the number of outputs of the 

945 node being converted. If the node had zero outputs, then it should return an 

946 ops.Operation object. These new sets of nodes should implement the 

947 functionality of running that operation for the number of iterations specified 

948 by `pfor_input.pfor.loop_len_vector[0]` where the inputs of the node for each 

949 iteration are picked from `pfor_inputs.inputs()`. 

950 

951 One tricky aspect of the conversion process is keeping track of, and 

952 leveraging loop invariance of computation. Each converted input is a 

953 WrappedTensor which indicates whether the input was loop invariant or not. If 

954 the converted value is loop invariant, its rank should match the rank of the 

955 corresponding tensor in the loop body, else its rank is larger by 1. The 

956 converter should look at the loop invariance of the inputs and generate new 

957 nodes based on that. Note that the converter will not be called if all inputs 

958 are loop invariant and the operation is not stateful. The converter should 

959 determine if its own output is loop invariant and `wrap` its output 

960 accordingly. 

961 

962 Example: 

963 

964 Here, the converter is trying to convert a Reshape node in the loop body. This 

965 node will have two inputs: the tensor to reshape, and the new shape. The 

966 example here only handles the case where the shape is loop invariant. 

967 

968 @RegisterPFor("Reshape") 

969 def _convert_reshape(pfor_input): 

970 # We assume that input is not loop invariant. Call to `stacked_input` 

971 # asserts that and returns the converted value. This value will have a rank 

972 # larger by 1 compared to the rank of the input in the loop body. 

973 t = pfor_input.stacked_input(0) 

974 

975 # We assume that shape input is loop invariant. Call to `unstacked_input` 

976 # asserts that and returns the converted value. 

977 shape = pfor_input.unstacked_input(1) 

978 

979 # We compute `new_shape` by prepending the number of iterations to the 

980 # original shape. 

981 new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], 

982 axis=0) 

983 

984 # The vectorized output involves reshaping the converted input `t` using 

985 # `new_shape`. 

986 new_output = array_ops.reshape(t, new_shape) 

987 

988 # The converted output is marked as not loop invariant using the call to 

989 # wrap. 

990 return wrap(new_output, True) 

991 """ 

992 

993 def __init__(self, op_type): 

994 """Creates an object to register a converter for op with type `op_type`.""" 

995 self.op_type = op_type 

996 

997 def __call__(self, converter): 

998 name = self.op_type 

999 assert name not in _pfor_converter_registry, "Re-registering %s " % name 

1000 _pfor_converter_registry[name] = converter 

1001 return converter 

1002 

1003 

1004class RegisterPForWithArgs(RegisterPFor): 

1005 """Utility to register converters for pfor. 

1006 

1007 Usage: 

1008 @RegisteRPFor(foo_op_type, foo=value, ....) 

1009 def _foo_converter(pfor_input, foo=None, ....): 

1010 ... 

1011 

1012 See RegisterPFor for details on the conversion function. 

1013 `RegisterPForWithArgs` allows binding extra arguments to the 

1014 conversion function at registration time. 

1015 """ 

1016 

1017 def __init__(self, op_type, *args, **kw_args): 

1018 super(RegisterPForWithArgs, self).__init__(op_type) 

1019 self._args = args 

1020 self._kw_args = kw_args 

1021 

1022 def __call__(self, converter): 

1023 

1024 def _f(pfor_input): 

1025 return converter(pfor_input, self.op_type, *self._args, **self._kw_args) 

1026 

1027 super(RegisterPForWithArgs, self).__call__(_f) 

1028 return converter 

1029 

1030 

1031# TODO(agarwal): call raw_ops instead of calling these low level routines. 

1032def _create_op(op_type, inputs, op_dtypes, attrs=None): 

1033 """Utility to create an op.""" 

1034 op = ops.get_default_graph().create_op( 

1035 op_type, inputs, op_dtypes, attrs=attrs, compute_device=True) 

1036 flat_attrs = [] 

1037 # The tape expects an alternating flat list of names and attribute values. 

1038 for a in attrs: 

1039 flat_attrs.append(str(a)) 

1040 flat_attrs.append(op.get_attr(str(a))) 

1041 execute.record_gradient(op_type, op.inputs, tuple(flat_attrs), op.outputs[:]) 

1042 return op 

1043 

1044 

1045WrappedTensor = collections.namedtuple("WrappedTensor", 

1046 ["t", "is_stacked", "is_sparse_stacked"]) 

1047"""Wrapper around the result of a Tensor conversion. 

1048 

1049The additional fields are useful for keeping track of the conversion state as 

1050data flows through the ops in the loop body. For every op whose output is a 

1051Tensor, its converter should return either a WrappedTensor or a list of 

1052WrappedTensors. 

1053 

1054Args: 

1055 t: The converted tensor 

1056 is_stacked: True if the tensor is stacked, i.e. represents the results of all 

1057 the iterations of the loop, where each row i of the tensor corresponds to 

1058 that op's output on iteration i of the loop. False if the tensor is not 

1059 stacked, i.e. represents the result of the op on of a single iteration of 

1060 the loop, where the result does not vary between iterations. 

1061 is_sparse_stacked: True if the tensor corresponds to a component tensor 

1062 (indices, values, or dense_shape) of a sparse tensor, and has been logically 

1063 stacked via a sparse conversion. 

1064""" 

1065 

1066 

1067def wrap(tensor, is_stacked=True, is_sparse_stacked=False): 

1068 """Helper to create a WrappedTensor object.""" 

1069 assert isinstance(is_stacked, bool) 

1070 assert isinstance(is_sparse_stacked, bool) 

1071 assert isinstance(tensor, ops.Tensor) 

1072 assert not is_sparse_stacked or is_stacked, ("If the wrapped tensor is " 

1073 "stacked via a sparse " 

1074 "conversion, it must also be " 

1075 "stacked.") 

1076 return WrappedTensor(tensor, is_stacked, is_sparse_stacked) 

1077 

1078 

1079def _wrap_and_tile_variants(tensor, length): 

1080 if tensor.dtype == dtypes.variant: 

1081 tensor = _tile_variant_with_length(tensor, length) 

1082 return wrap(tensor) 

1083 

1084 

1085def _fallback_converter(pfor_input, root_cause="", warn=False): 

1086 msg = ("Using a while_loop for converting " 

1087 f"{pfor_input.op_type} cause {root_cause}") 

1088 if warn: 

1089 logging.warning(msg) 

1090 else: 

1091 logging.debug(msg) 

1092 output_dtypes = [x.dtype for x in pfor_input.outputs] 

1093 iter_vec = pfor_input.pfor.loop_len_vector 

1094 # Use constant value if available, so that output shapes are static. 

1095 iter_vec_value = tensor_util.constant_value(iter_vec) 

1096 if iter_vec_value is not None: 

1097 iters = iter_vec_value[0].item() 

1098 else: 

1099 iters = iter_vec[0] 

1100 

1101 def while_body(i, *ta_list): 

1102 """Body of while loop.""" 

1103 inputs = [ 

1104 x[i, ...] if stacked else x for x, stacked, _ in pfor_input.inputs 

1105 ] 

1106 op_outputs = _create_op( 

1107 pfor_input.op_type, 

1108 inputs, 

1109 output_dtypes, 

1110 attrs=pfor_input.op.node_def.attr).outputs 

1111 

1112 outputs = [] 

1113 # TODO(agarwal): Add tf.debugging asserts to check that the shapes across 

1114 # the different iterations are the same. 

1115 for out, ta in zip(op_outputs, ta_list): 

1116 assert isinstance(out, ops.Tensor) 

1117 outputs.append(ta.write(i, out)) 

1118 return tuple([i + 1] + outputs) 

1119 

1120 ta_list = while_loop.while_loop( 

1121 lambda i, *ta: i < iters, while_body, [0] + 

1122 [tensor_array_ops.TensorArray(dtype, iters) for dtype in output_dtypes 

1123 ])[1:] 

1124 return tuple([wrap(ta.stack(), True) for ta in ta_list]) 

1125 

1126 

1127class PForConfig: 

1128 """A configuration object used to communicate with loop body function.""" 

1129 

1130 def __init__(self): 

1131 # This may be set to the number of iterations. 

1132 self._maybe_iters = None 

1133 # Map from reduction node, created by `reduce`, to the bundle of reduction 

1134 # function and arguments. 

1135 self._reduce_map = {} 

1136 

1137 def _has_reductions(self): 

1138 """True if some reductions where performed by loop body.""" 

1139 return len(self._reduce_map) 

1140 

1141 def _set_iters(self, iters): 

1142 """Set number of pfor iterations.""" 

1143 if isinstance(iters, ops.Tensor): 

1144 iters = tensor_util.constant_value(iters) 

1145 self._maybe_iters = iters 

1146 

1147 def reduce(self, fn, *args): 

1148 """Performs reduction `fn` on `args` vectorized across pfor iterations. 

1149 

1150 Note that `fn` is traced once inside the loop function context. Hence any 

1151 captures or side-effects will happen in that context. Call to the traced 

1152 version of `fn` happens during the construction of the vectorized code. 

1153 

1154 Note that this currently may not work inside a control flow construct. 

1155 Args: 

1156 fn: a reduction function. It will be called with arguments that have the 

1157 same structure as *args but with individual values whose rank may be 

1158 higher by 1 since they represent loop invariant vectorized versions of 

1159 the corresponding Tensors in *args. 

1160 *args: unvectorized Tensors. 

1161 

1162 Returns: 

1163 The result of running `fn` on the vectorized versions of `*args`. These 

1164 outputs will be available as loop invariant values to all the iterations. 

1165 """ 

1166 assert not context.executing_eagerly() 

1167 # Creates a concrete function that will be used for reduction. 

1168 tensor_specs = [] 

1169 for arg in args: 

1170 if not isinstance(arg, ops.Tensor): 

1171 raise ValueError(f"Got a non-Tensor argument {arg} in reduce.") 

1172 batched_shape = tensor_shape.TensorShape([self._maybe_iters 

1173 ]).concatenate(arg.shape) 

1174 tensor_specs.append( 

1175 tensor_spec.TensorSpec(shape=batched_shape, dtype=arg.dtype)) 

1176 concrete_function = def_function.function(fn).get_concrete_function( 

1177 *tensor_specs) 

1178 

1179 # Creates PlaceholderWithDefault and IdentityN nodes corresponding the 

1180 # reduction. 

1181 pl_outputs = [] 

1182 with ops.control_dependencies(args): 

1183 for output in concrete_function.outputs: 

1184 if not isinstance(output, ops.Tensor): 

1185 raise ValueError(f"Got a non-Tensor output {output} while running " 

1186 "reduce.") 

1187 # Note that we use placeholder_with_default just to make XLA happy since 

1188 # it does not like placeholder ops. 

1189 if output.shape.is_fully_defined(): 

1190 dummy = array_ops.zeros(output.shape.as_list(), dtype=output.dtype) 

1191 pl_outputs.append( 

1192 array_ops.placeholder_with_default(dummy, shape=output.shape)) 

1193 else: 

1194 # TODO(agarwal): support case when under XLA and output.shape is not 

1195 # fully defined. 

1196 pl_outputs.append( 

1197 array_ops.placeholder(output.dtype, shape=output.shape)) 

1198 

1199 reduction_op = array_ops.identity_n(pl_outputs)[0].op 

1200 self._reduce_map[reduction_op] = (concrete_function, args) 

1201 if len(reduction_op.outputs) == 1: 

1202 return reduction_op.outputs[0] 

1203 else: 

1204 return tuple(reduction_op.outputs) 

1205 

1206 # TODO(agarwal): handle reductions inside control flow constructs. 

1207 def reduce_concat(self, x): 

1208 """Performs a concat reduction on `x` across pfor iterations. 

1209 

1210 Note that this currently may not work inside a control flow construct. 

1211 Args: 

1212 x: an unvectorized Tensor. 

1213 

1214 Returns: 

1215 A Tensor that has rank one higher than `x`. The value is the vectorized 

1216 version of `x`, i.e. stacking the value of `x` across different pfor 

1217 iterations. 

1218 """ 

1219 return self.reduce(lambda y: y, x) 

1220 

1221 def reduce_mean(self, x): 

1222 """Performs a mean reduction on `x` across pfor iterations. 

1223 

1224 Note that this currently may not work inside a control flow construct. 

1225 Args: 

1226 x: an unvectorized Tensor. 

1227 

1228 Returns: 

1229 A Tensor that has same rank as `x`. The value is the mean of the values 

1230 of `x` across the pfor iterations. 

1231 """ 

1232 return self.reduce(lambda y: math_ops.reduce_mean(y, axis=0), x) 

1233 

1234 def reduce_sum(self, x): 

1235 """Performs a sum reduction on `x` across pfor iterations. 

1236 

1237 Note that this currently may not work inside a control flow construct. 

1238 Args: 

1239 x: an unvectorized Tensor. 

1240 

1241 Returns: 

1242 A Tensor that has same rank as `x`. The value is the sum of the values 

1243 of `x` across the pfor iterations. 

1244 """ 

1245 return self.reduce(lambda y: math_ops.reduce_sum(y, axis=0), x) 

1246 

1247 def _lookup_reduction(self, t): 

1248 """Lookups Tensor `t` in the reduction maps.""" 

1249 assert isinstance(t, ops.Tensor), t 

1250 return self._reduce_map.get(t.op) 

1251 

1252 

1253class PFor: 

1254 """Implementation of rewrite of parallel-for loops. 

1255 

1256 This class takes a DAG or a set of DAGs representing the body of a 

1257 parallel-for loop, and adds new operations to the graph that implements 

1258 functionality equivalent to running that loop body for a specified number of 

1259 iterations. This new set of nodes may or may not use a tensorflow loop 

1260 construct. 

1261 

1262 The process of conversion does not delete or change any existing operations. 

1263 It only adds operations that efficiently implement the equivalent 

1264 functionality. We refer to the added ops as "converted ops". 

1265 

1266 The conversion process uses a simple greedy heuristic. It walks the loop body 

1267 and tries to express the functionality of running each node in a loop with a 

1268 new set of nodes. When converting an op several cases are possible: 

1269 - The op is not inside the loop body. Hence it can be used as is. 

1270 - The op does not depend on the iteration number and is stateless. In this 

1271 case, it can be used as is. 

1272 - The op is not stateful, and depends on iteration number only through control 

1273 dependencies. In this case, we can create a single op with same inputs and 

1274 attributes, but with "converted" control dependencies. 

1275 - The op is not stateful, and all its inputs are loop invariant. In this 

1276 case, similar to above, we can create a single op with same inputs and 

1277 attributes, but with "converted" control dependencies. 

1278 - The op is stateful or at least one of the inputs is not loop invariant. In 

1279 this case, we run the registered converter for that op to create a set of 

1280 converted ops. All nodes in the set will have converted control dependencies 

1281 corresponding to control dependencies of the original op. If the op returned 

1282 multiple outputs, "converted outputs" could be produced by different ops in 

1283 this set. 

1284 """ 

1285 

1286 def __init__(self, 

1287 loop_var, 

1288 loop_len, 

1289 pfor_ops, 

1290 fallback_to_while_loop, 

1291 all_indices=None, 

1292 all_indices_partitioned=False, 

1293 pfor_config=None, 

1294 warn=False): 

1295 """Creates an object to rewrite a parallel-for loop. 

1296 

1297 Args: 

1298 loop_var: ops.Tensor output of a Placeholder operation. The value should 

1299 be an int32 scalar representing the loop iteration number. 

1300 loop_len: A scalar or scalar Tensor representing the number of iterations 

1301 the loop is run for. 

1302 pfor_ops: List of all ops inside the loop body. 

1303 fallback_to_while_loop: If True, on failure to vectorize an op, a while 

1304 loop is used to sequentially execute that op. 

1305 all_indices: If not None, an int32 vector with size `loop_len` 

1306 representing the iteration ids that are still active. These values 

1307 should be unique and sorted. However they may not be contiguous. This is 

1308 typically the case when inside a control flow construct which has 

1309 partitioned the indices of the iterations that are being converted. 

1310 all_indices_partitioned: If True, this object is being constructed from a 

1311 control flow construct where not all the pfor iterations are guaranteed 

1312 to be active. 

1313 pfor_config: PForConfig object used while constructing the loop body. 

1314 warn: Whether or not to warn on while loop conversions. 

1315 """ 

1316 assert isinstance(loop_var, ops.Tensor) 

1317 assert loop_var.op.type == "PlaceholderWithDefault" 

1318 self._loop_var = loop_var 

1319 loop_len_value = tensor_util.constant_value(loop_len) 

1320 if loop_len_value is not None: 

1321 loop_len = loop_len_value 

1322 self._loop_len_vector = ops.convert_to_tensor([loop_len]) 

1323 else: 

1324 self._loop_len_vector = array_ops.reshape(loop_len, [1]) 

1325 self._all_indices_partitioned = all_indices_partitioned 

1326 if all_indices_partitioned: 

1327 assert all_indices is not None 

1328 self.all_indices = ( 

1329 math_ops.range(loop_len) if all_indices is None else all_indices) 

1330 

1331 self._conversion_map = object_identity.ObjectIdentityDictionary() 

1332 self._conversion_map[loop_var] = wrap(self.all_indices, True) 

1333 self._pfor_ops = set(pfor_ops) 

1334 self._pfor_op_ids = set(x._id for x in pfor_ops) 

1335 self._fallback_to_while_loop = fallback_to_while_loop 

1336 self._warn = warn 

1337 self._pfor_config = pfor_config 

1338 

1339 def op_is_inside_loop(self, op): 

1340 """True if op was created inside the pfor loop body.""" 

1341 assert isinstance(op, ops.Operation) 

1342 # Note that we use self._pfor_op_ids for the check and not self._pfor_ops 

1343 # since it appears there tensorflow API could return different python 

1344 # objects representing the same Operation node. 

1345 return op._id in self._pfor_op_ids 

1346 

1347 def _convert_sparse(self, y): 

1348 """Returns the converted value corresponding to SparseTensor y. 

1349 

1350 For SparseTensors, instead of stacking the component tensors separately, 

1351 resulting in component tensors with shapes (N, m, rank), (N, m), and (N, 

1352 rank) respectively for indices, values, and dense_shape (where N is the loop 

1353 length and m is the number of sparse tensor values per loop iter), we want 

1354 to logically stack the SparseTensors, to create a SparseTensor whose 

1355 components are size (N * m, rank + 1), (N * m, ), and (rank + 1,) 

1356 respectively. 

1357 

1358 Here, we try to get the conversion of each component tensor. 

1359 If the tensors are stacked via a sparse conversion, return the resulting 

1360 SparseTensor composed of the converted components. Otherwise, the component 

1361 tensors are either unstacked or stacked naively. In the latter case, we 

1362 unstack the component tensors to reform loop_len SparseTensor elements, 

1363 then correctly batch them. 

1364 

1365 The unstacked tensors must have the same rank. Each dimension of each 

1366 SparseTensor will expand to be the largest among all SparseTensor elements 

1367 for that dimension. For example, if there are N SparseTensors of rank 3 

1368 being stacked, with N dense shapes, where the i_th shape is (x_i, y_i, z_i), 

1369 the new dense shape will be (N, max_i(x_i), max_i(y_i), max_i(z_i)). 

1370 

1371 Args: 

1372 y: A tf.sparse.SparseTensor. 

1373 

1374 Returns: 

1375 A tf.sparse.SparseTensor that is the converted value corresponding to y. 

1376 """ 

1377 outputs = [ 

1378 self._convert_helper(t) for t in (y.indices, y.values, y.dense_shape) 

1379 ] 

1380 assert all(isinstance(o, WrappedTensor) for o in outputs) 

1381 

1382 if all(w.is_sparse_stacked for w in outputs): 

1383 return sparse_tensor.SparseTensor(*[w.t for w in outputs]) 

1384 

1385 assert not any(w.is_sparse_stacked for w in outputs), ( 

1386 "Error converting SparseTensor. All components should be logically " 

1387 "stacked, or none.") 

1388 

1389 # If component tensors were not sparsely stacked, they are either unstacked 

1390 # or stacked without knowledge that they are components of sparse tensors. 

1391 # In this case, we have to restack them. 

1392 return self._restack_sparse_tensor_logically( 

1393 *[self._unwrap_or_tile(w) for w in outputs]) 

1394 

1395 def _restack_sparse_tensor_logically(self, indices, values, shape): 

1396 sparse_tensor_rank = indices.get_shape().dims[-1].value 

1397 if sparse_tensor_rank is not None: 

1398 sparse_tensor_rank += 1 

1399 

1400 def fn(args): 

1401 res = gen_sparse_ops.serialize_sparse( 

1402 args[0], args[1], args[2], out_type=dtypes.variant) 

1403 return res 

1404 

1405 # Applies a map function to the component tensors to serialize each 

1406 # sparse tensor element and batch them all, then deserializes the batch. 

1407 # TODO(rachelim): Try to do this without map_fn -- add the right offsets 

1408 # to shape and indices tensors instead. 

1409 result = map_fn.map_fn(fn, [indices, values, shape], dtype=dtypes.variant) 

1410 return sparse_ops.deserialize_sparse( 

1411 result, dtype=values.dtype, rank=sparse_tensor_rank) 

1412 

1413 def _unwrap_or_tile(self, wrapped_tensor): 

1414 """Given a wrapped tensor, unwrap if stacked. Otherwise, tiles it.""" 

1415 output, is_stacked = wrapped_tensor.t, wrapped_tensor.is_stacked 

1416 if is_stacked: 

1417 return output 

1418 else: 

1419 return _stack(output, self._loop_len_vector).t 

1420 

1421 def convert(self, y): 

1422 """Returns the converted value corresponding to y. 

1423 

1424 Args: 

1425 y: A ops.Tensor or a ops.Operation object. If latter, y should not have 

1426 any outputs. 

1427 

1428 Returns: 

1429 If y does not need to be converted, it returns y as is. Else it returns 

1430 the "converted value" corresponding to y. 

1431 """ 

1432 if y is None: 

1433 return None 

1434 if isinstance(y, sparse_tensor.SparseTensor): 

1435 return self._convert_sparse(y) 

1436 assert isinstance(y, (ops.Tensor, ops.Operation)), y 

1437 output = self._convert_helper(y) 

1438 if isinstance(output, WrappedTensor): 

1439 assert isinstance(y, ops.Tensor) 

1440 return self._unwrap_or_tile(output) 

1441 else: 

1442 assert isinstance(y, ops.Operation) 

1443 assert not y.outputs 

1444 assert isinstance(output, ops.Operation) 

1445 return output 

1446 

1447 def _was_converted(self, t): 

1448 """True if t is not a conversion of itself.""" 

1449 converted_t = self._conversion_map[t] 

1450 return converted_t.t is not t 

1451 

1452 def _add_conversion(self, old_output, new_output): 

1453 assert isinstance(old_output, (ops.Tensor, ops.Operation)), old_output 

1454 assert isinstance(new_output, (WrappedTensor, ops.Operation)), new_output 

1455 self._conversion_map[old_output] = new_output 

1456 

1457 def _convert_reduction(self, y): 

1458 # Handle reductions. 

1459 if self._pfor_config is None or isinstance(y, ops.Operation): 

1460 return None 

1461 reduction = self._pfor_config._lookup_reduction(y) 

1462 if reduction is None: 

1463 return None 

1464 (reduction_fn, reduction_args) = reduction 

1465 batched_args = [] 

1466 for reduction_arg in reduction_args: 

1467 assert isinstance(reduction_arg, ops.Tensor), reduction_arg 

1468 # Tensor being reduced should already be converted due to a control 

1469 # dependency on the created placeholder. 

1470 # Note that in cases where reduction_arg is in an outer context, one 

1471 # needs to locate the corresponding Enter node and use that to lookup 

1472 # the conversion. 

1473 # TODO(agarwal): handle reductions inside control flow constructs. 

1474 assert reduction_arg in self._conversion_map, ( 

1475 "Unable to handle reduction of %s, possibly as it was used " 

1476 "inside a control flow construct. Note that reductions across " 

1477 "pfor iterations are currently not supported inside control flow " 

1478 "constructs." % reduction_arg) 

1479 batched_arg = self._conversion_map[reduction_arg] 

1480 batched_args.append(self._unwrap_or_tile(batched_arg)) 

1481 outputs = reduction_fn(*batched_args) 

1482 return [wrap(output, False) for output in nest.flatten(outputs)] 

1483 

1484 def _convert_helper(self, op_or_tensor): 

1485 stack = collections.deque([op_or_tensor]) 

1486 while stack: 

1487 y = stack[0] 

1488 if y in self._conversion_map: 

1489 assert isinstance(self._conversion_map[y], 

1490 (WrappedTensor, ops.Operation)) 

1491 stack.popleft() 

1492 continue 

1493 if isinstance(y, ops.Operation): 

1494 assert not y.outputs, ( 

1495 "We only support converting Operation objects with no outputs. " 

1496 "Got %s", y) 

1497 y_op = y 

1498 else: 

1499 assert isinstance(y, ops.Tensor), y 

1500 y_op = y.op 

1501 

1502 is_while_loop = y_op.type == "Exit" 

1503 if is_while_loop: 

1504 while_op = WhileOp( 

1505 y, pfor_ops=self._pfor_ops, 

1506 fallback_to_while_loop=self.fallback_to_while_loop, 

1507 pfor_config=self._pfor_config) 

1508 is_inside_loop = while_op.is_inside_loop 

1509 # If all nodes in the while_loop graph were created inside the pfor, we 

1510 # treat the whole loop subgraph as a single op (y_op) and try to convert 

1511 # it. For while_loops that are created completely or partially outside, 

1512 # we treat them as external and should be able to simply return the Exit 

1513 # node output as is without needing any conversion. Note that for 

1514 # while_loops that are partially constructed inside, we assume they will 

1515 # be loop invariant. If that is not the case, it will create runtime 

1516 # errors since the converted graph would depend on the self._loop_var 

1517 # placeholder. 

1518 if is_inside_loop: 

1519 y_op = while_op 

1520 else: 

1521 is_inside_loop = self.op_is_inside_loop(y_op) 

1522 

1523 # If this op was not created inside the loop body, we will return as is. 

1524 # 1. Convert inputs and control inputs. 

1525 

1526 def _add_to_stack(x): 

1527 if x not in self._conversion_map: 

1528 stack.appendleft(x) 

1529 return True 

1530 else: 

1531 return False 

1532 

1533 if is_inside_loop: 

1534 added_to_stack = False 

1535 for inp in y_op.inputs: 

1536 added_to_stack |= _add_to_stack(inp) 

1537 for cinp in y_op.control_inputs: 

1538 if cinp.outputs: 

1539 for t in cinp.outputs: 

1540 added_to_stack |= _add_to_stack(t) 

1541 else: 

1542 added_to_stack |= _add_to_stack(cinp) 

1543 if added_to_stack: 

1544 continue 

1545 

1546 converted_inputs = [self._conversion_map[inp] for inp in y_op.inputs] 

1547 some_input_converted = any(self._was_converted(x) for x in y_op.inputs) 

1548 some_input_stacked = any(x.is_stacked for x in converted_inputs) 

1549 

1550 converted_control_ops = set() 

1551 some_control_input_converted = False 

1552 for cinp in y_op.control_inputs: 

1553 if cinp.outputs: 

1554 for t in cinp.outputs: 

1555 converted_t = self._conversion_map[t] 

1556 if self._was_converted(t): 

1557 some_control_input_converted = True 

1558 converted_control_ops.add(converted_t.t.op) 

1559 else: 

1560 converted_cinp = self._conversion_map[cinp] 

1561 assert isinstance(converted_cinp, ops.Operation) 

1562 if converted_cinp != cinp: 

1563 some_control_input_converted = True 

1564 converted_control_ops.add(converted_cinp) 

1565 converted_control_ops = list(converted_control_ops) 

1566 is_stateful = _is_stateful_pfor_op(y_op) 

1567 else: 

1568 converted_inputs = [] 

1569 converted_control_ops = [] 

1570 logging.vlog(3, "converting op:%s\ninputs:%s\ncontrol_inputs:%s", y_op, 

1571 converted_inputs, converted_control_ops) 

1572 

1573 # 2. Convert y_op 

1574 # If converting a while_loop, we let the while_loop convertor deal with 

1575 # putting the control dependencies appropriately. 

1576 control_dependencies = [] if is_while_loop else converted_control_ops 

1577 with ops.control_dependencies(control_dependencies), ops.name_scope( 

1578 y_op.name + "/pfor/"), ops.get_default_graph()._original_op(y_op): 

1579 # Op is a placeholder for a reduction. 

1580 reduce_output = self._convert_reduction(y) 

1581 if reduce_output is not None: 

1582 new_outputs = reduce_output 

1583 # None of the inputs and control inputs were converted. 

1584 elif ((not is_inside_loop or 

1585 (not is_stateful and not some_input_converted and 

1586 not some_control_input_converted)) and 

1587 y.graph == ops.get_default_graph()): 

1588 if y is y_op: 

1589 assert not isinstance(y_op, WhileOp) 

1590 new_outputs = y_op 

1591 else: 

1592 new_outputs = [wrap(x, False) for x in y_op.outputs] 

1593 elif not (is_stateful or is_while_loop or some_input_stacked): 

1594 # All inputs are unstacked or unconverted but some control inputs are 

1595 # converted. 

1596 # TODO(rachelim): Handle the case where some inputs are sparsely 

1597 # stacked (i.e. any(x.is_sparse_stacked for x in converted_inputs)) 

1598 new_op = _create_op(y_op.type, [x.t for x in converted_inputs], 

1599 [x.dtype for x in y_op.outputs], 

1600 y_op.node_def.attr) 

1601 if y is y_op: 

1602 new_outputs = new_op 

1603 else: 

1604 new_outputs = [] 

1605 for old_output, new_output in zip(y_op.outputs, new_op.outputs): 

1606 handle_data_util.copy_handle_data(old_output, new_output) 

1607 new_outputs.append(wrap(new_output, False)) 

1608 else: 

1609 # Either some inputs are not loop invariant or op is stateful. 

1610 if hasattr(y_op, "pfor_converter"): 

1611 converter = y_op.pfor_converter 

1612 else: 

1613 converter = _pfor_converter_registry.get(y_op.type, None) 

1614 if converter is None: 

1615 root_cause = (f"there is no registered converter for this op.") 

1616 has_variant_outputs = any(x.dtype == dtypes.variant for x in 

1617 y_op.outputs) 

1618 has_vectorized_variant_inputs = any( 

1619 _is_variant_with_internal_stacking(x) for x in 

1620 y_op.inputs) 

1621 if (self._fallback_to_while_loop and not has_variant_outputs 

1622 and not has_vectorized_variant_inputs): 

1623 converter = partial( 

1624 _fallback_converter, root_cause=root_cause, warn=self._warn) 

1625 else: 

1626 message = (f"No pfor vectorization defined for {y_op.type}\n" 

1627 f"{y_op}\n inputs: {converted_inputs}.") 

1628 if not self._fallback_to_while_loop: 

1629 message += ("Consider enabling the fallback_to_while_loop " 

1630 "option to pfor, which may run slower.") 

1631 raise ValueError(message) 

1632 # TODO(rachelim): Handle the case where some inputs are sparsely 

1633 # stacked. We should only call the converter if it supports handling 

1634 # those inputs. 

1635 pfor_inputs = _PforInput(self, y_op, converted_inputs) 

1636 try: 

1637 try: 

1638 new_outputs = converter(pfor_inputs) 

1639 except ConversionNotImplementedError as e: 

1640 has_vectorized_variant_inputs = any( 

1641 _is_variant_with_internal_stacking(x) for x in 

1642 y_op.inputs) 

1643 if (self._fallback_to_while_loop 

1644 and not has_vectorized_variant_inputs): 

1645 new_outputs = _fallback_converter( 

1646 pfor_inputs, root_cause=str(e)) 

1647 else: 

1648 raise ValueError(str(e)).with_traceback(sys.exc_info()[2]) 

1649 except Exception as e: # pylint: disable=broad-except 

1650 logging.error( 

1651 f"Got error while pfor was converting op {y_op} with inputs " 

1652 f"{y_op.inputs[:]}\n, converted inputs {pfor_inputs.inputs}\n" 

1653 f"Here are the pfor conversion stack traces: {e}") 

1654 original_op = y_op 

1655 while isinstance(original_op, ops.Operation): 

1656 logging.error( 

1657 "%s\ncreated at:\n %s", original_op, 

1658 " ".join(traceback.format_list(original_op.traceback))) 

1659 original_op = original_op._original_op 

1660 raise 

1661 

1662 if isinstance(new_outputs, WrappedTensor): 

1663 new_outputs = [new_outputs] 

1664 assert isinstance(new_outputs, 

1665 (list, tuple, ops.Operation)), new_outputs 

1666 logging.vlog(2, f"converted {y_op} {new_outputs}") 

1667 

1668 # Insert into self._conversion_map 

1669 if y is y_op: 

1670 assert isinstance(new_outputs, ops.Operation) 

1671 self._add_conversion(y_op, new_outputs) 

1672 else: 

1673 assert len(y_op.outputs) == len(new_outputs), (y_op, y_op.outputs, 

1674 new_outputs) 

1675 for old_output, new_output in zip(y_op.outputs, new_outputs): 

1676 assert isinstance(new_output, WrappedTensor), (new_output, y, y_op) 

1677 assert old_output.dtype == new_output.t.dtype, (new_output, y, y_op) 

1678 # Set shape for converted output. 

1679 output_shape = old_output.shape 

1680 if not new_output.is_sparse_stacked: 

1681 if new_output.is_stacked: 

1682 loop_len = tensor_util.constant_value(self.loop_len_vector) 

1683 if loop_len is None: 

1684 batch_dim = tensor_shape.TensorShape([None]) 

1685 else: 

1686 batch_dim = tensor_shape.TensorShape(loop_len) 

1687 output_shape = batch_dim.concatenate(output_shape) 

1688 if _is_variant_with_internal_stacking(new_output.t): 

1689 new_output.t.set_shape([]) 

1690 else: 

1691 new_output.t.set_shape(output_shape) 

1692 self._add_conversion(old_output, new_output) 

1693 stack.popleft() 

1694 

1695 return self._conversion_map[op_or_tensor] 

1696 

1697 @property 

1698 def loop_len_vector(self): 

1699 """Returns a single element vector whose value is number of iterations.""" 

1700 return self._loop_len_vector 

1701 

1702 @property 

1703 def loop_var(self): 

1704 """Returns placeholder loop variable.""" 

1705 return self._loop_var 

1706 

1707 @property 

1708 def pfor_ops(self): 

1709 return self._pfor_ops 

1710 

1711 @property 

1712 def pfor_config(self): 

1713 return self._pfor_config 

1714 

1715 @property 

1716 def all_indices_partitioned(self): 

1717 """all_indices_partitioned property. 

1718 

1719 Returns: 

1720 True if we are inside a control flow construct and not all pfor iterations 

1721 may be active. 

1722 """ 

1723 return self._all_indices_partitioned 

1724 

1725 @property 

1726 def fallback_to_while_loop(self): 

1727 return self._fallback_to_while_loop 

1728 

1729 

1730# The code below defines converters for different operations. Please see comment 

1731# for RegisterPFor to see how converters should be defined. 

1732 

1733 

1734# image_ops 

1735 

1736 

1737@RegisterPFor("AdjustContrastv2") 

1738def _convert_adjust_contrastv2(pfor_input): 

1739 images = pfor_input.stacked_input(0) 

1740 contrast_factor = pfor_input.unstacked_input(1) 

1741 return wrap(gen_image_ops.adjust_contrastv2(images, contrast_factor), True) 

1742 

1743 

1744@RegisterPFor("AdjustHue") 

1745def _convert_adjust_hue(pfor_input): 

1746 images = pfor_input.stacked_input(0) 

1747 delta = pfor_input.unstacked_input(1) 

1748 return wrap(gen_image_ops.adjust_hue(images, delta), True) 

1749 

1750 

1751@RegisterPFor("AdjustSaturation") 

1752def _convert_adjust_saturation(pfor_input): 

1753 images = pfor_input.stacked_input(0) 

1754 scale = pfor_input.unstacked_input(1) 

1755 return wrap(gen_image_ops.adjust_saturation(images, scale), True) 

1756 

1757 

1758# nn_ops 

1759 

1760 

1761def _flatten_first_two_dims(x): 

1762 """Merges first two dimensions.""" 

1763 old_shape = array_ops.shape(x) 

1764 new_shape = array_ops.concat([[-1], old_shape[2:]], axis=0) 

1765 return array_ops.reshape(x, new_shape) 

1766 

1767 

1768def _unflatten_first_dim(x, first_dim): 

1769 """Splits first dimension into [first_dim, -1].""" 

1770 old_shape = array_ops.shape(x) 

1771 new_shape = array_ops.concat([first_dim, [-1], old_shape[1:]], axis=0) 

1772 return array_ops.reshape(x, new_shape) 

1773 

1774 

1775def _inputs_with_flattening(pfor_input, input_indices): 

1776 """Stacks and flattens first dim of inputs at indices `input_indices`.""" 

1777 if input_indices is None: 

1778 input_indices = [] 

1779 pfor_input.stack_inputs(stack_indices=input_indices) 

1780 inputs = [] 

1781 for i in range(pfor_input.num_inputs): 

1782 if i in input_indices: 

1783 inp = pfor_input.stacked_input(i) 

1784 inp = _flatten_first_two_dims(inp) 

1785 else: 

1786 inp = pfor_input.unstacked_input(i) 

1787 inputs.append(inp) 

1788 return inputs 

1789 

1790 

1791@RegisterPForWithArgs("Conv2D", dims=[0]) 

1792@RegisterPForWithArgs("DepthToSpace", dims=[0]) 

1793@RegisterPForWithArgs("AvgPool", dims=[0]) 

1794@RegisterPForWithArgs("AvgPool3D", dims=[0]) 

1795@RegisterPForWithArgs("MaxPool", dims=[0]) 

1796@RegisterPForWithArgs("MaxPoolV2", dims=[0]) 

1797@RegisterPForWithArgs("MaxPool3D", dims=[0]) 

1798@RegisterPForWithArgs("MaxPool3DGrad", dims=[0, 1, 2]) 

1799@RegisterPForWithArgs("MaxPoolGrad", dims=[0, 1, 2]) 

1800@RegisterPForWithArgs("MaxPoolGradV2", dims=[0, 1, 2]) 

1801@RegisterPForWithArgs("MaxPool3DGradGrad", dims=[0, 1, 2]) 

1802@RegisterPForWithArgs("MaxPoolGradGrad", dims=[0, 1, 2]) 

1803@RegisterPForWithArgs("MaxPoolGradGradV2", dims=[0, 1, 2]) 

1804@RegisterPForWithArgs("SoftmaxCrossEntropyWithLogits", dims=[0, 1]) 

1805@RegisterPForWithArgs("SparseSoftmaxCrossEntropyWithLogits", dims=[0, 1]) 

1806@RegisterPForWithArgs("SpaceToDepth", dims=[0]) 

1807def _convert_flatten_batch(pfor_input, op_type, dims): 

1808 del op_type 

1809 inputs = _inputs_with_flattening(pfor_input, dims) 

1810 outputs = _create_op( 

1811 pfor_input.op_type, 

1812 inputs, [x.dtype for x in pfor_input.outputs], 

1813 attrs=pfor_input.op.node_def.attr).outputs 

1814 n = pfor_input.pfor.loop_len_vector 

1815 outputs = [_unflatten_first_dim(x, n) for x in outputs] 

1816 return [wrap(x, True) for x in outputs] 

1817 

1818 

1819_channel_flatten_input_cache = {} 

1820 

1821 

1822@RegisterPFor("BatchToSpaceND") 

1823def _convert_batch_to_space_nd(pfor_input): 

1824 inp = pfor_input.stacked_input(0) 

1825 block_shape = pfor_input.unstacked_input(1) 

1826 crops = pfor_input.unstacked_input(2) 

1827 

1828 inp_shape = array_ops.shape(inp) 

1829 n = pfor_input.pfor.loop_len_vector 

1830 

1831 # Reshape and transpose to move the vectorization axis inside the axes that 

1832 # will move to space. 

1833 # Reshape to 4D and transpose 

1834 block_size = math_ops.reduce_prod(block_shape) 

1835 new_shape = [n[0], block_size, inp_shape[1] // block_size, -1] 

1836 inp = array_ops.reshape(inp, new_shape) 

1837 inp = array_ops.transpose(inp, [1, 0, 2, 3]) 

1838 # Reshape back to merge the block, vectorization and batch dimension, and 

1839 # restore the other dimensions. 

1840 new_shape = array_ops.concat([n * inp_shape[1], inp_shape[2:]], axis=0) 

1841 inp = array_ops.reshape(inp, new_shape) 

1842 # Call batch_to_space and then split the new batch axis. 

1843 output = gen_array_ops.batch_to_space_nd(inp, block_shape, crops) 

1844 output = _unflatten_first_dim(output, n) 

1845 return wrap(output, True) 

1846 

1847 

1848@RegisterPFor("SpaceToBatchND") 

1849def _convert_space_to_batch_nd(pfor_input): 

1850 inp = pfor_input.stacked_input(0) 

1851 block_shape = pfor_input.unstacked_input(1) 

1852 paddings = pfor_input.unstacked_input(2) 

1853 

1854 n = pfor_input.pfor.loop_len_vector 

1855 inp_shape = array_ops.shape(inp) 

1856 inp = _flatten_first_two_dims(inp) 

1857 output = gen_array_ops.space_to_batch_nd(inp, block_shape, paddings) 

1858 output_shape = array_ops.shape(output) 

1859 block_size = math_ops.reduce_prod(block_shape) 

1860 new_shape = [block_size, n[0], -1] 

1861 output = array_ops.reshape(output, new_shape) 

1862 output = array_ops.transpose(output, [1, 0, 2]) 

1863 new_shape = array_ops.concat( 

1864 [n, block_size * inp_shape[1:2], output_shape[1:]], axis=0) 

1865 output = array_ops.reshape(output, new_shape) 

1866 return wrap(output, True) 

1867 

1868 

1869def _channel_flatten_input(x, data_format): 

1870 """Merge the stack dimension with the channel dimension. 

1871 

1872 If S is pfor's stacking dimension, then, 

1873 - for SNCHW, we transpose to NSCHW. If N dimension has size 1, the transpose 

1874 should be cheap. 

1875 - for SNHWC, we transpose to NHWSC. 

1876 We then merge the S and C dimension. 

1877 

1878 Args: 

1879 x: ops.Tensor to transform. 

1880 data_format: "NCHW" or "NHWC". 

1881 

1882 Returns: 

1883 A 3-element tuple with the transformed value, along with the shape for 

1884 reshape and order for transpose required to transform back. 

1885 """ 

1886 

1887 graph = ops.get_default_graph() 

1888 cache_key = (graph, x.ref(), data_format) 

1889 if cache_key not in _channel_flatten_input_cache: 

1890 x_shape = array_ops.shape(x) 

1891 if data_format == b"NCHW": 

1892 order = [1, 0, 2, 3, 4] 

1893 shape = array_ops.concat([x_shape[1:2], [-1], x_shape[3:]], axis=0) 

1894 reverse_order = order 

1895 else: 

1896 order = [1, 2, 3, 0, 4] 

1897 shape = array_ops.concat([x_shape[1:4], [-1]], axis=0) 

1898 reverse_order = [3, 0, 1, 2, 4] 

1899 # Move S dimension next to C dimension. 

1900 x = array_ops.transpose(x, order) 

1901 reverse_shape = array_ops.shape(x) 

1902 # Reshape to merge the S and C dimension. 

1903 x = array_ops.reshape(x, shape) 

1904 outputs = x, reverse_order, reverse_shape 

1905 _channel_flatten_input_cache[cache_key] = outputs 

1906 else: 

1907 outputs = _channel_flatten_input_cache[cache_key] 

1908 return outputs 

1909 

1910 

1911# Note that with training=True, running FusedBatchNormV3 on individual examples 

1912# is very different from running FusedBatchNormV3 on a batch of those examples. 

1913# This is because, for the latter case, the operation can be considered as first 

1914# computing the mean and variance over all the examples and then using these 

1915# to scale all those examples. This creates a data dependency between these 

1916# different "iterations" since the inputs to the scaling step depends on the 

1917# statistics coming from all these inputs. 

1918# As with other kernels, the conversion here effectively runs the kernel 

1919# independently for each iteration, and returns outputs by stacking outputs from 

1920# each of those iterations. 

1921@RegisterPFor("FusedBatchNormV3") 

1922def _convert_fused_batch_norm(pfor_input): 

1923 is_training = pfor_input.get_attr("is_training") 

1924 # When BatchNorm is used with training=False, mean and variance are provided 

1925 # externally and used as is by the op. Thus, we can merge the S and N 

1926 # dimensions as we do for regular operations. 

1927 # When BatchNorm is used with training=True, mean and variance are computed 

1928 # for each channel across the batch dimension (first one). If we merge S and N 

1929 # dimensions, mean and variances will be computed over a larger set. So, we 

1930 # merge the S and C dimensions instead. 

1931 if not is_training: 

1932 # We return zeros for batch_mean and batch_variance output. Note that CPU 

1933 # and GPU seem to have different behavior for those two outputs. CPU outputs 

1934 # zero because these values are not used during inference. GPU outputs 

1935 # something, probably real means and variances. 

1936 inputs = _inputs_with_flattening(pfor_input, [0]) 

1937 outputs = _create_op( 

1938 pfor_input.op_type, 

1939 inputs, [x.dtype for x in pfor_input.outputs], 

1940 attrs=pfor_input.op.node_def.attr).outputs 

1941 y = outputs[0] 

1942 n = pfor_input.pfor.loop_len_vector 

1943 y = _unflatten_first_dim(y, n) 

1944 mean = pfor_input.unstacked_input(3) 

1945 zeros = array_ops.zeros_like(mean) 

1946 return [wrap(y, True)] + [wrap(zeros, False)] * 5 

1947 

1948 pfor_input.stack_inputs() 

1949 data_format = pfor_input.get_attr("data_format") 

1950 # We merge the first dimension with the "C" dimension, run FusedBatchNormV3, 

1951 # and then transpose back. 

1952 x = pfor_input.stacked_input(0) 

1953 x, reverse_order, reverse_shape = _channel_flatten_input(x, data_format) 

1954 # Note that we stack all the other inputs as well so that they are the same 

1955 # size as the new size of the channel dimension. 

1956 inputs = [x] + [ 

1957 array_ops.reshape(pfor_input.stacked_input(i), [-1]) 

1958 for i in range(1, pfor_input.num_inputs) 

1959 ] 

1960 outputs = _create_op( 

1961 pfor_input.op_type, 

1962 inputs, [x.dtype for x in pfor_input.outputs], 

1963 attrs=pfor_input.op.node_def.attr).outputs 

1964 y = outputs[0] 

1965 y = array_ops.reshape(y, reverse_shape) 

1966 y = array_ops.transpose(y, reverse_order) 

1967 n = pfor_input.pfor.loop_len_vector 

1968 outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]] 

1969 outputs = [y] + outputs 

1970 return [wrap(x, True) for x in outputs] 

1971 

1972 

1973@RegisterPFor("FusedBatchNormGradV3") 

1974def _convert_fused_batch_norm_grad(pfor_input): 

1975 pfor_input.stack_inputs() 

1976 data_format = pfor_input.get_attr("data_format") 

1977 y_backprop = pfor_input.stacked_input(0) 

1978 y_backprop, _, _ = _channel_flatten_input(y_backprop, data_format) 

1979 x = pfor_input.stacked_input(1) 

1980 x, x_reverse_order, x_reverse_shape = _channel_flatten_input(x, data_format) 

1981 inputs = [y_backprop, x] + [ 

1982 array_ops.reshape(pfor_input.stacked_input(i), [-1]) 

1983 for i in range(2, pfor_input.num_inputs) 

1984 ] 

1985 outputs = _create_op( 

1986 pfor_input.op_type, 

1987 inputs, [x.dtype for x in pfor_input.outputs], 

1988 attrs=pfor_input.op.node_def.attr).outputs 

1989 x_backprop = outputs[0] 

1990 x_backprop = array_ops.reshape(x_backprop, x_reverse_shape) 

1991 x_backprop = array_ops.transpose(x_backprop, x_reverse_order) 

1992 n = pfor_input.pfor.loop_len_vector 

1993 outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]] 

1994 outputs = [x_backprop] + outputs 

1995 return [wrap(output, True) for output in outputs] 

1996 

1997 

1998@RegisterPForWithArgs("Conv2DBackpropInput", flatten_dims=[2], shape_dim=0) 

1999@RegisterPForWithArgs("AvgPoolGrad", flatten_dims=[1], shape_dim=0) 

2000@RegisterPForWithArgs("AvgPool3DGrad", flatten_dims=[1], shape_dim=0) 

2001def _convert_flatten_batch_shape_input(pfor_input, op_type, flatten_dims, 

2002 shape_dim): 

2003 del op_type 

2004 inputs = _inputs_with_flattening(pfor_input, flatten_dims) 

2005 n = pfor_input.pfor.loop_len_vector 

2006 # Adjust the `input_sizes` input. 

2007 ones = array_ops.ones([array_ops.shape(inputs[shape_dim])[0] - 1], 

2008 dtype=n.dtype) 

2009 inputs[shape_dim] *= array_ops.concat([n, ones], axis=0) 

2010 outputs = _create_op( 

2011 pfor_input.op_type, 

2012 inputs, [x.dtype for x in pfor_input.outputs], 

2013 attrs=pfor_input.op.node_def.attr).outputs 

2014 outputs = [_unflatten_first_dim(x, n) for x in outputs] 

2015 return [wrap(x, True) for x in outputs] 

2016 

2017 

2018@RegisterPFor("Conv2DBackpropFilter") 

2019def _convert_conv2d_backprop_filter(pfor_input): 

2020 pfor_input.stack_inputs(stack_indices=[2]) 

2021 inputs, inputs_stacked, _ = pfor_input.input(0) 

2022 filter_sizes = pfor_input.unstacked_input(1) 

2023 grads = pfor_input.stacked_input(2) 

2024 strides = pfor_input.get_attr("strides") 

2025 padding = pfor_input.get_attr("padding") 

2026 use_cudnn_on_gpu = pfor_input.get_attr("use_cudnn_on_gpu") 

2027 data_format = pfor_input.get_attr("data_format") 

2028 dilations = pfor_input.get_attr("dilations") 

2029 if inputs_stacked: 

2030 # TODO(agarwal): Implement this efficiently. 

2031 logging.warning("Conv2DBackpropFilter uses a while_loop. Fix that!") 

2032 

2033 def while_body(i, ta): 

2034 inp_i = inputs[i, ...] 

2035 grad_i = grads[i, ...] 

2036 output = nn_ops.conv2d_backprop_filter( 

2037 inp_i, 

2038 filter_sizes, 

2039 grad_i, 

2040 strides=strides, 

2041 padding=padding, 

2042 use_cudnn_on_gpu=use_cudnn_on_gpu, 

2043 data_format=data_format, 

2044 dilations=dilations) 

2045 return i + 1, ta.write(i, output) 

2046 

2047 n = array_ops.reshape(pfor_input.pfor.loop_len_vector, []) 

2048 _, ta = while_loop.while_loop( 

2049 lambda i, ta: i < n, while_body, 

2050 (0, tensor_array_ops.TensorArray(inputs.dtype, n))) 

2051 output = ta.stack() 

2052 return wrap(output, True) 

2053 else: 

2054 # We merge the stack dimension with the channel dimension of the gradients 

2055 # and pretend we had a larger filter (see change to filter_sizes below). 

2056 # Once the filter backprop is computed, we reshape and transpose back 

2057 # appropriately. 

2058 grads, _, _ = _channel_flatten_input(grads, data_format) 

2059 n = pfor_input.pfor.loop_len_vector 

2060 old_filter_sizes = filter_sizes 

2061 filter_sizes *= array_ops.concat([[1, 1, 1], n], axis=0) 

2062 output = nn_ops.conv2d_backprop_filter( 

2063 inputs, 

2064 filter_sizes, 

2065 grads, 

2066 strides=strides, 

2067 padding=padding, 

2068 use_cudnn_on_gpu=use_cudnn_on_gpu, 

2069 data_format=data_format, 

2070 dilations=dilations) 

2071 new_filter_shape = array_ops.concat([old_filter_sizes[:3], n, [-1]], axis=0) 

2072 output = array_ops.reshape(output, new_filter_shape) 

2073 output = array_ops.transpose(output, [3, 0, 1, 2, 4]) 

2074 return wrap(output, True) 

2075 

2076 

2077def _flatten_with_inner_dim(x, dim, x_rank): 

2078 """Merges the first dim with the specified dim.""" 

2079 shape = array_ops.shape(x) 

2080 x = array_ops.transpose(x, 

2081 list(range(1, dim)) + [0] + list(range(dim, x_rank))) 

2082 

2083 if dim < x_rank - 1: 

2084 new_shape_pieces = [shape[1:dim], [-1], shape[dim + 1:]] 

2085 else: 

2086 new_shape_pieces = [shape[1:dim], [-1]] 

2087 new_shape = array_ops.concat(new_shape_pieces, axis=0) 

2088 return array_ops.reshape(x, new_shape) 

2089 

2090 

2091def _unflatten_with_inner_dim(x, dim, x_rank, stack_size): 

2092 """Undoes _flatten_with_inner_dim.""" 

2093 shape = array_ops.shape(x) 

2094 if dim < x_rank - 1: 

2095 new_shape_pieces = [shape[:dim], [stack_size], [-1], shape[dim + 1:]] 

2096 else: 

2097 new_shape_pieces = [shape[:dim], [stack_size], [-1]] 

2098 new_shape = array_ops.concat(new_shape_pieces, axis=0) 

2099 x = array_ops.reshape(x, new_shape) 

2100 dims_permutation = [dim] + list(range(dim)) + list(range(dim + 1, x_rank + 1)) 

2101 return array_ops.transpose(x, dims_permutation) 

2102 

2103 

2104@RegisterPFor("DepthwiseConv2dNative") 

2105def _convert_depthwise_conv2d_native(pfor_input): 

2106 # Kernel can be vectorized, so folding to batch dimension does not work. We 

2107 # instead fold into the channel dimension because it is parallel. 

2108 stack_size = pfor_input.pfor.loop_len_vector[0] 

2109 data_format = pfor_input.get_attr("data_format") 

2110 c_dim = 1 if data_format == b"NCHW" else 3 

2111 t = _flatten_with_inner_dim(pfor_input.stacked_input(0), c_dim + 1, 5) 

2112 kernel = _flatten_with_inner_dim(pfor_input.stacked_input(1), 3, 5) 

2113 conv = _create_op( 

2114 "DepthwiseConv2dNative", [t, kernel], 

2115 [x.dtype for x in pfor_input.outputs], 

2116 attrs=pfor_input.op.node_def.attr).outputs[0] 

2117 return wrap(_unflatten_with_inner_dim(conv, c_dim, 4, stack_size), True) 

2118 

2119 

2120@RegisterPFor("DepthwiseConv2dNativeBackpropInput") 

2121def _convert_depthwise_conv2d_native_backprop_input(pfor_input): 

2122 stack_size = pfor_input.pfor.loop_len_vector[0] 

2123 input_sizes = pfor_input.unstacked_input(0) 

2124 data_format = pfor_input.get_attr("data_format") 

2125 c_dim = 1 if data_format == b"NCHW" else 3 

2126 input_sizes_mutipliers = [ 

2127 constant_op.constant([1] * c_dim, dtype=dtypes.int32), [stack_size] 

2128 ] 

2129 if c_dim < 3: 

2130 input_sizes_mutipliers += [ 

2131 constant_op.constant([1] * (3 - c_dim), dtype=dtypes.int32) 

2132 ] 

2133 input_sizes *= array_ops.concat(input_sizes_mutipliers, axis=0) 

2134 kernel = _flatten_with_inner_dim(pfor_input.stacked_input(1), 3, 5) 

2135 out_backprop = _flatten_with_inner_dim( 

2136 pfor_input.stacked_input(2), c_dim + 1, 5) 

2137 result = _create_op( 

2138 "DepthwiseConv2dNativeBackpropInput", [input_sizes, kernel, out_backprop], 

2139 [x.dtype for x in pfor_input.outputs], 

2140 attrs=pfor_input.op.node_def.attr).outputs[0] 

2141 return wrap(_unflatten_with_inner_dim(result, c_dim, 4, stack_size), True) 

2142 

2143 

2144@RegisterPFor("DepthwiseConv2dNativeBackpropFilter") 

2145def _convert_depthwise_conv2d_native_backprop_filter(pfor_input): 

2146 stack_size = pfor_input.pfor.loop_len_vector[0] 

2147 data_format = pfor_input.get_attr("data_format") 

2148 c_dim = 1 if data_format == b"NCHW" else 3 

2149 inputs = _flatten_with_inner_dim(pfor_input.stacked_input(0), c_dim + 1, 5) 

2150 filter_sizes = pfor_input.unstacked_input(1) 

2151 filter_sizes_multipliers = [ 

2152 constant_op.constant([1, 1], dtype=dtypes.int32), [stack_size], 

2153 constant_op.constant([1], dtype=dtypes.int32) 

2154 ] 

2155 filter_sizes *= array_ops.concat(filter_sizes_multipliers, axis=0) 

2156 out_backprop = _flatten_with_inner_dim( 

2157 pfor_input.stacked_input(2), c_dim + 1, 5) 

2158 result = _create_op( 

2159 "DepthwiseConv2dNativeBackpropFilter", 

2160 [inputs, filter_sizes, out_backprop], 

2161 [x.dtype for x in pfor_input.outputs], 

2162 attrs=pfor_input.op.node_def.attr).outputs[0] 

2163 return wrap(_unflatten_with_inner_dim(result, 2, 4, stack_size), True) 

2164 

2165 

2166@RegisterPForWithArgs("LogSoftmax", gen_nn_ops.log_softmax) 

2167@RegisterPForWithArgs("Softmax", gen_nn_ops.softmax) 

2168def _convert_softmax(pfor_input, op_type, op_func): 

2169 del op_type 

2170 return wrap(op_func(pfor_input.stacked_input(0)), True) 

2171 

2172 

2173# array_ops 

2174 

2175 

2176@RegisterPForWithArgs("Identity", array_ops.identity) 

2177@RegisterPForWithArgs("StopGradient", array_ops.stop_gradient) 

2178@RegisterPForWithArgs("MatrixDiag", array_ops.matrix_diag) 

2179@RegisterPForWithArgs("MatrixDiagPart", array_ops.matrix_diag_part) 

2180@RegisterPForWithArgs("_EagerConst", array_ops.identity) 

2181def _convert_identity(pfor_input, op_type, op_func): 

2182 del op_type 

2183 return wrap(op_func(*[x.t for x in pfor_input.inputs]), True) 

2184 

2185 

2186@RegisterPFor("IdentityN") 

2187def _convert_identity_n(pfor_input): 

2188 outputs = array_ops.identity_n([x.t for x in pfor_input.inputs]) 

2189 return [ 

2190 wrap(out, inp.is_stacked) for out, inp in zip(outputs, pfor_input.inputs) 

2191 ] 

2192 

2193 

2194@RegisterPFor("Reshape") 

2195def _convert_reshape(pfor_input): 

2196 t = pfor_input.stacked_input(0) 

2197 shape = pfor_input.unstacked_input(1) 

2198 new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0) 

2199 return wrap(array_ops.reshape(t, new_shape), True) 

2200 

2201 

2202@RegisterPFor("Fill") 

2203def _convert_fill(pfor_input): 

2204 dims = pfor_input.unstacked_input(0) 

2205 value = pfor_input.stacked_input(1) 

2206 # Expand the rank of `value` 

2207 new_shape = array_ops.concat( 

2208 [[-1], array_ops.ones([array_ops.size(dims)], dtype=dtypes.int32)], 

2209 axis=0) 

2210 value = array_ops.reshape(value, new_shape) 

2211 # Compute the new output shape 

2212 new_dims = array_ops.concat([pfor_input.pfor.loop_len_vector, dims], axis=0) 

2213 # Broadcast 

2214 return wrap(array_ops.broadcast_to(value, new_dims), True) 

2215 

2216 

2217@RegisterPFor("BroadcastTo") 

2218def _convert_broadcast_to(pfor_input): 

2219 t = pfor_input.stacked_input(0) 

2220 shape = pfor_input.unstacked_input(1) 

2221 new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0) 

2222 

2223 # Expand dims of stacked t to broadcast against the new shape. 

2224 # TODO(davmre): consider factoring out common code with 

2225 # `expanddim_inputs_for_broadcast`, which has similar logic but with 

2226 # implicit shapes (of input Tensors) rather than explicit shapes. 

2227 rank_diff = array_ops.shape(new_shape)[0] - array_ops.rank(t) 

2228 ones = array_ops.tile([1], array_ops.reshape(rank_diff, [1])) 

2229 t_shape = array_ops.shape(t) 

2230 t_expanded_shape = array_ops.concat([t_shape[:1], ones, t_shape[1:]], axis=0) 

2231 

2232 return wrap( 

2233 array_ops.broadcast_to(array_ops.reshape(t, t_expanded_shape), new_shape), 

2234 True) 

2235 

2236 

2237@RegisterPFor("ExpandDims") 

2238def _convert_expanddims(pfor_input): 

2239 t = pfor_input.stacked_input(0) 

2240 dim = pfor_input.unstacked_input(1) 

2241 dim += math_ops.cast(dim >= 0, dim.dtype) 

2242 return wrap(array_ops.expand_dims(t, axis=dim), True) 

2243 

2244 

2245@RegisterPForWithArgs("LowerBound", gen_array_ops.lower_bound) 

2246@RegisterPForWithArgs("UpperBound", gen_array_ops.upper_bound) 

2247def _convert_searchsorted(pfor_input, _, op_func): 

2248 pfor_input.stack_inputs() 

2249 sorted_inputs = _flatten_first_two_dims(pfor_input.stacked_input(0)) 

2250 values = _flatten_first_two_dims(pfor_input.stacked_input(1)) 

2251 out_type = pfor_input.get_attr("out_type") 

2252 output = op_func(sorted_inputs, values, out_type) 

2253 return wrap( 

2254 _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector), True) 

2255 

2256 

2257@RegisterPFor("MatrixBandPart") 

2258def _convert_matrix_band_part(pfor_input): 

2259 t = pfor_input.stacked_input(0) 

2260 num_lower = pfor_input.unstacked_input(1) 

2261 num_upper = pfor_input.unstacked_input(2) 

2262 return wrap( 

2263 array_ops.matrix_band_part(t, num_lower=num_lower, num_upper=num_upper), 

2264 True) 

2265 

2266 

2267@RegisterPFor("MatrixSetDiag") 

2268def _convert_matrix_set_diag(pfor_input): 

2269 pfor_input.stack_inputs() 

2270 t = pfor_input.stacked_input(0) 

2271 diag = pfor_input.stacked_input(1) 

2272 return wrap(array_ops.matrix_set_diag(t, diag), True) 

2273 

2274 

2275# Registrations for Matrix{Diag,DiagPart,SetDiag}V2-3. 

2276# The input orders defined in the OpKernel and the actual python API are 

2277# different (for compatibility with V1), so we cannot use _convert_identity. 

2278# v2 is not compatible with v3 and is never exposed on the public API. 

2279@RegisterPFor("MatrixDiagV2") 

2280@RegisterPFor("MatrixDiagV3") 

2281def _convert_matrix_diag_v2(pfor_input): 

2282 params = { 

2283 "diagonal": pfor_input.stacked_input(0), 

2284 "k": pfor_input.unstacked_input(1), 

2285 "num_rows": pfor_input.unstacked_input(2), 

2286 "num_cols": pfor_input.unstacked_input(3), 

2287 "padding_value": pfor_input.unstacked_input(4) 

2288 } 

2289 if pfor_input.op_type == "MatrixDiagV2": 

2290 return wrap(array_ops.matrix_diag_v2(**params), True) 

2291 params["align"] = pfor_input.get_attr("align") 

2292 return wrap(array_ops.matrix_diag(**params), True) 

2293 

2294 

2295@RegisterPFor("Diag") 

2296def _convert_diag(pfor_input): 

2297 diag = pfor_input.stacked_input(0) 

2298 if diag.shape.ndims == 2: 

2299 # We can use matrix_diag. 

2300 return wrap(array_ops.matrix_diag(diag), True) 

2301 else: 

2302 # It is not clear if we can do better than a while loop here with existing 

2303 # kernels. 

2304 return _fallback_converter(pfor_input, warn=False) 

2305 

2306 

2307# See notes for MatrixDiagV2 

2308@RegisterPFor("MatrixDiagPartV2") 

2309@RegisterPFor("MatrixDiagPartV3") 

2310def _convert_matrix_diag_part_v2(pfor_input): 

2311 params = { 

2312 "input": pfor_input.stacked_input(0), 

2313 "k": pfor_input.unstacked_input(1), 

2314 "padding_value": pfor_input.unstacked_input(2) 

2315 } 

2316 if pfor_input.op_type == "MatrixDiagPartV2": 

2317 return wrap(array_ops.matrix_diag_part_v2(**params), True) 

2318 params["align"] = pfor_input.get_attr("align") 

2319 return wrap(array_ops.matrix_diag_part(**params), True) 

2320 

2321 

2322# See notes for MatrixDiagV2 

2323@RegisterPFor("MatrixSetDiagV2") 

2324@RegisterPFor("MatrixSetDiagV3") 

2325def _convert_matrix_set_diag_v2(pfor_input): 

2326 pfor_input.stack_inputs([0, 1]) 

2327 params = { 

2328 "input": pfor_input.stacked_input(0), 

2329 "diagonal": pfor_input.stacked_input(1), 

2330 "k": pfor_input.unstacked_input(2) 

2331 } 

2332 if pfor_input.op_type == "MatrixSetDiagV2": 

2333 return wrap(array_ops.matrix_set_diag_v2(**params), True) 

2334 params["align"] = pfor_input.get_attr("align") 

2335 return wrap(array_ops.matrix_set_diag(**params), True) 

2336 

2337 

2338@RegisterPFor("DiagPart") 

2339def _convert_diag_part(pfor_input): 

2340 inp = pfor_input.stacked_input(0) 

2341 if inp.shape.ndims == 3: 

2342 # We can use matrix_diag_part. 

2343 return wrap(array_ops.matrix_diag_part(inp), True) 

2344 else: 

2345 # It is not clear if we can do better than a while loop here with existing 

2346 # kernels. 

2347 return _fallback_converter(pfor_input, warn=False) 

2348 

2349 

2350@RegisterPFor("OneHot") 

2351def _convert_one_hot(pfor_input): 

2352 indices = pfor_input.stacked_input(0) 

2353 depth = pfor_input.unstacked_input(1) 

2354 on_value = pfor_input.unstacked_input(2) 

2355 off_value = pfor_input.unstacked_input(3) 

2356 axis = pfor_input.get_attr("axis") 

2357 if axis >= 0: 

2358 axis += 1 

2359 return wrap( 

2360 array_ops.one_hot(indices, depth, on_value, off_value, axis), True) 

2361 

2362 

2363@RegisterPFor("Slice") 

2364def _convert_slice(pfor_input): 

2365 t = pfor_input.stacked_input(0) 

2366 begin, begin_stacked, _ = pfor_input.input(1) 

2367 size = pfor_input.unstacked_input(2) 

2368 if not begin_stacked: 

2369 begin = array_ops.concat([[0], begin], axis=0) 

2370 size = array_ops.concat([[-1], size], axis=0) 

2371 return wrap(array_ops.slice(t, begin, size), True) 

2372 else: 

2373 # Handle negative sizes. 

2374 # 

2375 # If the `begin` entry corresponding to a negative `size` is loop-variant, 

2376 # the output would be ragged. This case is not supported. But `size` having 

2377 # some negative values and some loop-variant `begin`s is OK (and it's hard 

2378 # to tell the difference statically). 

2379 original_unstacked_shape = _stack( 

2380 array_ops.shape(t)[1:], pfor_input.pfor.loop_len_vector).t 

2381 broadcast_size = _stack(size, pfor_input.pfor.loop_len_vector).t 

2382 result_shape = array_ops.where( 

2383 math_ops.less(broadcast_size, 0), 

2384 original_unstacked_shape - begin + broadcast_size + 1, broadcast_size) 

2385 result_shape = math_ops.cast(math_ops.reduce_max(result_shape, axis=0), 

2386 dtypes.int64) 

2387 

2388 # Now we enumerate points in the sliced region for each pfor iteration and 

2389 # gather them. 

2390 cumsize = math_ops.cumprod(result_shape, exclusive=True, reverse=True) 

2391 result_num_elements = math_ops.reduce_prod(result_shape) 

2392 # Offsets are loop-variant. We first compute loop-invariant gather 

2393 # coordinates, then broadcast-add the loop-variant `begin` offsets. 

2394 result_base_coordinates = ( 

2395 math_ops.range(result_num_elements, dtype=dtypes.int64)[:, None] 

2396 // cumsize[None, :]) % result_shape[None, :] 

2397 result_coordinates = ( 

2398 begin[:, None, :] 

2399 + math_ops.cast(result_base_coordinates, begin.dtype)[None, :, :]) 

2400 result_flat = array_ops.gather_nd(params=t, indices=result_coordinates, 

2401 batch_dims=1) 

2402 result_stacked_shape = array_ops.concat( 

2403 [math_ops.cast(pfor_input.pfor.loop_len_vector, result_shape.dtype), 

2404 result_shape], 

2405 axis=0) 

2406 return wrap(array_ops.reshape(result_flat, result_stacked_shape), True) 

2407 

2408 

2409@RegisterPFor("Tile") 

2410def _convert_tile(pfor_input): 

2411 t = pfor_input.stacked_input(0) 

2412 multiples = pfor_input.unstacked_input(1) 

2413 multiples = array_ops.concat([[1], multiples], 0) 

2414 return wrap(array_ops.tile(t, multiples), True) 

2415 

2416 

2417@RegisterPFor("Pack") 

2418def _convert_pack(pfor_input): 

2419 pfor_input.stack_inputs() 

2420 axis = pfor_input.get_attr("axis") 

2421 if axis >= 0: 

2422 axis += 1 

2423 return wrap( 

2424 array_ops_stack.stack([x.t for x in pfor_input.inputs], axis=axis), True) 

2425 

2426 

2427@RegisterPFor("Unpack") 

2428def _convert_unpack(pfor_input): 

2429 value = pfor_input.stacked_input(0) 

2430 axis = pfor_input.get_attr("axis") 

2431 if axis >= 0: 

2432 axis += 1 

2433 num = pfor_input.get_attr("num") 

2434 return [wrap(x, True) for x 

2435 in array_ops_stack.unstack(value, axis=axis, num=num)] 

2436 

2437 

2438@RegisterPFor("Pad") 

2439def _convert_pad(pfor_input): 

2440 t = pfor_input.stacked_input(0) 

2441 paddings = pfor_input.unstacked_input(1) 

2442 paddings = array_ops.concat([[[0, 0]], paddings], 0) 

2443 return wrap(array_ops.pad(t, paddings, mode="CONSTANT"), True) 

2444 

2445 

2446@RegisterPFor("PadV2") 

2447def _convert_pad_v2(pfor_input): 

2448 t = pfor_input.stacked_input(0) 

2449 paddings = pfor_input.unstacked_input(1) 

2450 paddings = array_ops.concat([[[0, 0]], paddings], 0) 

2451 return wrap(array_ops.pad_v2(t, paddings, mode="CONSTANT"), True) 

2452 

2453 

2454@RegisterPFor("Split") 

2455def _convert_split(pfor_input): 

2456 split_dim = pfor_input.unstacked_input(0) 

2457 t = pfor_input.stacked_input(1) 

2458 num_split = pfor_input.get_attr("num_split") 

2459 split_dim += math_ops.cast(split_dim >= 0, dtypes.int32) 

2460 return [wrap(x, True) for x in array_ops.split(t, num_split, axis=split_dim)] 

2461 

2462 

2463@RegisterPFor("SplitV") 

2464def _convert_split_v(pfor_input): 

2465 t = pfor_input.stacked_input(0) 

2466 splits = pfor_input.unstacked_input(1) 

2467 split_dim = pfor_input.unstacked_input(2) 

2468 split_dim += math_ops.cast(split_dim >= 0, dtypes.int32) 

2469 return [wrap(x, True) for x in array_ops.split(t, splits, axis=split_dim)] 

2470 

2471 

2472@RegisterPFor("Squeeze") 

2473def _convert_squeeze(pfor_input): 

2474 t = pfor_input.stacked_input(0) 

2475 squeeze_dims = pfor_input.get_attr("squeeze_dims") 

2476 squeeze_dims = [i + 1 if i >= 0 else i for i in squeeze_dims] 

2477 return wrap(array_ops.squeeze(t, axis=squeeze_dims), True) 

2478 

2479 

2480@RegisterPFor("ReverseV2") 

2481def _convert_reverse(pfor_input): 

2482 value = pfor_input.stacked_input(0) 

2483 axis = pfor_input.unstacked_input(1) 

2484 new_axis = array_ops.where_v2(axis >= 0, axis + 1, axis) 

2485 return wrap(gen_array_ops.reverse_v2(value, axis=new_axis), True) 

2486 

2487 

2488@RegisterPForWithArgs("Transpose", gen_array_ops.transpose) 

2489@RegisterPForWithArgs("ConjugateTranspose", gen_array_ops.conjugate_transpose) 

2490def _convert_transpose(pfor_input, _, op_func): 

2491 t = pfor_input.stacked_input(0) 

2492 perm = pfor_input.unstacked_input(1) 

2493 new_perm = array_ops.concat([[0], perm + 1], axis=0) 

2494 return wrap(op_func(t, new_perm), True) 

2495 

2496 

2497@RegisterPFor("ZerosLike") 

2498def _convert_zeroslike(pfor_input): 

2499 t = pfor_input.stacked_input(0) 

2500 shape = array_ops.shape(t)[1:] 

2501 return wrap(array_ops.zeros(shape, dtype=t.dtype), False) 

2502 

2503 

2504@RegisterPFor("Gather") 

2505@RegisterPFor("GatherV2") 

2506def _convert_gather(pfor_input): 

2507 param, param_stacked, _ = pfor_input.input(0) 

2508 indices, indices_stacked, _ = pfor_input.input(1) 

2509 batch_dims = pfor_input.get_attr("batch_dims") 

2510 

2511 op_type = pfor_input.op_type 

2512 if op_type == "Gather": 

2513 validate_indices = pfor_input.get_attr("validate_indices") 

2514 axis = 0 

2515 else: 

2516 validate_indices = None 

2517 # Assume we will never have a Tensor with rank > 2**32. 

2518 axis = math_ops.cast(pfor_input.unstacked_input(2), dtypes.int32) 

2519 axis_value = tensor_util.constant_value(axis) 

2520 if axis_value is not None: 

2521 axis = axis_value 

2522 if indices_stacked and not param_stacked: 

2523 if indices is pfor_input.pfor.all_indices and axis == 0: 

2524 param_shape0 = tensor_shape.dimension_value(param.shape[0]) 

2525 indices_shape0 = tensor_shape.dimension_value(indices.shape[0]) 

2526 if param_shape0 is not None and indices_shape0 == param_shape0: 

2527 # Note that with loops and conditionals, indices may not be contiguous. 

2528 # However they will be sorted and unique. So if the shape matches, then 

2529 # it must be picking up all the rows of param. 

2530 return wrap(param, True) 

2531 

2532 if batch_dims != 0: 

2533 # Convert `batch_dims` to its positive equivalent if necessary. 

2534 batch_dims_pos = batch_dims 

2535 if batch_dims < 0: 

2536 batch_dims_pos += array_ops.rank(indices) 

2537 # In order to maintain 

2538 # indices.shape[:batch_dims] == params.shape[:batch_dims] 

2539 # with stacked indices, we move the first dimension of `indices` to the 

2540 # `batch_dims + 1`th position. The (non-batch) index dimensions will be 

2541 # inserted into the shape of `output` at the `axis` dimension, which is 

2542 # then transposed to the front (below). 

2543 order = array_ops.concat([ 

2544 math_ops.range(1, batch_dims_pos + 1), 

2545 [0], 

2546 math_ops.range(batch_dims_pos + 1, array_ops.rank(indices))], axis=0) 

2547 indices = array_ops.transpose(indices, order) 

2548 

2549 output = array_ops.gather( 

2550 param, indices, validate_indices=validate_indices, axis=axis, 

2551 batch_dims=batch_dims) 

2552 if axis != 0: 

2553 axis = smart_cond.smart_cond(axis < 0, 

2554 lambda: axis + array_ops.rank(param), 

2555 lambda: ops.convert_to_tensor(axis)) 

2556 order = array_ops.concat( 

2557 [[axis], 

2558 math_ops.range(axis), 

2559 math_ops.range(axis + 1, array_ops.rank(output))], 

2560 axis=0) 

2561 output = smart_cond.smart_cond( 

2562 math_ops.equal(axis, 0), lambda: output, 

2563 lambda: array_ops.transpose(output, order)) 

2564 return wrap(output, True) 

2565 if param_stacked: 

2566 pfor_input.stack_inputs(stack_indices=[1]) 

2567 indices = pfor_input.stacked_input(1) 

2568 if isinstance(axis, ops.Tensor): 

2569 axis = array_ops.where(axis >= 0, axis + 1, axis) 

2570 else: 

2571 axis = axis + 1 if axis >= 0 else axis 

2572 batch_dims = batch_dims + 1 if batch_dims >= 0 else batch_dims 

2573 output = array_ops.gather(param, indices, axis=axis, batch_dims=batch_dims) 

2574 return wrap(output, True) 

2575 

2576 

2577@RegisterPFor("GatherNd") 

2578def _convert_gather_nd(pfor_input): 

2579 # TODO(jmenick): Add support for unstacked params. 

2580 pfor_input.stack_inputs(stack_indices=[1]) 

2581 params = pfor_input.stacked_input(0) 

2582 indices = pfor_input.stacked_input(1) 

2583 stacked_result = array_ops.gather_nd(params, indices, batch_dims=1) 

2584 return wrap(stacked_result, True) 

2585 

2586 

2587@RegisterPFor("ConcatV2") 

2588def _convert_concatv2(pfor_input): 

2589 n = pfor_input.num_inputs 

2590 pfor_input.stack_inputs(stack_indices=range(n - 1)) 

2591 axis = pfor_input.unstacked_input(n - 1) 

2592 axis += math_ops.cast(axis >= 0, axis.dtype) 

2593 return wrap( 

2594 array_ops.concat([x.t for x in pfor_input.inputs[:n - 1]], axis=axis), 

2595 True) 

2596 

2597 

2598@RegisterPFor("StridedSlice") 

2599def _convert_strided_slice(pfor_input): 

2600 inp = pfor_input.stacked_input(0) 

2601 begin = pfor_input.unstacked_input(1) 

2602 end = pfor_input.unstacked_input(2) 

2603 strides = pfor_input.unstacked_input(3) 

2604 begin_mask = pfor_input.get_attr("begin_mask") 

2605 end_mask = pfor_input.get_attr("end_mask") 

2606 ellipsis_mask = pfor_input.get_attr("ellipsis_mask") 

2607 new_axis_mask = pfor_input.get_attr("new_axis_mask") 

2608 shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask") 

2609 

2610 begin = array_ops.concat([[0], begin], axis=0) 

2611 end = array_ops.concat([[0], end], axis=0) 

2612 strides = array_ops.concat([[1], strides], axis=0) 

2613 begin_mask = begin_mask << 1 | 1 

2614 end_mask = end_mask << 1 | 1 

2615 ellipsis_mask <<= 1 

2616 new_axis_mask <<= 1 

2617 shrink_axis_mask <<= 1 

2618 return wrap( 

2619 array_ops.strided_slice( 

2620 inp, 

2621 begin, 

2622 end, 

2623 strides, 

2624 begin_mask=begin_mask, 

2625 end_mask=end_mask, 

2626 ellipsis_mask=ellipsis_mask, 

2627 new_axis_mask=new_axis_mask, 

2628 shrink_axis_mask=shrink_axis_mask), True) 

2629 

2630 

2631@RegisterPFor("StridedSliceGrad") 

2632def _convert_strided_slice_grad(pfor_input): 

2633 shape = pfor_input.unstacked_input(0) 

2634 begin = pfor_input.unstacked_input(1) 

2635 end = pfor_input.unstacked_input(2) 

2636 strides = pfor_input.unstacked_input(3) 

2637 dy = pfor_input.stacked_input(4) 

2638 begin_mask = pfor_input.get_attr("begin_mask") 

2639 end_mask = pfor_input.get_attr("end_mask") 

2640 ellipsis_mask = pfor_input.get_attr("ellipsis_mask") 

2641 new_axis_mask = pfor_input.get_attr("new_axis_mask") 

2642 shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask") 

2643 

2644 shape = array_ops.concat( 

2645 [math_ops.cast(pfor_input.pfor.loop_len_vector, shape.dtype), shape], 

2646 axis=0) 

2647 begin = array_ops.concat([[0], begin], axis=0) 

2648 end = array_ops.concat([[0], end], axis=0) 

2649 strides = array_ops.concat([[1], strides], axis=0) 

2650 begin_mask = begin_mask << 1 | 1 

2651 end_mask = end_mask << 1 | 1 

2652 ellipsis_mask <<= 1 

2653 new_axis_mask <<= 1 

2654 shrink_axis_mask <<= 1 

2655 return wrap( 

2656 array_ops.strided_slice_grad( 

2657 shape, 

2658 begin, 

2659 end, 

2660 strides, 

2661 dy, 

2662 begin_mask=begin_mask, 

2663 end_mask=end_mask, 

2664 ellipsis_mask=ellipsis_mask, 

2665 new_axis_mask=new_axis_mask, 

2666 shrink_axis_mask=shrink_axis_mask), True) 

2667 

2668 

2669@RegisterPFor("CheckNumerics") 

2670def _convert_check_numerics(pfor_input): 

2671 t = pfor_input.stacked_input(0) 

2672 message = pfor_input.get_attr("message") 

2673 return wrap(gen_array_ops.check_numerics(t, message), True) 

2674 

2675 

2676@RegisterPFor("EnsureShape") 

2677def _convert_ensure_shape(pfor_input): 

2678 t = pfor_input.stacked_input(0) 

2679 shape = tensor_shape.TensorShape(pfor_input.get_attr("shape")) 

2680 return wrap(gen_array_ops.ensure_shape(t, [None] + shape), True) 

2681 

2682 

2683# manip_ops 

2684 

2685 

2686@RegisterPFor("Roll") 

2687def _convert_roll(pfor_input): 

2688 t = pfor_input.stacked_input(0) 

2689 shift, shift_stacked, _ = pfor_input.input(1) 

2690 axis = pfor_input.unstacked_input(2) 

2691 if not shift_stacked: 

2692 return wrap(manip_ops.roll(t, shift, axis + 1), True) 

2693 else: 

2694 # `axis` and `shift` may both be vectors, with repeated axes summing the 

2695 # corresponding `shift`s. We scatter shifts into a dense array of shape 

2696 # [loop_len, num_unstacked_axes] indicating the offset for each axis. 

2697 num_unstacked_axes = math_ops.cast(array_ops.rank(t), dtypes.int64) - 1 

2698 axis = math_ops.cast(array_ops.reshape(axis, [-1]), dtypes.int64) 

2699 loop_len = math_ops.cast(pfor_input.pfor.loop_len_vector[0], dtypes.int64) 

2700 shift = math_ops.cast(array_ops.reshape(shift, [loop_len, -1]), 

2701 dtypes.int64) 

2702 axis_segment_ids = ( 

2703 math_ops.range(loop_len, dtype=dtypes.int64)[:, None] 

2704 * num_unstacked_axes + axis[None, :]) 

2705 axis_offsets = array_ops.reshape( 

2706 math_ops.unsorted_segment_sum( 

2707 data=shift, segment_ids=axis_segment_ids, 

2708 num_segments=loop_len * num_unstacked_axes), 

2709 [loop_len, num_unstacked_axes]) 

2710 

2711 # Determine the coordinates in the input array of each result and gather 

2712 # them. 

2713 unstacked_shape = array_ops.shape(t, out_type=dtypes.int64)[1:] 

2714 cumsize = math_ops.cumprod(unstacked_shape, exclusive=True, reverse=True) 

2715 num_unstacked_elements = math_ops.reduce_prod(unstacked_shape) 

2716 result_coordinates = ( 

2717 (math_ops.range(num_unstacked_elements, 

2718 dtype=dtypes.int64)[None, :, None] 

2719 // cumsize[None, None, :] - axis_offsets[:, None, :]) 

2720 % unstacked_shape[None, None, :]) 

2721 result_flat = array_ops.gather_nd(params=t, indices=result_coordinates, 

2722 batch_dims=1) 

2723 return wrap(array_ops.reshape(result_flat, array_ops.shape(t)), 

2724 True) 

2725 

2726# math_ops 

2727 

2728 

2729@RegisterPFor("MatMul") 

2730def _convert_matmul(pfor_input): 

2731 # TODO(agarwal): Check if tiling is faster than two transposes. 

2732 a, a_stacked, _ = pfor_input.input(0) 

2733 b, b_stacked, _ = pfor_input.input(1) 

2734 tr_a = pfor_input.get_attr("transpose_a") 

2735 tr_b = pfor_input.get_attr("transpose_b") 

2736 if a_stacked and b_stacked: 

2737 output = wrap(math_ops.matmul(a, b, adjoint_a=tr_a, adjoint_b=tr_b), True) 

2738 return output 

2739 elif a_stacked: 

2740 if tr_a: 

2741 a = array_ops.transpose(a, [0, 2, 1]) 

2742 if a.shape.is_fully_defined(): 

2743 x, y, z = a.shape 

2744 else: 

2745 x, y, z = [ 

2746 array_ops.reshape(i, []) 

2747 for i in array_ops.split(array_ops.shape(a), 3) 

2748 ] 

2749 a = array_ops.reshape(a, [x * y, z]) 

2750 prod = math_ops.matmul(a, b, transpose_b=tr_b) 

2751 return wrap(array_ops.reshape(prod, [x, y, -1]), True) 

2752 else: 

2753 assert b_stacked 

2754 if tr_b: 

2755 perm = [2, 0, 1] 

2756 b = array_ops.transpose(b, perm) 

2757 else: 

2758 # As an optimization, if one of the first two dimensions is 1, then we can 

2759 # reshape instead of transpose. 

2760 # TODO(agarwal): This check can be done inside Transpose kernel. 

2761 b_shape = array_ops.shape(b) 

2762 min_dim = math_ops.minimum(b_shape[0], b_shape[1]) 

2763 perm = array_ops.where( 

2764 math_ops.equal(min_dim, 1), [0, 1, 2], [1, 0, 2]) 

2765 new_shape = array_ops_stack.stack([b_shape[1], b_shape[0], b_shape[2]]) 

2766 b = array_ops.transpose(b, perm) 

2767 b = array_ops.reshape(b, new_shape) 

2768 

2769 if b.shape.is_fully_defined(): 

2770 x, y, z = b.shape 

2771 else: 

2772 x, y, z = [ 

2773 array_ops.reshape(i, []) 

2774 for i in array_ops.split(array_ops.shape(b), 3) 

2775 ] 

2776 b = array_ops.reshape(b, [x, y * z]) 

2777 prod = math_ops.matmul(a, b, transpose_a=tr_a) 

2778 prod = array_ops.reshape(prod, [-1, y, z]) 

2779 prod = array_ops.transpose(prod, [1, 0, 2]) 

2780 return wrap(prod, True) 

2781 

2782 

2783# TODO(rmlarsen): Use the converter of BatchMatMulV2 once compatibility window 

2784# is met. 

2785@RegisterPFor("BatchMatMul") 

2786def _convert_batch_mat_mul(pfor_input): 

2787 # TODO(agarwal): There may be a more efficient way to do this instead of 

2788 # stacking the inputs. 

2789 pfor_input.stack_inputs() 

2790 x = pfor_input.stacked_input(0) 

2791 y = pfor_input.stacked_input(1) 

2792 adj_x = pfor_input.get_attr("adj_x") 

2793 adj_y = pfor_input.get_attr("adj_y") 

2794 

2795 x = _flatten_first_two_dims(x) 

2796 y = _flatten_first_two_dims(y) 

2797 output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y) 

2798 output = _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector) 

2799 return wrap(output, True) 

2800 

2801 

2802@RegisterPFor("BatchMatMulV2") 

2803def _convert_batch_mat_mul_v2(pfor_input): 

2804 pfor_input.expanddim_inputs_for_broadcast() 

2805 x = pfor_input.input(0)[0] 

2806 y = pfor_input.input(1)[0] 

2807 adj_x = pfor_input.get_attr("adj_x") 

2808 adj_y = pfor_input.get_attr("adj_y") 

2809 

2810 output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y) 

2811 return wrap(output, True) 

2812 

2813 

2814@RegisterPForWithArgs("Sum", math_ops.reduce_sum) 

2815@RegisterPForWithArgs("Prod", math_ops.reduce_prod) 

2816@RegisterPForWithArgs("Max", math_ops.reduce_max) 

2817@RegisterPForWithArgs("Min", math_ops.reduce_min) 

2818@RegisterPForWithArgs("Mean", math_ops.reduce_mean) 

2819@RegisterPForWithArgs("All", math_ops.reduce_all) 

2820@RegisterPForWithArgs("Any", math_ops.reduce_any) 

2821def _convert_reduction(pfor_input, _, op_func): 

2822 t = pfor_input.stacked_input(0) 

2823 indices = pfor_input.unstacked_input(1) 

2824 # Shift positive indices by one to account for the extra dimension. 

2825 indices += math_ops.cast(indices >= 0, indices.dtype) 

2826 keep_dims = pfor_input.get_attr("keep_dims") 

2827 return wrap(op_func(t, indices, keepdims=keep_dims), True) 

2828 

2829 

2830@RegisterPForWithArgs("ArgMax", math_ops.argmax) 

2831@RegisterPForWithArgs("ArgMin", math_ops.argmin) 

2832def _convert_argmax_argmin(pfor_input, _, op_func): 

2833 t = pfor_input.stacked_input(0) 

2834 dimension = pfor_input.unstacked_input(1) 

2835 dimension += math_ops.cast(dimension >= 0, dimension.dtype) 

2836 output_type = pfor_input.get_attr("output_type") 

2837 return wrap(op_func(t, axis=dimension, output_type=output_type), True) 

2838 

2839 

2840@RegisterPFor("Bucketize") 

2841def _convert_bucketize(pfor_input): 

2842 t = pfor_input.stacked_input(0) 

2843 boundaries = pfor_input.get_attr("boundaries") 

2844 return wrap(math_ops.bucketize(t, boundaries), True) 

2845 

2846 

2847@RegisterPFor("ClipByValue") 

2848def _convert_clip_by_value(pfor_input): 

2849 t = pfor_input.stacked_input(0) 

2850 clip_value_min = pfor_input.unstacked_input(1) 

2851 clip_value_max = pfor_input.unstacked_input(2) 

2852 return wrap(gen_math_ops._clip_by_value(t, clip_value_min, clip_value_max), 

2853 True) 

2854 

2855 

2856@RegisterPForWithArgs("Cumsum", math_ops.cumsum) 

2857@RegisterPForWithArgs("Cumprod", math_ops.cumprod) 

2858def _convert_cumfoo(pfor_input, _, op_func): 

2859 t = pfor_input.stacked_input(0) 

2860 axis = pfor_input.unstacked_input(1) 

2861 # Shift positive indices by one to account for the extra dimension. 

2862 axis += math_ops.cast(axis >= 0, axis.dtype) 

2863 exclusive = pfor_input.get_attr("exclusive") 

2864 reverse = pfor_input.get_attr("reverse") 

2865 return wrap(op_func(t, axis, exclusive=exclusive, reverse=reverse), True) 

2866 

2867 

2868@RegisterPFor("BiasAdd") 

2869def _convert_biasadd(pfor_input): 

2870 t, t_stacked, _ = pfor_input.input(0) 

2871 bias, bias_stacked, _ = pfor_input.input(1) 

2872 data_format = pfor_input.get_attr("data_format").decode() 

2873 if bias_stacked: 

2874 # BiasAdd only supports 1-D biases, so cast bias to match value and use Add. 

2875 pfor_input.expanddim_inputs_for_broadcast() 

2876 t, _, _ = pfor_input.input(0) 

2877 bias = math_ops.cast(pfor_input.stacked_input(1), t.dtype) 

2878 if compat.as_bytes(data_format) == b"NCHW": 

2879 b_shape = array_ops.shape(bias) 

2880 new_b_shape = array_ops.concat( 

2881 [b_shape[:-3], b_shape[-1:], b_shape[-3:-1]], axis=0) 

2882 bias = array_ops.reshape(bias, new_b_shape) 

2883 return wrap(math_ops.add(t, bias), True) 

2884 else: 

2885 assert t_stacked, "At least one input to BiasAdd should be loop variant." 

2886 if compat.as_bytes(data_format) == b"NCHW": 

2887 shape = array_ops.shape(t) 

2888 flattened_shape = array_ops.concat([[-1], shape[2:]], axis=0) 

2889 t = array_ops.reshape(t, flattened_shape) 

2890 t = nn_ops.bias_add(t, bias, data_format="NCHW") 

2891 t = array_ops.reshape(t, shape) 

2892 return wrap(t, True) 

2893 return wrap(nn_ops.bias_add(t, bias, data_format=data_format), True) 

2894 

2895 

2896@RegisterPForWithArgs("UnsortedSegmentSum", math_ops.unsorted_segment_sum) 

2897@RegisterPForWithArgs("UnsortedSegmentMax", math_ops.unsorted_segment_max) 

2898@RegisterPForWithArgs("UnsortedSegmentMin", math_ops.unsorted_segment_min) 

2899@RegisterPForWithArgs("UnsortedSegmentProd", math_ops.unsorted_segment_prod) 

2900def _convert_unsortedsegmentsum(pfor_input, _, op_func): 

2901 pfor_input.stack_inputs([0, 1]) 

2902 data = pfor_input.stacked_input(0) 

2903 segment_ids = pfor_input.stacked_input(1) 

2904 # TODO(agarwal): handle stacked? 

2905 num_segments = pfor_input.unstacked_input(2) 

2906 if segment_ids.dtype != num_segments.dtype: 

2907 segment_ids = math_ops.cast(segment_ids, dtypes.int64) 

2908 num_segments = math_ops.cast(num_segments, dtypes.int64) 

2909 dtype = segment_ids.dtype 

2910 segment_shape = array_ops.shape(segment_ids, out_type=dtype) 

2911 n = segment_shape[0] 

2912 ones = array_ops.ones_like(segment_shape, dtype=dtype)[1:] 

2913 segment_offset = num_segments * math_ops.range(n, dtype=dtype) 

2914 segment_offset = array_ops.reshape(segment_offset, 

2915 array_ops.concat([[n], ones], axis=0)) 

2916 segment_ids += segment_offset 

2917 num_segments = math_ops.cast(num_segments, dtypes.int64) * math_ops.cast( 

2918 n, dtypes.int64) 

2919 output = op_func(data, segment_ids, num_segments) 

2920 new_output_shape = array_ops.concat( 

2921 [[n, -1], array_ops.shape(output)[1:]], axis=0) 

2922 output = array_ops.reshape(output, new_output_shape) 

2923 return wrap(output, True) 

2924 

2925 

2926def _flatten_array_with_offset(ids, offset_delta, num_rows): 

2927 """Flattens a rank 2 tensor, adding an offset to each row.""" 

2928 # Note that if `ids` is rank 1, it is broadcast to rank 2. 

2929 offset_delta = math_ops.cast(offset_delta, ids.dtype) 

2930 n = math_ops.cast(num_rows, dtype=ids.dtype) 

2931 offsets = math_ops.range( 

2932 start=0, limit=n * offset_delta, delta=offset_delta, dtype=ids.dtype) 

2933 offsets = array_ops.expand_dims(offsets, -1) 

2934 ids += offsets 

2935 return array_ops.reshape(ids, [-1]) 

2936 

2937 

2938@RegisterPForWithArgs("SparseSegmentSum", math_ops.sparse_segment_sum_v2) 

2939@RegisterPForWithArgs("SparseSegmentMean", math_ops.sparse_segment_mean_v2) 

2940@RegisterPForWithArgs("SparseSegmentSqrtN", math_ops.sparse_segment_sqrt_n_v2) 

2941@RegisterPForWithArgs("SparseSegmentSumWithNumSegments", 

2942 math_ops.sparse_segment_sum_v2) 

2943@RegisterPForWithArgs("SparseSegmentMeanWithNumSegments", 

2944 math_ops.sparse_segment_mean_v2) 

2945@RegisterPForWithArgs("SparseSegmentSqrtNWithNumSegments", 

2946 math_ops.sparse_segment_sqrt_n_v2) 

2947def _convert_sparse_segment(pfor_input, _, op_func): 

2948 _, segment_ids_stacked, _ = pfor_input.input(2) 

2949 if segment_ids_stacked: 

2950 pfor_input.stack_inputs([1]) 

2951 data, data_stacked, _ = pfor_input.input(0) 

2952 indices, _, _ = pfor_input.input(1) 

2953 num_inputs = len(pfor_input.inputs) 

2954 assert num_inputs in (3, 4) 

2955 if num_inputs == 3: 

2956 # `segment_ids` needs to be unstacked since otherwise output sizes could 

2957 # differ across pfor iterations. 

2958 segment_ids = pfor_input.unstacked_input(2) 

2959 num_segments = nn_ops.relu(math_ops.reduce_max(segment_ids) + 1) 

2960 else: 

2961 segment_ids, _, _ = pfor_input.input(2) 

2962 num_segments = pfor_input.unstacked_input(3) 

2963 

2964 n = pfor_input.pfor.loop_len_vector[0] 

2965 if data_stacked: 

2966 indices = _flatten_array_with_offset(indices, array_ops.shape(data)[1], n) 

2967 data = _flatten_first_two_dims(data) 

2968 else: 

2969 indices = array_ops.reshape(indices, [-1]) 

2970 segment_ids = _flatten_array_with_offset(segment_ids, num_segments, n) 

2971 

2972 if num_inputs == 3: 

2973 num_segments = None 

2974 else: 

2975 num_segments *= n 

2976 output = op_func(data, indices, segment_ids, num_segments=num_segments) 

2977 output = _unflatten_first_dim(output, [n]) 

2978 return wrap(output, True) 

2979 

2980 

2981@RegisterPForWithArgs("SparseSegmentSumGrad", math_ops.sparse_segment_sum_grad) 

2982@RegisterPForWithArgs("SparseSegmentMeanGrad", 

2983 math_ops.sparse_segment_mean_grad) 

2984@RegisterPForWithArgs("SparseSegmentSqrtNGrad", 

2985 math_ops.sparse_segment_sqrt_n_grad) 

2986def _convert_sparse_segment_grad(pfor_input, _, op_func): 

2987 grad = pfor_input.stacked_input(0) 

2988 indices = pfor_input.unstacked_input(1) 

2989 segment_ids = pfor_input.unstacked_input(2) 

2990 dim0 = pfor_input.unstacked_input(3) 

2991 

2992 n = pfor_input.pfor.loop_len_vector[0] 

2993 indices = _flatten_array_with_offset(indices, dim0, n) 

2994 num_segments = nn_ops.relu(math_ops.reduce_max(segment_ids) + 1) 

2995 segment_ids = _flatten_array_with_offset(segment_ids, num_segments, n) 

2996 grad = _flatten_first_two_dims(grad) 

2997 dim0 *= n 

2998 output = op_func(grad, indices, segment_ids, dim0) 

2999 output = _unflatten_first_dim(output, [n]) 

3000 return wrap(output, True) 

3001 

3002 

3003@RegisterPFor("Cast") 

3004def _convert_cast(pfor_input): 

3005 inp = pfor_input.stacked_input(0) 

3006 dtype = pfor_input.get_attr("DstT") 

3007 return wrap(math_ops.cast(inp, dtype), True) 

3008 

3009 

3010@RegisterPFor("Abs") 

3011@RegisterPFor("Acos") 

3012@RegisterPFor("Acosh") 

3013@RegisterPFor("Add") 

3014@RegisterPFor("AddV2") 

3015@RegisterPFor("Angle") 

3016@RegisterPFor("Asin") 

3017@RegisterPFor("Asinh") 

3018@RegisterPFor("Atan") 

3019@RegisterPFor("Atan2") 

3020@RegisterPFor("Atanh") 

3021@RegisterPFor("BesselI0") 

3022@RegisterPFor("BesselI1") 

3023@RegisterPFor("BesselI0e") 

3024@RegisterPFor("BesselI1e") 

3025@RegisterPFor("BesselK0") 

3026@RegisterPFor("BesselK1") 

3027@RegisterPFor("BesselK0e") 

3028@RegisterPFor("BesselK1e") 

3029@RegisterPFor("BesselJ0") 

3030@RegisterPFor("BesselJ1") 

3031@RegisterPFor("BesselY0") 

3032@RegisterPFor("BesselY1") 

3033@RegisterPFor("BitwiseAnd") 

3034@RegisterPFor("BitwiseOr") 

3035@RegisterPFor("BitwiseXor") 

3036@RegisterPFor("Ceil") 

3037@RegisterPFor("Complex") 

3038@RegisterPFor("ComplexAbs") 

3039@RegisterPFor("Conj") 

3040@RegisterPFor("Cos") 

3041@RegisterPFor("Cosh") 

3042@RegisterPFor("Dawsn") 

3043@RegisterPFor("Digamma") 

3044@RegisterPFor("Div") 

3045@RegisterPFor("DivNoNan") 

3046@RegisterPFor("Elu") 

3047@RegisterPFor("Erf") 

3048@RegisterPFor("Erfc") 

3049@RegisterPFor("Erfinv") 

3050@RegisterPFor("Exp") 

3051@RegisterPFor("Expint") 

3052@RegisterPFor("Expm1") 

3053@RegisterPFor("Floor") 

3054@RegisterPFor("FloorDiv") 

3055@RegisterPFor("FloorMod") 

3056@RegisterPFor("FresnelCos") 

3057@RegisterPFor("FresnelSin") 

3058@RegisterPFor("Greater") 

3059@RegisterPFor("GreaterEqual") 

3060@RegisterPFor("Igamma") 

3061@RegisterPFor("IgammaGradA") 

3062@RegisterPFor("Igammac") 

3063@RegisterPFor("Imag") 

3064@RegisterPFor("Inv") 

3065@RegisterPFor("Invert") 

3066@RegisterPFor("IsFinite") 

3067@RegisterPFor("IsInf") 

3068@RegisterPFor("IsNan") 

3069@RegisterPFor("LeftShift") 

3070@RegisterPFor("Less") 

3071@RegisterPFor("LessEqual") 

3072@RegisterPFor("Lgamma") 

3073@RegisterPFor("Log") 

3074@RegisterPFor("Log1p") 

3075@RegisterPFor("LogicalAnd") 

3076@RegisterPFor("LogicalNot") 

3077@RegisterPFor("LogicalOr") 

3078@RegisterPFor("LogicalXor") 

3079@RegisterPFor("Maximum") 

3080@RegisterPFor("Minimum") 

3081@RegisterPFor("Mod") 

3082@RegisterPFor("Mul") 

3083@RegisterPFor("MulNoNan") 

3084@RegisterPFor("Ndtri") 

3085@RegisterPFor("Neg") 

3086@RegisterPFor("Polygamma") 

3087@RegisterPFor("Pow") 

3088@RegisterPFor("Real") 

3089@RegisterPFor("RealDiv") 

3090@RegisterPFor("Reciprocal") 

3091@RegisterPFor("Relu") 

3092@RegisterPFor("Relu6") 

3093@RegisterPFor("RightShift") 

3094@RegisterPFor("Rint") 

3095@RegisterPFor("Round") 

3096@RegisterPFor("Rsqrt") 

3097@RegisterPFor("Selu") 

3098@RegisterPFor("Sigmoid") 

3099@RegisterPFor("Sign") 

3100@RegisterPFor("Sin") 

3101@RegisterPFor("Sinh") 

3102@RegisterPFor("Softplus") 

3103@RegisterPFor("Softsign") 

3104@RegisterPFor("Spence") 

3105@RegisterPFor("Sqrt") 

3106@RegisterPFor("Square") 

3107@RegisterPFor("SquaredDifference") 

3108@RegisterPFor("Sub") 

3109@RegisterPFor("Tan") 

3110@RegisterPFor("Tanh") 

3111@RegisterPFor("TruncateDiv") 

3112@RegisterPFor("TruncateMod") 

3113@RegisterPFor("Xdivy") 

3114@RegisterPFor("Xlogy") 

3115@RegisterPFor("Xlog1py") 

3116@RegisterPFor("Zeta") 

3117def _convert_cwise(pfor_input): 

3118 if pfor_input.num_inputs > 1: 

3119 pfor_input.expanddim_inputs_for_broadcast() 

3120 

3121 out = _create_op( 

3122 pfor_input.op_type, [x.t for x in pfor_input.inputs], 

3123 [x.dtype for x in pfor_input.outputs], 

3124 attrs=pfor_input.op.node_def.attr).outputs 

3125 assert len(out) == 1 

3126 out = out[0] 

3127 

3128 op_output = wrap(out, True) 

3129 return op_output 

3130 

3131 

3132@RegisterPFor("XlaSharding") 

3133def _convert_xla_sharding(pfor_input): 

3134 t = pfor_input.stacked_input(0) 

3135 sharding = pfor_input.get_attr("sharding") 

3136 return wrap(xla.sharding(t, sharding=sharding), True) 

3137 

3138 

3139@RegisterPFor("LeakyRelu") 

3140def _convert_leaky_relu(pfor_input): 

3141 t = pfor_input.stacked_input(0) 

3142 alpha = pfor_input.get_attr("alpha") 

3143 return wrap(gen_nn_ops.leaky_relu(t, alpha=alpha), True) 

3144 

3145 

3146@RegisterPFor("Equal") 

3147def _convert_equal(pfor_input): 

3148 pfor_input.expanddim_inputs_for_broadcast() 

3149 x = pfor_input.input(0)[0] 

3150 y = pfor_input.input(1)[0] 

3151 incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error") 

3152 return wrap(gen_math_ops.equal( 

3153 x, y, incompatible_shape_error=incompatible_shape_error), True) 

3154 

3155 

3156@RegisterPFor("NotEqual") 

3157def _convert_not_equal(pfor_input): 

3158 pfor_input.expanddim_inputs_for_broadcast() 

3159 x = pfor_input.input(0)[0] 

3160 y = pfor_input.input(1)[0] 

3161 incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error") 

3162 return wrap(gen_math_ops.not_equal( 

3163 x, y, incompatible_shape_error=incompatible_shape_error), True) 

3164 

3165 

3166@RegisterPFor("ApproximateEqual") 

3167def _convert_approximate_equal(pfor_input): 

3168 pfor_input.expanddim_inputs_for_broadcast() 

3169 x = pfor_input.input(0)[0] 

3170 y = pfor_input.input(1)[0] 

3171 tolerance = pfor_input.get_attr("tolerance") 

3172 return wrap(math_ops.approximate_equal(x, y, tolerance=tolerance), True) 

3173 

3174 

3175@RegisterPFor("Shape") 

3176def _convert_shape(pfor_input): 

3177 out_type = pfor_input.get_attr("out_type") 

3178 return wrap( 

3179 array_ops.shape(pfor_input.stacked_input(0), out_type=out_type)[1:], 

3180 False) 

3181 

3182 

3183@RegisterPFor("ShapeN") 

3184def _convert_shape_n(pfor_input): 

3185 out_type = pfor_input.get_attr("out_type") 

3186 shapes = [ 

3187 array_ops.shape(x, out_type=out_type)[1:] if stacked else array_ops.shape( 

3188 x, out_type=out_type) for x, stacked, _ in pfor_input.inputs 

3189 ] 

3190 return [wrap(x, False) for x in shapes] 

3191 

3192 

3193@RegisterPFor("Size") 

3194def _convert_size(pfor_input): 

3195 out_type = pfor_input.get_attr("out_type") 

3196 n = math_ops.cast(pfor_input.pfor.loop_len_vector[0], out_type) 

3197 return wrap( 

3198 array_ops.size(pfor_input.stacked_input(0), out_type=out_type) // n, 

3199 False) 

3200 

3201 

3202@RegisterPFor("Rank") 

3203def _convert_rank(pfor_input): 

3204 return wrap(array_ops.rank(pfor_input.stacked_input(0)) - 1, False) 

3205 

3206 

3207@RegisterPFor("AddN") 

3208def _convert_addn(pfor_input): 

3209 # AddN does not support broadcasting. 

3210 pfor_input.stack_inputs(tile_variants=False) 

3211 return _wrap_and_tile_variants( 

3212 math_ops.add_n([x.t for x in pfor_input.inputs]), 

3213 pfor_input.pfor.loop_len_vector) 

3214 

3215 

3216@RegisterPFor("Cross") 

3217def _convert_cross(pfor_input): 

3218 pfor_input.stack_inputs() 

3219 a = pfor_input.stacked_input(0) 

3220 b = pfor_input.stacked_input(1) 

3221 return wrap(math_ops.cross(a, b), True) 

3222 

3223 

3224@RegisterPFor("BiasAddGrad") 

3225def _convert_biasaddgrad(pfor_input): 

3226 grad = pfor_input.stacked_input(0) 

3227 fmt = pfor_input.get_attr("data_format") 

3228 if fmt == b"NCHW": 

3229 output = math_ops.reduce_sum(grad, axis=[1, 3, 4], keepdims=False) 

3230 else: 

3231 grad_shape = array_ops.shape(grad) 

3232 last_dim_shape = grad_shape[-1] 

3233 first_dim_shape = grad_shape[0] 

3234 output = array_ops.reshape(grad, [first_dim_shape, -1, last_dim_shape]) 

3235 output = math_ops.reduce_sum(output, axis=[1], keepdims=False) 

3236 return wrap(output, True) 

3237 

3238 

3239# Some required ops are not exposed under the tf namespace. Hence relying on 

3240# _create_op to create them. 

3241@RegisterPForWithArgs("EluGrad") 

3242@RegisterPForWithArgs("LeakyReluGrad") 

3243@RegisterPForWithArgs("ReciprocalGrad") 

3244@RegisterPForWithArgs("Relu6Grad") 

3245@RegisterPForWithArgs("ReluGrad") 

3246@RegisterPForWithArgs("RsqrtGrad") 

3247@RegisterPForWithArgs("SeluGrad") 

3248@RegisterPForWithArgs("SigmoidGrad") 

3249@RegisterPForWithArgs("SoftplusGrad") 

3250@RegisterPForWithArgs("SoftsignGrad") 

3251@RegisterPForWithArgs("SqrtGrad") 

3252@RegisterPForWithArgs("TanhGrad") 

3253def _convert_grads(pfor_input, op_type, *args, **kw_args): 

3254 del args 

3255 del kw_args 

3256 # TODO(agarwal): Looks like these ops don't support broadcasting. Hence we 

3257 # have to use tiling here. 

3258 pfor_input.stack_inputs() 

3259 outputs = _create_op( 

3260 op_type, [x.t for x in pfor_input.inputs], 

3261 [x.dtype for x in pfor_input.outputs], 

3262 attrs=pfor_input.op.node_def.attr).outputs 

3263 return [wrap(x, True) for x in outputs] 

3264 

3265 

3266@RegisterPFor("Select") 

3267def _convert_select(pfor_input): 

3268 pfor_input.stack_inputs() 

3269 cond = pfor_input.stacked_input(0) 

3270 t = pfor_input.stacked_input(1) 

3271 e = pfor_input.stacked_input(2) 

3272 cond_rank = array_ops.rank(cond) 

3273 cond, t, e = smart_cond.smart_cond( 

3274 cond_rank > 1, lambda: _inputs_with_flattening(pfor_input, [0, 1, 2]), 

3275 lambda: [cond, t, e]) 

3276 outputs = _create_op( 

3277 pfor_input.op_type, [cond, t, e], [x.dtype for x in pfor_input.outputs], 

3278 attrs=pfor_input.op.node_def.attr).outputs 

3279 n = pfor_input.pfor.loop_len_vector 

3280 out = smart_cond.smart_cond(cond_rank > 1, 

3281 lambda: _unflatten_first_dim(outputs[0], n), 

3282 lambda: outputs[0]) 

3283 return [wrap(out, True) for x in outputs] 

3284 

3285 

3286@RegisterPFor("SelectV2") 

3287def _convert_selectv2(pfor_input): 

3288 pfor_input.expanddim_inputs_for_broadcast() 

3289 cond = pfor_input.input(0)[0] 

3290 t = pfor_input.input(1)[0] 

3291 e = pfor_input.input(2)[0] 

3292 out = array_ops.where_v2(cond, t, e) 

3293 return wrap(out, True) 

3294 

3295 

3296# random_ops 

3297 

3298 

3299def _transpose_dim_to_front(x, dim): 

3300 rank = array_ops.rank(x) 

3301 return array_ops.transpose( 

3302 x, 

3303 perm=array_ops.concat( 

3304 [[dim], math_ops.range(0, dim), 

3305 math_ops.range(dim + 1, rank)], 

3306 axis=0)) 

3307 

3308 

3309@RegisterPForWithArgs("RandomUniform") 

3310@RegisterPForWithArgs("RandomUniformInt") 

3311@RegisterPForWithArgs("RandomStandardNormal") 

3312@RegisterPForWithArgs("TruncatedNormal") 

3313def _convert_random(pfor_input, op_type, *args, **kw_args): 

3314 del args 

3315 del kw_args 

3316 inputs = [pfor_input.unstacked_input(i) for i in range(pfor_input.num_inputs)] 

3317 # inputs[0] is "shape" 

3318 inputs[0] = array_ops.concat([pfor_input.pfor.loop_len_vector, inputs[0]], 

3319 axis=0) 

3320 # TODO(b/222761732): Turn this warning back on when legacy RNGs are 

3321 # deprecated. 

3322 # logging.warning( 

3323 # "Note that %s inside pfor op may not give same output as " 

3324 # "inside a sequential loop.", op_type) 

3325 outputs = _create_op( 

3326 op_type, 

3327 inputs, [x.dtype for x in pfor_input.outputs], 

3328 attrs=pfor_input.op.node_def.attr).outputs 

3329 return [wrap(x, True) for x in outputs] 

3330 

3331 

3332@RegisterPFor("RandomGamma") 

3333@RegisterPFor("RandomPoissonV2") 

3334def _convert_random_with_param(pfor_input): 

3335 shape = pfor_input.unstacked_input(0) 

3336 # param is lam (Poisson rate) or alpha (Gamma shape). 

3337 param, param_stacked, _ = pfor_input.input(1) 

3338 # TODO(b/222761732): Turn this warning back on when legacy RNGs are 

3339 # deprecated. 

3340 # logging.warning( 

3341 # "Note that %s inside pfor op may not give same output as " 

3342 # "inside a sequential loop.", pfor_input.op_type) 

3343 

3344 if param_stacked: 

3345 samples = _create_op( 

3346 pfor_input.op_type, 

3347 inputs=[shape, param], 

3348 op_dtypes=[x.dtype for x in pfor_input.outputs], 

3349 attrs=pfor_input.op.node_def.attr).outputs[0] 

3350 loop_dim = array_ops.shape(shape)[0] 

3351 stacked_samples = _transpose_dim_to_front(samples, loop_dim) 

3352 else: 

3353 shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0) 

3354 stacked_samples = _create_op( 

3355 pfor_input.op_type, 

3356 inputs=[shape, param], 

3357 op_dtypes=[x.dtype for x in pfor_input.outputs], 

3358 attrs=pfor_input.op.node_def.attr).outputs[0] 

3359 

3360 return wrap(stacked_samples, True) 

3361 

3362 

3363@RegisterPFor("Multinomial") 

3364def _convert_multinomial(pfor_input): 

3365 logits, logits_stacked, _ = pfor_input.input(0) 

3366 num_samples = pfor_input.unstacked_input(1) 

3367 seed = pfor_input.get_attr("seed") 

3368 seed2 = pfor_input.get_attr("seed2") 

3369 output_dtype = pfor_input.get_attr("output_dtype") 

3370 # TODO(b/222761732): Turn this warning back on when legacy RNGs are 

3371 # deprecated. 

3372 # logging.warning( 

3373 # "Note that Multinomial inside pfor op may not give same output as " 

3374 # "inside a sequential loop.") 

3375 

3376 n = pfor_input.pfor.loop_len_vector[0] 

3377 if logits_stacked: 

3378 flattened_logits = _flatten_first_two_dims(logits) 

3379 samples = gen_random_ops.multinomial( 

3380 flattened_logits, 

3381 num_samples, 

3382 seed=seed, 

3383 seed2=seed2, 

3384 output_dtype=output_dtype) 

3385 stacked_samples = _unflatten_first_dim(samples, [n]) 

3386 else: 

3387 samples = gen_random_ops.multinomial( 

3388 logits, 

3389 num_samples * n, 

3390 seed=seed, 

3391 seed2=seed2, 

3392 output_dtype=output_dtype) 

3393 stacked_samples = array_ops.transpose( 

3394 array_ops.reshape(samples, [-1, n, num_samples]), [1, 0, 2]) 

3395 

3396 return wrap(stacked_samples, True) 

3397 

3398 

3399@RegisterPFor("StatelessMultinomial") 

3400@RegisterPFor("StatelessParameterizedTruncatedNormal") 

3401@RegisterPFor("StatelessRandomBinomial") 

3402@RegisterPFor("StatelessRandomGammaV2") 

3403@RegisterPFor("StatelessRandomNormal") 

3404@RegisterPFor("StatelessRandomPoisson") 

3405@RegisterPFor("StatelessRandomUniform") 

3406@RegisterPFor("StatelessRandomUniformInt") 

3407@RegisterPFor("StatelessRandomUniformFullInt") 

3408@RegisterPFor("StatelessTruncatedNormal") 

3409def _convert_stateless_multinomial(pfor_input): 

3410 # Unlike stateful random ops, for stateless ones we want better 

3411 # reproducibility based on seed. Hence we don't want to use a similar strategy 

3412 # as used for stateful ones where we generate a possibly different set of 

3413 # random numbers under vectorization. 

3414 # Unfortunately, the kernels currently are not necessarily setup to do this 

3415 # efficiently and hence we fallback to a sequential loop for vectorization. 

3416 return _fallback_converter(pfor_input, warn=False) 

3417 

3418 

3419# linalg_ops 

3420 

3421 

3422@RegisterPForWithArgs("XlaEinsum") 

3423@RegisterPForWithArgs("Einsum") 

3424def _convert_einsum(pfor_input, op_type): 

3425 # Einsum may have either 1 or 2 inputs. 

3426 inputs, input_stacked, _ = zip(*[ 

3427 pfor_input.input(i) 

3428 for i in range(pfor_input.num_inputs)]) 

3429 

3430 # Parse the einsum equation. 

3431 equation = pfor_input.get_attr("equation").decode("utf-8") 

3432 input_expr, output_expr = equation.split("->") 

3433 input_exprs = input_expr.split(",") 

3434 

3435 # Pick a placeholder symbol to use for the new axis. 

3436 chosen_symbol = None 

3437 for s in string.ascii_letters: 

3438 if s in equation: 

3439 continue 

3440 else: 

3441 chosen_symbol = s 

3442 break 

3443 

3444 if chosen_symbol is None: 

3445 raise ValueError("Could not figure out what symbol to use for new axis.") 

3446 

3447 assert any(input_stacked) 

3448 for i in range(len(inputs)): 

3449 if input_stacked[i]: 

3450 input_exprs[i] = "{}{}".format(chosen_symbol, input_exprs[i]) 

3451 output_expr = "{}{}".format(chosen_symbol, output_expr) 

3452 

3453 new_equation = "{}->{}".format(",".join(input_exprs), output_expr) 

3454 

3455 if op_type == "XlaEinsum": 

3456 if len(inputs) == 1: 

3457 result = xla.einsum(equation=new_equation, a=inputs[0]) 

3458 else: 

3459 result = xla.einsum(equation=new_equation, a=inputs[0], b=inputs[1]) 

3460 else: 

3461 assert op_type == "Einsum" 

3462 result = special_math_ops.einsum(new_equation, *inputs) 

3463 

3464 return wrap(result, True) 

3465 

3466 

3467@RegisterPFor("Cholesky") 

3468def _convert_cholesky(pfor_input): 

3469 t = pfor_input.stacked_input(0) 

3470 return wrap(linalg_ops.cholesky(t), True) 

3471 

3472 

3473@RegisterPFor("LogMatrixDeterminant") 

3474def _convert_log_matrix_determinant(pfor_input): 

3475 t = pfor_input.stacked_input(0) 

3476 return [wrap(x, True) for x in linalg_ops.log_matrix_determinant(t)] 

3477 

3478 

3479@RegisterPFor("MatrixInverse") 

3480def _convert_matrix_inverse(pfor_input): 

3481 t = pfor_input.stacked_input(0) 

3482 adjoint = pfor_input.get_attr("adjoint") 

3483 return wrap(gen_linalg_ops.matrix_inverse(t, adjoint=adjoint), True) 

3484 

3485 

3486@RegisterPFor("MatrixSolve") 

3487def _convert_matrix_solve(pfor_input): 

3488 pfor_input.stack_inputs() 

3489 matrix = pfor_input.stacked_input(0) 

3490 rhs = pfor_input.stacked_input(1) 

3491 adjoint = pfor_input.get_attr("adjoint") 

3492 output = gen_linalg_ops.matrix_solve( 

3493 matrix, rhs, adjoint=adjoint) 

3494 return wrap(output, True) 

3495 

3496 

3497@RegisterPFor("MatrixTriangularSolve") 

3498def _convert_matrix_triangular_solve(pfor_input): 

3499 pfor_input.expanddim_inputs_for_broadcast() 

3500 matrix = pfor_input.input(0)[0] 

3501 rhs = pfor_input.input(1)[0] 

3502 lower = pfor_input.get_attr("lower") 

3503 adjoint = pfor_input.get_attr("adjoint") 

3504 output = linalg_ops.matrix_triangular_solve( 

3505 matrix, rhs, lower=lower, adjoint=adjoint) 

3506 return wrap(output, True) 

3507 

3508 

3509@RegisterPFor("SelfAdjointEigV2") 

3510def _convert_self_adjoint_eig(pfor_input): 

3511 t = pfor_input.stacked_input(0) 

3512 compute_v = pfor_input.get_attr("compute_v") 

3513 e, v = gen_linalg_ops.self_adjoint_eig_v2(t, compute_v=compute_v) 

3514 # If compute_v is False, v will have shape [0]. 

3515 return wrap(e, True), wrap(v, compute_v) 

3516 

3517 

3518# logging_ops 

3519 

3520 

3521@RegisterPFor("Assert") 

3522def _convert_assert(pfor_input): 

3523 cond, cond_stacked, _ = pfor_input.input(0) 

3524 if cond_stacked: 

3525 cond = math_ops.reduce_all(cond) 

3526 

3527 data_list = [x.t for x in pfor_input.inputs][1:] 

3528 return _create_op( 

3529 "Assert", [cond] + data_list, [], attrs=pfor_input.op.node_def.attr) 

3530 

3531 

3532@RegisterPFor("Print") 

3533def _convert_print(pfor_input): 

3534 # Note that we don't stack all the inputs. Hence unstacked values are printed 

3535 # once here vs multiple times in a while_loop. 

3536 pfor_input.stack_inputs([0]) 

3537 outputs = _create_op( 

3538 "Print", [x.t for x in pfor_input.inputs], 

3539 [x.dtype for x in pfor_input.outputs], 

3540 attrs=pfor_input.op.node_def.attr).outputs 

3541 return [wrap(x, True) for x in outputs] 

3542 

3543 

3544@RegisterPFor("PrintV2") 

3545def _convert_print_v2(pfor_input): 

3546 # Print the full input Tensor(s), including the batch dimension if stacked. 

3547 return _create_op( 

3548 "PrintV2", [x.t for x in pfor_input.inputs], 

3549 [x.dtype for x in pfor_input.outputs], 

3550 attrs=pfor_input.op.node_def.attr) 

3551 

3552 

3553@RegisterPFor("StringFormat") 

3554def _convert_string_format(pfor_input): 

3555 # Format using the full input Tensor(s), including the batch dimension if 

3556 # stacked. 

3557 op = _create_op( 

3558 "StringFormat", [x.t for x in pfor_input.inputs], 

3559 [x.dtype for x in pfor_input.outputs], 

3560 attrs=pfor_input.op.node_def.attr) 

3561 return [wrap(output, False) for output in op.outputs] 

3562 

3563 

3564# data_flow_ops 

3565 

3566# TensorArray conversion is tricky since we don't support arrays of 

3567# TensorArrays. For converting them, we consider two distinct cases: 

3568# 

3569# 1. The array is constructed outside the pfor call, and read/written inside the 

3570# loop. 

3571# This is an easier case since we don't need to make an array of TensorArrays. 

3572# A correctness requirement is that these parallel iterations shouldn't attempt 

3573# to write to the same location. Hence at conversion time we disallow indices to 

3574# be loop-invariant as that would guarantee a collision. Even if the indices are 

3575# not loop-invariant, they could conflict and that shall trigger runtime errors. 

3576# 

3577# 2. The array is constructed and used entirely inside each pfor iteration. 

3578# For simplicity, here we require that the indices used for write/scatter are 

3579# "unstacked". Otherwise it becomes hard to merge the TensorArrays created in 

3580# different pfor iterations. We consider two sub_cases: 

3581# 

3582# 2a Elements written to the array are "stacked" 

3583# To simulate multiple TensorArrays, we may increase the dimension of each 

3584# element of the array. i.e. the i_th row of the j_th entry of the converted 

3585# TensorArray corresponds to the j_th entry of the TensorArray in the i_th 

3586# pfor iteration. 

3587# 

3588# 2b Elements written to the array are "unstacked" 

3589# In this case we don't increase the dimensions to avoid redundant tiling. Each 

3590# iteration is trying to write the same value. So we convert that to a single 

3591# write. 

3592# 

3593# Here are some tricks used to implement the above: 

3594# - TensorArrayV3 constructor encodes the element shape as an attr. Instead of 

3595# trying to trace whether future writes are stacked or unstacked in order to set 

3596# this attr, we set it to correspond to unknown shape. 

3597# - We use the "flow" output of the different ops to track whether the array 

3598# elements are stacked or unstacked. If a stacked write/scatter is done, we make 

3599# the flow stacked as well. 

3600# - We use some heuristic traversal of the graph to track whether the 

3601# TensorArray handle was created inside or outside the pfor loop. 

3602 

3603 

3604@RegisterPFor("TensorArrayV3") 

3605def _convert_tensor_array_v3(pfor_input): 

3606 size = pfor_input.unstacked_input(0) 

3607 dtype = pfor_input.get_attr("dtype") 

3608 dynamic_size = pfor_input.get_attr("dynamic_size") 

3609 clear_after_read = pfor_input.get_attr("clear_after_read") 

3610 identical_element_shapes = pfor_input.get_attr("identical_element_shapes") 

3611 tensor_array_name = pfor_input.get_attr("tensor_array_name") 

3612 handle, flow = data_flow_ops.tensor_array_v3( 

3613 size, 

3614 dtype=dtype, 

3615 # We don't set element shape since we don't know if writes are stacked or 

3616 # not yet. 

3617 element_shape=None, 

3618 dynamic_size=dynamic_size, 

3619 clear_after_read=clear_after_read, 

3620 identical_element_shapes=identical_element_shapes, 

3621 tensor_array_name=tensor_array_name) 

3622 # Note we keep flow unstacked for now since we don't know if writes will be 

3623 # stacked or not. 

3624 return wrap(handle, False), wrap(flow, False) 

3625 

3626 

3627@RegisterPFor("TensorArraySizeV3") 

3628def _convert_tensor_array_size_v3(pfor_input): 

3629 handle = pfor_input.unstacked_input(0) 

3630 flow, flow_stacked, _ = pfor_input.input(1) 

3631 if flow_stacked: 

3632 flow = _unstack_flow(flow) 

3633 size = data_flow_ops.tensor_array_size_v3(handle, flow) 

3634 return wrap(size, False) 

3635 

3636 

3637def _handle_inside_pfor(pfor_input, handle): 

3638 """Returns True if handle was created inside the pfor loop.""" 

3639 # We use some heuristic to find the original TensorArray creation op. 

3640 # The logic should handle the common cases (except cond based subgraphs). 

3641 # In theory the user could perform different operations on the handle (like 

3642 # Reshape, stack multiple handles, etc) which could break this logic. 

3643 # TODO(agarwal): handle Switch/Merge. 

3644 while handle.op.type in ("Enter", "Identity"): 

3645 handle = handle.op.inputs[0] 

3646 if handle.op.type not in [ 

3647 "TensorArrayV3", "TensorArrayGradV3", "TensorArrayGradWithShape" 

3648 ]: 

3649 raise ValueError(f"Unable to find source for handle {handle}.") 

3650 else: 

3651 return pfor_input.pfor.op_is_inside_loop(handle.op) 

3652 

3653 

3654def _unstack_flow(value): 

3655 # TODO(agarwal): consider looking if this is a Tile op then get its input. 

3656 # This may avoid running the Tile operations. 

3657 return array_ops.gather(value, 0) 

3658 

3659 

3660@RegisterPFor("TensorArrayReadV3") 

3661def _convert_tensor_array_read_v3(pfor_input): 

3662 handle = pfor_input.unstacked_input(0) 

3663 index, index_stacked, _ = pfor_input.input(1) 

3664 dtype = pfor_input.get_attr("dtype") 

3665 flow, flow_stacked, _ = pfor_input.input(2) 

3666 if flow_stacked: 

3667 flow = _unstack_flow(flow) 

3668 

3669 is_inside_pfor = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 

3670 if is_inside_pfor: 

3671 # Note that if we are inside a control flow construct inside the pfor, and 

3672 # only some of the iterations are doing the read (i.e. 

3673 # `all_indices_partitioned` is True), then the read operation should only 

3674 # return values for the currently active pfor iterations (`all_indices` 

3675 # below). Hence, whenever the returned value is stacked (i.e. `flow` is 

3676 # stacked), we may need to do an extra gather after reading the values. Also 

3677 # note that if `is_inside` is false, then values in the tensor array are 

3678 # unstacked. So the check is only needed in this branch. 

3679 all_indices = pfor_input.pfor.all_indices 

3680 all_indices_partitioned = pfor_input.pfor.all_indices_partitioned 

3681 # Note: flow_stacked indicates if values in the TensorArray are stacked or 

3682 # not. 

3683 if index_stacked: 

3684 if flow_stacked: 

3685 raise ValueError( 

3686 "It looks like TensorArrayReadV3 was called on a TensorArray whose" 

3687 " values are not loop-invariant, and the read indices were also" 

3688 " not loop invariant. This is currently unsupported.") 

3689 value = data_flow_ops.tensor_array_gather_v3( 

3690 handle, index, flow, dtype=dtype) 

3691 return wrap(value, True) 

3692 value = data_flow_ops.tensor_array_read_v3(handle, index, flow, dtype=dtype) 

3693 if flow_stacked and all_indices_partitioned: 

3694 value = array_ops.gather(value, all_indices) 

3695 return wrap(value, flow_stacked) 

3696 # Values in the TensorArray should be unstacked (since different iterations 

3697 # couldn't write to the same location). So whether output is stacked or not 

3698 # depends on index_stacked. 

3699 if index_stacked: 

3700 value = data_flow_ops.tensor_array_gather_v3( 

3701 handle, index, flow, dtype=dtype) 

3702 else: 

3703 value = data_flow_ops.tensor_array_read_v3(handle, index, flow, dtype=dtype) 

3704 return wrap(value, index_stacked) 

3705 

3706 

3707@RegisterPFor("TensorArrayWriteV3") 

3708def _convert_tensor_array_write_v3(pfor_input): 

3709 handle = pfor_input.unstacked_input(0) 

3710 index, index_stacked, _ = pfor_input.input(1) 

3711 value, value_stacked, _ = pfor_input.input(2) 

3712 flow, flow_stacked, _ = pfor_input.input(3) 

3713 if value_stacked and pfor_input.pfor.all_indices_partitioned: 

3714 # Looks like we are in a control flow in a pfor where not all iterations are 

3715 # active now. We don't allow that since that could lead to different indices 

3716 # having different shapes which will be hard to merge later. 

3717 raise ValueError("Writing non loop invariant values to TensorArray from " 

3718 "inside a while_loop/cond not supported.") 

3719 if flow_stacked: 

3720 flow = _unstack_flow(flow) 

3721 is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 

3722 if is_inside: 

3723 if index_stacked: 

3724 raise ValueError(f"Need indices for {handle} to be loop invariant.") 

3725 if not flow_stacked and not value_stacked: 

3726 flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow) 

3727 return wrap(flow_out, False) 

3728 else: 

3729 if not value_stacked: 

3730 value = _stack(value, pfor_input.pfor.loop_len_vector).t 

3731 # TODO(agarwal): Note that if flow is unstacked and value is stacked, then 

3732 # this may or may not be a safe situation. flow is unstacked both for a 

3733 # freshly created TensorArray, as well as after unstacked values are 

3734 # written to it. If it is the latter, then we cannot write a stacked value 

3735 # now since that may cause runtime errors due to different shapes in the 

3736 # array. At the moment we are not able to handle this gracefully and 

3737 # distinguish between the two cases. That would require some heuristic 

3738 # traversal of the graph to figure out whether all the writes are 

3739 # unstacked or not. 

3740 flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow) 

3741 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 

3742 else: 

3743 if not index_stacked: 

3744 raise ValueError(f"Need indices for {handle} to be not loop invariant.") 

3745 # Note that even when index_stacked is true, actual values in index may 

3746 # still not be unique. However that will cause runtime error when executing 

3747 # the scatter operation below. 

3748 if not value_stacked: 

3749 value = _stack(value, pfor_input.pfor.loop_len_vector).t 

3750 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, index, value, flow) 

3751 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 

3752 

3753 

3754def _transpose_first_two_dims(value): 

3755 # TODO(agarwal): optimize if one of the dims == 1. 

3756 value_shape = array_ops.shape(value) 

3757 v0 = value_shape[0] 

3758 v1 = value_shape[1] 

3759 value = array_ops.reshape(value, [v0, v1, -1]) 

3760 value = array_ops.transpose(value, [1, 0, 2]) 

3761 new_shape = array_ops.concat([[v1, v0], value_shape[2:]], axis=0) 

3762 return array_ops.reshape(value, new_shape) 

3763 

3764 

3765@RegisterPFor("TensorArrayGatherV3") 

3766def _convert_tensor_array_gather_v3(pfor_input): 

3767 handle = pfor_input.unstacked_input(0) 

3768 indices, indices_stacked, _ = pfor_input.input(1) 

3769 indices = array_ops.reshape(indices, [-1]) 

3770 flow, flow_stacked, _ = pfor_input.input(2) 

3771 if flow_stacked: 

3772 flow = _unstack_flow(flow) 

3773 dtype = pfor_input.get_attr("dtype") 

3774 # TODO(agarwal): support element_shape attr? 

3775 

3776 n = pfor_input.pfor.loop_len_vector 

3777 value = data_flow_ops.tensor_array_gather_v3( 

3778 handle, indices, flow, dtype=dtype) 

3779 is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 

3780 if is_inside: 

3781 # flow_stacked indicates if values in the TensorArray are stacked or not. 

3782 if indices_stacked: 

3783 if flow_stacked: 

3784 raise ValueError( 

3785 "It looks like TensorArrayGatherV3 was called on a TensorArray " 

3786 "whose values are not loop-invariant, and the indices were also " 

3787 "not loop invariant. This is currently unsupported.") 

3788 else: 

3789 value = _unflatten_first_dim(value, n) 

3790 return wrap(value, True) 

3791 else: 

3792 if flow_stacked: 

3793 # Since elements in this array are stacked and `value` was produced by 

3794 # gather, its first two dims are "gathered elements" and "stack 

3795 # dimension". Our semantics require these two to be flipped. 

3796 value = _transpose_first_two_dims(value) 

3797 return wrap(value, flow_stacked) 

3798 else: 

3799 # Values in the TensorArray should be unstacked (since different iterations 

3800 # couldn't write to the same location). So whether output is stacked or not 

3801 # depends on indices_stacked. 

3802 if indices_stacked: 

3803 value = _unflatten_first_dim(value, n) 

3804 return wrap(value, indices_stacked) 

3805 

3806 

3807@RegisterPFor("TensorArrayScatterV3") 

3808def _convert_tensor_array_scatter_v3(pfor_input): 

3809 handle = pfor_input.unstacked_input(0) 

3810 indices, indices_stacked, _ = pfor_input.input(1) 

3811 indices = array_ops.reshape(indices, [-1]) 

3812 value, value_stacked, _ = pfor_input.input(2) 

3813 flow, flow_stacked, _ = pfor_input.input(3) 

3814 

3815 if flow_stacked: 

3816 flow = _unstack_flow(flow) 

3817 

3818 is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 

3819 if is_inside: 

3820 if indices_stacked: 

3821 raise ValueError(f"Need indices for {handle} to be loop invariant.") 

3822 # Note that flow_stacked indicates if existing values in the array are 

3823 # stacked or not. 

3824 if not flow_stacked and not value_stacked: 

3825 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, 

3826 flow) 

3827 return wrap(flow_out, False) 

3828 if not value_stacked: 

3829 # TODO(agarwal): tile in the second dimension directly instead of 

3830 # transposing below. 

3831 value = _stack(value, pfor_input.pfor.loop_len_vector).t 

3832 

3833 value = _transpose_first_two_dims(value) 

3834 # TODO(agarwal): Note that if a previous write was unstacked, flow will be 

3835 # unstacked, and a stacked value may be written here which may cause 

3836 # runtime error due to different elements having different shape. We do 

3837 # not try to prevent that. 

3838 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, 

3839 flow) 

3840 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 

3841 if not indices_stacked: 

3842 raise ValueError(f"Need indices for {handle} to be not loop invariant.") 

3843 if not value_stacked: 

3844 value = _stack(value, pfor_input.pfor.loop_len_vector).t 

3845 value = _flatten_first_two_dims(value) 

3846 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, flow) 

3847 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 

3848 

3849 

3850@RegisterPFor("TensorArrayGradV3") 

3851def _convert_tensor_array_grad_v3(pfor_input): 

3852 handle = pfor_input.unstacked_input(0) 

3853 flow, flow_stacked, _ = pfor_input.input(1) 

3854 if flow_stacked: 

3855 flow = _unstack_flow(flow) 

3856 source = pfor_input.get_attr("source") 

3857 # TODO(agarwal): For now, we assume that gradients are stacked if the 

3858 # TensorArrayGradV3 call is being done inside the pfor. Getting that wrong 

3859 # will give runtime error due to incorrect shape being written to the 

3860 # accumulator. It is difficult to know in advance if gradients written will be 

3861 # stacked or not. Note that flow being stacked is not indicative of the 

3862 # gradient being stacked or not. Revisit this later. 

3863 shape_to_prepend = pfor_input.pfor.loop_len_vector 

3864 grad_handle, flow_out = data_flow_ops.tensor_array_grad_with_shape( 

3865 handle=handle, 

3866 flow_in=flow, 

3867 shape_to_prepend=shape_to_prepend, 

3868 source=source) 

3869 flow_out = _stack(flow_out, pfor_input.pfor.loop_len_vector).t 

3870 return [wrap(grad_handle, False), wrap(flow_out, True)] 

3871 

3872 

3873def _stack_tensor_list_shape(shape, first_dim): 

3874 shape_value = tensor_util.constant_value(shape) 

3875 # Note that negative values in the shape are used to signify unknown shapes 

3876 # and are handled in a special way. 

3877 if shape_value is not None: 

3878 shape_value = np.asarray(shape_value) 

3879 if -1 in shape_value: 

3880 return constant_op.constant(-1) 

3881 elif not shape_value.size: 

3882 return first_dim 

3883 else: 

3884 shape = array_ops.reshape(shape, [-1]) 

3885 return tf_cond.cond( 

3886 math_ops.reduce_any(shape < 0), 

3887 lambda: constant_op.constant(-1), 

3888 lambda: array_ops.concat([first_dim, shape], axis=0)) 

3889 

3890 

3891def _tile_variant_with_length(t, length): 

3892 """stacks `t` `length` times.""" 

3893 if _is_variant_with_internal_stacking(t): 

3894 # The content of TensorLists is vectorized, not the variant itself. 

3895 return t 

3896 original_tensor = t 

3897 t.set_shape([]) 

3898 t = array_ops.reshape(t, [-1]) 

3899 with ops.device("CPU:0"): 

3900 result = array_ops.tile(t, length) 

3901 # TODO(b/169968286): Should regular shape functions do handle data 

3902 # propagation here? 

3903 handle_data_util.copy_handle_data(original_tensor, result) 

3904 return result 

3905 

3906 

3907def _tile_variant(t, pfor_input): 

3908 """stacks `t` according to its loop context.""" 

3909 return _tile_variant_with_length(t, pfor_input.pfor.loop_len_vector) 

3910 

3911 

3912def _untile_variant(t): 

3913 if _is_variant_with_internal_stacking(t): 

3914 # The content of TensorLists is vectorized, not the variant itself. 

3915 if not t.shape.is_compatible_with([]): 

3916 raise AssertionError( 

3917 ("Unexpectedly saw a vectorized variant (e.g. TensorList) with " 

3918 f"non-scalar shape: {t!r}")) 

3919 return t 

3920 return array_ops.gather(t, 0) 

3921 

3922 

3923@RegisterPFor("OptionalFromValue") 

3924def _convert_optional_from_value(pfor_input): 

3925 pfor_input.stack_inputs() 

3926 return wrap( 

3927 gen_optional_ops.optional_from_value([x.t for x in pfor_input.inputs]), 

3928 True, 

3929 ) 

3930 

3931 

3932@RegisterPFor("OptionalGetValue") 

3933def _convert_optional_get_value(pfor_input): 

3934 handle = pfor_input.stacked_input(0) 

3935 output_types = pfor_input.get_attr("output_types") 

3936 original_output_shapes = pfor_input.get_attr("output_shapes") 

3937 output_shapes = [] 

3938 for shape in original_output_shapes: 

3939 shape = tensor_shape.TensorShape(shape) 

3940 loop_len_value = tensor_util.constant_value(pfor_input.pfor.loop_len_vector) 

3941 loop_len_shape = tensor_shape.TensorShape( 

3942 [loop_len_value[0] if loop_len_value is not None else None] 

3943 ) 

3944 shape = loop_len_shape.concatenate(shape) 

3945 output_shapes.append(shape.as_proto()) 

3946 results = gen_optional_ops.optional_get_value( 

3947 handle, output_types, output_shapes 

3948 ) 

3949 return [wrap(t, True) for t in results] 

3950 

3951 

3952@RegisterPFor("TensorListReserve") 

3953def _convert_tensor_list_reserve(pfor_input): 

3954 element_shape = pfor_input.unstacked_input(0) 

3955 num_elements = pfor_input.unstacked_input(1) 

3956 element_dtype = pfor_input.get_attr("element_dtype") 

3957 

3958 # Prepend a dimension to element_shape. 

3959 element_shape = _stack_tensor_list_shape(element_shape, 

3960 pfor_input.pfor.loop_len_vector) 

3961 handle = list_ops.tensor_list_reserve( 

3962 element_shape, num_elements, element_dtype=element_dtype) 

3963 

3964 return wrap(_tile_variant(handle, pfor_input), True) 

3965 

3966 

3967@RegisterPFor("TensorListElementShape") 

3968def _convert_tensor_list_element_shape(pfor_input): 

3969 handle = _untile_variant(pfor_input.stacked_input(0)) 

3970 shape_type = pfor_input.get_attr("shape_type") 

3971 shape = list_ops.tensor_list_element_shape(handle, shape_type) 

3972 shape = array_ops.reshape(shape, [-1]) 

3973 shape = shape[1:] 

3974 return wrap(shape, False) 

3975 

3976 

3977@RegisterPFor("TensorListLength") 

3978def _convert_tensor_list_length(pfor_input): 

3979 handle = _untile_variant(pfor_input.stacked_input(0)) 

3980 return wrap(list_ops.tensor_list_length(handle), False) 

3981 

3982 

3983def _stack_tensor_list(handle, dtype, loop_len_vector, element_shape=None): 

3984 if element_shape is None: 

3985 element_shape = list_ops.tensor_list_element_shape(handle, dtypes.int32) 

3986 length = list_ops.tensor_list_length(handle) 

3987 new_handle = list_ops.tensor_list_reserve( 

3988 _stack_tensor_list_shape(element_shape, loop_len_vector), length, dtype) 

3989 

3990 def _body_fn(i, h): 

3991 elem = list_ops.tensor_list_get_item(handle, i, dtype, element_shape) 

3992 elem = _stack(elem, loop_len_vector).t 

3993 return i + 1, list_ops.tensor_list_set_item(h, i, elem) 

3994 

3995 return while_loop.while_loop(lambda i, _: i < length, _body_fn, 

3996 [0, new_handle])[1] 

3997 

3998 

3999@RegisterPFor("TensorListGetItem") 

4000def _convert_tensor_list_get_item(pfor_input): 

4001 handle, handle_stacked, _ = pfor_input.input(0) 

4002 index, index_stacked, _ = pfor_input.input(1) 

4003 element_shape = pfor_input.unstacked_input(2) 

4004 element_dtype = pfor_input.get_attr("element_dtype") 

4005 

4006 if handle_stacked: 

4007 handle = _untile_variant(handle) 

4008 element_shape = _stack_tensor_list_shape(element_shape, 

4009 pfor_input.pfor.loop_len_vector) 

4010 if index_stacked: 

4011 # We use a sequential loop since that may be more efficient than first 

4012 # gathering and concatenating all the element corresponding to `index`, 

4013 # and then doing a gather on it. 

4014 def _map_fn(i): 

4015 item_i = list_ops.tensor_list_get_item( 

4016 handle, 

4017 index[i], 

4018 element_dtype=element_dtype) 

4019 return array_ops.gather(item_i, i) 

4020 

4021 output = map_fn.map_fn(_map_fn, pfor_input.pfor.all_indices) 

4022 return wrap(output, True) 

4023 else: 

4024 output = list_ops.tensor_list_get_item( 

4025 handle, 

4026 index, 

4027 element_shape=element_shape, 

4028 element_dtype=element_dtype) 

4029 return wrap(output, True) 

4030 else: 

4031 assert index_stacked 

4032 return wrap( 

4033 list_ops.tensor_list_gather( 

4034 handle, 

4035 index, 

4036 element_shape=element_shape, 

4037 element_dtype=element_dtype), True) 

4038 

4039 

4040@RegisterPFor("TensorListSetItem") 

4041def _convert_tensor_array_set_item(pfor_input): 

4042 handle, handle_stacked, _ = pfor_input.input(0) 

4043 index, index_stacked, _ = pfor_input.input(1) 

4044 item, item_stacked, _ = pfor_input.input(2) 

4045 

4046 if not handle_stacked: 

4047 # Special case where we can statically guarantee that the indices are 

4048 # disjoint. 

4049 if index is pfor_input.pfor.all_indices: 

4050 if not item_stacked: 

4051 item = _stack(item, pfor_input.pfor.loop_len_vector).t 

4052 return wrap( 

4053 list_ops.tensor_list_scatter(item, index, input_handle=handle), False) 

4054 else: 

4055 handle = _stack_tensor_list(handle, item.dtype, 

4056 pfor_input.pfor.loop_len_vector) 

4057 else: 

4058 handle = _untile_variant(handle) 

4059 

4060 if index_stacked: 

4061 # TODO(agarwal): handle this. 

4062 raise ValueError("Vectorizing writes to a TensorList with loop " 

4063 "variant indices is currently unsupported.") 

4064 

4065 else: 

4066 if not item_stacked: 

4067 item = _stack(item, pfor_input.pfor.loop_len_vector).t 

4068 handle = list_ops.tensor_list_set_item(handle, index, item) 

4069 return wrap(_tile_variant(handle, pfor_input), True) 

4070 

4071 

4072@RegisterPFor("TensorListPushBack") 

4073def _convert_tensor_list_push_back(pfor_input): 

4074 handle, handle_stacked, _ = pfor_input.input(0) 

4075 tensor, tensor_stacked, _ = pfor_input.input(1) 

4076 if handle_stacked: 

4077 handle = _untile_variant(handle) 

4078 else: 

4079 handle = _stack_tensor_list(handle, tensor.dtype, 

4080 pfor_input.pfor.loop_len_vector) 

4081 if not tensor_stacked: 

4082 tensor = _stack(tensor, pfor_input.pfor.loop_len_vector).t 

4083 handle = list_ops.tensor_list_push_back(handle, tensor) 

4084 return wrap(_tile_variant(handle, pfor_input), True) 

4085 

4086 

4087@RegisterPFor("TensorListPopBack") 

4088def _convert_tensor_array_push_back(pfor_input): 

4089 handle = pfor_input.stacked_input(0) 

4090 element_shape = pfor_input.unstacked_input(1) 

4091 handle = _untile_variant(handle) 

4092 

4093 if element_shape.shape.ndims == 0: 

4094 # Default / unspecified 

4095 vectorized_shape = -1 

4096 else: 

4097 # PopBack has an element shape set when it's the gradient of PushBack, only 

4098 # used when the list is uninitialized. 

4099 vectorized_shape = array_ops.concat( 

4100 [pfor_input.pfor.loop_len_vector, element_shape], axis=0) 

4101 

4102 output_handle, tensor = gen_list_ops.tensor_list_pop_back( 

4103 input_handle=handle, element_dtype=pfor_input.get_attr("element_dtype"), 

4104 element_shape=vectorized_shape) 

4105 return wrap(output_handle, True), wrap(tensor, True) 

4106 

4107 

4108@RegisterPFor("TensorListConcatV2") 

4109def _convert_tensor_list_concat_v2(pfor_input): 

4110 input_handle = pfor_input.stacked_input(0) 

4111 element_shape = pfor_input.unstacked_input(1) 

4112 leading_dims = pfor_input.unstacked_input(2) 

4113 element_dtype = pfor_input.get_attr("element_dtype") 

4114 

4115 handle = _untile_variant(input_handle) 

4116 length = list_ops.tensor_list_length(handle) 

4117 # Note that element_shape attribute can have incomplete shapes. This doesn't 

4118 # seem to work well when creating another list and then doing a concat on it. 

4119 # Hence we try to find the dynamic shape here. 

4120 element_shape = tf_cond.cond( 

4121 length > 0, lambda: array_ops.shape( 

4122 list_ops.tensor_list_get_item(handle, 0, element_dtype, None)), 

4123 lambda: constant_op.constant([0, 0], dtype=dtypes.int32)) 

4124 # The code below creates a copy of the list with each elements' first two 

4125 # dimensions transposed. 

4126 new_element_shape = array_ops.concat( 

4127 [element_shape[1:2], element_shape[0:1], element_shape[2:]], axis=0) 

4128 

4129 # Create a new TensorList with elements transposed. 

4130 def _transpose_elem(i, h): 

4131 elem = list_ops.tensor_list_get_item(handle, i, element_dtype, None) 

4132 elem = _transpose_first_two_dims(elem) 

4133 return i + 1, list_ops.tensor_list_set_item(h, i, elem) 

4134 

4135 new_handle = list_ops.tensor_list_reserve(new_element_shape, length, 

4136 element_dtype) 

4137 new_handle = while_loop.while_loop(lambda i, _: i < length, _transpose_elem, 

4138 [0, new_handle])[1] 

4139 output, lengths = gen_list_ops.tensor_list_concat_v2( 

4140 input_handle=new_handle, 

4141 element_dtype=element_dtype, 

4142 element_shape=new_element_shape, 

4143 leading_dims=leading_dims) 

4144 output = _transpose_first_two_dims(output) 

4145 return wrap(output, True), wrap(lengths, False) 

4146 

4147 

4148@RegisterPFor("TensorListStack") 

4149def _convert_tensor_list_stack(pfor_input): 

4150 handle = pfor_input.stacked_input(0) 

4151 input_shape = pfor_input.unstacked_input(1) 

4152 element_dtype = pfor_input.get_attr("element_dtype") 

4153 num_elements = pfor_input.get_attr("num_elements") 

4154 

4155 handle = _untile_variant(handle) 

4156 input_shape = _stack_tensor_list_shape(input_shape, 

4157 pfor_input.pfor.loop_len_vector) 

4158 output = list_ops.tensor_list_stack( 

4159 handle, 

4160 element_dtype, 

4161 element_shape=input_shape, 

4162 num_elements=num_elements) 

4163 output = _transpose_first_two_dims(output) 

4164 return wrap(output, True) 

4165 

4166 

4167@RegisterPFor("TensorListGather") 

4168def _convert_tensor_list_gather(pfor_input): 

4169 handle, handle_stacked, _ = pfor_input.input(0) 

4170 index, index_stacked, _ = pfor_input.input(1) 

4171 element_shape = pfor_input.unstacked_input(2) 

4172 element_dtype = pfor_input.get_attr("element_dtype") 

4173 

4174 if handle_stacked: 

4175 handle = _untile_variant(handle) 

4176 element_shape = _stack_tensor_list_shape(element_shape, 

4177 pfor_input.pfor.loop_len_vector) 

4178 if index_stacked: 

4179 # We use a sequential loop since that may be more efficient than first 

4180 # gathering and concatenating all the element corresponding to `index`, 

4181 # and then doing a gather on it. 

4182 def _map_fn(i): 

4183 item_i = list_ops.tensor_list_gather( 

4184 handle, 

4185 index[i], 

4186 element_dtype=element_dtype) 

4187 axis = array_ops.rank(index) - 1 

4188 return array_ops.gather(item_i, i, axis=axis) 

4189 

4190 output = map_fn.map_fn(_map_fn, pfor_input.pfor.all_indices) 

4191 return wrap(output, True) 

4192 else: 

4193 output = list_ops.tensor_list_gather( 

4194 handle, 

4195 index, 

4196 element_shape=element_shape, 

4197 element_dtype=element_dtype) 

4198 return wrap(output, True) 

4199 else: 

4200 assert index_stacked 

4201 index_shape = array_ops.shape(index) 

4202 index = array_ops.reshape(index, [-1]) 

4203 values = list_ops.tensor_list_gather( 

4204 handle, index, element_shape=element_shape, element_dtype=element_dtype) 

4205 final_shape = array_ops.concat( 

4206 [index_shape, array_ops.shape(values)[1:]], axis=0) 

4207 return wrap(array_ops.reshape(values, final_shape), True) 

4208 

4209 

4210@RegisterPFor("TensorListScatterIntoExistingList") 

4211def _convert_tensor_list_scatter(pfor_input): 

4212 pfor_input.stack_inputs([1]) 

4213 handle, handle_stacked, _ = pfor_input.input(0) 

4214 item = pfor_input.stacked_input(1) 

4215 indices, indices_stacked, _ = pfor_input.input(2) 

4216 if handle_stacked: 

4217 handle = _untile_variant(handle) 

4218 else: 

4219 handle = _stack_tensor_list(handle, item.dtype, 

4220 pfor_input.pfor.loop_len_vector) 

4221 

4222 item = _transpose_first_two_dims(item) 

4223 if indices_stacked: 

4224 # Pretend the list is a dense tensor: 

4225 # list_as_dense: Tensor[list_len, loop_len, ...] 

4226 # And indices are a tensor with shape (before transpose): 

4227 # indices: Tensor[loop_len, num_scatters] 

4228 # The item to scatter has shape (before transpose): 

4229 # item: Tensor[loop_len, num_scatters, ...] 

4230 # 

4231 # We want list_as_dense[indices[i, j], i] = item[i, j] 

4232 # 

4233 # Since we're not just indexing along the first axis of `list_as_dense`, we 

4234 # need to first extract the relevant list entries based on `indices`, 

4235 # scatter into them according to the loop index, and re-scatter the chunks 

4236 # we updated back into the list. 

4237 indices = _transpose_first_two_dims(indices) 

4238 indices_flat = array_ops.reshape(indices, [-1]) 

4239 # In many cases `indices` will be unique across pfor iterations, but this is 

4240 # not guaranteed. If there are duplicates, we need to map multiple updates 

4241 # to a single chunk extracted from the list. The last update should win. 

4242 unique_indices = array_ops.unique(indices_flat) 

4243 gathered_items = list_ops.tensor_list_gather( 

4244 handle, unique_indices.y, element_dtype=item.dtype, 

4245 element_shape=array_ops.shape(item)[1:]) 

4246 loop_idx = math_ops.range(pfor_input.pfor.loop_len_vector[0]) 

4247 scatters_per_op = array_ops.shape(indices)[0] 

4248 

4249 unique_indices_loop_idx = array_ops.reshape(array_ops.tile( 

4250 loop_idx[None, :], [scatters_per_op, 1]), [-1]) 

4251 scatter_indices = array_ops_stack.stack( 

4252 [unique_indices.idx, unique_indices_loop_idx], 

4253 axis=1) 

4254 # This op does *not* guarantee last-update-wins on GPU, so semantics may not 

4255 # be exactly preserved for duplicate updates there. 

4256 scattered = array_ops.tensor_scatter_nd_update( 

4257 tensor=gathered_items, 

4258 indices=scatter_indices, 

4259 updates=_flatten_first_two_dims(item)) 

4260 handle = list_ops.tensor_list_scatter( 

4261 scattered, unique_indices.y, input_handle=handle) 

4262 else: 

4263 handle = list_ops.tensor_list_scatter(item, indices, input_handle=handle) 

4264 return wrap(_tile_variant(handle, pfor_input), True) 

4265 

4266 

4267@RegisterPFor("TensorListFromTensor") 

4268def _convert_tensor_list_from_tensor(pfor_input): 

4269 tensor = pfor_input.stacked_input(0) 

4270 element_shape = pfor_input.unstacked_input(1) 

4271 tensor = _transpose_first_two_dims(tensor) 

4272 element_shape = _stack_tensor_list_shape(element_shape, 

4273 pfor_input.pfor.loop_len_vector) 

4274 handle = list_ops.tensor_list_from_tensor(tensor, element_shape) 

4275 return wrap(_tile_variant(handle, pfor_input), True) 

4276 

4277 

4278@RegisterPFor("TensorScatterUpdate") 

4279def _convert_tensor_scatter_update(pfor_input): 

4280 pfor_input.stack_inputs([0, 1, 2]) 

4281 tensor = pfor_input.stacked_input(0) 

4282 indices = pfor_input.stacked_input(1) 

4283 updates = pfor_input.stacked_input(2) 

4284 

4285 indices_shape = array_ops.shape(indices) 

4286 indices_rank = array_ops.rank(indices) 

4287 loop_length = indices_shape[0] 

4288 

4289 # Create a loop count range and extend its dimensions to match `indices`. 

4290 loop_count_shape = array_ops.tensor_scatter_nd_update( 

4291 array_ops.ones([indices_rank], dtype=dtypes.int32), [[0]], [loop_length]) 

4292 loop_count = array_ops.reshape(math_ops.range(loop_length), loop_count_shape) 

4293 

4294 # Tile the loop count range for the batch dimensions (all except the first and 

4295 # last dimensions of indices). 

4296 # Rank(indices) >= 3 always for this function so we always have at least 1. 

4297 tile_multiplier = array_ops.tensor_scatter_nd_update( 

4298 indices_shape, [[0], [indices_rank - 1]], [1, 1]) 

4299 meta_index = array_ops.tile(loop_count, tile_multiplier) 

4300 

4301 # Insert the loop-identifying index. 

4302 indices = array_ops.concat([meta_index, indices], axis=-1) 

4303 

4304 result = array_ops.tensor_scatter_nd_update(tensor, indices, updates) 

4305 return wrap(result, True) 

4306 

4307# StackV2 conversion is tricky since we don't have arrays of StackV2. So similar 

4308# to TensorArrays, we convert them by changing the dimension of the elements 

4309# inside the stack. 

4310# 

4311# We consider two cases: 

4312# 

4313# 1. StackV2 is constructed and used entirely inside the pfor loop. 

4314# We keep a single Stack and perform the push/pop operations of all the 

4315# iterations in lock-step. We also assume that all the iterations perform these 

4316# operations. In case of dynamic control flow, if only some of the iterations 

4317# try to perform a push/pop, then the conversion may not work correctly and may 

4318# cause undefined behavior. 

4319# TODO(agarwal): test StackV2 with dynamic control flow. 

4320# 

4321# 2. StackV2 is constructed outside the pfor loop. 

4322# Performing stack push/pop in a parallel fashion is ill-defined. However given 

4323# that reading stacks created externally is a common operation when computing 

4324# jacobians, we provide some special semantics here as follows. 

4325# - disallow push operations to the stack 

4326# - pop operations are performed in lock step by all iterations, similar to the 

4327# case when the stack is created inside. A single value is popped during the 

4328# lock-step operation and broadcast to all the iterations. Values in the stack 

4329# are assumed to be loop-invariant. 

4330# 

4331# Some other implementation details: 

4332# We use an ugly logic to find whether values in Stack data structure are 

4333# loop invariant or not. When converting push/pop operations, we keep track of 

4334# whether the last conversion used a stacked value or not (see _stack_cache 

4335# below). As a result if an unstacked value is written first, subsequent stacked 

4336# writes are disallowed when they could have been allowed in theory. 

4337 

4338# Map from cache key based on StackV2 handle to a bool indicating whether values 

4339# are stacked or not. 

4340# TODO(agarwal): move _stack_cache inside pfor? 

4341_stack_cache = {} 

4342 

4343 

4344def _stack_cache_key(pfor_input): 

4345 """Create cache key corresponding to a stack handle.""" 

4346 op_type = pfor_input.op_type 

4347 assert op_type in ["StackPushV2", "StackPopV2"], op_type 

4348 orig_handle = pfor_input.op.inputs[0] 

4349 while orig_handle.op.type in ["Identity", "Enter"]: 

4350 orig_handle = orig_handle.op.inputs[0] 

4351 assert orig_handle.op.type == "StackV2", orig_handle.op 

4352 return ops.get_default_graph(), pfor_input.pfor, orig_handle 

4353 

4354 

4355def _stack_handle_inside_pfor(handle, pfor_input): 

4356 while handle.op.type in ["Identity", "Enter"]: 

4357 handle = handle.op.inputs[0] 

4358 assert handle.op.type == "StackV2", ("Unable to find StackV2 op. Got %s" % 

4359 handle.op) 

4360 return pfor_input.pfor.op_is_inside_loop(handle.op) 

4361 

4362 

4363@RegisterPFor("StackPushV2") 

4364def _convert_stack_push_v2(pfor_input): 

4365 handle = pfor_input.unstacked_input(0) 

4366 elem, elem_stacked, _ = pfor_input.input(1) 

4367 swap_memory = pfor_input.get_attr("swap_memory") 

4368 

4369 if not _stack_handle_inside_pfor(pfor_input.op.inputs[0], pfor_input): 

4370 raise ValueError("StackPushV2 not allowed on stacks created outside pfor.") 

4371 stack_cache_key = _stack_cache_key(pfor_input) 

4372 stacked = _stack_cache.get(stack_cache_key, None) 

4373 if stacked is None: 

4374 stacked = elem_stacked 

4375 _stack_cache[stack_cache_key] = stacked 

4376 else: 

4377 # If we previously made it unstacked then we can't revert to being stacked. 

4378 if not stacked and elem_stacked: 

4379 raise ValueError( 

4380 "It looks like the stack was previously determined to be loop " 

4381 "invariant, but we are now trying to push a loop dependent value " 

4382 "to it. This is currently unsupported.") 

4383 if stacked and not elem_stacked: 

4384 elem = _stack(elem, pfor_input.pfor.loop_len_vector).t 

4385 out = data_flow_ops.stack_push_v2(handle, elem, swap_memory=swap_memory) 

4386 return wrap(out, stacked) 

4387 

4388 

4389# Note that inputs to this convertor will be unstacked. However it should get 

4390# called since it is a stateful op. 

4391@RegisterPFor("StackPopV2") 

4392def _convert_stack_pop_v2(pfor_input): 

4393 handle = pfor_input.unstacked_input(0) 

4394 stack_cache_key = _stack_cache_key(pfor_input) 

4395 stacked = _stack_cache.get(stack_cache_key, None) 

4396 # If a StackPushV2 has not been converted yet, we default to unstacked since 

4397 # the push could be outside of pfor, or the convertor may not be called if the 

4398 # inputs are unconverted. 

4399 if stacked is None: 

4400 stacked = False 

4401 _stack_cache[stack_cache_key] = False 

4402 elem_type = pfor_input.get_attr("elem_type") 

4403 out = data_flow_ops.stack_pop_v2(handle, elem_type) 

4404 return wrap(out, stacked) 

4405 

4406 

4407# parsing_ops 

4408 

4409 

4410@RegisterPFor("DecodeCSV") 

4411def _convert_decode_csv(pfor_input): 

4412 lines = pfor_input.stacked_input(0) 

4413 record_defaults = [ 

4414 pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs) 

4415 ] 

4416 field_delim = pfor_input.get_attr("field_delim") 

4417 use_quote_delim = pfor_input.get_attr("use_quote_delim") 

4418 select_cols = pfor_input.get_attr("select_cols") 

4419 if not select_cols: 

4420 select_cols = None 

4421 return [ 

4422 wrap(t, True) for t in parsing_ops.decode_csv( 

4423 lines, 

4424 record_defaults, 

4425 field_delim=field_delim, 

4426 use_quote_delim=use_quote_delim, 

4427 select_cols=select_cols) 

4428 ] 

4429 

4430 

4431@RegisterPFor("ParseSingleExample") 

4432def _convert_parse_single_example(pfor_input): 

4433 serialized = pfor_input.stacked_input(0) 

4434 dense_defaults = [ 

4435 pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs) 

4436 ] 

4437 sparse_keys = pfor_input.get_attr("sparse_keys") 

4438 dense_keys = pfor_input.get_attr("dense_keys") 

4439 sparse_types = pfor_input.get_attr("sparse_types") 

4440 dense_shapes = pfor_input.get_attr("dense_shapes") 

4441 output = gen_parsing_ops.parse_example( 

4442 serialized=serialized, 

4443 names=[], 

4444 dense_defaults=dense_defaults, 

4445 sparse_keys=sparse_keys, 

4446 dense_keys=dense_keys, 

4447 sparse_types=sparse_types, 

4448 dense_shapes=dense_shapes) 

4449 return [wrap(t, True, True) for t in nest.flatten(output)] 

4450 

4451 

4452@RegisterPFor("ParseExampleV2") 

4453def _convert_parse_example_v2(pfor_input): 

4454 serialized = pfor_input.stacked_input(0) 

4455 sparse_keys = pfor_input.unstacked_input(2) 

4456 dense_keys = pfor_input.unstacked_input(3) 

4457 ragged_keys = pfor_input.unstacked_input(4) 

4458 dense_defaults = [ 

4459 pfor_input.unstacked_input(i) for i in range(5, pfor_input.num_inputs) 

4460 ] 

4461 num_sparse = pfor_input.get_attr("num_sparse") 

4462 sparse_types = pfor_input.get_attr("sparse_types") 

4463 ragged_value_types = pfor_input.get_attr("ragged_value_types") 

4464 ragged_split_types = pfor_input.get_attr("ragged_split_types") 

4465 dense_shapes = pfor_input.get_attr("dense_shapes") 

4466 if serialized.shape.ndims not in (None, 1): 

4467 raise ValueError("ParseExampleV2 can only be converted if `serialized` " 

4468 f"is scalar. Received shape: {serialized.shape}.") 

4469 output = gen_parsing_ops.parse_example_v2( 

4470 serialized=serialized, 

4471 names=[], 

4472 sparse_keys=sparse_keys, 

4473 dense_keys=dense_keys, 

4474 ragged_keys=ragged_keys, 

4475 dense_defaults=dense_defaults, 

4476 num_sparse=num_sparse, 

4477 sparse_types=sparse_types, 

4478 ragged_value_types=ragged_value_types, 

4479 ragged_split_types=ragged_split_types, 

4480 dense_shapes=dense_shapes) 

4481 return [wrap(t, True, True) for t in nest.flatten(output)] 

4482 

4483 

4484# functional_ops 

4485 

4486 

4487def _convert_function_call(func, converter, inputs): 

4488 assert isinstance(func.graph, func_graph.FuncGraph), func 

4489 assert isinstance(converter, PFor) 

4490 

4491 graph_outputs = func.graph.outputs[:len(func.function_type.flat_outputs)] 

4492 # TODO(agarwal): consider caching this function definition. 

4493 @def_function.function 

4494 def f(*args): 

4495 assert all(isinstance(arg, WrappedTensor) for arg in args), args 

4496 assert len(args) == len(func.graph.inputs), (args, func.graph.inputs) 

4497 # Map inputs to function arguments. 

4498 for inp, arg in zip(func.graph.inputs, args): 

4499 converter._add_conversion(inp, arg) 

4500 # Convert output tensors. 

4501 return tuple([converter._convert_helper(x).t for x in graph_outputs]) 

4502 

4503 call_outputs = f(*inputs) 

4504 assert len(call_outputs) == len(graph_outputs) 

4505 outputs = [] 

4506 for call_output, output_tensor in zip(call_outputs, graph_outputs): 

4507 func_output = converter._convert_helper(output_tensor) 

4508 outputs.append( 

4509 wrap(call_output, func_output.is_stacked, func_output.is_sparse_stacked) 

4510 ) 

4511 return outputs 

4512 

4513 

4514@RegisterPFor("StatefulPartitionedCall") 

4515@RegisterPFor("PartitionedCall") 

4516def _convert_partitioned_call(pfor_input): 

4517 func_name = pfor_input.get_attr("f").name 

4518 func = pfor_input.op.graph._get_function(compat.as_bytes(func_name)) 

4519 assert isinstance(func.graph, func_graph.FuncGraph), ( 

4520 "Could not find FuncGraph object for %s. Got func %s" % (func_name, func)) 

4521 pfor = pfor_input.pfor 

4522 converter = PFor( 

4523 loop_var=pfor.loop_var, 

4524 loop_len=pfor.loop_len_vector[0], 

4525 pfor_ops=func.graph.get_operations(), 

4526 fallback_to_while_loop=pfor.fallback_to_while_loop, 

4527 all_indices=pfor.all_indices, 

4528 all_indices_partitioned=pfor.all_indices_partitioned, 

4529 pfor_config=pfor.pfor_config) 

4530 return _convert_function_call(func, converter, pfor_input.inputs) 

4531 

4532 

4533def _partition_inputs_for_indices(inputs, indices): 

4534 new_inputs = [] 

4535 for inp in inputs: 

4536 if inp.is_stacked: 

4537 new_inputs.append(wrap(array_ops.gather(inp.t, indices), True)) 

4538 else: 

4539 new_inputs.append(inp) 

4540 return new_inputs 

4541 

4542 

4543def _outputs_for_branch(func_name, indices, pfor_input, inputs): 

4544 if indices is None: 

4545 indices = pfor_input.pfor.all_indices 

4546 partitioned = pfor_input.pfor.all_indices_partitioned 

4547 else: 

4548 partitioned = True 

4549 func = pfor_input.op.graph._get_function(func_name) 

4550 converter = PFor( 

4551 loop_var=pfor_input.pfor.loop_var, 

4552 loop_len=array_ops.size(indices), 

4553 pfor_ops=func.graph.get_operations(), 

4554 fallback_to_while_loop=pfor_input.pfor.fallback_to_while_loop, 

4555 all_indices=indices, 

4556 all_indices_partitioned=partitioned, 

4557 pfor_config=pfor_input.pfor.pfor_config) 

4558 outputs = _convert_function_call(func, converter, inputs) 

4559 stacked_outputs = [] 

4560 for out in outputs: 

4561 if not out.is_stacked: 

4562 stacked_outputs.append(_stack(out.t, [array_ops.size(indices)]).t) 

4563 else: 

4564 stacked_outputs.append(out.t) 

4565 return stacked_outputs 

4566 

4567 

4568# TODO(agarwal): Currently the converted code aggressively tiles loop variant 

4569# outputs from the then/else branches. Instead, it could do so only if at least 

4570# one of the branch outputs is loop variant. 

4571@RegisterPFor("StatelessIf") 

4572@RegisterPFor("If") 

4573def _convert_if(pfor_input): 

4574 cond, cond_stacked, _ = pfor_input.input(0) 

4575 inputs = pfor_input.inputs[1:] 

4576 then_branch = pfor_input.get_attr("then_branch") 

4577 else_branch = pfor_input.get_attr("else_branch") 

4578 

4579 if cond_stacked: 

4580 cond_int = math_ops.cast(cond, dtypes.int32) 

4581 # Compute loop indices for the different branches 

4582 false_indices, true_indices = data_flow_ops.dynamic_partition( 

4583 pfor_input.pfor.all_indices, cond_int, 2) 

4584 # Compute indices for cond being True or False. 

4585 if pfor_input.pfor.all_indices_partitioned: 

4586 else_indices, then_indices = data_flow_ops.dynamic_partition( 

4587 math_ops.range(pfor_input.pfor.loop_len_vector[0]), 

4588 cond_int, 2) 

4589 else: 

4590 else_indices, then_indices = false_indices, true_indices 

4591 # Partition inputs 

4592 then_inputs = _partition_inputs_for_indices(inputs, then_indices) 

4593 else_inputs = _partition_inputs_for_indices(inputs, else_indices) 

4594 

4595 # Convert "then" branch. 

4596 then_outputs = _outputs_for_branch(then_branch.name, true_indices, 

4597 pfor_input, then_inputs) 

4598 

4599 # Convert "else" branch. 

4600 else_outputs = _outputs_for_branch(else_branch.name, false_indices, 

4601 pfor_input, else_inputs) 

4602 

4603 assert len(then_outputs) == len(else_outputs) 

4604 # Note that if the "then" and "else" branches are updating the same state, 

4605 # and possibly reading them as well, it could lead to undefined behavior 

4606 # since the ordering of those operations is not well defined. 

4607 # One possibility is to order all the "then" branches to execute before all 

4608 # the "else" branches so that the side-effects in the former are visible to 

4609 # the latter. For now, we leave that as undefined behavior. 

4610 outputs = [] 

4611 # Merge outputs 

4612 for then_output, else_output in zip(then_outputs, else_outputs): 

4613 out = data_flow_ops.dynamic_stitch([then_indices, else_indices], 

4614 [then_output, else_output]) 

4615 outputs.append(wrap(out, True)) 

4616 return outputs 

4617 else: 

4618 outputs = tf_cond.cond( 

4619 cond, 

4620 lambda: _outputs_for_branch(then_branch.name, None, pfor_input, inputs), 

4621 lambda: _outputs_for_branch(else_branch.name, None, pfor_input, inputs)) 

4622 return [wrap(t, True) for t in outputs] 

4623 

4624 

4625@RegisterPFor("Case") 

4626@RegisterPFor("StatelessCase") 

4627def _convert_stateless_case(pfor_input): 

4628 branch_idx, is_stacked, _ = pfor_input.input(0) 

4629 branches = pfor_input.get_attr("branches") 

4630 inputs = pfor_input.inputs[1:] 

4631 

4632 if is_stacked: 

4633 logging.info("Running stacked flow") 

4634 

4635 # Compute loop indices for the different branches 

4636 switch_indices = data_flow_ops.dynamic_partition( 

4637 pfor_input.pfor.all_indices, branch_idx, len(branches)) 

4638 if pfor_input.pfor.all_indices_partitioned: 

4639 partitioned_indices = data_flow_ops.dynamic_partition( 

4640 math_ops.range(pfor_input.pfor.loop_len_vector[0]), branch_idx, 

4641 len(branches)) 

4642 else: 

4643 partitioned_indices = switch_indices 

4644 # Partition inputs 

4645 input_list = [] 

4646 for indices in partitioned_indices: 

4647 input_list.append(_partition_inputs_for_indices(inputs, indices)) 

4648 

4649 outputs = [] 

4650 for (b, indices, inputs) in zip(branches, switch_indices, input_list): 

4651 out = _outputs_for_branch(b.name, indices, pfor_input, inputs) 

4652 outputs.extend(out) 

4653 

4654 out = data_flow_ops.dynamic_stitch(partitioned_indices, outputs) 

4655 return [wrap(out, True)] 

4656 else: 

4657 new_branches = [] 

4658 for b in branches: 

4659 def new_function(func=b.name): 

4660 return _outputs_for_branch(func, None, pfor_input, 

4661 pfor_input.inputs[1:]) 

4662 

4663 new_branches.append(new_function) 

4664 

4665 outputs = [] 

4666 outputs = control_flow_switch_case.switch_case(branch_idx, new_branches) 

4667 return [wrap(t, True) for t in outputs] 

4668 

4669 

4670class WhileV2: 

4671 """Object for vectorizing V2 while_loop op.""" 

4672 

4673 def __init__(self, pfor_input): 

4674 self._pfor_input = pfor_input 

4675 self._pfor = pfor_input.pfor 

4676 cond_func_name = pfor_input.get_attr("cond").name 

4677 self._cond_func = pfor_input.op.graph._get_function(compat.as_bytes( 

4678 cond_func_name)) 

4679 body_func_name = pfor_input.get_attr("body").name 

4680 self._body_func = pfor_input.op.graph._get_function(compat.as_bytes( 

4681 body_func_name)) 

4682 if self._cond_func is None or self._body_func is None: 

4683 raise ValueError("Error extracting cond and body functions for op " 

4684 f"{self._pfor_input.op}.") 

4685 # Indices of inputs that are passed unchanged through the while loop body. 

4686 # Typically these are tensors captured from outside the body context. 

4687 self._body_pass_through_indices = set() 

4688 for i, (inp, out) in enumerate(zip(self._body_func.graph.inputs, 

4689 self._body_func.graph.outputs)): 

4690 if id(inp) == id(out): 

4691 self._body_pass_through_indices.add(i) 

4692 self._parallel_iterations = self._pfor_input.get_attr("parallel_iterations") 

4693 

4694 def _output_shapes(self): 

4695 # Calculate output shape for vectorized loop. This will be used as 

4696 # shape_invariant. Merges shape inference outputs with the `output_shapes` 

4697 # attribute of the op. 

4698 output_shapes = [out.shape for out in self._pfor_input.op.outputs] 

4699 shapes = self._pfor_input.get_attr("output_shapes") 

4700 if not shapes: 

4701 shapes = [tensor_shape.TensorShape(None) for _ in output_shapes] 

4702 else: 

4703 shapes = [tensor_shape.TensorShape(shape) for shape in shapes] 

4704 for i, shape in enumerate(shapes): 

4705 shape = shape.merge_with(output_shapes[i]) 

4706 pfor_input = self._pfor_input.input(i) 

4707 if pfor_input.is_stacked: 

4708 if _is_variant_with_internal_stacking(pfor_input.t): 

4709 shape = tensor_shape.TensorShape([]).concatenate(shape) 

4710 else: 

4711 shape = tensor_shape.TensorShape([None]).concatenate(shape) 

4712 output_shapes[i] = shape 

4713 assert len(output_shapes) == self._pfor_input.num_inputs 

4714 return output_shapes 

4715 

4716 def _init_values(self): 

4717 """Create arguments passed to converted while_loop.""" 

4718 loop_len = self._pfor.loop_len_vector[0] 

4719 inputs = [] 

4720 # TensorArrays for outputs of converted while loop 

4721 output_tas = [] 

4722 

4723 with ops.name_scope("while_init"): 

4724 for inp in self._pfor_input.inputs: 

4725 inputs.append(inp.t) 

4726 variant_type_id = _variant_type_id(inp.t) 

4727 if variant_type_id in _INTERNAL_STACKING_TYPE_IDS: 

4728 if variant_type_id != full_type_pb2.TFT_ARRAY: 

4729 raise NotImplementedError( 

4730 "While loop conversion is only supported for TensorLists. Got " 

4731 f"another variant {inp.t}, probably an optional. Please file " 

4732 "a bug.") 

4733 

4734 # For TensorLists, the input format is: 

4735 # 

4736 # List[user_list_len, Tensor[loop_len, ...]] 

4737 # 

4738 # rather than the usual 

4739 # 

4740 # Tensor[loop_len, ...] 

4741 # 

4742 # The body of the loop will take and return lists in this "internal 

4743 # vectorization" format, so we want to keep it that way as much as 

4744 # possible. We'll accumulate finished iterations (only relevant for 

4745 # pfor-loop-variant while_loop conditions) in an accumulator with 

4746 # type : 

4747 # 

4748 # List[user_list_len, List[loop_len, Tensor[...]]] 

4749 # 

4750 # This means that each while_loop iteration, we'll iterate over the 

4751 # length of the TensorList, dividing done/remaining pfor loop indices 

4752 # and scattering the done indices into the inner nested list of the 

4753 # accumulator. 

4754 element_shape = list_ops.tensor_list_element_shape( 

4755 inp.t, dtypes.int32) 

4756 if inp.is_stacked: 

4757 # Shapes may be tf.constant(-1) for fully dynamic, in which case 

4758 # slicing is an error. 

4759 element_shape = tf_cond.cond( 

4760 math_ops.equal(array_ops.rank(element_shape), 0), 

4761 lambda: element_shape, 

4762 lambda: element_shape[1:]) 

4763 dtype = _parse_variant_shapes_and_types(inp.t)[0].dtype 

4764 

4765 def _init_loop_body(index, output_ta): 

4766 output_ta = output_ta.write( 

4767 index, 

4768 list_ops.tensor_list_reserve(element_shape, loop_len, dtype)) 

4769 return index + 1, output_ta 

4770 

4771 length = list_ops.tensor_list_length(inp.t) 

4772 output_ta = tensor_array_ops.TensorArray( 

4773 inp.t.dtype, # Variant; this is a nested TensorList 

4774 size=length, 

4775 dynamic_size=True, 

4776 infer_shape=False) 

4777 _, output_ta = while_loop.while_loop(lambda index, _: index < length, 

4778 _init_loop_body, [0, output_ta]) 

4779 else: 

4780 output_ta = tensor_array_ops.TensorArray( 

4781 inp.t.dtype, 

4782 size=loop_len, 

4783 dynamic_size=False, 

4784 infer_shape=True) 

4785 output_tas.append(output_ta) 

4786 # See documentation for __call__ for the structure of init_values. 

4787 indices = ( 

4788 math_ops.range(self._pfor.loop_len_vector[0]) 

4789 if self._pfor.all_indices_partitioned else self._pfor.all_indices) 

4790 return [True, indices] + inputs + output_tas 

4791 

4792 def _process_cond_unstacked(self, conditions, indices, inputs, output_tas): 

4793 """Handles case when condition is pfor loop invariant.""" 

4794 # Note that all iterations end together. So we don't need to partition the 

4795 # inputs. 

4796 not_all_done = array_ops.reshape(conditions, []) 

4797 return not_all_done, indices, inputs, output_tas 

4798 

4799 def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked, 

4800 output_tas): 

4801 """Handles case when condition is pfor loop dependent.""" 

4802 # Compute if all iterations are done. 

4803 not_all_done = math_ops.reduce_any(conditions) 

4804 conditions_int = math_ops.cast(conditions, dtypes.int32) 

4805 # Partition the indices. 

4806 done_indices, new_indices = data_flow_ops.dynamic_partition( 

4807 indices, conditions_int, 2) 

4808 

4809 new_inputs = [] 

4810 new_output_tas = [] 

4811 for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)): 

4812 pass_through = i in self._body_pass_through_indices 

4813 if not pass_through and _variant_type_id(inp) == full_type_pb2.TFT_ARRAY: 

4814 shape_and_type = _parse_variant_shapes_and_types(inp)[0] 

4815 element_shape = list_ops.tensor_list_element_shape(inp, dtypes.int32) 

4816 user_list_len = list_ops.tensor_list_length(inp) 

4817 

4818 def _split_vectorized_ta_element(index, new_inp, new_out_ta): 

4819 elem = list_ops.tensor_list_get_item(inp, index, shape_and_type.dtype, 

4820 element_shape) 

4821 if stacked: 

4822 done_elem, new_elem = data_flow_ops.dynamic_partition( 

4823 elem, conditions_int, 2) 

4824 new_inp = list_ops.tensor_list_set_item(new_inp, index, new_elem) 

4825 else: 

4826 done_elem = _stack(elem, [array_ops.size(done_indices)]).t 

4827 done_accum = new_out_ta.read(index) 

4828 done_accum = list_ops.tensor_list_scatter( 

4829 tensor=done_elem, indices=done_indices, input_handle=done_accum) 

4830 new_out_ta = new_out_ta.write(index, done_accum) 

4831 return index + 1, new_inp, new_out_ta 

4832 

4833 length = list_ops.tensor_list_length(inp) 

4834 new_inp = list_ops.tensor_list_reserve( 

4835 tensor_shape.TensorShape([None]) 

4836 + tensor_shape.TensorShape(shape_and_type.shape)[1:], 

4837 user_list_len, shape_and_type.dtype) 

4838 _, new_inp, out_ta = while_loop.while_loop( 

4839 lambda index, unused_new_inp, unused_new_out_ta: index < length, 

4840 _split_vectorized_ta_element, [0, new_inp, output_tas[i]]) 

4841 else: 

4842 # Partition the inputs. 

4843 if stacked: 

4844 done_inp, new_inp = data_flow_ops.dynamic_partition( 

4845 inp, conditions_int, 2) 

4846 else: 

4847 if not pass_through: 

4848 done_inp = _stack(inp, [array_ops.size(done_indices)]).t 

4849 new_inp = inp 

4850 

4851 out_ta = output_tas[i] 

4852 if not pass_through: 

4853 # Note that done_indices can be empty. done_inp should also be empty 

4854 # in that case. 

4855 out_ta = out_ta.scatter(done_indices, done_inp) 

4856 new_inputs.append(new_inp) 

4857 new_output_tas.append(out_ta) 

4858 

4859 assert len(new_output_tas) == len(output_tas) 

4860 assert len(new_inputs) == len(inputs) 

4861 return not_all_done, new_indices, new_inputs, new_output_tas 

4862 

4863 def _process_body(self, inputs_stacked, new_indices, cond_stacked, 

4864 new_inputs, not_all_done): 

4865 """Convert the body function.""" 

4866 # This is used to store the indices of inputs to the while op that need to 

4867 # be stacked. This stacking may be needed in cases where the input to the 

4868 # while_loop is loop_invariant but the corresponding output is not. 

4869 mismatching_stacked_indices = [] 

4870 

4871 def true_fn(): 

4872 """Converts the body function for all but last iteration.""" 

4873 wrapped_inputs = [wrap(inp, stacked) for inp, stacked in 

4874 zip(new_inputs, inputs_stacked)] 

4875 # Note the iterative process below to figure out loop invariance. 

4876 # Here we iterate on vectorization process till a fixed point. The issue 

4877 # is that the while body can take pfor loop invariant inputs but return 

4878 # loop variant outputs. For any loop variant output, the corresponding 

4879 # input has to be then made loop variant (since subsequent while 

4880 # iterations will need to see loop variant values). 

4881 # However once we make a new input loop variant, we might make other 

4882 # outputs loop variant. Hence we need to iterate till we get fixed point. 

4883 while True: 

4884 if self._pfor.all_indices_partitioned: 

4885 indices = array_ops.gather(self._pfor.all_indices, new_indices) 

4886 else: 

4887 indices = new_indices 

4888 body_pfor = PFor( 

4889 loop_var=self._pfor.loop_var, 

4890 loop_len=array_ops.size(new_indices), 

4891 pfor_ops=self._body_func.graph.get_operations(), 

4892 fallback_to_while_loop=self._pfor.fallback_to_while_loop, 

4893 all_indices=indices, 

4894 all_indices_partitioned=(self._pfor.all_indices_partitioned or 

4895 cond_stacked), 

4896 pfor_config=self._pfor.pfor_config) 

4897 stacking_mismatch = False 

4898 outputs = _convert_function_call(self._body_func, 

4899 body_pfor, 

4900 wrapped_inputs) 

4901 for i, (out, inp) in enumerate(zip(outputs, wrapped_inputs)): 

4902 if out.is_stacked != inp.is_stacked: 

4903 stacking_mismatch = True 

4904 mismatching_stacked_indices.append(i) 

4905 stacked = _stack(inp.t, [array_ops.size(new_indices)]) 

4906 if inp.t.dtype == dtypes.variant: 

4907 stacked = wrap( 

4908 _tile_variant_with_length(stacked.t, 

4909 [array_ops.size(new_indices)])) 

4910 wrapped_inputs[i] = stacked 

4911 if not stacking_mismatch: 

4912 if mismatching_stacked_indices: 

4913 # We needed to stack some inputs. This code will be abandoned and 

4914 # should not get executed. Hence we simply return `new_inputs` to 

4915 # make sure the graph construction code completes. 

4916 with ops.control_dependencies([ 

4917 control_flow_assert.Assert( 

4918 False, ["pfor ERROR: this branch should never execute"]) 

4919 ]): 

4920 return [array_ops.identity(x) for x in new_inputs] 

4921 else: 

4922 return [out.t for out in outputs] 

4923 

4924 # If all are done, we simply return `new_inputs`. Else we need to run the 

4925 # body function. 

4926 return tf_cond.cond( 

4927 not_all_done, 

4928 true_fn, 

4929 lambda: list(new_inputs)), mismatching_stacked_indices 

4930 

4931 def __call__(self): 

4932 """Converter for the V2 while_loop. 

4933 

4934 The conversion of a while_loop is another while_loop. 

4935 

4936 The arguments to this converted while_loop are as follows: 

4937 not_all_done: Boolean scalar Tensor indicating if all the pfor iterations 

4938 are done. 

4939 indices: int32 1-D Tensor storing the id of the pfor iterations that are not 

4940 done. 

4941 args: Remaining arguments. These can be divided into 2 categories: 

4942 - The first set of arguments correspond one-to-one to the inputs to the 

4943 unvectorized while_loop. 

4944 - The second set are TensorArrays, corresponding one-to-one to each output 

4945 of the unvectorized while_loop. Each TensorArray has `PFor.loop_len` 

4946 elements, i.e. the number of pfor iterations. At the end, the i'th 

4947 element of each TensorArray will contain the output computed by the i'th 

4948 iteration of pfor. Note that elements can be written into these tensors 

4949 arrays in any order, depending on when the corresponding pfor iteration 

4950 is done. 

4951 In each iteration, the while_loop body recomputes the condition for all 

4952 active pfor iterations to see which of them are now done. It then partitions 

4953 all the inputs and passes them along to the converted body. Values for all 

4954 the iterations that are done are written to TensorArrays indexed by the pfor 

4955 iteration number. When all iterations are done, the TensorArrays are stacked 

4956 to get the final value. 

4957 

4958 Returns: 

4959 List of converted outputs. 

4960 """ 

4961 output_shapes = self._output_shapes() 

4962 # Note that we use these lists as a hack since we need the `body` to compute 

4963 # these values during construction of the while_loop graph. 

4964 cond_is_stacked = [None] 

4965 indices_to_stack = [] 

4966 

4967 def cond(not_all_done, *_): 

4968 return not_all_done 

4969 

4970 def body(not_all_done, indices, *args): 

4971 # See documentation for __call__ for the structure of *args. 

4972 num_inputs = self._pfor_input.num_inputs 

4973 inputs = args[:num_inputs] 

4974 output_tas = args[num_inputs:] 

4975 inputs_stacked = [x.is_stacked for x in self._pfor_input.inputs] 

4976 assert len(inputs) >= len(output_tas) 

4977 assert len(inputs) == len(inputs_stacked) 

4978 # Convert condition 

4979 with ops.name_scope("while_cond"): 

4980 # Note that we set all_indices_partitioned to True here. At this point 

4981 # we don't know if indices will be partitioned. Hence we use the 

4982 # conservative value. 

4983 cond_pfor = PFor( 

4984 loop_var=self._pfor.loop_var, 

4985 loop_len=array_ops.size(indices), 

4986 pfor_ops=self._cond_func.graph.get_operations(), 

4987 fallback_to_while_loop=self._pfor.fallback_to_while_loop, 

4988 all_indices=indices, 

4989 all_indices_partitioned=True, 

4990 pfor_config=self._pfor.pfor_config) 

4991 

4992 wrapped_inputs = [wrap(inp, stacked) for inp, stacked 

4993 in zip(inputs, inputs_stacked)] 

4994 conditions, cond_stacked, _ = _convert_function_call( 

4995 self._cond_func, 

4996 cond_pfor, 

4997 wrapped_inputs)[0] 

4998 cond_is_stacked[0] = cond_stacked 

4999 

5000 # Recompute the new condition, write outputs of done iterations, and 

5001 # partition the inputs if needed. 

5002 if not cond_stacked: 

5003 (not_all_done, new_indices, new_inputs, 

5004 new_output_tas) = self._process_cond_unstacked(conditions, indices, 

5005 inputs, output_tas) 

5006 else: 

5007 (not_all_done, new_indices, new_inputs, 

5008 new_output_tas) = self._process_cond_stacked(conditions, indices, 

5009 inputs, inputs_stacked, 

5010 output_tas) 

5011 # Convert body 

5012 with ops.name_scope("while_body"): 

5013 # Compute the outputs from the body. 

5014 new_outputs, mismatching_stacked_indices = self._process_body( 

5015 inputs_stacked, new_indices, cond_stacked, new_inputs, not_all_done) 

5016 

5017 indices_to_stack[:] = mismatching_stacked_indices 

5018 for i, new_output in enumerate(new_outputs): 

5019 new_output.set_shape(output_shapes[i]) 

5020 new_args = ([not_all_done, new_indices] + new_outputs + 

5021 list(new_output_tas)) 

5022 return tuple(new_args) 

5023 

5024 # Note that we run the code below in a function since we might abandon the 

5025 # generated code in cases where the conversion dictates that some inputs be 

5026 # further stacked. Hence we run the graph construction using 

5027 # `get_concrete_function` and avoid calling the constructed function if not 

5028 # needed. 

5029 @def_function.function 

5030 def while_fn(): 

5031 # Create init_values that will be passed to the while_loop. 

5032 init_values = self._init_values() 

5033 ta_shape_invariants = [tensor_shape.TensorShape([]) for _ in 

5034 self._pfor_input.outputs] 

5035 shape_invariants = ( 

5036 [tensor_shape.TensorShape([]), tensor_shape.TensorShape([None])] 

5037 + output_shapes + ta_shape_invariants) 

5038 

5039 while_outputs = while_loop.while_loop( 

5040 cond, 

5041 body, 

5042 init_values, 

5043 shape_invariants=shape_invariants, 

5044 parallel_iterations=self._parallel_iterations) 

5045 if indices_to_stack: 

5046 # This function will be abandoned. 

5047 return while_outputs 

5048 else: 

5049 num_inputs = self._pfor_input.num_inputs 

5050 new_inputs = while_outputs[2:num_inputs+2] 

5051 output_tas = while_outputs[num_inputs+2:] 

5052 assert cond_is_stacked[0] is not None 

5053 outputs = [] 

5054 for i, inp in enumerate(new_inputs): 

5055 if cond_is_stacked[0]: 

5056 if i in self._body_pass_through_indices: 

5057 outputs.append(init_values[i + 2]) 

5058 else: 

5059 ta = output_tas[i] 

5060 if _variant_type_id(inp) == full_type_pb2.TFT_ARRAY: 

5061 shape_and_type = _parse_variant_shapes_and_types(inp)[0] 

5062 length = list_ops.tensor_list_length(inp) 

5063 

5064 # We have been accumulating values in a: 

5065 # 

5066 # List[user_list_len, List[loop_len, Tensor[...]]] 

5067 # 

5068 # We want to return an output in the same format as the input: 

5069 # 

5070 # List[user_list_len, Tensor[loop_len, ...]] 

5071 # 

5072 # So we need to loop over the list and stack its contents. 

5073 def _stack_loop_body(index, output_list): 

5074 current_value = ta.read(index) 

5075 output_list = list_ops.tensor_list_set_item( 

5076 output_list, index, 

5077 list_ops.tensor_list_stack( 

5078 current_value, shape_and_type.dtype)) 

5079 return index + 1, output_list 

5080 

5081 output_list = list_ops.tensor_list_reserve( 

5082 tensor_shape.TensorShape(shape_and_type.shape), length, 

5083 shape_and_type.dtype) 

5084 _, output_list = while_loop.while_loop( 

5085 lambda index, _: index < length, _stack_loop_body, 

5086 [0, output_list]) 

5087 outputs.append(output_list) 

5088 else: 

5089 outputs.append(ta.stack()) 

5090 else: 

5091 outputs.append(inp) 

5092 return outputs 

5093 

5094 _ = while_fn.get_concrete_function() 

5095 if indices_to_stack: 

5096 # Need to abandon the current conversion, stack some inputs and restart. 

5097 self._pfor_input.stack_inputs( 

5098 stack_indices=indices_to_stack, tile_variants=True) 

5099 # Note that this call will recurse at most one time. The first call will 

5100 # do the required stacking, based on the iterative procedure in 

5101 # _process_body, and the next invocation to __call__ should not need to do 

5102 # any more stacking. 

5103 # We invoke `self()` here as a way to discard any corrupted state. 

5104 return self() 

5105 else: 

5106 outputs = while_fn() 

5107 wrapped_outputs = [] 

5108 for i, (out, inp) in enumerate(zip(outputs, self._pfor_input.inputs)): 

5109 if i not in self._body_pass_through_indices and cond_is_stacked[0]: 

5110 wrapped_outputs.append(wrap(out, True)) 

5111 else: 

5112 wrapped_outputs.append(wrap(out, inp.is_stacked)) 

5113 return wrapped_outputs 

5114 

5115 

5116@RegisterPFor("StatelessWhile") 

5117@RegisterPFor("While") 

5118def _convert_while(pfor_input): 

5119 converter = WhileV2(pfor_input) 

5120 return converter() 

5121 

5122 

5123# spectral_ops 

5124 

5125 

5126@RegisterPForWithArgs("FFT", gen_spectral_ops.fft) 

5127@RegisterPForWithArgs("FFT2D", gen_spectral_ops.fft2d) 

5128@RegisterPForWithArgs("FFT3D", gen_spectral_ops.fft3d) 

5129@RegisterPForWithArgs("IFFT", gen_spectral_ops.ifft) 

5130@RegisterPForWithArgs("IFFT2D", gen_spectral_ops.ifft2d) 

5131@RegisterPForWithArgs("IFFT3D", gen_spectral_ops.ifft3d) 

5132def _convert_fft(pfor_input, _, op_func): 

5133 return wrap(op_func(pfor_input.stacked_input(0)), True) 

5134 

5135 

5136@RegisterPForWithArgs("RFFT", gen_spectral_ops.rfft, "Tcomplex") 

5137@RegisterPForWithArgs("RFFT2D", gen_spectral_ops.rfft2d, "Tcomplex") 

5138@RegisterPForWithArgs("RFFT3D", gen_spectral_ops.rfft3d, "Tcomplex") 

5139@RegisterPForWithArgs("IRFFT", gen_spectral_ops.irfft, "Treal") 

5140@RegisterPForWithArgs("IRFFT2D", gen_spectral_ops.irfft2d, "Treal") 

5141@RegisterPForWithArgs("IRFFT3D", gen_spectral_ops.irfft3d, "Treal") 

5142def _convert_rfft(pfor_input, _, op_func, attr_name): 

5143 inp = pfor_input.stacked_input(0) 

5144 fft_length = pfor_input.unstacked_input(1) 

5145 attr = pfor_input.get_attr(attr_name) 

5146 return wrap(op_func(inp, fft_length, attr), True)