Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/parallel_for/pfor.py: 24%
2693 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Compiled parallel-for loop."""
16# pylint: disable=missing-docstring,g-direct-tensorflow-import
18import collections
19from functools import partial
20import string
21import sys
22import traceback
24import numpy as np
26from tensorflow.compiler.tf2xla.python import xla
27from tensorflow.core.framework import full_type_pb2
28from tensorflow.python.eager import context
29from tensorflow.python.eager import def_function
30from tensorflow.python.eager import execute
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import func_graph
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import smart_cond
36from tensorflow.python.framework import sparse_tensor
37from tensorflow.python.framework import tensor_shape
38from tensorflow.python.framework import tensor_spec
39from tensorflow.python.framework import tensor_util
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import array_ops_stack
42from tensorflow.python.ops import cond as tf_cond
43from tensorflow.python.ops import control_flow_assert
44from tensorflow.python.ops import control_flow_ops
45from tensorflow.python.ops import control_flow_switch_case
46from tensorflow.python.ops import data_flow_ops
47from tensorflow.python.ops import gen_array_ops
48from tensorflow.python.ops import gen_image_ops
49from tensorflow.python.ops import gen_linalg_ops
50from tensorflow.python.ops import gen_list_ops
51from tensorflow.python.ops import gen_math_ops
52from tensorflow.python.ops import gen_nn_ops
53from tensorflow.python.ops import gen_optional_ops
54from tensorflow.python.ops import gen_parsing_ops
55from tensorflow.python.ops import gen_random_ops
56from tensorflow.python.ops import gen_sparse_ops
57from tensorflow.python.ops import gen_spectral_ops
58from tensorflow.python.ops import handle_data_util
59from tensorflow.python.ops import linalg_ops
60from tensorflow.python.ops import list_ops
61from tensorflow.python.ops import manip_ops
62from tensorflow.python.ops import map_fn
63from tensorflow.python.ops import math_ops
64from tensorflow.python.ops import nn_ops
65from tensorflow.python.ops import parsing_ops
66from tensorflow.python.ops import resource_variable_ops
67from tensorflow.python.ops import sparse_ops
68from tensorflow.python.ops import special_math_ops
69from tensorflow.python.ops import tensor_array_ops
70from tensorflow.python.ops import while_loop
71from tensorflow.python.platform import flags
72from tensorflow.python.platform import tf_logging as logging
73from tensorflow.python.util import compat
74from tensorflow.python.util import nest
75from tensorflow.python.util import object_identity
78# TODO(agarwal): remove flag.
79flags.DEFINE_bool(
80 "op_conversion_fallback_to_while_loop", True,
81 "DEPRECATED: Flag is ignored.")
84def _variant_handle_data(t):
85 """Fetches handle data for a variant tensor `t`, or None if unavailable."""
86 handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
87 if not handle_data.is_set:
88 return None
89 return handle_data.shape_and_type
92def _variant_type_id(t):
93 """Returns the full_type_pb2 type of `t`, or None if it is not available."""
94 if t.dtype != dtypes.variant:
95 return None
96 shapes_and_types = _variant_handle_data(t)
97 if shapes_and_types is None or not shapes_and_types:
98 # TODO(b/169968286): Identify all variant tensors (e.g. maps) and we can
99 # make this an error instead of assuming TensorLists have handle data.
100 return None # Presumed not a TensorList/Optional
101 return shapes_and_types[0].type.type_id
104_INTERNAL_STACKING_TYPE_IDS = (
105 full_type_pb2.TFT_ARRAY,
106 full_type_pb2.TFT_OPTIONAL)
109def _is_variant_with_internal_stacking(t):
110 """Identifies variant tensors which pfor always maintains as scalars.
112 For these, the pfor tensor is recorded as "stacked" if the content of the
113 variant tensor (e.g. the elements of a TensorList) are all stacked.
115 Args:
116 t: A tensor to identify.
117 Returns:
118 True if `t` is a TensorList/Optional, False not, None if unknown.
119 """
120 type_id = _variant_type_id(t)
121 return type_id in _INTERNAL_STACKING_TYPE_IDS
124def _parse_variant_shapes_and_types(t):
125 """Extracts shape and dtype information from a variant tensor `t`."""
126 shapes_and_types = _variant_handle_data(t)
127 if shapes_and_types is None or not shapes_and_types:
128 raise ValueError("Required handle data not set for {!r}".format(t))
129 if shapes_and_types[0].type.type_id == full_type_pb2.TFT_ARRAY:
130 return shapes_and_types
131 else:
132 if shapes_and_types[0].type.type_id == full_type_pb2.TFT_UNSET:
133 return shapes_and_types
134 else:
135 raise ValueError(
136 "Attempted to stack a variant-dtype tensor with no type set ({!r})"
137 .format(t))
140def _stack(t, length):
141 """stacks `t` `length` times."""
142 # Note that this stacking may currently be triggered, for example, when a
143 # loop invariant tensor with dtype variant is input to a while_loop which then
144 # produces a loop dependent output. Simply stacking the variants may not be
145 # suitable since operations on stacked handles may expect a vectorized version
146 # of the variant.
147 if t.dtype == dtypes.variant:
148 shapes_and_types = _parse_variant_shapes_and_types(t)
149 if shapes_and_types[0].type.type_id == full_type_pb2.TFT_ARRAY:
150 if len(shapes_and_types) != 1:
151 raise ValueError(
152 f"Expected handle data of length 1, got {shapes_and_types!r} of "
153 f"length {len(shapes_and_types)}.")
154 return wrap(
155 _stack_tensor_list(t, shapes_and_types[0].dtype, length),
156 True)
157 else:
158 raise ValueError(
159 "Attempted to stack an unhandled variant-dtype tensor of "
160 f"type {shapes_and_types[0].type!r} ({t!r}).")
161 ones = array_ops.ones_like(array_ops.shape(t))
162 ones = array_ops.reshape(ones, [-1])
163 length = array_ops.reshape(length, [-1])
164 multiples = array_ops.concat([length, ones], 0)
165 t = array_ops.tile(array_ops.expand_dims(t, 0), multiples)
166 return wrap(t, True)
169# The following stateful ops can be safely called once, and with the same
170# signature as the unconverted version, if their inputs are loop invariant.
171# TODO(agarwal): implement a strategy for converting Variable reads/writes. The
172# plan is to map each read/write in the loop_fn to a corresponding merged
173# read/write in the converted graph. Writes need to be mergeable (e.g.
174# AssignAdd) to be used in `pfor`. Given a certain read/write order in the
175# loop_fn, doing a one-to-one conversion will simulate executing such
176# instructions in lock-step across all iterations.
177passthrough_stateful_ops = set([
178 "VariableV2",
179 "VarHandleOp",
180 "VariableShape",
181 "ReadVariableOp",
182 "StackV2",
183 "TensorArrayWriteV3",
184 "TensorArrayReadV3",
185 "TensorArraySizeV3",
186])
189# Ops which we will treat like stateful for the purpose of vectorization.
190# Typically this is used to force pfor converters to run for these ops.
191force_stateful_ops = set([
192 # We vectorize this since we need to change the element shape set on the
193 # list.
194 "TensorListReserve",
195])
198def _is_stateful_pfor_op(op):
199 if isinstance(op, WhileOp):
200 return op.is_stateful
201 if op.type == "Const":
202 # Const didn't have an op_def.
203 return False
204 if op.type in passthrough_stateful_ops:
205 return False
206 if op.type in force_stateful_ops:
207 return True
208 assert hasattr(op, "op_def") and op.op_def is not None, op
209 return op.op_def.is_stateful
212# pylint: disable=protected-access
213class WhileOp:
214 """Object for storing state for converting the outputs of a while_loop."""
216 def __init__(self, exit_node, pfor_ops, fallback_to_while_loop, pfor_config):
217 """Initializer.
219 Args:
220 exit_node: A tensor output from the while_loop.
221 pfor_ops: list of ops inside the current pfor loop.
222 fallback_to_while_loop: If True, fallback to while loop when conversion of
223 an op is not supported
224 pfor_config: PForConfig object used while constructing loop body.
225 """
226 self._fallback_to_while_loop = fallback_to_while_loop
227 self._pfor_config = pfor_config
228 self._pfor_ops = set(pfor_ops)
229 self._pfor_op_ids = set(x._id for x in pfor_ops)
230 assert isinstance(exit_node, ops.Tensor)
231 self._while_context = exit_node.op._get_control_flow_context()
232 assert isinstance(self._while_context, control_flow_ops.WhileContext)
233 self._context_name = self._while_context.name
234 self._condition = self._while_context.pivot.op.inputs[0]
235 # Parts of an external while_loop could be created inside a pfor loop.
236 # However for the purpose here, we declare such loops to be external. Also
237 # note that we check if the condition was created inside or outside to
238 # determine if the while_loop was first created inside or outside.
239 # TODO(agarwal): check that the Enter and Exit of this loop are unstacked.
240 self._is_inside_loop = self.op_is_inside_loop(self._condition.op)
241 if self._is_inside_loop:
242 for e in self._while_context.loop_exits:
243 assert self.op_is_inside_loop(e.op)
245 # Note the code below tries to reverse engineer an existing while_loop graph
246 # by assuming the following pattern of nodes.
247 #
248 # NextIteration <---- Body <--- Enter
249 # | ^
250 # V ___| Y
251 # Enter -> Merge -> Switch___
252 # ^ | N
253 # | V
254 # LoopCond Exit
256 # Node that elements in the list below correspond one-to-one with each
257 # other. i.e. these lists are the same size, and the i_th entry corresponds
258 # to different Operations/Tensors of a single cycle as illustrated above.
259 # List of Switch ops (ops.Operation) that feed into an Exit Node.
260 self._exit_switches = []
261 # List of inputs (ops.Tensor) to NextIteration.
262 self._body_outputs = []
263 # List of list of control inputs of the NextIteration nodes.
264 self._next_iter_control_inputs = []
265 # List of Merge ops (ops.Operation).
266 self._enter_merges = []
267 # List of output (ops.Tensor) of Exit nodes.
268 self._outputs = []
270 # List of Enter Tensors.
271 # There are two types of Enter nodes:
272 # - The Enter nodes that are used in the `loop_vars` argument to
273 # `while_loop` (see
274 # https://www.tensorflow.org/api_docs/python/tf/while_loop). We collect
275 # these Enter nodes immediately below by tracing backwards from the Exit
276 # nodes via Exit <- Switch <- Merge <- Enter. You can see this chain in the
277 # diagram above. This allows us to have a 1:1 correspondence between the
278 # self._outputs and the first elements in self._enters.
279 # - The Enter nodes that are used only by the body. They don't appear in the
280 # `loop_vars` and are not returned from the `while_loop`. In Python code,
281 # they are usually captured by the body lambda. We collect them below by
282 # iterating over all the ops in the graph. They are appended to the end of
283 # self._enters or self._direct_enters, and don't correspond to any outputs
284 # in self._outputs. Note that we keep the resource/variant Enter nodes in
285 # self._direct_enters and the constructed while_loop's body uses them
286 # directly as opposed to passing them as loop variables. This is done
287 # because the while_body cannot partition the resource/variant Tensors, so
288 # it has to leave them unchanged.
289 self._enters = []
290 self._direct_enters = []
292 for e in self._while_context.loop_exits:
293 self._outputs.append(e.op.outputs[0])
294 switch = e.op.inputs[0].op
295 assert switch.type == "Switch", switch
296 self._exit_switches.append(switch)
297 merge = switch.inputs[0].op
298 assert merge.type == "Merge", merge
299 self._enter_merges.append(merge)
300 enter = merge.inputs[0].op
301 assert enter.type == "Enter", enter
302 self._enters.append(enter.outputs[0])
303 next_iter = merge.inputs[1].op
304 assert next_iter.type == "NextIteration", next_iter
305 self._body_outputs.append(next_iter.inputs[0])
306 self._next_iter_control_inputs.append(next_iter.control_inputs)
308 # Collect all the Enter nodes that are not part of `loop_vars`, the second
309 # category described above.
310 # Also track whether the loop body has any stateful ops.
311 self._is_stateful = False
312 for op in ops.get_default_graph().get_operations():
313 # TODO(agarwal): make sure this works with nested case.
314 control_flow_context = op._get_control_flow_context()
315 if control_flow_context is None:
316 continue
317 if control_flow_context.name == self._context_name:
318 self._is_stateful |= _is_stateful_pfor_op(op)
319 if op.type == "Enter":
320 output = op.outputs[0]
321 if output not in self._enters:
322 if output.dtype in (dtypes.resource, dtypes.variant):
323 if output not in self._direct_enters:
324 self._direct_enters.append(output)
325 else:
326 self._enters.append(output)
328 def __str__(self):
329 """String representation."""
330 return "while_loop(%s)" % self.name
332 @property
333 def inputs(self):
334 """Input to all the Enter nodes."""
335 return [x.op.inputs[0] for x in self._enters + self._direct_enters]
337 @property
338 def control_inputs(self):
339 """Control input to all the Enter nodes."""
340 control_inputs = []
341 for x in self._enters + self._direct_enters:
342 control_inputs.extend(x.op.control_inputs)
343 return control_inputs
345 @property
346 def outputs(self):
347 """Outputs of all the Exit nodes."""
348 return self._outputs
350 @property
351 def name(self):
352 """Context name for the while loop."""
353 return self._context_name
355 @property
356 def is_inside_loop(self):
357 """Returns true if the while_loop was created inside the pfor."""
358 return self._is_inside_loop
360 def op_is_inside_loop(self, op):
361 """True if op was created inside the pfor loop body."""
362 assert isinstance(op, ops.Operation)
363 # Note that we use self._pfor_op_ids for the check and not self._pfor_ops
364 # since it appears there tensorflow API could return different python
365 # objects representing the same Operation node.
366 return op._id in self._pfor_op_ids
368 @property
369 def is_stateful(self):
370 return self._is_stateful
372 @property
373 def pfor_converter(self):
374 """Return a converter for the while loop."""
375 return self
377 def _init_pfor(self, parent_pfor, indices, cond_stacked, inputs,
378 inputs_stacked):
379 """Create a PFor object for converting parts of the while_loop.
381 Args:
382 parent_pfor: PFor object being used for converting the while_loop.
383 indices: int32 Tensor of ids for the iterations that are still active
384 (i.e. did not exit the while_loop).
385 cond_stacked: True if the while_loop condition is stacked.
386 inputs: list of input Tensors corresponding 1-to-1 with self._enters. Note
387 that these Tensors are a subset of the loop variables for the generated
388 while_loop.
389 inputs_stacked: List of booleans corresponding 1-to-1 with `inputs`,
390 indicating if the value is stacked or not.
392 Returns:
393 A PFor instance. The instance is initialized by adding conversion mappings
394 of nodes that will be external to the conversion that the returned
395 instance will be used for. e.g. Enter nodes as well as Merge and Switch
396 outputs are mapped to converted values.
397 """
398 num_outputs = len(self._outputs)
399 assert len(inputs) == len(self._enters)
400 assert len(inputs_stacked) == len(self._enters)
401 loop_var = parent_pfor.loop_var
402 loop_len = array_ops.size(indices)
403 pfor = PFor(
404 loop_var,
405 loop_len,
406 pfor_ops=self._pfor_ops,
407 all_indices=indices,
408 all_indices_partitioned=cond_stacked,
409 fallback_to_while_loop=self._fallback_to_while_loop,
410 pfor_config=self._pfor_config)
411 # Map all inputs of Enter nodes in self._direct_enters to their converted
412 # values.
413 for enter in self._direct_enters:
414 enter_input = enter.op.inputs[0]
415 converted_enter, stacked, is_sparse_stacked = parent_pfor._convert_helper(
416 enter_input)
417 # Since these are resources / variants, they should be unstacked.
418 assert not stacked and not is_sparse_stacked, (enter, converted_enter)
419 pfor._add_conversion(enter, wrap(converted_enter, False))
421 # Map all Enter nodes to the inputs.
422 for enter, inp, stacked in zip(self._enters, inputs, inputs_stacked):
423 pfor._add_conversion(enter, wrap(inp, stacked))
424 # Map outputs of Switch and Merge.
425 for i in range(num_outputs):
426 wrapped_inp = wrap(inputs[i], inputs_stacked[i])
427 merge = self._enter_merges[i]
428 pfor._add_conversion(merge.outputs[0], wrapped_inp)
429 # Note that second output of Merge is typically not used, except possibly
430 # as a control dependency. To avoid trying to output the correct value, we
431 # employ a hack here. We output a dummy invalid value with an incorrect
432 # dtype. This will allow control dependency to work but if using it as an
433 # input, it should typically lead to errors during graph construction due
434 # to dtype mismatch.
435 # TODO(agarwal): Check in the original graph to see if there are any
436 # consumers of this Tensor that use it as an input.
437 pfor._add_conversion(merge.outputs[1],
438 wrap(constant_op.constant(-1.0), False))
439 switch = self._exit_switches[i]
440 # Don't need to worry about switch.output[0] which will feed to Exit node.
441 pfor._add_conversion(switch.outputs[1], wrapped_inp)
442 return pfor
444 def _convert_enter(self, parent_pfor, enter):
445 """Converts an Enter node."""
446 inp, stacked, _ = parent_pfor._convert_helper(enter.op.inputs[0])
447 control_inputs = []
448 for x in enter.op.control_inputs:
449 converted = parent_pfor._convert_helper(x)
450 if not isinstance(converted, ops.Operation):
451 converted = converted.t
452 control_inputs.append(converted)
453 if control_inputs:
454 with ops.control_dependencies(control_inputs):
455 inp = array_ops.identity(inp)
456 return inp, stacked
458 def _maybe_stacked(self, cache, inp):
459 """Heuristic to figure out if the converting inp leads to a stacked value.
462 Args:
463 cache: map from Tensor to boolean indicating stacked/unstacked.
464 inp: input Tensor.
466 Returns:
467 True if `inp` could get stacked. If the function returns False, the
468 converted value should be guaranteed to be unstacked. If returning True,
469 it may or may not be stacked.
470 """
471 if inp in cache:
472 return cache[inp]
473 if not self.op_is_inside_loop(inp.op):
474 return False
475 op = inp.op
476 output = False
477 if op.type in [
478 "Shape",
479 "Rank",
480 "ShapeN",
481 "ZerosLike",
482 "TensorArrayV3",
483 "TensorArraySizeV3",
484 ]:
485 output = False
486 elif _is_stateful_pfor_op(op):
487 # This may be fairly aggressive.
488 output = True
489 elif op.type == "Exit":
490 # This may be fairly aggressive.
491 output = True
492 else:
493 for t in op.inputs:
494 if self._maybe_stacked(cache, t):
495 output = True
496 break
497 cache[inp] = output
498 return output
500 def _create_init_values(self, pfor_input):
501 """Create arguments passed to converted while_loop."""
502 with ops.name_scope("while_init"):
503 loop_len_vector = pfor_input.pfor.loop_len_vector
504 loop_len = loop_len_vector[0]
505 num_outputs = len(self._outputs)
507 inputs = []
508 maybe_stacked_cache = {}
509 # Convert all the Enters. Need to do this before checking for stacking
510 # below.
511 for i, enter in enumerate(self._enters):
512 inp, stacked = self._convert_enter(pfor_input.pfor, enter)
513 inputs.append(inp)
514 maybe_stacked_cache[enter] = stacked
515 # Since this enter node is part of the `loop_vars`, it corresponds to an
516 # output and its preceding switch. We mark this switch's output the same
517 # stackness, to act at the base case for the logic below. Below, we will
518 # be going through the body figuring out which inputs might need to be
519 # stacked and which inputs can safely remain unstacked.
520 if i < num_outputs:
521 maybe_stacked_cache[self._exit_switches[i].outputs[1]] = stacked
523 # Shape invariants for init_values corresponding to self._enters.
524 input_shape_invariants = []
525 # TensorArrays for outputs of converted while loop
526 output_tas = []
527 # Shape invariants for output TensorArrays.
528 ta_shape_invariants = []
529 # List of booleans indicating stackness of inputs, i.e. tensors
530 # corresponding to self._enters.
531 inputs_stacked = []
532 for i, inp in enumerate(inputs):
533 enter = self._enters[i]
534 inp_stacked = self._maybe_stacked(maybe_stacked_cache, enter)
535 # Note that even when an input is unstacked, the body could make it
536 # stacked. we use a heuristic below to figure out if body may be making
537 # it stacked.
538 if i < num_outputs:
539 body_output = self._body_outputs[i]
540 if enter.op in self._pfor_ops:
541 body_output_stacked = self._maybe_stacked(maybe_stacked_cache,
542 body_output)
543 else:
544 # If constructed outside of pfor loop, then the output would not be
545 # stacked.
546 body_output_stacked = False
547 if body_output_stacked and not inp_stacked:
548 inp = _stack(inp, loop_len_vector).t
549 inputs[i] = inp
550 inp_stacked = True
551 # TODO(agarwal): other attributes for the TensorArray ?
552 output_tas.append(tensor_array_ops.TensorArray(inp.dtype, loop_len))
553 ta_shape_invariants.append(tensor_shape.TensorShape(None))
555 inputs_stacked.append(inp_stacked)
556 input_shape_invariants.append(tensor_shape.TensorShape(None))
558 # See documentation for __call__ for the structure of init_values.
559 init_values = [True, pfor_input.pfor.all_indices] + inputs + output_tas
560 # TODO(agarwal): try stricter shape invariants
561 shape_invariants = (
562 [tensor_shape.TensorShape(None),
563 tensor_shape.TensorShape(None)] + input_shape_invariants +
564 ta_shape_invariants)
566 return init_values, inputs_stacked, shape_invariants
568 def _process_cond_unstacked(self, conditions, indices, inputs, output_tas):
569 """Handles case when condition is unstacked.
571 Note that all iterations end together. So we don't need to partition the
572 inputs. When all iterations are done, we write the inputs to the
573 TensorArrays. Note that we only write to index 0 of output_tas. Since all
574 iterations end together, they can all be output together.
575 """
576 not_all_done = array_ops.reshape(conditions, [])
577 new_output_tas = []
578 # pylint: disable=cell-var-from-loop
579 for i, out_ta in enumerate(output_tas):
580 inp = inputs[i]
581 new_output_tas.append(
582 tf_cond.cond(not_all_done, lambda: out_ta,
583 lambda: out_ta.write(0, inp)))
584 # pylint: enable=cell-var-from-loop
585 return not_all_done, indices, inputs, new_output_tas
587 def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked,
588 output_tas):
589 num_outputs = len(self._outputs)
590 # Compute if all iterations are done.
591 not_all_done = math_ops.reduce_any(conditions)
592 conditions_int = math_ops.cast(conditions, dtypes.int32)
593 # Partition the indices.
594 done_indices, new_indices = data_flow_ops.dynamic_partition(
595 indices, conditions_int, 2)
597 new_inputs = []
598 new_output_tas = []
599 for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)):
600 # Partition the inputs.
601 if stacked:
602 done_inp, new_inp = data_flow_ops.dynamic_partition(
603 inp, conditions_int, 2)
604 else:
605 # TODO(agarwal): avoid this stacking. See TODO earlier in
606 # _process_cond_unstacked.
607 done_inp = _stack(inp, [array_ops.size(done_indices)]).t
608 new_inp = inp
609 new_inputs.append(new_inp)
610 # For iterations that are done, write them to TensorArrays.
611 if i < num_outputs:
612 out_ta = output_tas[i]
613 # Note that done_indices can be empty. done_inp should also be empty in
614 # that case.
615 new_output_tas.append(out_ta.scatter(done_indices, done_inp))
616 return not_all_done, new_indices, new_inputs, new_output_tas
618 def _process_body(self, pfor_input, inputs_stacked, new_indices, cond_stacked,
619 new_inputs, not_all_done):
620 """Convert the body function."""
622 def true_fn(control_inputs, body_pfor, body_output, stacked):
623 """Converts the body function for all but last iteration.
625 This essentially converts body_output. Additionally, it needs to handle
626 any control dependencies on the NextIteration node. So it creates another
627 Identity node with the converted dependencies.
628 """
629 converted_control_inp = []
630 for x in control_inputs:
631 for t in x.outputs:
632 converted_control_inp.append(body_pfor._convert_helper(t).t)
633 if stacked:
634 # Note convert always does the stacking.
635 output = body_pfor.convert(body_output)
636 else:
637 output, convert_stacked, _ = body_pfor._convert_helper(body_output)
638 assert convert_stacked == stacked, body_output
639 with ops.control_dependencies(converted_control_inp):
640 return array_ops.identity(output)
642 body_pfor = self._init_pfor(pfor_input.pfor, new_indices, cond_stacked,
643 new_inputs, inputs_stacked)
644 new_outputs = []
646 for i, (body_output,
647 stacked) in enumerate(zip(self._body_outputs, inputs_stacked)):
648 control_inp = self._next_iter_control_inputs[i]
649 out_dtype = body_output.dtype
650 # Note that we want to run the body only if not all pfor iterations are
651 # done. If all are done, we return empty tensors since these values will
652 # not be used. Notice that the value returned by the loop is based on
653 # TensorArrays and not directly on these returned values.
654 # pylint: disable=cell-var-from-loop
655 new_output = tf_cond.cond(
656 not_all_done,
657 lambda: true_fn(control_inp, body_pfor, body_output, stacked),
658 lambda: constant_op.constant([], dtype=out_dtype))
659 # pylint: enable=cell-var-from-loop
660 new_outputs.append(new_output)
661 return new_outputs
663 def __call__(self, pfor_input):
664 """Converter for the while_loop.
666 The conversion of a while_loop is another while_loop.
668 The arguments to this converted while_loop are as follows:
669 not_all_done: Boolean scalar Tensor indicating if all the pfor iterations
670 are done.
671 indices: int32 1-D Tensor storing the id of the iterations that are not
672 done.
673 args: Remaining arguments. These can be divided into 3 categories:
674 - First set of arguments are the tensors that correspond to the initial
675 elements of self._enters. The elements that appear in original while
676 loop's `loop_vars`.
677 - The second set of arguments are the tensors that correspond to the
678 remaining elements of self._enters. These are the tensors that directly
679 enter the original while loop body.
680 - Finally, the last set of arguments are TensorArrays. These TensorArrays
681 correspond to the outputs of the original while_loop, i.e. to the
682 elements in self._outputs. Each TensorArray has `PFor.loop_len`
683 elements, i.e. the number of pfor iterations. At the end, the i'th
684 element of each TensorArray will contain the output computed by the
685 i'th iteration of pfor. Note that elements can be written into these
686 tensors arrays in any order, depending on when the corresponding pfor
687 iteration is done.
688 If the original while_loop had `k` tensors in its `loop_vars` and its body
689 directly captured `m` tensors, the `args` will contain `2 * k + m` values.
691 In each iteration, the while_loop body recomputes the condition for all
692 active pfor iterations to see which of them are now done. It then partitions
693 all the inputs and passes them along to the converted body. Values for all
694 the iterations that are done are written to TensorArrays indexed by the pfor
695 iteration number. When all iterations are done, the TensorArrays are stacked
696 to get the final value.
698 Args:
699 pfor_input: A PForInput object corresponding to the output of any Exit
700 node from this while loop.
702 Returns:
703 List of converted outputs.
704 """
705 # Create init_values that will be passed to the while_loop.
706 init_values, inputs_stacked, shape_invariants = self._create_init_values(
707 pfor_input)
708 # Note that we use a list as a hack since we need the nested function body
709 # to set the value of cond_is_stacked. python2.x doesn't support nonlocal
710 # variables.
711 cond_is_stacked = [None]
713 def cond(not_all_done, *_):
714 return not_all_done
716 def body(not_all_done, indices, *args):
717 # See documentation for __call__ for the structure of *args.
718 num_enters = len(self._enters)
719 inputs = args[:num_enters]
720 output_tas = args[num_enters:]
721 # TODO(agarwal): see which outputs have consumers and only populate the
722 # TensorArrays corresponding to those. Or do those paths get trimmed out
723 # from inside the while_loop body?
724 assert len(inputs) >= len(output_tas)
725 assert len(inputs) == len(inputs_stacked)
727 # Convert condition
728 with ops.name_scope("while_cond"):
729 # Note that we set cond_stacked to True here. At this point we don't
730 # know if it could be loop invariant, hence the conservative value is
731 # to assume stacked.
732 cond_pfor = self._init_pfor(
733 pfor_input.pfor,
734 indices,
735 cond_stacked=True,
736 inputs=inputs,
737 inputs_stacked=inputs_stacked)
738 conditions, cond_stacked, _ = cond_pfor._convert_helper(self._condition)
739 cond_is_stacked[0] = cond_stacked
741 # Recompute the new condition, write outputs of done iterations, and
742 # partition the inputs if needed.
743 if not cond_stacked:
744 (not_all_done, new_indices, new_inputs,
745 new_output_tas) = self._process_cond_unstacked(conditions, indices,
746 inputs, output_tas)
747 else:
748 (not_all_done, new_indices, new_inputs,
749 new_output_tas) = self._process_cond_stacked(conditions, indices,
750 inputs, inputs_stacked,
751 output_tas)
753 # Convert body
754 with ops.name_scope("while_body"):
755 # Compute the outputs from the body.
756 new_outputs = self._process_body(pfor_input, inputs_stacked,
757 new_indices, cond_stacked, new_inputs,
758 not_all_done)
760 # Note that the first num_outputs new values of inputs are computed using
761 # the body. Rest of them were direct Enters into the condition/body and
762 # the partitioning done earlier is sufficient to give the new value.
763 num_outputs = len(self._outputs)
764 new_args = ([not_all_done, new_indices] + new_outputs +
765 list(new_inputs[num_outputs:]) + new_output_tas)
766 return tuple(new_args)
768 while_outputs = while_loop.while_loop(
769 cond, body, init_values, shape_invariants=shape_invariants)
770 output_tas = while_outputs[-len(self._outputs):]
771 outputs = []
772 assert cond_is_stacked[0] is not None
773 for inp_stacked, ta in zip(inputs_stacked, output_tas):
774 if cond_is_stacked[0]:
775 outputs.append(wrap(ta.stack(), True))
776 else:
777 # Note that if while_loop condition is unstacked, all iterations exit at
778 # the same time and we wrote those outputs in index 0 of the tensor
779 # array.
780 outputs.append(wrap(ta.read(0), inp_stacked))
781 return outputs
784class ConversionNotImplementedError(Exception):
785 pass
788class _PforInput:
789 """Input object passed to registered pfor converters."""
791 __slots__ = ["pfor", "_op", "_inputs"]
793 def __init__(self, pfor, op, inputs):
794 """Creates a _PforInput object.
796 Args:
797 pfor: PFor converter object.
798 op: the Operation object that is being converted.
799 inputs: list of WrappedTensor objects representing converted values of the
800 inputs of `op`.
801 """
802 self.pfor = pfor
803 self._op = op
804 self._inputs = inputs
806 def stack_inputs(self, stack_indices=None, tile_variants=False):
807 """Stacks unstacked inputs at `stack_indices`.
809 Args:
810 stack_indices: indices of inputs at which stacking is done. If None,
811 stacking is done at all indices.
812 tile_variants: If True, affected indices which have a variant dtype will
813 be tiled after this operation to match the expected shape of a
814 vectorized tensor. Variants generally need to be un-tiled when they are
815 inputs to operations and tiled when returned.
816 """
817 if stack_indices is None:
818 stack_indices = range(len(self._inputs))
819 length = self.pfor.loop_len_vector
820 for i in stack_indices:
821 inp = self._inputs[i]
822 is_variant = inp.t.dtype == dtypes.variant
823 if not inp.is_stacked:
824 self._inputs[i] = _stack(inp.t, length)
825 if tile_variants and is_variant:
826 self._inputs[i] = wrap(
827 _tile_variant_with_length(self._inputs[i].t, length), True)
828 elif not tile_variants and is_variant:
829 self._inputs[i] = wrap(_untile_variant(self._inputs[i].t), True)
831 def expanddim_inputs_for_broadcast(self):
832 """Reshapes stacked inputs to prepare them for broadcast.
834 Since stacked inputs have an extra leading dimension, automatic broadcasting
835 rules could incorrectly try to expand dimensions before that leading
836 dimension. To avoid that, we reshape these stacked inputs to the maximum
837 rank they will need to be broadcasted to.
838 """
839 if not self._inputs:
840 return
842 # Find max rank
843 def _get_rank(x):
844 rank = array_ops.rank(x.t)
845 if not x.is_stacked:
846 rank += 1
847 return rank
849 ranks = [_get_rank(x) for x in self._inputs]
850 max_rank = ranks[0]
851 for rank in ranks[1:]:
852 max_rank = math_ops.maximum(rank, max_rank)
854 for i, inp in enumerate(self._inputs):
855 if inp.is_stacked:
856 shape = array_ops.shape(inp.t)
857 rank_diff = array_ops.reshape(max_rank - ranks[i], [1])
858 ones = array_ops.tile([1], rank_diff)
859 new_shape = array_ops.concat([shape[:1], ones, shape[1:]], axis=0)
860 self._inputs[i] = wrap(array_ops.reshape(inp.t, new_shape), True)
862 @property
863 def inputs(self):
864 return self._inputs
866 @property
867 def num_inputs(self):
868 return len(self._inputs)
870 def input(self, index):
871 assert len(self._inputs) > index, (index, self._inputs)
872 return self._inputs[index]
874 def stacked_input(self, index):
875 t, is_stacked, _ = self.input(index)
876 if not is_stacked:
877 op_type = self.op_type
878 op_def = getattr(self._op, "op_def", None)
879 if op_def is None:
880 input_name = "at index %d" % index
881 else:
882 input_name = "\"%s\"" % op_def.input_arg[index].name
883 raise ConversionNotImplementedError(
884 f"Input {input_name} of op '{op_type}' expected to be not loop "
885 "invariant.")
886 return t
888 def unstacked_input(self, index):
889 t, is_stacked, _ = self.input(index)
890 if is_stacked:
891 op_type = self.op_type
892 op_def = getattr(self._op, "op_def", None)
893 if op_def is None:
894 input_name = "at index %d" % index
895 else:
896 input_name = "\"%s\"" % op_def.input_arg[index].name
897 raise ConversionNotImplementedError(
898 f"Input {input_name} of op '{op_type}' expected to be loop "
899 "invariant.")
900 return t
902 @property
903 def op(self):
904 return self._op
906 @property
907 def op_type(self):
908 return self._op.type
910 def get_attr(self, attr):
911 return self._op.get_attr(attr)
913 @property
914 def outputs(self):
915 return self._op.outputs
917 def output(self, index):
918 assert index < len(self._op.outputs)
919 return self._op.outputs[index]
922_pfor_converter_registry = {}
925class RegisterPFor:
926 """Utility to register converters for pfor.
928 Usage:
929 @RegisterPFor(foo_op_type)
930 def _foo_converter(pfor_input):
931 ...
933 The above will register conversion function `_foo_converter` for handling
934 conversion of `foo_op_type`. These converters are called during vectorization
935 of a `pfor` loop body. For each operation node in this loop body,
936 the vectorization process will call the converter corresponding to the
937 operation type of the node.
939 During conversion, the registered function will be called with a single
940 argument `pfor_input`, of type `PForInput`, which will contain state needed
941 for the conversion. When the converter is called for a node, all its inputs
942 should already have been converted and these converted values are stored in
943 `pfor_input.inputs`. This registered function should output a list of
944 WrappedTensor objects with the same length as the number of outputs of the
945 node being converted. If the node had zero outputs, then it should return an
946 ops.Operation object. These new sets of nodes should implement the
947 functionality of running that operation for the number of iterations specified
948 by `pfor_input.pfor.loop_len_vector[0]` where the inputs of the node for each
949 iteration are picked from `pfor_inputs.inputs()`.
951 One tricky aspect of the conversion process is keeping track of, and
952 leveraging loop invariance of computation. Each converted input is a
953 WrappedTensor which indicates whether the input was loop invariant or not. If
954 the converted value is loop invariant, its rank should match the rank of the
955 corresponding tensor in the loop body, else its rank is larger by 1. The
956 converter should look at the loop invariance of the inputs and generate new
957 nodes based on that. Note that the converter will not be called if all inputs
958 are loop invariant and the operation is not stateful. The converter should
959 determine if its own output is loop invariant and `wrap` its output
960 accordingly.
962 Example:
964 Here, the converter is trying to convert a Reshape node in the loop body. This
965 node will have two inputs: the tensor to reshape, and the new shape. The
966 example here only handles the case where the shape is loop invariant.
968 @RegisterPFor("Reshape")
969 def _convert_reshape(pfor_input):
970 # We assume that input is not loop invariant. Call to `stacked_input`
971 # asserts that and returns the converted value. This value will have a rank
972 # larger by 1 compared to the rank of the input in the loop body.
973 t = pfor_input.stacked_input(0)
975 # We assume that shape input is loop invariant. Call to `unstacked_input`
976 # asserts that and returns the converted value.
977 shape = pfor_input.unstacked_input(1)
979 # We compute `new_shape` by prepending the number of iterations to the
980 # original shape.
981 new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape],
982 axis=0)
984 # The vectorized output involves reshaping the converted input `t` using
985 # `new_shape`.
986 new_output = array_ops.reshape(t, new_shape)
988 # The converted output is marked as not loop invariant using the call to
989 # wrap.
990 return wrap(new_output, True)
991 """
993 def __init__(self, op_type):
994 """Creates an object to register a converter for op with type `op_type`."""
995 self.op_type = op_type
997 def __call__(self, converter):
998 name = self.op_type
999 assert name not in _pfor_converter_registry, "Re-registering %s " % name
1000 _pfor_converter_registry[name] = converter
1001 return converter
1004class RegisterPForWithArgs(RegisterPFor):
1005 """Utility to register converters for pfor.
1007 Usage:
1008 @RegisteRPFor(foo_op_type, foo=value, ....)
1009 def _foo_converter(pfor_input, foo=None, ....):
1010 ...
1012 See RegisterPFor for details on the conversion function.
1013 `RegisterPForWithArgs` allows binding extra arguments to the
1014 conversion function at registration time.
1015 """
1017 def __init__(self, op_type, *args, **kw_args):
1018 super(RegisterPForWithArgs, self).__init__(op_type)
1019 self._args = args
1020 self._kw_args = kw_args
1022 def __call__(self, converter):
1024 def _f(pfor_input):
1025 return converter(pfor_input, self.op_type, *self._args, **self._kw_args)
1027 super(RegisterPForWithArgs, self).__call__(_f)
1028 return converter
1031# TODO(agarwal): call raw_ops instead of calling these low level routines.
1032def _create_op(op_type, inputs, op_dtypes, attrs=None):
1033 """Utility to create an op."""
1034 op = ops.get_default_graph().create_op(
1035 op_type, inputs, op_dtypes, attrs=attrs, compute_device=True)
1036 flat_attrs = []
1037 # The tape expects an alternating flat list of names and attribute values.
1038 for a in attrs:
1039 flat_attrs.append(str(a))
1040 flat_attrs.append(op.get_attr(str(a)))
1041 execute.record_gradient(op_type, op.inputs, tuple(flat_attrs), op.outputs[:])
1042 return op
1045WrappedTensor = collections.namedtuple("WrappedTensor",
1046 ["t", "is_stacked", "is_sparse_stacked"])
1047"""Wrapper around the result of a Tensor conversion.
1049The additional fields are useful for keeping track of the conversion state as
1050data flows through the ops in the loop body. For every op whose output is a
1051Tensor, its converter should return either a WrappedTensor or a list of
1052WrappedTensors.
1054Args:
1055 t: The converted tensor
1056 is_stacked: True if the tensor is stacked, i.e. represents the results of all
1057 the iterations of the loop, where each row i of the tensor corresponds to
1058 that op's output on iteration i of the loop. False if the tensor is not
1059 stacked, i.e. represents the result of the op on of a single iteration of
1060 the loop, where the result does not vary between iterations.
1061 is_sparse_stacked: True if the tensor corresponds to a component tensor
1062 (indices, values, or dense_shape) of a sparse tensor, and has been logically
1063 stacked via a sparse conversion.
1064"""
1067def wrap(tensor, is_stacked=True, is_sparse_stacked=False):
1068 """Helper to create a WrappedTensor object."""
1069 assert isinstance(is_stacked, bool)
1070 assert isinstance(is_sparse_stacked, bool)
1071 assert isinstance(tensor, ops.Tensor)
1072 assert not is_sparse_stacked or is_stacked, ("If the wrapped tensor is "
1073 "stacked via a sparse "
1074 "conversion, it must also be "
1075 "stacked.")
1076 return WrappedTensor(tensor, is_stacked, is_sparse_stacked)
1079def _wrap_and_tile_variants(tensor, length):
1080 if tensor.dtype == dtypes.variant:
1081 tensor = _tile_variant_with_length(tensor, length)
1082 return wrap(tensor)
1085def _fallback_converter(pfor_input, root_cause="", warn=False):
1086 msg = ("Using a while_loop for converting "
1087 f"{pfor_input.op_type} cause {root_cause}")
1088 if warn:
1089 logging.warning(msg)
1090 else:
1091 logging.debug(msg)
1092 output_dtypes = [x.dtype for x in pfor_input.outputs]
1093 iter_vec = pfor_input.pfor.loop_len_vector
1094 # Use constant value if available, so that output shapes are static.
1095 iter_vec_value = tensor_util.constant_value(iter_vec)
1096 if iter_vec_value is not None:
1097 iters = iter_vec_value[0].item()
1098 else:
1099 iters = iter_vec[0]
1101 def while_body(i, *ta_list):
1102 """Body of while loop."""
1103 inputs = [
1104 x[i, ...] if stacked else x for x, stacked, _ in pfor_input.inputs
1105 ]
1106 op_outputs = _create_op(
1107 pfor_input.op_type,
1108 inputs,
1109 output_dtypes,
1110 attrs=pfor_input.op.node_def.attr).outputs
1112 outputs = []
1113 # TODO(agarwal): Add tf.debugging asserts to check that the shapes across
1114 # the different iterations are the same.
1115 for out, ta in zip(op_outputs, ta_list):
1116 assert isinstance(out, ops.Tensor)
1117 outputs.append(ta.write(i, out))
1118 return tuple([i + 1] + outputs)
1120 ta_list = while_loop.while_loop(
1121 lambda i, *ta: i < iters, while_body, [0] +
1122 [tensor_array_ops.TensorArray(dtype, iters) for dtype in output_dtypes
1123 ])[1:]
1124 return tuple([wrap(ta.stack(), True) for ta in ta_list])
1127class PForConfig:
1128 """A configuration object used to communicate with loop body function."""
1130 def __init__(self):
1131 # This may be set to the number of iterations.
1132 self._maybe_iters = None
1133 # Map from reduction node, created by `reduce`, to the bundle of reduction
1134 # function and arguments.
1135 self._reduce_map = {}
1137 def _has_reductions(self):
1138 """True if some reductions where performed by loop body."""
1139 return len(self._reduce_map)
1141 def _set_iters(self, iters):
1142 """Set number of pfor iterations."""
1143 if isinstance(iters, ops.Tensor):
1144 iters = tensor_util.constant_value(iters)
1145 self._maybe_iters = iters
1147 def reduce(self, fn, *args):
1148 """Performs reduction `fn` on `args` vectorized across pfor iterations.
1150 Note that `fn` is traced once inside the loop function context. Hence any
1151 captures or side-effects will happen in that context. Call to the traced
1152 version of `fn` happens during the construction of the vectorized code.
1154 Note that this currently may not work inside a control flow construct.
1155 Args:
1156 fn: a reduction function. It will be called with arguments that have the
1157 same structure as *args but with individual values whose rank may be
1158 higher by 1 since they represent loop invariant vectorized versions of
1159 the corresponding Tensors in *args.
1160 *args: unvectorized Tensors.
1162 Returns:
1163 The result of running `fn` on the vectorized versions of `*args`. These
1164 outputs will be available as loop invariant values to all the iterations.
1165 """
1166 assert not context.executing_eagerly()
1167 # Creates a concrete function that will be used for reduction.
1168 tensor_specs = []
1169 for arg in args:
1170 if not isinstance(arg, ops.Tensor):
1171 raise ValueError(f"Got a non-Tensor argument {arg} in reduce.")
1172 batched_shape = tensor_shape.TensorShape([self._maybe_iters
1173 ]).concatenate(arg.shape)
1174 tensor_specs.append(
1175 tensor_spec.TensorSpec(shape=batched_shape, dtype=arg.dtype))
1176 concrete_function = def_function.function(fn).get_concrete_function(
1177 *tensor_specs)
1179 # Creates PlaceholderWithDefault and IdentityN nodes corresponding the
1180 # reduction.
1181 pl_outputs = []
1182 with ops.control_dependencies(args):
1183 for output in concrete_function.outputs:
1184 if not isinstance(output, ops.Tensor):
1185 raise ValueError(f"Got a non-Tensor output {output} while running "
1186 "reduce.")
1187 # Note that we use placeholder_with_default just to make XLA happy since
1188 # it does not like placeholder ops.
1189 if output.shape.is_fully_defined():
1190 dummy = array_ops.zeros(output.shape.as_list(), dtype=output.dtype)
1191 pl_outputs.append(
1192 array_ops.placeholder_with_default(dummy, shape=output.shape))
1193 else:
1194 # TODO(agarwal): support case when under XLA and output.shape is not
1195 # fully defined.
1196 pl_outputs.append(
1197 array_ops.placeholder(output.dtype, shape=output.shape))
1199 reduction_op = array_ops.identity_n(pl_outputs)[0].op
1200 self._reduce_map[reduction_op] = (concrete_function, args)
1201 if len(reduction_op.outputs) == 1:
1202 return reduction_op.outputs[0]
1203 else:
1204 return tuple(reduction_op.outputs)
1206 # TODO(agarwal): handle reductions inside control flow constructs.
1207 def reduce_concat(self, x):
1208 """Performs a concat reduction on `x` across pfor iterations.
1210 Note that this currently may not work inside a control flow construct.
1211 Args:
1212 x: an unvectorized Tensor.
1214 Returns:
1215 A Tensor that has rank one higher than `x`. The value is the vectorized
1216 version of `x`, i.e. stacking the value of `x` across different pfor
1217 iterations.
1218 """
1219 return self.reduce(lambda y: y, x)
1221 def reduce_mean(self, x):
1222 """Performs a mean reduction on `x` across pfor iterations.
1224 Note that this currently may not work inside a control flow construct.
1225 Args:
1226 x: an unvectorized Tensor.
1228 Returns:
1229 A Tensor that has same rank as `x`. The value is the mean of the values
1230 of `x` across the pfor iterations.
1231 """
1232 return self.reduce(lambda y: math_ops.reduce_mean(y, axis=0), x)
1234 def reduce_sum(self, x):
1235 """Performs a sum reduction on `x` across pfor iterations.
1237 Note that this currently may not work inside a control flow construct.
1238 Args:
1239 x: an unvectorized Tensor.
1241 Returns:
1242 A Tensor that has same rank as `x`. The value is the sum of the values
1243 of `x` across the pfor iterations.
1244 """
1245 return self.reduce(lambda y: math_ops.reduce_sum(y, axis=0), x)
1247 def _lookup_reduction(self, t):
1248 """Lookups Tensor `t` in the reduction maps."""
1249 assert isinstance(t, ops.Tensor), t
1250 return self._reduce_map.get(t.op)
1253class PFor:
1254 """Implementation of rewrite of parallel-for loops.
1256 This class takes a DAG or a set of DAGs representing the body of a
1257 parallel-for loop, and adds new operations to the graph that implements
1258 functionality equivalent to running that loop body for a specified number of
1259 iterations. This new set of nodes may or may not use a tensorflow loop
1260 construct.
1262 The process of conversion does not delete or change any existing operations.
1263 It only adds operations that efficiently implement the equivalent
1264 functionality. We refer to the added ops as "converted ops".
1266 The conversion process uses a simple greedy heuristic. It walks the loop body
1267 and tries to express the functionality of running each node in a loop with a
1268 new set of nodes. When converting an op several cases are possible:
1269 - The op is not inside the loop body. Hence it can be used as is.
1270 - The op does not depend on the iteration number and is stateless. In this
1271 case, it can be used as is.
1272 - The op is not stateful, and depends on iteration number only through control
1273 dependencies. In this case, we can create a single op with same inputs and
1274 attributes, but with "converted" control dependencies.
1275 - The op is not stateful, and all its inputs are loop invariant. In this
1276 case, similar to above, we can create a single op with same inputs and
1277 attributes, but with "converted" control dependencies.
1278 - The op is stateful or at least one of the inputs is not loop invariant. In
1279 this case, we run the registered converter for that op to create a set of
1280 converted ops. All nodes in the set will have converted control dependencies
1281 corresponding to control dependencies of the original op. If the op returned
1282 multiple outputs, "converted outputs" could be produced by different ops in
1283 this set.
1284 """
1286 def __init__(self,
1287 loop_var,
1288 loop_len,
1289 pfor_ops,
1290 fallback_to_while_loop,
1291 all_indices=None,
1292 all_indices_partitioned=False,
1293 pfor_config=None,
1294 warn=False):
1295 """Creates an object to rewrite a parallel-for loop.
1297 Args:
1298 loop_var: ops.Tensor output of a Placeholder operation. The value should
1299 be an int32 scalar representing the loop iteration number.
1300 loop_len: A scalar or scalar Tensor representing the number of iterations
1301 the loop is run for.
1302 pfor_ops: List of all ops inside the loop body.
1303 fallback_to_while_loop: If True, on failure to vectorize an op, a while
1304 loop is used to sequentially execute that op.
1305 all_indices: If not None, an int32 vector with size `loop_len`
1306 representing the iteration ids that are still active. These values
1307 should be unique and sorted. However they may not be contiguous. This is
1308 typically the case when inside a control flow construct which has
1309 partitioned the indices of the iterations that are being converted.
1310 all_indices_partitioned: If True, this object is being constructed from a
1311 control flow construct where not all the pfor iterations are guaranteed
1312 to be active.
1313 pfor_config: PForConfig object used while constructing the loop body.
1314 warn: Whether or not to warn on while loop conversions.
1315 """
1316 assert isinstance(loop_var, ops.Tensor)
1317 assert loop_var.op.type == "PlaceholderWithDefault"
1318 self._loop_var = loop_var
1319 loop_len_value = tensor_util.constant_value(loop_len)
1320 if loop_len_value is not None:
1321 loop_len = loop_len_value
1322 self._loop_len_vector = ops.convert_to_tensor([loop_len])
1323 else:
1324 self._loop_len_vector = array_ops.reshape(loop_len, [1])
1325 self._all_indices_partitioned = all_indices_partitioned
1326 if all_indices_partitioned:
1327 assert all_indices is not None
1328 self.all_indices = (
1329 math_ops.range(loop_len) if all_indices is None else all_indices)
1331 self._conversion_map = object_identity.ObjectIdentityDictionary()
1332 self._conversion_map[loop_var] = wrap(self.all_indices, True)
1333 self._pfor_ops = set(pfor_ops)
1334 self._pfor_op_ids = set(x._id for x in pfor_ops)
1335 self._fallback_to_while_loop = fallback_to_while_loop
1336 self._warn = warn
1337 self._pfor_config = pfor_config
1339 def op_is_inside_loop(self, op):
1340 """True if op was created inside the pfor loop body."""
1341 assert isinstance(op, ops.Operation)
1342 # Note that we use self._pfor_op_ids for the check and not self._pfor_ops
1343 # since it appears there tensorflow API could return different python
1344 # objects representing the same Operation node.
1345 return op._id in self._pfor_op_ids
1347 def _convert_sparse(self, y):
1348 """Returns the converted value corresponding to SparseTensor y.
1350 For SparseTensors, instead of stacking the component tensors separately,
1351 resulting in component tensors with shapes (N, m, rank), (N, m), and (N,
1352 rank) respectively for indices, values, and dense_shape (where N is the loop
1353 length and m is the number of sparse tensor values per loop iter), we want
1354 to logically stack the SparseTensors, to create a SparseTensor whose
1355 components are size (N * m, rank + 1), (N * m, ), and (rank + 1,)
1356 respectively.
1358 Here, we try to get the conversion of each component tensor.
1359 If the tensors are stacked via a sparse conversion, return the resulting
1360 SparseTensor composed of the converted components. Otherwise, the component
1361 tensors are either unstacked or stacked naively. In the latter case, we
1362 unstack the component tensors to reform loop_len SparseTensor elements,
1363 then correctly batch them.
1365 The unstacked tensors must have the same rank. Each dimension of each
1366 SparseTensor will expand to be the largest among all SparseTensor elements
1367 for that dimension. For example, if there are N SparseTensors of rank 3
1368 being stacked, with N dense shapes, where the i_th shape is (x_i, y_i, z_i),
1369 the new dense shape will be (N, max_i(x_i), max_i(y_i), max_i(z_i)).
1371 Args:
1372 y: A tf.sparse.SparseTensor.
1374 Returns:
1375 A tf.sparse.SparseTensor that is the converted value corresponding to y.
1376 """
1377 outputs = [
1378 self._convert_helper(t) for t in (y.indices, y.values, y.dense_shape)
1379 ]
1380 assert all(isinstance(o, WrappedTensor) for o in outputs)
1382 if all(w.is_sparse_stacked for w in outputs):
1383 return sparse_tensor.SparseTensor(*[w.t for w in outputs])
1385 assert not any(w.is_sparse_stacked for w in outputs), (
1386 "Error converting SparseTensor. All components should be logically "
1387 "stacked, or none.")
1389 # If component tensors were not sparsely stacked, they are either unstacked
1390 # or stacked without knowledge that they are components of sparse tensors.
1391 # In this case, we have to restack them.
1392 return self._restack_sparse_tensor_logically(
1393 *[self._unwrap_or_tile(w) for w in outputs])
1395 def _restack_sparse_tensor_logically(self, indices, values, shape):
1396 sparse_tensor_rank = indices.get_shape().dims[-1].value
1397 if sparse_tensor_rank is not None:
1398 sparse_tensor_rank += 1
1400 def fn(args):
1401 res = gen_sparse_ops.serialize_sparse(
1402 args[0], args[1], args[2], out_type=dtypes.variant)
1403 return res
1405 # Applies a map function to the component tensors to serialize each
1406 # sparse tensor element and batch them all, then deserializes the batch.
1407 # TODO(rachelim): Try to do this without map_fn -- add the right offsets
1408 # to shape and indices tensors instead.
1409 result = map_fn.map_fn(fn, [indices, values, shape], dtype=dtypes.variant)
1410 return sparse_ops.deserialize_sparse(
1411 result, dtype=values.dtype, rank=sparse_tensor_rank)
1413 def _unwrap_or_tile(self, wrapped_tensor):
1414 """Given a wrapped tensor, unwrap if stacked. Otherwise, tiles it."""
1415 output, is_stacked = wrapped_tensor.t, wrapped_tensor.is_stacked
1416 if is_stacked:
1417 return output
1418 else:
1419 return _stack(output, self._loop_len_vector).t
1421 def convert(self, y):
1422 """Returns the converted value corresponding to y.
1424 Args:
1425 y: A ops.Tensor or a ops.Operation object. If latter, y should not have
1426 any outputs.
1428 Returns:
1429 If y does not need to be converted, it returns y as is. Else it returns
1430 the "converted value" corresponding to y.
1431 """
1432 if y is None:
1433 return None
1434 if isinstance(y, sparse_tensor.SparseTensor):
1435 return self._convert_sparse(y)
1436 assert isinstance(y, (ops.Tensor, ops.Operation)), y
1437 output = self._convert_helper(y)
1438 if isinstance(output, WrappedTensor):
1439 assert isinstance(y, ops.Tensor)
1440 return self._unwrap_or_tile(output)
1441 else:
1442 assert isinstance(y, ops.Operation)
1443 assert not y.outputs
1444 assert isinstance(output, ops.Operation)
1445 return output
1447 def _was_converted(self, t):
1448 """True if t is not a conversion of itself."""
1449 converted_t = self._conversion_map[t]
1450 return converted_t.t is not t
1452 def _add_conversion(self, old_output, new_output):
1453 assert isinstance(old_output, (ops.Tensor, ops.Operation)), old_output
1454 assert isinstance(new_output, (WrappedTensor, ops.Operation)), new_output
1455 self._conversion_map[old_output] = new_output
1457 def _convert_reduction(self, y):
1458 # Handle reductions.
1459 if self._pfor_config is None or isinstance(y, ops.Operation):
1460 return None
1461 reduction = self._pfor_config._lookup_reduction(y)
1462 if reduction is None:
1463 return None
1464 (reduction_fn, reduction_args) = reduction
1465 batched_args = []
1466 for reduction_arg in reduction_args:
1467 assert isinstance(reduction_arg, ops.Tensor), reduction_arg
1468 # Tensor being reduced should already be converted due to a control
1469 # dependency on the created placeholder.
1470 # Note that in cases where reduction_arg is in an outer context, one
1471 # needs to locate the corresponding Enter node and use that to lookup
1472 # the conversion.
1473 # TODO(agarwal): handle reductions inside control flow constructs.
1474 assert reduction_arg in self._conversion_map, (
1475 "Unable to handle reduction of %s, possibly as it was used "
1476 "inside a control flow construct. Note that reductions across "
1477 "pfor iterations are currently not supported inside control flow "
1478 "constructs." % reduction_arg)
1479 batched_arg = self._conversion_map[reduction_arg]
1480 batched_args.append(self._unwrap_or_tile(batched_arg))
1481 outputs = reduction_fn(*batched_args)
1482 return [wrap(output, False) for output in nest.flatten(outputs)]
1484 def _convert_helper(self, op_or_tensor):
1485 stack = collections.deque([op_or_tensor])
1486 while stack:
1487 y = stack[0]
1488 if y in self._conversion_map:
1489 assert isinstance(self._conversion_map[y],
1490 (WrappedTensor, ops.Operation))
1491 stack.popleft()
1492 continue
1493 if isinstance(y, ops.Operation):
1494 assert not y.outputs, (
1495 "We only support converting Operation objects with no outputs. "
1496 "Got %s", y)
1497 y_op = y
1498 else:
1499 assert isinstance(y, ops.Tensor), y
1500 y_op = y.op
1502 is_while_loop = y_op.type == "Exit"
1503 if is_while_loop:
1504 while_op = WhileOp(
1505 y, pfor_ops=self._pfor_ops,
1506 fallback_to_while_loop=self.fallback_to_while_loop,
1507 pfor_config=self._pfor_config)
1508 is_inside_loop = while_op.is_inside_loop
1509 # If all nodes in the while_loop graph were created inside the pfor, we
1510 # treat the whole loop subgraph as a single op (y_op) and try to convert
1511 # it. For while_loops that are created completely or partially outside,
1512 # we treat them as external and should be able to simply return the Exit
1513 # node output as is without needing any conversion. Note that for
1514 # while_loops that are partially constructed inside, we assume they will
1515 # be loop invariant. If that is not the case, it will create runtime
1516 # errors since the converted graph would depend on the self._loop_var
1517 # placeholder.
1518 if is_inside_loop:
1519 y_op = while_op
1520 else:
1521 is_inside_loop = self.op_is_inside_loop(y_op)
1523 # If this op was not created inside the loop body, we will return as is.
1524 # 1. Convert inputs and control inputs.
1526 def _add_to_stack(x):
1527 if x not in self._conversion_map:
1528 stack.appendleft(x)
1529 return True
1530 else:
1531 return False
1533 if is_inside_loop:
1534 added_to_stack = False
1535 for inp in y_op.inputs:
1536 added_to_stack |= _add_to_stack(inp)
1537 for cinp in y_op.control_inputs:
1538 if cinp.outputs:
1539 for t in cinp.outputs:
1540 added_to_stack |= _add_to_stack(t)
1541 else:
1542 added_to_stack |= _add_to_stack(cinp)
1543 if added_to_stack:
1544 continue
1546 converted_inputs = [self._conversion_map[inp] for inp in y_op.inputs]
1547 some_input_converted = any(self._was_converted(x) for x in y_op.inputs)
1548 some_input_stacked = any(x.is_stacked for x in converted_inputs)
1550 converted_control_ops = set()
1551 some_control_input_converted = False
1552 for cinp in y_op.control_inputs:
1553 if cinp.outputs:
1554 for t in cinp.outputs:
1555 converted_t = self._conversion_map[t]
1556 if self._was_converted(t):
1557 some_control_input_converted = True
1558 converted_control_ops.add(converted_t.t.op)
1559 else:
1560 converted_cinp = self._conversion_map[cinp]
1561 assert isinstance(converted_cinp, ops.Operation)
1562 if converted_cinp != cinp:
1563 some_control_input_converted = True
1564 converted_control_ops.add(converted_cinp)
1565 converted_control_ops = list(converted_control_ops)
1566 is_stateful = _is_stateful_pfor_op(y_op)
1567 else:
1568 converted_inputs = []
1569 converted_control_ops = []
1570 logging.vlog(3, "converting op:%s\ninputs:%s\ncontrol_inputs:%s", y_op,
1571 converted_inputs, converted_control_ops)
1573 # 2. Convert y_op
1574 # If converting a while_loop, we let the while_loop convertor deal with
1575 # putting the control dependencies appropriately.
1576 control_dependencies = [] if is_while_loop else converted_control_ops
1577 with ops.control_dependencies(control_dependencies), ops.name_scope(
1578 y_op.name + "/pfor/"), ops.get_default_graph()._original_op(y_op):
1579 # Op is a placeholder for a reduction.
1580 reduce_output = self._convert_reduction(y)
1581 if reduce_output is not None:
1582 new_outputs = reduce_output
1583 # None of the inputs and control inputs were converted.
1584 elif ((not is_inside_loop or
1585 (not is_stateful and not some_input_converted and
1586 not some_control_input_converted)) and
1587 y.graph == ops.get_default_graph()):
1588 if y is y_op:
1589 assert not isinstance(y_op, WhileOp)
1590 new_outputs = y_op
1591 else:
1592 new_outputs = [wrap(x, False) for x in y_op.outputs]
1593 elif not (is_stateful or is_while_loop or some_input_stacked):
1594 # All inputs are unstacked or unconverted but some control inputs are
1595 # converted.
1596 # TODO(rachelim): Handle the case where some inputs are sparsely
1597 # stacked (i.e. any(x.is_sparse_stacked for x in converted_inputs))
1598 new_op = _create_op(y_op.type, [x.t for x in converted_inputs],
1599 [x.dtype for x in y_op.outputs],
1600 y_op.node_def.attr)
1601 if y is y_op:
1602 new_outputs = new_op
1603 else:
1604 new_outputs = []
1605 for old_output, new_output in zip(y_op.outputs, new_op.outputs):
1606 handle_data_util.copy_handle_data(old_output, new_output)
1607 new_outputs.append(wrap(new_output, False))
1608 else:
1609 # Either some inputs are not loop invariant or op is stateful.
1610 if hasattr(y_op, "pfor_converter"):
1611 converter = y_op.pfor_converter
1612 else:
1613 converter = _pfor_converter_registry.get(y_op.type, None)
1614 if converter is None:
1615 root_cause = (f"there is no registered converter for this op.")
1616 has_variant_outputs = any(x.dtype == dtypes.variant for x in
1617 y_op.outputs)
1618 has_vectorized_variant_inputs = any(
1619 _is_variant_with_internal_stacking(x) for x in
1620 y_op.inputs)
1621 if (self._fallback_to_while_loop and not has_variant_outputs
1622 and not has_vectorized_variant_inputs):
1623 converter = partial(
1624 _fallback_converter, root_cause=root_cause, warn=self._warn)
1625 else:
1626 message = (f"No pfor vectorization defined for {y_op.type}\n"
1627 f"{y_op}\n inputs: {converted_inputs}.")
1628 if not self._fallback_to_while_loop:
1629 message += ("Consider enabling the fallback_to_while_loop "
1630 "option to pfor, which may run slower.")
1631 raise ValueError(message)
1632 # TODO(rachelim): Handle the case where some inputs are sparsely
1633 # stacked. We should only call the converter if it supports handling
1634 # those inputs.
1635 pfor_inputs = _PforInput(self, y_op, converted_inputs)
1636 try:
1637 try:
1638 new_outputs = converter(pfor_inputs)
1639 except ConversionNotImplementedError as e:
1640 has_vectorized_variant_inputs = any(
1641 _is_variant_with_internal_stacking(x) for x in
1642 y_op.inputs)
1643 if (self._fallback_to_while_loop
1644 and not has_vectorized_variant_inputs):
1645 new_outputs = _fallback_converter(
1646 pfor_inputs, root_cause=str(e))
1647 else:
1648 raise ValueError(str(e)).with_traceback(sys.exc_info()[2])
1649 except Exception as e: # pylint: disable=broad-except
1650 logging.error(
1651 f"Got error while pfor was converting op {y_op} with inputs "
1652 f"{y_op.inputs[:]}\n, converted inputs {pfor_inputs.inputs}\n"
1653 f"Here are the pfor conversion stack traces: {e}")
1654 original_op = y_op
1655 while isinstance(original_op, ops.Operation):
1656 logging.error(
1657 "%s\ncreated at:\n %s", original_op,
1658 " ".join(traceback.format_list(original_op.traceback)))
1659 original_op = original_op._original_op
1660 raise
1662 if isinstance(new_outputs, WrappedTensor):
1663 new_outputs = [new_outputs]
1664 assert isinstance(new_outputs,
1665 (list, tuple, ops.Operation)), new_outputs
1666 logging.vlog(2, f"converted {y_op} {new_outputs}")
1668 # Insert into self._conversion_map
1669 if y is y_op:
1670 assert isinstance(new_outputs, ops.Operation)
1671 self._add_conversion(y_op, new_outputs)
1672 else:
1673 assert len(y_op.outputs) == len(new_outputs), (y_op, y_op.outputs,
1674 new_outputs)
1675 for old_output, new_output in zip(y_op.outputs, new_outputs):
1676 assert isinstance(new_output, WrappedTensor), (new_output, y, y_op)
1677 assert old_output.dtype == new_output.t.dtype, (new_output, y, y_op)
1678 # Set shape for converted output.
1679 output_shape = old_output.shape
1680 if not new_output.is_sparse_stacked:
1681 if new_output.is_stacked:
1682 loop_len = tensor_util.constant_value(self.loop_len_vector)
1683 if loop_len is None:
1684 batch_dim = tensor_shape.TensorShape([None])
1685 else:
1686 batch_dim = tensor_shape.TensorShape(loop_len)
1687 output_shape = batch_dim.concatenate(output_shape)
1688 if _is_variant_with_internal_stacking(new_output.t):
1689 new_output.t.set_shape([])
1690 else:
1691 new_output.t.set_shape(output_shape)
1692 self._add_conversion(old_output, new_output)
1693 stack.popleft()
1695 return self._conversion_map[op_or_tensor]
1697 @property
1698 def loop_len_vector(self):
1699 """Returns a single element vector whose value is number of iterations."""
1700 return self._loop_len_vector
1702 @property
1703 def loop_var(self):
1704 """Returns placeholder loop variable."""
1705 return self._loop_var
1707 @property
1708 def pfor_ops(self):
1709 return self._pfor_ops
1711 @property
1712 def pfor_config(self):
1713 return self._pfor_config
1715 @property
1716 def all_indices_partitioned(self):
1717 """all_indices_partitioned property.
1719 Returns:
1720 True if we are inside a control flow construct and not all pfor iterations
1721 may be active.
1722 """
1723 return self._all_indices_partitioned
1725 @property
1726 def fallback_to_while_loop(self):
1727 return self._fallback_to_while_loop
1730# The code below defines converters for different operations. Please see comment
1731# for RegisterPFor to see how converters should be defined.
1734# image_ops
1737@RegisterPFor("AdjustContrastv2")
1738def _convert_adjust_contrastv2(pfor_input):
1739 images = pfor_input.stacked_input(0)
1740 contrast_factor = pfor_input.unstacked_input(1)
1741 return wrap(gen_image_ops.adjust_contrastv2(images, contrast_factor), True)
1744@RegisterPFor("AdjustHue")
1745def _convert_adjust_hue(pfor_input):
1746 images = pfor_input.stacked_input(0)
1747 delta = pfor_input.unstacked_input(1)
1748 return wrap(gen_image_ops.adjust_hue(images, delta), True)
1751@RegisterPFor("AdjustSaturation")
1752def _convert_adjust_saturation(pfor_input):
1753 images = pfor_input.stacked_input(0)
1754 scale = pfor_input.unstacked_input(1)
1755 return wrap(gen_image_ops.adjust_saturation(images, scale), True)
1758# nn_ops
1761def _flatten_first_two_dims(x):
1762 """Merges first two dimensions."""
1763 old_shape = array_ops.shape(x)
1764 new_shape = array_ops.concat([[-1], old_shape[2:]], axis=0)
1765 return array_ops.reshape(x, new_shape)
1768def _unflatten_first_dim(x, first_dim):
1769 """Splits first dimension into [first_dim, -1]."""
1770 old_shape = array_ops.shape(x)
1771 new_shape = array_ops.concat([first_dim, [-1], old_shape[1:]], axis=0)
1772 return array_ops.reshape(x, new_shape)
1775def _inputs_with_flattening(pfor_input, input_indices):
1776 """Stacks and flattens first dim of inputs at indices `input_indices`."""
1777 if input_indices is None:
1778 input_indices = []
1779 pfor_input.stack_inputs(stack_indices=input_indices)
1780 inputs = []
1781 for i in range(pfor_input.num_inputs):
1782 if i in input_indices:
1783 inp = pfor_input.stacked_input(i)
1784 inp = _flatten_first_two_dims(inp)
1785 else:
1786 inp = pfor_input.unstacked_input(i)
1787 inputs.append(inp)
1788 return inputs
1791@RegisterPForWithArgs("Conv2D", dims=[0])
1792@RegisterPForWithArgs("DepthToSpace", dims=[0])
1793@RegisterPForWithArgs("AvgPool", dims=[0])
1794@RegisterPForWithArgs("AvgPool3D", dims=[0])
1795@RegisterPForWithArgs("MaxPool", dims=[0])
1796@RegisterPForWithArgs("MaxPoolV2", dims=[0])
1797@RegisterPForWithArgs("MaxPool3D", dims=[0])
1798@RegisterPForWithArgs("MaxPool3DGrad", dims=[0, 1, 2])
1799@RegisterPForWithArgs("MaxPoolGrad", dims=[0, 1, 2])
1800@RegisterPForWithArgs("MaxPoolGradV2", dims=[0, 1, 2])
1801@RegisterPForWithArgs("MaxPool3DGradGrad", dims=[0, 1, 2])
1802@RegisterPForWithArgs("MaxPoolGradGrad", dims=[0, 1, 2])
1803@RegisterPForWithArgs("MaxPoolGradGradV2", dims=[0, 1, 2])
1804@RegisterPForWithArgs("SoftmaxCrossEntropyWithLogits", dims=[0, 1])
1805@RegisterPForWithArgs("SparseSoftmaxCrossEntropyWithLogits", dims=[0, 1])
1806@RegisterPForWithArgs("SpaceToDepth", dims=[0])
1807def _convert_flatten_batch(pfor_input, op_type, dims):
1808 del op_type
1809 inputs = _inputs_with_flattening(pfor_input, dims)
1810 outputs = _create_op(
1811 pfor_input.op_type,
1812 inputs, [x.dtype for x in pfor_input.outputs],
1813 attrs=pfor_input.op.node_def.attr).outputs
1814 n = pfor_input.pfor.loop_len_vector
1815 outputs = [_unflatten_first_dim(x, n) for x in outputs]
1816 return [wrap(x, True) for x in outputs]
1819_channel_flatten_input_cache = {}
1822@RegisterPFor("BatchToSpaceND")
1823def _convert_batch_to_space_nd(pfor_input):
1824 inp = pfor_input.stacked_input(0)
1825 block_shape = pfor_input.unstacked_input(1)
1826 crops = pfor_input.unstacked_input(2)
1828 inp_shape = array_ops.shape(inp)
1829 n = pfor_input.pfor.loop_len_vector
1831 # Reshape and transpose to move the vectorization axis inside the axes that
1832 # will move to space.
1833 # Reshape to 4D and transpose
1834 block_size = math_ops.reduce_prod(block_shape)
1835 new_shape = [n[0], block_size, inp_shape[1] // block_size, -1]
1836 inp = array_ops.reshape(inp, new_shape)
1837 inp = array_ops.transpose(inp, [1, 0, 2, 3])
1838 # Reshape back to merge the block, vectorization and batch dimension, and
1839 # restore the other dimensions.
1840 new_shape = array_ops.concat([n * inp_shape[1], inp_shape[2:]], axis=0)
1841 inp = array_ops.reshape(inp, new_shape)
1842 # Call batch_to_space and then split the new batch axis.
1843 output = gen_array_ops.batch_to_space_nd(inp, block_shape, crops)
1844 output = _unflatten_first_dim(output, n)
1845 return wrap(output, True)
1848@RegisterPFor("SpaceToBatchND")
1849def _convert_space_to_batch_nd(pfor_input):
1850 inp = pfor_input.stacked_input(0)
1851 block_shape = pfor_input.unstacked_input(1)
1852 paddings = pfor_input.unstacked_input(2)
1854 n = pfor_input.pfor.loop_len_vector
1855 inp_shape = array_ops.shape(inp)
1856 inp = _flatten_first_two_dims(inp)
1857 output = gen_array_ops.space_to_batch_nd(inp, block_shape, paddings)
1858 output_shape = array_ops.shape(output)
1859 block_size = math_ops.reduce_prod(block_shape)
1860 new_shape = [block_size, n[0], -1]
1861 output = array_ops.reshape(output, new_shape)
1862 output = array_ops.transpose(output, [1, 0, 2])
1863 new_shape = array_ops.concat(
1864 [n, block_size * inp_shape[1:2], output_shape[1:]], axis=0)
1865 output = array_ops.reshape(output, new_shape)
1866 return wrap(output, True)
1869def _channel_flatten_input(x, data_format):
1870 """Merge the stack dimension with the channel dimension.
1872 If S is pfor's stacking dimension, then,
1873 - for SNCHW, we transpose to NSCHW. If N dimension has size 1, the transpose
1874 should be cheap.
1875 - for SNHWC, we transpose to NHWSC.
1876 We then merge the S and C dimension.
1878 Args:
1879 x: ops.Tensor to transform.
1880 data_format: "NCHW" or "NHWC".
1882 Returns:
1883 A 3-element tuple with the transformed value, along with the shape for
1884 reshape and order for transpose required to transform back.
1885 """
1887 graph = ops.get_default_graph()
1888 cache_key = (graph, x.ref(), data_format)
1889 if cache_key not in _channel_flatten_input_cache:
1890 x_shape = array_ops.shape(x)
1891 if data_format == b"NCHW":
1892 order = [1, 0, 2, 3, 4]
1893 shape = array_ops.concat([x_shape[1:2], [-1], x_shape[3:]], axis=0)
1894 reverse_order = order
1895 else:
1896 order = [1, 2, 3, 0, 4]
1897 shape = array_ops.concat([x_shape[1:4], [-1]], axis=0)
1898 reverse_order = [3, 0, 1, 2, 4]
1899 # Move S dimension next to C dimension.
1900 x = array_ops.transpose(x, order)
1901 reverse_shape = array_ops.shape(x)
1902 # Reshape to merge the S and C dimension.
1903 x = array_ops.reshape(x, shape)
1904 outputs = x, reverse_order, reverse_shape
1905 _channel_flatten_input_cache[cache_key] = outputs
1906 else:
1907 outputs = _channel_flatten_input_cache[cache_key]
1908 return outputs
1911# Note that with training=True, running FusedBatchNormV3 on individual examples
1912# is very different from running FusedBatchNormV3 on a batch of those examples.
1913# This is because, for the latter case, the operation can be considered as first
1914# computing the mean and variance over all the examples and then using these
1915# to scale all those examples. This creates a data dependency between these
1916# different "iterations" since the inputs to the scaling step depends on the
1917# statistics coming from all these inputs.
1918# As with other kernels, the conversion here effectively runs the kernel
1919# independently for each iteration, and returns outputs by stacking outputs from
1920# each of those iterations.
1921@RegisterPFor("FusedBatchNormV3")
1922def _convert_fused_batch_norm(pfor_input):
1923 is_training = pfor_input.get_attr("is_training")
1924 # When BatchNorm is used with training=False, mean and variance are provided
1925 # externally and used as is by the op. Thus, we can merge the S and N
1926 # dimensions as we do for regular operations.
1927 # When BatchNorm is used with training=True, mean and variance are computed
1928 # for each channel across the batch dimension (first one). If we merge S and N
1929 # dimensions, mean and variances will be computed over a larger set. So, we
1930 # merge the S and C dimensions instead.
1931 if not is_training:
1932 # We return zeros for batch_mean and batch_variance output. Note that CPU
1933 # and GPU seem to have different behavior for those two outputs. CPU outputs
1934 # zero because these values are not used during inference. GPU outputs
1935 # something, probably real means and variances.
1936 inputs = _inputs_with_flattening(pfor_input, [0])
1937 outputs = _create_op(
1938 pfor_input.op_type,
1939 inputs, [x.dtype for x in pfor_input.outputs],
1940 attrs=pfor_input.op.node_def.attr).outputs
1941 y = outputs[0]
1942 n = pfor_input.pfor.loop_len_vector
1943 y = _unflatten_first_dim(y, n)
1944 mean = pfor_input.unstacked_input(3)
1945 zeros = array_ops.zeros_like(mean)
1946 return [wrap(y, True)] + [wrap(zeros, False)] * 5
1948 pfor_input.stack_inputs()
1949 data_format = pfor_input.get_attr("data_format")
1950 # We merge the first dimension with the "C" dimension, run FusedBatchNormV3,
1951 # and then transpose back.
1952 x = pfor_input.stacked_input(0)
1953 x, reverse_order, reverse_shape = _channel_flatten_input(x, data_format)
1954 # Note that we stack all the other inputs as well so that they are the same
1955 # size as the new size of the channel dimension.
1956 inputs = [x] + [
1957 array_ops.reshape(pfor_input.stacked_input(i), [-1])
1958 for i in range(1, pfor_input.num_inputs)
1959 ]
1960 outputs = _create_op(
1961 pfor_input.op_type,
1962 inputs, [x.dtype for x in pfor_input.outputs],
1963 attrs=pfor_input.op.node_def.attr).outputs
1964 y = outputs[0]
1965 y = array_ops.reshape(y, reverse_shape)
1966 y = array_ops.transpose(y, reverse_order)
1967 n = pfor_input.pfor.loop_len_vector
1968 outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]]
1969 outputs = [y] + outputs
1970 return [wrap(x, True) for x in outputs]
1973@RegisterPFor("FusedBatchNormGradV3")
1974def _convert_fused_batch_norm_grad(pfor_input):
1975 pfor_input.stack_inputs()
1976 data_format = pfor_input.get_attr("data_format")
1977 y_backprop = pfor_input.stacked_input(0)
1978 y_backprop, _, _ = _channel_flatten_input(y_backprop, data_format)
1979 x = pfor_input.stacked_input(1)
1980 x, x_reverse_order, x_reverse_shape = _channel_flatten_input(x, data_format)
1981 inputs = [y_backprop, x] + [
1982 array_ops.reshape(pfor_input.stacked_input(i), [-1])
1983 for i in range(2, pfor_input.num_inputs)
1984 ]
1985 outputs = _create_op(
1986 pfor_input.op_type,
1987 inputs, [x.dtype for x in pfor_input.outputs],
1988 attrs=pfor_input.op.node_def.attr).outputs
1989 x_backprop = outputs[0]
1990 x_backprop = array_ops.reshape(x_backprop, x_reverse_shape)
1991 x_backprop = array_ops.transpose(x_backprop, x_reverse_order)
1992 n = pfor_input.pfor.loop_len_vector
1993 outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]]
1994 outputs = [x_backprop] + outputs
1995 return [wrap(output, True) for output in outputs]
1998@RegisterPForWithArgs("Conv2DBackpropInput", flatten_dims=[2], shape_dim=0)
1999@RegisterPForWithArgs("AvgPoolGrad", flatten_dims=[1], shape_dim=0)
2000@RegisterPForWithArgs("AvgPool3DGrad", flatten_dims=[1], shape_dim=0)
2001def _convert_flatten_batch_shape_input(pfor_input, op_type, flatten_dims,
2002 shape_dim):
2003 del op_type
2004 inputs = _inputs_with_flattening(pfor_input, flatten_dims)
2005 n = pfor_input.pfor.loop_len_vector
2006 # Adjust the `input_sizes` input.
2007 ones = array_ops.ones([array_ops.shape(inputs[shape_dim])[0] - 1],
2008 dtype=n.dtype)
2009 inputs[shape_dim] *= array_ops.concat([n, ones], axis=0)
2010 outputs = _create_op(
2011 pfor_input.op_type,
2012 inputs, [x.dtype for x in pfor_input.outputs],
2013 attrs=pfor_input.op.node_def.attr).outputs
2014 outputs = [_unflatten_first_dim(x, n) for x in outputs]
2015 return [wrap(x, True) for x in outputs]
2018@RegisterPFor("Conv2DBackpropFilter")
2019def _convert_conv2d_backprop_filter(pfor_input):
2020 pfor_input.stack_inputs(stack_indices=[2])
2021 inputs, inputs_stacked, _ = pfor_input.input(0)
2022 filter_sizes = pfor_input.unstacked_input(1)
2023 grads = pfor_input.stacked_input(2)
2024 strides = pfor_input.get_attr("strides")
2025 padding = pfor_input.get_attr("padding")
2026 use_cudnn_on_gpu = pfor_input.get_attr("use_cudnn_on_gpu")
2027 data_format = pfor_input.get_attr("data_format")
2028 dilations = pfor_input.get_attr("dilations")
2029 if inputs_stacked:
2030 # TODO(agarwal): Implement this efficiently.
2031 logging.warning("Conv2DBackpropFilter uses a while_loop. Fix that!")
2033 def while_body(i, ta):
2034 inp_i = inputs[i, ...]
2035 grad_i = grads[i, ...]
2036 output = nn_ops.conv2d_backprop_filter(
2037 inp_i,
2038 filter_sizes,
2039 grad_i,
2040 strides=strides,
2041 padding=padding,
2042 use_cudnn_on_gpu=use_cudnn_on_gpu,
2043 data_format=data_format,
2044 dilations=dilations)
2045 return i + 1, ta.write(i, output)
2047 n = array_ops.reshape(pfor_input.pfor.loop_len_vector, [])
2048 _, ta = while_loop.while_loop(
2049 lambda i, ta: i < n, while_body,
2050 (0, tensor_array_ops.TensorArray(inputs.dtype, n)))
2051 output = ta.stack()
2052 return wrap(output, True)
2053 else:
2054 # We merge the stack dimension with the channel dimension of the gradients
2055 # and pretend we had a larger filter (see change to filter_sizes below).
2056 # Once the filter backprop is computed, we reshape and transpose back
2057 # appropriately.
2058 grads, _, _ = _channel_flatten_input(grads, data_format)
2059 n = pfor_input.pfor.loop_len_vector
2060 old_filter_sizes = filter_sizes
2061 filter_sizes *= array_ops.concat([[1, 1, 1], n], axis=0)
2062 output = nn_ops.conv2d_backprop_filter(
2063 inputs,
2064 filter_sizes,
2065 grads,
2066 strides=strides,
2067 padding=padding,
2068 use_cudnn_on_gpu=use_cudnn_on_gpu,
2069 data_format=data_format,
2070 dilations=dilations)
2071 new_filter_shape = array_ops.concat([old_filter_sizes[:3], n, [-1]], axis=0)
2072 output = array_ops.reshape(output, new_filter_shape)
2073 output = array_ops.transpose(output, [3, 0, 1, 2, 4])
2074 return wrap(output, True)
2077def _flatten_with_inner_dim(x, dim, x_rank):
2078 """Merges the first dim with the specified dim."""
2079 shape = array_ops.shape(x)
2080 x = array_ops.transpose(x,
2081 list(range(1, dim)) + [0] + list(range(dim, x_rank)))
2083 if dim < x_rank - 1:
2084 new_shape_pieces = [shape[1:dim], [-1], shape[dim + 1:]]
2085 else:
2086 new_shape_pieces = [shape[1:dim], [-1]]
2087 new_shape = array_ops.concat(new_shape_pieces, axis=0)
2088 return array_ops.reshape(x, new_shape)
2091def _unflatten_with_inner_dim(x, dim, x_rank, stack_size):
2092 """Undoes _flatten_with_inner_dim."""
2093 shape = array_ops.shape(x)
2094 if dim < x_rank - 1:
2095 new_shape_pieces = [shape[:dim], [stack_size], [-1], shape[dim + 1:]]
2096 else:
2097 new_shape_pieces = [shape[:dim], [stack_size], [-1]]
2098 new_shape = array_ops.concat(new_shape_pieces, axis=0)
2099 x = array_ops.reshape(x, new_shape)
2100 dims_permutation = [dim] + list(range(dim)) + list(range(dim + 1, x_rank + 1))
2101 return array_ops.transpose(x, dims_permutation)
2104@RegisterPFor("DepthwiseConv2dNative")
2105def _convert_depthwise_conv2d_native(pfor_input):
2106 # Kernel can be vectorized, so folding to batch dimension does not work. We
2107 # instead fold into the channel dimension because it is parallel.
2108 stack_size = pfor_input.pfor.loop_len_vector[0]
2109 data_format = pfor_input.get_attr("data_format")
2110 c_dim = 1 if data_format == b"NCHW" else 3
2111 t = _flatten_with_inner_dim(pfor_input.stacked_input(0), c_dim + 1, 5)
2112 kernel = _flatten_with_inner_dim(pfor_input.stacked_input(1), 3, 5)
2113 conv = _create_op(
2114 "DepthwiseConv2dNative", [t, kernel],
2115 [x.dtype for x in pfor_input.outputs],
2116 attrs=pfor_input.op.node_def.attr).outputs[0]
2117 return wrap(_unflatten_with_inner_dim(conv, c_dim, 4, stack_size), True)
2120@RegisterPFor("DepthwiseConv2dNativeBackpropInput")
2121def _convert_depthwise_conv2d_native_backprop_input(pfor_input):
2122 stack_size = pfor_input.pfor.loop_len_vector[0]
2123 input_sizes = pfor_input.unstacked_input(0)
2124 data_format = pfor_input.get_attr("data_format")
2125 c_dim = 1 if data_format == b"NCHW" else 3
2126 input_sizes_mutipliers = [
2127 constant_op.constant([1] * c_dim, dtype=dtypes.int32), [stack_size]
2128 ]
2129 if c_dim < 3:
2130 input_sizes_mutipliers += [
2131 constant_op.constant([1] * (3 - c_dim), dtype=dtypes.int32)
2132 ]
2133 input_sizes *= array_ops.concat(input_sizes_mutipliers, axis=0)
2134 kernel = _flatten_with_inner_dim(pfor_input.stacked_input(1), 3, 5)
2135 out_backprop = _flatten_with_inner_dim(
2136 pfor_input.stacked_input(2), c_dim + 1, 5)
2137 result = _create_op(
2138 "DepthwiseConv2dNativeBackpropInput", [input_sizes, kernel, out_backprop],
2139 [x.dtype for x in pfor_input.outputs],
2140 attrs=pfor_input.op.node_def.attr).outputs[0]
2141 return wrap(_unflatten_with_inner_dim(result, c_dim, 4, stack_size), True)
2144@RegisterPFor("DepthwiseConv2dNativeBackpropFilter")
2145def _convert_depthwise_conv2d_native_backprop_filter(pfor_input):
2146 stack_size = pfor_input.pfor.loop_len_vector[0]
2147 data_format = pfor_input.get_attr("data_format")
2148 c_dim = 1 if data_format == b"NCHW" else 3
2149 inputs = _flatten_with_inner_dim(pfor_input.stacked_input(0), c_dim + 1, 5)
2150 filter_sizes = pfor_input.unstacked_input(1)
2151 filter_sizes_multipliers = [
2152 constant_op.constant([1, 1], dtype=dtypes.int32), [stack_size],
2153 constant_op.constant([1], dtype=dtypes.int32)
2154 ]
2155 filter_sizes *= array_ops.concat(filter_sizes_multipliers, axis=0)
2156 out_backprop = _flatten_with_inner_dim(
2157 pfor_input.stacked_input(2), c_dim + 1, 5)
2158 result = _create_op(
2159 "DepthwiseConv2dNativeBackpropFilter",
2160 [inputs, filter_sizes, out_backprop],
2161 [x.dtype for x in pfor_input.outputs],
2162 attrs=pfor_input.op.node_def.attr).outputs[0]
2163 return wrap(_unflatten_with_inner_dim(result, 2, 4, stack_size), True)
2166@RegisterPForWithArgs("LogSoftmax", gen_nn_ops.log_softmax)
2167@RegisterPForWithArgs("Softmax", gen_nn_ops.softmax)
2168def _convert_softmax(pfor_input, op_type, op_func):
2169 del op_type
2170 return wrap(op_func(pfor_input.stacked_input(0)), True)
2173# array_ops
2176@RegisterPForWithArgs("Identity", array_ops.identity)
2177@RegisterPForWithArgs("StopGradient", array_ops.stop_gradient)
2178@RegisterPForWithArgs("MatrixDiag", array_ops.matrix_diag)
2179@RegisterPForWithArgs("MatrixDiagPart", array_ops.matrix_diag_part)
2180@RegisterPForWithArgs("_EagerConst", array_ops.identity)
2181def _convert_identity(pfor_input, op_type, op_func):
2182 del op_type
2183 return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
2186@RegisterPFor("IdentityN")
2187def _convert_identity_n(pfor_input):
2188 outputs = array_ops.identity_n([x.t for x in pfor_input.inputs])
2189 return [
2190 wrap(out, inp.is_stacked) for out, inp in zip(outputs, pfor_input.inputs)
2191 ]
2194@RegisterPFor("Reshape")
2195def _convert_reshape(pfor_input):
2196 t = pfor_input.stacked_input(0)
2197 shape = pfor_input.unstacked_input(1)
2198 new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0)
2199 return wrap(array_ops.reshape(t, new_shape), True)
2202@RegisterPFor("Fill")
2203def _convert_fill(pfor_input):
2204 dims = pfor_input.unstacked_input(0)
2205 value = pfor_input.stacked_input(1)
2206 # Expand the rank of `value`
2207 new_shape = array_ops.concat(
2208 [[-1], array_ops.ones([array_ops.size(dims)], dtype=dtypes.int32)],
2209 axis=0)
2210 value = array_ops.reshape(value, new_shape)
2211 # Compute the new output shape
2212 new_dims = array_ops.concat([pfor_input.pfor.loop_len_vector, dims], axis=0)
2213 # Broadcast
2214 return wrap(array_ops.broadcast_to(value, new_dims), True)
2217@RegisterPFor("BroadcastTo")
2218def _convert_broadcast_to(pfor_input):
2219 t = pfor_input.stacked_input(0)
2220 shape = pfor_input.unstacked_input(1)
2221 new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0)
2223 # Expand dims of stacked t to broadcast against the new shape.
2224 # TODO(davmre): consider factoring out common code with
2225 # `expanddim_inputs_for_broadcast`, which has similar logic but with
2226 # implicit shapes (of input Tensors) rather than explicit shapes.
2227 rank_diff = array_ops.shape(new_shape)[0] - array_ops.rank(t)
2228 ones = array_ops.tile([1], array_ops.reshape(rank_diff, [1]))
2229 t_shape = array_ops.shape(t)
2230 t_expanded_shape = array_ops.concat([t_shape[:1], ones, t_shape[1:]], axis=0)
2232 return wrap(
2233 array_ops.broadcast_to(array_ops.reshape(t, t_expanded_shape), new_shape),
2234 True)
2237@RegisterPFor("ExpandDims")
2238def _convert_expanddims(pfor_input):
2239 t = pfor_input.stacked_input(0)
2240 dim = pfor_input.unstacked_input(1)
2241 dim += math_ops.cast(dim >= 0, dim.dtype)
2242 return wrap(array_ops.expand_dims(t, axis=dim), True)
2245@RegisterPForWithArgs("LowerBound", gen_array_ops.lower_bound)
2246@RegisterPForWithArgs("UpperBound", gen_array_ops.upper_bound)
2247def _convert_searchsorted(pfor_input, _, op_func):
2248 pfor_input.stack_inputs()
2249 sorted_inputs = _flatten_first_two_dims(pfor_input.stacked_input(0))
2250 values = _flatten_first_two_dims(pfor_input.stacked_input(1))
2251 out_type = pfor_input.get_attr("out_type")
2252 output = op_func(sorted_inputs, values, out_type)
2253 return wrap(
2254 _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector), True)
2257@RegisterPFor("MatrixBandPart")
2258def _convert_matrix_band_part(pfor_input):
2259 t = pfor_input.stacked_input(0)
2260 num_lower = pfor_input.unstacked_input(1)
2261 num_upper = pfor_input.unstacked_input(2)
2262 return wrap(
2263 array_ops.matrix_band_part(t, num_lower=num_lower, num_upper=num_upper),
2264 True)
2267@RegisterPFor("MatrixSetDiag")
2268def _convert_matrix_set_diag(pfor_input):
2269 pfor_input.stack_inputs()
2270 t = pfor_input.stacked_input(0)
2271 diag = pfor_input.stacked_input(1)
2272 return wrap(array_ops.matrix_set_diag(t, diag), True)
2275# Registrations for Matrix{Diag,DiagPart,SetDiag}V2-3.
2276# The input orders defined in the OpKernel and the actual python API are
2277# different (for compatibility with V1), so we cannot use _convert_identity.
2278# v2 is not compatible with v3 and is never exposed on the public API.
2279@RegisterPFor("MatrixDiagV2")
2280@RegisterPFor("MatrixDiagV3")
2281def _convert_matrix_diag_v2(pfor_input):
2282 params = {
2283 "diagonal": pfor_input.stacked_input(0),
2284 "k": pfor_input.unstacked_input(1),
2285 "num_rows": pfor_input.unstacked_input(2),
2286 "num_cols": pfor_input.unstacked_input(3),
2287 "padding_value": pfor_input.unstacked_input(4)
2288 }
2289 if pfor_input.op_type == "MatrixDiagV2":
2290 return wrap(array_ops.matrix_diag_v2(**params), True)
2291 params["align"] = pfor_input.get_attr("align")
2292 return wrap(array_ops.matrix_diag(**params), True)
2295@RegisterPFor("Diag")
2296def _convert_diag(pfor_input):
2297 diag = pfor_input.stacked_input(0)
2298 if diag.shape.ndims == 2:
2299 # We can use matrix_diag.
2300 return wrap(array_ops.matrix_diag(diag), True)
2301 else:
2302 # It is not clear if we can do better than a while loop here with existing
2303 # kernels.
2304 return _fallback_converter(pfor_input, warn=False)
2307# See notes for MatrixDiagV2
2308@RegisterPFor("MatrixDiagPartV2")
2309@RegisterPFor("MatrixDiagPartV3")
2310def _convert_matrix_diag_part_v2(pfor_input):
2311 params = {
2312 "input": pfor_input.stacked_input(0),
2313 "k": pfor_input.unstacked_input(1),
2314 "padding_value": pfor_input.unstacked_input(2)
2315 }
2316 if pfor_input.op_type == "MatrixDiagPartV2":
2317 return wrap(array_ops.matrix_diag_part_v2(**params), True)
2318 params["align"] = pfor_input.get_attr("align")
2319 return wrap(array_ops.matrix_diag_part(**params), True)
2322# See notes for MatrixDiagV2
2323@RegisterPFor("MatrixSetDiagV2")
2324@RegisterPFor("MatrixSetDiagV3")
2325def _convert_matrix_set_diag_v2(pfor_input):
2326 pfor_input.stack_inputs([0, 1])
2327 params = {
2328 "input": pfor_input.stacked_input(0),
2329 "diagonal": pfor_input.stacked_input(1),
2330 "k": pfor_input.unstacked_input(2)
2331 }
2332 if pfor_input.op_type == "MatrixSetDiagV2":
2333 return wrap(array_ops.matrix_set_diag_v2(**params), True)
2334 params["align"] = pfor_input.get_attr("align")
2335 return wrap(array_ops.matrix_set_diag(**params), True)
2338@RegisterPFor("DiagPart")
2339def _convert_diag_part(pfor_input):
2340 inp = pfor_input.stacked_input(0)
2341 if inp.shape.ndims == 3:
2342 # We can use matrix_diag_part.
2343 return wrap(array_ops.matrix_diag_part(inp), True)
2344 else:
2345 # It is not clear if we can do better than a while loop here with existing
2346 # kernels.
2347 return _fallback_converter(pfor_input, warn=False)
2350@RegisterPFor("OneHot")
2351def _convert_one_hot(pfor_input):
2352 indices = pfor_input.stacked_input(0)
2353 depth = pfor_input.unstacked_input(1)
2354 on_value = pfor_input.unstacked_input(2)
2355 off_value = pfor_input.unstacked_input(3)
2356 axis = pfor_input.get_attr("axis")
2357 if axis >= 0:
2358 axis += 1
2359 return wrap(
2360 array_ops.one_hot(indices, depth, on_value, off_value, axis), True)
2363@RegisterPFor("Slice")
2364def _convert_slice(pfor_input):
2365 t = pfor_input.stacked_input(0)
2366 begin, begin_stacked, _ = pfor_input.input(1)
2367 size = pfor_input.unstacked_input(2)
2368 if not begin_stacked:
2369 begin = array_ops.concat([[0], begin], axis=0)
2370 size = array_ops.concat([[-1], size], axis=0)
2371 return wrap(array_ops.slice(t, begin, size), True)
2372 else:
2373 # Handle negative sizes.
2374 #
2375 # If the `begin` entry corresponding to a negative `size` is loop-variant,
2376 # the output would be ragged. This case is not supported. But `size` having
2377 # some negative values and some loop-variant `begin`s is OK (and it's hard
2378 # to tell the difference statically).
2379 original_unstacked_shape = _stack(
2380 array_ops.shape(t)[1:], pfor_input.pfor.loop_len_vector).t
2381 broadcast_size = _stack(size, pfor_input.pfor.loop_len_vector).t
2382 result_shape = array_ops.where(
2383 math_ops.less(broadcast_size, 0),
2384 original_unstacked_shape - begin + broadcast_size + 1, broadcast_size)
2385 result_shape = math_ops.cast(math_ops.reduce_max(result_shape, axis=0),
2386 dtypes.int64)
2388 # Now we enumerate points in the sliced region for each pfor iteration and
2389 # gather them.
2390 cumsize = math_ops.cumprod(result_shape, exclusive=True, reverse=True)
2391 result_num_elements = math_ops.reduce_prod(result_shape)
2392 # Offsets are loop-variant. We first compute loop-invariant gather
2393 # coordinates, then broadcast-add the loop-variant `begin` offsets.
2394 result_base_coordinates = (
2395 math_ops.range(result_num_elements, dtype=dtypes.int64)[:, None]
2396 // cumsize[None, :]) % result_shape[None, :]
2397 result_coordinates = (
2398 begin[:, None, :]
2399 + math_ops.cast(result_base_coordinates, begin.dtype)[None, :, :])
2400 result_flat = array_ops.gather_nd(params=t, indices=result_coordinates,
2401 batch_dims=1)
2402 result_stacked_shape = array_ops.concat(
2403 [math_ops.cast(pfor_input.pfor.loop_len_vector, result_shape.dtype),
2404 result_shape],
2405 axis=0)
2406 return wrap(array_ops.reshape(result_flat, result_stacked_shape), True)
2409@RegisterPFor("Tile")
2410def _convert_tile(pfor_input):
2411 t = pfor_input.stacked_input(0)
2412 multiples = pfor_input.unstacked_input(1)
2413 multiples = array_ops.concat([[1], multiples], 0)
2414 return wrap(array_ops.tile(t, multiples), True)
2417@RegisterPFor("Pack")
2418def _convert_pack(pfor_input):
2419 pfor_input.stack_inputs()
2420 axis = pfor_input.get_attr("axis")
2421 if axis >= 0:
2422 axis += 1
2423 return wrap(
2424 array_ops_stack.stack([x.t for x in pfor_input.inputs], axis=axis), True)
2427@RegisterPFor("Unpack")
2428def _convert_unpack(pfor_input):
2429 value = pfor_input.stacked_input(0)
2430 axis = pfor_input.get_attr("axis")
2431 if axis >= 0:
2432 axis += 1
2433 num = pfor_input.get_attr("num")
2434 return [wrap(x, True) for x
2435 in array_ops_stack.unstack(value, axis=axis, num=num)]
2438@RegisterPFor("Pad")
2439def _convert_pad(pfor_input):
2440 t = pfor_input.stacked_input(0)
2441 paddings = pfor_input.unstacked_input(1)
2442 paddings = array_ops.concat([[[0, 0]], paddings], 0)
2443 return wrap(array_ops.pad(t, paddings, mode="CONSTANT"), True)
2446@RegisterPFor("PadV2")
2447def _convert_pad_v2(pfor_input):
2448 t = pfor_input.stacked_input(0)
2449 paddings = pfor_input.unstacked_input(1)
2450 paddings = array_ops.concat([[[0, 0]], paddings], 0)
2451 return wrap(array_ops.pad_v2(t, paddings, mode="CONSTANT"), True)
2454@RegisterPFor("Split")
2455def _convert_split(pfor_input):
2456 split_dim = pfor_input.unstacked_input(0)
2457 t = pfor_input.stacked_input(1)
2458 num_split = pfor_input.get_attr("num_split")
2459 split_dim += math_ops.cast(split_dim >= 0, dtypes.int32)
2460 return [wrap(x, True) for x in array_ops.split(t, num_split, axis=split_dim)]
2463@RegisterPFor("SplitV")
2464def _convert_split_v(pfor_input):
2465 t = pfor_input.stacked_input(0)
2466 splits = pfor_input.unstacked_input(1)
2467 split_dim = pfor_input.unstacked_input(2)
2468 split_dim += math_ops.cast(split_dim >= 0, dtypes.int32)
2469 return [wrap(x, True) for x in array_ops.split(t, splits, axis=split_dim)]
2472@RegisterPFor("Squeeze")
2473def _convert_squeeze(pfor_input):
2474 t = pfor_input.stacked_input(0)
2475 squeeze_dims = pfor_input.get_attr("squeeze_dims")
2476 squeeze_dims = [i + 1 if i >= 0 else i for i in squeeze_dims]
2477 return wrap(array_ops.squeeze(t, axis=squeeze_dims), True)
2480@RegisterPFor("ReverseV2")
2481def _convert_reverse(pfor_input):
2482 value = pfor_input.stacked_input(0)
2483 axis = pfor_input.unstacked_input(1)
2484 new_axis = array_ops.where_v2(axis >= 0, axis + 1, axis)
2485 return wrap(gen_array_ops.reverse_v2(value, axis=new_axis), True)
2488@RegisterPForWithArgs("Transpose", gen_array_ops.transpose)
2489@RegisterPForWithArgs("ConjugateTranspose", gen_array_ops.conjugate_transpose)
2490def _convert_transpose(pfor_input, _, op_func):
2491 t = pfor_input.stacked_input(0)
2492 perm = pfor_input.unstacked_input(1)
2493 new_perm = array_ops.concat([[0], perm + 1], axis=0)
2494 return wrap(op_func(t, new_perm), True)
2497@RegisterPFor("ZerosLike")
2498def _convert_zeroslike(pfor_input):
2499 t = pfor_input.stacked_input(0)
2500 shape = array_ops.shape(t)[1:]
2501 return wrap(array_ops.zeros(shape, dtype=t.dtype), False)
2504@RegisterPFor("Gather")
2505@RegisterPFor("GatherV2")
2506def _convert_gather(pfor_input):
2507 param, param_stacked, _ = pfor_input.input(0)
2508 indices, indices_stacked, _ = pfor_input.input(1)
2509 batch_dims = pfor_input.get_attr("batch_dims")
2511 op_type = pfor_input.op_type
2512 if op_type == "Gather":
2513 validate_indices = pfor_input.get_attr("validate_indices")
2514 axis = 0
2515 else:
2516 validate_indices = None
2517 # Assume we will never have a Tensor with rank > 2**32.
2518 axis = math_ops.cast(pfor_input.unstacked_input(2), dtypes.int32)
2519 axis_value = tensor_util.constant_value(axis)
2520 if axis_value is not None:
2521 axis = axis_value
2522 if indices_stacked and not param_stacked:
2523 if indices is pfor_input.pfor.all_indices and axis == 0:
2524 param_shape0 = tensor_shape.dimension_value(param.shape[0])
2525 indices_shape0 = tensor_shape.dimension_value(indices.shape[0])
2526 if param_shape0 is not None and indices_shape0 == param_shape0:
2527 # Note that with loops and conditionals, indices may not be contiguous.
2528 # However they will be sorted and unique. So if the shape matches, then
2529 # it must be picking up all the rows of param.
2530 return wrap(param, True)
2532 if batch_dims != 0:
2533 # Convert `batch_dims` to its positive equivalent if necessary.
2534 batch_dims_pos = batch_dims
2535 if batch_dims < 0:
2536 batch_dims_pos += array_ops.rank(indices)
2537 # In order to maintain
2538 # indices.shape[:batch_dims] == params.shape[:batch_dims]
2539 # with stacked indices, we move the first dimension of `indices` to the
2540 # `batch_dims + 1`th position. The (non-batch) index dimensions will be
2541 # inserted into the shape of `output` at the `axis` dimension, which is
2542 # then transposed to the front (below).
2543 order = array_ops.concat([
2544 math_ops.range(1, batch_dims_pos + 1),
2545 [0],
2546 math_ops.range(batch_dims_pos + 1, array_ops.rank(indices))], axis=0)
2547 indices = array_ops.transpose(indices, order)
2549 output = array_ops.gather(
2550 param, indices, validate_indices=validate_indices, axis=axis,
2551 batch_dims=batch_dims)
2552 if axis != 0:
2553 axis = smart_cond.smart_cond(axis < 0,
2554 lambda: axis + array_ops.rank(param),
2555 lambda: ops.convert_to_tensor(axis))
2556 order = array_ops.concat(
2557 [[axis],
2558 math_ops.range(axis),
2559 math_ops.range(axis + 1, array_ops.rank(output))],
2560 axis=0)
2561 output = smart_cond.smart_cond(
2562 math_ops.equal(axis, 0), lambda: output,
2563 lambda: array_ops.transpose(output, order))
2564 return wrap(output, True)
2565 if param_stacked:
2566 pfor_input.stack_inputs(stack_indices=[1])
2567 indices = pfor_input.stacked_input(1)
2568 if isinstance(axis, ops.Tensor):
2569 axis = array_ops.where(axis >= 0, axis + 1, axis)
2570 else:
2571 axis = axis + 1 if axis >= 0 else axis
2572 batch_dims = batch_dims + 1 if batch_dims >= 0 else batch_dims
2573 output = array_ops.gather(param, indices, axis=axis, batch_dims=batch_dims)
2574 return wrap(output, True)
2577@RegisterPFor("GatherNd")
2578def _convert_gather_nd(pfor_input):
2579 # TODO(jmenick): Add support for unstacked params.
2580 pfor_input.stack_inputs(stack_indices=[1])
2581 params = pfor_input.stacked_input(0)
2582 indices = pfor_input.stacked_input(1)
2583 stacked_result = array_ops.gather_nd(params, indices, batch_dims=1)
2584 return wrap(stacked_result, True)
2587@RegisterPFor("ConcatV2")
2588def _convert_concatv2(pfor_input):
2589 n = pfor_input.num_inputs
2590 pfor_input.stack_inputs(stack_indices=range(n - 1))
2591 axis = pfor_input.unstacked_input(n - 1)
2592 axis += math_ops.cast(axis >= 0, axis.dtype)
2593 return wrap(
2594 array_ops.concat([x.t for x in pfor_input.inputs[:n - 1]], axis=axis),
2595 True)
2598@RegisterPFor("StridedSlice")
2599def _convert_strided_slice(pfor_input):
2600 inp = pfor_input.stacked_input(0)
2601 begin = pfor_input.unstacked_input(1)
2602 end = pfor_input.unstacked_input(2)
2603 strides = pfor_input.unstacked_input(3)
2604 begin_mask = pfor_input.get_attr("begin_mask")
2605 end_mask = pfor_input.get_attr("end_mask")
2606 ellipsis_mask = pfor_input.get_attr("ellipsis_mask")
2607 new_axis_mask = pfor_input.get_attr("new_axis_mask")
2608 shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask")
2610 begin = array_ops.concat([[0], begin], axis=0)
2611 end = array_ops.concat([[0], end], axis=0)
2612 strides = array_ops.concat([[1], strides], axis=0)
2613 begin_mask = begin_mask << 1 | 1
2614 end_mask = end_mask << 1 | 1
2615 ellipsis_mask <<= 1
2616 new_axis_mask <<= 1
2617 shrink_axis_mask <<= 1
2618 return wrap(
2619 array_ops.strided_slice(
2620 inp,
2621 begin,
2622 end,
2623 strides,
2624 begin_mask=begin_mask,
2625 end_mask=end_mask,
2626 ellipsis_mask=ellipsis_mask,
2627 new_axis_mask=new_axis_mask,
2628 shrink_axis_mask=shrink_axis_mask), True)
2631@RegisterPFor("StridedSliceGrad")
2632def _convert_strided_slice_grad(pfor_input):
2633 shape = pfor_input.unstacked_input(0)
2634 begin = pfor_input.unstacked_input(1)
2635 end = pfor_input.unstacked_input(2)
2636 strides = pfor_input.unstacked_input(3)
2637 dy = pfor_input.stacked_input(4)
2638 begin_mask = pfor_input.get_attr("begin_mask")
2639 end_mask = pfor_input.get_attr("end_mask")
2640 ellipsis_mask = pfor_input.get_attr("ellipsis_mask")
2641 new_axis_mask = pfor_input.get_attr("new_axis_mask")
2642 shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask")
2644 shape = array_ops.concat(
2645 [math_ops.cast(pfor_input.pfor.loop_len_vector, shape.dtype), shape],
2646 axis=0)
2647 begin = array_ops.concat([[0], begin], axis=0)
2648 end = array_ops.concat([[0], end], axis=0)
2649 strides = array_ops.concat([[1], strides], axis=0)
2650 begin_mask = begin_mask << 1 | 1
2651 end_mask = end_mask << 1 | 1
2652 ellipsis_mask <<= 1
2653 new_axis_mask <<= 1
2654 shrink_axis_mask <<= 1
2655 return wrap(
2656 array_ops.strided_slice_grad(
2657 shape,
2658 begin,
2659 end,
2660 strides,
2661 dy,
2662 begin_mask=begin_mask,
2663 end_mask=end_mask,
2664 ellipsis_mask=ellipsis_mask,
2665 new_axis_mask=new_axis_mask,
2666 shrink_axis_mask=shrink_axis_mask), True)
2669@RegisterPFor("CheckNumerics")
2670def _convert_check_numerics(pfor_input):
2671 t = pfor_input.stacked_input(0)
2672 message = pfor_input.get_attr("message")
2673 return wrap(gen_array_ops.check_numerics(t, message), True)
2676@RegisterPFor("EnsureShape")
2677def _convert_ensure_shape(pfor_input):
2678 t = pfor_input.stacked_input(0)
2679 shape = tensor_shape.TensorShape(pfor_input.get_attr("shape"))
2680 return wrap(gen_array_ops.ensure_shape(t, [None] + shape), True)
2683# manip_ops
2686@RegisterPFor("Roll")
2687def _convert_roll(pfor_input):
2688 t = pfor_input.stacked_input(0)
2689 shift, shift_stacked, _ = pfor_input.input(1)
2690 axis = pfor_input.unstacked_input(2)
2691 if not shift_stacked:
2692 return wrap(manip_ops.roll(t, shift, axis + 1), True)
2693 else:
2694 # `axis` and `shift` may both be vectors, with repeated axes summing the
2695 # corresponding `shift`s. We scatter shifts into a dense array of shape
2696 # [loop_len, num_unstacked_axes] indicating the offset for each axis.
2697 num_unstacked_axes = math_ops.cast(array_ops.rank(t), dtypes.int64) - 1
2698 axis = math_ops.cast(array_ops.reshape(axis, [-1]), dtypes.int64)
2699 loop_len = math_ops.cast(pfor_input.pfor.loop_len_vector[0], dtypes.int64)
2700 shift = math_ops.cast(array_ops.reshape(shift, [loop_len, -1]),
2701 dtypes.int64)
2702 axis_segment_ids = (
2703 math_ops.range(loop_len, dtype=dtypes.int64)[:, None]
2704 * num_unstacked_axes + axis[None, :])
2705 axis_offsets = array_ops.reshape(
2706 math_ops.unsorted_segment_sum(
2707 data=shift, segment_ids=axis_segment_ids,
2708 num_segments=loop_len * num_unstacked_axes),
2709 [loop_len, num_unstacked_axes])
2711 # Determine the coordinates in the input array of each result and gather
2712 # them.
2713 unstacked_shape = array_ops.shape(t, out_type=dtypes.int64)[1:]
2714 cumsize = math_ops.cumprod(unstacked_shape, exclusive=True, reverse=True)
2715 num_unstacked_elements = math_ops.reduce_prod(unstacked_shape)
2716 result_coordinates = (
2717 (math_ops.range(num_unstacked_elements,
2718 dtype=dtypes.int64)[None, :, None]
2719 // cumsize[None, None, :] - axis_offsets[:, None, :])
2720 % unstacked_shape[None, None, :])
2721 result_flat = array_ops.gather_nd(params=t, indices=result_coordinates,
2722 batch_dims=1)
2723 return wrap(array_ops.reshape(result_flat, array_ops.shape(t)),
2724 True)
2726# math_ops
2729@RegisterPFor("MatMul")
2730def _convert_matmul(pfor_input):
2731 # TODO(agarwal): Check if tiling is faster than two transposes.
2732 a, a_stacked, _ = pfor_input.input(0)
2733 b, b_stacked, _ = pfor_input.input(1)
2734 tr_a = pfor_input.get_attr("transpose_a")
2735 tr_b = pfor_input.get_attr("transpose_b")
2736 if a_stacked and b_stacked:
2737 output = wrap(math_ops.matmul(a, b, adjoint_a=tr_a, adjoint_b=tr_b), True)
2738 return output
2739 elif a_stacked:
2740 if tr_a:
2741 a = array_ops.transpose(a, [0, 2, 1])
2742 if a.shape.is_fully_defined():
2743 x, y, z = a.shape
2744 else:
2745 x, y, z = [
2746 array_ops.reshape(i, [])
2747 for i in array_ops.split(array_ops.shape(a), 3)
2748 ]
2749 a = array_ops.reshape(a, [x * y, z])
2750 prod = math_ops.matmul(a, b, transpose_b=tr_b)
2751 return wrap(array_ops.reshape(prod, [x, y, -1]), True)
2752 else:
2753 assert b_stacked
2754 if tr_b:
2755 perm = [2, 0, 1]
2756 b = array_ops.transpose(b, perm)
2757 else:
2758 # As an optimization, if one of the first two dimensions is 1, then we can
2759 # reshape instead of transpose.
2760 # TODO(agarwal): This check can be done inside Transpose kernel.
2761 b_shape = array_ops.shape(b)
2762 min_dim = math_ops.minimum(b_shape[0], b_shape[1])
2763 perm = array_ops.where(
2764 math_ops.equal(min_dim, 1), [0, 1, 2], [1, 0, 2])
2765 new_shape = array_ops_stack.stack([b_shape[1], b_shape[0], b_shape[2]])
2766 b = array_ops.transpose(b, perm)
2767 b = array_ops.reshape(b, new_shape)
2769 if b.shape.is_fully_defined():
2770 x, y, z = b.shape
2771 else:
2772 x, y, z = [
2773 array_ops.reshape(i, [])
2774 for i in array_ops.split(array_ops.shape(b), 3)
2775 ]
2776 b = array_ops.reshape(b, [x, y * z])
2777 prod = math_ops.matmul(a, b, transpose_a=tr_a)
2778 prod = array_ops.reshape(prod, [-1, y, z])
2779 prod = array_ops.transpose(prod, [1, 0, 2])
2780 return wrap(prod, True)
2783# TODO(rmlarsen): Use the converter of BatchMatMulV2 once compatibility window
2784# is met.
2785@RegisterPFor("BatchMatMul")
2786def _convert_batch_mat_mul(pfor_input):
2787 # TODO(agarwal): There may be a more efficient way to do this instead of
2788 # stacking the inputs.
2789 pfor_input.stack_inputs()
2790 x = pfor_input.stacked_input(0)
2791 y = pfor_input.stacked_input(1)
2792 adj_x = pfor_input.get_attr("adj_x")
2793 adj_y = pfor_input.get_attr("adj_y")
2795 x = _flatten_first_two_dims(x)
2796 y = _flatten_first_two_dims(y)
2797 output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
2798 output = _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector)
2799 return wrap(output, True)
2802@RegisterPFor("BatchMatMulV2")
2803def _convert_batch_mat_mul_v2(pfor_input):
2804 pfor_input.expanddim_inputs_for_broadcast()
2805 x = pfor_input.input(0)[0]
2806 y = pfor_input.input(1)[0]
2807 adj_x = pfor_input.get_attr("adj_x")
2808 adj_y = pfor_input.get_attr("adj_y")
2810 output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
2811 return wrap(output, True)
2814@RegisterPForWithArgs("Sum", math_ops.reduce_sum)
2815@RegisterPForWithArgs("Prod", math_ops.reduce_prod)
2816@RegisterPForWithArgs("Max", math_ops.reduce_max)
2817@RegisterPForWithArgs("Min", math_ops.reduce_min)
2818@RegisterPForWithArgs("Mean", math_ops.reduce_mean)
2819@RegisterPForWithArgs("All", math_ops.reduce_all)
2820@RegisterPForWithArgs("Any", math_ops.reduce_any)
2821def _convert_reduction(pfor_input, _, op_func):
2822 t = pfor_input.stacked_input(0)
2823 indices = pfor_input.unstacked_input(1)
2824 # Shift positive indices by one to account for the extra dimension.
2825 indices += math_ops.cast(indices >= 0, indices.dtype)
2826 keep_dims = pfor_input.get_attr("keep_dims")
2827 return wrap(op_func(t, indices, keepdims=keep_dims), True)
2830@RegisterPForWithArgs("ArgMax", math_ops.argmax)
2831@RegisterPForWithArgs("ArgMin", math_ops.argmin)
2832def _convert_argmax_argmin(pfor_input, _, op_func):
2833 t = pfor_input.stacked_input(0)
2834 dimension = pfor_input.unstacked_input(1)
2835 dimension += math_ops.cast(dimension >= 0, dimension.dtype)
2836 output_type = pfor_input.get_attr("output_type")
2837 return wrap(op_func(t, axis=dimension, output_type=output_type), True)
2840@RegisterPFor("Bucketize")
2841def _convert_bucketize(pfor_input):
2842 t = pfor_input.stacked_input(0)
2843 boundaries = pfor_input.get_attr("boundaries")
2844 return wrap(math_ops.bucketize(t, boundaries), True)
2847@RegisterPFor("ClipByValue")
2848def _convert_clip_by_value(pfor_input):
2849 t = pfor_input.stacked_input(0)
2850 clip_value_min = pfor_input.unstacked_input(1)
2851 clip_value_max = pfor_input.unstacked_input(2)
2852 return wrap(gen_math_ops._clip_by_value(t, clip_value_min, clip_value_max),
2853 True)
2856@RegisterPForWithArgs("Cumsum", math_ops.cumsum)
2857@RegisterPForWithArgs("Cumprod", math_ops.cumprod)
2858def _convert_cumfoo(pfor_input, _, op_func):
2859 t = pfor_input.stacked_input(0)
2860 axis = pfor_input.unstacked_input(1)
2861 # Shift positive indices by one to account for the extra dimension.
2862 axis += math_ops.cast(axis >= 0, axis.dtype)
2863 exclusive = pfor_input.get_attr("exclusive")
2864 reverse = pfor_input.get_attr("reverse")
2865 return wrap(op_func(t, axis, exclusive=exclusive, reverse=reverse), True)
2868@RegisterPFor("BiasAdd")
2869def _convert_biasadd(pfor_input):
2870 t, t_stacked, _ = pfor_input.input(0)
2871 bias, bias_stacked, _ = pfor_input.input(1)
2872 data_format = pfor_input.get_attr("data_format").decode()
2873 if bias_stacked:
2874 # BiasAdd only supports 1-D biases, so cast bias to match value and use Add.
2875 pfor_input.expanddim_inputs_for_broadcast()
2876 t, _, _ = pfor_input.input(0)
2877 bias = math_ops.cast(pfor_input.stacked_input(1), t.dtype)
2878 if compat.as_bytes(data_format) == b"NCHW":
2879 b_shape = array_ops.shape(bias)
2880 new_b_shape = array_ops.concat(
2881 [b_shape[:-3], b_shape[-1:], b_shape[-3:-1]], axis=0)
2882 bias = array_ops.reshape(bias, new_b_shape)
2883 return wrap(math_ops.add(t, bias), True)
2884 else:
2885 assert t_stacked, "At least one input to BiasAdd should be loop variant."
2886 if compat.as_bytes(data_format) == b"NCHW":
2887 shape = array_ops.shape(t)
2888 flattened_shape = array_ops.concat([[-1], shape[2:]], axis=0)
2889 t = array_ops.reshape(t, flattened_shape)
2890 t = nn_ops.bias_add(t, bias, data_format="NCHW")
2891 t = array_ops.reshape(t, shape)
2892 return wrap(t, True)
2893 return wrap(nn_ops.bias_add(t, bias, data_format=data_format), True)
2896@RegisterPForWithArgs("UnsortedSegmentSum", math_ops.unsorted_segment_sum)
2897@RegisterPForWithArgs("UnsortedSegmentMax", math_ops.unsorted_segment_max)
2898@RegisterPForWithArgs("UnsortedSegmentMin", math_ops.unsorted_segment_min)
2899@RegisterPForWithArgs("UnsortedSegmentProd", math_ops.unsorted_segment_prod)
2900def _convert_unsortedsegmentsum(pfor_input, _, op_func):
2901 pfor_input.stack_inputs([0, 1])
2902 data = pfor_input.stacked_input(0)
2903 segment_ids = pfor_input.stacked_input(1)
2904 # TODO(agarwal): handle stacked?
2905 num_segments = pfor_input.unstacked_input(2)
2906 if segment_ids.dtype != num_segments.dtype:
2907 segment_ids = math_ops.cast(segment_ids, dtypes.int64)
2908 num_segments = math_ops.cast(num_segments, dtypes.int64)
2909 dtype = segment_ids.dtype
2910 segment_shape = array_ops.shape(segment_ids, out_type=dtype)
2911 n = segment_shape[0]
2912 ones = array_ops.ones_like(segment_shape, dtype=dtype)[1:]
2913 segment_offset = num_segments * math_ops.range(n, dtype=dtype)
2914 segment_offset = array_ops.reshape(segment_offset,
2915 array_ops.concat([[n], ones], axis=0))
2916 segment_ids += segment_offset
2917 num_segments = math_ops.cast(num_segments, dtypes.int64) * math_ops.cast(
2918 n, dtypes.int64)
2919 output = op_func(data, segment_ids, num_segments)
2920 new_output_shape = array_ops.concat(
2921 [[n, -1], array_ops.shape(output)[1:]], axis=0)
2922 output = array_ops.reshape(output, new_output_shape)
2923 return wrap(output, True)
2926def _flatten_array_with_offset(ids, offset_delta, num_rows):
2927 """Flattens a rank 2 tensor, adding an offset to each row."""
2928 # Note that if `ids` is rank 1, it is broadcast to rank 2.
2929 offset_delta = math_ops.cast(offset_delta, ids.dtype)
2930 n = math_ops.cast(num_rows, dtype=ids.dtype)
2931 offsets = math_ops.range(
2932 start=0, limit=n * offset_delta, delta=offset_delta, dtype=ids.dtype)
2933 offsets = array_ops.expand_dims(offsets, -1)
2934 ids += offsets
2935 return array_ops.reshape(ids, [-1])
2938@RegisterPForWithArgs("SparseSegmentSum", math_ops.sparse_segment_sum_v2)
2939@RegisterPForWithArgs("SparseSegmentMean", math_ops.sparse_segment_mean_v2)
2940@RegisterPForWithArgs("SparseSegmentSqrtN", math_ops.sparse_segment_sqrt_n_v2)
2941@RegisterPForWithArgs("SparseSegmentSumWithNumSegments",
2942 math_ops.sparse_segment_sum_v2)
2943@RegisterPForWithArgs("SparseSegmentMeanWithNumSegments",
2944 math_ops.sparse_segment_mean_v2)
2945@RegisterPForWithArgs("SparseSegmentSqrtNWithNumSegments",
2946 math_ops.sparse_segment_sqrt_n_v2)
2947def _convert_sparse_segment(pfor_input, _, op_func):
2948 _, segment_ids_stacked, _ = pfor_input.input(2)
2949 if segment_ids_stacked:
2950 pfor_input.stack_inputs([1])
2951 data, data_stacked, _ = pfor_input.input(0)
2952 indices, _, _ = pfor_input.input(1)
2953 num_inputs = len(pfor_input.inputs)
2954 assert num_inputs in (3, 4)
2955 if num_inputs == 3:
2956 # `segment_ids` needs to be unstacked since otherwise output sizes could
2957 # differ across pfor iterations.
2958 segment_ids = pfor_input.unstacked_input(2)
2959 num_segments = nn_ops.relu(math_ops.reduce_max(segment_ids) + 1)
2960 else:
2961 segment_ids, _, _ = pfor_input.input(2)
2962 num_segments = pfor_input.unstacked_input(3)
2964 n = pfor_input.pfor.loop_len_vector[0]
2965 if data_stacked:
2966 indices = _flatten_array_with_offset(indices, array_ops.shape(data)[1], n)
2967 data = _flatten_first_two_dims(data)
2968 else:
2969 indices = array_ops.reshape(indices, [-1])
2970 segment_ids = _flatten_array_with_offset(segment_ids, num_segments, n)
2972 if num_inputs == 3:
2973 num_segments = None
2974 else:
2975 num_segments *= n
2976 output = op_func(data, indices, segment_ids, num_segments=num_segments)
2977 output = _unflatten_first_dim(output, [n])
2978 return wrap(output, True)
2981@RegisterPForWithArgs("SparseSegmentSumGrad", math_ops.sparse_segment_sum_grad)
2982@RegisterPForWithArgs("SparseSegmentMeanGrad",
2983 math_ops.sparse_segment_mean_grad)
2984@RegisterPForWithArgs("SparseSegmentSqrtNGrad",
2985 math_ops.sparse_segment_sqrt_n_grad)
2986def _convert_sparse_segment_grad(pfor_input, _, op_func):
2987 grad = pfor_input.stacked_input(0)
2988 indices = pfor_input.unstacked_input(1)
2989 segment_ids = pfor_input.unstacked_input(2)
2990 dim0 = pfor_input.unstacked_input(3)
2992 n = pfor_input.pfor.loop_len_vector[0]
2993 indices = _flatten_array_with_offset(indices, dim0, n)
2994 num_segments = nn_ops.relu(math_ops.reduce_max(segment_ids) + 1)
2995 segment_ids = _flatten_array_with_offset(segment_ids, num_segments, n)
2996 grad = _flatten_first_two_dims(grad)
2997 dim0 *= n
2998 output = op_func(grad, indices, segment_ids, dim0)
2999 output = _unflatten_first_dim(output, [n])
3000 return wrap(output, True)
3003@RegisterPFor("Cast")
3004def _convert_cast(pfor_input):
3005 inp = pfor_input.stacked_input(0)
3006 dtype = pfor_input.get_attr("DstT")
3007 return wrap(math_ops.cast(inp, dtype), True)
3010@RegisterPFor("Abs")
3011@RegisterPFor("Acos")
3012@RegisterPFor("Acosh")
3013@RegisterPFor("Add")
3014@RegisterPFor("AddV2")
3015@RegisterPFor("Angle")
3016@RegisterPFor("Asin")
3017@RegisterPFor("Asinh")
3018@RegisterPFor("Atan")
3019@RegisterPFor("Atan2")
3020@RegisterPFor("Atanh")
3021@RegisterPFor("BesselI0")
3022@RegisterPFor("BesselI1")
3023@RegisterPFor("BesselI0e")
3024@RegisterPFor("BesselI1e")
3025@RegisterPFor("BesselK0")
3026@RegisterPFor("BesselK1")
3027@RegisterPFor("BesselK0e")
3028@RegisterPFor("BesselK1e")
3029@RegisterPFor("BesselJ0")
3030@RegisterPFor("BesselJ1")
3031@RegisterPFor("BesselY0")
3032@RegisterPFor("BesselY1")
3033@RegisterPFor("BitwiseAnd")
3034@RegisterPFor("BitwiseOr")
3035@RegisterPFor("BitwiseXor")
3036@RegisterPFor("Ceil")
3037@RegisterPFor("Complex")
3038@RegisterPFor("ComplexAbs")
3039@RegisterPFor("Conj")
3040@RegisterPFor("Cos")
3041@RegisterPFor("Cosh")
3042@RegisterPFor("Dawsn")
3043@RegisterPFor("Digamma")
3044@RegisterPFor("Div")
3045@RegisterPFor("DivNoNan")
3046@RegisterPFor("Elu")
3047@RegisterPFor("Erf")
3048@RegisterPFor("Erfc")
3049@RegisterPFor("Erfinv")
3050@RegisterPFor("Exp")
3051@RegisterPFor("Expint")
3052@RegisterPFor("Expm1")
3053@RegisterPFor("Floor")
3054@RegisterPFor("FloorDiv")
3055@RegisterPFor("FloorMod")
3056@RegisterPFor("FresnelCos")
3057@RegisterPFor("FresnelSin")
3058@RegisterPFor("Greater")
3059@RegisterPFor("GreaterEqual")
3060@RegisterPFor("Igamma")
3061@RegisterPFor("IgammaGradA")
3062@RegisterPFor("Igammac")
3063@RegisterPFor("Imag")
3064@RegisterPFor("Inv")
3065@RegisterPFor("Invert")
3066@RegisterPFor("IsFinite")
3067@RegisterPFor("IsInf")
3068@RegisterPFor("IsNan")
3069@RegisterPFor("LeftShift")
3070@RegisterPFor("Less")
3071@RegisterPFor("LessEqual")
3072@RegisterPFor("Lgamma")
3073@RegisterPFor("Log")
3074@RegisterPFor("Log1p")
3075@RegisterPFor("LogicalAnd")
3076@RegisterPFor("LogicalNot")
3077@RegisterPFor("LogicalOr")
3078@RegisterPFor("LogicalXor")
3079@RegisterPFor("Maximum")
3080@RegisterPFor("Minimum")
3081@RegisterPFor("Mod")
3082@RegisterPFor("Mul")
3083@RegisterPFor("MulNoNan")
3084@RegisterPFor("Ndtri")
3085@RegisterPFor("Neg")
3086@RegisterPFor("Polygamma")
3087@RegisterPFor("Pow")
3088@RegisterPFor("Real")
3089@RegisterPFor("RealDiv")
3090@RegisterPFor("Reciprocal")
3091@RegisterPFor("Relu")
3092@RegisterPFor("Relu6")
3093@RegisterPFor("RightShift")
3094@RegisterPFor("Rint")
3095@RegisterPFor("Round")
3096@RegisterPFor("Rsqrt")
3097@RegisterPFor("Selu")
3098@RegisterPFor("Sigmoid")
3099@RegisterPFor("Sign")
3100@RegisterPFor("Sin")
3101@RegisterPFor("Sinh")
3102@RegisterPFor("Softplus")
3103@RegisterPFor("Softsign")
3104@RegisterPFor("Spence")
3105@RegisterPFor("Sqrt")
3106@RegisterPFor("Square")
3107@RegisterPFor("SquaredDifference")
3108@RegisterPFor("Sub")
3109@RegisterPFor("Tan")
3110@RegisterPFor("Tanh")
3111@RegisterPFor("TruncateDiv")
3112@RegisterPFor("TruncateMod")
3113@RegisterPFor("Xdivy")
3114@RegisterPFor("Xlogy")
3115@RegisterPFor("Xlog1py")
3116@RegisterPFor("Zeta")
3117def _convert_cwise(pfor_input):
3118 if pfor_input.num_inputs > 1:
3119 pfor_input.expanddim_inputs_for_broadcast()
3121 out = _create_op(
3122 pfor_input.op_type, [x.t for x in pfor_input.inputs],
3123 [x.dtype for x in pfor_input.outputs],
3124 attrs=pfor_input.op.node_def.attr).outputs
3125 assert len(out) == 1
3126 out = out[0]
3128 op_output = wrap(out, True)
3129 return op_output
3132@RegisterPFor("XlaSharding")
3133def _convert_xla_sharding(pfor_input):
3134 t = pfor_input.stacked_input(0)
3135 sharding = pfor_input.get_attr("sharding")
3136 return wrap(xla.sharding(t, sharding=sharding), True)
3139@RegisterPFor("LeakyRelu")
3140def _convert_leaky_relu(pfor_input):
3141 t = pfor_input.stacked_input(0)
3142 alpha = pfor_input.get_attr("alpha")
3143 return wrap(gen_nn_ops.leaky_relu(t, alpha=alpha), True)
3146@RegisterPFor("Equal")
3147def _convert_equal(pfor_input):
3148 pfor_input.expanddim_inputs_for_broadcast()
3149 x = pfor_input.input(0)[0]
3150 y = pfor_input.input(1)[0]
3151 incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error")
3152 return wrap(gen_math_ops.equal(
3153 x, y, incompatible_shape_error=incompatible_shape_error), True)
3156@RegisterPFor("NotEqual")
3157def _convert_not_equal(pfor_input):
3158 pfor_input.expanddim_inputs_for_broadcast()
3159 x = pfor_input.input(0)[0]
3160 y = pfor_input.input(1)[0]
3161 incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error")
3162 return wrap(gen_math_ops.not_equal(
3163 x, y, incompatible_shape_error=incompatible_shape_error), True)
3166@RegisterPFor("ApproximateEqual")
3167def _convert_approximate_equal(pfor_input):
3168 pfor_input.expanddim_inputs_for_broadcast()
3169 x = pfor_input.input(0)[0]
3170 y = pfor_input.input(1)[0]
3171 tolerance = pfor_input.get_attr("tolerance")
3172 return wrap(math_ops.approximate_equal(x, y, tolerance=tolerance), True)
3175@RegisterPFor("Shape")
3176def _convert_shape(pfor_input):
3177 out_type = pfor_input.get_attr("out_type")
3178 return wrap(
3179 array_ops.shape(pfor_input.stacked_input(0), out_type=out_type)[1:],
3180 False)
3183@RegisterPFor("ShapeN")
3184def _convert_shape_n(pfor_input):
3185 out_type = pfor_input.get_attr("out_type")
3186 shapes = [
3187 array_ops.shape(x, out_type=out_type)[1:] if stacked else array_ops.shape(
3188 x, out_type=out_type) for x, stacked, _ in pfor_input.inputs
3189 ]
3190 return [wrap(x, False) for x in shapes]
3193@RegisterPFor("Size")
3194def _convert_size(pfor_input):
3195 out_type = pfor_input.get_attr("out_type")
3196 n = math_ops.cast(pfor_input.pfor.loop_len_vector[0], out_type)
3197 return wrap(
3198 array_ops.size(pfor_input.stacked_input(0), out_type=out_type) // n,
3199 False)
3202@RegisterPFor("Rank")
3203def _convert_rank(pfor_input):
3204 return wrap(array_ops.rank(pfor_input.stacked_input(0)) - 1, False)
3207@RegisterPFor("AddN")
3208def _convert_addn(pfor_input):
3209 # AddN does not support broadcasting.
3210 pfor_input.stack_inputs(tile_variants=False)
3211 return _wrap_and_tile_variants(
3212 math_ops.add_n([x.t for x in pfor_input.inputs]),
3213 pfor_input.pfor.loop_len_vector)
3216@RegisterPFor("Cross")
3217def _convert_cross(pfor_input):
3218 pfor_input.stack_inputs()
3219 a = pfor_input.stacked_input(0)
3220 b = pfor_input.stacked_input(1)
3221 return wrap(math_ops.cross(a, b), True)
3224@RegisterPFor("BiasAddGrad")
3225def _convert_biasaddgrad(pfor_input):
3226 grad = pfor_input.stacked_input(0)
3227 fmt = pfor_input.get_attr("data_format")
3228 if fmt == b"NCHW":
3229 output = math_ops.reduce_sum(grad, axis=[1, 3, 4], keepdims=False)
3230 else:
3231 grad_shape = array_ops.shape(grad)
3232 last_dim_shape = grad_shape[-1]
3233 first_dim_shape = grad_shape[0]
3234 output = array_ops.reshape(grad, [first_dim_shape, -1, last_dim_shape])
3235 output = math_ops.reduce_sum(output, axis=[1], keepdims=False)
3236 return wrap(output, True)
3239# Some required ops are not exposed under the tf namespace. Hence relying on
3240# _create_op to create them.
3241@RegisterPForWithArgs("EluGrad")
3242@RegisterPForWithArgs("LeakyReluGrad")
3243@RegisterPForWithArgs("ReciprocalGrad")
3244@RegisterPForWithArgs("Relu6Grad")
3245@RegisterPForWithArgs("ReluGrad")
3246@RegisterPForWithArgs("RsqrtGrad")
3247@RegisterPForWithArgs("SeluGrad")
3248@RegisterPForWithArgs("SigmoidGrad")
3249@RegisterPForWithArgs("SoftplusGrad")
3250@RegisterPForWithArgs("SoftsignGrad")
3251@RegisterPForWithArgs("SqrtGrad")
3252@RegisterPForWithArgs("TanhGrad")
3253def _convert_grads(pfor_input, op_type, *args, **kw_args):
3254 del args
3255 del kw_args
3256 # TODO(agarwal): Looks like these ops don't support broadcasting. Hence we
3257 # have to use tiling here.
3258 pfor_input.stack_inputs()
3259 outputs = _create_op(
3260 op_type, [x.t for x in pfor_input.inputs],
3261 [x.dtype for x in pfor_input.outputs],
3262 attrs=pfor_input.op.node_def.attr).outputs
3263 return [wrap(x, True) for x in outputs]
3266@RegisterPFor("Select")
3267def _convert_select(pfor_input):
3268 pfor_input.stack_inputs()
3269 cond = pfor_input.stacked_input(0)
3270 t = pfor_input.stacked_input(1)
3271 e = pfor_input.stacked_input(2)
3272 cond_rank = array_ops.rank(cond)
3273 cond, t, e = smart_cond.smart_cond(
3274 cond_rank > 1, lambda: _inputs_with_flattening(pfor_input, [0, 1, 2]),
3275 lambda: [cond, t, e])
3276 outputs = _create_op(
3277 pfor_input.op_type, [cond, t, e], [x.dtype for x in pfor_input.outputs],
3278 attrs=pfor_input.op.node_def.attr).outputs
3279 n = pfor_input.pfor.loop_len_vector
3280 out = smart_cond.smart_cond(cond_rank > 1,
3281 lambda: _unflatten_first_dim(outputs[0], n),
3282 lambda: outputs[0])
3283 return [wrap(out, True) for x in outputs]
3286@RegisterPFor("SelectV2")
3287def _convert_selectv2(pfor_input):
3288 pfor_input.expanddim_inputs_for_broadcast()
3289 cond = pfor_input.input(0)[0]
3290 t = pfor_input.input(1)[0]
3291 e = pfor_input.input(2)[0]
3292 out = array_ops.where_v2(cond, t, e)
3293 return wrap(out, True)
3296# random_ops
3299def _transpose_dim_to_front(x, dim):
3300 rank = array_ops.rank(x)
3301 return array_ops.transpose(
3302 x,
3303 perm=array_ops.concat(
3304 [[dim], math_ops.range(0, dim),
3305 math_ops.range(dim + 1, rank)],
3306 axis=0))
3309@RegisterPForWithArgs("RandomUniform")
3310@RegisterPForWithArgs("RandomUniformInt")
3311@RegisterPForWithArgs("RandomStandardNormal")
3312@RegisterPForWithArgs("TruncatedNormal")
3313def _convert_random(pfor_input, op_type, *args, **kw_args):
3314 del args
3315 del kw_args
3316 inputs = [pfor_input.unstacked_input(i) for i in range(pfor_input.num_inputs)]
3317 # inputs[0] is "shape"
3318 inputs[0] = array_ops.concat([pfor_input.pfor.loop_len_vector, inputs[0]],
3319 axis=0)
3320 # TODO(b/222761732): Turn this warning back on when legacy RNGs are
3321 # deprecated.
3322 # logging.warning(
3323 # "Note that %s inside pfor op may not give same output as "
3324 # "inside a sequential loop.", op_type)
3325 outputs = _create_op(
3326 op_type,
3327 inputs, [x.dtype for x in pfor_input.outputs],
3328 attrs=pfor_input.op.node_def.attr).outputs
3329 return [wrap(x, True) for x in outputs]
3332@RegisterPFor("RandomGamma")
3333@RegisterPFor("RandomPoissonV2")
3334def _convert_random_with_param(pfor_input):
3335 shape = pfor_input.unstacked_input(0)
3336 # param is lam (Poisson rate) or alpha (Gamma shape).
3337 param, param_stacked, _ = pfor_input.input(1)
3338 # TODO(b/222761732): Turn this warning back on when legacy RNGs are
3339 # deprecated.
3340 # logging.warning(
3341 # "Note that %s inside pfor op may not give same output as "
3342 # "inside a sequential loop.", pfor_input.op_type)
3344 if param_stacked:
3345 samples = _create_op(
3346 pfor_input.op_type,
3347 inputs=[shape, param],
3348 op_dtypes=[x.dtype for x in pfor_input.outputs],
3349 attrs=pfor_input.op.node_def.attr).outputs[0]
3350 loop_dim = array_ops.shape(shape)[0]
3351 stacked_samples = _transpose_dim_to_front(samples, loop_dim)
3352 else:
3353 shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0)
3354 stacked_samples = _create_op(
3355 pfor_input.op_type,
3356 inputs=[shape, param],
3357 op_dtypes=[x.dtype for x in pfor_input.outputs],
3358 attrs=pfor_input.op.node_def.attr).outputs[0]
3360 return wrap(stacked_samples, True)
3363@RegisterPFor("Multinomial")
3364def _convert_multinomial(pfor_input):
3365 logits, logits_stacked, _ = pfor_input.input(0)
3366 num_samples = pfor_input.unstacked_input(1)
3367 seed = pfor_input.get_attr("seed")
3368 seed2 = pfor_input.get_attr("seed2")
3369 output_dtype = pfor_input.get_attr("output_dtype")
3370 # TODO(b/222761732): Turn this warning back on when legacy RNGs are
3371 # deprecated.
3372 # logging.warning(
3373 # "Note that Multinomial inside pfor op may not give same output as "
3374 # "inside a sequential loop.")
3376 n = pfor_input.pfor.loop_len_vector[0]
3377 if logits_stacked:
3378 flattened_logits = _flatten_first_two_dims(logits)
3379 samples = gen_random_ops.multinomial(
3380 flattened_logits,
3381 num_samples,
3382 seed=seed,
3383 seed2=seed2,
3384 output_dtype=output_dtype)
3385 stacked_samples = _unflatten_first_dim(samples, [n])
3386 else:
3387 samples = gen_random_ops.multinomial(
3388 logits,
3389 num_samples * n,
3390 seed=seed,
3391 seed2=seed2,
3392 output_dtype=output_dtype)
3393 stacked_samples = array_ops.transpose(
3394 array_ops.reshape(samples, [-1, n, num_samples]), [1, 0, 2])
3396 return wrap(stacked_samples, True)
3399@RegisterPFor("StatelessMultinomial")
3400@RegisterPFor("StatelessParameterizedTruncatedNormal")
3401@RegisterPFor("StatelessRandomBinomial")
3402@RegisterPFor("StatelessRandomGammaV2")
3403@RegisterPFor("StatelessRandomNormal")
3404@RegisterPFor("StatelessRandomPoisson")
3405@RegisterPFor("StatelessRandomUniform")
3406@RegisterPFor("StatelessRandomUniformInt")
3407@RegisterPFor("StatelessRandomUniformFullInt")
3408@RegisterPFor("StatelessTruncatedNormal")
3409def _convert_stateless_multinomial(pfor_input):
3410 # Unlike stateful random ops, for stateless ones we want better
3411 # reproducibility based on seed. Hence we don't want to use a similar strategy
3412 # as used for stateful ones where we generate a possibly different set of
3413 # random numbers under vectorization.
3414 # Unfortunately, the kernels currently are not necessarily setup to do this
3415 # efficiently and hence we fallback to a sequential loop for vectorization.
3416 return _fallback_converter(pfor_input, warn=False)
3419# linalg_ops
3422@RegisterPForWithArgs("XlaEinsum")
3423@RegisterPForWithArgs("Einsum")
3424def _convert_einsum(pfor_input, op_type):
3425 # Einsum may have either 1 or 2 inputs.
3426 inputs, input_stacked, _ = zip(*[
3427 pfor_input.input(i)
3428 for i in range(pfor_input.num_inputs)])
3430 # Parse the einsum equation.
3431 equation = pfor_input.get_attr("equation").decode("utf-8")
3432 input_expr, output_expr = equation.split("->")
3433 input_exprs = input_expr.split(",")
3435 # Pick a placeholder symbol to use for the new axis.
3436 chosen_symbol = None
3437 for s in string.ascii_letters:
3438 if s in equation:
3439 continue
3440 else:
3441 chosen_symbol = s
3442 break
3444 if chosen_symbol is None:
3445 raise ValueError("Could not figure out what symbol to use for new axis.")
3447 assert any(input_stacked)
3448 for i in range(len(inputs)):
3449 if input_stacked[i]:
3450 input_exprs[i] = "{}{}".format(chosen_symbol, input_exprs[i])
3451 output_expr = "{}{}".format(chosen_symbol, output_expr)
3453 new_equation = "{}->{}".format(",".join(input_exprs), output_expr)
3455 if op_type == "XlaEinsum":
3456 if len(inputs) == 1:
3457 result = xla.einsum(equation=new_equation, a=inputs[0])
3458 else:
3459 result = xla.einsum(equation=new_equation, a=inputs[0], b=inputs[1])
3460 else:
3461 assert op_type == "Einsum"
3462 result = special_math_ops.einsum(new_equation, *inputs)
3464 return wrap(result, True)
3467@RegisterPFor("Cholesky")
3468def _convert_cholesky(pfor_input):
3469 t = pfor_input.stacked_input(0)
3470 return wrap(linalg_ops.cholesky(t), True)
3473@RegisterPFor("LogMatrixDeterminant")
3474def _convert_log_matrix_determinant(pfor_input):
3475 t = pfor_input.stacked_input(0)
3476 return [wrap(x, True) for x in linalg_ops.log_matrix_determinant(t)]
3479@RegisterPFor("MatrixInverse")
3480def _convert_matrix_inverse(pfor_input):
3481 t = pfor_input.stacked_input(0)
3482 adjoint = pfor_input.get_attr("adjoint")
3483 return wrap(gen_linalg_ops.matrix_inverse(t, adjoint=adjoint), True)
3486@RegisterPFor("MatrixSolve")
3487def _convert_matrix_solve(pfor_input):
3488 pfor_input.stack_inputs()
3489 matrix = pfor_input.stacked_input(0)
3490 rhs = pfor_input.stacked_input(1)
3491 adjoint = pfor_input.get_attr("adjoint")
3492 output = gen_linalg_ops.matrix_solve(
3493 matrix, rhs, adjoint=adjoint)
3494 return wrap(output, True)
3497@RegisterPFor("MatrixTriangularSolve")
3498def _convert_matrix_triangular_solve(pfor_input):
3499 pfor_input.expanddim_inputs_for_broadcast()
3500 matrix = pfor_input.input(0)[0]
3501 rhs = pfor_input.input(1)[0]
3502 lower = pfor_input.get_attr("lower")
3503 adjoint = pfor_input.get_attr("adjoint")
3504 output = linalg_ops.matrix_triangular_solve(
3505 matrix, rhs, lower=lower, adjoint=adjoint)
3506 return wrap(output, True)
3509@RegisterPFor("SelfAdjointEigV2")
3510def _convert_self_adjoint_eig(pfor_input):
3511 t = pfor_input.stacked_input(0)
3512 compute_v = pfor_input.get_attr("compute_v")
3513 e, v = gen_linalg_ops.self_adjoint_eig_v2(t, compute_v=compute_v)
3514 # If compute_v is False, v will have shape [0].
3515 return wrap(e, True), wrap(v, compute_v)
3518# logging_ops
3521@RegisterPFor("Assert")
3522def _convert_assert(pfor_input):
3523 cond, cond_stacked, _ = pfor_input.input(0)
3524 if cond_stacked:
3525 cond = math_ops.reduce_all(cond)
3527 data_list = [x.t for x in pfor_input.inputs][1:]
3528 return _create_op(
3529 "Assert", [cond] + data_list, [], attrs=pfor_input.op.node_def.attr)
3532@RegisterPFor("Print")
3533def _convert_print(pfor_input):
3534 # Note that we don't stack all the inputs. Hence unstacked values are printed
3535 # once here vs multiple times in a while_loop.
3536 pfor_input.stack_inputs([0])
3537 outputs = _create_op(
3538 "Print", [x.t for x in pfor_input.inputs],
3539 [x.dtype for x in pfor_input.outputs],
3540 attrs=pfor_input.op.node_def.attr).outputs
3541 return [wrap(x, True) for x in outputs]
3544@RegisterPFor("PrintV2")
3545def _convert_print_v2(pfor_input):
3546 # Print the full input Tensor(s), including the batch dimension if stacked.
3547 return _create_op(
3548 "PrintV2", [x.t for x in pfor_input.inputs],
3549 [x.dtype for x in pfor_input.outputs],
3550 attrs=pfor_input.op.node_def.attr)
3553@RegisterPFor("StringFormat")
3554def _convert_string_format(pfor_input):
3555 # Format using the full input Tensor(s), including the batch dimension if
3556 # stacked.
3557 op = _create_op(
3558 "StringFormat", [x.t for x in pfor_input.inputs],
3559 [x.dtype for x in pfor_input.outputs],
3560 attrs=pfor_input.op.node_def.attr)
3561 return [wrap(output, False) for output in op.outputs]
3564# data_flow_ops
3566# TensorArray conversion is tricky since we don't support arrays of
3567# TensorArrays. For converting them, we consider two distinct cases:
3568#
3569# 1. The array is constructed outside the pfor call, and read/written inside the
3570# loop.
3571# This is an easier case since we don't need to make an array of TensorArrays.
3572# A correctness requirement is that these parallel iterations shouldn't attempt
3573# to write to the same location. Hence at conversion time we disallow indices to
3574# be loop-invariant as that would guarantee a collision. Even if the indices are
3575# not loop-invariant, they could conflict and that shall trigger runtime errors.
3576#
3577# 2. The array is constructed and used entirely inside each pfor iteration.
3578# For simplicity, here we require that the indices used for write/scatter are
3579# "unstacked". Otherwise it becomes hard to merge the TensorArrays created in
3580# different pfor iterations. We consider two sub_cases:
3581#
3582# 2a Elements written to the array are "stacked"
3583# To simulate multiple TensorArrays, we may increase the dimension of each
3584# element of the array. i.e. the i_th row of the j_th entry of the converted
3585# TensorArray corresponds to the j_th entry of the TensorArray in the i_th
3586# pfor iteration.
3587#
3588# 2b Elements written to the array are "unstacked"
3589# In this case we don't increase the dimensions to avoid redundant tiling. Each
3590# iteration is trying to write the same value. So we convert that to a single
3591# write.
3592#
3593# Here are some tricks used to implement the above:
3594# - TensorArrayV3 constructor encodes the element shape as an attr. Instead of
3595# trying to trace whether future writes are stacked or unstacked in order to set
3596# this attr, we set it to correspond to unknown shape.
3597# - We use the "flow" output of the different ops to track whether the array
3598# elements are stacked or unstacked. If a stacked write/scatter is done, we make
3599# the flow stacked as well.
3600# - We use some heuristic traversal of the graph to track whether the
3601# TensorArray handle was created inside or outside the pfor loop.
3604@RegisterPFor("TensorArrayV3")
3605def _convert_tensor_array_v3(pfor_input):
3606 size = pfor_input.unstacked_input(0)
3607 dtype = pfor_input.get_attr("dtype")
3608 dynamic_size = pfor_input.get_attr("dynamic_size")
3609 clear_after_read = pfor_input.get_attr("clear_after_read")
3610 identical_element_shapes = pfor_input.get_attr("identical_element_shapes")
3611 tensor_array_name = pfor_input.get_attr("tensor_array_name")
3612 handle, flow = data_flow_ops.tensor_array_v3(
3613 size,
3614 dtype=dtype,
3615 # We don't set element shape since we don't know if writes are stacked or
3616 # not yet.
3617 element_shape=None,
3618 dynamic_size=dynamic_size,
3619 clear_after_read=clear_after_read,
3620 identical_element_shapes=identical_element_shapes,
3621 tensor_array_name=tensor_array_name)
3622 # Note we keep flow unstacked for now since we don't know if writes will be
3623 # stacked or not.
3624 return wrap(handle, False), wrap(flow, False)
3627@RegisterPFor("TensorArraySizeV3")
3628def _convert_tensor_array_size_v3(pfor_input):
3629 handle = pfor_input.unstacked_input(0)
3630 flow, flow_stacked, _ = pfor_input.input(1)
3631 if flow_stacked:
3632 flow = _unstack_flow(flow)
3633 size = data_flow_ops.tensor_array_size_v3(handle, flow)
3634 return wrap(size, False)
3637def _handle_inside_pfor(pfor_input, handle):
3638 """Returns True if handle was created inside the pfor loop."""
3639 # We use some heuristic to find the original TensorArray creation op.
3640 # The logic should handle the common cases (except cond based subgraphs).
3641 # In theory the user could perform different operations on the handle (like
3642 # Reshape, stack multiple handles, etc) which could break this logic.
3643 # TODO(agarwal): handle Switch/Merge.
3644 while handle.op.type in ("Enter", "Identity"):
3645 handle = handle.op.inputs[0]
3646 if handle.op.type not in [
3647 "TensorArrayV3", "TensorArrayGradV3", "TensorArrayGradWithShape"
3648 ]:
3649 raise ValueError(f"Unable to find source for handle {handle}.")
3650 else:
3651 return pfor_input.pfor.op_is_inside_loop(handle.op)
3654def _unstack_flow(value):
3655 # TODO(agarwal): consider looking if this is a Tile op then get its input.
3656 # This may avoid running the Tile operations.
3657 return array_ops.gather(value, 0)
3660@RegisterPFor("TensorArrayReadV3")
3661def _convert_tensor_array_read_v3(pfor_input):
3662 handle = pfor_input.unstacked_input(0)
3663 index, index_stacked, _ = pfor_input.input(1)
3664 dtype = pfor_input.get_attr("dtype")
3665 flow, flow_stacked, _ = pfor_input.input(2)
3666 if flow_stacked:
3667 flow = _unstack_flow(flow)
3669 is_inside_pfor = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
3670 if is_inside_pfor:
3671 # Note that if we are inside a control flow construct inside the pfor, and
3672 # only some of the iterations are doing the read (i.e.
3673 # `all_indices_partitioned` is True), then the read operation should only
3674 # return values for the currently active pfor iterations (`all_indices`
3675 # below). Hence, whenever the returned value is stacked (i.e. `flow` is
3676 # stacked), we may need to do an extra gather after reading the values. Also
3677 # note that if `is_inside` is false, then values in the tensor array are
3678 # unstacked. So the check is only needed in this branch.
3679 all_indices = pfor_input.pfor.all_indices
3680 all_indices_partitioned = pfor_input.pfor.all_indices_partitioned
3681 # Note: flow_stacked indicates if values in the TensorArray are stacked or
3682 # not.
3683 if index_stacked:
3684 if flow_stacked:
3685 raise ValueError(
3686 "It looks like TensorArrayReadV3 was called on a TensorArray whose"
3687 " values are not loop-invariant, and the read indices were also"
3688 " not loop invariant. This is currently unsupported.")
3689 value = data_flow_ops.tensor_array_gather_v3(
3690 handle, index, flow, dtype=dtype)
3691 return wrap(value, True)
3692 value = data_flow_ops.tensor_array_read_v3(handle, index, flow, dtype=dtype)
3693 if flow_stacked and all_indices_partitioned:
3694 value = array_ops.gather(value, all_indices)
3695 return wrap(value, flow_stacked)
3696 # Values in the TensorArray should be unstacked (since different iterations
3697 # couldn't write to the same location). So whether output is stacked or not
3698 # depends on index_stacked.
3699 if index_stacked:
3700 value = data_flow_ops.tensor_array_gather_v3(
3701 handle, index, flow, dtype=dtype)
3702 else:
3703 value = data_flow_ops.tensor_array_read_v3(handle, index, flow, dtype=dtype)
3704 return wrap(value, index_stacked)
3707@RegisterPFor("TensorArrayWriteV3")
3708def _convert_tensor_array_write_v3(pfor_input):
3709 handle = pfor_input.unstacked_input(0)
3710 index, index_stacked, _ = pfor_input.input(1)
3711 value, value_stacked, _ = pfor_input.input(2)
3712 flow, flow_stacked, _ = pfor_input.input(3)
3713 if value_stacked and pfor_input.pfor.all_indices_partitioned:
3714 # Looks like we are in a control flow in a pfor where not all iterations are
3715 # active now. We don't allow that since that could lead to different indices
3716 # having different shapes which will be hard to merge later.
3717 raise ValueError("Writing non loop invariant values to TensorArray from "
3718 "inside a while_loop/cond not supported.")
3719 if flow_stacked:
3720 flow = _unstack_flow(flow)
3721 is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
3722 if is_inside:
3723 if index_stacked:
3724 raise ValueError(f"Need indices for {handle} to be loop invariant.")
3725 if not flow_stacked and not value_stacked:
3726 flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow)
3727 return wrap(flow_out, False)
3728 else:
3729 if not value_stacked:
3730 value = _stack(value, pfor_input.pfor.loop_len_vector).t
3731 # TODO(agarwal): Note that if flow is unstacked and value is stacked, then
3732 # this may or may not be a safe situation. flow is unstacked both for a
3733 # freshly created TensorArray, as well as after unstacked values are
3734 # written to it. If it is the latter, then we cannot write a stacked value
3735 # now since that may cause runtime errors due to different shapes in the
3736 # array. At the moment we are not able to handle this gracefully and
3737 # distinguish between the two cases. That would require some heuristic
3738 # traversal of the graph to figure out whether all the writes are
3739 # unstacked or not.
3740 flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow)
3741 return _stack(flow_out, pfor_input.pfor.loop_len_vector)
3742 else:
3743 if not index_stacked:
3744 raise ValueError(f"Need indices for {handle} to be not loop invariant.")
3745 # Note that even when index_stacked is true, actual values in index may
3746 # still not be unique. However that will cause runtime error when executing
3747 # the scatter operation below.
3748 if not value_stacked:
3749 value = _stack(value, pfor_input.pfor.loop_len_vector).t
3750 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, index, value, flow)
3751 return _stack(flow_out, pfor_input.pfor.loop_len_vector)
3754def _transpose_first_two_dims(value):
3755 # TODO(agarwal): optimize if one of the dims == 1.
3756 value_shape = array_ops.shape(value)
3757 v0 = value_shape[0]
3758 v1 = value_shape[1]
3759 value = array_ops.reshape(value, [v0, v1, -1])
3760 value = array_ops.transpose(value, [1, 0, 2])
3761 new_shape = array_ops.concat([[v1, v0], value_shape[2:]], axis=0)
3762 return array_ops.reshape(value, new_shape)
3765@RegisterPFor("TensorArrayGatherV3")
3766def _convert_tensor_array_gather_v3(pfor_input):
3767 handle = pfor_input.unstacked_input(0)
3768 indices, indices_stacked, _ = pfor_input.input(1)
3769 indices = array_ops.reshape(indices, [-1])
3770 flow, flow_stacked, _ = pfor_input.input(2)
3771 if flow_stacked:
3772 flow = _unstack_flow(flow)
3773 dtype = pfor_input.get_attr("dtype")
3774 # TODO(agarwal): support element_shape attr?
3776 n = pfor_input.pfor.loop_len_vector
3777 value = data_flow_ops.tensor_array_gather_v3(
3778 handle, indices, flow, dtype=dtype)
3779 is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
3780 if is_inside:
3781 # flow_stacked indicates if values in the TensorArray are stacked or not.
3782 if indices_stacked:
3783 if flow_stacked:
3784 raise ValueError(
3785 "It looks like TensorArrayGatherV3 was called on a TensorArray "
3786 "whose values are not loop-invariant, and the indices were also "
3787 "not loop invariant. This is currently unsupported.")
3788 else:
3789 value = _unflatten_first_dim(value, n)
3790 return wrap(value, True)
3791 else:
3792 if flow_stacked:
3793 # Since elements in this array are stacked and `value` was produced by
3794 # gather, its first two dims are "gathered elements" and "stack
3795 # dimension". Our semantics require these two to be flipped.
3796 value = _transpose_first_two_dims(value)
3797 return wrap(value, flow_stacked)
3798 else:
3799 # Values in the TensorArray should be unstacked (since different iterations
3800 # couldn't write to the same location). So whether output is stacked or not
3801 # depends on indices_stacked.
3802 if indices_stacked:
3803 value = _unflatten_first_dim(value, n)
3804 return wrap(value, indices_stacked)
3807@RegisterPFor("TensorArrayScatterV3")
3808def _convert_tensor_array_scatter_v3(pfor_input):
3809 handle = pfor_input.unstacked_input(0)
3810 indices, indices_stacked, _ = pfor_input.input(1)
3811 indices = array_ops.reshape(indices, [-1])
3812 value, value_stacked, _ = pfor_input.input(2)
3813 flow, flow_stacked, _ = pfor_input.input(3)
3815 if flow_stacked:
3816 flow = _unstack_flow(flow)
3818 is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
3819 if is_inside:
3820 if indices_stacked:
3821 raise ValueError(f"Need indices for {handle} to be loop invariant.")
3822 # Note that flow_stacked indicates if existing values in the array are
3823 # stacked or not.
3824 if not flow_stacked and not value_stacked:
3825 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value,
3826 flow)
3827 return wrap(flow_out, False)
3828 if not value_stacked:
3829 # TODO(agarwal): tile in the second dimension directly instead of
3830 # transposing below.
3831 value = _stack(value, pfor_input.pfor.loop_len_vector).t
3833 value = _transpose_first_two_dims(value)
3834 # TODO(agarwal): Note that if a previous write was unstacked, flow will be
3835 # unstacked, and a stacked value may be written here which may cause
3836 # runtime error due to different elements having different shape. We do
3837 # not try to prevent that.
3838 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value,
3839 flow)
3840 return _stack(flow_out, pfor_input.pfor.loop_len_vector)
3841 if not indices_stacked:
3842 raise ValueError(f"Need indices for {handle} to be not loop invariant.")
3843 if not value_stacked:
3844 value = _stack(value, pfor_input.pfor.loop_len_vector).t
3845 value = _flatten_first_two_dims(value)
3846 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, flow)
3847 return _stack(flow_out, pfor_input.pfor.loop_len_vector)
3850@RegisterPFor("TensorArrayGradV3")
3851def _convert_tensor_array_grad_v3(pfor_input):
3852 handle = pfor_input.unstacked_input(0)
3853 flow, flow_stacked, _ = pfor_input.input(1)
3854 if flow_stacked:
3855 flow = _unstack_flow(flow)
3856 source = pfor_input.get_attr("source")
3857 # TODO(agarwal): For now, we assume that gradients are stacked if the
3858 # TensorArrayGradV3 call is being done inside the pfor. Getting that wrong
3859 # will give runtime error due to incorrect shape being written to the
3860 # accumulator. It is difficult to know in advance if gradients written will be
3861 # stacked or not. Note that flow being stacked is not indicative of the
3862 # gradient being stacked or not. Revisit this later.
3863 shape_to_prepend = pfor_input.pfor.loop_len_vector
3864 grad_handle, flow_out = data_flow_ops.tensor_array_grad_with_shape(
3865 handle=handle,
3866 flow_in=flow,
3867 shape_to_prepend=shape_to_prepend,
3868 source=source)
3869 flow_out = _stack(flow_out, pfor_input.pfor.loop_len_vector).t
3870 return [wrap(grad_handle, False), wrap(flow_out, True)]
3873def _stack_tensor_list_shape(shape, first_dim):
3874 shape_value = tensor_util.constant_value(shape)
3875 # Note that negative values in the shape are used to signify unknown shapes
3876 # and are handled in a special way.
3877 if shape_value is not None:
3878 shape_value = np.asarray(shape_value)
3879 if -1 in shape_value:
3880 return constant_op.constant(-1)
3881 elif not shape_value.size:
3882 return first_dim
3883 else:
3884 shape = array_ops.reshape(shape, [-1])
3885 return tf_cond.cond(
3886 math_ops.reduce_any(shape < 0),
3887 lambda: constant_op.constant(-1),
3888 lambda: array_ops.concat([first_dim, shape], axis=0))
3891def _tile_variant_with_length(t, length):
3892 """stacks `t` `length` times."""
3893 if _is_variant_with_internal_stacking(t):
3894 # The content of TensorLists is vectorized, not the variant itself.
3895 return t
3896 original_tensor = t
3897 t.set_shape([])
3898 t = array_ops.reshape(t, [-1])
3899 with ops.device("CPU:0"):
3900 result = array_ops.tile(t, length)
3901 # TODO(b/169968286): Should regular shape functions do handle data
3902 # propagation here?
3903 handle_data_util.copy_handle_data(original_tensor, result)
3904 return result
3907def _tile_variant(t, pfor_input):
3908 """stacks `t` according to its loop context."""
3909 return _tile_variant_with_length(t, pfor_input.pfor.loop_len_vector)
3912def _untile_variant(t):
3913 if _is_variant_with_internal_stacking(t):
3914 # The content of TensorLists is vectorized, not the variant itself.
3915 if not t.shape.is_compatible_with([]):
3916 raise AssertionError(
3917 ("Unexpectedly saw a vectorized variant (e.g. TensorList) with "
3918 f"non-scalar shape: {t!r}"))
3919 return t
3920 return array_ops.gather(t, 0)
3923@RegisterPFor("OptionalFromValue")
3924def _convert_optional_from_value(pfor_input):
3925 pfor_input.stack_inputs()
3926 return wrap(
3927 gen_optional_ops.optional_from_value([x.t for x in pfor_input.inputs]),
3928 True,
3929 )
3932@RegisterPFor("OptionalGetValue")
3933def _convert_optional_get_value(pfor_input):
3934 handle = pfor_input.stacked_input(0)
3935 output_types = pfor_input.get_attr("output_types")
3936 original_output_shapes = pfor_input.get_attr("output_shapes")
3937 output_shapes = []
3938 for shape in original_output_shapes:
3939 shape = tensor_shape.TensorShape(shape)
3940 loop_len_value = tensor_util.constant_value(pfor_input.pfor.loop_len_vector)
3941 loop_len_shape = tensor_shape.TensorShape(
3942 [loop_len_value[0] if loop_len_value is not None else None]
3943 )
3944 shape = loop_len_shape.concatenate(shape)
3945 output_shapes.append(shape.as_proto())
3946 results = gen_optional_ops.optional_get_value(
3947 handle, output_types, output_shapes
3948 )
3949 return [wrap(t, True) for t in results]
3952@RegisterPFor("TensorListReserve")
3953def _convert_tensor_list_reserve(pfor_input):
3954 element_shape = pfor_input.unstacked_input(0)
3955 num_elements = pfor_input.unstacked_input(1)
3956 element_dtype = pfor_input.get_attr("element_dtype")
3958 # Prepend a dimension to element_shape.
3959 element_shape = _stack_tensor_list_shape(element_shape,
3960 pfor_input.pfor.loop_len_vector)
3961 handle = list_ops.tensor_list_reserve(
3962 element_shape, num_elements, element_dtype=element_dtype)
3964 return wrap(_tile_variant(handle, pfor_input), True)
3967@RegisterPFor("TensorListElementShape")
3968def _convert_tensor_list_element_shape(pfor_input):
3969 handle = _untile_variant(pfor_input.stacked_input(0))
3970 shape_type = pfor_input.get_attr("shape_type")
3971 shape = list_ops.tensor_list_element_shape(handle, shape_type)
3972 shape = array_ops.reshape(shape, [-1])
3973 shape = shape[1:]
3974 return wrap(shape, False)
3977@RegisterPFor("TensorListLength")
3978def _convert_tensor_list_length(pfor_input):
3979 handle = _untile_variant(pfor_input.stacked_input(0))
3980 return wrap(list_ops.tensor_list_length(handle), False)
3983def _stack_tensor_list(handle, dtype, loop_len_vector, element_shape=None):
3984 if element_shape is None:
3985 element_shape = list_ops.tensor_list_element_shape(handle, dtypes.int32)
3986 length = list_ops.tensor_list_length(handle)
3987 new_handle = list_ops.tensor_list_reserve(
3988 _stack_tensor_list_shape(element_shape, loop_len_vector), length, dtype)
3990 def _body_fn(i, h):
3991 elem = list_ops.tensor_list_get_item(handle, i, dtype, element_shape)
3992 elem = _stack(elem, loop_len_vector).t
3993 return i + 1, list_ops.tensor_list_set_item(h, i, elem)
3995 return while_loop.while_loop(lambda i, _: i < length, _body_fn,
3996 [0, new_handle])[1]
3999@RegisterPFor("TensorListGetItem")
4000def _convert_tensor_list_get_item(pfor_input):
4001 handle, handle_stacked, _ = pfor_input.input(0)
4002 index, index_stacked, _ = pfor_input.input(1)
4003 element_shape = pfor_input.unstacked_input(2)
4004 element_dtype = pfor_input.get_attr("element_dtype")
4006 if handle_stacked:
4007 handle = _untile_variant(handle)
4008 element_shape = _stack_tensor_list_shape(element_shape,
4009 pfor_input.pfor.loop_len_vector)
4010 if index_stacked:
4011 # We use a sequential loop since that may be more efficient than first
4012 # gathering and concatenating all the element corresponding to `index`,
4013 # and then doing a gather on it.
4014 def _map_fn(i):
4015 item_i = list_ops.tensor_list_get_item(
4016 handle,
4017 index[i],
4018 element_dtype=element_dtype)
4019 return array_ops.gather(item_i, i)
4021 output = map_fn.map_fn(_map_fn, pfor_input.pfor.all_indices)
4022 return wrap(output, True)
4023 else:
4024 output = list_ops.tensor_list_get_item(
4025 handle,
4026 index,
4027 element_shape=element_shape,
4028 element_dtype=element_dtype)
4029 return wrap(output, True)
4030 else:
4031 assert index_stacked
4032 return wrap(
4033 list_ops.tensor_list_gather(
4034 handle,
4035 index,
4036 element_shape=element_shape,
4037 element_dtype=element_dtype), True)
4040@RegisterPFor("TensorListSetItem")
4041def _convert_tensor_array_set_item(pfor_input):
4042 handle, handle_stacked, _ = pfor_input.input(0)
4043 index, index_stacked, _ = pfor_input.input(1)
4044 item, item_stacked, _ = pfor_input.input(2)
4046 if not handle_stacked:
4047 # Special case where we can statically guarantee that the indices are
4048 # disjoint.
4049 if index is pfor_input.pfor.all_indices:
4050 if not item_stacked:
4051 item = _stack(item, pfor_input.pfor.loop_len_vector).t
4052 return wrap(
4053 list_ops.tensor_list_scatter(item, index, input_handle=handle), False)
4054 else:
4055 handle = _stack_tensor_list(handle, item.dtype,
4056 pfor_input.pfor.loop_len_vector)
4057 else:
4058 handle = _untile_variant(handle)
4060 if index_stacked:
4061 # TODO(agarwal): handle this.
4062 raise ValueError("Vectorizing writes to a TensorList with loop "
4063 "variant indices is currently unsupported.")
4065 else:
4066 if not item_stacked:
4067 item = _stack(item, pfor_input.pfor.loop_len_vector).t
4068 handle = list_ops.tensor_list_set_item(handle, index, item)
4069 return wrap(_tile_variant(handle, pfor_input), True)
4072@RegisterPFor("TensorListPushBack")
4073def _convert_tensor_list_push_back(pfor_input):
4074 handle, handle_stacked, _ = pfor_input.input(0)
4075 tensor, tensor_stacked, _ = pfor_input.input(1)
4076 if handle_stacked:
4077 handle = _untile_variant(handle)
4078 else:
4079 handle = _stack_tensor_list(handle, tensor.dtype,
4080 pfor_input.pfor.loop_len_vector)
4081 if not tensor_stacked:
4082 tensor = _stack(tensor, pfor_input.pfor.loop_len_vector).t
4083 handle = list_ops.tensor_list_push_back(handle, tensor)
4084 return wrap(_tile_variant(handle, pfor_input), True)
4087@RegisterPFor("TensorListPopBack")
4088def _convert_tensor_array_push_back(pfor_input):
4089 handle = pfor_input.stacked_input(0)
4090 element_shape = pfor_input.unstacked_input(1)
4091 handle = _untile_variant(handle)
4093 if element_shape.shape.ndims == 0:
4094 # Default / unspecified
4095 vectorized_shape = -1
4096 else:
4097 # PopBack has an element shape set when it's the gradient of PushBack, only
4098 # used when the list is uninitialized.
4099 vectorized_shape = array_ops.concat(
4100 [pfor_input.pfor.loop_len_vector, element_shape], axis=0)
4102 output_handle, tensor = gen_list_ops.tensor_list_pop_back(
4103 input_handle=handle, element_dtype=pfor_input.get_attr("element_dtype"),
4104 element_shape=vectorized_shape)
4105 return wrap(output_handle, True), wrap(tensor, True)
4108@RegisterPFor("TensorListConcatV2")
4109def _convert_tensor_list_concat_v2(pfor_input):
4110 input_handle = pfor_input.stacked_input(0)
4111 element_shape = pfor_input.unstacked_input(1)
4112 leading_dims = pfor_input.unstacked_input(2)
4113 element_dtype = pfor_input.get_attr("element_dtype")
4115 handle = _untile_variant(input_handle)
4116 length = list_ops.tensor_list_length(handle)
4117 # Note that element_shape attribute can have incomplete shapes. This doesn't
4118 # seem to work well when creating another list and then doing a concat on it.
4119 # Hence we try to find the dynamic shape here.
4120 element_shape = tf_cond.cond(
4121 length > 0, lambda: array_ops.shape(
4122 list_ops.tensor_list_get_item(handle, 0, element_dtype, None)),
4123 lambda: constant_op.constant([0, 0], dtype=dtypes.int32))
4124 # The code below creates a copy of the list with each elements' first two
4125 # dimensions transposed.
4126 new_element_shape = array_ops.concat(
4127 [element_shape[1:2], element_shape[0:1], element_shape[2:]], axis=0)
4129 # Create a new TensorList with elements transposed.
4130 def _transpose_elem(i, h):
4131 elem = list_ops.tensor_list_get_item(handle, i, element_dtype, None)
4132 elem = _transpose_first_two_dims(elem)
4133 return i + 1, list_ops.tensor_list_set_item(h, i, elem)
4135 new_handle = list_ops.tensor_list_reserve(new_element_shape, length,
4136 element_dtype)
4137 new_handle = while_loop.while_loop(lambda i, _: i < length, _transpose_elem,
4138 [0, new_handle])[1]
4139 output, lengths = gen_list_ops.tensor_list_concat_v2(
4140 input_handle=new_handle,
4141 element_dtype=element_dtype,
4142 element_shape=new_element_shape,
4143 leading_dims=leading_dims)
4144 output = _transpose_first_two_dims(output)
4145 return wrap(output, True), wrap(lengths, False)
4148@RegisterPFor("TensorListStack")
4149def _convert_tensor_list_stack(pfor_input):
4150 handle = pfor_input.stacked_input(0)
4151 input_shape = pfor_input.unstacked_input(1)
4152 element_dtype = pfor_input.get_attr("element_dtype")
4153 num_elements = pfor_input.get_attr("num_elements")
4155 handle = _untile_variant(handle)
4156 input_shape = _stack_tensor_list_shape(input_shape,
4157 pfor_input.pfor.loop_len_vector)
4158 output = list_ops.tensor_list_stack(
4159 handle,
4160 element_dtype,
4161 element_shape=input_shape,
4162 num_elements=num_elements)
4163 output = _transpose_first_two_dims(output)
4164 return wrap(output, True)
4167@RegisterPFor("TensorListGather")
4168def _convert_tensor_list_gather(pfor_input):
4169 handle, handle_stacked, _ = pfor_input.input(0)
4170 index, index_stacked, _ = pfor_input.input(1)
4171 element_shape = pfor_input.unstacked_input(2)
4172 element_dtype = pfor_input.get_attr("element_dtype")
4174 if handle_stacked:
4175 handle = _untile_variant(handle)
4176 element_shape = _stack_tensor_list_shape(element_shape,
4177 pfor_input.pfor.loop_len_vector)
4178 if index_stacked:
4179 # We use a sequential loop since that may be more efficient than first
4180 # gathering and concatenating all the element corresponding to `index`,
4181 # and then doing a gather on it.
4182 def _map_fn(i):
4183 item_i = list_ops.tensor_list_gather(
4184 handle,
4185 index[i],
4186 element_dtype=element_dtype)
4187 axis = array_ops.rank(index) - 1
4188 return array_ops.gather(item_i, i, axis=axis)
4190 output = map_fn.map_fn(_map_fn, pfor_input.pfor.all_indices)
4191 return wrap(output, True)
4192 else:
4193 output = list_ops.tensor_list_gather(
4194 handle,
4195 index,
4196 element_shape=element_shape,
4197 element_dtype=element_dtype)
4198 return wrap(output, True)
4199 else:
4200 assert index_stacked
4201 index_shape = array_ops.shape(index)
4202 index = array_ops.reshape(index, [-1])
4203 values = list_ops.tensor_list_gather(
4204 handle, index, element_shape=element_shape, element_dtype=element_dtype)
4205 final_shape = array_ops.concat(
4206 [index_shape, array_ops.shape(values)[1:]], axis=0)
4207 return wrap(array_ops.reshape(values, final_shape), True)
4210@RegisterPFor("TensorListScatterIntoExistingList")
4211def _convert_tensor_list_scatter(pfor_input):
4212 pfor_input.stack_inputs([1])
4213 handle, handle_stacked, _ = pfor_input.input(0)
4214 item = pfor_input.stacked_input(1)
4215 indices, indices_stacked, _ = pfor_input.input(2)
4216 if handle_stacked:
4217 handle = _untile_variant(handle)
4218 else:
4219 handle = _stack_tensor_list(handle, item.dtype,
4220 pfor_input.pfor.loop_len_vector)
4222 item = _transpose_first_two_dims(item)
4223 if indices_stacked:
4224 # Pretend the list is a dense tensor:
4225 # list_as_dense: Tensor[list_len, loop_len, ...]
4226 # And indices are a tensor with shape (before transpose):
4227 # indices: Tensor[loop_len, num_scatters]
4228 # The item to scatter has shape (before transpose):
4229 # item: Tensor[loop_len, num_scatters, ...]
4230 #
4231 # We want list_as_dense[indices[i, j], i] = item[i, j]
4232 #
4233 # Since we're not just indexing along the first axis of `list_as_dense`, we
4234 # need to first extract the relevant list entries based on `indices`,
4235 # scatter into them according to the loop index, and re-scatter the chunks
4236 # we updated back into the list.
4237 indices = _transpose_first_two_dims(indices)
4238 indices_flat = array_ops.reshape(indices, [-1])
4239 # In many cases `indices` will be unique across pfor iterations, but this is
4240 # not guaranteed. If there are duplicates, we need to map multiple updates
4241 # to a single chunk extracted from the list. The last update should win.
4242 unique_indices = array_ops.unique(indices_flat)
4243 gathered_items = list_ops.tensor_list_gather(
4244 handle, unique_indices.y, element_dtype=item.dtype,
4245 element_shape=array_ops.shape(item)[1:])
4246 loop_idx = math_ops.range(pfor_input.pfor.loop_len_vector[0])
4247 scatters_per_op = array_ops.shape(indices)[0]
4249 unique_indices_loop_idx = array_ops.reshape(array_ops.tile(
4250 loop_idx[None, :], [scatters_per_op, 1]), [-1])
4251 scatter_indices = array_ops_stack.stack(
4252 [unique_indices.idx, unique_indices_loop_idx],
4253 axis=1)
4254 # This op does *not* guarantee last-update-wins on GPU, so semantics may not
4255 # be exactly preserved for duplicate updates there.
4256 scattered = array_ops.tensor_scatter_nd_update(
4257 tensor=gathered_items,
4258 indices=scatter_indices,
4259 updates=_flatten_first_two_dims(item))
4260 handle = list_ops.tensor_list_scatter(
4261 scattered, unique_indices.y, input_handle=handle)
4262 else:
4263 handle = list_ops.tensor_list_scatter(item, indices, input_handle=handle)
4264 return wrap(_tile_variant(handle, pfor_input), True)
4267@RegisterPFor("TensorListFromTensor")
4268def _convert_tensor_list_from_tensor(pfor_input):
4269 tensor = pfor_input.stacked_input(0)
4270 element_shape = pfor_input.unstacked_input(1)
4271 tensor = _transpose_first_two_dims(tensor)
4272 element_shape = _stack_tensor_list_shape(element_shape,
4273 pfor_input.pfor.loop_len_vector)
4274 handle = list_ops.tensor_list_from_tensor(tensor, element_shape)
4275 return wrap(_tile_variant(handle, pfor_input), True)
4278@RegisterPFor("TensorScatterUpdate")
4279def _convert_tensor_scatter_update(pfor_input):
4280 pfor_input.stack_inputs([0, 1, 2])
4281 tensor = pfor_input.stacked_input(0)
4282 indices = pfor_input.stacked_input(1)
4283 updates = pfor_input.stacked_input(2)
4285 indices_shape = array_ops.shape(indices)
4286 indices_rank = array_ops.rank(indices)
4287 loop_length = indices_shape[0]
4289 # Create a loop count range and extend its dimensions to match `indices`.
4290 loop_count_shape = array_ops.tensor_scatter_nd_update(
4291 array_ops.ones([indices_rank], dtype=dtypes.int32), [[0]], [loop_length])
4292 loop_count = array_ops.reshape(math_ops.range(loop_length), loop_count_shape)
4294 # Tile the loop count range for the batch dimensions (all except the first and
4295 # last dimensions of indices).
4296 # Rank(indices) >= 3 always for this function so we always have at least 1.
4297 tile_multiplier = array_ops.tensor_scatter_nd_update(
4298 indices_shape, [[0], [indices_rank - 1]], [1, 1])
4299 meta_index = array_ops.tile(loop_count, tile_multiplier)
4301 # Insert the loop-identifying index.
4302 indices = array_ops.concat([meta_index, indices], axis=-1)
4304 result = array_ops.tensor_scatter_nd_update(tensor, indices, updates)
4305 return wrap(result, True)
4307# StackV2 conversion is tricky since we don't have arrays of StackV2. So similar
4308# to TensorArrays, we convert them by changing the dimension of the elements
4309# inside the stack.
4310#
4311# We consider two cases:
4312#
4313# 1. StackV2 is constructed and used entirely inside the pfor loop.
4314# We keep a single Stack and perform the push/pop operations of all the
4315# iterations in lock-step. We also assume that all the iterations perform these
4316# operations. In case of dynamic control flow, if only some of the iterations
4317# try to perform a push/pop, then the conversion may not work correctly and may
4318# cause undefined behavior.
4319# TODO(agarwal): test StackV2 with dynamic control flow.
4320#
4321# 2. StackV2 is constructed outside the pfor loop.
4322# Performing stack push/pop in a parallel fashion is ill-defined. However given
4323# that reading stacks created externally is a common operation when computing
4324# jacobians, we provide some special semantics here as follows.
4325# - disallow push operations to the stack
4326# - pop operations are performed in lock step by all iterations, similar to the
4327# case when the stack is created inside. A single value is popped during the
4328# lock-step operation and broadcast to all the iterations. Values in the stack
4329# are assumed to be loop-invariant.
4330#
4331# Some other implementation details:
4332# We use an ugly logic to find whether values in Stack data structure are
4333# loop invariant or not. When converting push/pop operations, we keep track of
4334# whether the last conversion used a stacked value or not (see _stack_cache
4335# below). As a result if an unstacked value is written first, subsequent stacked
4336# writes are disallowed when they could have been allowed in theory.
4338# Map from cache key based on StackV2 handle to a bool indicating whether values
4339# are stacked or not.
4340# TODO(agarwal): move _stack_cache inside pfor?
4341_stack_cache = {}
4344def _stack_cache_key(pfor_input):
4345 """Create cache key corresponding to a stack handle."""
4346 op_type = pfor_input.op_type
4347 assert op_type in ["StackPushV2", "StackPopV2"], op_type
4348 orig_handle = pfor_input.op.inputs[0]
4349 while orig_handle.op.type in ["Identity", "Enter"]:
4350 orig_handle = orig_handle.op.inputs[0]
4351 assert orig_handle.op.type == "StackV2", orig_handle.op
4352 return ops.get_default_graph(), pfor_input.pfor, orig_handle
4355def _stack_handle_inside_pfor(handle, pfor_input):
4356 while handle.op.type in ["Identity", "Enter"]:
4357 handle = handle.op.inputs[0]
4358 assert handle.op.type == "StackV2", ("Unable to find StackV2 op. Got %s" %
4359 handle.op)
4360 return pfor_input.pfor.op_is_inside_loop(handle.op)
4363@RegisterPFor("StackPushV2")
4364def _convert_stack_push_v2(pfor_input):
4365 handle = pfor_input.unstacked_input(0)
4366 elem, elem_stacked, _ = pfor_input.input(1)
4367 swap_memory = pfor_input.get_attr("swap_memory")
4369 if not _stack_handle_inside_pfor(pfor_input.op.inputs[0], pfor_input):
4370 raise ValueError("StackPushV2 not allowed on stacks created outside pfor.")
4371 stack_cache_key = _stack_cache_key(pfor_input)
4372 stacked = _stack_cache.get(stack_cache_key, None)
4373 if stacked is None:
4374 stacked = elem_stacked
4375 _stack_cache[stack_cache_key] = stacked
4376 else:
4377 # If we previously made it unstacked then we can't revert to being stacked.
4378 if not stacked and elem_stacked:
4379 raise ValueError(
4380 "It looks like the stack was previously determined to be loop "
4381 "invariant, but we are now trying to push a loop dependent value "
4382 "to it. This is currently unsupported.")
4383 if stacked and not elem_stacked:
4384 elem = _stack(elem, pfor_input.pfor.loop_len_vector).t
4385 out = data_flow_ops.stack_push_v2(handle, elem, swap_memory=swap_memory)
4386 return wrap(out, stacked)
4389# Note that inputs to this convertor will be unstacked. However it should get
4390# called since it is a stateful op.
4391@RegisterPFor("StackPopV2")
4392def _convert_stack_pop_v2(pfor_input):
4393 handle = pfor_input.unstacked_input(0)
4394 stack_cache_key = _stack_cache_key(pfor_input)
4395 stacked = _stack_cache.get(stack_cache_key, None)
4396 # If a StackPushV2 has not been converted yet, we default to unstacked since
4397 # the push could be outside of pfor, or the convertor may not be called if the
4398 # inputs are unconverted.
4399 if stacked is None:
4400 stacked = False
4401 _stack_cache[stack_cache_key] = False
4402 elem_type = pfor_input.get_attr("elem_type")
4403 out = data_flow_ops.stack_pop_v2(handle, elem_type)
4404 return wrap(out, stacked)
4407# parsing_ops
4410@RegisterPFor("DecodeCSV")
4411def _convert_decode_csv(pfor_input):
4412 lines = pfor_input.stacked_input(0)
4413 record_defaults = [
4414 pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs)
4415 ]
4416 field_delim = pfor_input.get_attr("field_delim")
4417 use_quote_delim = pfor_input.get_attr("use_quote_delim")
4418 select_cols = pfor_input.get_attr("select_cols")
4419 if not select_cols:
4420 select_cols = None
4421 return [
4422 wrap(t, True) for t in parsing_ops.decode_csv(
4423 lines,
4424 record_defaults,
4425 field_delim=field_delim,
4426 use_quote_delim=use_quote_delim,
4427 select_cols=select_cols)
4428 ]
4431@RegisterPFor("ParseSingleExample")
4432def _convert_parse_single_example(pfor_input):
4433 serialized = pfor_input.stacked_input(0)
4434 dense_defaults = [
4435 pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs)
4436 ]
4437 sparse_keys = pfor_input.get_attr("sparse_keys")
4438 dense_keys = pfor_input.get_attr("dense_keys")
4439 sparse_types = pfor_input.get_attr("sparse_types")
4440 dense_shapes = pfor_input.get_attr("dense_shapes")
4441 output = gen_parsing_ops.parse_example(
4442 serialized=serialized,
4443 names=[],
4444 dense_defaults=dense_defaults,
4445 sparse_keys=sparse_keys,
4446 dense_keys=dense_keys,
4447 sparse_types=sparse_types,
4448 dense_shapes=dense_shapes)
4449 return [wrap(t, True, True) for t in nest.flatten(output)]
4452@RegisterPFor("ParseExampleV2")
4453def _convert_parse_example_v2(pfor_input):
4454 serialized = pfor_input.stacked_input(0)
4455 sparse_keys = pfor_input.unstacked_input(2)
4456 dense_keys = pfor_input.unstacked_input(3)
4457 ragged_keys = pfor_input.unstacked_input(4)
4458 dense_defaults = [
4459 pfor_input.unstacked_input(i) for i in range(5, pfor_input.num_inputs)
4460 ]
4461 num_sparse = pfor_input.get_attr("num_sparse")
4462 sparse_types = pfor_input.get_attr("sparse_types")
4463 ragged_value_types = pfor_input.get_attr("ragged_value_types")
4464 ragged_split_types = pfor_input.get_attr("ragged_split_types")
4465 dense_shapes = pfor_input.get_attr("dense_shapes")
4466 if serialized.shape.ndims not in (None, 1):
4467 raise ValueError("ParseExampleV2 can only be converted if `serialized` "
4468 f"is scalar. Received shape: {serialized.shape}.")
4469 output = gen_parsing_ops.parse_example_v2(
4470 serialized=serialized,
4471 names=[],
4472 sparse_keys=sparse_keys,
4473 dense_keys=dense_keys,
4474 ragged_keys=ragged_keys,
4475 dense_defaults=dense_defaults,
4476 num_sparse=num_sparse,
4477 sparse_types=sparse_types,
4478 ragged_value_types=ragged_value_types,
4479 ragged_split_types=ragged_split_types,
4480 dense_shapes=dense_shapes)
4481 return [wrap(t, True, True) for t in nest.flatten(output)]
4484# functional_ops
4487def _convert_function_call(func, converter, inputs):
4488 assert isinstance(func.graph, func_graph.FuncGraph), func
4489 assert isinstance(converter, PFor)
4491 graph_outputs = func.graph.outputs[:len(func.function_type.flat_outputs)]
4492 # TODO(agarwal): consider caching this function definition.
4493 @def_function.function
4494 def f(*args):
4495 assert all(isinstance(arg, WrappedTensor) for arg in args), args
4496 assert len(args) == len(func.graph.inputs), (args, func.graph.inputs)
4497 # Map inputs to function arguments.
4498 for inp, arg in zip(func.graph.inputs, args):
4499 converter._add_conversion(inp, arg)
4500 # Convert output tensors.
4501 return tuple([converter._convert_helper(x).t for x in graph_outputs])
4503 call_outputs = f(*inputs)
4504 assert len(call_outputs) == len(graph_outputs)
4505 outputs = []
4506 for call_output, output_tensor in zip(call_outputs, graph_outputs):
4507 func_output = converter._convert_helper(output_tensor)
4508 outputs.append(
4509 wrap(call_output, func_output.is_stacked, func_output.is_sparse_stacked)
4510 )
4511 return outputs
4514@RegisterPFor("StatefulPartitionedCall")
4515@RegisterPFor("PartitionedCall")
4516def _convert_partitioned_call(pfor_input):
4517 func_name = pfor_input.get_attr("f").name
4518 func = pfor_input.op.graph._get_function(compat.as_bytes(func_name))
4519 assert isinstance(func.graph, func_graph.FuncGraph), (
4520 "Could not find FuncGraph object for %s. Got func %s" % (func_name, func))
4521 pfor = pfor_input.pfor
4522 converter = PFor(
4523 loop_var=pfor.loop_var,
4524 loop_len=pfor.loop_len_vector[0],
4525 pfor_ops=func.graph.get_operations(),
4526 fallback_to_while_loop=pfor.fallback_to_while_loop,
4527 all_indices=pfor.all_indices,
4528 all_indices_partitioned=pfor.all_indices_partitioned,
4529 pfor_config=pfor.pfor_config)
4530 return _convert_function_call(func, converter, pfor_input.inputs)
4533def _partition_inputs_for_indices(inputs, indices):
4534 new_inputs = []
4535 for inp in inputs:
4536 if inp.is_stacked:
4537 new_inputs.append(wrap(array_ops.gather(inp.t, indices), True))
4538 else:
4539 new_inputs.append(inp)
4540 return new_inputs
4543def _outputs_for_branch(func_name, indices, pfor_input, inputs):
4544 if indices is None:
4545 indices = pfor_input.pfor.all_indices
4546 partitioned = pfor_input.pfor.all_indices_partitioned
4547 else:
4548 partitioned = True
4549 func = pfor_input.op.graph._get_function(func_name)
4550 converter = PFor(
4551 loop_var=pfor_input.pfor.loop_var,
4552 loop_len=array_ops.size(indices),
4553 pfor_ops=func.graph.get_operations(),
4554 fallback_to_while_loop=pfor_input.pfor.fallback_to_while_loop,
4555 all_indices=indices,
4556 all_indices_partitioned=partitioned,
4557 pfor_config=pfor_input.pfor.pfor_config)
4558 outputs = _convert_function_call(func, converter, inputs)
4559 stacked_outputs = []
4560 for out in outputs:
4561 if not out.is_stacked:
4562 stacked_outputs.append(_stack(out.t, [array_ops.size(indices)]).t)
4563 else:
4564 stacked_outputs.append(out.t)
4565 return stacked_outputs
4568# TODO(agarwal): Currently the converted code aggressively tiles loop variant
4569# outputs from the then/else branches. Instead, it could do so only if at least
4570# one of the branch outputs is loop variant.
4571@RegisterPFor("StatelessIf")
4572@RegisterPFor("If")
4573def _convert_if(pfor_input):
4574 cond, cond_stacked, _ = pfor_input.input(0)
4575 inputs = pfor_input.inputs[1:]
4576 then_branch = pfor_input.get_attr("then_branch")
4577 else_branch = pfor_input.get_attr("else_branch")
4579 if cond_stacked:
4580 cond_int = math_ops.cast(cond, dtypes.int32)
4581 # Compute loop indices for the different branches
4582 false_indices, true_indices = data_flow_ops.dynamic_partition(
4583 pfor_input.pfor.all_indices, cond_int, 2)
4584 # Compute indices for cond being True or False.
4585 if pfor_input.pfor.all_indices_partitioned:
4586 else_indices, then_indices = data_flow_ops.dynamic_partition(
4587 math_ops.range(pfor_input.pfor.loop_len_vector[0]),
4588 cond_int, 2)
4589 else:
4590 else_indices, then_indices = false_indices, true_indices
4591 # Partition inputs
4592 then_inputs = _partition_inputs_for_indices(inputs, then_indices)
4593 else_inputs = _partition_inputs_for_indices(inputs, else_indices)
4595 # Convert "then" branch.
4596 then_outputs = _outputs_for_branch(then_branch.name, true_indices,
4597 pfor_input, then_inputs)
4599 # Convert "else" branch.
4600 else_outputs = _outputs_for_branch(else_branch.name, false_indices,
4601 pfor_input, else_inputs)
4603 assert len(then_outputs) == len(else_outputs)
4604 # Note that if the "then" and "else" branches are updating the same state,
4605 # and possibly reading them as well, it could lead to undefined behavior
4606 # since the ordering of those operations is not well defined.
4607 # One possibility is to order all the "then" branches to execute before all
4608 # the "else" branches so that the side-effects in the former are visible to
4609 # the latter. For now, we leave that as undefined behavior.
4610 outputs = []
4611 # Merge outputs
4612 for then_output, else_output in zip(then_outputs, else_outputs):
4613 out = data_flow_ops.dynamic_stitch([then_indices, else_indices],
4614 [then_output, else_output])
4615 outputs.append(wrap(out, True))
4616 return outputs
4617 else:
4618 outputs = tf_cond.cond(
4619 cond,
4620 lambda: _outputs_for_branch(then_branch.name, None, pfor_input, inputs),
4621 lambda: _outputs_for_branch(else_branch.name, None, pfor_input, inputs))
4622 return [wrap(t, True) for t in outputs]
4625@RegisterPFor("Case")
4626@RegisterPFor("StatelessCase")
4627def _convert_stateless_case(pfor_input):
4628 branch_idx, is_stacked, _ = pfor_input.input(0)
4629 branches = pfor_input.get_attr("branches")
4630 inputs = pfor_input.inputs[1:]
4632 if is_stacked:
4633 logging.info("Running stacked flow")
4635 # Compute loop indices for the different branches
4636 switch_indices = data_flow_ops.dynamic_partition(
4637 pfor_input.pfor.all_indices, branch_idx, len(branches))
4638 if pfor_input.pfor.all_indices_partitioned:
4639 partitioned_indices = data_flow_ops.dynamic_partition(
4640 math_ops.range(pfor_input.pfor.loop_len_vector[0]), branch_idx,
4641 len(branches))
4642 else:
4643 partitioned_indices = switch_indices
4644 # Partition inputs
4645 input_list = []
4646 for indices in partitioned_indices:
4647 input_list.append(_partition_inputs_for_indices(inputs, indices))
4649 outputs = []
4650 for (b, indices, inputs) in zip(branches, switch_indices, input_list):
4651 out = _outputs_for_branch(b.name, indices, pfor_input, inputs)
4652 outputs.extend(out)
4654 out = data_flow_ops.dynamic_stitch(partitioned_indices, outputs)
4655 return [wrap(out, True)]
4656 else:
4657 new_branches = []
4658 for b in branches:
4659 def new_function(func=b.name):
4660 return _outputs_for_branch(func, None, pfor_input,
4661 pfor_input.inputs[1:])
4663 new_branches.append(new_function)
4665 outputs = []
4666 outputs = control_flow_switch_case.switch_case(branch_idx, new_branches)
4667 return [wrap(t, True) for t in outputs]
4670class WhileV2:
4671 """Object for vectorizing V2 while_loop op."""
4673 def __init__(self, pfor_input):
4674 self._pfor_input = pfor_input
4675 self._pfor = pfor_input.pfor
4676 cond_func_name = pfor_input.get_attr("cond").name
4677 self._cond_func = pfor_input.op.graph._get_function(compat.as_bytes(
4678 cond_func_name))
4679 body_func_name = pfor_input.get_attr("body").name
4680 self._body_func = pfor_input.op.graph._get_function(compat.as_bytes(
4681 body_func_name))
4682 if self._cond_func is None or self._body_func is None:
4683 raise ValueError("Error extracting cond and body functions for op "
4684 f"{self._pfor_input.op}.")
4685 # Indices of inputs that are passed unchanged through the while loop body.
4686 # Typically these are tensors captured from outside the body context.
4687 self._body_pass_through_indices = set()
4688 for i, (inp, out) in enumerate(zip(self._body_func.graph.inputs,
4689 self._body_func.graph.outputs)):
4690 if id(inp) == id(out):
4691 self._body_pass_through_indices.add(i)
4692 self._parallel_iterations = self._pfor_input.get_attr("parallel_iterations")
4694 def _output_shapes(self):
4695 # Calculate output shape for vectorized loop. This will be used as
4696 # shape_invariant. Merges shape inference outputs with the `output_shapes`
4697 # attribute of the op.
4698 output_shapes = [out.shape for out in self._pfor_input.op.outputs]
4699 shapes = self._pfor_input.get_attr("output_shapes")
4700 if not shapes:
4701 shapes = [tensor_shape.TensorShape(None) for _ in output_shapes]
4702 else:
4703 shapes = [tensor_shape.TensorShape(shape) for shape in shapes]
4704 for i, shape in enumerate(shapes):
4705 shape = shape.merge_with(output_shapes[i])
4706 pfor_input = self._pfor_input.input(i)
4707 if pfor_input.is_stacked:
4708 if _is_variant_with_internal_stacking(pfor_input.t):
4709 shape = tensor_shape.TensorShape([]).concatenate(shape)
4710 else:
4711 shape = tensor_shape.TensorShape([None]).concatenate(shape)
4712 output_shapes[i] = shape
4713 assert len(output_shapes) == self._pfor_input.num_inputs
4714 return output_shapes
4716 def _init_values(self):
4717 """Create arguments passed to converted while_loop."""
4718 loop_len = self._pfor.loop_len_vector[0]
4719 inputs = []
4720 # TensorArrays for outputs of converted while loop
4721 output_tas = []
4723 with ops.name_scope("while_init"):
4724 for inp in self._pfor_input.inputs:
4725 inputs.append(inp.t)
4726 variant_type_id = _variant_type_id(inp.t)
4727 if variant_type_id in _INTERNAL_STACKING_TYPE_IDS:
4728 if variant_type_id != full_type_pb2.TFT_ARRAY:
4729 raise NotImplementedError(
4730 "While loop conversion is only supported for TensorLists. Got "
4731 f"another variant {inp.t}, probably an optional. Please file "
4732 "a bug.")
4734 # For TensorLists, the input format is:
4735 #
4736 # List[user_list_len, Tensor[loop_len, ...]]
4737 #
4738 # rather than the usual
4739 #
4740 # Tensor[loop_len, ...]
4741 #
4742 # The body of the loop will take and return lists in this "internal
4743 # vectorization" format, so we want to keep it that way as much as
4744 # possible. We'll accumulate finished iterations (only relevant for
4745 # pfor-loop-variant while_loop conditions) in an accumulator with
4746 # type :
4747 #
4748 # List[user_list_len, List[loop_len, Tensor[...]]]
4749 #
4750 # This means that each while_loop iteration, we'll iterate over the
4751 # length of the TensorList, dividing done/remaining pfor loop indices
4752 # and scattering the done indices into the inner nested list of the
4753 # accumulator.
4754 element_shape = list_ops.tensor_list_element_shape(
4755 inp.t, dtypes.int32)
4756 if inp.is_stacked:
4757 # Shapes may be tf.constant(-1) for fully dynamic, in which case
4758 # slicing is an error.
4759 element_shape = tf_cond.cond(
4760 math_ops.equal(array_ops.rank(element_shape), 0),
4761 lambda: element_shape,
4762 lambda: element_shape[1:])
4763 dtype = _parse_variant_shapes_and_types(inp.t)[0].dtype
4765 def _init_loop_body(index, output_ta):
4766 output_ta = output_ta.write(
4767 index,
4768 list_ops.tensor_list_reserve(element_shape, loop_len, dtype))
4769 return index + 1, output_ta
4771 length = list_ops.tensor_list_length(inp.t)
4772 output_ta = tensor_array_ops.TensorArray(
4773 inp.t.dtype, # Variant; this is a nested TensorList
4774 size=length,
4775 dynamic_size=True,
4776 infer_shape=False)
4777 _, output_ta = while_loop.while_loop(lambda index, _: index < length,
4778 _init_loop_body, [0, output_ta])
4779 else:
4780 output_ta = tensor_array_ops.TensorArray(
4781 inp.t.dtype,
4782 size=loop_len,
4783 dynamic_size=False,
4784 infer_shape=True)
4785 output_tas.append(output_ta)
4786 # See documentation for __call__ for the structure of init_values.
4787 indices = (
4788 math_ops.range(self._pfor.loop_len_vector[0])
4789 if self._pfor.all_indices_partitioned else self._pfor.all_indices)
4790 return [True, indices] + inputs + output_tas
4792 def _process_cond_unstacked(self, conditions, indices, inputs, output_tas):
4793 """Handles case when condition is pfor loop invariant."""
4794 # Note that all iterations end together. So we don't need to partition the
4795 # inputs.
4796 not_all_done = array_ops.reshape(conditions, [])
4797 return not_all_done, indices, inputs, output_tas
4799 def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked,
4800 output_tas):
4801 """Handles case when condition is pfor loop dependent."""
4802 # Compute if all iterations are done.
4803 not_all_done = math_ops.reduce_any(conditions)
4804 conditions_int = math_ops.cast(conditions, dtypes.int32)
4805 # Partition the indices.
4806 done_indices, new_indices = data_flow_ops.dynamic_partition(
4807 indices, conditions_int, 2)
4809 new_inputs = []
4810 new_output_tas = []
4811 for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)):
4812 pass_through = i in self._body_pass_through_indices
4813 if not pass_through and _variant_type_id(inp) == full_type_pb2.TFT_ARRAY:
4814 shape_and_type = _parse_variant_shapes_and_types(inp)[0]
4815 element_shape = list_ops.tensor_list_element_shape(inp, dtypes.int32)
4816 user_list_len = list_ops.tensor_list_length(inp)
4818 def _split_vectorized_ta_element(index, new_inp, new_out_ta):
4819 elem = list_ops.tensor_list_get_item(inp, index, shape_and_type.dtype,
4820 element_shape)
4821 if stacked:
4822 done_elem, new_elem = data_flow_ops.dynamic_partition(
4823 elem, conditions_int, 2)
4824 new_inp = list_ops.tensor_list_set_item(new_inp, index, new_elem)
4825 else:
4826 done_elem = _stack(elem, [array_ops.size(done_indices)]).t
4827 done_accum = new_out_ta.read(index)
4828 done_accum = list_ops.tensor_list_scatter(
4829 tensor=done_elem, indices=done_indices, input_handle=done_accum)
4830 new_out_ta = new_out_ta.write(index, done_accum)
4831 return index + 1, new_inp, new_out_ta
4833 length = list_ops.tensor_list_length(inp)
4834 new_inp = list_ops.tensor_list_reserve(
4835 tensor_shape.TensorShape([None])
4836 + tensor_shape.TensorShape(shape_and_type.shape)[1:],
4837 user_list_len, shape_and_type.dtype)
4838 _, new_inp, out_ta = while_loop.while_loop(
4839 lambda index, unused_new_inp, unused_new_out_ta: index < length,
4840 _split_vectorized_ta_element, [0, new_inp, output_tas[i]])
4841 else:
4842 # Partition the inputs.
4843 if stacked:
4844 done_inp, new_inp = data_flow_ops.dynamic_partition(
4845 inp, conditions_int, 2)
4846 else:
4847 if not pass_through:
4848 done_inp = _stack(inp, [array_ops.size(done_indices)]).t
4849 new_inp = inp
4851 out_ta = output_tas[i]
4852 if not pass_through:
4853 # Note that done_indices can be empty. done_inp should also be empty
4854 # in that case.
4855 out_ta = out_ta.scatter(done_indices, done_inp)
4856 new_inputs.append(new_inp)
4857 new_output_tas.append(out_ta)
4859 assert len(new_output_tas) == len(output_tas)
4860 assert len(new_inputs) == len(inputs)
4861 return not_all_done, new_indices, new_inputs, new_output_tas
4863 def _process_body(self, inputs_stacked, new_indices, cond_stacked,
4864 new_inputs, not_all_done):
4865 """Convert the body function."""
4866 # This is used to store the indices of inputs to the while op that need to
4867 # be stacked. This stacking may be needed in cases where the input to the
4868 # while_loop is loop_invariant but the corresponding output is not.
4869 mismatching_stacked_indices = []
4871 def true_fn():
4872 """Converts the body function for all but last iteration."""
4873 wrapped_inputs = [wrap(inp, stacked) for inp, stacked in
4874 zip(new_inputs, inputs_stacked)]
4875 # Note the iterative process below to figure out loop invariance.
4876 # Here we iterate on vectorization process till a fixed point. The issue
4877 # is that the while body can take pfor loop invariant inputs but return
4878 # loop variant outputs. For any loop variant output, the corresponding
4879 # input has to be then made loop variant (since subsequent while
4880 # iterations will need to see loop variant values).
4881 # However once we make a new input loop variant, we might make other
4882 # outputs loop variant. Hence we need to iterate till we get fixed point.
4883 while True:
4884 if self._pfor.all_indices_partitioned:
4885 indices = array_ops.gather(self._pfor.all_indices, new_indices)
4886 else:
4887 indices = new_indices
4888 body_pfor = PFor(
4889 loop_var=self._pfor.loop_var,
4890 loop_len=array_ops.size(new_indices),
4891 pfor_ops=self._body_func.graph.get_operations(),
4892 fallback_to_while_loop=self._pfor.fallback_to_while_loop,
4893 all_indices=indices,
4894 all_indices_partitioned=(self._pfor.all_indices_partitioned or
4895 cond_stacked),
4896 pfor_config=self._pfor.pfor_config)
4897 stacking_mismatch = False
4898 outputs = _convert_function_call(self._body_func,
4899 body_pfor,
4900 wrapped_inputs)
4901 for i, (out, inp) in enumerate(zip(outputs, wrapped_inputs)):
4902 if out.is_stacked != inp.is_stacked:
4903 stacking_mismatch = True
4904 mismatching_stacked_indices.append(i)
4905 stacked = _stack(inp.t, [array_ops.size(new_indices)])
4906 if inp.t.dtype == dtypes.variant:
4907 stacked = wrap(
4908 _tile_variant_with_length(stacked.t,
4909 [array_ops.size(new_indices)]))
4910 wrapped_inputs[i] = stacked
4911 if not stacking_mismatch:
4912 if mismatching_stacked_indices:
4913 # We needed to stack some inputs. This code will be abandoned and
4914 # should not get executed. Hence we simply return `new_inputs` to
4915 # make sure the graph construction code completes.
4916 with ops.control_dependencies([
4917 control_flow_assert.Assert(
4918 False, ["pfor ERROR: this branch should never execute"])
4919 ]):
4920 return [array_ops.identity(x) for x in new_inputs]
4921 else:
4922 return [out.t for out in outputs]
4924 # If all are done, we simply return `new_inputs`. Else we need to run the
4925 # body function.
4926 return tf_cond.cond(
4927 not_all_done,
4928 true_fn,
4929 lambda: list(new_inputs)), mismatching_stacked_indices
4931 def __call__(self):
4932 """Converter for the V2 while_loop.
4934 The conversion of a while_loop is another while_loop.
4936 The arguments to this converted while_loop are as follows:
4937 not_all_done: Boolean scalar Tensor indicating if all the pfor iterations
4938 are done.
4939 indices: int32 1-D Tensor storing the id of the pfor iterations that are not
4940 done.
4941 args: Remaining arguments. These can be divided into 2 categories:
4942 - The first set of arguments correspond one-to-one to the inputs to the
4943 unvectorized while_loop.
4944 - The second set are TensorArrays, corresponding one-to-one to each output
4945 of the unvectorized while_loop. Each TensorArray has `PFor.loop_len`
4946 elements, i.e. the number of pfor iterations. At the end, the i'th
4947 element of each TensorArray will contain the output computed by the i'th
4948 iteration of pfor. Note that elements can be written into these tensors
4949 arrays in any order, depending on when the corresponding pfor iteration
4950 is done.
4951 In each iteration, the while_loop body recomputes the condition for all
4952 active pfor iterations to see which of them are now done. It then partitions
4953 all the inputs and passes them along to the converted body. Values for all
4954 the iterations that are done are written to TensorArrays indexed by the pfor
4955 iteration number. When all iterations are done, the TensorArrays are stacked
4956 to get the final value.
4958 Returns:
4959 List of converted outputs.
4960 """
4961 output_shapes = self._output_shapes()
4962 # Note that we use these lists as a hack since we need the `body` to compute
4963 # these values during construction of the while_loop graph.
4964 cond_is_stacked = [None]
4965 indices_to_stack = []
4967 def cond(not_all_done, *_):
4968 return not_all_done
4970 def body(not_all_done, indices, *args):
4971 # See documentation for __call__ for the structure of *args.
4972 num_inputs = self._pfor_input.num_inputs
4973 inputs = args[:num_inputs]
4974 output_tas = args[num_inputs:]
4975 inputs_stacked = [x.is_stacked for x in self._pfor_input.inputs]
4976 assert len(inputs) >= len(output_tas)
4977 assert len(inputs) == len(inputs_stacked)
4978 # Convert condition
4979 with ops.name_scope("while_cond"):
4980 # Note that we set all_indices_partitioned to True here. At this point
4981 # we don't know if indices will be partitioned. Hence we use the
4982 # conservative value.
4983 cond_pfor = PFor(
4984 loop_var=self._pfor.loop_var,
4985 loop_len=array_ops.size(indices),
4986 pfor_ops=self._cond_func.graph.get_operations(),
4987 fallback_to_while_loop=self._pfor.fallback_to_while_loop,
4988 all_indices=indices,
4989 all_indices_partitioned=True,
4990 pfor_config=self._pfor.pfor_config)
4992 wrapped_inputs = [wrap(inp, stacked) for inp, stacked
4993 in zip(inputs, inputs_stacked)]
4994 conditions, cond_stacked, _ = _convert_function_call(
4995 self._cond_func,
4996 cond_pfor,
4997 wrapped_inputs)[0]
4998 cond_is_stacked[0] = cond_stacked
5000 # Recompute the new condition, write outputs of done iterations, and
5001 # partition the inputs if needed.
5002 if not cond_stacked:
5003 (not_all_done, new_indices, new_inputs,
5004 new_output_tas) = self._process_cond_unstacked(conditions, indices,
5005 inputs, output_tas)
5006 else:
5007 (not_all_done, new_indices, new_inputs,
5008 new_output_tas) = self._process_cond_stacked(conditions, indices,
5009 inputs, inputs_stacked,
5010 output_tas)
5011 # Convert body
5012 with ops.name_scope("while_body"):
5013 # Compute the outputs from the body.
5014 new_outputs, mismatching_stacked_indices = self._process_body(
5015 inputs_stacked, new_indices, cond_stacked, new_inputs, not_all_done)
5017 indices_to_stack[:] = mismatching_stacked_indices
5018 for i, new_output in enumerate(new_outputs):
5019 new_output.set_shape(output_shapes[i])
5020 new_args = ([not_all_done, new_indices] + new_outputs +
5021 list(new_output_tas))
5022 return tuple(new_args)
5024 # Note that we run the code below in a function since we might abandon the
5025 # generated code in cases where the conversion dictates that some inputs be
5026 # further stacked. Hence we run the graph construction using
5027 # `get_concrete_function` and avoid calling the constructed function if not
5028 # needed.
5029 @def_function.function
5030 def while_fn():
5031 # Create init_values that will be passed to the while_loop.
5032 init_values = self._init_values()
5033 ta_shape_invariants = [tensor_shape.TensorShape([]) for _ in
5034 self._pfor_input.outputs]
5035 shape_invariants = (
5036 [tensor_shape.TensorShape([]), tensor_shape.TensorShape([None])]
5037 + output_shapes + ta_shape_invariants)
5039 while_outputs = while_loop.while_loop(
5040 cond,
5041 body,
5042 init_values,
5043 shape_invariants=shape_invariants,
5044 parallel_iterations=self._parallel_iterations)
5045 if indices_to_stack:
5046 # This function will be abandoned.
5047 return while_outputs
5048 else:
5049 num_inputs = self._pfor_input.num_inputs
5050 new_inputs = while_outputs[2:num_inputs+2]
5051 output_tas = while_outputs[num_inputs+2:]
5052 assert cond_is_stacked[0] is not None
5053 outputs = []
5054 for i, inp in enumerate(new_inputs):
5055 if cond_is_stacked[0]:
5056 if i in self._body_pass_through_indices:
5057 outputs.append(init_values[i + 2])
5058 else:
5059 ta = output_tas[i]
5060 if _variant_type_id(inp) == full_type_pb2.TFT_ARRAY:
5061 shape_and_type = _parse_variant_shapes_and_types(inp)[0]
5062 length = list_ops.tensor_list_length(inp)
5064 # We have been accumulating values in a:
5065 #
5066 # List[user_list_len, List[loop_len, Tensor[...]]]
5067 #
5068 # We want to return an output in the same format as the input:
5069 #
5070 # List[user_list_len, Tensor[loop_len, ...]]
5071 #
5072 # So we need to loop over the list and stack its contents.
5073 def _stack_loop_body(index, output_list):
5074 current_value = ta.read(index)
5075 output_list = list_ops.tensor_list_set_item(
5076 output_list, index,
5077 list_ops.tensor_list_stack(
5078 current_value, shape_and_type.dtype))
5079 return index + 1, output_list
5081 output_list = list_ops.tensor_list_reserve(
5082 tensor_shape.TensorShape(shape_and_type.shape), length,
5083 shape_and_type.dtype)
5084 _, output_list = while_loop.while_loop(
5085 lambda index, _: index < length, _stack_loop_body,
5086 [0, output_list])
5087 outputs.append(output_list)
5088 else:
5089 outputs.append(ta.stack())
5090 else:
5091 outputs.append(inp)
5092 return outputs
5094 _ = while_fn.get_concrete_function()
5095 if indices_to_stack:
5096 # Need to abandon the current conversion, stack some inputs and restart.
5097 self._pfor_input.stack_inputs(
5098 stack_indices=indices_to_stack, tile_variants=True)
5099 # Note that this call will recurse at most one time. The first call will
5100 # do the required stacking, based on the iterative procedure in
5101 # _process_body, and the next invocation to __call__ should not need to do
5102 # any more stacking.
5103 # We invoke `self()` here as a way to discard any corrupted state.
5104 return self()
5105 else:
5106 outputs = while_fn()
5107 wrapped_outputs = []
5108 for i, (out, inp) in enumerate(zip(outputs, self._pfor_input.inputs)):
5109 if i not in self._body_pass_through_indices and cond_is_stacked[0]:
5110 wrapped_outputs.append(wrap(out, True))
5111 else:
5112 wrapped_outputs.append(wrap(out, inp.is_stacked))
5113 return wrapped_outputs
5116@RegisterPFor("StatelessWhile")
5117@RegisterPFor("While")
5118def _convert_while(pfor_input):
5119 converter = WhileV2(pfor_input)
5120 return converter()
5123# spectral_ops
5126@RegisterPForWithArgs("FFT", gen_spectral_ops.fft)
5127@RegisterPForWithArgs("FFT2D", gen_spectral_ops.fft2d)
5128@RegisterPForWithArgs("FFT3D", gen_spectral_ops.fft3d)
5129@RegisterPForWithArgs("IFFT", gen_spectral_ops.ifft)
5130@RegisterPForWithArgs("IFFT2D", gen_spectral_ops.ifft2d)
5131@RegisterPForWithArgs("IFFT3D", gen_spectral_ops.ifft3d)
5132def _convert_fft(pfor_input, _, op_func):
5133 return wrap(op_func(pfor_input.stacked_input(0)), True)
5136@RegisterPForWithArgs("RFFT", gen_spectral_ops.rfft, "Tcomplex")
5137@RegisterPForWithArgs("RFFT2D", gen_spectral_ops.rfft2d, "Tcomplex")
5138@RegisterPForWithArgs("RFFT3D", gen_spectral_ops.rfft3d, "Tcomplex")
5139@RegisterPForWithArgs("IRFFT", gen_spectral_ops.irfft, "Treal")
5140@RegisterPForWithArgs("IRFFT2D", gen_spectral_ops.irfft2d, "Treal")
5141@RegisterPForWithArgs("IRFFT3D", gen_spectral_ops.irfft3d, "Treal")
5142def _convert_rfft(pfor_input, _, op_func, attr_name):
5143 inp = pfor_input.stacked_input(0)
5144 fft_length = pfor_input.unstacked_input(1)
5145 attr = pfor_input.get_attr(attr_name)
5146 return wrap(op_func(inp, fft_length, attr), True)