Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/operators/slices.py: 40%
52 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operators specific to slicing operations."""
17import collections
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import tensor_util
21from tensorflow.python.ops import gen_array_ops
22from tensorflow.python.ops import gen_string_ops
23from tensorflow.python.ops import list_ops
24from tensorflow.python.ops import tensor_array_ops
27# TODO(mdan): Support extended slices.
30class GetItemOpts(collections.namedtuple('GetItemOpts', ('element_dtype',))):
31 pass
34def get_item(target, i, opts):
35 """The slice read operator (i.e. __getitem__).
37 Note: it is unspecified whether target will be mutated or not. In general,
38 if target is mutable (like Python lists), it will be mutated.
40 Args:
41 target: An entity that supports getitem semantics.
42 i: Index to read from.
43 opts: A GetItemOpts object.
45 Returns:
46 The read element.
48 Raises:
49 ValueError: if target is not of a supported type.
50 """
51 assert isinstance(opts, GetItemOpts)
53 if isinstance(target, tensor_array_ops.TensorArray):
54 return _tf_tensorarray_get_item(target, i)
55 elif tensor_util.is_tf_type(target):
56 if target.dtype == dtypes.variant:
57 return _tf_tensor_list_get_item(target, i, opts)
58 elif target.dtype == dtypes.string and target.shape.ndims == 0:
59 return _tf_tensor_string_get_item(target, i)
60 else:
61 return _tf_tensor_get_item(target, i)
62 else:
63 return _py_get_item(target, i)
66def _tf_tensorarray_get_item(target, i):
67 """Overload of get_item that stages a TensorArray read."""
68 return target.read(i)
71def _tf_tensor_list_get_item(target, i, opts):
72 """Overload of get_item that stages a Tensor list read."""
73 if opts.element_dtype is None:
74 raise ValueError('cannot retrieve from a list without knowing its '
75 'element type; use set_element_type to annotate it')
76 x = list_ops.tensor_list_get_item(target, i, element_dtype=opts.element_dtype)
77 return x
80def _tf_tensor_get_item(target, i):
81 """Overload of get_item that stages a Tensor (not Tensor list) read."""
82 return target[i]
85def _tf_tensor_string_get_item(target, i):
86 """Overload of get_item that stages a Tensor string read."""
87 x = gen_string_ops.substr(target, i, 1)
88 return x
91def _py_get_item(target, i):
92 """Overload of get_item that executes a Python list modification."""
93 return target[i]
96def set_item(target, i, x):
97 """The slice write operator (i.e. __setitem__).
99 Note: it is unspecified whether target will be mutated or not. In general,
100 if target is mutable (like Python lists), it will be mutated.
102 Args:
103 target: An entity that supports setitem semantics.
104 i: Index to modify.
105 x: The new element value.
107 Returns:
108 Same as target, after the update was performed.
110 Raises:
111 ValueError: if target is not of a supported type.
112 """
113 if isinstance(target, tensor_array_ops.TensorArray):
114 return _tf_tensorarray_set_item(target, i, x)
115 elif tensor_util.is_tf_type(target):
116 if target.dtype == dtypes.variant:
117 return _tf_tensor_list_set_item(target, i, x)
118 else:
119 return _tf_tensor_set_item(target, i, x)
120 else:
121 return _py_set_item(target, i, x)
124def _tf_tensorarray_set_item(target, i, x):
125 """Overload of set_item that stages a TensorArray write."""
126 return target.write(i, x)
129def _tf_tensor_list_set_item(target, i, x):
130 """Overload of set_item that stages a Tensor list update."""
131 return list_ops.tensor_list_set_item(target, i, x)
134def _tf_tensor_set_item(target, i, x):
135 """Overload of set_item that stages a Tensor scatter update."""
136 return gen_array_ops.tensor_scatter_update(target, ((i,),), (x,))
139def _py_set_item(target, i, x):
140 """Overload of set_item that executes a Python list modification."""
141 target[i] = x
142 return target