Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_factory_ops.py: 19%
124 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operations for constructing RaggedTensors."""
17import numpy as np
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops.ragged import ragged_tensor
25from tensorflow.python.ops.ragged import ragged_tensor_value
26from tensorflow.python.util import dispatch
27from tensorflow.python.util.tf_export import tf_export
30#===============================================================================
31# Op to construct a constant RaggedTensor from a nested Python list.
32#===============================================================================
33@tf_export("ragged.constant")
34@dispatch.add_dispatch_support
35def constant(pylist, dtype=None, ragged_rank=None, inner_shape=None,
36 name=None, row_splits_dtype=dtypes.int64):
37 """Constructs a constant RaggedTensor from a nested Python list.
39 Example:
41 >>> tf.ragged.constant([[1, 2], [3], [4, 5, 6]])
42 <tf.RaggedTensor [[1, 2], [3], [4, 5, 6]]>
44 All scalar values in `pylist` must have the same nesting depth `K`, and the
45 returned `RaggedTensor` will have rank `K`. If `pylist` contains no scalar
46 values, then `K` is one greater than the maximum depth of empty lists in
47 `pylist`. All scalar values in `pylist` must be compatible with `dtype`.
49 Args:
50 pylist: A nested `list`, `tuple` or `np.ndarray`. Any nested element that
51 is not a `list`, `tuple` or `np.ndarray` must be a scalar value
52 compatible with `dtype`.
53 dtype: The type of elements for the returned `RaggedTensor`. If not
54 specified, then a default is chosen based on the scalar values in
55 `pylist`.
56 ragged_rank: An integer specifying the ragged rank of the returned
57 `RaggedTensor`. Must be nonnegative and less than `K`. Defaults to
58 `max(0, K - 1)` if `inner_shape` is not specified. Defaults to
59 `max(0, K - 1 - len(inner_shape))` if `inner_shape` is specified.
60 inner_shape: A tuple of integers specifying the shape for individual inner
61 values in the returned `RaggedTensor`. Defaults to `()` if `ragged_rank`
62 is not specified. If `ragged_rank` is specified, then a default is chosen
63 based on the contents of `pylist`.
64 name: A name prefix for the returned tensor (optional).
65 row_splits_dtype: data type for the constructed `RaggedTensor`'s row_splits.
66 One of `tf.int32` or `tf.int64`.
68 Returns:
69 A potentially ragged tensor with rank `K` and the specified `ragged_rank`,
70 containing the values from `pylist`.
72 Raises:
73 ValueError: If the scalar values in `pylist` have inconsistent nesting
74 depth; or if ragged_rank or inner_shape are incompatible with `pylist`.
75 """
76 def ragged_factory(values, row_splits):
77 row_splits = constant_op.constant(row_splits, dtype=row_splits_dtype)
78 return ragged_tensor.RaggedTensor.from_row_splits(values, row_splits,
79 validate=False)
81 with ops.name_scope(name, "RaggedConstant"):
82 return _constant_value(ragged_factory, constant_op.constant, pylist, dtype,
83 ragged_rank, inner_shape)
86@tf_export(v1=["ragged.constant_value"])
87@dispatch.add_dispatch_support
88def constant_value(pylist, dtype=None, ragged_rank=None, inner_shape=None,
89 row_splits_dtype="int64"):
90 """Constructs a RaggedTensorValue from a nested Python list.
92 Warning: This function returns a `RaggedTensorValue`, not a `RaggedTensor`.
93 If you wish to construct a constant `RaggedTensor`, use
94 [`ragged.constant(...)`](constant.md) instead.
96 Example:
98 >>> tf.compat.v1.ragged.constant_value([[1, 2], [3], [4, 5, 6]])
99 tf.RaggedTensorValue(values=array([1, 2, 3, 4, 5, 6]),
100 row_splits=array([0, 2, 3, 6]))
102 All scalar values in `pylist` must have the same nesting depth `K`, and the
103 returned `RaggedTensorValue` will have rank `K`. If `pylist` contains no
104 scalar values, then `K` is one greater than the maximum depth of empty lists
105 in `pylist`. All scalar values in `pylist` must be compatible with `dtype`.
107 Args:
108 pylist: A nested `list`, `tuple` or `np.ndarray`. Any nested element that
109 is not a `list` or `tuple` must be a scalar value compatible with `dtype`.
110 dtype: `numpy.dtype`. The type of elements for the returned `RaggedTensor`.
111 If not specified, then a default is chosen based on the scalar values in
112 `pylist`.
113 ragged_rank: An integer specifying the ragged rank of the returned
114 `RaggedTensorValue`. Must be nonnegative and less than `K`. Defaults to
115 `max(0, K - 1)` if `inner_shape` is not specified. Defaults to `max(0, K
116 - 1 - len(inner_shape))` if `inner_shape` is specified.
117 inner_shape: A tuple of integers specifying the shape for individual inner
118 values in the returned `RaggedTensorValue`. Defaults to `()` if
119 `ragged_rank` is not specified. If `ragged_rank` is specified, then a
120 default is chosen based on the contents of `pylist`.
121 row_splits_dtype: data type for the constructed `RaggedTensorValue`'s
122 row_splits. One of `numpy.int32` or `numpy.int64`.
124 Returns:
125 A `tf.RaggedTensorValue` or `numpy.array` with rank `K` and the specified
126 `ragged_rank`, containing the values from `pylist`.
128 Raises:
129 ValueError: If the scalar values in `pylist` have inconsistent nesting
130 depth; or if ragged_rank or inner_shape are incompatible with `pylist`.
131 """
132 if dtype is not None and isinstance(dtype, dtypes.DType):
133 dtype = dtype.as_numpy_dtype
134 row_splits_dtype = dtypes.as_dtype(row_splits_dtype).as_numpy_dtype
135 def _ragged_factory(values, row_splits):
136 row_splits = np.array(row_splits, dtype=row_splits_dtype)
137 return ragged_tensor_value.RaggedTensorValue(values, row_splits)
139 def _inner_factory(pylist, dtype, shape, name=None): # pylint: disable=unused-argument
140 return np.reshape(np.array(pylist, dtype=dtype), shape)
142 return _constant_value(_ragged_factory, _inner_factory, pylist, dtype,
143 ragged_rank, inner_shape)
146def _constant_value(ragged_factory, inner_factory, pylist, dtype, ragged_rank,
147 inner_shape):
148 """Constructs a constant RaggedTensor or RaggedTensorValue.
150 Args:
151 ragged_factory: A factory function with the signature:
152 `ragged_factory(values, row_splits)`
153 inner_factory: A factory function with the signature: `inner_factory(pylist,
154 dtype, shape, name)`
155 pylist: A nested `list`, `tuple` or `np.ndarray`.
156 dtype: Data type for returned value.
157 ragged_rank: Ragged rank for returned value.
158 inner_shape: Inner value shape for returned value.
160 Returns:
161 A value returned by `ragged_factory` or `inner_factory`.
163 Raises:
164 ValueError: If the scalar values in `pylist` have inconsistent nesting
165 depth; or if ragged_rank or inner_shape are incompatible with `pylist`.
166 """
167 if ragged_tensor.is_ragged(pylist):
168 raise TypeError("pylist may not be a RaggedTensor or RaggedTensorValue.")
169 # np.ndim builds an array, so we short-circuit lists and tuples.
170 if not isinstance(pylist, (list, tuple)) and np.ndim(pylist) == 0:
171 # Scalar value
172 if ragged_rank is not None and ragged_rank != 0:
173 raise ValueError("Invalid pylist=%r: incompatible with ragged_rank=%d" %
174 (pylist, ragged_rank))
175 if inner_shape is not None and inner_shape:
176 raise ValueError(
177 "Invalid pylist=%r: incompatible with dim(inner_shape)=%d" %
178 (pylist, len(inner_shape)))
179 return inner_factory(pylist, dtype, ())
181 if ragged_rank is not None and ragged_rank < 0:
182 raise ValueError(
183 "Invalid ragged_rank=%r: must be nonnegative" % ragged_rank)
185 # Find the depth of scalar values in `pylist`.
186 scalar_depth, max_depth = _find_scalar_and_max_depth(pylist)
187 if scalar_depth is not None:
188 if max_depth > scalar_depth:
189 raise ValueError("Invalid pylist=%r: empty list nesting is greater "
190 "than scalar value nesting" % pylist)
191 if ragged_rank is not None and max_depth < ragged_rank:
192 raise ValueError(f"Invalid pylist={pylist}, max depth smaller than "
193 f"ragged_rank={ragged_rank}")
195 # If both inner_shape and ragged_rank were specified, then check that
196 # they are compatible with pylist.
197 if inner_shape is not None and ragged_rank is not None:
198 expected_depth = ragged_rank + len(inner_shape) + 1
199 if ((scalar_depth is not None and expected_depth != scalar_depth) or
200 (scalar_depth is None and expected_depth < max_depth)):
201 raise ValueError(
202 "Invalid pylist=%r: incompatible with ragged_rank=%d "
203 "and dim(inner_shape)=%d" % (pylist, ragged_rank, len(inner_shape)))
205 # Check if the result is a `Tensor`.
206 if (ragged_rank == 0 or
207 (ragged_rank is None and
208 ((max_depth < 2) or
209 (inner_shape is not None and max_depth - len(inner_shape) < 2)))):
210 return inner_factory(pylist, dtype, inner_shape)
212 # Compute default value for inner_shape.
213 if inner_shape is None:
214 if ragged_rank is None:
215 inner_shape = ()
216 else:
217 inner_shape = _default_inner_shape_for_pylist(pylist, ragged_rank)
219 # Compute default value for ragged_rank.
220 if ragged_rank is None:
221 if scalar_depth is None:
222 ragged_rank = max(1, max_depth - 1)
223 else:
224 ragged_rank = max(1, scalar_depth - 1 - len(inner_shape))
226 # Build the splits for each ragged rank, and concatenate the inner values
227 # into a single list.
228 nested_splits = []
229 values = pylist
230 for dim in range(ragged_rank):
231 nested_splits.append([0])
232 concatenated_values = []
233 for row in values:
234 nested_splits[dim].append(nested_splits[dim][-1] + len(row))
235 concatenated_values.extend(row)
236 values = concatenated_values
238 values = inner_factory(
239 values, dtype=dtype, shape=(len(values),) + inner_shape, name="values")
240 for row_splits in reversed(nested_splits):
241 values = ragged_factory(values, row_splits)
242 return values
245def _find_scalar_and_max_depth(pylist):
246 """Finds nesting depth of scalar values in pylist.
248 Args:
249 pylist: A nested python `list` or `tuple`.
251 Returns:
252 A tuple `(scalar_depth, max_depth)`. `scalar_depth` is the nesting
253 depth of scalar values in `pylist`, or `None` if `pylist` contains no
254 scalars. `max_depth` is the maximum depth of `pylist` (including
255 empty lists).
257 Raises:
258 ValueError: If pylist has inconsistent nesting depths for scalars.
259 """
260 # Check if pylist is not scalar. np.ndim builds an array, so we
261 # short-circuit lists and tuples.
262 if isinstance(pylist, (list, tuple)) or np.ndim(pylist) != 0:
263 scalar_depth = None
264 max_depth = 1
265 for child in pylist:
266 child_scalar_depth, child_max_depth = _find_scalar_and_max_depth(child)
267 if child_scalar_depth is not None:
268 if scalar_depth is not None and scalar_depth != child_scalar_depth + 1:
269 raise ValueError("all scalar values must have the same nesting depth")
270 scalar_depth = child_scalar_depth + 1
271 max_depth = max(max_depth, child_max_depth + 1)
272 return (scalar_depth, max_depth)
273 return (0, 0)
276def _default_inner_shape_for_pylist(pylist, ragged_rank):
277 """Computes a default inner shape for the given python list."""
279 def get_inner_shape(item):
280 """Returns the inner shape for a python list `item`."""
281 if not isinstance(item, (list, tuple)) and np.ndim(item) == 0:
282 return ()
283 # Note that we need this check here in case `item` is not a Python list but
284 # fakes as being one (pylist). For a scenario of this, see test added in
285 # https://github.com/tensorflow/tensorflow/pull/48945
286 elif len(item) > 0: # pylint: disable=g-explicit-length-test
287 return (len(item),) + get_inner_shape(item[0])
288 return (0,)
290 def check_inner_shape(item, shape):
291 """Checks that `item` has a consistent shape matching `shape`."""
292 is_nested = isinstance(item, (list, tuple)) or np.ndim(item) != 0
293 if is_nested != bool(shape):
294 raise ValueError("inner values have inconsistent shape")
295 if is_nested:
296 if shape[0] != len(item):
297 raise ValueError("inner values have inconsistent shape")
298 for child in item:
299 check_inner_shape(child, shape[1:])
301 # Collapse the ragged layers to get the list of inner values.
302 flat_values = pylist
303 for dim in range(ragged_rank):
304 if not all(
305 isinstance(v, (list, tuple)) or np.ndim(v) != 0 for v in flat_values):
306 raise ValueError("pylist has scalar values depth %d, but ragged_rank=%d "
307 "requires scalar value depth greater than %d" %
308 (dim + 1, ragged_rank, ragged_rank))
309 flat_values = sum((list(v) for v in flat_values), [])
311 # Compute the inner shape looking only at the leftmost elements; and then
312 # use check_inner_shape to verify that other elements have the same shape.
313 inner_shape = get_inner_shape(flat_values)
314 check_inner_shape(flat_values, inner_shape)
315 return inner_shape[1:]
318@tf_export(v1=["ragged.placeholder"])
319@dispatch.add_dispatch_support
320def placeholder(dtype, ragged_rank, value_shape=None, name=None):
321 """Creates a placeholder for a `tf.RaggedTensor` that will always be fed.
323 **Important**: This ragged tensor will produce an error if evaluated.
324 Its value must be fed using the `feed_dict` optional argument to
325 `Session.run()`, `Tensor.eval()`, or `Operation.run()`.
328 Args:
329 dtype: The data type for the `RaggedTensor`.
330 ragged_rank: The ragged rank for the `RaggedTensor`
331 value_shape: The shape for individual flat values in the `RaggedTensor`.
332 name: A name for the operation (optional).
334 Returns:
335 A `RaggedTensor` that may be used as a handle for feeding a value, but
336 not evaluated directly.
338 Raises:
339 RuntimeError: if eager execution is enabled
341 @compatibility(TF2)
342 This API is not compatible with eager execution and `tf.function`. To migrate
343 to TF2, rewrite the code to be compatible with eager execution. Check the
344 [migration
345 guide](https://www.tensorflow.org/guide/migrate#1_replace_v1sessionrun_calls)
346 on replacing `Session.run` calls. In TF2, you can just pass tensors directly
347 into ops and layers. If you want to explicitly set up your inputs, also see
348 [Keras functional API](https://www.tensorflow.org/guide/keras/functional) on
349 how to use `tf.keras.Input` to replace `tf.compat.v1.ragged.placeholder`.
350 `tf.function` arguments also do the job of `tf.compat.v1.ragged.placeholder`.
351 For more details please read [Better
352 performance with tf.function](https://www.tensorflow.org/guide/function).
353 @end_compatibility
354 """
355 if ragged_rank == 0:
356 return array_ops.placeholder(dtype, value_shape, name)
358 with ops.name_scope(name, "RaggedPlaceholder", []):
359 flat_shape = tensor_shape.TensorShape([None]).concatenate(value_shape)
360 result = array_ops.placeholder(dtype, flat_shape, "flat_values")
361 for i in reversed(range(ragged_rank)):
362 row_splits = array_ops.placeholder(dtypes.int64, [None],
363 "row_splits_%d" % i)
364 result = ragged_tensor.RaggedTensor.from_row_splits(result, row_splits,
365 validate=False)
366 return result