Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/list_ops.py: 33%
165 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops to manipulate lists of tensors."""
17# pylint: disable=g-bad-name
18import numpy as np
20from tensorflow.core.framework import full_type_pb2
21from tensorflow.python.framework import cpp_shape_inference_pb2
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import gen_list_ops
28from tensorflow.python.ops import handle_data_util
29# go/tf-wildcard-import
30# pylint: disable=wildcard-import
31from tensorflow.python.ops.gen_list_ops import *
32# pylint: enable=wildcard-import
35ops.NotDifferentiable("TensorListConcatLists")
36ops.NotDifferentiable("TensorListElementShape")
37ops.NotDifferentiable("TensorListLength")
38ops.NotDifferentiable("TensorListPushBackBatch")
41def empty_tensor_list(element_shape,
42 element_dtype,
43 max_num_elements=None,
44 name=None):
45 if max_num_elements is None:
46 max_num_elements = -1
48 return gen_list_ops.empty_tensor_list(
49 element_shape=_build_element_shape(element_shape),
50 element_dtype=element_dtype,
51 max_num_elements=max_num_elements,
52 name=name)
55def _set_handle_data(list_handle, element_shape, element_dtype):
56 """Sets type information on `list_handle` for consistency with graphs."""
57 # TODO(b/169968286): It would be better if we had a consistent story for
58 # creating handle data from eager operations (shared with VarHandleOp).
59 if isinstance(list_handle, ops.EagerTensor):
60 if tensor_util.is_tf_type(element_shape):
61 element_shape = tensor_shape.TensorShape(None)
62 elif not isinstance(element_shape, tensor_shape.TensorShape):
63 element_shape = tensor_shape.TensorShape(element_shape)
64 handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData()
65 handle_data.is_set = True
66 # TODO(b/191472076): This duplicates type inference. Clean up.
67 handle_data.shape_and_type.append(
68 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType(
69 shape=element_shape.as_proto(),
70 dtype=element_dtype.as_datatype_enum,
71 type=full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_ARRAY)))
72 list_handle._handle_data = handle_data # pylint: disable=protected-access
75def tensor_list_reserve(element_shape, num_elements, element_dtype, name=None):
76 result = gen_list_ops.tensor_list_reserve(
77 element_shape=_build_element_shape(element_shape),
78 num_elements=num_elements,
79 element_dtype=element_dtype,
80 name=name)
81 # TODO(b/169968286): gen_ops needs to ensure the metadata is properly
82 # populated for eager operations.
83 _set_handle_data(result, element_shape, element_dtype)
84 return result
87def tensor_list_from_tensor(tensor, element_shape, name=None):
88 tensor = ops.convert_to_tensor(tensor)
89 result = gen_list_ops.tensor_list_from_tensor(
90 tensor=tensor,
91 element_shape=_build_element_shape(element_shape),
92 name=name)
93 _set_handle_data(result, tensor.shape, tensor.dtype)
94 return result
97def tensor_list_get_item(input_handle, index, element_dtype, element_shape=None,
98 name=None):
99 return gen_list_ops.tensor_list_get_item(
100 input_handle=input_handle,
101 index=index,
102 element_shape=_build_element_shape(element_shape),
103 element_dtype=element_dtype,
104 name=name)
107def tensor_list_pop_back(input_handle, element_dtype, name=None):
108 return gen_list_ops.tensor_list_pop_back(
109 input_handle=input_handle,
110 element_shape=-1,
111 element_dtype=element_dtype,
112 name=name)
115def tensor_list_gather(input_handle,
116 indices,
117 element_dtype,
118 element_shape=None,
119 name=None):
120 return gen_list_ops.tensor_list_gather(
121 input_handle=input_handle,
122 indices=indices,
123 element_shape=_build_element_shape(element_shape),
124 element_dtype=element_dtype,
125 name=name)
128def tensor_list_scatter(tensor,
129 indices,
130 element_shape=None,
131 input_handle=None,
132 name=None):
133 """Returns a TensorList created or updated by scattering `tensor`."""
134 tensor = ops.convert_to_tensor(tensor)
135 if input_handle is not None:
136 output_handle = gen_list_ops.tensor_list_scatter_into_existing_list(
137 input_handle=input_handle, tensor=tensor, indices=indices, name=name)
138 handle_data_util.copy_handle_data(input_handle, output_handle)
139 return output_handle
140 else:
141 output_handle = gen_list_ops.tensor_list_scatter_v2(
142 tensor=tensor,
143 indices=indices,
144 element_shape=_build_element_shape(element_shape),
145 num_elements=-1,
146 name=name)
147 _set_handle_data(output_handle, element_shape, tensor.dtype)
148 return output_handle
151def tensor_list_stack(input_handle,
152 element_dtype,
153 num_elements=-1,
154 element_shape=None,
155 name=None):
156 return gen_list_ops.tensor_list_stack(
157 input_handle=input_handle,
158 element_shape=_build_element_shape(element_shape),
159 element_dtype=element_dtype,
160 num_elements=num_elements,
161 name=name)
164def tensor_list_concat(input_handle, element_dtype, element_shape=None,
165 name=None):
166 # Ignore the lengths output of TensorListConcat. It is only used during
167 # gradient computation.
168 return gen_list_ops.tensor_list_concat_v2(
169 input_handle=input_handle,
170 element_dtype=element_dtype,
171 element_shape=_build_element_shape(element_shape),
172 leading_dims=ops.convert_to_tensor([], dtype=dtypes.int64),
173 name=name)[0]
176def tensor_list_split(tensor, element_shape, lengths, name=None):
177 return gen_list_ops.tensor_list_split(
178 tensor=tensor,
179 element_shape=_build_element_shape(element_shape),
180 lengths=lengths,
181 name=name)
184def tensor_list_set_item(input_handle,
185 index,
186 item,
187 resize_if_index_out_of_bounds=False,
188 name=None):
189 """Sets `item` at `index` in input list."""
190 output_handle = gen_list_ops.tensor_list_set_item(
191 input_handle=input_handle,
192 index=index,
193 item=item,
194 name=name,
195 resize_if_index_out_of_bounds=resize_if_index_out_of_bounds,
196 )
197 handle_data_util.copy_handle_data(input_handle, output_handle)
198 return output_handle
201@ops.RegisterGradient("TensorListPushBack")
202def _PushBackGrad(op, dresult):
203 return gen_list_ops.tensor_list_pop_back(
204 dresult,
205 element_shape=array_ops.shape(op.inputs[1]),
206 element_dtype=op.get_attr("element_dtype"))
209@ops.RegisterGradient("TensorListPopBack")
210def _PopBackGrad(op, dlist, delement):
211 if dlist is None:
212 dlist = empty_tensor_list(
213 element_dtype=delement.dtype,
214 element_shape=gen_list_ops.tensor_list_element_shape(
215 op.outputs[0], shape_type=dtypes.int32))
216 if delement is None:
217 delement = array_ops.zeros_like(op.outputs[1])
218 return gen_list_ops.tensor_list_push_back(dlist, delement), None
221@ops.RegisterGradient("TensorListStack")
222def _TensorListStackGrad(unused_op, dtensor):
223 return tensor_list_from_tensor(dtensor, element_shape=dtensor.shape[1:]), None
226@ops.RegisterGradient("TensorListConcat")
227@ops.RegisterGradient("TensorListConcatV2")
228def _TensorListConcatGrad(op, dtensor, unused_dlengths):
229 """Gradient function for TensorListConcat."""
230 dlist = tensor_list_split(
231 dtensor,
232 element_shape=gen_list_ops.tensor_list_element_shape(
233 op.inputs[0], shape_type=dtypes.int32),
234 lengths=op.outputs[1])
235 if op.type == "TensorListConcatV2":
236 return dlist, None, None
237 else:
238 return dlist
241@ops.RegisterGradient("TensorListSplit")
242def _TensorListSplitGrad(op, dlist):
243 tensor, _, lengths = op.inputs
244 element_shape = array_ops.slice(array_ops.shape(tensor), [1], [-1])
245 element_shape = array_ops.concat([[-1], element_shape], axis=0)
246 return gen_list_ops.tensor_list_concat_v2(
247 dlist,
248 element_shape=element_shape,
249 leading_dims=lengths,
250 element_dtype=op.inputs[0].dtype)[0], None, None
253@ops.RegisterGradient("TensorListFromTensor")
254def _TensorListFromTensorGrad(op, dlist):
255 """Gradient for TensorListFromTensor."""
256 t = op.inputs[0]
257 if t.shape.dims and t.shape.dims[0].value is not None:
258 num_elements = t.shape.dims[0].value
259 else:
260 num_elements = None
261 if dlist is None:
262 dlist = empty_tensor_list(
263 element_dtype=t.dtype,
264 element_shape=gen_list_ops.tensor_list_element_shape(
265 op.outputs[0], shape_type=dtypes.int32))
266 tensor_grad = gen_list_ops.tensor_list_stack(
267 dlist,
268 element_shape=array_ops.slice(array_ops.shape(t), [1], [-1]),
269 element_dtype=t.dtype,
270 num_elements=num_elements)
271 shape_grad = None
272 return tensor_grad, shape_grad
275@ops.RegisterGradient("TensorListGetItem")
276def _TensorListGetItemGrad(op, ditem):
277 """Gradient for TensorListGetItem."""
278 list_size = gen_list_ops.tensor_list_length(op.inputs[0])
279 list_grad = gen_list_ops.tensor_list_set_item(
280 gen_list_ops.tensor_list_reserve(
281 gen_list_ops.tensor_list_element_shape(op.inputs[0],
282 shape_type=dtypes.int32),
283 list_size, element_dtype=ditem.dtype),
284 index=op.inputs[1],
285 item=ditem)
286 index_grad = None
287 element_shape_grad = None
288 return list_grad, index_grad, element_shape_grad
291@ops.RegisterGradient("TensorListSetItem")
292def _TensorListSetItemGrad(op, dlist):
293 """Gradient function for TensorListSetItem."""
294 input_list, index, item = op.inputs
295 list_grad = gen_list_ops.tensor_list_set_item(
296 dlist, index=index, item=array_ops.zeros_like(item)
297 )
298 index_grad = None
299 element_grad = tensor_list_get_item(
300 dlist,
301 index,
302 element_shape=array_ops.shape(item),
303 element_dtype=item.dtype,
304 )
305 if op.get_attr(
306 "resize_if_index_out_of_bounds"
307 ):
308 input_list_size = gen_list_ops.tensor_list_length(input_list)
309 list_grad = gen_list_ops.tensor_list_resize(list_grad, input_list_size)
310 return list_grad, index_grad, element_grad
313@ops.RegisterGradient("TensorListResize")
314def _TensorListResizeGrad(op, dlist):
315 input_list, _ = op.inputs
316 input_list_size = gen_list_ops.tensor_list_length(input_list)
317 return gen_list_ops.tensor_list_resize(dlist, input_list_size), None
320@ops.RegisterGradient("TensorListGather")
321def _TensorListGatherGrad(op, dtensor):
322 """Gradient function for TensorListGather."""
323 input_list, indices, _ = op.inputs
324 element_shape = gen_list_ops.tensor_list_element_shape(
325 input_list, shape_type=dtypes.int32)
326 num_elements = gen_list_ops.tensor_list_length(input_list)
327 dlist = tensor_list_reserve(element_shape, num_elements, dtensor.dtype)
328 dlist = tensor_list_scatter(
329 tensor=dtensor, indices=indices, input_handle=dlist)
330 return dlist, None, None
333@ops.RegisterGradient("TensorListScatter")
334@ops.RegisterGradient("TensorListScatterV2")
335def _TensorListScatterGrad(op, dlist):
336 """Gradient function for TensorListScatter."""
337 tensor = op.inputs[0]
338 indices = op.inputs[1]
339 dtensor = gen_list_ops.tensor_list_gather(
340 dlist,
341 indices,
342 element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]),
343 element_dtype=tensor.dtype)
344 if op.type == "TensorListScatterV2":
345 return dtensor, None, None, None
346 else:
347 return dtensor, None, None
350@ops.RegisterGradient("TensorListScatterIntoExistingList")
351def _TensorListScatterIntoExistingListGrad(op, dlist):
352 """Gradient function for TensorListScatterIntoExistingList."""
353 _, tensor, indices = op.inputs
354 dtensor = gen_list_ops.tensor_list_gather(
355 dlist,
356 indices,
357 element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]),
358 element_dtype=tensor.dtype)
359 zeros = array_ops.zeros_like(tensor)
360 dlist = tensor_list_scatter(zeros, indices, indices, input_handle=dlist)
361 return dlist, dtensor, None
364def _build_element_shape(shape):
365 """Converts shape to a format understood by list_ops for element_shape.
367 If `shape` is already a `Tensor` it is returned as-is. We do not perform a
368 type check here.
370 If shape is None or a TensorShape with unknown rank, -1 is returned.
372 If shape is a scalar, an int32 tensor with empty list is returned. Note we
373 do directly return an empty list since ops.convert_to_tensor would conver it
374 to a float32 which is not a valid type for element_shape.
376 If shape is a sequence of dims, None's in the list are replaced with -1. We
377 do not check the dtype of the other dims.
379 Args:
380 shape: Could be None, Tensor, TensorShape or a list of dims (each dim could
381 be a None, scalar or Tensor).
383 Returns:
384 A None-free shape that can be converted to a tensor.
385 """
386 if isinstance(shape, ops.Tensor):
387 return shape
388 if isinstance(shape, tensor_shape.TensorShape):
389 # `TensorShape.as_list` requires rank to be known.
390 shape = shape.as_list() if shape else None
391 # Shape is unknown.
392 if shape is None:
393 return -1
394 # Shape is numpy array or a scalar.
395 if isinstance(shape, (np.ndarray, np.generic)) or not shape:
396 return ops.convert_to_tensor(shape, dtype=dtypes.int32)
397 # Shape is a sequence of dimensions. Convert None dims to -1.
398 def convert(val):
399 if val is None:
400 return -1
401 if isinstance(val, ops.Tensor):
402 return val
403 if isinstance(val, tensor_shape.Dimension):
404 return val.value if val.value is not None else -1
405 return val
407 return [convert(d) for d in shape]