Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/structured/structured_array_ops.py: 25%
228 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""StructuredTensor array ops."""
17from typing import Sequence
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops import random_ops
25from tensorflow.python.ops.ragged import dynamic_ragged_shape
26from tensorflow.python.ops.ragged import ragged_tensor
27from tensorflow.python.ops.ragged.row_partition import RowPartition
28from tensorflow.python.ops.structured.structured_tensor import StructuredTensor
29from tensorflow.python.util import deprecation
30from tensorflow.python.util import dispatch
33@dispatch.dispatch_for_api(array_ops.shape_v2)
34def shape_v2(input: StructuredTensor, out_type=dtypes.int32, # pylint: disable=redefined-builtin
35 name=None) -> dynamic_ragged_shape.DynamicRaggedShape:
36 """Returns a DynamicRaggedShape containing the shape of the input."""
37 del name
38 return input._ragged_shape.with_dtype(out_type) # pylint: disable=protected-access
41@dispatch.dispatch_for_api(array_ops.shape)
42def shape_v1(input: StructuredTensor, name=None, # pylint: disable=redefined-builtin
43 out_type=dtypes.int32) -> dynamic_ragged_shape.DynamicRaggedShape:
44 """Returns a DynamicRaggedShape containing the shape of the input."""
45 del name
46 return input._ragged_shape.with_dtype(out_type) # pylint: disable=protected-access
49@dispatch.dispatch_for_types(array_ops.expand_dims, StructuredTensor)
50@deprecation.deprecated_args(None, 'Use the `axis` argument instead', 'dim')
51def expand_dims(input, axis=None, name=None, dim=None): # pylint: disable=redefined-builtin
52 """Creates a StructuredTensor with a length 1 axis inserted at index `axis`.
54 This is an implementation of tf.expand_dims for StructuredTensor. Note
55 that the `axis` must be less than or equal to rank.
57 >>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]])
58 >>> tf.expand_dims(st, 0).to_pyval()
59 [[[{'x': 1}, {'x': 2}], [{'x': 3}]]]
60 >>> tf.expand_dims(st, 1).to_pyval()
61 [[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]]
62 >>> tf.expand_dims(st, 2).to_pyval()
63 [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
64 >>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2
65 [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
67 Args:
68 input: the original StructuredTensor.
69 axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank`
70 name: the name of the op.
71 dim: deprecated: use axis.
73 Returns:
74 a new structured tensor with larger rank.
76 Raises:
77 an error if `axis < -(rank + 1)` or `rank < axis`.
78 """
79 axis = deprecation.deprecated_argument_lookup('axis', axis, 'dim', dim)
80 return _expand_dims_impl(input, axis, name=name)
83@dispatch.dispatch_for_types(array_ops.expand_dims_v2, StructuredTensor)
84def expand_dims_v2(input, axis, name=None): # pylint: disable=redefined-builtin
85 """Creates a StructuredTensor with a length 1 axis inserted at index `axis`.
87 This is an implementation of tf.expand_dims for StructuredTensor. Note
88 that the `axis` must be less than or equal to rank.
90 >>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]])
91 >>> tf.expand_dims(st, 0).to_pyval()
92 [[[{'x': 1}, {'x': 2}], [{'x': 3}]]]
93 >>> tf.expand_dims(st, 1).to_pyval()
94 [[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]]
95 >>> tf.expand_dims(st, 2).to_pyval()
96 [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
97 >>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2
98 [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
100 Args:
101 input: the original StructuredTensor.
102 axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank`
103 name: the name of the op.
105 Returns:
106 a new structured tensor with larger rank.
108 Raises:
109 an error if `axis < -(rank + 1)` or `rank < axis`.
110 """
111 return _expand_dims_impl(input, axis, name=name)
114@dispatch.dispatch_for_types(array_ops.gather, StructuredTensor)
115def gather(params,
116 indices,
117 validate_indices=None,
118 name=None,
119 axis=None,
120 batch_dims=0):
121 """tf.gather for structured tensors.
123 Does not support (yet) checks on illegal axis values, et cetera.
125 Indices must be a ragged or dense tensor.
126 Args:
127 params: a structured tensor to be gathered
128 indices: a ragged tensor or tensor to gather by.
129 validate_indices: whether to validate the indices
130 name: the name of the op(s).
131 axis: the axis in params to gather on.
132 batch_dims: the number of batch dimensions.
134 Returns:
135 the params reorganized according to indices.
136 """
137 if name is None:
138 name = 'gather'
139 with ops.name_scope(name):
140 if axis is None:
141 axis = batch_dims
142 axis = array_ops.get_positive_axis(axis, params.shape.rank,
143 ndims_name='params.shape.rank')
144 indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
145 indices, name='indices')
147 def leaf_op(p):
148 return array_ops.gather(
149 p,
150 indices,
151 validate_indices=validate_indices,
152 axis=axis,
153 batch_dims=batch_dims,
154 name=None)
156 return _extend_op_single(params, leaf_op)
159@dispatch.dispatch_for_types(array_ops.concat, StructuredTensor)
160def concat(values, axis, name: str = 'concat'):
161 """tf.concat for structured tensors.
163 Does not support (yet) checks on illegal axis values, et cetera.
165 Args:
166 values: a sequence of StructuredTensors.
167 axis: an axis to concatenate upon.
168 name: the name of the op(s).
170 Returns:
171 the params reorganized according to indices.
172 """
173 if name is None:
174 name = 'concat'
175 _assert_concat_compatible_structured_tensors(values)
176 def leaf_op(values):
177 return array_ops.concat(values, axis)
178 # TODO(martinz): handle axis when it is a tensor.
179 axis = array_ops.get_positive_axis(axis, values[0].rank)
180 with ops.name_scope(name, 'StructuredConcat', values):
181 return _extend_op(values, leaf_op)
184@dispatch.dispatch_for_types(random_ops.random_shuffle, StructuredTensor)
185def random_shuffle(value, seed=None, name=None):
186 """Shuffle a structured tensor on the zeroth axis.
188 Args:
189 value: a structured tensor of rank at least one.
190 seed: the seed for shuffling.
191 name: the name for shuffle.
193 Returns:
194 The shuffled structured tensor.
195 """
196 with ops.name_scope(name, 'shuffle', [value, seed]):
197 if value.rank == 0:
198 raise ValueError('Cannot shuffle a scalar StructuredTensor')
199 first_dimension = value.nrows()
200 index = random_ops.random_shuffle(math_ops.range(first_dimension),
201 seed=seed)
202 return gather(value, index, axis=0)
205@dispatch.dispatch_for_types(array_ops.size_v2, StructuredTensor)
206def size_v2(input, out_type=dtypes.int32, name=None):
207 # pylint: disable=redefined-builtin
208 """Returns the size of a tensor."""
209 return size(input, name=name, out_type=out_type)
212# pylint: disable=protected-access
213@dispatch.dispatch_for_types(array_ops.size, StructuredTensor)
214def size(input, name=None, out_type=dtypes.int32):
215 # pylint: disable=redefined-builtin
216 """Returns the size of a tensor."""
217 with ops.name_scope(name, 'size', [input]) as name:
218 if not input.row_partitions:
219 if input.nrows() is not None:
220 return math_ops.cast(input.nrows(), out_type) # vector.
221 else:
222 return math_ops.cast(1, out_type) # scalar.
223 # 2D and up.
224 nvals = input.row_partitions[-1].nvals()
225 if nvals is None or out_type is None:
226 return nvals
227 return math_ops.cast(nvals, dtype=out_type)
230# pylint: disable=protected-access
231@dispatch.dispatch_for_types(array_ops.zeros_like, StructuredTensor)
232def zeros_like(tensor, dtype=None, name=None, optimize=True):
233 """Implementation of zeros_like for StructuredTensor for TF v1."""
234 del optimize
235 return zeros_like_v2(tensor, dtype=dtype, name=name)
238# pylint: disable=protected-access
239@dispatch.dispatch_for_types(array_ops.zeros_like_v2, StructuredTensor)
240def zeros_like_v2(input, dtype=None, name=None): # pylint: disable=redefined-builtin
241 """Replace every object with a zero.
243 Example:
244 >>> st = StructuredTensor.from_pyval([{"x":[3]}, {"x":[4,5]}])
245 >>> tf.zeros_like(st)
246 <tf.Tensor: shape=(2,), dtype=int32, numpy=array([0.0, 0.0], dtype=float32)>
247 >>> st = StructuredTensor.from_pyval([[{"x":[3]}], [{"x":[4,5]}, {"x":[]}]])
248 >>> tf.zeros_like(st, dtype=tf.int32)
249 <tf.RaggedTensor [[0], [0, 0]]>
251 Args:
252 input: a structured tensor.
253 dtype: the dtype of the resulting zeros. (default is tf.float32)
254 name: a name for the op.
255 Returns:
256 a tensor of zeros of the same shape.
257 """
258 if dtype is None:
259 dtype = dtypes.float32
260 with ops.name_scope(name, 'zeros_like', [input]) as name:
261 if not input.row_partitions:
262 if input.nrows() is not None:
263 return array_ops.zeros([input.nrows()], dtype) # vector.
264 else:
265 return array_ops.zeros([], dtype) # scalar.
266 # 2D and up.
267 last_row_partition = input.row_partitions[-1]
269 result = ragged_tensor.RaggedTensor._from_nested_row_partitions(
270 array_ops.zeros(last_row_partition.nvals(), dtype=dtype),
271 input.row_partitions)
272 return result
275# pylint: disable=protected-access
276@dispatch.dispatch_for_types(array_ops.ones_like, StructuredTensor)
277def ones_like(tensor, dtype=None, name=None, optimize=True):
278 """Implementation of zeros_like for StructuredTensor for TF v1."""
279 del optimize
280 return ones_like_v2(tensor, dtype=dtype, name=name)
283# pylint: disable=protected-access
284@dispatch.dispatch_for_types(array_ops.ones_like_v2, StructuredTensor)
285def ones_like_v2(input, dtype=None, name=None): # pylint: disable=redefined-builtin
286 """Replace every object with a zero.
288 Example:
289 >>> st = StructuredTensor.from_pyval([{"x":[3]}, {"x":[4,5]}])
290 >>> tf.ones_like(st)
291 <tf.Tensor: shape=(2,), dtype=int32, numpy=array([1.0, 1.0], dtype=float32)>
292 >>> st = StructuredTensor.from_pyval([[{"x":[3]}], [{"x":[4,5]}, {"x":[]}]])
293 >>> tf.ones_like(st, dtype=tf.int32)
294 <tf.RaggedTensor [[1], [1, 1]]>
296 Args:
297 input: a structured tensor.
298 dtype: the dtype of the resulting zeros. (default is tf.float32)
299 name: a name for the op.
300 Returns:
301 a tensor of zeros of the same shape.
302 """
303 if dtype is None:
304 dtype = dtypes.float32
305 with ops.name_scope(name, 'ones_like', [input]) as name:
306 if not input.row_partitions:
307 if input.nrows() is not None:
308 return array_ops.ones([input.nrows()], dtype) # vector.
309 else:
310 return array_ops.ones([], dtype) # scalar.
311 # 2D and up.
312 last_row_partition = input.row_partitions[-1]
314 result = ragged_tensor.RaggedTensor._from_nested_row_partitions(
315 array_ops.ones(last_row_partition.nvals(), dtype=dtype),
316 input.row_partitions)
317 return result
320@dispatch.dispatch_for_types(array_ops.rank, StructuredTensor)
321def rank(input, name=None):
322 # pylint: disable=redefined-builtin
323 """Returns the rank of a tensor."""
324 with ops.name_scope(name, 'rank', [input]) as name:
325 return constant_op.constant(input.rank, dtype=dtypes.int32)
328def _expand_dims_impl(st, axis, name=None): # pylint: disable=redefined-builtin
329 """Creates a StructuredTensor with a length 1 axis inserted at index `axis`.
331 This is an implementation of tf.expand_dims for StructuredTensor. Note
332 that the `axis` must be less than or equal to rank.
334 >>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]])
335 >>> tf.expand_dims(st, 0).to_pyval()
336 [[[{'x': 1}, {'x': 2}], [{'x': 3}]]]
337 >>> tf.expand_dims(st, 1).to_pyval()
338 [[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]]
339 >>> tf.expand_dims(st, 2).to_pyval()
340 [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
341 >>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2
342 [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
344 Args:
345 st: the original StructuredTensor.
346 axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank`
347 name: the name of the op.
349 Returns:
350 a new structured tensor with larger rank.
352 Raises:
353 an error if `axis < -(rank + 1)` or `rank < axis`.
354 """
355 axis = array_ops.get_positive_axis(
356 axis, st.rank + 1, axis_name='axis', ndims_name='rank(st)')
357 with ops.name_scope(name, 'ExpandDims', [st, axis]):
358 new_fields = {
359 k: array_ops.expand_dims(v, axis) for (k, v) in st._fields.items()
360 }
361 new_shape = st.shape[:axis] + (1,) + st.shape[axis:]
362 new_row_partitions = _expand_st_row_partitions(st, axis)
363 new_nrows = st.nrows() if (axis > 0) else 1
364 return StructuredTensor.from_fields(
365 new_fields,
366 shape=new_shape,
367 row_partitions=new_row_partitions,
368 nrows=new_nrows)
371def _expand_st_row_partitions(st, axis):
372 """Create the row_partitions for expand_dims."""
373 if axis == 0:
374 if st.shape.rank == 0:
375 return ()
376 nvals = st.nrows()
377 new_partition = RowPartition.from_uniform_row_length(
378 nvals, nvals, nrows=1, validate=False)
379 return (new_partition,) + st.row_partitions
380 elif axis == st.rank:
381 nvals = (
382 st.row_partitions[axis - 2].nvals() if (axis - 2 >= 0) else st.nrows())
383 return st.row_partitions + (RowPartition.from_uniform_row_length(
384 1, nvals, nrows=nvals, validate=False),)
385 else:
386 nvals = (
387 st.row_partitions[axis - 1].nrows() if (axis - 1 >= 0) else st.nrows())
388 return st.row_partitions[:axis - 1] + (RowPartition.from_uniform_row_length(
389 1, nvals, nrows=nvals, validate=False),) + st.row_partitions[axis - 1:]
392# TODO(martinz): consider allowing values to be nested.
393def _extend_op(values, leaf_op, empty_st_op=None):
394 """Extend an op from RaggedTensor and Tensor to StructuredTensor.
396 Visits all children of the structured tensor, and children of children,
397 applying leaf_op whenever it reaches a leaf, and empty_st_op whenever
398 it reaches an internal node without children.
400 Args:
401 values: a list of structured tensors, ragged tensors, or tensors. All must
402 have the same type. If they are structured tensors, they must have the
403 same paths.
404 leaf_op: an op for handling non-structured tensor.
405 empty_st_op: op to create a structured tensor without fields.
407 Returns:
408 the result of the extended op (a StructuredTensor, RaggedTensor, or Tensor)
410 Raises:
411 ValueError:
412 If values is not a Sequence or is empty.
413 """
414 if not isinstance(values, Sequence):
415 raise ValueError('Expected a list')
417 if not values:
418 raise ValueError('List cannot be empty')
420 if empty_st_op is None:
421 empty_st_op = empty_st_op_like_zeros(leaf_op)
422 # Use the structure of the first StructuredTensor. They are all assumed to
423 # be the same.
424 value = values[0]
426 if isinstance(value, StructuredTensor):
427 # TODO(martinz): Calling empty_st_op may add unnecessary ops. Revisit later.
428 empty_result = empty_st_op(values)
429 if not value.field_names():
430 return empty_result
431 new_fields = {}
432 for k in value.field_names():
433 new_fields[k] = _extend_op([v.field_value(k) for v in values], leaf_op,
434 empty_st_op)
435 return StructuredTensor.from_fields(new_fields, shape=empty_result.shape)
436 else:
437 return leaf_op(values)
440def _extend_op_single(value, leaf_op, empty_st_op=None):
441 """Extend an op to a value instead of a list of values."""
443 def to_list_op(element_op):
444 if element_op is None:
445 return None
447 def list_op(values):
448 [value] = values
449 return element_op(value)
451 return list_op
453 return _extend_op([value], to_list_op(leaf_op), to_list_op(empty_st_op))
456def empty_st_op_like_zeros(leaf_op):
458 def empty_st_op(values):
459 as_zeros = [
460 zeros_like_v2(value, dtype=dtypes.int32) for value in values
461 ]
462 result = leaf_op(as_zeros)
463 return _structured_tensor_like(result)
465 return empty_st_op
468def _structured_tensor_from_dense_tensor(t):
469 """Create a structured tensor with the shape of a dense tensor."""
470 # Note: If a tensor will have rank 0,
471 # it either has a fully defined shape or has unknown rank.
472 if t.shape.is_fully_defined():
473 return StructuredTensor.from_fields({}, shape=t.shape)
474 elif t.shape.rank is None:
475 raise ValueError("Can't build StructuredTensor w/ unknown rank")
476 elif t.shape.rank == 1:
477 return StructuredTensor.from_fields({}, shape=t.shape,
478 nrows=array_ops.shape(t)[0])
479 else:
480 rt = ragged_tensor.RaggedTensor.from_tensor(t)
481 return _structured_tensor_from_row_partitions(t.shape,
482 rt._nested_row_partitions)
485def _structured_tensor_from_row_partitions(shape, row_partitions):
486 return StructuredTensor.from_fields({},
487 shape=shape,
488 row_partitions=row_partitions)
491# pylint: disable=protected_access
492def _all_nested_row_partitions(rt):
493 """Returns all nested row partitions in rt, including for dense dimensions."""
494 if isinstance(rt, ops.Tensor):
495 if rt.shape.rank <= 1:
496 return ()
497 else:
498 rt2 = ragged_tensor.RaggedTensor.from_tensor(rt)
499 return rt2._nested_row_partitions
500 else:
501 tail_partitions = _all_nested_row_partitions(rt.flat_values)
502 head_partitions = rt._nested_row_partitions # pylint: disable=protected_access
503 return head_partitions + tail_partitions
506def _structured_tensor_like(t):
507 """Create a StructuredTensor with the shape of a (composite) tensor."""
508 if isinstance(t, ops.Tensor):
509 return _structured_tensor_from_dense_tensor(t)
510 if ragged_tensor.is_ragged(t):
511 return StructuredTensor.from_fields(
512 {}, shape=t.get_shape(), row_partitions=_all_nested_row_partitions(t))
513 # here, it is a StructuredTensor
514 return StructuredTensor.from_fields({},
515 shape=t.shape,
516 row_partitions=t.row_partitions,
517 nrows=t.nrows())
520def _get_all_paths(st):
521 """Get all the paths from a StructuredTensor."""
522 fields = st.field_names()
523 all_paths = {()}
524 for k in fields:
525 v = st.field_value(k)
526 if isinstance(v, StructuredTensor):
527 all_paths = all_paths.union([(k,) + p for p in _get_all_paths(v)])
528 else:
529 all_paths.add((k,))
530 return all_paths
533def _get_all_ranks(st):
534 """Get ranks of all submessages of a StructuredTensor."""
535 fields = st.field_names()
536 all_ranks = {(): st.rank}
537 for k in fields:
538 v = st.field_value(k)
539 if isinstance(v, StructuredTensor):
540 for (k2, v2) in _get_all_ranks(v).items():
541 all_ranks[(k,) + k2] = v2
542 return all_ranks
545def _assert_all_paths_match(values):
546 """Raises an error if the paths are not identical."""
547 paths = [_get_all_paths(st) for st in values]
548 path_diff = set()
549 for other_paths in paths[1:]:
550 path_diff = path_diff.union(paths[0].symmetric_difference(other_paths))
551 if path_diff:
552 raise ValueError(
553 'Some paths are present in some, but not all, structured tensors: %r' %
554 (path_diff,))
557def _assert_all_ranks_match(values):
558 """Raises an error if the ranks of submessages are not identical."""
559 ranks = [_get_all_ranks(st) for st in values]
560 for other_ranks in ranks[1:]:
561 if other_ranks != ranks[0]:
562 # TODO(martinz): If this becomes common, we can provide more detail.
563 # e.g.: which path is inconsistent.
564 raise ValueError('Ranks of sub-message do not match')
567def _assert_concat_compatible_structured_tensors(values):
568 """Sometimes raises an error if concat doesn't make sense statically on values.
570 values must be a sequence, and each element in values must be a structured
571 tensor, and must have the same paths. Additionally, each path that is a
572 submessage must have the same rank.
574 These constraints are sufficient for concat on the fields to be the same
575 as concat on structured tensors. This is meant to capture scenarios like
576 paths that are not in the first structured tensor, but are in later
577 structured tensors, which will just be ignored by the recursive algorithm.
579 If the rank of a submessage was different for two structured tensors,
580 then that is also a non-sensical merge.
582 Note that all of these checks are static, as paths and submessage ranks
583 are known.
585 Args:
586 values: a Sequence of StructuredTensors.
588 Raises:
589 ValueError: if there is any inconsistency as described above.
590 """
591 if not isinstance(values, Sequence):
592 raise ValueError('values must be a list of StructuredTensors (not a list)')
593 if not values:
594 raise ValueError('values must not be an empty list')
595 for st in values:
596 if not isinstance(st, StructuredTensor):
597 raise ValueError('values must be a list of StructuredTensors')
598 _assert_all_paths_match(values)
599 _assert_all_ranks_match(values)