Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/util/nest_util.py: 19%
388 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""Utility methods for handling nests.
18This module encapsulates different semantics of handling nests by the public
19tf.nest APIs and internal tf.data APIs. The difference in semantics exists for
20historic reasons and reconciliation would require a non-backwards compatible
21change.
23The implementation of the different semantics use a common utility to
24avoid / minimize further divergence between the two APIs over time.
25"""
27import collections as _collections
28import enum
30import six as _six
31import wrapt as _wrapt
33from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
34from tensorflow.python.platform import tf_logging
35from tensorflow.python.util import _pywrap_utils
36from tensorflow.python.util.compat import collections_abc as _collections_abc
39_is_mapping_view = _pywrap_utils.IsMappingView
40_is_attrs = _pywrap_utils.IsAttrs
41_is_composite_tensor = _pywrap_utils.IsCompositeTensor
42_is_type_spec = _pywrap_utils.IsTypeSpec
43_is_mutable_mapping = _pywrap_utils.IsMutableMapping
44_is_mapping = _pywrap_utils.IsMapping
45_tf_data_is_nested = _pywrap_utils.IsNestedForData
46_tf_data_flatten = _pywrap_utils.FlattenForData
47_tf_core_is_nested = _pywrap_utils.IsNested
48_is_nested_or_composite = _pywrap_utils.IsNestedOrComposite
49# See the swig file (util.i) for documentation.
50same_namedtuples = _pywrap_utils.SameNamedtuples
53STRUCTURES_HAVE_MISMATCHING_TYPES = (
54 "The two structures don't have the same sequence type. Input structure has "
55 "type {input_type}, while shallow structure has type {shallow_type}."
56)
58STRUCTURES_HAVE_MISMATCHING_LENGTHS = (
59 "The two structures don't have the same sequence length. Input "
60 "structure has length {input_length}, while shallow structure has length "
61 "{shallow_length}."
62)
64INPUT_TREE_SMALLER_THAN_SHALLOW_TREE = (
65 "The input_tree has fewer items than the shallow_tree. Input structure "
66 "has length {input_size}, while shallow structure has length "
67 "{shallow_size}."
68)
70SHALLOW_TREE_HAS_INVALID_KEYS = (
71 "The shallow_tree's keys are not a subset of the input_tree's keys. The "
72 "shallow_tree has the following keys that are not in the input_tree: {}."
73)
76class Modality(enum.Enum):
77 """Modality/semantic used for treating nested structures.
79 - Modality.CORE follows tensorflow_core/tf.nest semantics.
81 The following collection types are recognized by `tf.nest` as nested
82 structures:
84 * `collections.abc.Sequence` (except `string` and `bytes`).
85 This includes `list`, `tuple`, and `namedtuple`.
86 * `collections.abc.Mapping` (with sortable keys).
87 This includes `dict` and `collections.OrderedDict`.
88 * `collections.abc.MappingView` (with sortable keys).
89 * [`attr.s` classes](https://www.attrs.org/).
91 Any other values are considered **atoms**. Not all collection types are
92 considered nested structures. For example, the following types are
93 considered atoms:
95 * `set`; `{"a", "b"}` is an atom, while `["a", "b"]` is a nested structure.
96 * [`dataclass` classes](https://docs.python.org/library/dataclasses.html)
97 * `tf.Tensor`
98 * `numpy.array`
100 - Modality.DATA follows tf.data's nest semantics.
102 This modality makes two changes:
103 1. It removes support for lists as a level of nesting in nested structures.
104 2. It adds support for `SparseTensorValue` as an atomic element.
106 The motivation for this change is twofold:
108 1. It seems more natural for lists to be treated (e.g. in Dataset
109 constructors)
110 as tensors, rather than lists of (lists of...) tensors.
111 2. This is needed because `SparseTensorValue` is implemented as a `namedtuple`
112 that would normally be flattened and we want to be able to create sparse
113 tensor from `SparseTensorValue's similarly to creating tensors from numpy
114 arrays.
115 """
117 CORE = "CORE"
118 DATA = "DATA"
121class _DotString(object):
122 __slots__ = []
124 def __str__(self):
125 return "."
127 def __repr__(self):
128 return "."
131_DOT = _DotString()
134def is_nested(modality, structure):
135 """Returns true if its input is a nested structure.
137 For Modality.CORE refer to
138 [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest)
139 for the definition of a nested structure.
141 Args:
142 modality: enum value of supported modality [Modality.CORE or Modality.DATA]
143 structure: the value to test.
145 Returns:
146 True if the input is a nested structure.
147 """
148 if modality == Modality.CORE:
149 return _tf_core_is_nested(structure)
150 elif modality == Modality.DATA:
151 return _tf_data_is_nested(structure)
152 else:
153 raise ValueError(
154 "Unknown modality used {} for nested structure".format(modality)
155 )
158# TODO(b/225045380): Move to a "leaf" library to use in trace_type.
159def is_namedtuple(instance, strict=False):
160 """Returns True iff `instance` is a `namedtuple`.
162 Args:
163 instance: An instance of a Python object.
164 strict: If True, `instance` is considered to be a `namedtuple` only if it is
165 a "plain" namedtuple. For instance, a class inheriting from a `namedtuple`
166 will be considered to be a `namedtuple` iff `strict=False`.
168 Returns:
169 True if `instance` is a `namedtuple`.
170 """
171 return _pywrap_utils.IsNamedtuple(instance, strict)
174def sequence_like(instance, args):
175 """Converts the sequence `args` to the same type as `instance`.
177 Args:
178 instance: an instance of `tuple`, `list`, `namedtuple`, `dict`,
179 `collections.OrderedDict`, or `composite_tensor.Composite_Tensor` or
180 `type_spec.TypeSpec`.
181 args: items to be converted to the `instance` type.
183 Returns:
184 `args` with the type of `instance`.
185 """
186 if _is_mutable_mapping(instance):
187 # Pack dictionaries in a deterministic order by sorting the keys.
188 # Notice this means that we ignore the original order of `OrderedDict`
189 # instances. This is intentional, to avoid potential bugs caused by mixing
190 # ordered and plain dicts (e.g., flattening a dict but using a
191 # corresponding `OrderedDict` to pack it back).
192 result = dict(zip(_tf_core_sorted(instance), args))
193 instance_type = type(instance)
194 if instance_type == _collections.defaultdict:
195 d = _collections.defaultdict(instance.default_factory)
196 else:
197 d = instance_type()
198 for key in instance:
199 d[key] = result[key]
200 return d
201 elif _is_mapping(instance):
202 result = dict(zip(_tf_core_sorted(instance), args))
203 instance_type = type(instance)
204 if not getattr(instance_type, "__supported_by_tf_nest__", False):
205 tf_logging.log_first_n(
206 tf_logging.WARN,
207 "Mapping types may not work well with tf.nest. "
208 "Prefer using MutableMapping for {}".format(instance_type),
209 1,
210 )
211 try:
212 return instance_type((key, result[key]) for key in instance)
213 except TypeError as err:
214 # pylint: disable=raise-missing-from
215 raise TypeError(
216 "Error creating an object of type {} like {}. Note that "
217 "it must accept a single positional argument "
218 "representing an iterable of key-value pairs, in "
219 "addition to self. Cause: {}".format(type(instance), instance, err)
220 )
221 elif _is_mapping_view(instance):
222 # We can't directly construct mapping views, so we create a list instead
223 return list(args)
224 elif is_namedtuple(instance) or _is_attrs(instance):
225 if isinstance(instance, _wrapt.ObjectProxy):
226 instance_type = type(instance.__wrapped__)
227 else:
228 instance_type = type(instance)
229 return instance_type(*args)
230 elif _is_composite_tensor(instance):
231 assert len(args) == 1
232 spec = instance._type_spec # pylint: disable=protected-access
233 return spec._from_components(args[0]) # pylint: disable=protected-access
234 elif _is_type_spec(instance):
235 # Pack a CompositeTensor's components according to a TypeSpec.
236 assert len(args) == 1
237 return instance._from_components(args[0]) # pylint: disable=protected-access
238 elif isinstance(instance, _six.moves.range):
239 return sequence_like(list(instance), args)
240 elif isinstance(instance, _wrapt.ObjectProxy):
241 # For object proxies, first create the underlying type and then re-wrap it
242 # in the proxy type.
243 return type(instance)(sequence_like(instance.__wrapped__, args))
244 else:
245 # Not a namedtuple
246 return type(instance)(args)
249def _get_attrs_items(obj):
250 """Returns a list of (name, value) pairs from an attrs instance.
252 TODO(b/268078256): check if this comment is valid, and if so, ensure it's
253 handled in the function below.
254 The list will be sorted by name.
256 Args:
257 obj: an object.
259 Returns:
260 A list of (attr_name, attr_value) pairs, sorted by attr_name.
261 """
262 attrs = getattr(obj.__class__, "__attrs_attrs__")
263 attr_names = (a.name for a in attrs)
264 return [(attr_name, getattr(obj, attr_name)) for attr_name in attr_names]
267def _tf_core_sorted(dict_):
268 """Returns a sorted list of the dict keys, with error if keys not sortable."""
269 try:
270 return sorted(dict_.keys())
271 except TypeError:
272 # pylint: disable=raise-missing-from
273 raise TypeError("nest only supports dicts with sortable keys.")
276def _tf_data_sorted(dict_):
277 """Returns a sorted list of the dict keys, with error if keys not sortable."""
278 try:
279 return sorted(list(dict_))
280 except TypeError as e:
281 # pylint: disable=raise-missing-from
282 raise TypeError(
283 f"nest only supports dicts with sortable keys. Error: {e.message}"
284 )
287def yield_value(modality, iterable):
288 """Yield elements of `iterable` in a deterministic order.
290 Args:
291 modality: enum value of supported modality [Modality.CORE or Modality.DATA]
292 iterable: an iterable.
294 Yields:
295 The iterable elements in a deterministic order.
296 """
297 if modality == Modality.CORE:
298 yield from _tf_core_yield_value(iterable)
299 elif modality == Modality.DATA:
300 yield from _tf_data_yield_value(iterable)
301 else:
302 raise ValueError(
303 "Unknown modality used {} for nested structure".format(modality)
304 )
307def _tf_core_yield_value(iterable):
308 for _, v in _tf_core_yield_sorted_items(iterable):
309 yield v
312def yield_sorted_items(modality, iterable):
313 if modality == Modality.CORE:
314 return _tf_core_yield_sorted_items(iterable)
315 else:
316 raise ValueError(
317 "Unknown modality used {} for nested structure".format(modality)
318 )
321def _tf_core_yield_sorted_items(iterable):
322 """Yield (key, value) pairs for `iterable` in a deterministic order.
324 For Sequences, the key will be an int, the array index of a value.
325 For Mappings, the key will be the dictionary key.
326 For objects (e.g. namedtuples), the key will be the attribute name.
328 In all cases, the keys will be iterated in sorted order.
330 Args:
331 iterable: an iterable.
333 Yields:
334 The iterable's (key, value) pairs, in order of sorted keys.
335 """
336 # Ordered to check common structure types (list, tuple, dict) first.
337 if isinstance(iterable, list):
338 for item in enumerate(iterable):
339 yield item
340 # namedtuples handled separately to avoid expensive namedtuple check.
341 elif type(iterable) == tuple: # pylint: disable=unidiomatic-typecheck
342 for item in enumerate(iterable):
343 yield item
344 elif isinstance(iterable, (dict, _collections_abc.Mapping)):
345 # Iterate through dictionaries in a deterministic order by sorting the
346 # keys. Notice this means that we ignore the original order of `OrderedDict`
347 # instances. This is intentional, to avoid potential bugs caused by mixing
348 # ordered and plain dicts (e.g., flattening a dict but using a
349 # corresponding `OrderedDict` to pack it back).
350 for key in _tf_core_sorted(iterable):
351 yield key, iterable[key]
352 elif _is_attrs(iterable):
353 for item in _get_attrs_items(iterable):
354 yield item
355 elif is_namedtuple(iterable):
356 for field in iterable._fields:
357 yield field, getattr(iterable, field)
358 elif _is_composite_tensor(iterable):
359 type_spec = iterable._type_spec # pylint: disable=protected-access
360 yield type_spec.value_type.__name__, type_spec._to_components(iterable) # pylint: disable=protected-access
361 elif _is_type_spec(iterable):
362 # Note: to allow CompositeTensors and their TypeSpecs to have matching
363 # structures, we need to use the same key string here.
364 yield iterable.value_type.__name__, iterable._component_specs # pylint: disable=protected-access
365 else:
366 for item in enumerate(iterable):
367 yield item
370def _tf_data_yield_value(iterable):
371 """Yield elements of `iterable` in a deterministic order.
373 Args:
374 iterable: an iterable.
376 Yields:
377 The iterable elements in a deterministic order.
378 """
379 # pylint: disable=protected-access
380 if isinstance(iterable, _collections_abc.Mapping):
381 # Iterate through dictionaries in a deterministic order by sorting the
382 # keys. Notice this means that we ignore the original order of `OrderedDict`
383 # instances. This is intentional, to avoid potential bugs caused by mixing
384 # ordered and plain dicts (e.g., flattening a dict but using a
385 # corresponding `OrderedDict` to pack it back).
386 for key in _tf_data_sorted(iterable):
387 yield iterable[key]
388 # To avoid circular imports. sparse_tensor
389 # depends on tensorflow/python/util/nest.py transitively, and if we try to
390 # import sparse_tensor again, it results in a circular import. Instead, here
391 # we check the class name instead of using `isinstance`.
392 elif iterable.__class__.__name__ == "SparseTensorValue":
393 yield iterable
394 elif _is_attrs(iterable):
395 for _, attr in _get_attrs_items(iterable):
396 yield attr
397 else:
398 for value in iterable:
399 yield value
402def assert_same_structure(
403 modality, nest1, nest2, check_types=True, expand_composites=False
404):
405 """Asserts that two structures are nested in the same way.
407 For Modality.CORE refer to
408 [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest)
409 for the definition of a structure. Note the method does not check the types of
410 atoms inside the structures.
412 Examples:
414 * These atom vs. atom comparisons will pass:
416 >>> tf.nest.assert_same_structure(1.5, tf.Variable(1, tf.uint32))
417 >>> tf.nest.assert_same_structure("abc", np.array([1, 2]))
419 * These nested structure vs. nested structure comparisons will pass:
421 >>> structure1 = (((1, 2), 3), 4, (5, 6))
422 >>> structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
423 >>> structure3 = [(("a", "b"), "c"), "d", ["e", "f"]]
424 >>> tf.nest.assert_same_structure(structure1, structure2)
425 >>> tf.nest.assert_same_structure(structure1, structure3, check_types=False)
427 >>> import collections
428 >>> tf.nest.assert_same_structure(
429 ... collections.namedtuple("bar", "a b")(1, 2),
430 ... collections.namedtuple("foo", "a b")(2, 3),
431 ... check_types=False)
433 >>> tf.nest.assert_same_structure(
434 ... collections.namedtuple("bar", "a b")(1, 2),
435 ... { "a": 1, "b": 2 },
436 ... check_types=False)
438 >>> tf.nest.assert_same_structure(
439 ... { "a": 1, "b": 2, "c": 3 },
440 ... { "c": 6, "b": 5, "a": 4 })
442 >>> ragged_tensor1 = tf.RaggedTensor.from_row_splits(
443 ... values=[3, 1, 4, 1, 5, 9, 2, 6],
444 ... row_splits=[0, 4, 4, 7, 8, 8])
445 >>> ragged_tensor2 = tf.RaggedTensor.from_row_splits(
446 ... values=[3, 1, 4],
447 ... row_splits=[0, 3])
448 >>> tf.nest.assert_same_structure(
449 ... ragged_tensor1,
450 ... ragged_tensor2,
451 ... expand_composites=True)
453 * These examples will raise exceptions:
455 >>> tf.nest.assert_same_structure([0, 1], np.array([0, 1]))
456 Traceback (most recent call last):
457 ...
458 ValueError: The two structures don't have the same nested structure
460 >>> tf.nest.assert_same_structure(
461 ... collections.namedtuple('bar', 'a b')(1, 2),
462 ... collections.namedtuple('foo', 'a b')(2, 3))
463 Traceback (most recent call last):
464 ...
465 TypeError: The two structures don't have the same nested structure
467 For Modality.DATA, nested structures are treated differently than
468 Modality.CORE. Please refer to class Modality's documentation above to read up
469 on these differences.
471 Args:
472 modality: enum value of supported modality [Modality.CORE or Modality.DATA]
473 nest1: an atom or a nested structure.
474 nest2: an atom or a nested structure.
475 check_types: - For Modality.CORE: if `True` (default) types of structures
476 are checked as well, including the keys of dictionaries. If set to
477 `False`, for example a list and a tuple of objects will look the same if
478 they have the same size. Note that namedtuples with identical name and
479 fields are always considered to have the same shallow structure. Two types
480 will also be considered the same if they are both list subtypes (which
481 allows "list" and "_ListWrapper" from trackable dependency tracking to
482 compare equal). `check_types=True` only checks type of sub-structures. The
483 types of atoms are not checked. - For Modality.DATA: if `True` (default)
484 types of sequences should be same as well. For dictionary, "type" of
485 dictionary is considered to include its keys. In other words, two
486 dictionaries with different keys are considered to have a different
487 "type". If set to `False`, two iterables are considered same as long as
488 they yield the elements that have same structures.
489 expand_composites: Arg only valid for Modality.CORE. If true, then composite
490 tensors such as `tf.sparse.SparseTensor` and `tf.RaggedTensor` are
491 expanded into their component tensors.
493 Raises:
494 ValueError: If the two structures do not have the same number of atoms or
495 if the two structures are not nested in the same way.
496 TypeError: If the two structures differ in the type of sequence in any of
497 their substructures. Only possible if `check_types` is `True`.
498 """
499 if modality == Modality.CORE:
500 _tf_core_assert_same_structure(nest1, nest2, check_types, expand_composites)
501 elif modality == Modality.DATA:
502 _tf_data_assert_same_structure(nest1, nest2, check_types)
503 else:
504 raise ValueError(
505 "Unknown modality used {} for nested structure".format(modality)
506 )
509# pylint: disable=missing-function-docstring
510def _tf_core_assert_same_structure(
511 nest1, nest2, check_types=True, expand_composites=False
512):
513 # Convert to bool explicitly as otherwise pybind will not be able# to handle
514 # type mismatch message correctly. See GitHub issue 42329 for details.
515 check_types = bool(check_types)
516 expand_composites = bool(expand_composites)
517 try:
518 _pywrap_utils.AssertSameStructure(
519 nest1, nest2, check_types, expand_composites
520 )
521 except (ValueError, TypeError) as e:
522 str1 = str(_tf_core_map_structure(lambda _: _DOT, nest1))
523 str2 = str(_tf_core_map_structure(lambda _: _DOT, nest2))
524 raise type(e)(
525 "%s\nEntire first structure:\n%s\nEntire second structure:\n%s"
526 % (str(e), str1, str2)
527 )
530def _tf_data_assert_same_structure(nest1, nest2, check_types=True):
531 _pywrap_utils.AssertSameStructureForData(nest1, nest2, check_types)
534def _tf_core_packed_nest_with_indices(
535 structure, flat, index, is_nested_fn, sequence_fn=None
536):
537 """Helper function for pack_sequence_as.
539 Args:
540 structure: structure to mimic.
541 flat: Flattened values to output substructure for.
542 index: Index at which to start reading from flat.
543 is_nested_fn: Function used to test if a value should be treated as a nested
544 structure.
545 sequence_fn: Function used to generate a new strcuture instance.
547 Returns:
548 The tuple (new_index, child), where:
549 * new_index - the updated index into `flat` having processed `structure`.
550 * packed - the subset of `flat` corresponding to `structure`,
551 having started at `index`, and packed into the same nested
552 format.
554 Raises:
555 ValueError: if `structure` contains more atoms than `flat`
556 (assuming indexing starts from `index`).
557 """
558 packed = []
559 sequence_fn = sequence_fn or sequence_like
560 for s in _tf_core_yield_value(structure):
561 if is_nested_fn(s):
562 new_index, child = _tf_core_packed_nest_with_indices(
563 s, flat, index, is_nested_fn, sequence_fn
564 )
565 packed.append(sequence_fn(s, child))
566 index = new_index
567 else:
568 packed.append(flat[index])
569 index += 1
570 return index, packed
573def _tf_data_packed_nest_with_indices(structure, flat, index):
574 """Helper function for pack_nest_as.
576 Args:
577 structure: Substructure (tuple of elements and/or tuples) to mimic
578 flat: Flattened values to output substructure for.
579 index: Index at which to start reading from flat.
581 Returns:
582 The tuple (new_index, child), where:
583 * new_index - the updated index into `flat` having processed `structure`.
584 * packed - the subset of `flat` corresponding to `structure`,
585 having started at `index`, and packed into the same nested
586 format.
588 Raises:
589 ValueError: if `structure` contains more elements than `flat`
590 (assuming indexing starts from `index`).
591 """
592 packed = []
593 for s in _tf_data_yield_value(structure):
594 if _tf_data_is_nested(s):
595 new_index, child = _tf_data_packed_nest_with_indices(s, flat, index)
596 packed.append(sequence_like(s, child)) # pylint: disable=protected-access
597 index = new_index
598 else:
599 packed.append(flat[index])
600 index += 1
601 return index, packed
604def flatten(modality, structure, expand_composites=False):
605 """Flattens a nested structure.
607 - For Modality.CORE: refer to
608 [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest)
609 for the definition of a structure.
611 If the structure is an atom, then returns a single-item list: [structure].
613 This is the inverse of the `nest.pack_sequence_as` method that takes in a
614 flattened list and re-packs it into the nested structure.
616 In the case of dict instances, the sequence consists of the values, sorted by
617 key to ensure deterministic behavior. This is true also for OrderedDict
618 instances: their sequence order is ignored, the sorting order of keys is used
619 instead. The same convention is followed in `nest.pack_sequence_as`. This
620 correctly repacks dicts and OrderedDicts after they have been flattened, and
621 also allows flattening an OrderedDict and then repacking it back using a
622 corresponding plain dict, or vice-versa. Dictionaries with non-sortable keys
623 cannot be flattened.
625 Users must not modify any collections used in nest while this function is
626 running.
628 Examples:
630 1. Python dict (ordered by key):
632 >>> dict = { "key3": "value3", "key1": "value1", "key2": "value2" }
633 >>> tf.nest.flatten(dict)
634 ['value1', 'value2', 'value3']
636 2. For a nested python tuple:
638 >>> tuple = ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0)
639 >>> tf.nest.flatten(tuple)
640 [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
642 3. For a nested dictionary of dictionaries:
644 >>> dict = { "key3": {"c": (1.0, 2.0), "a": (3.0)},
645 ... "key1": {"m": "val1", "g": "val2"} }
646 >>> tf.nest.flatten(dict)
647 ['val2', 'val1', 3.0, 1.0, 2.0]
649 4. Numpy array (will not flatten):
651 >>> array = np.array([[1, 2], [3, 4]])
652 >>> tf.nest.flatten(array)
653 [array([[1, 2],
654 [3, 4]])]
656 5. `tf.Tensor` (will not flatten):
658 >>> tensor = tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])
659 >>> tf.nest.flatten(tensor)
660 [<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
661 array([[1., 2., 3.],
662 [4., 5., 6.],
663 [7., 8., 9.]], dtype=float32)>]
665 6. `tf.RaggedTensor`: This is a composite tensor thats representation consists
666 of a flattened list of 'values' and a list of 'row_splits' which indicate how
667 to chop up the flattened list into different rows. For more details on
668 `tf.RaggedTensor`, please visit
669 https://www.tensorflow.org/api_docs/python/tf/RaggedTensor.
671 with `expand_composites=False`, we just return the RaggedTensor as is.
673 >>> tensor = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2]])
674 >>> tf.nest.flatten(tensor, expand_composites=False)
675 [<tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2]]>]
677 with `expand_composites=True`, we return the component Tensors that make up
678 the RaggedTensor representation (the values and row_splits tensors)
680 >>> tensor = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2]])
681 >>> tf.nest.flatten(tensor, expand_composites=True)
682 [<tf.Tensor: shape=(7,), dtype=int32, numpy=array([3, 1, 4, 1, 5, 9, 2],
683 dtype=int32)>,
684 <tf.Tensor: shape=(4,), dtype=int64, numpy=array([0, 4, 4, 7])>]
686 Args:
687 modality: enum value of supported modality [Modality.CORE or Modality.DATA]
688 structure: an atom or a nested structure. Note, numpy arrays are considered
689 atoms and are not flattened.
690 expand_composites: Arg valid for Modality.CORE only. If true, then composite
691 tensors such as `tf.sparse.SparseTensor` and `tf.RaggedTensor` are
692 expanded into their component tensors.
694 Returns:
695 A Python list, the flattened version of the input.
697 Raises:
698 TypeError: The nest is or contains a dict with non-sortable keys.
699 """
700 if modality == Modality.CORE:
701 return _tf_core_flatten(structure, expand_composites)
702 elif modality == Modality.DATA:
703 return _tf_data_flatten(structure)
704 else:
705 raise ValueError(
706 "Unknown modality used {} for nested structure".format(modality)
707 )
710def _tf_core_flatten(structure, expand_composites=False):
711 """See comments for flatten() in tensorflow/python/util/nest.py."""
712 if structure is None:
713 return [None]
714 expand_composites = bool(expand_composites)
715 return _pywrap_utils.Flatten(structure, expand_composites)
718def pack_sequence_as(
719 modality, structure, flat_sequence, expand_composites, sequence_fn=None
720):
721 """Returns a given flattened sequence packed into a given structure.
723 - For Modality.CORE: Refer to
724 [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest)
725 for the definition of a structure.
727 If `structure` is an atom, `flat_sequence` must be a single-item list;
728 in this case the return value is `flat_sequence[0]`.
730 If `structure` is or contains a dict instance, the keys will be sorted to
731 pack the flat sequence in deterministic order. This is true also for
732 `OrderedDict` instances: their sequence order is ignored, the sorting order of
733 keys is used instead. The same convention is followed in `flatten`.
734 This correctly repacks dicts and `OrderedDict`s after they have been
735 flattened, and also allows flattening an `OrderedDict` and then repacking it
736 back using a corresponding plain dict, or vice-versa.
737 Dictionaries with non-sortable keys cannot be flattened.
739 Examples:
741 1. Python dict:
743 >>> structure = { "key3": "", "key1": "", "key2": "" }
744 >>> flat_sequence = ["value1", "value2", "value3"]
745 >>> tf.nest.pack_sequence_as(structure, flat_sequence)
746 {'key3': 'value3', 'key1': 'value1', 'key2': 'value2'}
748 2. For a nested python tuple:
750 >>> structure = (('a','b'), ('c','d','e'), 'f')
751 >>> flat_sequence = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
752 >>> tf.nest.pack_sequence_as(structure, flat_sequence)
753 ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0)
755 3. For a nested dictionary of dictionaries:
757 >>> structure = { "key3": {"c": ('alpha', 'beta'), "a": ('gamma')},
758 ... "key1": {"e": "val1", "d": "val2"} }
759 >>> flat_sequence = ['val2', 'val1', 3.0, 1.0, 2.0]
760 >>> tf.nest.pack_sequence_as(structure, flat_sequence)
761 {'key3': {'c': (1.0, 2.0), 'a': 3.0}, 'key1': {'e': 'val1', 'd': 'val2'}}
763 4. Numpy array (considered a scalar):
765 >>> structure = ['a']
766 >>> flat_sequence = [np.array([[1, 2], [3, 4]])]
767 >>> tf.nest.pack_sequence_as(structure, flat_sequence)
768 [array([[1, 2],
769 [3, 4]])]
771 5. tf.Tensor (considered a scalar):
773 >>> structure = ['a']
774 >>> flat_sequence = [tf.constant([[1., 2., 3.], [4., 5., 6.]])]
775 >>> tf.nest.pack_sequence_as(structure, flat_sequence)
776 [<tf.Tensor: shape=(2, 3), dtype=float32,
777 numpy= array([[1., 2., 3.], [4., 5., 6.]], dtype=float32)>]
779 6. `tf.RaggedTensor`: This is a composite tensor thats representation consists
780 of a flattened list of 'values' and a list of 'row_splits' which indicate how
781 to chop up the flattened list into different rows. For more details on
782 `tf.RaggedTensor`, please visit
783 https://www.tensorflow.org/api_docs/python/tf/RaggedTensor.
785 With `expand_composites=False`, we treat RaggedTensor as a scalar.
787 >>> structure = { "foo": tf.ragged.constant([[1, 2], [3]]),
788 ... "bar": tf.constant([[5]]) }
789 >>> flat_sequence = [ "one", "two" ]
790 >>> tf.nest.pack_sequence_as(structure, flat_sequence,
791 ... expand_composites=False)
792 {'foo': 'two', 'bar': 'one'}
794 With `expand_composites=True`, we expect that the flattened input contains
795 the tensors making up the ragged tensor i.e. the values and row_splits
796 tensors.
798 >>> structure = { "foo": tf.ragged.constant([[1., 2.], [3.]]),
799 ... "bar": tf.constant([[5.]]) }
800 >>> tensors = tf.nest.flatten(structure, expand_composites=True)
801 >>> print(tensors)
802 [<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[5.]],
803 dtype=float32)>,
804 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([1., 2., 3.],
805 dtype=float32)>,
806 <tf.Tensor: shape=(3,), dtype=int64, numpy=array([0, 2, 3])>]
807 >>> verified_tensors = [tf.debugging.check_numerics(t, 'invalid tensor: ')
808 ... if t.dtype==tf.float32 else t
809 ... for t in tensors]
810 >>> tf.nest.pack_sequence_as(structure, verified_tensors,
811 ... expand_composites=True)
812 {'foo': <tf.RaggedTensor [[1.0, 2.0], [3.0]]>,
813 'bar': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[5.]],
814 dtype=float32)>}
816 - For Modality.DATA: If `structure` is a scalar, `flat_sequence` must be a
817 single-element list;
818 in this case the return value is `flat_sequence[0]`.
820 Args:
821 modality: enum value of supported modality [Modality.CORE or Modality.DATA]
822 structure: - For Modality.CORE: Nested structure, whose structure is given
823 by nested lists, tuples, and dicts. Note: numpy arrays and strings are
824 considered scalars. - For Modality.DATA: tuple or list constructed of
825 scalars and/or other tuples/lists, or a scalar. Note: numpy arrays are
826 considered scalars.
827 flat_sequence: flat sequence to pack.
828 expand_composites: Arg valid for Modality.CORE only. If true, then composite
829 tensors such as `tf.sparse.SparseTensor` and `tf.RaggedTensor` are
830 expanded into their component tensors.
831 sequence_fn: Arg valid for Modality.CORE only.
833 Returns:
834 packed: `flat_sequence` converted to have the same recursive structure as
835 `structure`.
837 Raises:
838 ValueError: If `flat_sequence` and `structure` have different
839 atom counts.
840 TypeError: For Modality.CORE only. `structure` is or contains a dict with
841 non-sortable keys.
842 """
843 if modality == Modality.CORE:
844 return _tf_core_pack_sequence_as(
845 structure, flat_sequence, expand_composites, sequence_fn
846 )
847 elif modality == Modality.DATA:
848 return _tf_data_pack_sequence_as(structure, flat_sequence)
849 else:
850 raise ValueError(
851 "Unknown modality used {} for nested structure".format(modality)
852 )
855def _tf_core_pack_sequence_as(
856 structure, flat_sequence, expand_composites, sequence_fn=None
857):
858 """Implements sequence packing, with the option to alter the structure."""
859 is_nested_fn = (
860 _is_nested_or_composite if expand_composites else _tf_core_is_nested
861 )
862 sequence_fn = sequence_fn or sequence_like
864 def truncate(value, length):
865 value_str = str(value)
866 return value_str[:length] + (value_str[length:] and "...")
868 if not is_nested_fn(flat_sequence):
869 raise TypeError(
870 "Attempted to pack value:\n {}\ninto a structure, but found "
871 "incompatible type `{}` instead.".format(
872 truncate(flat_sequence, 100), type(flat_sequence)
873 )
874 )
876 if not is_nested_fn(structure):
877 if len(flat_sequence) != 1:
878 raise ValueError(
879 "The target structure is of type `{}`\n {}\nHowever the input "
880 "is a sequence ({}) of length {}.\n {}\nnest cannot "
881 "guarantee that it is safe to map one to the other.".format(
882 type(structure),
883 truncate(structure, 100),
884 type(flat_sequence),
885 len(flat_sequence),
886 truncate(flat_sequence, 100),
887 )
888 )
889 return flat_sequence[0]
891 try:
892 final_index, packed = _tf_core_packed_nest_with_indices(
893 structure, flat_sequence, 0, is_nested_fn, sequence_fn
894 )
895 if final_index < len(flat_sequence):
896 raise IndexError
897 except IndexError:
898 flat_structure = _tf_core_flatten(
899 structure, expand_composites=expand_composites
900 )
901 if len(flat_structure) != len(flat_sequence):
902 # pylint: disable=raise-missing-from
903 raise ValueError(
904 "Could not pack sequence. Structure had %d atoms, but "
905 "flat_sequence had %d items. Structure: %s, flat_sequence: %s."
906 % (len(flat_structure), len(flat_sequence), structure, flat_sequence)
907 )
908 return sequence_fn(structure, packed)
911def _tf_data_pack_sequence_as(structure, flat_sequence):
912 """Returns a given flattened sequence packed into a nest.
914 If `structure` is a scalar, `flat_sequence` must be a single-element list;
915 in this case the return value is `flat_sequence[0]`.
917 Args:
918 structure: tuple or list constructed of scalars and/or other tuples/lists,
919 or a scalar. Note: numpy arrays are considered scalars.
920 flat_sequence: flat sequence to pack.
922 Returns:
923 packed: `flat_sequence` converted to have the same recursive structure as
924 `structure`.
926 Raises:
927 ValueError: If nest and structure have different element counts.
928 """
929 if not (_tf_data_is_nested(flat_sequence) or isinstance(flat_sequence, list)):
930 raise TypeError(
931 "Argument `flat_sequence` must be a sequence. Got "
932 f"'{type(flat_sequence).__name__}'."
933 )
935 if not _tf_data_is_nested(structure):
936 if len(flat_sequence) != 1:
937 raise ValueError(
938 "Argument `structure` is a scalar but "
939 f"`len(flat_sequence)`={len(flat_sequence)} > 1"
940 )
941 return flat_sequence[0]
943 flat_structure = _tf_data_flatten(structure)
944 if len(flat_structure) != len(flat_sequence):
945 raise ValueError(
946 "Could not pack sequence. Argument `structure` had "
947 f"{len(flat_structure)} elements, but argument `flat_sequence` had "
948 f"{len(flat_sequence)} elements. Received structure: "
949 f"{structure}, flat_sequence: {flat_sequence}."
950 )
952 _, packed = _tf_data_packed_nest_with_indices(structure, flat_sequence, 0)
953 return sequence_like(structure, packed) # pylint: disable=protected-access
956def map_structure(modality, func, *structure, **kwargs):
957 """Creates a new structure by applying `func` to each atom in `structure`.
959 - For Modality.CORE: Refer to
960 [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest)
961 for the definition of a structure.
963 Applies `func(x[0], x[1], ...)` where x[i] enumerates all atoms in
964 `structure[i]`. All items in `structure` must have the same arity,
965 and the return value will contain results with the same structure layout.
967 Examples:
969 * A single Python dict:
971 >>> a = {"hello": 24, "world": 76}
972 >>> tf.nest.map_structure(lambda p: p * 2, a)
973 {'hello': 48, 'world': 152}
975 * Multiple Python dictionaries:
977 >>> d1 = {"hello": 24, "world": 76}
978 >>> d2 = {"hello": 36, "world": 14}
979 >>> tf.nest.map_structure(lambda p1, p2: p1 + p2, d1, d2)
980 {'hello': 60, 'world': 90}
982 * A single Python list:
984 >>> a = [24, 76, "ab"]
985 >>> tf.nest.map_structure(lambda p: p * 2, a)
986 [48, 152, 'abab']
988 * Scalars:
990 >>> tf.nest.map_structure(lambda x, y: x + y, 3, 4)
991 7
993 * Empty structures:
995 >>> tf.nest.map_structure(lambda x: x + 1, ())
996 ()
998 * Check the types of iterables:
1000 >>> s1 = (((1, 2), 3), 4, (5, 6))
1001 >>> s1_list = [[[1, 2], 3], 4, [5, 6]]
1002 >>> tf.nest.map_structure(lambda x, y: None, s1, s1_list)
1003 Traceback (most recent call last):
1004 ...
1005 TypeError: The two structures don't have the same nested structure
1007 * Type check is set to False:
1009 >>> s1 = (((1, 2), 3), 4, (5, 6))
1010 >>> s1_list = [[[1, 2], 3], 4, [5, 6]]
1011 >>> tf.nest.map_structure(lambda x, y: None, s1, s1_list, check_types=False)
1012 (((None, None), None), None, (None, None))
1014 - For Modality.DATA: Applies `func(x[0], x[1], ...)` where x[i] is an entry in
1015 `structure[i]`. All structures in `structure` must have the same arity,
1016 and the return value will contain the results in the same structure.
1018 Args:
1019 modality: enum value of supported modality [Modality.CORE or Modality.DATA]
1020 func: A callable that accepts as many arguments as there are structures.
1021 *structure: - For Modality.CORE: atom or nested structure. - For
1022 Modality.DATA: scalar, or tuple or list of constructed scalars and/or
1023 other tuples/lists, or scalars. Note: numpy arrays are considered
1024 scalars.
1025 **kwargs: Valid keyword args are: * `check_types`: - For Modality.CORE: If
1026 set to `True` (default) the types of iterables within the structures have
1027 to be same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError`
1028 exception). To allow this set this argument to `False`. Note that
1029 namedtuples with identical name and fields are always considered to have
1030 the same shallow structure. - For Modality.DATA: only valid keyword
1031 argument is `check_types`. If set to `True` (default) the types of
1032 iterables within the structures have to be same (e.g. `map_structure(func,
1033 [1], (1,))` raises a `TypeError` exception). To allow this set this
1034 argument to `False`. * `expand_composites`: Valid for Modality.CORE only.
1035 If set to `True`, then composite tensors such as `tf.sparse.SparseTensor`
1036 and `tf.RaggedTensor` are expanded into their component tensors. If
1037 `False` (the default), then composite tensors are not expanded.
1039 Returns:
1040 A new structure with the same arity as `structure[0]`, whose atoms
1041 correspond to `func(x[0], x[1], ...)` where `x[i]` is the atom in the
1042 corresponding location in `structure[i]`. If there are different structure
1043 types and `check_types` is `False` the structure types of the first
1044 structure will be used.
1046 Raises:
1047 TypeError: If `func` is not callable or if the structures do not match
1048 each other by depth tree.
1049 ValueError: If no structure is provided or if the structures do not match
1050 each other by type.
1051 ValueError: If wrong keyword arguments are provided.
1052 """
1053 if modality == Modality.CORE:
1054 return _tf_core_map_structure(func, *structure, **kwargs)
1055 elif modality == Modality.DATA:
1056 return _tf_data_map_structure(func, *structure, **kwargs)
1057 else:
1058 raise ValueError(
1059 "Unknown modality used {} for nested structure".format(modality)
1060 )
1063# pylint: disable=missing-function-docstring
1064def _tf_core_map_structure(func, *structure, **kwargs):
1065 if not callable(func):
1066 raise TypeError("func must be callable, got: %s" % func)
1068 if not structure:
1069 raise ValueError("Must provide at least one structure")
1071 check_types = kwargs.pop("check_types", True)
1072 expand_composites = kwargs.pop("expand_composites", False)
1074 if kwargs:
1075 raise ValueError(
1076 "Only valid keyword arguments are `check_types` and "
1077 "`expand_composites`, not: `%s`"
1078 % "`, `".join(kwargs.keys())
1079 )
1081 for other in structure[1:]:
1082 _tf_core_assert_same_structure(
1083 structure[0],
1084 other,
1085 check_types=check_types,
1086 expand_composites=expand_composites,
1087 )
1089 flat_structure = (_tf_core_flatten(s, expand_composites) for s in structure)
1090 entries = zip(*flat_structure)
1092 return _tf_core_pack_sequence_as(
1093 structure[0],
1094 [func(*x) for x in entries],
1095 expand_composites=expand_composites,
1096 )
1099# pylint: disable=missing-function-docstring
1100def _tf_data_map_structure(func, *structure, **check_types_dict):
1101 if not callable(func):
1102 raise TypeError(f"Argument `func` must be callable, got: {func}")
1104 if not structure:
1105 raise ValueError("Must provide at least one structure")
1107 if check_types_dict:
1108 if "check_types" not in check_types_dict or len(check_types_dict) > 1:
1109 raise ValueError(
1110 "Only valid keyword argument for `check_types_dict` is "
1111 f"'check_types'. Got {check_types_dict}."
1112 )
1113 check_types = check_types_dict["check_types"]
1114 else:
1115 check_types = True
1117 for other in structure[1:]:
1118 _tf_data_assert_same_structure(structure[0], other, check_types=check_types)
1120 flat_structure = (_tf_data_flatten(s) for s in structure)
1121 entries = zip(*flat_structure)
1123 return _tf_data_pack_sequence_as(structure[0], [func(*x) for x in entries])
1126def yield_flat_up_to(modality, shallow_tree, input_tree, is_nested_fn, path=()):
1127 """Yields (path, value) pairs of input_tree flattened up to shallow_tree.
1129 - For Modality.CORE: See comments for _tf_core_yield_flat_up_to() below
1130 - For Modality.DATA: See comments for _tf_data_yield_flat_up_to() below
1132 Args:
1133 modality: enum value of supported modality [Modality.CORE or Modality.DATA]
1134 shallow_tree: Nested structure. Traverse no further than its leaf nodes.
1135 input_tree: Nested structure. Return the paths and values from this tree.
1136 Must have the same upper structure as shallow_tree.
1137 is_nested_fn: Arg valid for Modality.CORE only. Function used to test if a
1138 value should be treated as a nested structure.
1139 path: Arg valid for Modality.CORE only. Tuple. Optional argument, only used
1140 when recursing. The path from the root of the original shallow_tree, down
1141 to the root of the shallow_tree arg of this recursive call.
1143 Yields:
1144 Pairs of (path, value), where path the tuple path of a leaf node in
1145 shallow_tree, and value is the value of the corresponding node in
1146 input_tree.
1147 """
1148 if modality == Modality.CORE:
1149 yield from _tf_core_yield_flat_up_to(
1150 shallow_tree, input_tree, is_nested_fn, path
1151 )
1152 elif modality == Modality.DATA:
1153 yield from _tf_data_yield_flat_up_to(shallow_tree, input_tree)
1154 else:
1155 raise ValueError(
1156 "Unknown modality used {} for nested structure".format(modality)
1157 )
1160def _tf_core_yield_flat_up_to(shallow_tree, input_tree, is_nested_fn, path=()):
1161 """Yields (path, value) pairs of input_tree flattened up to shallow_tree.
1163 Args:
1164 shallow_tree: Nested structure. Traverse no further than its leaf nodes.
1165 input_tree: Nested structure. Return the paths and values from this tree.
1166 Must have the same upper structure as shallow_tree.
1167 is_nested_fn: Function used to test if a value should be treated as a nested
1168 structure.
1169 path: Tuple. Optional argument, only used when recursing. The path from the
1170 root of the original shallow_tree, down to the root of the shallow_tree
1171 arg of this recursive call.
1173 Yields:
1174 Pairs of (path, value), where path the tuple path of a leaf node in
1175 shallow_tree, and value is the value of the corresponding node in
1176 input_tree.
1177 """
1178 if not is_nested_fn(shallow_tree):
1179 yield (path, input_tree)
1180 else:
1181 input_tree = dict(_tf_core_yield_sorted_items(input_tree))
1182 for (
1183 shallow_key,
1184 shallow_subtree,
1185 ) in _tf_core_yield_sorted_items(shallow_tree):
1186 subpath = path + (shallow_key,)
1187 input_subtree = input_tree[shallow_key]
1188 for leaf_path, leaf_value in _tf_core_yield_flat_up_to(
1189 shallow_subtree, input_subtree, is_nested_fn, path=subpath
1190 ):
1191 yield (leaf_path, leaf_value)
1194def _tf_data_yield_flat_up_to(shallow_tree, input_tree):
1195 """Yields elements `input_tree` partially flattened up to `shallow_tree`."""
1196 if _tf_data_is_nested(shallow_tree):
1197 for shallow_branch, input_branch in zip(
1198 _tf_data_yield_value(shallow_tree), _tf_data_yield_value(input_tree)
1199 ):
1200 for input_leaf in _tf_data_yield_flat_up_to(shallow_branch, input_branch):
1201 yield input_leaf
1202 else:
1203 yield input_tree
1206def assert_shallow_structure(
1207 modality,
1208 shallow_tree,
1209 input_tree,
1210 check_types=True,
1211 expand_composites=False,
1212):
1213 """Asserts that `shallow_tree` is a shallow structure of `input_tree`.
1215 This function tests if the `input_tree` structure can be created from
1216 the `shallow_tree` structure by replacing its leaf nodes with deeper
1217 tree structures.
1219 Examples:
1221 The following code will raise an exception:
1222 ```python
1223 shallow_tree = {"a": "A", "b": "B"}
1224 input_tree = {"a": 1, "c": 2}
1225 assert_shallow_structure(shallow_tree, input_tree)
1226 ```
1228 The following code will raise an exception:
1229 ```python
1230 shallow_tree = ["a", "b"]
1231 input_tree = ["c", ["d", "e"], "f"]
1232 assert_shallow_structure(shallow_tree, input_tree)
1233 ```
1235 Args:
1236 modality: enum value of supported modality [Modality.CORE or Modality.DATA]
1237 shallow_tree: an arbitrarily nested structure.
1238 input_tree: an arbitrarily nested structure.
1239 check_types: if `True` (default) the sequence types of `shallow_tree` and
1240 `input_tree` have to be the same. Note that even with check_types==True,
1241 this function will consider two different namedtuple classes with the same
1242 name and _fields attribute to be the same class.
1243 expand_composites: Valid for Modality.CORE only. If true, then composite
1244 tensors such as `tf.sparse.SparseTensor` and `tf.RaggedTensor` are
1245 expanded into their component tensors.
1247 Raises:
1248 TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
1249 TypeError: If the sequence types of `shallow_tree` are different from
1250 `input_tree`. Only raised if `check_types` is `True`.
1251 ValueError: If the sequence lengths of `shallow_tree` are different from
1252 `input_tree`.
1253 """
1254 if modality == Modality.CORE:
1255 _tf_core_assert_shallow_structure(
1256 shallow_tree, input_tree, check_types, expand_composites
1257 )
1258 elif modality == Modality.DATA:
1259 _tf_data_assert_shallow_structure(shallow_tree, input_tree, check_types)
1260 else:
1261 raise ValueError(
1262 "Unknown modality used {} for nested structure".format(modality)
1263 )
1266# pylint: disable=missing-function-docstring
1267def _tf_core_assert_shallow_structure(
1268 shallow_tree, input_tree, check_types=True, expand_composites=False
1269):
1270 is_nested_fn = (
1271 _is_nested_or_composite if expand_composites else _tf_core_is_nested
1272 )
1273 if is_nested_fn(shallow_tree):
1274 if not is_nested_fn(input_tree):
1275 raise TypeError(
1276 "If shallow structure is a sequence, input must also be a sequence. "
1277 "Input has type: %s."
1278 % type(input_tree)
1279 )
1281 if isinstance(shallow_tree, _wrapt.ObjectProxy):
1282 shallow_type = type(shallow_tree.__wrapped__)
1283 else:
1284 shallow_type = type(shallow_tree)
1286 if check_types and not isinstance(input_tree, shallow_type):
1287 # Duck-typing means that nest should be fine with two different
1288 # namedtuples with identical name and fields.
1289 shallow_is_namedtuple = is_namedtuple(shallow_tree, False)
1290 input_is_namedtuple = is_namedtuple(input_tree, False)
1291 if shallow_is_namedtuple and input_is_namedtuple:
1292 if not same_namedtuples(shallow_tree, input_tree):
1293 raise TypeError(
1294 STRUCTURES_HAVE_MISMATCHING_TYPES.format(
1295 input_type=type(input_tree), shallow_type=type(shallow_tree)
1296 )
1297 )
1299 elif isinstance(shallow_tree, list) and isinstance(input_tree, list):
1300 # List subclasses are considered the same,
1301 # e.g. python list vs. _ListWrapper.
1302 pass
1304 elif (
1305 _is_composite_tensor(shallow_tree) or _is_type_spec(shallow_tree)
1306 ) and (_is_composite_tensor(input_tree) or _is_type_spec(input_tree)):
1307 pass # Compatibility will be checked below.
1309 elif not (
1310 isinstance(shallow_tree, _collections_abc.Mapping)
1311 and isinstance(input_tree, _collections_abc.Mapping)
1312 ):
1313 raise TypeError(
1314 STRUCTURES_HAVE_MISMATCHING_TYPES.format(
1315 input_type=type(input_tree), shallow_type=type(shallow_tree)
1316 )
1317 )
1319 if _is_composite_tensor(shallow_tree) or _is_composite_tensor(input_tree):
1320 if not (
1321 (_is_composite_tensor(input_tree) or _is_type_spec(input_tree))
1322 and (
1323 _is_composite_tensor(shallow_tree) or _is_type_spec(shallow_tree)
1324 )
1325 ):
1326 raise TypeError(
1327 STRUCTURES_HAVE_MISMATCHING_TYPES.format(
1328 input_type=type(input_tree), shallow_type=type(shallow_tree)
1329 )
1330 )
1331 # pylint: disable=protected-access
1332 type_spec_1 = (
1333 shallow_tree
1334 if _is_type_spec(shallow_tree)
1335 else shallow_tree._type_spec
1336 )._without_tensor_names()
1337 type_spec_2 = (
1338 input_tree if _is_type_spec(input_tree) else input_tree._type_spec
1339 )._without_tensor_names()
1340 # TODO(b/246356867): Replace the most_specific_common_supertype below
1341 # with get_structure.
1342 if hasattr(type_spec_1, "_get_structure") and hasattr(
1343 type_spec_2, "_get_structure"
1344 ):
1345 result = (
1346 type_spec_1._get_structure() == type_spec_2._get_structure() or None
1347 )
1348 else:
1349 result = type_spec_1.most_specific_common_supertype([type_spec_2])
1350 if result is None:
1351 raise ValueError(
1352 "Incompatible CompositeTensor TypeSpecs: %s vs. %s"
1353 % (type_spec_1, type_spec_2)
1354 )
1355 # pylint: enable=protected-access
1357 elif _is_type_spec(shallow_tree):
1358 if not _is_type_spec(input_tree):
1359 raise TypeError(
1360 "If shallow structure is a TypeSpec, input must also "
1361 "be a TypeSpec. Input has type: %s."
1362 % type(input_tree)
1363 )
1364 else:
1365 if len(input_tree) != len(shallow_tree):
1366 raise ValueError(
1367 STRUCTURES_HAVE_MISMATCHING_LENGTHS.format(
1368 input_length=len(input_tree), shallow_length=len(shallow_tree)
1369 )
1370 )
1371 elif len(input_tree) < len(shallow_tree):
1372 raise ValueError(
1373 INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format(
1374 input_size=len(input_tree), shallow_size=len(shallow_tree)
1375 )
1376 )
1378 if isinstance(shallow_tree, _collections_abc.Mapping):
1379 absent_keys = set(shallow_tree) - set(input_tree)
1380 if absent_keys:
1381 raise ValueError(
1382 SHALLOW_TREE_HAS_INVALID_KEYS.format(sorted(absent_keys))
1383 )
1385 for shallow_branch, input_branch in zip(
1386 _tf_core_yield_value(shallow_tree),
1387 _tf_core_yield_value(input_tree),
1388 ):
1389 _tf_core_assert_shallow_structure(
1390 shallow_branch,
1391 input_branch,
1392 check_types=check_types,
1393 expand_composites=expand_composites,
1394 )
1397# pylint: disable=missing-function-docstring
1398def _tf_data_assert_shallow_structure(
1399 shallow_tree, input_tree, check_types=True
1400):
1401 if _tf_data_is_nested(shallow_tree):
1402 if not _tf_data_is_nested(input_tree):
1403 raise TypeError(
1404 "If shallow structure is a sequence, input must also be a sequence. "
1405 f"Input has type: '{type(input_tree).__name__}'."
1406 )
1408 if check_types and not isinstance(input_tree, type(shallow_tree)):
1409 raise TypeError(
1410 "The two structures don't have the same sequence type. Input "
1411 f"structure has type '{type(input_tree).__name__}', while shallow "
1412 f"structure has type '{type(shallow_tree).__name__}'."
1413 )
1415 if len(input_tree) != len(shallow_tree):
1416 raise ValueError(
1417 "The two structures don't have the same sequence length. Input "
1418 f"structure has length {len(input_tree)}, while shallow structure "
1419 f"has length {len(shallow_tree)}."
1420 )
1422 if check_types and isinstance(shallow_tree, _collections_abc.Mapping):
1423 if set(input_tree) != set(shallow_tree):
1424 raise ValueError(
1425 "The two structures don't have the same keys. Input "
1426 f"structure has keys {list(input_tree)}, while shallow structure "
1427 f"has keys {list(shallow_tree)}."
1428 )
1429 input_tree = sorted(input_tree.items())
1430 shallow_tree = sorted(shallow_tree.items())
1432 for shallow_branch, input_branch in zip(shallow_tree, input_tree):
1433 _tf_data_assert_shallow_structure(
1434 shallow_branch, input_branch, check_types=check_types
1435 )
1438def flatten_up_to(
1439 modality,
1440 shallow_tree,
1441 input_tree,
1442 check_types=True,
1443 expand_composites=False,
1444):
1445 # pylint: disable=g-doc-return-or-yield,g-doc-args
1446 """Flattens `input_tree` up to `shallow_tree`.
1448 - For Modality.CORE: refer to
1449 [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest)
1450 for the definition of a structure.
1452 Any further depth in structure in `input_tree` is retained as structures in
1453 the partially flatten output.
1455 If `shallow_tree` and `input_tree` are atoms, this returns a
1456 single-item list: `[input_tree]`.
1458 Use Case:
1460 Sometimes we may wish to partially flatten a structure, retaining some
1461 of the nested structure. We achieve this by specifying a shallow structure,
1462 `shallow_tree`, we wish to flatten up to.
1464 The input, `input_tree`, can be thought of as having the same structure layout
1465 as `shallow_tree`, but with leaf nodes that are themselves tree structures.
1467 Examples:
1469 ```python
1470 input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
1471 shallow_tree = [[True, True], [False, True]]
1473 flattened_input_tree = flatten_up_to(shallow_tree, input_tree)
1474 flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree)
1476 # Output is:
1477 # [[2, 2], [3, 3], [4, 9], [5, 5]]
1478 # [True, True, False, True]
1479 ```
1481 ```python
1482 input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]]
1483 shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]]
1485 input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree)
1486 input_tree_flattened = flatten(input_tree)
1488 # Output is:
1489 # [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
1490 # ['a', 1, 'b', 2, 'c', 3, 'd', 4]
1491 ```
1493 Edge Cases:
1495 ```python
1496 flatten_up_to(0, 0) # Output: [0]
1497 flatten_up_to(0, [0, 1, 2]) # Output: [[0, 1, 2]]
1498 flatten_up_to([0, 1, 2], 0) # Output: TypeError
1499 flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2]
1501 ```
1503 Args:
1504 modality: enum value of supported modality [Modality.CORE or Modality.DATA]
1505 shallow_tree: a possibly pruned structure of input_tree.
1506 input_tree: an atom or a nested structure. Note, numpy arrays are considered
1507 atoms.
1508 check_types: bool. If True, check that each node in shallow_tree has the
1509 same type as the corresponding node in input_tree.
1510 expand_composites: Arg valid for Modality.CORE only. If true, then composite
1511 tensors such as `tf.sparse.SparseTensor` and `tf.RaggedTensor` are
1512 expanded into their component tensors.
1514 Returns:
1515 A Python list, the partially flattened version of `input_tree` according to
1516 the structure of `shallow_tree`.
1518 Raises:
1519 TypeError: If `shallow_tree` is a nested structure but `input_tree` is not.
1520 TypeError: If the structure types of `shallow_tree` are different from
1521 `input_tree`.
1522 ValueError: If the structure lengths of `shallow_tree` are different from
1523 `input_tree`.
1524 """
1525 if modality == Modality.CORE:
1526 return _tf_core_flatten_up_to(
1527 shallow_tree, input_tree, check_types, expand_composites
1528 )
1529 elif modality == Modality.DATA:
1530 return _tf_data_flatten_up_to(shallow_tree, input_tree)
1531 else:
1532 raise ValueError(
1533 "Unknown modality used {} for nested structure".format(modality)
1534 )
1537def _tf_core_flatten_up_to(
1538 shallow_tree, input_tree, check_types=True, expand_composites=False
1539):
1540 is_nested_fn = (
1541 _is_nested_or_composite if expand_composites else _tf_core_is_nested
1542 )
1543 _tf_core_assert_shallow_structure(
1544 shallow_tree,
1545 input_tree,
1546 check_types=check_types,
1547 expand_composites=expand_composites,
1548 )
1549 # Discard paths returned by nest_util._tf_core_yield_flat_up_to.
1550 return [
1551 v
1552 for _, v in _tf_core_yield_flat_up_to(
1553 shallow_tree, input_tree, is_nested_fn
1554 )
1555 ]
1558def _tf_data_flatten_up_to(shallow_tree, input_tree):
1559 _tf_data_assert_shallow_structure(shallow_tree, input_tree)
1560 return list(_tf_data_yield_flat_up_to(shallow_tree, input_tree))
1563def map_structure_up_to(modality, shallow_tree, func, *inputs, **kwargs):
1564 """Applies a function or op to a number of partially flattened inputs.
1566 The `inputs` are flattened up to `shallow_tree` before being mapped.
1568 Use Case:
1570 Sometimes we wish to apply a function to a partially flattened
1571 structure (for example when the function itself takes structure inputs). We
1572 achieve this by specifying a shallow structure, `shallow_tree` we wish to
1573 flatten up to.
1575 The `inputs`, can be thought of as having the same structure layout as
1576 `shallow_tree`, but with leaf nodes that are themselves tree structures.
1578 This function therefore will return something with the same base structure as
1579 `shallow_tree`.
1581 Examples:
1583 ```python
1584 shallow_tree = [None, None]
1585 inp_val = [1, 2, 3]
1586 out = map_structure_up_to(shallow_tree, lambda x: 2 * x, inp_val)
1588 # Output is: [2, 4]
1589 ```
1591 ```python
1592 ab_tuple = collections.namedtuple("ab_tuple", "a, b")
1593 op_tuple = collections.namedtuple("op_tuple", "add, mul")
1594 inp_val = ab_tuple(a=2, b=3)
1595 inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
1596 out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul,
1597 inp_val, inp_ops)
1599 # Output is: ab_tuple(a=6, b=15)
1600 ```
1602 ```python
1603 data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
1604 name_list = ['evens', ['odds', 'primes']]
1605 out = map_structure_up_to(
1606 name_list,
1607 lambda name, sec: "first_{}_{}".format(len(sec), name),
1608 name_list, data_list)
1610 # Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']]
1611 ```
1613 Args:
1614 modality: enum value of supported modality [Modality.CORE or Modality.DATA]
1615 shallow_tree: a shallow structure, common to all the inputs.
1616 func: callable which will be applied to each input individually.
1617 *inputs: structures that are compatible with shallow_tree. The function
1618 `func` is applied to corresponding structures due to partial flattening of
1619 each input, so the function must support arity of `len(inputs)`.
1620 **kwargs: Arg valid for Modality.CORE only. kwargs to feed to func().
1621 Special kwarg `check_types` is not passed to func, but instead determines
1622 whether the types of iterables within the structures have to be same (e.g.
1623 `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow
1624 this set this argument to `False`.
1626 Raises:
1627 TypeError: If `shallow_tree` is a nested structure but `input_tree` is not.
1628 TypeError: If the structure types of `shallow_tree` are different from
1629 `input_tree`.
1630 ValueError: If the structure lengths of `shallow_tree` are different from
1631 `input_tree`.
1633 Returns:
1634 result of repeatedly applying `func`, with the same structure layout as
1635 `shallow_tree`.
1636 """
1637 if modality == Modality.CORE:
1638 return _tf_core_map_structure_with_tuple_paths_up_to(
1639 shallow_tree, func, *inputs, **kwargs
1640 )
1641 elif modality == Modality.DATA:
1642 return _tf_data_map_structure_up_to(shallow_tree, func, *inputs)
1643 else:
1644 raise ValueError(
1645 "Unknown modality used {} for nested structure".format(modality)
1646 )
1649def _tf_core_map_structure_with_tuple_paths_up_to(
1650 shallow_tree, func, *inputs, **kwargs
1651):
1652 """See comments for map_structure_with_tuple_paths_up_to() in tensorflow/python/util/nest.py."""
1653 if not inputs:
1654 raise ValueError("Cannot map over no sequences")
1656 check_types = kwargs.pop("check_types", True)
1657 expand_composites = kwargs.pop("expand_composites", False)
1658 is_nested_fn = (
1659 _is_nested_or_composite if expand_composites else _tf_core_is_nested
1660 )
1662 for input_tree in inputs:
1663 _tf_core_assert_shallow_structure(
1664 shallow_tree,
1665 input_tree,
1666 check_types=check_types,
1667 expand_composites=expand_composites,
1668 )
1670 # Flatten each input separately, apply the function to corresponding items,
1671 # then repack based on the structure of the first input.
1672 flat_value_gen = (
1673 _tf_core_flatten_up_to( # pylint: disable=g-complex-comprehension
1674 shallow_tree,
1675 input_tree,
1676 check_types,
1677 expand_composites=expand_composites,
1678 )
1679 for input_tree in inputs
1680 )
1681 flat_path_gen = (
1682 path
1683 for path, _ in _tf_core_yield_flat_up_to(
1684 shallow_tree, inputs[0], is_nested_fn
1685 )
1686 )
1687 results = [
1688 func(*args, **kwargs) for args in zip(flat_path_gen, *flat_value_gen)
1689 ]
1690 return _tf_core_pack_sequence_as(
1691 structure=shallow_tree,
1692 flat_sequence=results,
1693 expand_composites=expand_composites,
1694 )
1697# pylint: disable=missing-function-docstring
1698def _tf_data_map_structure_up_to(shallow_tree, func, *inputs):
1699 if not inputs:
1700 raise ValueError(
1701 "Argument `inputs` is empty. Cannot map over no sequences."
1702 )
1703 for input_tree in inputs:
1704 _tf_data_assert_shallow_structure(shallow_tree, input_tree)
1706 # Flatten each input separately, apply the function to corresponding elements,
1707 # then repack based on the structure of the first input.
1708 all_flattened_up_to = (
1709 _tf_data_flatten_up_to(shallow_tree, input_tree) for input_tree in inputs
1710 )
1712 results = [func(*tensors) for tensors in zip(*all_flattened_up_to)]
1713 return _tf_data_pack_sequence_as(
1714 structure=shallow_tree, flat_sequence=results
1715 )