Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/tensor_array_ops.py: 29%
554 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorArray: a dynamically sized array of Tensors."""
16# Mixture of pep8 and non-pep8 names, so disable pylint bad-name
17# pylint: disable=g-bad-name
18import contextlib
20import traceback
21import weakref
23import numpy as np
25from tensorflow.core.protobuf import struct_pb2
26from tensorflow.python.eager import context
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import errors_impl
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.framework import tensor_spec
33from tensorflow.python.framework import tensor_util
34from tensorflow.python.framework import type_spec
35from tensorflow.python.framework import type_spec_registry
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import array_ops_stack
38from tensorflow.python.ops import control_flow_util
39from tensorflow.python.ops import gen_control_flow_ops
40from tensorflow.python.ops import gen_data_flow_ops
41from tensorflow.python.ops import list_ops
42from tensorflow.python.ops import math_ops
43from tensorflow.python.platform import tf_logging as logging
44from tensorflow.python.saved_model import nested_structure_coder
45from tensorflow.python.util import tf_should_use
46from tensorflow.python.util.tf_export import tf_export
49# _GraphTensorArray accesses many of the hidden generated ops, but is in
50# fact built to wrap these methods.
51# pylint: disable=protected-access
52class _GraphTensorArray:
53 """Graph-mode implementation of TensorArray."""
55 def __init__(self,
56 dtype,
57 size=None,
58 dynamic_size=None,
59 clear_after_read=None,
60 tensor_array_name=None,
61 handle=None,
62 flow=None,
63 infer_shape=True,
64 element_shape=None,
65 colocate_with_first_write_call=True,
66 name=None):
67 """Constructs a graph mode TensorArray.
69 Args:
70 dtype: (required) data type of the TensorArray.
71 size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
72 Required if handle is not provided.
73 dynamic_size: (optional) Python bool: If true, writes to the TensorArray
74 can grow the TensorArray past its initial size. Default: False.
75 clear_after_read: Boolean (optional, default: True). If True, clear
76 TensorArray values after reading them. This disables read-many
77 semantics, but allows early release of memory.
78 tensor_array_name: (optional) Python string: the name of the TensorArray.
79 This is used when creating the TensorArray handle. If this value is
80 set, handle should be None.
81 handle: (optional) A `Tensor` handle to an existing TensorArray. If this
82 is set, tensor_array_name should be None. Only supported in graph mode.
83 flow: (optional) A float `Tensor` scalar coming from an existing
84 `TensorArray.flow`. Only supported in graph mode.
85 infer_shape: (optional, default: True) If True, shape inference is
86 enabled. In this case, all elements must have the same shape.
87 element_shape: (optional, default: None) A `TensorShape` object specifying
88 the shape constraints of each of the elements of the TensorArray. Need
89 not be fully defined.
90 colocate_with_first_write_call: If `True`, the TensorArray will be
91 colocated on the same device as the Tensor used on its first write
92 (write operations include `write`, `unstack`, and `split`). If `False`,
93 the TensorArray will be placed on the device determined by the device
94 context available during its initialization.
95 name: A name for the operation (optional).
97 Raises:
98 ValueError: if both handle and tensor_array_name are provided.
99 TypeError: if handle is provided but is not a Tensor.
100 """
101 if handle is not None and tensor_array_name:
102 raise ValueError(
103 "Cannot provide both `handle` and `tensor_array_name` arguments at "
104 "the same time.")
105 if handle is not None and not isinstance(handle, ops.Tensor):
106 raise TypeError(
107 f"Expected `handle` to be a Tensor, but got `{handle}` of type "
108 f"`{type(handle)}` instead.")
109 if handle is None and size is None:
110 raise ValueError(
111 "Argument `size` must be provided if handle is not provided.")
112 if handle is not None and size is not None:
113 raise ValueError("Cannot provide both a `handle` and `size` arguments "
114 "at the same time.")
115 if handle is not None and element_shape is not None:
116 raise ValueError(
117 "Cannot provide both `handle` and `element_shape` arguments "
118 "at the same time.")
119 if handle is not None and dynamic_size is not None:
120 raise ValueError(
121 "Cannot provide both `handle` and `dynamic_size` arguments "
122 "at the same time.")
123 if handle is not None and clear_after_read is not None:
124 raise ValueError(
125 "Cannot provide both `handle` and `clear_after_read` arguments "
126 "at the same time.")
128 if clear_after_read is None:
129 clear_after_read = True
130 self._dynamic_size = dynamic_size or False
131 self._dtype = dtypes.as_dtype(dtype).base_dtype
133 # Used to keep track of what tensors the TensorArray should be
134 # colocated with. We choose to colocate the TensorArray with the
135 # first tensor written to it.
136 self._colocate_with_first_write_call = colocate_with_first_write_call
137 if colocate_with_first_write_call:
138 self._colocate_with = []
139 else:
140 self._colocate_with = None
142 # Record the current static shape for the array elements. The element
143 # shape is defined either by `element_shape` or the shape of the tensor
144 # of the first write. If `infer_shape` is true, all writes checks for
145 # shape equality.
146 self._element_shape = [tensor_shape.as_shape(element_shape)]
147 self._infer_shape = infer_shape
148 self._size = size
149 with ops.name_scope(name, "TensorArray", [handle, size, flow]) as scope:
150 if handle is not None:
151 self._handle = handle
152 if flow is None:
153 raise ValueError("flow must not be None if handle is not None.")
154 self._flow = flow
155 else:
156 # Construct the TensorArray with an empty device. The first
157 # write into the TensorArray from a Tensor with a set device
158 # will retroactively set the device value of this op.
159 def create():
160 """Create the TensorArray op."""
161 return gen_data_flow_ops.tensor_array_v3(
162 dtype=dtype,
163 size=size,
164 element_shape=element_shape,
165 identical_element_shapes=infer_shape,
166 dynamic_size=self._dynamic_size,
167 clear_after_read=clear_after_read,
168 tensor_array_name=tensor_array_name,
169 name=scope)
171 if colocate_with_first_write_call:
172 with ops.device(None), ops.colocate_with(None, ignore_existing=True):
173 self._handle, self._flow = create()
174 else:
175 self._handle, self._flow = create()
177 @property
178 def flow(self):
179 return self._flow
181 @property
182 def dtype(self):
183 return self._dtype
185 @property
186 def handle(self):
187 return self._handle
189 @property
190 def element_shape(self):
191 return self._element_shape[0]
193 def _check_element_shape(self, shape):
194 """Changes the element shape of the array given a shape to merge with.
196 Args:
197 shape: A `TensorShape` object to merge with.
199 Raises:
200 ValueError: if the provided shape is incompatible with the current
201 element shape of the `TensorArray`.
202 """
203 if not shape.is_compatible_with(self.element_shape):
204 raise ValueError("Inconsistent shapes: saw %s but expected %s " %
205 (shape, self.element_shape))
206 if self._infer_shape:
207 self._element_shape[0] = self.element_shape.merge_with(shape)
209 @contextlib.contextmanager
210 def _maybe_colocate_with(self, value):
211 """Colocate operations with an internal colocation group or `value`.
213 Args:
214 value: `Tensor`, the tensor to try to colocate with.
216 Yields:
217 Does not yield anything, but the new context is a colocation context.
219 If no internal colocation group is set, colocate with `value` and set
220 the internal colocation group to be value.
221 """
222 if not self._colocate_with_first_write_call:
223 yield
224 else:
225 if not self._colocate_with:
226 self._colocate_with.append(value)
227 with ops.colocate_with(self._colocate_with[0]):
228 yield
230 def identity(self):
231 """See TensorArray."""
232 flow = array_ops.identity(self._flow)
233 return build_ta_with_new_flow(self, flow)
235 def grad(self, source, flow=None, name=None):
236 """See TensorArray."""
237 # tensor_array_grad requires a flow input when forward
238 # TensorArrays are dynamically sized. This forces the creation
239 # of the grad TensorArray only once the final forward array's size
240 # is fixed.
241 if flow is None:
242 flow = self.flow
243 with ops.name_scope(name, "TensorArrayGrad", [self._handle]):
244 with ops.colocate_with(self._handle):
245 g_handle, unused_flow = gen_data_flow_ops.tensor_array_grad_v3(
246 handle=self._handle, source=source, flow_in=flow, name=name)
247 with ops.control_dependencies([g_handle]):
248 flow = array_ops.identity(flow, name="gradient_flow")
249 g = TensorArray(
250 dtype=self._dtype,
251 handle=g_handle,
252 flow=flow,
253 infer_shape=self._infer_shape,
254 colocate_with_first_write_call=False)
255 # pylint: disable=protected-access
256 g._implementation._element_shape = self._element_shape
257 # pylint: enable=protected-access
258 return g
260 def read(self, index, name=None):
261 """See TensorArray."""
262 value = gen_data_flow_ops.tensor_array_read_v3(
263 handle=self._handle,
264 index=index,
265 flow_in=self._flow,
266 dtype=self._dtype,
267 name=name)
268 if self._element_shape:
269 value.set_shape(self._element_shape[0].dims)
270 return value
272 def write(self, index, value, name=None):
273 """See TensorArray."""
274 with ops.name_scope(name, "TensorArrayWrite", [self._handle, index, value]):
275 # TODO(b/129870929): Fix after all callers provide proper init dtype.
276 value = ops.convert_to_tensor(
277 value, preferred_dtype=self._dtype, name="value")
278 _check_dtypes(value, self._dtype)
279 self._check_element_shape(value.shape)
280 with self._maybe_colocate_with(value):
281 flow_out = gen_data_flow_ops.tensor_array_write_v3(
282 handle=self._handle,
283 index=index,
284 value=value,
285 flow_in=self._flow,
286 name=name)
287 return build_ta_with_new_flow(self, flow_out)
289 def stack(self, name=None):
290 """See TensorArray."""
291 with ops.colocate_with(self._handle):
292 with ops.name_scope(name, "TensorArrayStack", [self._handle]):
293 value = self.gather(math_ops.range(0, self.size()), name=name)
294 if (self.element_shape and not self._dynamic_size and
295 self._size is not None):
296 value.set_shape([tensor_util.constant_value(self._size)] +
297 self.element_shape.dims)
298 return value
300 def gather(self, indices, name=None):
301 """See TensorArray."""
302 if self._element_shape:
303 element_shape = self._element_shape[0]
304 else:
305 element_shape = tensor_shape.unknown_shape(None)
306 value = gen_data_flow_ops.tensor_array_gather_v3(
307 handle=self._handle,
308 indices=indices,
309 flow_in=self._flow,
310 dtype=self._dtype,
311 name=name,
312 element_shape=element_shape)
313 if self.element_shape:
314 value.set_shape([None] + self.element_shape.dims)
315 return value
317 def concat(self, name=None):
318 """See TensorArray."""
319 value, _ = gen_data_flow_ops.tensor_array_concat_v3(
320 handle=self._handle,
321 flow_in=self._flow,
322 dtype=self._dtype,
323 name=name,
324 element_shape_except0=self.element_shape[1:])
325 if self.element_shape:
326 dim0 = None
327 if self._infer_shape:
328 size = tensor_util.constant_value(self.size())
329 if size is not None and self.element_shape[0] is not None:
330 dim0 = size * self.element_shape[0]
331 value.set_shape([dim0] + self.element_shape.dims[1:])
332 return value
334 @tf_should_use.should_use_result
335 def unstack(self, value, name=None):
336 """See TensorArray."""
337 with ops.name_scope(name, "TensorArrayUnstack", [self._handle, value]):
338 num_elements = array_ops.shape(value)[0]
339 return self.scatter(
340 indices=math_ops.range(0, num_elements), value=value, name=name)
342 @tf_should_use.should_use_result
343 def scatter(self, indices, value, name=None):
344 """See TensorArray."""
345 with ops.name_scope(name, "TensorArrayScatter",
346 [self._handle, value, indices]):
347 # TODO(b/129870929): Fix after all callers provide proper init dtype.
348 value = ops.convert_to_tensor(
349 value, preferred_dtype=self._dtype, name="value")
350 _check_dtypes(value, self._dtype)
351 if not context.executing_eagerly():
352 self._check_element_shape(value.shape[1:])
353 with self._maybe_colocate_with(value):
354 flow_out = gen_data_flow_ops.tensor_array_scatter_v3(
355 handle=self._handle,
356 indices=indices,
357 value=value,
358 flow_in=self._flow,
359 name=name)
360 return build_ta_with_new_flow(self, flow_out)
362 @tf_should_use.should_use_result
363 def split(self, value, lengths, name=None):
364 """See TensorArray."""
365 with ops.name_scope(name, "TensorArraySplit",
366 [self._handle, value, lengths]):
367 value = ops.convert_to_tensor(value, dtype=self._dtype, name="value")
368 with self._maybe_colocate_with(value):
369 lengths_64 = math_ops.cast(lengths, dtypes.int64)
370 if not context.executing_eagerly():
371 clengths = tensor_util.constant_value(lengths_64)
372 if value.shape.dims is not None and clengths is not None:
373 if clengths.shape and clengths.max() == clengths.min():
374 self._check_element_shape(
375 tensor_shape.TensorShape([clengths[0]
376 ]).concatenate(value.shape[1:]))
377 flow_out = gen_data_flow_ops.tensor_array_split_v3(
378 handle=self._handle,
379 value=value,
380 lengths=lengths_64,
381 flow_in=self._flow,
382 name=name)
383 return build_ta_with_new_flow(self, flow_out)
385 def size(self, name=None):
386 """See TensorArray."""
387 if not self._dynamic_size and self._size is not None:
388 return ops.convert_to_tensor(self._size, dtype=dtypes.int32)
389 else:
390 return gen_data_flow_ops.tensor_array_size_v3(
391 handle=self._handle, flow_in=self.flow, name=name)
393 @tf_should_use.should_use_result
394 def close(self, name=None):
395 """See TensorArray."""
396 return gen_data_flow_ops.tensor_array_close_v3(
397 handle=self._handle, name=name)
400class _GraphTensorArrayV2:
401 """Graph-mode implementation of TensorArray backed by TensorLists.
403 The backing tensor of this TensorArray is a TensorList variant tensor which is
404 stored in the `flow`. The `handle` is always none here. The reason we use the
405 `flow` field and not the `handle` field is to ensure backwards compatibility
406 with legacy control flow.
407 """
409 def __init__(self,
410 dtype,
411 size=None,
412 dynamic_size=None,
413 clear_after_read=None,
414 tensor_array_name=None,
415 handle=None,
416 flow=None,
417 infer_shape=True,
418 element_shape=None,
419 colocate_with_first_write_call=True,
420 name=None):
421 """Constructs a graph mode TensorArray.
423 Args:
424 dtype: (required) data type of the TensorArray.
425 size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
426 Required if flow is not provided.
427 dynamic_size: (optional) Python bool: If true, writes to the TensorArray
428 can grow the TensorArray past its initial size. Default: False.
429 clear_after_read: (optional) unused. Not supported in TensorLists.
430 tensor_array_name: (optional) unused.
431 handle: (optional) Must always be None.
432 flow: (optional) A variant `Tensor` scalar for a TensorList.
433 infer_shape: (optional, default: True) If True, shape inference is
434 enabled. In this case, all elements must have the same shape.
435 element_shape: (optional, default: None) A `TensorShape` object specifying
436 the shape constraints of each of the elements of the TensorArray. Need
437 not be fully defined.
438 colocate_with_first_write_call: (optional). unused.
439 name: (optional) A name for the operation.
441 Raises:
442 ValueError: if both handle and tensor_array_name are provided.
443 TypeError: if handle is provided but is not a Tensor.
444 """
445 assert handle is None
446 del handle
447 del clear_after_read
448 del tensor_array_name
449 del colocate_with_first_write_call
451 self._dynamic_size = dynamic_size
452 self._size = size
454 if (flow is not None and
455 (not isinstance(flow, ops.Tensor) or flow.dtype != dtypes.variant)):
456 raise TypeError(
457 f"Expected `flow` to be a variant tensor, but received `{flow.dtype}` "
458 f"instead.")
459 if flow is None and size is None:
460 raise ValueError("Argument `size` must be provided if argument `flow` "
461 "is not provided.")
462 if flow is not None and size is not None:
463 raise ValueError("Cannot provide both `flow` and `size` arguments "
464 "at the same time.")
465 if flow is not None and element_shape is not None:
466 raise ValueError(
467 "Cannot provide both `flow` and `element_shape` arguments"
468 "at the same time.")
470 self._dtype = dtypes.as_dtype(dtype).base_dtype
472 # Record the current static shape for the array elements. The element
473 # shape is defined either by `element_shape` or the shape of the tensor
474 # of the first write. If `infer_shape` is true, all writes checks for
475 # shape equality.
476 self._element_shape = [tensor_shape.as_shape(element_shape)]
477 self._infer_shape = infer_shape
478 with ops.name_scope(name, "TensorArrayV2", [size, flow]) as scope:
479 if flow is None:
480 self._flow = list_ops.tensor_list_reserve(
481 element_shape=element_shape,
482 num_elements=size,
483 element_dtype=dtype,
484 name=scope)
485 else:
486 self._flow = flow
488 # For backwards compatibility.
489 self._colocate_with_first_write_call = None
490 self._colocate_with = None
492 @property
493 def flow(self):
494 return self._flow
496 @property
497 def dtype(self):
498 return self._dtype
500 @property
501 def element_shape(self):
502 return self._element_shape[0]
504 @property
505 def handle(self):
506 # We intentionally do not raise an error so that legacy while_loop does not
507 # complain.
508 return None
510 def _check_element_shape(self, shape):
511 """Changes the element shape of the array given a shape to merge with.
513 Args:
514 shape: A `TensorShape` object to merge with.
516 Raises:
517 ValueError: if the provided shape is incompatible with the current
518 element shape of the `TensorArray`.
519 """
520 if not shape.is_compatible_with(self.element_shape):
521 raise ValueError("Inconsistent shapes: saw %s but expected %s " %
522 (shape, self.element_shape))
523 if self._infer_shape:
524 self._element_shape[0] = self.element_shape.merge_with(shape)
526 def identity(self):
527 """See TensorArray."""
528 flow = array_ops.identity(self._flow)
529 return build_ta_with_new_flow(self, flow)
531 def grad(self, source, flow=None, name=None):
532 """Not supported."""
533 raise NotImplementedError()
535 def read(self, index, name=None):
536 """See TensorArray."""
537 with ops.name_scope(name, "TensorArrayV2Read", [self._flow, index]):
538 value = list_ops.tensor_list_get_item(
539 input_handle=self._flow,
540 index=index,
541 element_dtype=self._dtype,
542 element_shape=self.element_shape,
543 name=name)
544 return value
546 def write(self, index, value, name=None):
547 """See TensorArray."""
548 with ops.name_scope(name, "TensorArrayV2Write", [self._flow, index, value]):
549 # TODO(b/129870929): Fix after all callers provide proper init dtype.
550 value = ops.convert_to_tensor(
551 value, preferred_dtype=self._dtype, name="value")
552 _check_dtypes(value, self._dtype)
553 self._check_element_shape(value.shape)
554 flow_out = list_ops.tensor_list_set_item(
555 input_handle=self._flow,
556 index=index,
557 item=value,
558 resize_if_index_out_of_bounds=self._dynamic_size,
559 name=name)
560 return build_ta_with_new_flow(self, flow_out)
562 def stack(self, name=None):
563 """See TensorArray."""
564 with ops.name_scope(name, "TensorArrayV2Stack", [self._flow]):
565 # TODO(b/139941163): remove constant_value after changing num_elements to regular input
566 if not self._dynamic_size and self._size is not None:
567 ta_size = tensor_util.constant_value(self._size)
568 else:
569 ta_size = -1
570 value = list_ops.tensor_list_stack(
571 input_handle=self._flow,
572 element_dtype=self._dtype,
573 num_elements=ta_size,
574 element_shape=self.element_shape)
575 return value
577 def gather(self, indices, name=None):
578 """See TensorArray."""
579 value = list_ops.tensor_list_gather(
580 input_handle=self._flow,
581 indices=indices,
582 element_dtype=self._dtype,
583 element_shape=self.element_shape,
584 name=name)
585 return value
587 def concat(self, name=None):
588 """See TensorArray."""
589 if self.element_shape:
590 element_shape = [None] + self.element_shape.dims[1:]
591 else:
592 element_shape = None
594 value = list_ops.tensor_list_concat(
595 input_handle=self._flow,
596 element_dtype=self._dtype,
597 element_shape=element_shape,
598 name=name)
599 return value
601 @tf_should_use.should_use_result
602 def unstack(self, value, name=None):
603 """See TensorArray."""
604 with ops.name_scope(name, "TensorArrayUnstack", [self._flow, value]):
605 # TODO(b/129870929): Fix after all callers provide proper init dtype.
606 value = ops.convert_to_tensor(
607 value, preferred_dtype=self._dtype, name="value")
608 _check_dtypes(value, self._dtype)
609 self._check_element_shape(value.shape[1:])
610 flow_out = list_ops.tensor_list_from_tensor(
611 tensor=value, element_shape=value.shape[1:])
612 return build_ta_with_new_flow(self, flow_out)
614 @tf_should_use.should_use_result
615 def scatter(self, indices, value, name=None):
616 """See TensorArray."""
617 with ops.name_scope(name, "TensorArrayScatter",
618 [self._flow, value, indices]):
619 # TODO(b/129870929): Fix after all callers provide proper init dtype.
620 value = ops.convert_to_tensor(
621 value, preferred_dtype=self._dtype, name="value")
622 _check_dtypes(value, self._dtype)
623 self._check_element_shape(value.shape[1:])
624 flow_out = list_ops.tensor_list_scatter(
625 tensor=value,
626 indices=indices,
627 element_shape=self.element_shape,
628 input_handle=self._flow)
629 return build_ta_with_new_flow(self, flow_out)
631 @tf_should_use.should_use_result
632 def split(self, value, lengths, name=None):
633 """See TensorArray."""
634 with ops.name_scope(name, "TensorArraySplit", [self._flow, value, lengths]):
635 # TODO(b/129870929): Fix after all callers provide proper init dtype.
636 value = ops.convert_to_tensor(
637 value, preferred_dtype=self._dtype, name="value")
638 _check_dtypes(value, self._dtype)
639 lengths_64 = math_ops.cast(lengths, dtypes.int64)
640 if not context.executing_eagerly():
641 clengths = tensor_util.constant_value(lengths_64)
642 if value.shape.dims is not None and clengths is not None:
643 if clengths.shape and clengths.max() == clengths.min():
644 self._check_element_shape(
645 tensor_shape.TensorShape([clengths[0]
646 ]).concatenate(value.shape[1:]))
647 flow_out = list_ops.tensor_list_split(
648 tensor=value,
649 lengths=lengths_64,
650 element_shape=self.element_shape,
651 name=name)
652 return build_ta_with_new_flow(self, flow_out)
654 def size(self, name=None):
655 """See TensorArray."""
656 if not self._dynamic_size and self._size is not None:
657 return ops.convert_to_tensor(self._size, dtype=dtypes.int32)
658 else:
659 return list_ops.tensor_list_length(input_handle=self._flow, name=name)
661 def close(self, name=None):
662 """See TensorArray."""
663 return gen_control_flow_ops.no_op(name=name)
666# pylint: enable=protected-access
669class _EagerTensorArray:
670 """Eager-compatible implementation of TensorArray."""
672 def __init__(self,
673 dtype,
674 size=None,
675 dynamic_size=None,
676 clear_after_read=None,
677 tensor_array_name=None,
678 handle=None,
679 flow=None,
680 infer_shape=True,
681 element_shape=None,
682 colocate_with_first_write_call=True,
683 name=None):
684 """Constructs a TensorArray compatible with eager execution.
686 Args:
687 dtype: (required) data type of the TensorArray.
688 size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
689 Required if handle is not provided.
690 dynamic_size: (optional) Python bool: If true, writes to the TensorArray
691 can grow the TensorArray past its initial size. Default: False.
692 clear_after_read: Boolean (optional, default: True). If True, clear
693 TensorArray values after reading them. This disables read-many
694 semantics, but allows early release of memory.
695 tensor_array_name: unused.
696 handle: unsupported.
697 flow: unsupported.
698 infer_shape: used for error checking, same semantics as TensorArray.
699 element_shape: used for error checking, same semantics as TensorArray.
700 colocate_with_first_write_call: unsupported.
701 name: unsupported.
703 Raises:
704 ValueError: handle or flow are supplied, or if size is not supplied.
705 """
707 del (flow, tensor_array_name, name) # Unused.
709 if handle is not None:
710 raise ValueError("TensorArray handles are not supported when eager "
711 "execution is enabled.")
712 if size is None:
713 raise ValueError("Size must be declared for TensorArrays when eager "
714 "execution is enabled.")
716 # These attributes are not meaningful when eager is enabled, but some
717 # library functions (e.g., those in control_flow_ops.py) access them to
718 # create new tensor arrays; as such, we define them for the sake of
719 # compatibility.
720 self._handle = None
721 # we assign a dummy value to _flow in case other code assumes it to be
722 # a Tensor
723 self._flow = constant_op.constant(0, dtype=dtypes.int32)
724 self._infer_shape = infer_shape
725 self._element_shape = tensor_shape.as_shape(element_shape)
726 self._colocate_with_first_write_call = colocate_with_first_write_call
728 self._dtype = dtypes.as_dtype(dtype).base_dtype
729 self._dynamic_size = dynamic_size or False
730 self._clear_after_read = (True
731 if clear_after_read is None else clear_after_read)
732 self._previously_read_indices = []
734 if isinstance(size, ops.EagerTensor):
735 size = size.numpy()
736 self._tensor_array = [None for _ in range(size)]
738 @property
739 def flow(self):
740 """For compatibility; flows are not meaningful when eager is enabled."""
741 return self._flow
743 @property
744 def dtype(self):
745 return self._dtype
747 @property
748 def handle(self):
749 """For compatibility; handles are not meaningful when eager is enabled."""
750 return self._handle
752 @property
753 def element_shape(self):
754 return self._element_shape
756 def identity(self):
757 """See TensorArray."""
758 return self.parent()
760 def grad(self, source, flow=None, name=None):
761 raise NotImplementedError(
762 "TensorArray.grad is not supported when executing eagerly; eager's "
763 "gradient implementation does not use/need this function to compute "
764 "gradients of operations that use TensorArrays.")
766 def read(self, index, name=None):
767 """See TensorArray."""
768 del name # not meaningful when executing eagerly.
770 if isinstance(index, ops.EagerTensor):
771 index = index.numpy()
773 if index < 0:
774 raise errors_impl.OutOfRangeError(
775 None, None,
776 "Reading from negative indices (index %d) is not allowed." % index)
778 if index >= len(self._tensor_array):
779 raise errors_impl.OutOfRangeError(
780 None, None, "Tried to read from index %d but array size is: %d " %
781 (index, len(self._tensor_array)))
783 tensor = self._tensor_array[index]
784 if tensor is None:
785 if index in self._previously_read_indices:
786 raise errors_impl.InvalidArgumentError(
787 None, None,
788 "Could not read index %d twice because it was cleared after "
789 "a previous read (perhaps try setting clear_after_read = false?)" %
790 index)
791 else:
792 tensor = self._maybe_zero(index)
794 if self._clear_after_read:
795 self._tensor_array[index] = None
796 self._previously_read_indices.append(index)
797 return tensor
799 def _write(self, index, value):
800 """Writes `value` into index named by `index`.
802 Args:
803 index: 0-D. int32 scalar with the index to write to.
804 value: N-D. Tensor of type `dtype`. The `Tensor` to write to `index`.
806 Raises:
807 errors_impl.InvalidArgumentError: `value` dtype does not match dtype.
808 errors_impl.OutOfRangeError: `index` is out of bounds.
809 ValueError: shape of `value` is not consistent with inferred shape.
810 """
812 if isinstance(index, ops.EagerTensor):
813 index = index.numpy()
815 if index < 0:
816 raise errors_impl.OutOfRangeError(
817 None, None,
818 "Writing to negative indices (index %d) is not allowed." % index)
820 size = len(self._tensor_array)
821 if index >= size:
822 if not self._dynamic_size:
823 raise errors_impl.OutOfRangeError(
824 None, None,
825 "Tried to write to index %d but array is not resizeable and size "
826 "is: %d " % (index, size))
827 self._tensor_array.extend(None for _ in range(index - size + 1))
829 if not isinstance(value, ops.EagerTensor):
830 # TODO(b/129870929): Fix after all callers provide proper init dtype.
831 value = ops.convert_to_tensor(
832 value, preferred_dtype=self._dtype, name="value")
834 if self._dtype != value.dtype:
835 raise errors_impl.InvalidArgumentError(
836 None, None,
837 "TensorArray dtype is %s but Op is trying to write dtype %s " %
838 (self._dtype.name, value.dtype.name))
840 if not self._element_shape.is_compatible_with(value.shape):
841 raise ValueError("Incompatible shape for value (%s), expected (%s)" %
842 (value.shape, self._element_shape))
844 if self._infer_shape:
845 self._element_shape = self._element_shape.merge_with(value.shape)
847 self._tensor_array[index] = value
849 def write(self, index, value, name=None):
850 """See TensorArray."""
851 del name # not meaningful when executing eagerly.
852 self._write(index, value)
853 return self.parent()
855 def _maybe_zero(self, ix):
856 val = self._tensor_array[ix]
857 if val is None:
858 val = self._tensor_array[ix] = array_ops.zeros(
859 shape=self._element_shape, dtype=self._dtype)
860 return val
862 def stack(self, name=None):
863 """See TensorArray."""
864 if self._tensor_array:
865 for ix in range(len(self._tensor_array)):
866 self._maybe_zero(ix)
867 if not self._tensor_array and self._element_shape.is_fully_defined():
868 return ops.convert_to_tensor(
869 np.ndarray([0] + self._element_shape), name=name, dtype=self._dtype)
870 else:
871 return ops.convert_to_tensor(
872 self._tensor_array, name=name, dtype=self._dtype)
874 def gather(self, indices, name=None):
875 """See TensorArray."""
876 del name # not meaningful when executing eagerly.
877 if isinstance(indices, ops.EagerTensor):
878 indices = indices.numpy()
879 return array_ops_stack.stack([self._maybe_zero(i) for i in indices])
881 def concat(self, name=None):
882 """See TensorArray."""
883 try:
884 return array_ops.concat(
885 [self._maybe_zero(ix) for ix in range(len(self._tensor_array))],
886 0,
887 name=name)
888 except errors_impl.OpError:
889 # Reproduce a subset of the error-handling for graph-mode TensorArrays.
890 shapes = [t.shape for t in self._tensor_array]
891 ndims = [s.ndims for s in shapes]
892 if 0 in ndims:
893 idx = ndims.index(0)
894 raise errors_impl.InvalidArgumentError(
895 None, None, "Concat saw a scalar shape at index %d but requires "
896 "at least vectors." % idx)
897 else:
898 raise
900 def unstack(self, value, name=None):
901 """See TensorArray."""
902 tensors = array_ops_stack.unstack(value, name=name)
903 if len(tensors) > len(self._tensor_array) and not self._dynamic_size:
904 raise ValueError(
905 "Cannot unstack %d tensors into a TensorArray of static size %d " %
906 (len(tensors), len(self._tensor_array)))
907 self._tensor_array = tensors
908 return self.parent()
910 def scatter(self, indices, value, name=None):
911 """See TensorArray."""
912 del name # not meaningful when executing eagerly.
913 if isinstance(indices, ops.EagerTensor):
914 indices = indices.numpy()
915 for index, val in zip(indices, array_ops_stack.unstack(value)):
916 self._write(index, val) # pylint: disable=protected-access
917 return self.parent()
919 def split(self, value, lengths, name=None):
920 """See TensorArray."""
921 # TODO(b/129870929): Fix after all callers provide proper init dtype.
922 value = ops.convert_to_tensor(
923 value, preferred_dtype=self._dtype, name="value")
924 _check_dtypes(value, self._dtype)
925 lengths = ops.convert_to_tensor(lengths)
926 sum_lengths = math_ops.reduce_sum(lengths)
927 if lengths.shape.ndims != 1:
928 raise errors_impl.InvalidArgumentError(
929 None, None, "Expected lengths to be a vector, received shape: %s " %
930 lengths.shape.as_list())
931 elif value.shape.ndims == 0:
932 raise errors_impl.InvalidArgumentError(
933 None, None, "Expected value to be at least a vector, "
934 "but received shape: %s " % value.shape.as_list())
935 elif sum_lengths.numpy() != value.shape.as_list()[0]:
936 raise errors_impl.InvalidArgumentError(
937 None, None, "Expected sum of lengths to be equal to "
938 "values.shape[0], but sum of lengths is %d and "
939 "value's shape is: %s " % (sum_lengths.numpy(),
940 value.shape.as_list()))
941 elif not self._dynamic_size and lengths.shape[0] != len(self._tensor_array):
942 raise errors_impl.InvalidArgumentError(
943 None, None, "TensorArray's size is not equal to the size of "
944 "lengths (%d vs. %d), and the TensorArray is not marked as "
945 "dynamically resizeable." %
946 (len(self._tensor_array), lengths.shape[0]))
947 else:
948 self._tensor_array = array_ops.split(value, lengths, name=name)
949 return self.parent()
951 def size(self, name=None):
952 """See TensorArray."""
953 del name # not meaningful when executing eagerly.
954 return constant_op.constant(len(self._tensor_array))
956 def close(self, name=None):
957 del name # not meaningful when executing eagerly.
958 del self._tensor_array[:]
961# TensorArray is designed to hide an underlying implementation object
962# and as such accesses many of that object's hidden fields.
963# pylint: disable=protected-access
964# pylint:disable=line-too-long
965@tf_export("TensorArray")
966class TensorArray:
967 """Class wrapping dynamic-sized, per-time-step, Tensor arrays.
969 This class is meant to be used with dynamic iteration primitives such as
970 `while_loop` and `map_fn`. It supports gradient back-propagation via special
971 "flow" control flow dependencies.
973 Note that although the array can be read multiple times and positions can be
974 overwritten, behavior may be undefined when storing multiple references to
975 the same array and clear_after_read is False. In particular, avoid using
976 methods like concat() to convert an intermediate TensorArray to a Tensor,
977 then further modifying the TensorArray, particularly if you need to backprop
978 through it later.
980 Example 1: Plain reading and writing.
982 >>> ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True, clear_after_read=False)
983 >>> ta = ta.write(0, 10)
984 >>> ta = ta.write(1, 20)
985 >>> ta = ta.write(2, 30)
986 >>>
987 >>> ta.read(0)
988 <tf.Tensor: shape=(), dtype=float32, numpy=10.0>
989 >>> ta.read(1)
990 <tf.Tensor: shape=(), dtype=float32, numpy=20.0>
991 >>> ta.read(2)
992 <tf.Tensor: shape=(), dtype=float32, numpy=30.0>
993 >>> ta.stack()
994 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([10., 20., 30.],
995 dtype=float32)>
997 Example 2: Fibonacci sequence algorithm that writes in a loop then returns.
999 >>> @tf.function
1000 ... def fibonacci(n):
1001 ... ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
1002 ... ta = ta.unstack([0., 1.])
1003 ...
1004 ... for i in range(2, n):
1005 ... ta = ta.write(i, ta.read(i - 1) + ta.read(i - 2))
1006 ...
1007 ... return ta.stack()
1008 >>>
1009 >>> fibonacci(7)
1010 <tf.Tensor: shape=(7,), dtype=float32,
1011 numpy=array([0., 1., 1., 2., 3., 5., 8.], dtype=float32)>
1013 Example 3: A simple loop interacting with a `tf.Variable`.
1015 >>> v = tf.Variable(1)
1016 >>> @tf.function
1017 ... def f(x):
1018 ... ta = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
1019 ... for i in tf.range(x):
1020 ... v.assign_add(i)
1021 ... ta = ta.write(i, v)
1022 ... return ta.stack()
1023 >>> f(5)
1024 <tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 1, 2, 4, 7, 11],
1025 dtype=int32)>
1026 """
1028 def __init__(self,
1029 dtype,
1030 size=None,
1031 dynamic_size=None,
1032 clear_after_read=None,
1033 tensor_array_name=None,
1034 handle=None,
1035 flow=None,
1036 infer_shape=True,
1037 element_shape=None,
1038 colocate_with_first_write_call=True,
1039 name=None):
1040 """Construct a new TensorArray or wrap an existing TensorArray handle.
1042 A note about the parameter `name`:
1044 The name of the `TensorArray` (even if passed in) is uniquified: each time
1045 a new `TensorArray` is created at runtime it is assigned its own name for
1046 the duration of the run. This avoids name collisions if a `TensorArray`
1047 is created within a `while_loop`.
1049 Args:
1050 dtype: (required) data type of the TensorArray.
1051 size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
1052 Required if handle is not provided.
1053 dynamic_size: (optional) Python bool: If true, writes to the TensorArray
1054 can grow the TensorArray past its initial size. Default: False.
1055 clear_after_read: Boolean (optional, default: True). If True, clear
1056 TensorArray values after reading them. This disables read-many
1057 semantics, but allows early release of memory.
1058 tensor_array_name: (optional) Python string: the name of the TensorArray.
1059 This is used when creating the TensorArray handle. If this value is
1060 set, handle should be None.
1061 handle: (optional) A `Tensor` handle to an existing TensorArray. If this
1062 is set, tensor_array_name should be None. Only supported in graph mode.
1063 flow: (optional) A float `Tensor` scalar coming from an existing
1064 `TensorArray.flow`. Only supported in graph mode.
1065 infer_shape: (optional, default: True) If True, shape inference is
1066 enabled. In this case, all elements must have the same shape.
1067 element_shape: (optional, default: None) A `TensorShape` object specifying
1068 the shape constraints of each of the elements of the TensorArray. Need
1069 not be fully defined.
1070 colocate_with_first_write_call: If `True`, the TensorArray will be
1071 colocated on the same device as the Tensor used on its first write
1072 (write operations include `write`, `unstack`, and `split`). If `False`,
1073 the TensorArray will be placed on the device determined by the device
1074 context available during its initialization.
1075 name: A name for the operation (optional).
1077 Raises:
1078 ValueError: if both handle and tensor_array_name are provided.
1079 TypeError: if handle is provided but is not a Tensor.
1080 """
1081 if (context.executing_eagerly() and
1082 (flow is None or flow.dtype != dtypes.variant)):
1083 # It is possible to create a Variant-style TensorArray even in eager mode,
1084 # and this is fine but can have performance implications in eager.
1085 # An example of when this happens is if a tf.function returns a
1086 # TensorArray in its output; its flow variant object is returned to Eager.
1087 # This can be wrapped back up in a Variant-style TensorArray.
1088 implementation = _EagerTensorArray
1089 elif (flow is not None and flow.dtype == dtypes.variant or
1090 control_flow_util.EnableControlFlowV2(ops.get_default_graph())):
1091 implementation = _GraphTensorArrayV2
1092 else:
1093 implementation = _GraphTensorArray
1094 self._implementation = implementation(
1095 dtype,
1096 size=size,
1097 dynamic_size=dynamic_size,
1098 clear_after_read=clear_after_read,
1099 tensor_array_name=tensor_array_name,
1100 handle=handle,
1101 flow=flow,
1102 infer_shape=infer_shape,
1103 element_shape=element_shape,
1104 colocate_with_first_write_call=colocate_with_first_write_call,
1105 name=name)
1107 self._implementation.parent = weakref.ref(self)
1109 @property
1110 def flow(self):
1111 """The flow `Tensor` forcing ops leading to this TensorArray state."""
1112 return self._implementation._flow
1114 @property
1115 def dtype(self):
1116 """The data type of this TensorArray."""
1117 return self._implementation._dtype
1119 @property
1120 def handle(self):
1121 """The reference to the TensorArray."""
1122 return self._implementation.handle
1124 @property
1125 def element_shape(self):
1126 """The `tf.TensorShape` of elements in this TensorArray."""
1127 return self._implementation.element_shape
1129 @property
1130 def dynamic_size(self):
1131 """Python bool; if `True` the TensorArray can grow dynamically."""
1132 return self._implementation._dynamic_size
1134 @property
1135 def _infer_shape(self):
1136 # TODO(slebedev): consider making public or changing TensorArrayStructure
1137 # to access _implementation directly. Note that dynamic_size is also
1138 # only used by TensorArrayStructure.
1139 return self._implementation._infer_shape
1141 def identity(self):
1142 """Returns a TensorArray with the same content and properties.
1144 Returns:
1145 A new TensorArray object with flow that ensures the control dependencies
1146 from the contexts will become control dependencies for writes, reads, etc.
1147 Use this object for all subsequent operations.
1148 """
1149 return self._implementation.identity()
1151 def grad(self, source, flow=None, name=None):
1152 return self._implementation.grad(source, flow=flow, name=name)
1154 def read(self, index, name=None):
1155 """Read the value at location `index` in the TensorArray.
1157 Args:
1158 index: 0-D. int32 tensor with the index to read from.
1159 name: A name for the operation (optional).
1161 Returns:
1162 The tensor at index `index`.
1163 """
1164 return self._implementation.read(index, name=name)
1166 @tf_should_use.should_use_result(warn_in_eager=True)
1167 def write(self, index, value, name=None):
1168 """Write `value` into index `index` of the TensorArray.
1170 Args:
1171 index: 0-D. int32 scalar with the index to write to.
1172 value: N-D. Tensor of type `dtype`. The Tensor to write to this index.
1173 name: A name for the operation (optional).
1175 Returns:
1176 A new TensorArray object with flow that ensures the write occurs.
1177 Use this object for all subsequent operations.
1179 Raises:
1180 ValueError: if there are more writers than specified.
1181 """
1182 return self._implementation.write(index, value, name=name)
1184 def stack(self, name=None):
1185 """Return the values in the TensorArray as a stacked `Tensor`.
1187 All of the values must have been written and their shapes must all match.
1188 If input shapes have rank-`R`, then output shape will have rank-`(R+1)`.
1190 For example:
1193 >>> ta = tf.TensorArray(tf.int32, size=3)
1194 >>> ta = ta.write(0, tf.constant([1, 2]))
1195 >>> ta = ta.write(1, tf.constant([3, 4]))
1196 >>> ta = ta.write(2, tf.constant([5, 6]))
1197 >>> ta.stack()
1198 <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
1199 array([[1, 2],
1200 [3, 4],
1201 [5, 6]], dtype=int32)>
1204 Args:
1205 name: A name for the operation (optional).
1207 Returns:
1208 All the tensors in the TensorArray stacked into one tensor.
1209 """
1210 return self._implementation.stack(name=name)
1212 def gather(self, indices, name=None):
1213 """Return selected values in the TensorArray as a packed `Tensor`.
1215 All of selected values must have been written and their shapes
1216 must all match.
1218 Args:
1219 indices: A `1-D` `Tensor` taking values in `[0, max_value)`. If the
1220 `TensorArray` is not dynamic, `max_value=size()`.
1221 name: A name for the operation (optional).
1223 Returns:
1224 The tensors in the `TensorArray` selected by `indices`, packed into one
1225 tensor.
1226 """
1227 return self._implementation.gather(indices, name=name)
1229 def concat(self, name=None):
1230 """Return the values in the TensorArray as a concatenated `Tensor`.
1232 All of the values must have been written, their ranks must match, and
1233 and their shapes must all match for all dimensions except the first.
1235 Args:
1236 name: A name for the operation (optional).
1238 Returns:
1239 All the tensors in the TensorArray concatenated into one tensor.
1240 """
1241 return self._implementation.concat(name=name)
1243 @tf_should_use.should_use_result
1244 def unstack(self, value, name=None):
1245 """Unstack the values of a `Tensor` in the TensorArray.
1247 If input value shapes have rank-`R`, then the output TensorArray will
1248 contain elements whose shapes are rank-`(R-1)`.
1250 Args:
1251 value: (N+1)-D. Tensor of type `dtype`. The Tensor to unstack.
1252 name: A name for the operation (optional).
1254 Returns:
1255 A new TensorArray object with flow that ensures the unstack occurs.
1256 Use this object for all subsequent operations.
1258 Raises:
1259 ValueError: if the shape inference fails.
1260 """
1261 return self._implementation.unstack(value, name=name)
1263 @tf_should_use.should_use_result
1264 def scatter(self, indices, value, name=None):
1265 """Scatter the values of a `Tensor` in specific indices of a `TensorArray`.
1267 Args:
1268 indices: A `1-D` `Tensor` taking values in `[0, max_value)`. If the
1269 `TensorArray` is not dynamic, `max_value=size()`.
1270 value: (N+1)-D. Tensor of type `dtype`. The Tensor to unpack.
1271 name: A name for the operation (optional).
1273 Returns:
1274 A new TensorArray object with flow that ensures the scatter occurs.
1275 Use this object for all subsequent operations.
1277 Raises:
1278 ValueError: if the shape inference fails.
1279 """
1280 return self._implementation.scatter(indices, value, name=name)
1282 @tf_should_use.should_use_result
1283 def split(self, value, lengths, name=None):
1284 """Split the values of a `Tensor` into the TensorArray.
1286 Args:
1287 value: (N+1)-D. Tensor of type `dtype`. The Tensor to split.
1288 lengths: 1-D. int32 vector with the lengths to use when splitting `value`
1289 along its first dimension.
1290 name: A name for the operation (optional).
1292 Returns:
1293 A new TensorArray object with flow that ensures the split occurs.
1294 Use this object for all subsequent operations.
1296 Raises:
1297 ValueError: if the shape inference fails.
1298 """
1299 return self._implementation.split(value, lengths, name=name)
1301 def size(self, name=None):
1302 """Return the size of the TensorArray."""
1303 return self._implementation.size(name=name)
1305 @tf_should_use.should_use_result
1306 def close(self, name=None):
1307 """Close the current TensorArray."""
1308 return self._implementation.close(name=name)
1311def build_ta_with_new_flow(old_ta, flow):
1312 """Builds a TensorArray with a new `flow` tensor."""
1313 # Sometimes we get old_ta as the implementation, sometimes it's the
1314 # TensorArray wrapper object.
1315 impl = (old_ta._implementation if isinstance(old_ta, TensorArray) else old_ta)
1317 if not context.executing_eagerly():
1318 if (not isinstance(impl, _GraphTensorArrayV2) and
1319 control_flow_util.EnableControlFlowV2(ops.get_default_graph())):
1320 raise NotImplementedError("Attempting to build a graph-mode TF2-style "
1321 "TensorArray from either an eager-mode "
1322 "TensorArray or a TF1-style TensorArray. "
1323 "This is not currently supported. You may be "
1324 "attempting to capture a TensorArray "
1325 "inside a tf.function or tf.data map function. "
1326 "Instead, construct a new TensorArray inside "
1327 "the function.")
1328 new_ta = TensorArray(
1329 dtype=impl.dtype,
1330 handle=impl.handle,
1331 flow=flow,
1332 infer_shape=impl._infer_shape,
1333 colocate_with_first_write_call=impl._colocate_with_first_write_call)
1334 new_impl = new_ta._implementation
1335 new_impl._dynamic_size = impl._dynamic_size
1336 new_impl._size = impl._size
1337 new_impl._colocate_with = impl._colocate_with
1338 new_impl._element_shape = impl._element_shape # Share _element_shape.
1339 return new_ta
1342# pylint: enable=protected-access
1345def _check_dtypes(value, dtype):
1346 if value.dtype != dtype:
1347 logging.error("Error: Input value {} has dtype {}, but expected dtype {}. "
1348 "This leads to undefined behavior and will be an error "
1349 "in future versions of TensorFlow. Traceback:\n{}".format(
1350 value, str(value.dtype), str(dtype),
1351 "".join(traceback.format_stack())))
1354@tf_export("TensorArraySpec")
1355@type_spec_registry.register("tf.TensorArraySpec")
1356class TensorArraySpec(type_spec.TypeSpec):
1357 """Type specification for a `tf.TensorArray`."""
1359 __slots__ = ["_element_shape", "_dtype", "_dynamic_size", "_infer_shape"]
1361 value_type = property(lambda self: TensorArray)
1363 def __init__(self,
1364 element_shape=None,
1365 dtype=dtypes.float32,
1366 dynamic_size=False,
1367 infer_shape=True):
1368 """Constructs a type specification for a `tf.TensorArray`.
1370 Args:
1371 element_shape: The shape of each element in the `TensorArray`.
1372 dtype: Data type of the `TensorArray`.
1373 dynamic_size: Whether the `TensorArray` can grow past its initial size.
1374 infer_shape: Whether shape inference is enabled.
1375 """
1376 self._element_shape = tensor_shape.as_shape(element_shape)
1377 self._dtype = dtypes.as_dtype(dtype)
1378 self._dynamic_size = dynamic_size
1379 self._infer_shape = infer_shape
1381 def is_subtype_of(self, other):
1382 # pylint: disable=protected-access
1383 return (isinstance(other, TensorArraySpec) and
1384 self._dtype == other._dtype and
1385 self._dynamic_size == other._dynamic_size)
1387 def most_specific_common_supertype(self, others):
1388 """Returns the most specific supertype of `self` and `others`.
1390 Args:
1391 others: A Sequence of `TypeSpec`.
1393 Returns `None` if a supertype does not exist.
1394 """
1395 # pylint: disable=protected-access
1396 if not all(isinstance(other, TensorArraySpec) for other in others):
1397 return False
1399 common_shape = self._element_shape.most_specific_common_supertype(
1400 other._element_shape for other in others)
1401 if common_shape is None:
1402 return None
1404 if not all(self._dtype == other._dtype for other in others):
1405 return None
1407 if not all(self._dynamic_size == other._dynamic_size for other in others):
1408 return None
1410 infer_shape = self._infer_shape and all(
1411 other._infer_shape for other in others)
1413 return TensorArraySpec(common_shape, self._dtype, self._dynamic_size,
1414 infer_shape)
1416 def is_compatible_with(self, other):
1417 # pylint: disable=protected-access
1418 if not isinstance(other, type_spec.TypeSpec):
1419 other = type_spec.type_spec_from_value(other)
1421 # Note: we intentionally exclude infer_shape in this check.
1422 return (isinstance(other, TensorArraySpec) and
1423 self._dtype.is_compatible_with(other._dtype) and
1424 self._element_shape.is_compatible_with(other._element_shape) and
1425 self._dynamic_size == other._dynamic_size)
1427 def _serialize(self):
1428 return (self._element_shape, self._dtype, self._dynamic_size,
1429 self._infer_shape)
1431 @property
1432 def _component_specs(self):
1433 return [tensor_spec.TensorSpec([], dtypes.variant)]
1435 def _to_components(self, value):
1436 if not isinstance(value, TensorArray):
1437 raise TypeError("Expected value to be a TensorArray, but got: `{}`".format(
1438 type(value)))
1439 if value.flow is not None and value.flow.dtype == dtypes.variant:
1440 return [value.flow]
1441 else:
1442 # Convert to a TF2-style TensorArray.
1443 # TODO(ebrevdo): Add an "_as_variant" method to TensorArray class, or
1444 # "implementation / as_variant" arg to TensorArray constructor.
1445 with ops.name_scope("convert_tensor_array"):
1446 flow = list_ops.tensor_list_from_tensor(
1447 tensor=value.stack(), element_shape=value.element_shape)
1448 return [flow]
1450 def _from_components(self, tensor_list):
1451 # This will return a TF2 Graph-style TensorArray because tensor_list[0] is
1452 # a variant object. size == -1 implies unknown size.
1453 ret = TensorArray(
1454 dtype=self._dtype,
1455 flow=tensor_list[0],
1456 dynamic_size=self._dynamic_size,
1457 infer_shape=self._infer_shape)
1458 ret._implementation._element_shape = [self._element_shape] # pylint: disable=protected-access
1459 return ret
1461 @staticmethod
1462 def from_value(value):
1463 if not isinstance(value, TensorArray):
1464 raise TypeError("Expected value to be a TensorArray, but got: `{}`".format(
1465 type(value)))
1467 return TensorArraySpec(
1468 dtype=value.dtype,
1469 element_shape=value.element_shape,
1470 dynamic_size=value.dynamic_size,
1471 infer_shape=value._infer_shape) # pylint: disable=protected-access
1473 def _to_legacy_output_types(self):
1474 return self._dtype
1476 def _to_legacy_output_shapes(self):
1477 # Sneak the dynamic_size and infer_shape values into the legacy shape.
1478 return (tensor_shape.TensorShape([self._dynamic_size, self._infer_shape
1479 ]).concatenate(self._element_shape))
1481 def _to_legacy_output_classes(self):
1482 return TensorArray
1485nested_structure_coder.register_codec(
1486 nested_structure_coder.BuiltInTypeSpecCodec(
1487 TensorArraySpec, struct_pb2.TypeSpecProto.TENSOR_ARRAY_SPEC
1488 )
1489)
1492# Register the TypeSpec for TensorArray. If TensorArray is updated to be a
1493# CompositeTensor, then this registration can be deleted.
1494type_spec.register_type_spec_from_value_converter(
1495 TensorArray, TensorArraySpec.from_value, allow_subclass=True)