Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/structured/structured_tensor.py: 19%
648 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Structured Tensors."""
17import re
18from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
20import numpy as np
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import extension_type
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import tensor_spec
28from tensorflow.python.framework import type_spec
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import check_ops
31from tensorflow.python.ops import control_flow_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops.ragged import dynamic_ragged_shape
34from tensorflow.python.ops.ragged import ragged_factory_ops
35from tensorflow.python.ops.ragged import ragged_tensor
36from tensorflow.python.ops.ragged.row_partition import RowPartition
37from tensorflow.python.util import compat
38from tensorflow.python.util import nest
39from tensorflow.python.util.tf_export import tf_export
41# Each field may contain one of the following types of Tensors.
42_FieldValue = Union[ops.Tensor, ragged_tensor.RaggedTensor, 'StructuredTensor',
43 extension_type.ExtensionType]
44# Function that takes a FieldValue as input and returns the transformed
45# FieldValue.
46_FieldFn = Callable[[_FieldValue], _FieldValue]
49@tf_export('experimental.StructuredTensor')
50class StructuredTensor(extension_type.BatchableExtensionType):
51 """A multidimensional collection of structures with the same schema.
53 A **`StructuredTensor`** is a multi-dimensional collection of ***structures***
54 with the same ***schema***, where:
56 * A ***schema*** is a collection of fields, each of which has a name and type.
57 * A ***structure*** maps each field in the schema to a tensor value (which
58 could be a nested StructuredTensor).
60 As an important special case, a 1D `StructuredTensor` encodes a 2D table,
61 where columns are heterogeneous `Tensor`s, and rows are the aligned elements
62 in each of those `Tensor`s.
64 Internally, StructuredTensors use a "field-major" encoding: for each leaf
65 field, there is a single tensor that stores the value of that field for all
66 structures in the `StructuredTensor`.
68 ### Examples
70 >>> # A scalar StructuredTensor describing a single person.
71 >>> s1 = tf.experimental.StructuredTensor.from_pyval(
72 ... {"age": 82, "nicknames": ["Bob", "Bobby"]})
73 >>> s1.shape
74 TensorShape([])
75 >>> s1["age"]
76 <tf.Tensor: shape=(), dtype=int32, numpy=82>
78 >>> # A vector StructuredTensor describing three people.
79 >>> s2 = tf.experimental.StructuredTensor.from_pyval([
80 ... {"age": 12, "nicknames": ["Josaphine"]},
81 ... {"age": 82, "nicknames": ["Bob", "Bobby"]},
82 ... {"age": 42, "nicknames": ["Elmo"]}])
83 >>> s2.shape
84 TensorShape([3])
85 >>> s2[0]["age"]
86 <tf.Tensor: shape=(), dtype=int32, numpy=12>
89 ### Field Paths
91 A *field path* is a tuple of field names, specifying the path to a nested
92 field.
93 """
94 _fields: Mapping[str, _FieldValue]
95 _ragged_shape: dynamic_ragged_shape.DynamicRaggedShape
97 __name__ = 'tf.StructuredTensor'
98 #=============================================================================
99 # Common Types
100 #=============================================================================
101 # pylint: disable=invalid-name
102 # Field names work as key, and they can be a sequence to refer to the
103 # sub-levels (embedded) StructuredTensor's.
104 FieldName = Union[str, Sequence[str]]
106 # pylint: enable=invalid-name
108 #=============================================================================
109 # Constructor & Factory Methods
110 #=============================================================================
111 def __init__(self, fields: Mapping[str, _FieldValue],
112 ragged_shape: dynamic_ragged_shape.DynamicRaggedShape):
113 self._fields = fields
114 self._ragged_shape = ragged_shape
116 @classmethod
117 def _old_init(cls, fields, shape, nrows, row_partitions, internal=False):
118 """Private constructor -- use factory methods to create StructuredTensors.
120 This constructor builds a `StructuredTensor` from the given attributes,
121 performing minimal validation.
123 Args:
124 fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or
125 `StructuredTensor`. (This dict is not copied, so the caller must ensure
126 that it does not get mutated via leaked references.)
127 shape: `tf.TensorShape` with statically known rank.
128 nrows: scalar integer `tf.Tensor`, or `None` if `shape.rank==0`.
129 row_partitions: tuple of `RowPartition`s, with length `shape.rank-1`.
130 internal: ignored argument.
132 Returns:
133 a StructuredTensor.
134 """
135 assert isinstance(fields, dict), fields
136 assert isinstance(shape, tensor_shape.TensorShape), shape
137 assert nrows is None or isinstance(nrows, ops.Tensor), nrows
138 assert row_partitions is None or isinstance(row_partitions,
139 tuple), row_partitions
140 return StructuredTensor(
141 fields=fields,
142 ragged_shape=_dynamic_ragged_shape_init(fields, shape, nrows,
143 row_partitions))
145 @classmethod
146 def from_shape(
147 cls, ragged_shape: dynamic_ragged_shape.DynamicRaggedShape
148 ) -> 'StructuredTensor':
149 """Creates a `StructuredTensor` with no fields and ragged_shape.
151 Args:
152 ragged_shape: the shape of the structured tensor.
154 Returns:
155 a StructuredTensor with no fields and ragged_shape.
156 """
157 return StructuredTensor(fields={}, ragged_shape=ragged_shape)
159 @classmethod
160 def from_fields(cls,
161 fields,
162 shape=(),
163 nrows=None,
164 row_partitions=None,
165 validate=False):
166 """Creates a `StructuredTensor` from a dictionary of fields.
168 Args:
169 fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or
170 `StructuredTensor`, providing the values for individual fields in each
171 structure. If `shape.rank > 0`, then every tensor in `fields` must have
172 the same shape in the first `shape.rank` dimensions; and that shape must
173 be compatible with `shape`; and `result[i1...iN][key] =
174 fields[key][i1...iN]` (where `N==shape.rank`).
175 shape: A `TensorShape`: static information about the shape of the
176 `StructuredTensor`. Must have a known `rank`. Defaults to scalar shape
177 (i.e. `rank=0`).
178 nrows: scalar integer tensor containing the number of rows in this
179 `StructuredTensor`. Should only be specified if `shape.rank > 0`.
180 Default value is inferred from the `fields` values. If `fields` is
181 empty, then this must be specified.
182 row_partitions: A list of `RowPartition`s describing the (possibly ragged)
183 shape of this `StructuredTensor`. Should only be specified if
184 `shape.rank > 1`. Default value is inferred from the `fields` values.
185 If `fields` is empty, then this must be specified.
186 validate: If true, then add runtime validation ops that check that the
187 field values all have compatible shapes in the outer `shape.rank`
188 dimensions.
190 Returns:
191 A `StructuredTensor`.
193 Examples:
195 >>> tf.experimental.StructuredTensor.from_fields({'x': 1, 'y': [1, 2, 3]})
196 <StructuredTensor(
197 fields={
198 "x": tf.Tensor(1, shape=(), dtype=int32),
199 "y": tf.Tensor([1 2 3], shape=(3,), dtype=int32)},
200 shape=())>
202 >>> tf.experimental.StructuredTensor.from_fields(
203 ... {'foo': [1, 2], 'bar': [3, 4]}, shape=[2])
204 <StructuredTensor(
205 fields={
206 "bar": tf.Tensor([3 4], shape=(2,), dtype=int32),
207 "foo": tf.Tensor([1 2], shape=(2,), dtype=int32)},
208 shape=(2,))>
209 """
210 shape = tensor_shape.as_shape(shape)
211 rank = shape.rank
212 if rank is None:
213 raise ValueError("StructuredTensor's shape must have known rank.")
214 if not isinstance(fields, dict):
215 raise TypeError('fields must be a dictionary, got %s' %
216 type(fields).__name__)
217 if rank < 2 and row_partitions:
218 raise ValueError('row_partitions must be None or [] if shape.rank<2')
219 if rank == 0 and nrows is not None:
220 raise ValueError('nrows must be None if shape.rank==0')
221 if row_partitions is not None:
222 row_partitions = tuple(row_partitions)
223 if len(row_partitions) != max(0, rank - 1):
224 raise ValueError('len(row_partitions) must be shape.rank-1')
225 elif rank < 2:
226 row_partitions = ()
228 fields = dict(fields) # Make a private copy.
229 with ops.name_scope(None, 'StructuredTensor', fields.values()):
230 # TODO(martinz): Make this have better errors.
231 shape = _dynamic_ragged_shape_init(fields, shape, nrows, row_partitions)
233 # TODO(martinz): This may not need to be done if all fields are dense.
234 if shape.rank > 1:
235 shape = shape._with_num_row_partitions(shape.rank - 1)
237 # Validate keys and convert field values to tensors.
238 for key, value in fields.items():
239 if not isinstance(key, str):
240 raise TypeError(f'Unexpected type for key in `fields`: {key}')
241 if not _FIELD_NAME_RE.match(key):
242 raise ValueError('Field name %r is not currently allowed.' % key)
243 fields[key] = _convert_to_structured_field_value(value)
245 fields = dict([(k, _replace_row_partitions(v, row_partitions))
246 for (k, v) in fields.items()])
247 return cls(fields=fields, ragged_shape=shape)
249 @classmethod
250 def from_fields_and_rank(
251 cls,
252 fields: Mapping[str, _FieldValue],
253 rank: int,
254 validate: bool = False,
255 dtype: Optional[dtypes.DType] = None) -> 'StructuredTensor':
256 """Creates a `StructuredTensor` from a nonempty dictionary of fields.
258 Note that if the shape dtype is not specified, the shape dtype will be
259 inferred from any fields that have a shape dtype. If fields differ, then
260 int64 will be preferred to int32, because coercing from int32 to int64 is
261 safer than coercing from int64 to int32.
263 If there are no ragged fields, then it will be int64 by default, but this
264 will be changed to int32 in the future.
266 Args:
267 fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or
268 `StructuredTensor`, providing the values for individual fields in each
269 structure. If `rank > 0`, then every tensor in `fields` must have the
270 same shape in the first `rank` dimensions. Cannot be empty.
271 rank: The rank of the resulting structured tensor.
272 validate: If true, then add runtime validation ops that check that the
273 field values all have compatible shapes in the outer `rank` dimensions.
274 dtype: If specified, then forces dtype of the shape to be this.
276 Returns:
277 A `StructuredTensor`.
278 Examples:
279 >>> tf.experimental.StructuredTensor.from_fields_and_rank(
280 ... {'x': 1, 'y': [1, 2, 3]}, 0)
281 <StructuredTensor(
282 fields={
283 "x": tf.Tensor(1, shape=(), dtype=int32),
284 "y": tf.Tensor([1 2 3], shape=(3,), dtype=int32)},
285 shape=())>
286 >>> StructuredTensor.from_fields_and_rank({'foo': [1, 2], 'bar': [3, 4]},
287 ... 1)
288 <StructuredTensor(
289 fields={
290 "bar": tf.Tensor([3 4], shape=(2,), dtype=int32),
291 "foo": tf.Tensor([1 2], shape=(2,), dtype=int32)},
292 shape=(2,))>
293 """
294 if not fields:
295 raise ValueError('Must provide at least one field')
296 if not isinstance(rank, int):
297 raise ValueError('rank must be an integer')
298 if rank < 0:
299 raise ValueError('rank must be nonnegative')
300 fields = {
301 k: _convert_to_structured_field_value(v) for (k, v) in fields.items()
302 }
303 if dtype is None:
304 dtype = _find_shape_dtype(fields, None, None)
305 fields = _fields_with_dtype(fields, dtype)
307 shape = _shape_from_fields(fields, rank, dtype)
308 if rank > 1:
309 shape = shape._with_num_row_partitions(rank - 1)
310 new_rp = shape._row_partitions # pylint: disable=protected-access
311 fields = {
312 k: _replace_row_partitions(v, new_rp) for (k, v) in fields.items()
313 }
314 return StructuredTensor(fields=fields, ragged_shape=shape)
316 def with_updates(self,
317 updates: Dict[FieldName, Union[_FieldValue, _FieldFn, None]],
318 validate: bool = False) -> 'StructuredTensor':
319 """Creates a new `StructuredTensor` with the updated fields.
321 If this `StructuredTensor` is a scalar, and `k` is the `FieldName` being
322 updated and `v` the new value, then:
324 ```
325 result[k] = v # If (k, v) is in updates and v is a FieldValue
326 result[k] = f(self[k]) # If (k, f) is in updates and f is a FieldFn
327 result[k] = self[k] # If k is in self.field_names but not in updates
328 ```
330 If this `StructuredTensor` has rank `N` and shape `[D1...DN]`, then each
331 FieldValue `v` in `updates` must have shape `[D1...DN, ...]`, that is,
332 prefixed with the same shape as the `StructuredTensor`. Then the resulting
333 `StructuredTensor` will have:
335 ```
336 result[i1...iN][k] = v[i1...iN] # (k, v) in updates
337 result[i1...iN][k] = f(self.field_value(k))[i1...iN] # (k, f) in updates
338 result[i1...iN][k] = self[i1...iN][k] # k not in updates
339 ```
341 Note that `result.shape` is always equal to `self.shape` (but the shapes
342 of nested StructuredTensors may be changed if they are updated with new
343 values).
345 Args:
346 updates: A dictionary mapping `FieldName` to either a `FieldValue` to be
347 used to update, or a `FieldFn` that will transform the value for the
348 given `FieldName`. `FieldName` can be a string for a direct field, or a
349 sequence of strings to refer to a nested sub-field. `FieldFn` is a
350 function that takes a `FieldValue` as input and should return a
351 `FieldValue`. All other fields are copied over to the new
352 `StructuredTensor`. New `FieldName` can be given (to add new fields),
353 but only to existing `StructuredTensor`, it won't automatically create
354 new nested structures -- but one can create a whole `StructureTensor`
355 sub-structure and set that into an existing structure. If the new value
356 is set to `None`, it is removed.
357 validate: If true, then add runtime validation ops that check that the
358 field values all have compatible shapes in the outer `shape.rank`
359 dimensions.
361 Returns:
362 A `StructuredTensor`.
364 Raises:
365 `ValueError`: If the any of the `FieldName` keys points to non-existent
366 sub-structures, if parent and child nodes are updated, if shapes
367 change, if a delete update is given for a non-existent field, or if a
368 `FieldFn` transforming function is given for a `FieldName` that doesn't
369 yet exist.
371 Examples:
373 >>> shoes_us = tf.experimental.StructuredTensor.from_pyval([
374 ... {"age": 12, "nicknames": ["Josaphine"],
375 ... "shoes": {"sizes": [8.0, 7.5, 7.5]}},
376 ... {"age": 82, "nicknames": ["Bob", "Bobby"],
377 ... "shoes": {"sizes": [11.0, 11.5, 12.0]}},
378 ... {"age": 42, "nicknames": ["Elmo"],
379 ... "shoes": {"sizes": [9.0, 9.5, 10.0]}}])
380 >>> def us_to_europe(t):
381 ... return tf.round(t * 2.54 + 17.0) # Rough approximation.
382 >>> shoe_sizes_key = ("shoes", "sizes")
383 >>> shoes_eu = shoes_us.with_updates({shoe_sizes_key: us_to_europe})
384 >>> shoes_eu.field_value(shoe_sizes_key)
385 <tf.RaggedTensor [[37.0, 36.0, 36.0], [45.0, 46.0, 47.0],
386 [40.0, 41.0, 42.0]]>
387 """
388 updates_items = [(_normalize_field_name_to_tuple(name), value)
389 for name, value in updates.items()]
391 # Sort by keys and check for updates of both parent and child nodes.
392 updates_items = sorted(updates_items)
393 for i in range(1, len(updates_items)):
394 # Parent of a node would precede node in the sorted order.
395 name = updates_items[i][0] # item[0] is the name, item[1] is the value.
396 prev_name = updates_items[i - 1][0]
397 if name[:len(prev_name)] == prev_name:
398 raise ValueError(
399 '`StructuredTensor.with_updates` does not allow both parent and '
400 'child nodes to be updated: parent={}, child={}. If needed you can '
401 'update child nodes in the parent update value.'.format(
402 prev_name, name))
403 return self._with_updates_impl((), updates_items, validate)
405 def _with_updates_impl(self, error_prefix: Tuple[str, ...],
406 updates: List[Tuple[FieldName, Union[_FieldValue,
407 _FieldFn]]],
408 validate: bool) -> 'StructuredTensor':
409 """Recursive part of `with_updates` implementation."""
410 # Get current fields.
411 new_fields = dict(self._fields)
413 # Convert field name to string with full path for error messages.
414 def name_fullpath(name: Sequence[str]) -> str:
415 return str(error_prefix + (name,))
417 # Apply value if a function or the value itself.
418 def apply_value(name: str, value: Union[_FieldValue,
419 _FieldFn]) -> _FieldValue:
420 if callable(value):
421 # `value` is actually a transforming function.
422 if name not in new_fields:
423 raise ValueError(
424 '`StructuredTensor.with_updates` cannot update the field {} '
425 'because a transforming function was given, but that field '
426 'does not already exist.'.format(name_fullpath(name)))
427 value = value(new_fields[name])
428 return value
430 # Merge updates.
431 for name, value in updates:
432 if not name or not name[0]:
433 raise ValueError(
434 '`StructuredTensor.with_updates` does not allow empty names '
435 '{}.'.format(name_fullpath(name)))
437 if len(name) == 1:
438 name = name[0]
439 if value is None:
440 if name not in new_fields:
441 raise ValueError(
442 '`StructuredTensor.with_updates` cannot delete field '
443 '{} because it is not present.'.format(name_fullpath(name)))
444 new_fields.pop(name)
445 else:
446 new_fields[name] = apply_value(name, value)
447 else:
448 # Recursive
449 prefix = name[0]
450 suffix = name[1:]
451 if prefix not in new_fields:
452 raise ValueError(
453 '`StructuredTensor.with_updates` cannot create new sub-field '
454 '{} if parent field {} is not set.'.format(
455 error_prefix + tuple(name), name_fullpath(prefix)))
456 current_value = new_fields[prefix]
457 if not isinstance(current_value, StructuredTensor):
458 raise ValueError(
459 '`StructuredTensor.with_updates` cannot create new sub-field '
460 '{} if parent structure {} is not a `StructuredTensor` that '
461 'can contain sub-structures -- it is a `{}`.'.format(
462 error_prefix + tuple(name), name_fullpath(prefix),
463 type(current_value)))
464 one_update = [(suffix, value)]
466 # Accessing protected member in recursion.
467 # FutureWork: optimize by aggregating the recursions, instead of
468 # calling one at a time.
469 # pylint: disable=protected-access
470 value = current_value._with_updates_impl(error_prefix + (prefix,),
471 one_update, validate)
472 # pylint: enable=protected-access
473 new_fields[prefix] = value
475 # TODO(edloper): When validate=True, only validate the modified fields.
476 try:
477 return StructuredTensor.from_fields(
478 new_fields,
479 shape=self.shape,
480 row_partitions=self.row_partitions,
481 nrows=self.nrows(),
482 validate=validate)
484 except ValueError as e:
485 msg = '`StructuredTensor.with_updates` failed'
486 if error_prefix:
487 msg = '{} for field {}'.format(msg, error_prefix)
488 raise ValueError(msg) from e
490 def _promote_helper(self, source_path, new_parent_path):
491 """Creates a promoted field without adding it to the structure.
493 Args:
494 source_path: the source path in the structured tensor.
495 new_parent_path: the new parent path. Must be a prefix of source_path.
497 Returns:
498 a composite tensor of source_path promoted.
499 Raises:
500 ValueError: if the shape of the field is unknown and the right strategy
501 cannot be determined.
502 """
503 current_field = self.field_value(source_path)
504 new_parent_rank = self.field_value(new_parent_path).rank
505 parent_rank = self.field_value(source_path[:-1]).rank
506 if new_parent_rank == parent_rank:
507 return current_field
508 current_field_rank = current_field.shape.rank
509 if current_field_rank is None:
510 raise ValueError('Cannot determine if dimensions should be merged.')
511 inner_dim = min(parent_rank, current_field_rank - 1)
512 if inner_dim <= new_parent_rank:
513 return current_field
514 return _merge_dims_generic(current_field, new_parent_rank, inner_dim)
516 def promote(self, source_path, new_name):
517 """Promotes a field, merging dimensions between grandparent and parent.
519 >>> d = [
520 ... {'docs': [{'tokens':[1, 2]}, {'tokens':[3]}]},
521 ... {'docs': [{'tokens':[7]}]}]
522 >>> st = tf.experimental.StructuredTensor.from_pyval(d)
523 >>> st2 =st.promote(('docs','tokens'), 'docs_tokens')
524 >>> st2[0]['docs_tokens']
525 <tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>
526 >>> st2[1]['docs_tokens']
527 <tf.Tensor: shape=(1,), dtype=int32, numpy=array([7], dtype=int32)>
529 Args:
530 source_path: the path of the field or substructure to promote; must have
531 length at least 2.
532 new_name: the name of the new field (must be a string).
534 Returns:
535 a modified structured tensor with the new field as a child of the
536 grandparent of the source_path.
538 Raises:
539 ValueError: if source_path is not a list or a tuple or has a length
540 less than two, or new_name is not a string, or the rank
541 of source_path is unknown and it is needed.
542 """
543 if not isinstance(new_name, str):
544 raise ValueError('new_name is not a string')
545 if not isinstance(source_path, (list, tuple)):
546 raise ValueError('source_path must be a list or tuple')
548 if len(source_path) < 2:
549 raise ValueError('source_path must have length at least two')
551 grandparent_path = source_path[:-2]
552 new_field = self._promote_helper(source_path, grandparent_path)
553 new_path = grandparent_path + (new_name,)
554 return self.with_updates({new_path: new_field})
556 #=============================================================================
557 # Properties
558 #=============================================================================
560 @property
561 def rank(self):
562 """The rank of this StructuredTensor. Guaranteed not to be `None`."""
563 return self._ragged_shape.rank
565 @property
566 def shape(self):
567 """The static shape of this StructuredTensor.
569 The returned `TensorShape` is guaranteed to have a known rank, but the
570 individual dimension sizes may be unknown.
572 Returns:
573 `tf.TensorShape`
574 """
575 return self._ragged_shape._to_tensor_shape() # pylint: disable=protected-access
577 # TODO(martinz): for backwards compatibility
578 @property
579 def _row_partitions(self):
580 """Deprecated form of row_partitions."""
581 return self.row_partitions
583 # TODO(edloper): Make this a func instead of a property? Or make nrows
584 # a property instead of a func? Seems like these should be consistent.
585 @property
586 def row_partitions(self):
587 """A tuple of `RowPartition`s defining the shape of this `StructuredTensor`.
589 When `self.rank <= 1`, this tuple will be empty.
591 When `self.rank > 1`, these `RowPartitions` define the shape of the
592 `StructuredTensor` by describing how a flat (1D) list of structures can be
593 repeatedly partitioned to form a higher-dimensional object. In particular,
594 the flat list is first partitioned into sublists using `row_partitions[-1]`,
595 and then those sublists are further partitioned using `row_partitions[-2]`,
596 etc. The following examples show the row partitions used to describe
597 several different `StructuredTensor`, each of which contains 8 copies of
598 the same structure (`x`):
600 >>> x = {'a': 1, 'b': ['foo', 'bar', 'baz']} # shape = [] (scalar)
602 >>> s1 = [[x, x, x, x], [x, x, x, x]] # shape = [2, 4]
603 >>> tf.experimental.StructuredTensor.from_pyval(s1).row_partitions
604 (tf.RowPartition(row_splits=[0 4 8]),)
606 >>> s2 = [[x, x], [x, x], [x, x], [x, x]] # shape = [4, 2]
607 >>> tf.experimental.StructuredTensor.from_pyval(s2).row_partitions
608 (tf.RowPartition(row_splits=[0 2 4 6 8]),)
610 >>> s3 = [[x, x, x], [], [x, x, x, x], [x]] # shape = [2, None]
611 >>> tf.experimental.StructuredTensor.from_pyval(s3).row_partitions
612 (tf.RowPartition(row_splits=[0 3 3 7 8]),)
614 >>> s4 = [[[x, x], [x, x]], [[x, x], [x, x]]] # shape = [2, 2, 2]
615 >>> tf.experimental.StructuredTensor.from_pyval(s4).row_partitions
616 (tf.RowPartition(row_splits=[0 2 4]),
617 tf.RowPartition(row_splits=[0 2 4 6 8]))
620 >>> s5 = [[[x, x], [x]], [[x, x]], [[x, x], [x]]] # shape = [3, None, None]
621 >>> tf.experimental.StructuredTensor.from_pyval(s5).row_partitions
622 (tf.RowPartition(row_splits=[0 2 3 5]),
623 tf.RowPartition(row_splits=[0 2 3 5 7 8]))
625 Note that shapes for nested fields (such as `x['b']` in the above example)
626 are not considered part of the shape of a `StructuredTensor`, and are not
627 included in `row_partitions`.
629 If this `StructuredTensor` has a ragged shape (i.e., if any of the
630 `row_partitions` is not uniform in size), then all fields will be encoded
631 as either `RaggedTensor`s or `StructuredTensor`s with these `RowPartition`s
632 used to define their outermost `self.rank` dimensions.
634 Returns:
635 A `tuple` of `RowPartition` objects with length `self.rank - 1`
636 (or `0` if `self.rank < 2`)
638 """
639 if self.rank < 2:
640 return ()
641 return self._ragged_shape._as_row_partitions() # pylint:disable=protected-access
643 def nrows(self):
644 """The number of rows in this StructuredTensor (if rank>0).
646 This means the length of the outer-most dimension of the StructuredTensor.
648 Notice that if `self.rank > 1`, then this equals the number of rows
649 of the first row partition. That is,
650 `self.nrows() == self.row_partitions[0].nrows()`.
652 Otherwise `self.nrows()` will be the first dimension of the field values.
654 Returns:
655 A scalar integer `Tensor` (or `None` if `self.rank == 0`).
656 """
657 if self.rank == 0:
658 return None
659 return self._ragged_shape[0]
661 def with_shape_dtype(self, dtype: dtypes.DType) -> 'StructuredTensor':
662 if dtype == self._ragged_shape.dtype:
663 return self
664 return StructuredTensor(
665 fields=_fields_with_dtype(self._fields, dtype),
666 ragged_shape=self._ragged_shape.with_dtype(dtype))
668 def _is_eager(self):
669 """True if all fields are composed of eager tensors."""
670 tensors = nest.flatten(self, expand_composites=True)
671 return all(isinstance(t, ops.EagerTensor) for t in tensors)
673 #=============================================================================
674 # Encoding
675 #=============================================================================
677 def field_names(self):
678 """Returns the string field names for this `StructuredTensor`."""
679 return tuple(self._fields.keys())
681 def field_value(self, field_name):
682 """Returns the tensor value for the specified field or path.
684 If `field_name` is a `string`, then it names a field directly owned by this
685 `StructuredTensor`. If this `StructuredTensor` has shape `[D1...DN]`, then
686 the returned tensor will have shape `[D1...DN, V1...VM]`, where the slice
687 `result[d1...dN]` contains the field value for the structure at
688 `self[d1...dN]`.
690 If `field_name` is a `tuple` of `string`, then it specifies a path to a
691 field owned by nested `StructuredTensor`. In particular,
692 `struct.field_value((f1, f2, ..., fN))` is equivalent to
693 `struct.field_value(f1).field_value(f2)....field_value(fN)`
695 Args:
696 field_name: `string` or `tuple` of `string`: The field whose values should
697 be returned.
699 Returns:
700 `Tensor`, `StructuredTensor`, or `RaggedTensor`.
702 Raises:
703 KeyError: If the given field_name is not found.
704 """
705 if isinstance(field_name, (list, tuple)):
706 value = self
707 for f in field_name:
708 if not isinstance(value, StructuredTensor):
709 raise KeyError('Field path {} not found in {}'.format(
710 field_name, self))
711 value = value.field_value(f)
712 return value
713 return self._fields[field_name]
715 #=============================================================================
716 # Operators
717 #=============================================================================
719 # TODO(edloper): Add support for ellipsis and/or newaxis?
720 def __getitem__(self, key):
721 """Returns the specified piece of this StructuredTensor.
723 * If `struct_tensor` is scalar (i.e., a single structure), then
724 `struct_tensor[f]` returns the value of field `f` (where `f` must be a
725 string).
727 * If `struct_tensor` is non-scalar (i.e., a vector or higher-dimensional
728 tensor of structures), `struct_tensor[i]` selects an element or slice of
729 the tensor using standard Python semantics (e.g., negative values index
730 from the end). `i` may have any of the following types:
732 * `int` constant
733 * `string` constant
734 * scalar integer `Tensor`
735 * `slice` containing integer constants and/or scalar integer
736 `Tensor`s
738 #### Multidimensional indexing
740 `StructuredTensor` supports multidimensional indexing. I.e., `key` may be a
741 `tuple` of values, indexing or slicing multiple dimensions at once. For
742 example, if `people` is a vector of structures, each of which has a vector-
743 valued `names` field, then `people[3, 'names', 0]` is equivalent to
744 `people[3]['names'][0]`; and `people[:, 'names', :]` will return a (possibly
745 ragged) matrix of names, with shape `[num_people, num_names_per_person]`.
747 Args:
748 key: Indicates which piece of the StructuredTensor to return.
750 Returns:
751 A `Tensor`, `StructuredTensor`, or `RaggedTensor`.
752 """
753 if isinstance(key, list):
754 key = tuple(key)
755 elif not isinstance(key, tuple):
756 key = (key,)
757 if not key:
758 return self
760 if self.rank == 0:
761 return self._scalar_getitem(key)
762 else:
763 return self._tensor_getitem(key)
765 def _scalar_getitem(self, key):
766 if (isinstance(key[0], slice) and key[0].start is None and
767 key[0].stop is None and key[0].step is None):
768 fields = dict((field_name, field_value.__getitem__(key[1:]))
769 for (field_name, field_value) in self._fields.items())
770 return StructuredTensor.from_fields(fields, self.shape)
772 elif not isinstance(key[0], compat.bytes_or_text_types):
773 raise ValueError('Key for indexing a StructuredTensor must be a '
774 "string or a full slice (':')")
776 return self._fields[key[0]].__getitem__(key[1:])
778 def _tensor_getitem(self, key):
779 rank = self.rank
780 if len(key) <= rank:
781 new_fields = dict((field_name, field_value.__getitem__(key))
782 for (field_name, field_value) in self._fields.items())
783 result_shape = self.shape.as_list()
784 for d, k in enumerate(key):
785 if isinstance(k, slice):
786 if not (k.start is None and k.stop is None and k.step is None):
787 # TODO(edloper): Better static shape analysis here.
788 result_shape[d] = None
789 elif isinstance(k, (int, ops.Tensor)):
790 result_shape[d] = -1 # mark for deletion
791 elif k is None:
792 raise ValueError('Slicing not supported for tf.newaxis')
793 else:
794 # Ellipsis, tf.newaxis:
795 raise ValueError('Slicing not supported for %r' % k)
796 result_shape = [d for d in result_shape if d != -1]
797 return StructuredTensor.from_fields(new_fields, result_shape)
799 else:
800 if not isinstance(key[rank], compat.bytes_or_text_types):
801 # TODO(edloper): Also support full slice here?
802 raise ValueError('Key for indexing a StructuredTensor must be a string')
803 return self._fields[key[rank]].__getitem__(key[:rank] + key[rank + 1:])
805 def __repr__(self):
806 fields = sorted(self._fields.items())
807 fields = ((k, str(v).replace('\n', '\n ')) for k, v in fields)
808 fields = ('"{}": {}'.format(k, v) for k, v in fields)
809 dict_repr = ',\n '.join(fields)
810 return ('<StructuredTensor(\n'
811 ' fields={\n'
812 ' %s},\n'
813 ' shape=%s)>' % (dict_repr, self.shape))
815 #=============================================================================
816 # Conversion
817 #=============================================================================
819 def to_pyval(self):
820 """Returns this StructuredTensor as a nested Python dict or list of dicts.
822 Converts this `StructuredTensor` to a nested python value:
824 * `StructTensors` with `rank=0` are converted into a dictionary, with an
825 entry for each field. Field names are used as keys and field values are
826 converted to python values. In particular:
828 * Scalar Tensor fields are converted to simple values (such as
829 `int` or `float` or `string`)
830 * Non-scalar Tensor fields and RaggedTensor fields are converted to
831 nested lists of simple values.
832 * StructuredTensor fields are converted recursively using `to_pyval`.
834 * `StructTensors` with `rank>0` are converted to nested python `list`s,
835 containing one dictionary for each structure (where each structure's
836 dictionary is defined as described above).
838 Requires that all fields are Eager tensors.
840 >>> tf.experimental.StructuredTensor.from_fields(
841 ... {'a': [1, 2, 3]}, [3]).to_pyval()
842 [{'a': 1}, {'a': 2}, {'a': 3}]
844 Note that `StructuredTensor.from_pyval(pyval).to_pyval() == pyval`.
846 Returns:
847 A nested Python dict or list of dicts.
848 """
849 if not self._is_eager():
850 raise ValueError(
851 'StructuredTensor.to_pyval() is only supported in eager mode.')
853 # Convert each field value to a nested list.
854 result = {}
855 for (key, value) in self._fields.items():
856 if isinstance(value, ops.EagerTensor):
857 value = value.numpy()
858 if isinstance(value, np.ndarray):
859 value = value.tolist()
860 elif isinstance(value, ragged_tensor.RaggedTensor):
861 value = value.to_list()
862 elif isinstance(value, StructuredTensor):
863 value = value.to_pyval()
864 # TODO(edloper): Throw an exception if value is an unexpected type.
865 result[key] = value
867 # If rank>0, then re-group each value from dict-of-list to list-of-dict.
868 if len(self.shape) > 0: # pylint: disable=g-explicit-length-test
869 if not result: # special-case for StructuredTensors w/ no fields.
870 return _empty_dict_pylist_from_row_partitions(self.row_partitions,
871 self.nrows())
872 return _pyval_field_major_to_node_major(
873 list(result.keys()), list(result.values()), self.rank)
874 else:
875 return result
877 @classmethod
878 def from_pyval(cls, pyval, typespec=None):
879 """Constructs a StructuredTensor from a nested Python structure.
881 >>> tf.experimental.StructuredTensor.from_pyval(
882 ... {'a': [1, 2, 3], 'b': [[4, 5], [6, 7]]})
883 <StructuredTensor(
884 fields={
885 "a": tf.Tensor([1 2 3], shape=(3,), dtype=int32),
886 "b": <tf.RaggedTensor [[4, 5], [6, 7]]>},
887 shape=())>
889 Note that `StructuredTensor.from_pyval(pyval).to_pyval() == pyval`.
891 Args:
892 pyval: The nested Python structure that should be used to create the new
893 `StructuredTensor`.
894 typespec: A `StructuredTensor.Spec` specifying the expected type for each
895 field. If not specified, then all nested dictionaries are turned into
896 StructuredTensors, and all nested lists are turned into Tensors (if
897 rank<2) or RaggedTensors (if rank>=2).
899 Returns:
900 A `StructuredTensor`.
901 """
902 return cls._from_pyval(pyval, typespec, ())
904 @classmethod
905 def _from_pyval(cls, pyval, typespec, path_so_far):
906 """Helper function for from_pyval.
909 Args:
910 pyval: The nested Python structure that should be used to create the new
911 `StructuredTensor`.
912 typespec: A `StructuredTensor.Spec` specifying the expected type for each
913 field. If not specified, then all nested dictionaries are turned into
914 StructuredTensors, and all nested lists are turned into Tensors (if
915 rank<2) or RaggedTensors (if rank>=2).
916 path_so_far: the path of fields that led here (for error messages).
918 Returns:
919 A `StructuredTensor`.
920 """
921 if isinstance(pyval, dict):
922 return cls._from_pydict(pyval, typespec, path_so_far)
923 elif isinstance(pyval, (list, tuple)):
924 keys = set()
925 rank = _pyval_find_struct_keys_and_depth(pyval, keys)
926 if rank is not None:
927 return cls._from_pylist_of_dict(pyval, keys, rank, typespec,
928 path_so_far)
929 else:
930 return cls._from_pylist_of_value(pyval, typespec, path_so_far)
931 else:
932 return cls._from_pyscalar(pyval, typespec, path_so_far)
934 @classmethod
935 def _from_pydict(cls, pyval, typespec, path_so_far):
936 """Converts python dictionary `pyval` to a StructuredTensor with rank=0."""
937 if typespec is None:
938 fields = dict((k, cls._from_pyval(v, None, path_so_far + (k,)))
939 for (k, v) in pyval.items())
940 else:
941 spec_shape = typespec._shape # pylint: disable=protected-access
942 field_specs = typespec._field_specs # pylint: disable=protected-access
943 if not (isinstance(typespec, StructuredTensor.Spec) and
944 spec_shape.rank == 0 and set(pyval) == set(field_specs)):
945 raise ValueError('Value at %r does not match typespec: %r vs %r' %
946 (path_so_far, pyval, typespec))
947 fields = dict((k, cls._from_pyval(v, field_specs[k], path_so_far + (k,)))
948 for (k, v) in pyval.items())
949 return StructuredTensor.from_fields(fields=fields, shape=(), validate=False)
951 @classmethod
952 def _from_pylist_of_dict(cls, pyval, keys, rank, typespec, path_so_far):
953 """Converts python list `pyval` to a StructuredTensor with rank>1."""
954 fields = dict((key, []) for key in keys)
955 for child in pyval:
956 _pyval_update_fields(child, fields, 1)
957 if typespec is None:
958 shape = tensor_shape.TensorShape([None] * rank)
959 for (key, target) in fields.items():
960 fields[key] = cls._from_pyval(target, None, path_so_far + (key,))
961 else:
962 field_specs = typespec._fields # pylint: disable=protected-access
963 if ((not isinstance(typespec, StructuredTensor.Spec)) or # pylint: disable=superfluous-parens
964 (set(fields) - set(field_specs))):
965 raise ValueError('Value at %r does not match typespec: %r vs %r' %
966 (path_so_far, pyval, typespec))
967 shape = typespec._shape
968 if shape.rank < rank:
969 raise ValueError('Value at %r does not match typespec (rank mismatch): '
970 '%r vs %r' % (path_so_far, pyval, typespec))
971 for (key, spec) in field_specs.items():
972 fields[key] = cls._from_pyval(
973 fields.get(key, []), spec, path_so_far + (key,))
974 try:
975 if not fields and typespec is None:
976 # TODO(b/183245576): handle cases where the typespec is known
977 # but the dictionary is empty.
978 return StructuredTensor._from_pylist_of_empty_dict(pyval, rank)
979 return StructuredTensor.from_fields(
980 fields=fields, shape=shape, validate=False)
981 except Exception as exc:
982 raise ValueError('Error parsing path %r' % (path_so_far,)) from exc
984 @classmethod
985 def _from_pylist_of_empty_dict(cls, pyval, rank):
986 """Converts a pylist of empty dictionaries to StructuredTensors."""
987 if rank == 0:
988 return StructuredTensor.from_fields(fields={}, shape=(), validate=False)
989 elif rank == 1:
990 nrows = len(pyval)
991 shape = (nrows,)
992 return StructuredTensor.from_fields(fields={}, shape=shape, nrows=nrows)
993 elif rank > 1:
994 ragged_zeros = ragged_factory_ops.constant(_dicts_to_zeros(pyval))
995 nrows = len(pyval)
996 shape = tensor_shape.TensorShape([len(pyval)] + ([None] * (rank - 1)))
997 return StructuredTensor.from_fields(
998 fields={},
999 shape=shape,
1000 row_partitions=ragged_zeros._nested_row_partitions, # pylint:disable=protected-access
1001 nrows=nrows)
1003 @classmethod
1004 def _from_pylist_of_value(cls, pyval, typespec, path_so_far):
1005 """Converts python list `pyval` to a Tensor or RaggedTensor with rank>1."""
1006 if typespec is None:
1007 try:
1008 return ragged_factory_ops.constant(pyval)
1009 except Exception as exc:
1010 raise ValueError('Error parsing path %r' % (path_so_far,)) from exc
1011 elif isinstance(typespec, tensor_spec.TensorSpec):
1012 try:
1013 result = constant_op.constant(pyval, typespec.dtype)
1014 except Exception as exc:
1015 raise ValueError('Error parsing path %r' % (path_so_far,)) from exc
1016 if not typespec.shape.is_compatible_with(result.shape):
1017 raise ValueError('Value at %r does not match typespec: %r vs %r' %
1018 (path_so_far, typespec, pyval))
1019 return result
1020 elif isinstance(typespec, ragged_tensor.RaggedTensorSpec):
1021 # pylint: disable=protected-access
1022 try:
1023 return ragged_factory_ops.constant(
1024 pyval,
1025 dtype=typespec._dtype,
1026 ragged_rank=typespec._ragged_rank,
1027 row_splits_dtype=typespec._row_splits_dtype,
1028 inner_shape=typespec._shape[typespec._ragged_rank + 1:])
1029 except Exception as exc:
1030 raise ValueError('Error parsing path %r' % (path_so_far,)) from exc
1031 elif isinstance(typespec, StructuredTensor.Spec):
1032 empty_rank = _pyval_empty_list_depth(pyval)
1033 if empty_rank is None:
1034 raise ValueError('Value at %r does not match typespec: %r vs %r' %
1035 (path_so_far, typespec, pyval))
1036 else:
1037 return cls._from_pylist_of_dict(pyval, set(), empty_rank, typespec,
1038 path_so_far)
1039 else:
1040 raise ValueError('Value at %r does not match typespec: %r vs %r' %
1041 (path_so_far, typespec, pyval))
1043 @classmethod
1044 def _from_pyscalar(cls, pyval, typespec, path_so_far):
1045 """Converts python scalar value `pyval` to a Tensor."""
1046 if typespec is None:
1047 try:
1048 return constant_op.constant(pyval)
1049 except Exception as exc:
1050 raise ValueError('Error parsing path %r' % (path_so_far,)) from exc
1051 else:
1052 if not (isinstance(typespec, tensor_spec.TensorSpec) and
1053 typespec.shape.rank == 0):
1054 raise ValueError('Value at %r does not match typespec: %r vs %r' %
1055 (path_so_far, typespec, pyval))
1056 # TODO(edloper): Check that typespec.shape matches.
1057 return constant_op.constant(pyval, typespec.dtype)
1059 #=============================================================================
1060 # Transforms
1061 #=============================================================================
1063 # TODO(edloper): Add a 'validate' option here?
1064 # TODO(edloper): Unify nomenclature with RaggedTensor. Should RaggedTensor
1065 # have a partition_outer_dimension method?
1066 def partition_outer_dimension(self, row_partition):
1067 """Partitions the outer dimension of this StructuredTensor.
1069 Returns a new `StructuredTensor` with the same values as `self`, where
1070 the outer dimension is partitioned into two (possibly ragged) dimensions.
1071 Requires that this StructuredTensor have an outer dimension (i.e.,
1072 `self.shape.rank > 0`).
1074 >>> st = tf.experimental.StructuredTensor.from_pyval(
1075 ... [{'foo': 12}, {'foo': 33}, {'foo': 99}])
1076 >>> partition = RowPartition.from_row_lengths([2, 0, 1])
1077 >>> st.partition_outer_dimension(partition)
1078 <StructuredTensor(
1079 fields={
1080 "foo": <tf.RaggedTensor [[12, 33], [], [99]]>},
1081 shape=(3, None))>
1083 Args:
1084 row_partition: A `RowPartition`.
1086 Returns:
1087 A `StructuredTensor` with rank `values.rank + 1`.
1088 """
1089 if not isinstance(row_partition, RowPartition):
1090 raise TypeError('row_partition must be a RowPartition.')
1091 if self.shape.rank == 0:
1092 raise ValueError('Shape %s must have rank at least 1' % self.shape)
1093 return _partition_outer_dimension(self, row_partition)
1095 def merge_dims(self, outer_axis, inner_axis):
1096 """Merges outer_axis...inner_axis into a single dimension.
1098 Returns a copy of this RaggedTensor with the specified range of dimensions
1099 flattened into a single dimension, with elements in row-major order.
1101 >>> st = tf.experimental.StructuredTensor.from_pyval(
1102 ... [[{'foo': 12}, {'foo': 33}], [], [{'foo': 99}]])
1103 >>> st.merge_dims(0, 1)
1104 <StructuredTensor(
1105 fields={
1106 "foo": tf.Tensor([12 33 99], shape=(3,), dtype=int32)},
1107 shape=(3,))>
1109 Args:
1110 outer_axis: `int`: The first dimension in the range of dimensions to
1111 merge. May be negative (to index from the last dimension).
1112 inner_axis: `int`: The last dimension in the range of dimensions to merge.
1113 May be negative (to index from the last dimension).
1115 Returns:
1116 A copy of this tensor, with the specified dimensions merged into a
1117 single dimension. The shape of the returned tensor will be
1118 `self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N`
1119 is the total number of slices in the merged dimensions.
1120 """
1121 outer_axis = array_ops.get_positive_axis(
1122 outer_axis,
1123 self.shape.rank,
1124 axis_name='outer_axis',
1125 ndims_name='rank(self)')
1126 inner_axis = array_ops.get_positive_axis(
1127 inner_axis,
1128 self.shape.rank,
1129 axis_name='inner_axis',
1130 ndims_name='rank(self)')
1131 if not outer_axis <= inner_axis:
1132 raise ValueError('Expected outer_axis (%d) to be less than or equal to '
1133 'inner_axis (%d)' % (outer_axis, inner_axis))
1134 return _merge_dims(self, outer_axis, inner_axis)
1136 class Spec:
1137 """A spec for StructuredTensor."""
1139 def __validate__(self):
1140 assert self._ragged_shape is not None
1142 @classmethod
1143 def _from_fields_and_rank(cls, fields, rank):
1144 """Creates a spec of a StructuredTensor with fields and rank."""
1145 shape = None
1146 for (k, v) in fields.items():
1147 field_shape_untruncated = _dynamic_ragged_shape_spec_from_spec(v)
1148 if field_shape_untruncated is None:
1149 raise ValueError(f'Cannot convert spec of {k}.')
1150 untruncated_rank = field_shape_untruncated.rank
1151 if (untruncated_rank is not None and untruncated_rank < rank):
1152 raise ValueError(f'Rank of field {k} is {untruncated_rank}, '
1153 f'but must be at least {rank}.')
1154 field_shape = field_shape_untruncated._truncate(rank) # pylint: disable=protected-access
1155 if shape is None:
1156 shape = field_shape
1157 else:
1158 shape = shape._merge_with(field_shape)
1159 return StructuredTensor.Spec(_ragged_shape=shape, _fields=fields)
1161 @classmethod
1162 def _from_shape(
1163 cls, shape: dynamic_ragged_shape.DynamicRaggedShape
1164 ) -> 'StructuredTensor.Spec':
1165 """Creates the spec of an empty StructuredTensor."""
1166 return StructuredTensor.Spec(_ragged_shape=shape, _fields={})
1168 # For backwards compatibility
1169 @property
1170 def _shape(self) -> tensor_shape.TensorShape:
1171 return self._ragged_shape._to_tensor_shape() # pylint: disable=protected-access
1173 # For backwards compatibility
1174 @property
1175 def _field_specs(self) -> Dict[str, type_spec.TypeSpec]:
1176 return self._fields
1178 # For backwards compatibility
1179 @property
1180 def shape(self) -> tensor_shape.TensorShape:
1181 return self._shape
1183 # For backwards compatibility
1184 @property
1185 def rank(self):
1186 return self._ragged_shape.rank
1189# Regular expression used to determine whether a string is a valid field name.
1190# Note: we plan to relax (or possibly eliminate) this in the future; you
1191# should not rely on the fact that some field names are currently disallowed.
1192_FIELD_NAME_RE = re.compile('^[a-zA-Z][a-zA-Z0-9_]*$')
1194#=============================================================================
1195# Helper functions
1196#=============================================================================
1197# TODO(edloper): Move some of these helpers to row_partition.py?
1200def _convert_to_structured_field_value(value):
1201 """Converts `value` to a Tensor, RaggedTensor, or StructuredTensor."""
1202 if isinstance(value,
1203 (ops.Tensor, ragged_tensor.RaggedTensor, StructuredTensor)):
1204 return value
1205 elif ragged_tensor.is_ragged(value):
1206 return ragged_tensor.convert_to_tensor_or_ragged_tensor(value)
1207 elif isinstance(value, extension_type.ExtensionType):
1208 return value
1209 else:
1210 try:
1211 return ops.convert_to_tensor(value)
1212 except (ValueError, TypeError) as e:
1213 raise TypeError('Unexpected type for value in `fields`: %r' %
1214 value) from e
1217def _find_shape_dtype(
1218 fields: Mapping[str, _FieldValue], nrows: Optional[ops.Tensor],
1219 row_partitions: Optional[Sequence[RowPartition]]) -> dtypes.DType:
1220 """Return a consistent dtype for fields, nrows, & row_partitions.
1222 In the future, the default will switch from int64 to int32, but for now,
1223 we stick with int64.
1225 Args:
1226 fields: the fields of the StructuredTensor.
1227 nrows: the nrows of the StructuredTensor
1228 row_partitions: the row_partitions of the StructuredTensor.
1230 Returns:
1231 If anything requires int64, then return int64.
1232 If int32 is explicitly specified, return int32. Otherwise, return int64.
1233 """
1234 field_dtypes = [_field_shape_dtype(v) for v in fields.values()]
1235 nrows_dtypes = [nrows.dtype] if isinstance(nrows, ops.Tensor) else []
1236 rp_dtypes = [] if row_partitions is None else [
1237 rp.dtype for rp in row_partitions
1238 ]
1240 all_dtypes = field_dtypes + nrows_dtypes + rp_dtypes
1242 if dtypes.int64 in all_dtypes:
1243 return dtypes.int64
1244 if dtypes.int32 in all_dtypes:
1245 return dtypes.int32
1247 # TODO(martinz): Eventually, shift this to tf.int32.
1248 return dtypes.int64
1251def _merge_nrows(nrows, static_nrows, value, dtype, validate):
1252 """Merges `nrows` with `nrows(value)`.
1254 Checks that `value` has the expected number of rows (`nrows`), and returns
1255 `nrows`. If `validate` is true, then add validation ops that check that
1256 the `nrows` values match.
1258 Args:
1259 nrows: scalar integer Tensor.
1260 static_nrows: tf.Dimension: static value of nrows, if known.
1261 value: Tensor or RaggedTensor or StructuredTensor
1262 dtype: dtype for `nrows`.
1263 validate: bool -- whether to add validation ops.
1265 Returns:
1266 A tuple `(nrows, static_nrows)`.
1267 """
1268 static_value_nrows = tensor_shape.dimension_at_index(value.shape, 0)
1269 if isinstance(value, ops.Tensor):
1270 value_nrows = array_ops.shape(value, out_type=dtype)[0]
1271 else:
1272 value_nrows = value.nrows()
1273 if nrows is None:
1274 nrows = value_nrows
1275 elif (static_value_nrows.value is not None and
1276 static_nrows.value is not None):
1277 if not static_value_nrows.is_compatible_with(static_nrows):
1278 raise ValueError('fields have incompatible nrows')
1279 nrows = value_nrows # No need to add an assertion op.
1280 elif validate:
1281 nrows = control_flow_ops.with_dependencies([
1282 check_ops.assert_equal(
1283 nrows, value_nrows, message='fields have incompatible nrows')
1284 ], nrows)
1285 return nrows, static_nrows._merge_with(static_value_nrows) # pylint: disable=protected-access
1288def _merge_row_partitions(row_partitions, value, rank, dtype, validate):
1289 """Merges `row_partitions` with `row_partitions(value)`."""
1290 if isinstance(value, ops.Tensor):
1291 value_row_partitions = _row_partitions_for_tensor(value, rank, dtype)
1293 elif isinstance(value, ragged_tensor.RaggedTensor):
1294 value_row_partitions = _row_partitions_for_ragged_tensor(value, rank, dtype)
1296 else:
1297 assert isinstance(value, StructuredTensor), type(value)
1298 value_row_partitions = value.row_partitions[:rank - 1]
1300 assert len(value_row_partitions) == rank - 1
1301 if row_partitions is None:
1302 return tuple(value_row_partitions)
1303 else:
1304 return tuple([
1305 p1._merge_precomputed_encodings(p2, validate) # pylint: disable=protected-access
1306 for (p1, p2) in zip(row_partitions, value_row_partitions)
1307 ])
1310def _row_partitions_for_tensor(value, rank, dtype):
1311 """Returns the row partitions for a tf.Tensor."""
1312 shape = array_ops.shape(value, out_type=dtype)
1313 return _row_partitions_for_uniform_shape(shape, rank)
1316def _row_partitions_for_ragged_tensor(value, rank, dtype):
1317 """Returns the row partitions for a tf.RaggedTensor."""
1318 assert rank > 1
1319 value_row_partitions = value._nested_row_partitions[:rank - 1] # pylint: disable=protected-access
1320 if len(value_row_partitions) < (rank - 1):
1321 value_row_partitions += _row_partitions_for_tensor(
1322 value.flat_values, rank - len(value_row_partitions), dtype)
1323 assert len(value_row_partitions) == rank - 1
1324 return value_row_partitions
1327def _row_partitions_for_uniform_shape(shape, rank):
1328 """Returns row partitions for the given shape Tensor.
1330 Args:
1331 shape: A vector describing a uniform shape.
1332 rank: The number of dimensions to generate row partitions for
1334 Returns:
1335 A list of (rank-1) `RowPartition`s with uniform row length.
1336 """
1337 shape_cumprod = math_ops.cumprod(shape[:rank])
1338 # pylint: disable=g-complex-comprehension
1339 return tuple([
1340 RowPartition.from_uniform_row_length(
1341 uniform_row_length=shape[i + 1],
1342 nvals=shape_cumprod[i + 1],
1343 nrows=shape_cumprod[i]) for i in range(rank - 1)
1344 ])
1347def _pyval_field_major_to_node_major(keys, values, depth):
1348 """Regroup each field (k, v) from dict-of-list to list-of-dict.
1350 Given a "field-major" encoding of the StructuredTensor (which maps each key to
1351 a single nested list containing the values for all structs), return a
1352 corresponding "node-major" encoding, consisting of a nested list of dicts.
1354 Args:
1355 keys: The field names (list of string). Must not be empty.
1356 values: The field values (list of python values). Must have the same length
1357 as `keys`.
1358 depth: The list depth at which dictionaries should be created.
1360 Returns:
1361 A nested list of dict, with depth `depth`.
1362 """
1363 assert keys
1364 if depth == 0:
1365 return dict(zip(keys, values))
1366 nvals = len(values[0])
1367 assert all(nvals == len(values[i]) for i in range(1, len(values)))
1368 return [
1369 _pyval_field_major_to_node_major(keys, value_slice, depth - 1)
1370 for value_slice in zip(*values)
1371 ]
1374def _empty_dict_pylist_from_row_partitions(row_partitions, nrows):
1375 """Returns a python list of empty dicts from the given row partitions.
1377 Args:
1378 row_partitions: The row-partitions describing the ragged shape of the
1379 result.
1380 nrows: The number of rows in the outermost row-partition. (Or if
1381 `len(row_partitions)==0`, then the number of empty dicts to return.)
1383 Returns:
1384 A nested python list whose leaves (if any) are empty python dicts.
1385 """
1386 if not row_partitions:
1387 return [{} for _ in range(nrows)]
1388 else:
1389 values = _empty_dict_pylist_from_row_partitions(
1390 row_partitions[1:], row_partitions[0].row_splits()[-1])
1391 splits = row_partitions[0].row_splits()
1392 return [values[splits[i]:splits[i + 1]] for i in range(len(splits) - 1)]
1395def _pyval_find_struct_keys_and_depth(pyval, keys):
1396 """Finds the keys & depth of nested dictionaries in `pyval`.
1398 Args:
1399 pyval: A nested structure of lists, tuples, and dictionaries.
1400 keys: (output parameter) A set, which will be updated with any keys that are
1401 found in the nested dictionaries.
1403 Returns:
1404 The nesting depth of dictionaries in `pyval`, or `None` if `pyval` does
1405 not contain any dictionaries.
1406 Raises:
1407 ValueError: If dictionaries have inconsistent depth.
1408 """
1409 if isinstance(pyval, dict):
1410 keys.update(pyval.keys())
1411 return 0
1412 elif isinstance(pyval, (list, tuple)):
1413 depth = None
1414 for child in pyval:
1415 child_depth = _pyval_find_struct_keys_and_depth(child, keys)
1416 if child_depth is not None:
1417 if depth is None:
1418 depth = child_depth + 1
1419 elif depth != child_depth + 1:
1420 raise ValueError('Inconsistent depth of dictionaries')
1421 return depth
1422 else:
1423 return None
1426def _pyval_update_fields(pyval, fields, depth):
1427 """Append the field values from `pyval` to `fields`.
1429 Args:
1430 pyval: A python `dict`, or nested list/tuple of `dict`, whose value(s)
1431 should be appended to `fields`.
1432 fields: A dictionary mapping string keys to field values. Field values
1433 extracted from `pyval` are appended to this dictionary's values.
1434 depth: The depth at which `pyval` should be appended to the field values.
1435 """
1436 if not isinstance(pyval, (dict, list, tuple)):
1437 raise ValueError('Expected dict or nested list/tuple of dict')
1439 for (key, target) in fields.items():
1440 for _ in range(1, depth):
1441 target = target[-1]
1442 target.append(pyval[key] if isinstance(pyval, dict) else [])
1444 if isinstance(pyval, (list, tuple)):
1445 for child in pyval:
1446 _pyval_update_fields(child, fields, depth + 1)
1449def _pyval_empty_list_depth(pyval):
1450 """Find the max depth for nested empty lists.
1452 Args:
1453 pyval: A nested python list.
1455 Returns:
1456 The maximum depth of empty lists in `pyval`, or None if `pyval` contains
1457 anything other than nested empty lists.
1458 """
1459 if isinstance(pyval, list):
1460 if not pyval:
1461 return 1
1462 depths = [_pyval_empty_list_depth(v) for v in pyval]
1463 if any(depth is None for depth in depths):
1464 return None
1465 else:
1466 return max(depths) + 1
1467 else:
1468 return None
1471def _replace_row_partitions(value, new_partitions):
1472 """Updates `value` to use `new_partitions` as its (outer) row partitions.
1474 This is used to ensure that all fields in a `StructuredTensor` use identical
1475 `RowPartition` objects for the shared dimensions. In particular,
1476 `StructuredTensor.from_fields` first merges all of the row partitions from
1477 any fields, and then replaces the outer row partitions of all fields with
1478 the merged row partitions (using this function).
1480 Args:
1481 value: A `Tensor`, `RaggedTensor`, or `StructuredTensor`.
1482 new_partitions: A list of row-partitions that should be used by `value`.
1483 Must be equivalent to `value`'s current row partitions.
1485 Returns:
1486 A value that is equivalent to `value`, where outer row partitions have been
1487 replaced by `new_partitions`.
1488 """
1489 if isinstance(value, ops.Tensor) or not new_partitions:
1490 return value
1492 elif isinstance(value, ragged_tensor.RaggedTensor):
1493 return ragged_tensor.RaggedTensor._from_row_partition( # pylint: disable=protected-access
1494 values=_replace_row_partitions(value.values, new_partitions[1:]),
1495 row_partition=new_partitions[0])
1497 else:
1498 assert isinstance(value, StructuredTensor)
1499 new_fields = dict((k, _replace_row_partitions(v, new_partitions))
1500 for (k, v) in value._fields.items())
1501 return StructuredTensor._old_init( # pylint: disable=protected-access
1502 fields=new_fields,
1503 shape=value.shape,
1504 nrows=value.nrows(),
1505 row_partitions=tuple(new_partitions) +
1506 tuple(value.row_partitions[len(new_partitions):]))
1509def _partition_outer_dimension(value, row_partition):
1510 """Partitions the outer dimension of `value` using `row_partitions`.
1512 Examples:
1514 >>> partition = RowPartition.from_row_lengths([2, 0, 1])
1515 >>> _partition_outer_dimension(tf.constant([1, 2, 3]), partition)
1516 <tf.RaggedTensor [[1, 2], [], [3]]>
1518 >>> struct_value = tf.experimental.StructuredTensor.from_pyval(
1519 ... [{'x': 1}, {'x': 2}, {'x': 3}])
1520 >>> _partition_outer_dimension(struct_value, partition)
1521 <StructuredTensor(
1522 fields={
1523 "x": <tf.RaggedTensor [[1, 2], [], [3]]>},
1524 shape=(3, None))>
1526 Args:
1527 value: Tensor, RaggedTensor, or StructuredTensor
1528 row_partition: RowPartition
1530 Returns:
1531 A value with the same type as `value`, where
1532 `result.rank = value.rank + 1`.
1533 """
1534 is_ragged = row_partition.uniform_row_length() is None
1535 if isinstance(value, ops.Tensor) and not is_ragged:
1536 new_shape = array_ops.concat(
1537 [[row_partition.nrows(),
1538 row_partition.uniform_row_length()],
1539 array_ops.shape(value, out_type=row_partition.dtype)[1:]],
1540 axis=0)
1541 return array_ops.reshape(value, new_shape)
1542 elif isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)):
1543 return ragged_tensor.RaggedTensor._from_row_partition( # pylint: disable=protected-access
1544 value, row_partition)
1545 else:
1546 assert isinstance(value, StructuredTensor)
1547 nrows = row_partition.static_nrows
1548 ncols = row_partition.static_uniform_row_length
1549 shape = tensor_shape.TensorShape([nrows,
1550 ncols]).concatenate(value.shape[1:])
1551 fields = dict((k, _partition_outer_dimension(v, row_partition))
1552 for (k, v) in value._fields.items())
1553 return StructuredTensor._old_init( # pylint: disable=protected-access
1554 fields, shape, row_partition.nrows(),
1555 (row_partition,) + value.row_partitions)
1558def _merge_dims(value, outer_axis, inner_axis):
1559 """Merges `outer_axis...inner_axis` of `value` into a single dimension."""
1560 assert outer_axis < inner_axis
1561 if isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)):
1562 return ragged_tensor.merge_dims(value, outer_axis, inner_axis)
1563 else:
1564 assert isinstance(value, StructuredTensor)
1565 fields = dict((k, _merge_dims(v, outer_axis, inner_axis))
1566 for (k, v) in value._fields.items())
1567 ragged_shape = value._ragged_shape._merge_dims( # pylint: disable=protected-access
1568 outer_axis, inner_axis)
1569 return StructuredTensor(fields, ragged_shape)
1572_structured_tensor_factory_key = object() # unique private object
1575def _dynamic_ragged_shape_spec_from_spec(
1576 spec: Union[dynamic_ragged_shape.DynamicRaggedShape.Spec,
1577 ragged_tensor.RaggedTensorSpec, StructuredTensor.Spec,
1578 tensor_spec.TensorSpec]
1579) -> dynamic_ragged_shape.DynamicRaggedShape.Spec:
1580 if isinstance(spec, StructuredTensor.Spec):
1581 return spec._ragged_shape # pylint: disable=protected-access
1582 else:
1583 return dynamic_ragged_shape.DynamicRaggedShape.Spec._from_spec(spec) # pylint: disable=protected-access
1586def _normalize_field_name_to_tuple(name: 'FieldName') -> Sequence[str]:
1587 """FieldName can be given also as string, this normalizes it to a tuple."""
1588 if isinstance(name, str):
1589 return (name,)
1590 if isinstance(name, list):
1591 return tuple(name)
1592 assert isinstance(name, tuple)
1593 return name
1596def _dicts_to_zeros(pyval):
1597 """Replaces dictionaries zeros in a pylist."""
1598 if isinstance(pyval, dict):
1599 return 0
1600 return [_dicts_to_zeros(x) for x in pyval]
1603def _merge_dims_generic(source, outer, inner):
1604 """Merges outer_axis...inner_axis into a single dimension.
1606 If outer == inner, this is a NOOP. If inner < outer, then this fials.
1607 If inner >= source.shape.rank, then the behavior is undefined.
1609 Args:
1610 source: a tensor, ragged tensor, or structured tensor.
1611 outer: a python int, indicating the first dimension to compress (must be
1612 nonnegative).
1613 inner: a python int, indicating the first dimension to keep (of the tail)
1614 (must be nonnegative).
1616 Returns:
1617 source with outer_axis...inner_axis merged into a single dimension.
1619 """
1620 if isinstance(source, StructuredTensor):
1621 return source.merge_dims(outer, inner)
1622 else:
1623 return ragged_tensor.merge_dims(source, outer, inner)
1626def _dynamic_ragged_shape_from_tensor(
1627 field, dtype=None) -> dynamic_ragged_shape.DynamicRaggedShape:
1628 """Extension of DynamicRaggedShape.from_tensor to support StructuredTensor."""
1629 if isinstance(field, StructuredTensor):
1630 return field._ragged_shape # pylint: disable=protected-access
1631 shape = array_ops.shape_v2(field, out_type=dtype)
1633 if isinstance(shape, ops.Tensor):
1634 return dynamic_ragged_shape.DynamicRaggedShape(
1635 row_partitions=[], inner_shape=shape)
1636 elif isinstance(shape, dynamic_ragged_shape.DynamicRaggedShape):
1637 return shape
1638 # TODO(martinz): add a test for the following line.
1639 raise TypeError(f'Expected shape tf.shape({field}) to return a Tensor or a '
1640 f'DynamicRaggedShape. Instead, got: {shape}.')
1643def _merge_with_optional(
1644 a: Optional[dynamic_ragged_shape.DynamicRaggedShape],
1645 b: Optional[dynamic_ragged_shape.DynamicRaggedShape]
1646) -> Optional[dynamic_ragged_shape.DynamicRaggedShape]:
1647 if a is None:
1648 return b
1649 if b is None:
1650 return a
1651 return a._merge_with(b) # pylint: disable=protected-access
1654def _shape_from_fields(
1655 fields, rank: int,
1656 dtype: dtypes.DType) -> Optional[dynamic_ragged_shape.DynamicRaggedShape]:
1657 """Given fields, rank, and dtype, create a shape."""
1659 field_shape = None
1660 for (k, field) in fields.items():
1661 try:
1662 next_field_shape_raw = _dynamic_ragged_shape_from_tensor(
1663 field, dtype=dtype)
1664 next_field_shape = next_field_shape_raw[:rank]
1665 field_shape = _merge_with_optional(field_shape, next_field_shape)
1666 except Exception as err:
1667 raise ValueError(f'Error in shape of {k}') from err
1669 return field_shape
1672def _field_shape_dtype(field: _FieldValue) -> Optional[dtypes.DType]:
1673 if isinstance(field, ragged_tensor.RaggedTensor):
1674 return field._row_partition.dtype # pylint: disable=protected-access
1675 if isinstance(field, StructuredTensor):
1676 return field._ragged_shape.dtype # pylint: disable=protected-access
1677 return None
1680def _field_with_shape_dtype(field: _FieldValue,
1681 dtype: dtypes.DType) -> _FieldValue:
1682 if isinstance(field, ragged_tensor.RaggedTensor):
1683 return field.with_row_splits_dtype(dtype)
1684 if isinstance(field, StructuredTensor):
1685 return field.with_shape_dtype(dtype)
1687 return field
1690def _fields_with_dtype(fields: Mapping[str, _FieldValue],
1691 dtype: dtypes.DType) -> Mapping[str, _FieldValue]:
1692 return {k: _field_with_shape_dtype(v, dtype) for (k, v) in fields.items()}
1695# pylint:disable=protected-access
1696def _dynamic_ragged_shape_init(fields, shape, nrows, row_partitions):
1697 """Produce a DynamicRaggedShape for StructuredTensor."""
1698 assert isinstance(fields, dict), fields
1699 assert isinstance(shape, tensor_shape.TensorShape), shape
1700 assert nrows is None or isinstance(nrows, ops.Tensor) or isinstance(
1701 nrows, int), nrows
1702 assert row_partitions is None or isinstance(row_partitions,
1703 tuple), row_partitions
1704 rank = shape.rank
1706 if rank is None:
1707 raise TypeError("StructuredTensor's shape must have known rank.")
1709 # TODO(martinz): figure out whether to validate.
1710 dtype = _find_shape_dtype(fields, nrows, row_partitions)
1712 fields = _fields_with_dtype(fields, dtype)
1714 result = None
1715 if shape.is_fully_defined():
1716 result = dynamic_ragged_shape.DynamicRaggedShape._from_inner_shape(
1717 shape.as_list(), dtype=dtype)
1719 if rank == 0:
1720 return dynamic_ragged_shape.DynamicRaggedShape._from_inner_shape(
1721 array_ops.zeros((0,), dtype=dtype))
1723 result = _merge_with_optional(result, _shape_from_fields(fields, rank, dtype))
1724 if rank == 1:
1725 alt_value = tensor_shape.dimension_value(shape[0])
1726 if alt_value is not None:
1727 nrows = alt_value
1728 if nrows is not None:
1729 result = _merge_with_optional(
1730 result,
1731 dynamic_ragged_shape.DynamicRaggedShape._from_inner_shape(
1732 [nrows], dtype=dtype))
1733 if result is None:
1734 raise ValueError('Must specify `nrows`, a fully specified `shape`,' +
1735 ' or have `fields` if `rank=1`')
1737 return result
1739 if row_partitions:
1740 result = _merge_with_optional(
1741 result,
1742 dynamic_ragged_shape.DynamicRaggedShape.from_row_partitions(
1743 row_partitions, dtype=dtype))
1745 if result is None:
1746 raise ValueError('Must specify row_partitions, a fully specified shape, ' +
1747 'or have fields if rank > 1')
1748 return result
1751# TODO(martinz): Drop this method or rename.
1752def StructuredTensorSpec(shape, field_specs): # pylint:disable=invalid-name
1753 """A placeholder for the old StructuredTensorSpec."""
1754 if not isinstance(field_specs, dict):
1755 raise TypeError('field_specs must be a dictionary.')
1756 for k in field_specs.keys():
1757 if not isinstance(k, str):
1758 raise TypeError('field_specs must be a dictionary with string keys.')
1759 for v in field_specs.values():
1760 if not isinstance(v, type_spec.TypeSpec):
1761 raise TypeError('field_specs must be a dictionary with TypeSpec values.')
1763 shape = dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape(
1764 tensor_shape.as_shape(shape), 0, dtypes.int32)
1765 rank = shape.rank
1766 if rank is None:
1767 raise TypeError("StructuredTensor's shape must have known rank.")
1768 for (k, v) in field_specs.items():
1769 field_shape_untruncated = _dynamic_ragged_shape_spec_from_spec(v)
1770 if field_shape_untruncated is None:
1771 raise ValueError(f'Cannot convert spec of {k}.')
1772 untruncated_rank = field_shape_untruncated.rank
1773 if (untruncated_rank is not None and untruncated_rank < rank):
1774 raise ValueError(f'Rank of field {k} is {untruncated_rank},'
1775 f' but must be at least {rank}.')
1776 field_shape = field_shape_untruncated._truncate(rank)
1777 shape = shape._merge_with(field_shape)
1778 return StructuredTensor.Spec(_ragged_shape=shape, _fields=field_specs)