Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/operators/data_structures.py: 20%
147 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operators specific to data structures: list append, subscripts, etc."""
17import collections
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import cond
25from tensorflow.python.ops import list_ops
26from tensorflow.python.ops import tensor_array_ops
29# TODO(mdan): Once control flow supports objects, repackage as a class.
32def new_list(iterable=None):
33 """The list constructor.
35 Args:
36 iterable: Optional elements to fill the list with.
38 Returns:
39 A list-like object. The exact return value depends on the initial elements.
40 """
41 if iterable:
42 elements = tuple(iterable)
43 else:
44 elements = ()
46 if elements:
47 # When the list contains elements, it is assumed to be a "Python" lvalue
48 # list.
49 return _py_list_new(elements)
50 return tf_tensor_list_new(elements)
53def tf_tensor_array_new(elements, element_dtype=None, element_shape=None):
54 """Overload of new_list that stages a Tensor list creation."""
55 elements = tuple(ops.convert_to_tensor(el) for el in elements)
57 all_dtypes = set(el.dtype for el in elements)
58 if len(all_dtypes) == 1:
59 inferred_dtype, = tuple(all_dtypes)
60 if element_dtype is not None and element_dtype != inferred_dtype:
61 raise ValueError(
62 'incompatible dtype; specified: {}, inferred from {}: {}'.format(
63 element_dtype, elements, inferred_dtype))
64 elif len(all_dtypes) > 1:
65 raise ValueError(
66 'TensorArray requires all elements to have the same dtype:'
67 ' {}'.format(elements))
68 else:
69 if element_dtype is None:
70 raise ValueError('dtype is required to create an empty TensorArray')
72 all_shapes = set(tuple(el.shape.as_list()) for el in elements)
73 if len(all_shapes) == 1:
74 inferred_shape, = tuple(all_shapes)
75 if element_shape is not None and element_shape != inferred_shape:
76 raise ValueError(
77 'incompatible shape; specified: {}, inferred from {}: {}'.format(
78 element_shape, elements, inferred_shape))
79 elif len(all_shapes) > 1:
80 raise ValueError(
81 'TensorArray requires all elements to have the same shape:'
82 ' {}'.format(elements))
83 # TODO(mdan): We may want to allow different shapes with infer_shape=False.
84 else:
85 inferred_shape = None
87 if element_dtype is None:
88 element_dtype = inferred_dtype
89 if element_shape is None:
90 element_shape = inferred_shape
92 l = tensor_array_ops.TensorArray(
93 dtype=element_dtype,
94 size=len(elements),
95 dynamic_size=True,
96 infer_shape=(element_shape is None),
97 element_shape=element_shape)
98 for i, el in enumerate(elements):
99 l = l.write(i, el)
100 return l
103def tf_tensor_list_new(elements, element_dtype=None, element_shape=None):
104 """Overload of new_list that stages a Tensor list creation."""
105 if tensor_util.is_tf_type(elements):
106 if element_shape is not None:
107 raise ValueError(
108 'element shape may not be specified when creating list from tensor')
109 element_shape = array_ops.shape(elements)[1:]
110 l = list_ops.tensor_list_from_tensor(elements, element_shape=element_shape)
111 return l
113 elements = tuple(ops.convert_to_tensor(el) for el in elements)
115 all_dtypes = set(el.dtype for el in elements)
116 if len(all_dtypes) == 1:
117 inferred_dtype = tuple(all_dtypes)[0]
118 if element_dtype is not None and element_dtype != inferred_dtype:
119 raise ValueError(
120 'incompatible dtype; specified: {}, inferred from {}: {}'.format(
121 element_dtype, elements, inferred_dtype))
122 elif all_dtypes:
123 # Heterogeneous lists are ok.
124 if element_dtype is not None:
125 raise ValueError(
126 'specified dtype {} is inconsistent with that of elements {}'.format(
127 element_dtype, elements))
128 inferred_dtype = dtypes.variant
129 else:
130 inferred_dtype = dtypes.variant
132 all_shapes = set(tuple(el.shape.as_list()) for el in elements)
133 if len(all_shapes) == 1:
134 inferred_shape = array_ops.shape(elements[0])
135 if element_shape is not None and element_shape != inferred_shape:
136 raise ValueError(
137 'incompatible shape; specified: {}, inferred from {}: {}'.format(
138 element_shape, elements, inferred_shape))
139 elif all_shapes:
140 # Heterogeneous lists are ok.
141 if element_shape is not None:
142 raise ValueError(
143 'specified shape {} is inconsistent with that of elements {}'.format(
144 element_shape, elements))
145 inferred_shape = constant_op.constant(-1) # unknown shape, by convention
146 else:
147 inferred_shape = constant_op.constant(-1) # unknown shape, by convention
149 if element_dtype is None:
150 element_dtype = inferred_dtype
151 if element_shape is None:
152 element_shape = inferred_shape
154 element_shape = ops.convert_to_tensor(element_shape, dtype=dtypes.int32)
155 l = list_ops.empty_tensor_list(
156 element_shape=element_shape, element_dtype=element_dtype)
157 for el in elements:
158 l = list_ops.tensor_list_push_back(l, el)
159 return l
162def _py_list_new(elements):
163 """Overload of new_list that creates a Python list."""
164 return list(elements)
167def list_append(list_, x):
168 """The list append function.
170 Note: it is unspecified where list_ will be mutated or not. If list_ is
171 a TensorFlow entity, it will not be typically mutated. If list_ is a plain
172 list, it will be. In general, if the list is mutated then the return value
173 should point to the original entity.
175 Args:
176 list_: An entity that supports append semantics.
177 x: The element to append.
179 Returns:
180 Same as list_, after the append was performed.
182 Raises:
183 ValueError: if list_ is not of a known list-like type.
184 """
185 if isinstance(list_, tensor_array_ops.TensorArray):
186 return _tf_tensorarray_append(list_, x)
187 elif tensor_util.is_tf_type(list_):
188 if list_.dtype == dtypes.variant:
189 return _tf_tensor_list_append(list_, x)
190 else:
191 raise ValueError(
192 'tensor lists are expected to be Tensors with dtype=tf.variant,'
193 ' instead found %s' % list_)
194 else:
195 return _py_list_append(list_, x)
198def _tf_tensor_list_append(list_, x):
199 """Overload of list_append that stages a Tensor list write."""
200 def empty_list_of_elements_like_x():
201 tensor_x = ops.convert_to_tensor(x)
202 return list_ops.empty_tensor_list(
203 element_shape=array_ops.shape(tensor_x),
204 element_dtype=tensor_x.dtype)
206 list_ = cond.cond(
207 list_ops.tensor_list_length(list_) > 0,
208 lambda: list_,
209 empty_list_of_elements_like_x,
210 )
211 return list_ops.tensor_list_push_back(list_, x)
214def _tf_tensorarray_append(list_, x):
215 """Overload of list_append that stages a TensorArray write."""
216 return list_.write(list_.size(), x)
219def _py_list_append(list_, x):
220 """Overload of list_append that executes a Python list append."""
221 # Revert to the original call.
222 list_.append(x)
223 return list_
226class ListPopOpts(
227 collections.namedtuple('ListPopOpts', ('element_dtype', 'element_shape'))):
228 pass
231def list_pop(list_, i, opts):
232 """The list pop function.
234 Note: it is unspecified where list_ will be mutated or not. If list_ is
235 a TensorFlow entity, it will not be typically mutated. If list_ is a plain
236 list, it will be. In general, if the list is mutated then the return value
237 should point to the original entity.
239 Args:
240 list_: An entity that supports pop semantics.
241 i: Optional index to pop from. May be None.
242 opts: A ListPopOpts.
244 Returns:
245 Tuple (x, out_list_):
246 out_list_: same as list_, after the removal was performed.
247 x: the removed element value.
249 Raises:
250 ValueError: if list_ is not of a known list-like type or the operation is
251 not supported for that type.
252 """
253 assert isinstance(opts, ListPopOpts)
255 if isinstance(list_, tensor_array_ops.TensorArray):
256 raise ValueError('TensorArray does not support item removal')
257 elif tensor_util.is_tf_type(list_):
258 if list_.dtype == dtypes.variant:
259 return _tf_tensor_list_pop(list_, i, opts)
260 else:
261 raise ValueError(
262 'tensor lists are expected to be Tensors with dtype=tf.variant,'
263 ' instead found %s' % list_)
264 else:
265 return _py_list_pop(list_, i)
268def _tf_tensor_list_pop(list_, i, opts):
269 """Overload of list_pop that stages a Tensor list pop."""
270 if i is not None:
271 raise NotImplementedError('tensor lists only support removing from the end')
273 if opts.element_dtype is None:
274 raise ValueError('cannot pop from a list without knowing its element '
275 'type; use set_element_type to annotate it')
276 if opts.element_shape is None:
277 raise ValueError('cannot pop from a list without knowing its element '
278 'shape; use set_element_type to annotate it')
279 list_out, x = list_ops.tensor_list_pop_back(
280 list_, element_dtype=opts.element_dtype)
281 x.set_shape(opts.element_shape)
282 return list_out, x
285def _py_list_pop(list_, i):
286 """Overload of list_pop that executes a Python list append."""
287 if i is None:
288 x = list_.pop()
289 else:
290 x = list_.pop(i)
291 return list_, x
294# TODO(mdan): Look into reducing duplication between all these containers.
295class ListStackOpts(
296 collections.namedtuple('ListStackOpts',
297 ('element_dtype', 'original_call'))):
298 pass
301def list_stack(list_, opts):
302 """The list stack function.
304 This does not have a direct correspondent in Python. The closest idiom to
305 this is tf.append or np.stack. It's different from those in the sense that it
306 accepts a Tensor list, rather than a list of tensors. It can also accept
307 TensorArray. When the target is anything else, the dispatcher will rely on
308 ctx.original_call for fallback.
310 Args:
311 list_: An entity that supports append semantics.
312 opts: A ListStackOpts object.
314 Returns:
315 The output of the stack operation, typically a Tensor.
316 """
317 assert isinstance(opts, ListStackOpts)
319 if isinstance(list_, tensor_array_ops.TensorArray):
320 return _tf_tensorarray_stack(list_)
321 elif tensor_util.is_tf_type(list_):
322 if list_.dtype == dtypes.variant:
323 return _tf_tensor_list_stack(list_, opts)
324 else:
325 # No-op for primitive Tensor arguments.
326 return list_
327 else:
328 return _py_list_stack(list_, opts)
331def _tf_tensorarray_stack(list_):
332 """Overload of list_stack that stages a TensorArray stack."""
333 return list_.stack()
336def _tf_tensor_list_stack(list_, opts):
337 """Overload of list_stack that stages a Tensor list write."""
338 if opts.element_dtype is None:
339 raise ValueError('cannot stack a list without knowing its element type;'
340 ' use set_element_type to annotate it')
341 return list_ops.tensor_list_stack(list_, element_dtype=opts.element_dtype)
344def _py_list_stack(list_, opts):
345 """Overload of list_stack that executes a Python list append."""
346 # Revert to the original call.
347 return opts.original_call(list_)