Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/util/nest_util.py: 19%

388 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2023 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""Utility methods for handling nests. 

17 

18This module encapsulates different semantics of handling nests by the public 

19tf.nest APIs and internal tf.data APIs. The difference in semantics exists for 

20historic reasons and reconciliation would require a non-backwards compatible 

21change. 

22 

23The implementation of the different semantics use a common utility to 

24avoid / minimize further divergence between the two APIs over time. 

25""" 

26 

27import collections as _collections 

28import enum 

29 

30import six as _six 

31import wrapt as _wrapt 

32 

33from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import 

34from tensorflow.python.platform import tf_logging 

35from tensorflow.python.util import _pywrap_utils 

36from tensorflow.python.util.compat import collections_abc as _collections_abc 

37 

38 

39_is_mapping_view = _pywrap_utils.IsMappingView 

40_is_attrs = _pywrap_utils.IsAttrs 

41_is_composite_tensor = _pywrap_utils.IsCompositeTensor 

42_is_type_spec = _pywrap_utils.IsTypeSpec 

43_is_mutable_mapping = _pywrap_utils.IsMutableMapping 

44_is_mapping = _pywrap_utils.IsMapping 

45_tf_data_is_nested = _pywrap_utils.IsNestedForData 

46_tf_data_flatten = _pywrap_utils.FlattenForData 

47_tf_core_is_nested = _pywrap_utils.IsNested 

48_is_nested_or_composite = _pywrap_utils.IsNestedOrComposite 

49# See the swig file (util.i) for documentation. 

50same_namedtuples = _pywrap_utils.SameNamedtuples 

51 

52 

53STRUCTURES_HAVE_MISMATCHING_TYPES = ( 

54 "The two structures don't have the same sequence type. Input structure has " 

55 "type {input_type}, while shallow structure has type {shallow_type}." 

56) 

57 

58STRUCTURES_HAVE_MISMATCHING_LENGTHS = ( 

59 "The two structures don't have the same sequence length. Input " 

60 "structure has length {input_length}, while shallow structure has length " 

61 "{shallow_length}." 

62) 

63 

64INPUT_TREE_SMALLER_THAN_SHALLOW_TREE = ( 

65 "The input_tree has fewer items than the shallow_tree. Input structure " 

66 "has length {input_size}, while shallow structure has length " 

67 "{shallow_size}." 

68) 

69 

70SHALLOW_TREE_HAS_INVALID_KEYS = ( 

71 "The shallow_tree's keys are not a subset of the input_tree's keys. The " 

72 "shallow_tree has the following keys that are not in the input_tree: {}." 

73) 

74 

75 

76class Modality(enum.Enum): 

77 """Modality/semantic used for treating nested structures. 

78 

79 - Modality.CORE follows tensorflow_core/tf.nest semantics. 

80 

81 The following collection types are recognized by `tf.nest` as nested 

82 structures: 

83 

84 * `collections.abc.Sequence` (except `string` and `bytes`). 

85 This includes `list`, `tuple`, and `namedtuple`. 

86 * `collections.abc.Mapping` (with sortable keys). 

87 This includes `dict` and `collections.OrderedDict`. 

88 * `collections.abc.MappingView` (with sortable keys). 

89 * [`attr.s` classes](https://www.attrs.org/). 

90 

91 Any other values are considered **atoms**. Not all collection types are 

92 considered nested structures. For example, the following types are 

93 considered atoms: 

94 

95 * `set`; `{"a", "b"}` is an atom, while `["a", "b"]` is a nested structure. 

96 * [`dataclass` classes](https://docs.python.org/library/dataclasses.html) 

97 * `tf.Tensor` 

98 * `numpy.array` 

99 

100 - Modality.DATA follows tf.data's nest semantics. 

101 

102 This modality makes two changes: 

103 1. It removes support for lists as a level of nesting in nested structures. 

104 2. It adds support for `SparseTensorValue` as an atomic element. 

105 

106 The motivation for this change is twofold: 

107 

108 1. It seems more natural for lists to be treated (e.g. in Dataset 

109 constructors) 

110 as tensors, rather than lists of (lists of...) tensors. 

111 2. This is needed because `SparseTensorValue` is implemented as a `namedtuple` 

112 that would normally be flattened and we want to be able to create sparse 

113 tensor from `SparseTensorValue's similarly to creating tensors from numpy 

114 arrays. 

115 """ 

116 

117 CORE = "CORE" 

118 DATA = "DATA" 

119 

120 

121class _DotString(object): 

122 __slots__ = [] 

123 

124 def __str__(self): 

125 return "." 

126 

127 def __repr__(self): 

128 return "." 

129 

130 

131_DOT = _DotString() 

132 

133 

134def is_nested(modality, structure): 

135 """Returns true if its input is a nested structure. 

136 

137 For Modality.CORE refer to 

138 [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest) 

139 for the definition of a nested structure. 

140 

141 Args: 

142 modality: enum value of supported modality [Modality.CORE or Modality.DATA] 

143 structure: the value to test. 

144 

145 Returns: 

146 True if the input is a nested structure. 

147 """ 

148 if modality == Modality.CORE: 

149 return _tf_core_is_nested(structure) 

150 elif modality == Modality.DATA: 

151 return _tf_data_is_nested(structure) 

152 else: 

153 raise ValueError( 

154 "Unknown modality used {} for nested structure".format(modality) 

155 ) 

156 

157 

158# TODO(b/225045380): Move to a "leaf" library to use in trace_type. 

159def is_namedtuple(instance, strict=False): 

160 """Returns True iff `instance` is a `namedtuple`. 

161 

162 Args: 

163 instance: An instance of a Python object. 

164 strict: If True, `instance` is considered to be a `namedtuple` only if it is 

165 a "plain" namedtuple. For instance, a class inheriting from a `namedtuple` 

166 will be considered to be a `namedtuple` iff `strict=False`. 

167 

168 Returns: 

169 True if `instance` is a `namedtuple`. 

170 """ 

171 return _pywrap_utils.IsNamedtuple(instance, strict) 

172 

173 

174def sequence_like(instance, args): 

175 """Converts the sequence `args` to the same type as `instance`. 

176 

177 Args: 

178 instance: an instance of `tuple`, `list`, `namedtuple`, `dict`, 

179 `collections.OrderedDict`, or `composite_tensor.Composite_Tensor` or 

180 `type_spec.TypeSpec`. 

181 args: items to be converted to the `instance` type. 

182 

183 Returns: 

184 `args` with the type of `instance`. 

185 """ 

186 if _is_mutable_mapping(instance): 

187 # Pack dictionaries in a deterministic order by sorting the keys. 

188 # Notice this means that we ignore the original order of `OrderedDict` 

189 # instances. This is intentional, to avoid potential bugs caused by mixing 

190 # ordered and plain dicts (e.g., flattening a dict but using a 

191 # corresponding `OrderedDict` to pack it back). 

192 result = dict(zip(_tf_core_sorted(instance), args)) 

193 instance_type = type(instance) 

194 if instance_type == _collections.defaultdict: 

195 d = _collections.defaultdict(instance.default_factory) 

196 else: 

197 d = instance_type() 

198 for key in instance: 

199 d[key] = result[key] 

200 return d 

201 elif _is_mapping(instance): 

202 result = dict(zip(_tf_core_sorted(instance), args)) 

203 instance_type = type(instance) 

204 if not getattr(instance_type, "__supported_by_tf_nest__", False): 

205 tf_logging.log_first_n( 

206 tf_logging.WARN, 

207 "Mapping types may not work well with tf.nest. " 

208 "Prefer using MutableMapping for {}".format(instance_type), 

209 1, 

210 ) 

211 try: 

212 return instance_type((key, result[key]) for key in instance) 

213 except TypeError as err: 

214 # pylint: disable=raise-missing-from 

215 raise TypeError( 

216 "Error creating an object of type {} like {}. Note that " 

217 "it must accept a single positional argument " 

218 "representing an iterable of key-value pairs, in " 

219 "addition to self. Cause: {}".format(type(instance), instance, err) 

220 ) 

221 elif _is_mapping_view(instance): 

222 # We can't directly construct mapping views, so we create a list instead 

223 return list(args) 

224 elif is_namedtuple(instance) or _is_attrs(instance): 

225 if isinstance(instance, _wrapt.ObjectProxy): 

226 instance_type = type(instance.__wrapped__) 

227 else: 

228 instance_type = type(instance) 

229 return instance_type(*args) 

230 elif _is_composite_tensor(instance): 

231 assert len(args) == 1 

232 spec = instance._type_spec # pylint: disable=protected-access 

233 return spec._from_components(args[0]) # pylint: disable=protected-access 

234 elif _is_type_spec(instance): 

235 # Pack a CompositeTensor's components according to a TypeSpec. 

236 assert len(args) == 1 

237 return instance._from_components(args[0]) # pylint: disable=protected-access 

238 elif isinstance(instance, _six.moves.range): 

239 return sequence_like(list(instance), args) 

240 elif isinstance(instance, _wrapt.ObjectProxy): 

241 # For object proxies, first create the underlying type and then re-wrap it 

242 # in the proxy type. 

243 return type(instance)(sequence_like(instance.__wrapped__, args)) 

244 else: 

245 # Not a namedtuple 

246 return type(instance)(args) 

247 

248 

249def _get_attrs_items(obj): 

250 """Returns a list of (name, value) pairs from an attrs instance. 

251 

252 TODO(b/268078256): check if this comment is valid, and if so, ensure it's 

253 handled in the function below. 

254 The list will be sorted by name. 

255 

256 Args: 

257 obj: an object. 

258 

259 Returns: 

260 A list of (attr_name, attr_value) pairs, sorted by attr_name. 

261 """ 

262 attrs = getattr(obj.__class__, "__attrs_attrs__") 

263 attr_names = (a.name for a in attrs) 

264 return [(attr_name, getattr(obj, attr_name)) for attr_name in attr_names] 

265 

266 

267def _tf_core_sorted(dict_): 

268 """Returns a sorted list of the dict keys, with error if keys not sortable.""" 

269 try: 

270 return sorted(dict_.keys()) 

271 except TypeError: 

272 # pylint: disable=raise-missing-from 

273 raise TypeError("nest only supports dicts with sortable keys.") 

274 

275 

276def _tf_data_sorted(dict_): 

277 """Returns a sorted list of the dict keys, with error if keys not sortable.""" 

278 try: 

279 return sorted(list(dict_)) 

280 except TypeError as e: 

281 # pylint: disable=raise-missing-from 

282 raise TypeError( 

283 f"nest only supports dicts with sortable keys. Error: {e.message}" 

284 ) 

285 

286 

287def yield_value(modality, iterable): 

288 """Yield elements of `iterable` in a deterministic order. 

289 

290 Args: 

291 modality: enum value of supported modality [Modality.CORE or Modality.DATA] 

292 iterable: an iterable. 

293 

294 Yields: 

295 The iterable elements in a deterministic order. 

296 """ 

297 if modality == Modality.CORE: 

298 yield from _tf_core_yield_value(iterable) 

299 elif modality == Modality.DATA: 

300 yield from _tf_data_yield_value(iterable) 

301 else: 

302 raise ValueError( 

303 "Unknown modality used {} for nested structure".format(modality) 

304 ) 

305 

306 

307def _tf_core_yield_value(iterable): 

308 for _, v in _tf_core_yield_sorted_items(iterable): 

309 yield v 

310 

311 

312def yield_sorted_items(modality, iterable): 

313 if modality == Modality.CORE: 

314 return _tf_core_yield_sorted_items(iterable) 

315 else: 

316 raise ValueError( 

317 "Unknown modality used {} for nested structure".format(modality) 

318 ) 

319 

320 

321def _tf_core_yield_sorted_items(iterable): 

322 """Yield (key, value) pairs for `iterable` in a deterministic order. 

323 

324 For Sequences, the key will be an int, the array index of a value. 

325 For Mappings, the key will be the dictionary key. 

326 For objects (e.g. namedtuples), the key will be the attribute name. 

327 

328 In all cases, the keys will be iterated in sorted order. 

329 

330 Args: 

331 iterable: an iterable. 

332 

333 Yields: 

334 The iterable's (key, value) pairs, in order of sorted keys. 

335 """ 

336 # Ordered to check common structure types (list, tuple, dict) first. 

337 if isinstance(iterable, list): 

338 for item in enumerate(iterable): 

339 yield item 

340 # namedtuples handled separately to avoid expensive namedtuple check. 

341 elif type(iterable) == tuple: # pylint: disable=unidiomatic-typecheck 

342 for item in enumerate(iterable): 

343 yield item 

344 elif isinstance(iterable, (dict, _collections_abc.Mapping)): 

345 # Iterate through dictionaries in a deterministic order by sorting the 

346 # keys. Notice this means that we ignore the original order of `OrderedDict` 

347 # instances. This is intentional, to avoid potential bugs caused by mixing 

348 # ordered and plain dicts (e.g., flattening a dict but using a 

349 # corresponding `OrderedDict` to pack it back). 

350 for key in _tf_core_sorted(iterable): 

351 yield key, iterable[key] 

352 elif _is_attrs(iterable): 

353 for item in _get_attrs_items(iterable): 

354 yield item 

355 elif is_namedtuple(iterable): 

356 for field in iterable._fields: 

357 yield field, getattr(iterable, field) 

358 elif _is_composite_tensor(iterable): 

359 type_spec = iterable._type_spec # pylint: disable=protected-access 

360 yield type_spec.value_type.__name__, type_spec._to_components(iterable) # pylint: disable=protected-access 

361 elif _is_type_spec(iterable): 

362 # Note: to allow CompositeTensors and their TypeSpecs to have matching 

363 # structures, we need to use the same key string here. 

364 yield iterable.value_type.__name__, iterable._component_specs # pylint: disable=protected-access 

365 else: 

366 for item in enumerate(iterable): 

367 yield item 

368 

369 

370def _tf_data_yield_value(iterable): 

371 """Yield elements of `iterable` in a deterministic order. 

372 

373 Args: 

374 iterable: an iterable. 

375 

376 Yields: 

377 The iterable elements in a deterministic order. 

378 """ 

379 # pylint: disable=protected-access 

380 if isinstance(iterable, _collections_abc.Mapping): 

381 # Iterate through dictionaries in a deterministic order by sorting the 

382 # keys. Notice this means that we ignore the original order of `OrderedDict` 

383 # instances. This is intentional, to avoid potential bugs caused by mixing 

384 # ordered and plain dicts (e.g., flattening a dict but using a 

385 # corresponding `OrderedDict` to pack it back). 

386 for key in _tf_data_sorted(iterable): 

387 yield iterable[key] 

388 # To avoid circular imports. sparse_tensor 

389 # depends on tensorflow/python/util/nest.py transitively, and if we try to 

390 # import sparse_tensor again, it results in a circular import. Instead, here 

391 # we check the class name instead of using `isinstance`. 

392 elif iterable.__class__.__name__ == "SparseTensorValue": 

393 yield iterable 

394 elif _is_attrs(iterable): 

395 for _, attr in _get_attrs_items(iterable): 

396 yield attr 

397 else: 

398 for value in iterable: 

399 yield value 

400 

401 

402def assert_same_structure( 

403 modality, nest1, nest2, check_types=True, expand_composites=False 

404): 

405 """Asserts that two structures are nested in the same way. 

406 

407 For Modality.CORE refer to 

408 [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest) 

409 for the definition of a structure. Note the method does not check the types of 

410 atoms inside the structures. 

411 

412 Examples: 

413 

414 * These atom vs. atom comparisons will pass: 

415 

416 >>> tf.nest.assert_same_structure(1.5, tf.Variable(1, tf.uint32)) 

417 >>> tf.nest.assert_same_structure("abc", np.array([1, 2])) 

418 

419 * These nested structure vs. nested structure comparisons will pass: 

420 

421 >>> structure1 = (((1, 2), 3), 4, (5, 6)) 

422 >>> structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) 

423 >>> structure3 = [(("a", "b"), "c"), "d", ["e", "f"]] 

424 >>> tf.nest.assert_same_structure(structure1, structure2) 

425 >>> tf.nest.assert_same_structure(structure1, structure3, check_types=False) 

426 

427 >>> import collections 

428 >>> tf.nest.assert_same_structure( 

429 ... collections.namedtuple("bar", "a b")(1, 2), 

430 ... collections.namedtuple("foo", "a b")(2, 3), 

431 ... check_types=False) 

432 

433 >>> tf.nest.assert_same_structure( 

434 ... collections.namedtuple("bar", "a b")(1, 2), 

435 ... { "a": 1, "b": 2 }, 

436 ... check_types=False) 

437 

438 >>> tf.nest.assert_same_structure( 

439 ... { "a": 1, "b": 2, "c": 3 }, 

440 ... { "c": 6, "b": 5, "a": 4 }) 

441 

442 >>> ragged_tensor1 = tf.RaggedTensor.from_row_splits( 

443 ... values=[3, 1, 4, 1, 5, 9, 2, 6], 

444 ... row_splits=[0, 4, 4, 7, 8, 8]) 

445 >>> ragged_tensor2 = tf.RaggedTensor.from_row_splits( 

446 ... values=[3, 1, 4], 

447 ... row_splits=[0, 3]) 

448 >>> tf.nest.assert_same_structure( 

449 ... ragged_tensor1, 

450 ... ragged_tensor2, 

451 ... expand_composites=True) 

452 

453 * These examples will raise exceptions: 

454 

455 >>> tf.nest.assert_same_structure([0, 1], np.array([0, 1])) 

456 Traceback (most recent call last): 

457 ... 

458 ValueError: The two structures don't have the same nested structure 

459 

460 >>> tf.nest.assert_same_structure( 

461 ... collections.namedtuple('bar', 'a b')(1, 2), 

462 ... collections.namedtuple('foo', 'a b')(2, 3)) 

463 Traceback (most recent call last): 

464 ... 

465 TypeError: The two structures don't have the same nested structure 

466 

467 For Modality.DATA, nested structures are treated differently than 

468 Modality.CORE. Please refer to class Modality's documentation above to read up 

469 on these differences. 

470 

471 Args: 

472 modality: enum value of supported modality [Modality.CORE or Modality.DATA] 

473 nest1: an atom or a nested structure. 

474 nest2: an atom or a nested structure. 

475 check_types: - For Modality.CORE: if `True` (default) types of structures 

476 are checked as well, including the keys of dictionaries. If set to 

477 `False`, for example a list and a tuple of objects will look the same if 

478 they have the same size. Note that namedtuples with identical name and 

479 fields are always considered to have the same shallow structure. Two types 

480 will also be considered the same if they are both list subtypes (which 

481 allows "list" and "_ListWrapper" from trackable dependency tracking to 

482 compare equal). `check_types=True` only checks type of sub-structures. The 

483 types of atoms are not checked. - For Modality.DATA: if `True` (default) 

484 types of sequences should be same as well. For dictionary, "type" of 

485 dictionary is considered to include its keys. In other words, two 

486 dictionaries with different keys are considered to have a different 

487 "type". If set to `False`, two iterables are considered same as long as 

488 they yield the elements that have same structures. 

489 expand_composites: Arg only valid for Modality.CORE. If true, then composite 

490 tensors such as `tf.sparse.SparseTensor` and `tf.RaggedTensor` are 

491 expanded into their component tensors. 

492 

493 Raises: 

494 ValueError: If the two structures do not have the same number of atoms or 

495 if the two structures are not nested in the same way. 

496 TypeError: If the two structures differ in the type of sequence in any of 

497 their substructures. Only possible if `check_types` is `True`. 

498 """ 

499 if modality == Modality.CORE: 

500 _tf_core_assert_same_structure(nest1, nest2, check_types, expand_composites) 

501 elif modality == Modality.DATA: 

502 _tf_data_assert_same_structure(nest1, nest2, check_types) 

503 else: 

504 raise ValueError( 

505 "Unknown modality used {} for nested structure".format(modality) 

506 ) 

507 

508 

509# pylint: disable=missing-function-docstring 

510def _tf_core_assert_same_structure( 

511 nest1, nest2, check_types=True, expand_composites=False 

512): 

513 # Convert to bool explicitly as otherwise pybind will not be able# to handle 

514 # type mismatch message correctly. See GitHub issue 42329 for details. 

515 check_types = bool(check_types) 

516 expand_composites = bool(expand_composites) 

517 try: 

518 _pywrap_utils.AssertSameStructure( 

519 nest1, nest2, check_types, expand_composites 

520 ) 

521 except (ValueError, TypeError) as e: 

522 str1 = str(_tf_core_map_structure(lambda _: _DOT, nest1)) 

523 str2 = str(_tf_core_map_structure(lambda _: _DOT, nest2)) 

524 raise type(e)( 

525 "%s\nEntire first structure:\n%s\nEntire second structure:\n%s" 

526 % (str(e), str1, str2) 

527 ) 

528 

529 

530def _tf_data_assert_same_structure(nest1, nest2, check_types=True): 

531 _pywrap_utils.AssertSameStructureForData(nest1, nest2, check_types) 

532 

533 

534def _tf_core_packed_nest_with_indices( 

535 structure, flat, index, is_nested_fn, sequence_fn=None 

536): 

537 """Helper function for pack_sequence_as. 

538 

539 Args: 

540 structure: structure to mimic. 

541 flat: Flattened values to output substructure for. 

542 index: Index at which to start reading from flat. 

543 is_nested_fn: Function used to test if a value should be treated as a nested 

544 structure. 

545 sequence_fn: Function used to generate a new strcuture instance. 

546 

547 Returns: 

548 The tuple (new_index, child), where: 

549 * new_index - the updated index into `flat` having processed `structure`. 

550 * packed - the subset of `flat` corresponding to `structure`, 

551 having started at `index`, and packed into the same nested 

552 format. 

553 

554 Raises: 

555 ValueError: if `structure` contains more atoms than `flat` 

556 (assuming indexing starts from `index`). 

557 """ 

558 packed = [] 

559 sequence_fn = sequence_fn or sequence_like 

560 for s in _tf_core_yield_value(structure): 

561 if is_nested_fn(s): 

562 new_index, child = _tf_core_packed_nest_with_indices( 

563 s, flat, index, is_nested_fn, sequence_fn 

564 ) 

565 packed.append(sequence_fn(s, child)) 

566 index = new_index 

567 else: 

568 packed.append(flat[index]) 

569 index += 1 

570 return index, packed 

571 

572 

573def _tf_data_packed_nest_with_indices(structure, flat, index): 

574 """Helper function for pack_nest_as. 

575 

576 Args: 

577 structure: Substructure (tuple of elements and/or tuples) to mimic 

578 flat: Flattened values to output substructure for. 

579 index: Index at which to start reading from flat. 

580 

581 Returns: 

582 The tuple (new_index, child), where: 

583 * new_index - the updated index into `flat` having processed `structure`. 

584 * packed - the subset of `flat` corresponding to `structure`, 

585 having started at `index`, and packed into the same nested 

586 format. 

587 

588 Raises: 

589 ValueError: if `structure` contains more elements than `flat` 

590 (assuming indexing starts from `index`). 

591 """ 

592 packed = [] 

593 for s in _tf_data_yield_value(structure): 

594 if _tf_data_is_nested(s): 

595 new_index, child = _tf_data_packed_nest_with_indices(s, flat, index) 

596 packed.append(sequence_like(s, child)) # pylint: disable=protected-access 

597 index = new_index 

598 else: 

599 packed.append(flat[index]) 

600 index += 1 

601 return index, packed 

602 

603 

604def flatten(modality, structure, expand_composites=False): 

605 """Flattens a nested structure. 

606 

607 - For Modality.CORE: refer to 

608 [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest) 

609 for the definition of a structure. 

610 

611 If the structure is an atom, then returns a single-item list: [structure]. 

612 

613 This is the inverse of the `nest.pack_sequence_as` method that takes in a 

614 flattened list and re-packs it into the nested structure. 

615 

616 In the case of dict instances, the sequence consists of the values, sorted by 

617 key to ensure deterministic behavior. This is true also for OrderedDict 

618 instances: their sequence order is ignored, the sorting order of keys is used 

619 instead. The same convention is followed in `nest.pack_sequence_as`. This 

620 correctly repacks dicts and OrderedDicts after they have been flattened, and 

621 also allows flattening an OrderedDict and then repacking it back using a 

622 corresponding plain dict, or vice-versa. Dictionaries with non-sortable keys 

623 cannot be flattened. 

624 

625 Users must not modify any collections used in nest while this function is 

626 running. 

627 

628 Examples: 

629 

630 1. Python dict (ordered by key): 

631 

632 >>> dict = { "key3": "value3", "key1": "value1", "key2": "value2" } 

633 >>> tf.nest.flatten(dict) 

634 ['value1', 'value2', 'value3'] 

635 

636 2. For a nested python tuple: 

637 

638 >>> tuple = ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0) 

639 >>> tf.nest.flatten(tuple) 

640 [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] 

641 

642 3. For a nested dictionary of dictionaries: 

643 

644 >>> dict = { "key3": {"c": (1.0, 2.0), "a": (3.0)}, 

645 ... "key1": {"m": "val1", "g": "val2"} } 

646 >>> tf.nest.flatten(dict) 

647 ['val2', 'val1', 3.0, 1.0, 2.0] 

648 

649 4. Numpy array (will not flatten): 

650 

651 >>> array = np.array([[1, 2], [3, 4]]) 

652 >>> tf.nest.flatten(array) 

653 [array([[1, 2], 

654 [3, 4]])] 

655 

656 5. `tf.Tensor` (will not flatten): 

657 

658 >>> tensor = tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]) 

659 >>> tf.nest.flatten(tensor) 

660 [<tf.Tensor: shape=(3, 3), dtype=float32, numpy= 

661 array([[1., 2., 3.], 

662 [4., 5., 6.], 

663 [7., 8., 9.]], dtype=float32)>] 

664 

665 6. `tf.RaggedTensor`: This is a composite tensor thats representation consists 

666 of a flattened list of 'values' and a list of 'row_splits' which indicate how 

667 to chop up the flattened list into different rows. For more details on 

668 `tf.RaggedTensor`, please visit 

669 https://www.tensorflow.org/api_docs/python/tf/RaggedTensor. 

670 

671 with `expand_composites=False`, we just return the RaggedTensor as is. 

672 

673 >>> tensor = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2]]) 

674 >>> tf.nest.flatten(tensor, expand_composites=False) 

675 [<tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2]]>] 

676 

677 with `expand_composites=True`, we return the component Tensors that make up 

678 the RaggedTensor representation (the values and row_splits tensors) 

679 

680 >>> tensor = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2]]) 

681 >>> tf.nest.flatten(tensor, expand_composites=True) 

682 [<tf.Tensor: shape=(7,), dtype=int32, numpy=array([3, 1, 4, 1, 5, 9, 2], 

683 dtype=int32)>, 

684 <tf.Tensor: shape=(4,), dtype=int64, numpy=array([0, 4, 4, 7])>] 

685 

686 Args: 

687 modality: enum value of supported modality [Modality.CORE or Modality.DATA] 

688 structure: an atom or a nested structure. Note, numpy arrays are considered 

689 atoms and are not flattened. 

690 expand_composites: Arg valid for Modality.CORE only. If true, then composite 

691 tensors such as `tf.sparse.SparseTensor` and `tf.RaggedTensor` are 

692 expanded into their component tensors. 

693 

694 Returns: 

695 A Python list, the flattened version of the input. 

696 

697 Raises: 

698 TypeError: The nest is or contains a dict with non-sortable keys. 

699 """ 

700 if modality == Modality.CORE: 

701 return _tf_core_flatten(structure, expand_composites) 

702 elif modality == Modality.DATA: 

703 return _tf_data_flatten(structure) 

704 else: 

705 raise ValueError( 

706 "Unknown modality used {} for nested structure".format(modality) 

707 ) 

708 

709 

710def _tf_core_flatten(structure, expand_composites=False): 

711 """See comments for flatten() in tensorflow/python/util/nest.py.""" 

712 if structure is None: 

713 return [None] 

714 expand_composites = bool(expand_composites) 

715 return _pywrap_utils.Flatten(structure, expand_composites) 

716 

717 

718def pack_sequence_as( 

719 modality, structure, flat_sequence, expand_composites, sequence_fn=None 

720): 

721 """Returns a given flattened sequence packed into a given structure. 

722 

723 - For Modality.CORE: Refer to 

724 [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest) 

725 for the definition of a structure. 

726 

727 If `structure` is an atom, `flat_sequence` must be a single-item list; 

728 in this case the return value is `flat_sequence[0]`. 

729 

730 If `structure` is or contains a dict instance, the keys will be sorted to 

731 pack the flat sequence in deterministic order. This is true also for 

732 `OrderedDict` instances: their sequence order is ignored, the sorting order of 

733 keys is used instead. The same convention is followed in `flatten`. 

734 This correctly repacks dicts and `OrderedDict`s after they have been 

735 flattened, and also allows flattening an `OrderedDict` and then repacking it 

736 back using a corresponding plain dict, or vice-versa. 

737 Dictionaries with non-sortable keys cannot be flattened. 

738 

739 Examples: 

740 

741 1. Python dict: 

742 

743 >>> structure = { "key3": "", "key1": "", "key2": "" } 

744 >>> flat_sequence = ["value1", "value2", "value3"] 

745 >>> tf.nest.pack_sequence_as(structure, flat_sequence) 

746 {'key3': 'value3', 'key1': 'value1', 'key2': 'value2'} 

747 

748 2. For a nested python tuple: 

749 

750 >>> structure = (('a','b'), ('c','d','e'), 'f') 

751 >>> flat_sequence = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] 

752 >>> tf.nest.pack_sequence_as(structure, flat_sequence) 

753 ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0) 

754 

755 3. For a nested dictionary of dictionaries: 

756 

757 >>> structure = { "key3": {"c": ('alpha', 'beta'), "a": ('gamma')}, 

758 ... "key1": {"e": "val1", "d": "val2"} } 

759 >>> flat_sequence = ['val2', 'val1', 3.0, 1.0, 2.0] 

760 >>> tf.nest.pack_sequence_as(structure, flat_sequence) 

761 {'key3': {'c': (1.0, 2.0), 'a': 3.0}, 'key1': {'e': 'val1', 'd': 'val2'}} 

762 

763 4. Numpy array (considered a scalar): 

764 

765 >>> structure = ['a'] 

766 >>> flat_sequence = [np.array([[1, 2], [3, 4]])] 

767 >>> tf.nest.pack_sequence_as(structure, flat_sequence) 

768 [array([[1, 2], 

769 [3, 4]])] 

770 

771 5. tf.Tensor (considered a scalar): 

772 

773 >>> structure = ['a'] 

774 >>> flat_sequence = [tf.constant([[1., 2., 3.], [4., 5., 6.]])] 

775 >>> tf.nest.pack_sequence_as(structure, flat_sequence) 

776 [<tf.Tensor: shape=(2, 3), dtype=float32, 

777 numpy= array([[1., 2., 3.], [4., 5., 6.]], dtype=float32)>] 

778 

779 6. `tf.RaggedTensor`: This is a composite tensor thats representation consists 

780 of a flattened list of 'values' and a list of 'row_splits' which indicate how 

781 to chop up the flattened list into different rows. For more details on 

782 `tf.RaggedTensor`, please visit 

783 https://www.tensorflow.org/api_docs/python/tf/RaggedTensor. 

784 

785 With `expand_composites=False`, we treat RaggedTensor as a scalar. 

786 

787 >>> structure = { "foo": tf.ragged.constant([[1, 2], [3]]), 

788 ... "bar": tf.constant([[5]]) } 

789 >>> flat_sequence = [ "one", "two" ] 

790 >>> tf.nest.pack_sequence_as(structure, flat_sequence, 

791 ... expand_composites=False) 

792 {'foo': 'two', 'bar': 'one'} 

793 

794 With `expand_composites=True`, we expect that the flattened input contains 

795 the tensors making up the ragged tensor i.e. the values and row_splits 

796 tensors. 

797 

798 >>> structure = { "foo": tf.ragged.constant([[1., 2.], [3.]]), 

799 ... "bar": tf.constant([[5.]]) } 

800 >>> tensors = tf.nest.flatten(structure, expand_composites=True) 

801 >>> print(tensors) 

802 [<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[5.]], 

803 dtype=float32)>, 

804 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([1., 2., 3.], 

805 dtype=float32)>, 

806 <tf.Tensor: shape=(3,), dtype=int64, numpy=array([0, 2, 3])>] 

807 >>> verified_tensors = [tf.debugging.check_numerics(t, 'invalid tensor: ') 

808 ... if t.dtype==tf.float32 else t 

809 ... for t in tensors] 

810 >>> tf.nest.pack_sequence_as(structure, verified_tensors, 

811 ... expand_composites=True) 

812 {'foo': <tf.RaggedTensor [[1.0, 2.0], [3.0]]>, 

813 'bar': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[5.]], 

814 dtype=float32)>} 

815 

816 - For Modality.DATA: If `structure` is a scalar, `flat_sequence` must be a 

817 single-element list; 

818 in this case the return value is `flat_sequence[0]`. 

819 

820 Args: 

821 modality: enum value of supported modality [Modality.CORE or Modality.DATA] 

822 structure: - For Modality.CORE: Nested structure, whose structure is given 

823 by nested lists, tuples, and dicts. Note: numpy arrays and strings are 

824 considered scalars. - For Modality.DATA: tuple or list constructed of 

825 scalars and/or other tuples/lists, or a scalar. Note: numpy arrays are 

826 considered scalars. 

827 flat_sequence: flat sequence to pack. 

828 expand_composites: Arg valid for Modality.CORE only. If true, then composite 

829 tensors such as `tf.sparse.SparseTensor` and `tf.RaggedTensor` are 

830 expanded into their component tensors. 

831 sequence_fn: Arg valid for Modality.CORE only. 

832 

833 Returns: 

834 packed: `flat_sequence` converted to have the same recursive structure as 

835 `structure`. 

836 

837 Raises: 

838 ValueError: If `flat_sequence` and `structure` have different 

839 atom counts. 

840 TypeError: For Modality.CORE only. `structure` is or contains a dict with 

841 non-sortable keys. 

842 """ 

843 if modality == Modality.CORE: 

844 return _tf_core_pack_sequence_as( 

845 structure, flat_sequence, expand_composites, sequence_fn 

846 ) 

847 elif modality == Modality.DATA: 

848 return _tf_data_pack_sequence_as(structure, flat_sequence) 

849 else: 

850 raise ValueError( 

851 "Unknown modality used {} for nested structure".format(modality) 

852 ) 

853 

854 

855def _tf_core_pack_sequence_as( 

856 structure, flat_sequence, expand_composites, sequence_fn=None 

857): 

858 """Implements sequence packing, with the option to alter the structure.""" 

859 is_nested_fn = ( 

860 _is_nested_or_composite if expand_composites else _tf_core_is_nested 

861 ) 

862 sequence_fn = sequence_fn or sequence_like 

863 

864 def truncate(value, length): 

865 value_str = str(value) 

866 return value_str[:length] + (value_str[length:] and "...") 

867 

868 if not is_nested_fn(flat_sequence): 

869 raise TypeError( 

870 "Attempted to pack value:\n {}\ninto a structure, but found " 

871 "incompatible type `{}` instead.".format( 

872 truncate(flat_sequence, 100), type(flat_sequence) 

873 ) 

874 ) 

875 

876 if not is_nested_fn(structure): 

877 if len(flat_sequence) != 1: 

878 raise ValueError( 

879 "The target structure is of type `{}`\n {}\nHowever the input " 

880 "is a sequence ({}) of length {}.\n {}\nnest cannot " 

881 "guarantee that it is safe to map one to the other.".format( 

882 type(structure), 

883 truncate(structure, 100), 

884 type(flat_sequence), 

885 len(flat_sequence), 

886 truncate(flat_sequence, 100), 

887 ) 

888 ) 

889 return flat_sequence[0] 

890 

891 try: 

892 final_index, packed = _tf_core_packed_nest_with_indices( 

893 structure, flat_sequence, 0, is_nested_fn, sequence_fn 

894 ) 

895 if final_index < len(flat_sequence): 

896 raise IndexError 

897 except IndexError: 

898 flat_structure = _tf_core_flatten( 

899 structure, expand_composites=expand_composites 

900 ) 

901 if len(flat_structure) != len(flat_sequence): 

902 # pylint: disable=raise-missing-from 

903 raise ValueError( 

904 "Could not pack sequence. Structure had %d atoms, but " 

905 "flat_sequence had %d items. Structure: %s, flat_sequence: %s." 

906 % (len(flat_structure), len(flat_sequence), structure, flat_sequence) 

907 ) 

908 return sequence_fn(structure, packed) 

909 

910 

911def _tf_data_pack_sequence_as(structure, flat_sequence): 

912 """Returns a given flattened sequence packed into a nest. 

913 

914 If `structure` is a scalar, `flat_sequence` must be a single-element list; 

915 in this case the return value is `flat_sequence[0]`. 

916 

917 Args: 

918 structure: tuple or list constructed of scalars and/or other tuples/lists, 

919 or a scalar. Note: numpy arrays are considered scalars. 

920 flat_sequence: flat sequence to pack. 

921 

922 Returns: 

923 packed: `flat_sequence` converted to have the same recursive structure as 

924 `structure`. 

925 

926 Raises: 

927 ValueError: If nest and structure have different element counts. 

928 """ 

929 if not (_tf_data_is_nested(flat_sequence) or isinstance(flat_sequence, list)): 

930 raise TypeError( 

931 "Argument `flat_sequence` must be a sequence. Got " 

932 f"'{type(flat_sequence).__name__}'." 

933 ) 

934 

935 if not _tf_data_is_nested(structure): 

936 if len(flat_sequence) != 1: 

937 raise ValueError( 

938 "Argument `structure` is a scalar but " 

939 f"`len(flat_sequence)`={len(flat_sequence)} > 1" 

940 ) 

941 return flat_sequence[0] 

942 

943 flat_structure = _tf_data_flatten(structure) 

944 if len(flat_structure) != len(flat_sequence): 

945 raise ValueError( 

946 "Could not pack sequence. Argument `structure` had " 

947 f"{len(flat_structure)} elements, but argument `flat_sequence` had " 

948 f"{len(flat_sequence)} elements. Received structure: " 

949 f"{structure}, flat_sequence: {flat_sequence}." 

950 ) 

951 

952 _, packed = _tf_data_packed_nest_with_indices(structure, flat_sequence, 0) 

953 return sequence_like(structure, packed) # pylint: disable=protected-access 

954 

955 

956def map_structure(modality, func, *structure, **kwargs): 

957 """Creates a new structure by applying `func` to each atom in `structure`. 

958 

959 - For Modality.CORE: Refer to 

960 [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest) 

961 for the definition of a structure. 

962 

963 Applies `func(x[0], x[1], ...)` where x[i] enumerates all atoms in 

964 `structure[i]`. All items in `structure` must have the same arity, 

965 and the return value will contain results with the same structure layout. 

966 

967 Examples: 

968 

969 * A single Python dict: 

970 

971 >>> a = {"hello": 24, "world": 76} 

972 >>> tf.nest.map_structure(lambda p: p * 2, a) 

973 {'hello': 48, 'world': 152} 

974 

975 * Multiple Python dictionaries: 

976 

977 >>> d1 = {"hello": 24, "world": 76} 

978 >>> d2 = {"hello": 36, "world": 14} 

979 >>> tf.nest.map_structure(lambda p1, p2: p1 + p2, d1, d2) 

980 {'hello': 60, 'world': 90} 

981 

982 * A single Python list: 

983 

984 >>> a = [24, 76, "ab"] 

985 >>> tf.nest.map_structure(lambda p: p * 2, a) 

986 [48, 152, 'abab'] 

987 

988 * Scalars: 

989 

990 >>> tf.nest.map_structure(lambda x, y: x + y, 3, 4) 

991 7 

992 

993 * Empty structures: 

994 

995 >>> tf.nest.map_structure(lambda x: x + 1, ()) 

996 () 

997 

998 * Check the types of iterables: 

999 

1000 >>> s1 = (((1, 2), 3), 4, (5, 6)) 

1001 >>> s1_list = [[[1, 2], 3], 4, [5, 6]] 

1002 >>> tf.nest.map_structure(lambda x, y: None, s1, s1_list) 

1003 Traceback (most recent call last): 

1004 ... 

1005 TypeError: The two structures don't have the same nested structure 

1006 

1007 * Type check is set to False: 

1008 

1009 >>> s1 = (((1, 2), 3), 4, (5, 6)) 

1010 >>> s1_list = [[[1, 2], 3], 4, [5, 6]] 

1011 >>> tf.nest.map_structure(lambda x, y: None, s1, s1_list, check_types=False) 

1012 (((None, None), None), None, (None, None)) 

1013 

1014 - For Modality.DATA: Applies `func(x[0], x[1], ...)` where x[i] is an entry in 

1015 `structure[i]`. All structures in `structure` must have the same arity, 

1016 and the return value will contain the results in the same structure. 

1017 

1018 Args: 

1019 modality: enum value of supported modality [Modality.CORE or Modality.DATA] 

1020 func: A callable that accepts as many arguments as there are structures. 

1021 *structure: - For Modality.CORE: atom or nested structure. - For 

1022 Modality.DATA: scalar, or tuple or list of constructed scalars and/or 

1023 other tuples/lists, or scalars. Note: numpy arrays are considered 

1024 scalars. 

1025 **kwargs: Valid keyword args are: * `check_types`: - For Modality.CORE: If 

1026 set to `True` (default) the types of iterables within the structures have 

1027 to be same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError` 

1028 exception). To allow this set this argument to `False`. Note that 

1029 namedtuples with identical name and fields are always considered to have 

1030 the same shallow structure. - For Modality.DATA: only valid keyword 

1031 argument is `check_types`. If set to `True` (default) the types of 

1032 iterables within the structures have to be same (e.g. `map_structure(func, 

1033 [1], (1,))` raises a `TypeError` exception). To allow this set this 

1034 argument to `False`. * `expand_composites`: Valid for Modality.CORE only. 

1035 If set to `True`, then composite tensors such as `tf.sparse.SparseTensor` 

1036 and `tf.RaggedTensor` are expanded into their component tensors. If 

1037 `False` (the default), then composite tensors are not expanded. 

1038 

1039 Returns: 

1040 A new structure with the same arity as `structure[0]`, whose atoms 

1041 correspond to `func(x[0], x[1], ...)` where `x[i]` is the atom in the 

1042 corresponding location in `structure[i]`. If there are different structure 

1043 types and `check_types` is `False` the structure types of the first 

1044 structure will be used. 

1045 

1046 Raises: 

1047 TypeError: If `func` is not callable or if the structures do not match 

1048 each other by depth tree. 

1049 ValueError: If no structure is provided or if the structures do not match 

1050 each other by type. 

1051 ValueError: If wrong keyword arguments are provided. 

1052 """ 

1053 if modality == Modality.CORE: 

1054 return _tf_core_map_structure(func, *structure, **kwargs) 

1055 elif modality == Modality.DATA: 

1056 return _tf_data_map_structure(func, *structure, **kwargs) 

1057 else: 

1058 raise ValueError( 

1059 "Unknown modality used {} for nested structure".format(modality) 

1060 ) 

1061 

1062 

1063# pylint: disable=missing-function-docstring 

1064def _tf_core_map_structure(func, *structure, **kwargs): 

1065 if not callable(func): 

1066 raise TypeError("func must be callable, got: %s" % func) 

1067 

1068 if not structure: 

1069 raise ValueError("Must provide at least one structure") 

1070 

1071 check_types = kwargs.pop("check_types", True) 

1072 expand_composites = kwargs.pop("expand_composites", False) 

1073 

1074 if kwargs: 

1075 raise ValueError( 

1076 "Only valid keyword arguments are `check_types` and " 

1077 "`expand_composites`, not: `%s`" 

1078 % "`, `".join(kwargs.keys()) 

1079 ) 

1080 

1081 for other in structure[1:]: 

1082 _tf_core_assert_same_structure( 

1083 structure[0], 

1084 other, 

1085 check_types=check_types, 

1086 expand_composites=expand_composites, 

1087 ) 

1088 

1089 flat_structure = (_tf_core_flatten(s, expand_composites) for s in structure) 

1090 entries = zip(*flat_structure) 

1091 

1092 return _tf_core_pack_sequence_as( 

1093 structure[0], 

1094 [func(*x) for x in entries], 

1095 expand_composites=expand_composites, 

1096 ) 

1097 

1098 

1099# pylint: disable=missing-function-docstring 

1100def _tf_data_map_structure(func, *structure, **check_types_dict): 

1101 if not callable(func): 

1102 raise TypeError(f"Argument `func` must be callable, got: {func}") 

1103 

1104 if not structure: 

1105 raise ValueError("Must provide at least one structure") 

1106 

1107 if check_types_dict: 

1108 if "check_types" not in check_types_dict or len(check_types_dict) > 1: 

1109 raise ValueError( 

1110 "Only valid keyword argument for `check_types_dict` is " 

1111 f"'check_types'. Got {check_types_dict}." 

1112 ) 

1113 check_types = check_types_dict["check_types"] 

1114 else: 

1115 check_types = True 

1116 

1117 for other in structure[1:]: 

1118 _tf_data_assert_same_structure(structure[0], other, check_types=check_types) 

1119 

1120 flat_structure = (_tf_data_flatten(s) for s in structure) 

1121 entries = zip(*flat_structure) 

1122 

1123 return _tf_data_pack_sequence_as(structure[0], [func(*x) for x in entries]) 

1124 

1125 

1126def yield_flat_up_to(modality, shallow_tree, input_tree, is_nested_fn, path=()): 

1127 """Yields (path, value) pairs of input_tree flattened up to shallow_tree. 

1128 

1129 - For Modality.CORE: See comments for _tf_core_yield_flat_up_to() below 

1130 - For Modality.DATA: See comments for _tf_data_yield_flat_up_to() below 

1131 

1132 Args: 

1133 modality: enum value of supported modality [Modality.CORE or Modality.DATA] 

1134 shallow_tree: Nested structure. Traverse no further than its leaf nodes. 

1135 input_tree: Nested structure. Return the paths and values from this tree. 

1136 Must have the same upper structure as shallow_tree. 

1137 is_nested_fn: Arg valid for Modality.CORE only. Function used to test if a 

1138 value should be treated as a nested structure. 

1139 path: Arg valid for Modality.CORE only. Tuple. Optional argument, only used 

1140 when recursing. The path from the root of the original shallow_tree, down 

1141 to the root of the shallow_tree arg of this recursive call. 

1142 

1143 Yields: 

1144 Pairs of (path, value), where path the tuple path of a leaf node in 

1145 shallow_tree, and value is the value of the corresponding node in 

1146 input_tree. 

1147 """ 

1148 if modality == Modality.CORE: 

1149 yield from _tf_core_yield_flat_up_to( 

1150 shallow_tree, input_tree, is_nested_fn, path 

1151 ) 

1152 elif modality == Modality.DATA: 

1153 yield from _tf_data_yield_flat_up_to(shallow_tree, input_tree) 

1154 else: 

1155 raise ValueError( 

1156 "Unknown modality used {} for nested structure".format(modality) 

1157 ) 

1158 

1159 

1160def _tf_core_yield_flat_up_to(shallow_tree, input_tree, is_nested_fn, path=()): 

1161 """Yields (path, value) pairs of input_tree flattened up to shallow_tree. 

1162 

1163 Args: 

1164 shallow_tree: Nested structure. Traverse no further than its leaf nodes. 

1165 input_tree: Nested structure. Return the paths and values from this tree. 

1166 Must have the same upper structure as shallow_tree. 

1167 is_nested_fn: Function used to test if a value should be treated as a nested 

1168 structure. 

1169 path: Tuple. Optional argument, only used when recursing. The path from the 

1170 root of the original shallow_tree, down to the root of the shallow_tree 

1171 arg of this recursive call. 

1172 

1173 Yields: 

1174 Pairs of (path, value), where path the tuple path of a leaf node in 

1175 shallow_tree, and value is the value of the corresponding node in 

1176 input_tree. 

1177 """ 

1178 if not is_nested_fn(shallow_tree): 

1179 yield (path, input_tree) 

1180 else: 

1181 input_tree = dict(_tf_core_yield_sorted_items(input_tree)) 

1182 for ( 

1183 shallow_key, 

1184 shallow_subtree, 

1185 ) in _tf_core_yield_sorted_items(shallow_tree): 

1186 subpath = path + (shallow_key,) 

1187 input_subtree = input_tree[shallow_key] 

1188 for leaf_path, leaf_value in _tf_core_yield_flat_up_to( 

1189 shallow_subtree, input_subtree, is_nested_fn, path=subpath 

1190 ): 

1191 yield (leaf_path, leaf_value) 

1192 

1193 

1194def _tf_data_yield_flat_up_to(shallow_tree, input_tree): 

1195 """Yields elements `input_tree` partially flattened up to `shallow_tree`.""" 

1196 if _tf_data_is_nested(shallow_tree): 

1197 for shallow_branch, input_branch in zip( 

1198 _tf_data_yield_value(shallow_tree), _tf_data_yield_value(input_tree) 

1199 ): 

1200 for input_leaf in _tf_data_yield_flat_up_to(shallow_branch, input_branch): 

1201 yield input_leaf 

1202 else: 

1203 yield input_tree 

1204 

1205 

1206def assert_shallow_structure( 

1207 modality, 

1208 shallow_tree, 

1209 input_tree, 

1210 check_types=True, 

1211 expand_composites=False, 

1212): 

1213 """Asserts that `shallow_tree` is a shallow structure of `input_tree`. 

1214 

1215 This function tests if the `input_tree` structure can be created from 

1216 the `shallow_tree` structure by replacing its leaf nodes with deeper 

1217 tree structures. 

1218 

1219 Examples: 

1220 

1221 The following code will raise an exception: 

1222 ```python 

1223 shallow_tree = {"a": "A", "b": "B"} 

1224 input_tree = {"a": 1, "c": 2} 

1225 assert_shallow_structure(shallow_tree, input_tree) 

1226 ``` 

1227 

1228 The following code will raise an exception: 

1229 ```python 

1230 shallow_tree = ["a", "b"] 

1231 input_tree = ["c", ["d", "e"], "f"] 

1232 assert_shallow_structure(shallow_tree, input_tree) 

1233 ``` 

1234 

1235 Args: 

1236 modality: enum value of supported modality [Modality.CORE or Modality.DATA] 

1237 shallow_tree: an arbitrarily nested structure. 

1238 input_tree: an arbitrarily nested structure. 

1239 check_types: if `True` (default) the sequence types of `shallow_tree` and 

1240 `input_tree` have to be the same. Note that even with check_types==True, 

1241 this function will consider two different namedtuple classes with the same 

1242 name and _fields attribute to be the same class. 

1243 expand_composites: Valid for Modality.CORE only. If true, then composite 

1244 tensors such as `tf.sparse.SparseTensor` and `tf.RaggedTensor` are 

1245 expanded into their component tensors. 

1246 

1247 Raises: 

1248 TypeError: If `shallow_tree` is a sequence but `input_tree` is not. 

1249 TypeError: If the sequence types of `shallow_tree` are different from 

1250 `input_tree`. Only raised if `check_types` is `True`. 

1251 ValueError: If the sequence lengths of `shallow_tree` are different from 

1252 `input_tree`. 

1253 """ 

1254 if modality == Modality.CORE: 

1255 _tf_core_assert_shallow_structure( 

1256 shallow_tree, input_tree, check_types, expand_composites 

1257 ) 

1258 elif modality == Modality.DATA: 

1259 _tf_data_assert_shallow_structure(shallow_tree, input_tree, check_types) 

1260 else: 

1261 raise ValueError( 

1262 "Unknown modality used {} for nested structure".format(modality) 

1263 ) 

1264 

1265 

1266# pylint: disable=missing-function-docstring 

1267def _tf_core_assert_shallow_structure( 

1268 shallow_tree, input_tree, check_types=True, expand_composites=False 

1269): 

1270 is_nested_fn = ( 

1271 _is_nested_or_composite if expand_composites else _tf_core_is_nested 

1272 ) 

1273 if is_nested_fn(shallow_tree): 

1274 if not is_nested_fn(input_tree): 

1275 raise TypeError( 

1276 "If shallow structure is a sequence, input must also be a sequence. " 

1277 "Input has type: %s." 

1278 % type(input_tree) 

1279 ) 

1280 

1281 if isinstance(shallow_tree, _wrapt.ObjectProxy): 

1282 shallow_type = type(shallow_tree.__wrapped__) 

1283 else: 

1284 shallow_type = type(shallow_tree) 

1285 

1286 if check_types and not isinstance(input_tree, shallow_type): 

1287 # Duck-typing means that nest should be fine with two different 

1288 # namedtuples with identical name and fields. 

1289 shallow_is_namedtuple = is_namedtuple(shallow_tree, False) 

1290 input_is_namedtuple = is_namedtuple(input_tree, False) 

1291 if shallow_is_namedtuple and input_is_namedtuple: 

1292 if not same_namedtuples(shallow_tree, input_tree): 

1293 raise TypeError( 

1294 STRUCTURES_HAVE_MISMATCHING_TYPES.format( 

1295 input_type=type(input_tree), shallow_type=type(shallow_tree) 

1296 ) 

1297 ) 

1298 

1299 elif isinstance(shallow_tree, list) and isinstance(input_tree, list): 

1300 # List subclasses are considered the same, 

1301 # e.g. python list vs. _ListWrapper. 

1302 pass 

1303 

1304 elif ( 

1305 _is_composite_tensor(shallow_tree) or _is_type_spec(shallow_tree) 

1306 ) and (_is_composite_tensor(input_tree) or _is_type_spec(input_tree)): 

1307 pass # Compatibility will be checked below. 

1308 

1309 elif not ( 

1310 isinstance(shallow_tree, _collections_abc.Mapping) 

1311 and isinstance(input_tree, _collections_abc.Mapping) 

1312 ): 

1313 raise TypeError( 

1314 STRUCTURES_HAVE_MISMATCHING_TYPES.format( 

1315 input_type=type(input_tree), shallow_type=type(shallow_tree) 

1316 ) 

1317 ) 

1318 

1319 if _is_composite_tensor(shallow_tree) or _is_composite_tensor(input_tree): 

1320 if not ( 

1321 (_is_composite_tensor(input_tree) or _is_type_spec(input_tree)) 

1322 and ( 

1323 _is_composite_tensor(shallow_tree) or _is_type_spec(shallow_tree) 

1324 ) 

1325 ): 

1326 raise TypeError( 

1327 STRUCTURES_HAVE_MISMATCHING_TYPES.format( 

1328 input_type=type(input_tree), shallow_type=type(shallow_tree) 

1329 ) 

1330 ) 

1331 # pylint: disable=protected-access 

1332 type_spec_1 = ( 

1333 shallow_tree 

1334 if _is_type_spec(shallow_tree) 

1335 else shallow_tree._type_spec 

1336 )._without_tensor_names() 

1337 type_spec_2 = ( 

1338 input_tree if _is_type_spec(input_tree) else input_tree._type_spec 

1339 )._without_tensor_names() 

1340 # TODO(b/246356867): Replace the most_specific_common_supertype below 

1341 # with get_structure. 

1342 if hasattr(type_spec_1, "_get_structure") and hasattr( 

1343 type_spec_2, "_get_structure" 

1344 ): 

1345 result = ( 

1346 type_spec_1._get_structure() == type_spec_2._get_structure() or None 

1347 ) 

1348 else: 

1349 result = type_spec_1.most_specific_common_supertype([type_spec_2]) 

1350 if result is None: 

1351 raise ValueError( 

1352 "Incompatible CompositeTensor TypeSpecs: %s vs. %s" 

1353 % (type_spec_1, type_spec_2) 

1354 ) 

1355 # pylint: enable=protected-access 

1356 

1357 elif _is_type_spec(shallow_tree): 

1358 if not _is_type_spec(input_tree): 

1359 raise TypeError( 

1360 "If shallow structure is a TypeSpec, input must also " 

1361 "be a TypeSpec. Input has type: %s." 

1362 % type(input_tree) 

1363 ) 

1364 else: 

1365 if len(input_tree) != len(shallow_tree): 

1366 raise ValueError( 

1367 STRUCTURES_HAVE_MISMATCHING_LENGTHS.format( 

1368 input_length=len(input_tree), shallow_length=len(shallow_tree) 

1369 ) 

1370 ) 

1371 elif len(input_tree) < len(shallow_tree): 

1372 raise ValueError( 

1373 INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format( 

1374 input_size=len(input_tree), shallow_size=len(shallow_tree) 

1375 ) 

1376 ) 

1377 

1378 if isinstance(shallow_tree, _collections_abc.Mapping): 

1379 absent_keys = set(shallow_tree) - set(input_tree) 

1380 if absent_keys: 

1381 raise ValueError( 

1382 SHALLOW_TREE_HAS_INVALID_KEYS.format(sorted(absent_keys)) 

1383 ) 

1384 

1385 for shallow_branch, input_branch in zip( 

1386 _tf_core_yield_value(shallow_tree), 

1387 _tf_core_yield_value(input_tree), 

1388 ): 

1389 _tf_core_assert_shallow_structure( 

1390 shallow_branch, 

1391 input_branch, 

1392 check_types=check_types, 

1393 expand_composites=expand_composites, 

1394 ) 

1395 

1396 

1397# pylint: disable=missing-function-docstring 

1398def _tf_data_assert_shallow_structure( 

1399 shallow_tree, input_tree, check_types=True 

1400): 

1401 if _tf_data_is_nested(shallow_tree): 

1402 if not _tf_data_is_nested(input_tree): 

1403 raise TypeError( 

1404 "If shallow structure is a sequence, input must also be a sequence. " 

1405 f"Input has type: '{type(input_tree).__name__}'." 

1406 ) 

1407 

1408 if check_types and not isinstance(input_tree, type(shallow_tree)): 

1409 raise TypeError( 

1410 "The two structures don't have the same sequence type. Input " 

1411 f"structure has type '{type(input_tree).__name__}', while shallow " 

1412 f"structure has type '{type(shallow_tree).__name__}'." 

1413 ) 

1414 

1415 if len(input_tree) != len(shallow_tree): 

1416 raise ValueError( 

1417 "The two structures don't have the same sequence length. Input " 

1418 f"structure has length {len(input_tree)}, while shallow structure " 

1419 f"has length {len(shallow_tree)}." 

1420 ) 

1421 

1422 if check_types and isinstance(shallow_tree, _collections_abc.Mapping): 

1423 if set(input_tree) != set(shallow_tree): 

1424 raise ValueError( 

1425 "The two structures don't have the same keys. Input " 

1426 f"structure has keys {list(input_tree)}, while shallow structure " 

1427 f"has keys {list(shallow_tree)}." 

1428 ) 

1429 input_tree = sorted(input_tree.items()) 

1430 shallow_tree = sorted(shallow_tree.items()) 

1431 

1432 for shallow_branch, input_branch in zip(shallow_tree, input_tree): 

1433 _tf_data_assert_shallow_structure( 

1434 shallow_branch, input_branch, check_types=check_types 

1435 ) 

1436 

1437 

1438def flatten_up_to( 

1439 modality, 

1440 shallow_tree, 

1441 input_tree, 

1442 check_types=True, 

1443 expand_composites=False, 

1444): 

1445 # pylint: disable=g-doc-return-or-yield,g-doc-args 

1446 """Flattens `input_tree` up to `shallow_tree`. 

1447 

1448 - For Modality.CORE: refer to 

1449 [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest) 

1450 for the definition of a structure. 

1451 

1452 Any further depth in structure in `input_tree` is retained as structures in 

1453 the partially flatten output. 

1454 

1455 If `shallow_tree` and `input_tree` are atoms, this returns a 

1456 single-item list: `[input_tree]`. 

1457 

1458 Use Case: 

1459 

1460 Sometimes we may wish to partially flatten a structure, retaining some 

1461 of the nested structure. We achieve this by specifying a shallow structure, 

1462 `shallow_tree`, we wish to flatten up to. 

1463 

1464 The input, `input_tree`, can be thought of as having the same structure layout 

1465 as `shallow_tree`, but with leaf nodes that are themselves tree structures. 

1466 

1467 Examples: 

1468 

1469 ```python 

1470 input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] 

1471 shallow_tree = [[True, True], [False, True]] 

1472 

1473 flattened_input_tree = flatten_up_to(shallow_tree, input_tree) 

1474 flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree) 

1475 

1476 # Output is: 

1477 # [[2, 2], [3, 3], [4, 9], [5, 5]] 

1478 # [True, True, False, True] 

1479 ``` 

1480 

1481 ```python 

1482 input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]] 

1483 shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]] 

1484 

1485 input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree) 

1486 input_tree_flattened = flatten(input_tree) 

1487 

1488 # Output is: 

1489 # [('a', 1), ('b', 2), ('c', 3), ('d', 4)] 

1490 # ['a', 1, 'b', 2, 'c', 3, 'd', 4] 

1491 ``` 

1492 

1493 Edge Cases: 

1494 

1495 ```python 

1496 flatten_up_to(0, 0) # Output: [0] 

1497 flatten_up_to(0, [0, 1, 2]) # Output: [[0, 1, 2]] 

1498 flatten_up_to([0, 1, 2], 0) # Output: TypeError 

1499 flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2] 

1500 

1501 ``` 

1502 

1503 Args: 

1504 modality: enum value of supported modality [Modality.CORE or Modality.DATA] 

1505 shallow_tree: a possibly pruned structure of input_tree. 

1506 input_tree: an atom or a nested structure. Note, numpy arrays are considered 

1507 atoms. 

1508 check_types: bool. If True, check that each node in shallow_tree has the 

1509 same type as the corresponding node in input_tree. 

1510 expand_composites: Arg valid for Modality.CORE only. If true, then composite 

1511 tensors such as `tf.sparse.SparseTensor` and `tf.RaggedTensor` are 

1512 expanded into their component tensors. 

1513 

1514 Returns: 

1515 A Python list, the partially flattened version of `input_tree` according to 

1516 the structure of `shallow_tree`. 

1517 

1518 Raises: 

1519 TypeError: If `shallow_tree` is a nested structure but `input_tree` is not. 

1520 TypeError: If the structure types of `shallow_tree` are different from 

1521 `input_tree`. 

1522 ValueError: If the structure lengths of `shallow_tree` are different from 

1523 `input_tree`. 

1524 """ 

1525 if modality == Modality.CORE: 

1526 return _tf_core_flatten_up_to( 

1527 shallow_tree, input_tree, check_types, expand_composites 

1528 ) 

1529 elif modality == Modality.DATA: 

1530 return _tf_data_flatten_up_to(shallow_tree, input_tree) 

1531 else: 

1532 raise ValueError( 

1533 "Unknown modality used {} for nested structure".format(modality) 

1534 ) 

1535 

1536 

1537def _tf_core_flatten_up_to( 

1538 shallow_tree, input_tree, check_types=True, expand_composites=False 

1539): 

1540 is_nested_fn = ( 

1541 _is_nested_or_composite if expand_composites else _tf_core_is_nested 

1542 ) 

1543 _tf_core_assert_shallow_structure( 

1544 shallow_tree, 

1545 input_tree, 

1546 check_types=check_types, 

1547 expand_composites=expand_composites, 

1548 ) 

1549 # Discard paths returned by nest_util._tf_core_yield_flat_up_to. 

1550 return [ 

1551 v 

1552 for _, v in _tf_core_yield_flat_up_to( 

1553 shallow_tree, input_tree, is_nested_fn 

1554 ) 

1555 ] 

1556 

1557 

1558def _tf_data_flatten_up_to(shallow_tree, input_tree): 

1559 _tf_data_assert_shallow_structure(shallow_tree, input_tree) 

1560 return list(_tf_data_yield_flat_up_to(shallow_tree, input_tree)) 

1561 

1562 

1563def map_structure_up_to(modality, shallow_tree, func, *inputs, **kwargs): 

1564 """Applies a function or op to a number of partially flattened inputs. 

1565 

1566 The `inputs` are flattened up to `shallow_tree` before being mapped. 

1567 

1568 Use Case: 

1569 

1570 Sometimes we wish to apply a function to a partially flattened 

1571 structure (for example when the function itself takes structure inputs). We 

1572 achieve this by specifying a shallow structure, `shallow_tree` we wish to 

1573 flatten up to. 

1574 

1575 The `inputs`, can be thought of as having the same structure layout as 

1576 `shallow_tree`, but with leaf nodes that are themselves tree structures. 

1577 

1578 This function therefore will return something with the same base structure as 

1579 `shallow_tree`. 

1580 

1581 Examples: 

1582 

1583 ```python 

1584 shallow_tree = [None, None] 

1585 inp_val = [1, 2, 3] 

1586 out = map_structure_up_to(shallow_tree, lambda x: 2 * x, inp_val) 

1587 

1588 # Output is: [2, 4] 

1589 ``` 

1590 

1591 ```python 

1592 ab_tuple = collections.namedtuple("ab_tuple", "a, b") 

1593 op_tuple = collections.namedtuple("op_tuple", "add, mul") 

1594 inp_val = ab_tuple(a=2, b=3) 

1595 inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) 

1596 out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul, 

1597 inp_val, inp_ops) 

1598 

1599 # Output is: ab_tuple(a=6, b=15) 

1600 ``` 

1601 

1602 ```python 

1603 data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] 

1604 name_list = ['evens', ['odds', 'primes']] 

1605 out = map_structure_up_to( 

1606 name_list, 

1607 lambda name, sec: "first_{}_{}".format(len(sec), name), 

1608 name_list, data_list) 

1609 

1610 # Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']] 

1611 ``` 

1612 

1613 Args: 

1614 modality: enum value of supported modality [Modality.CORE or Modality.DATA] 

1615 shallow_tree: a shallow structure, common to all the inputs. 

1616 func: callable which will be applied to each input individually. 

1617 *inputs: structures that are compatible with shallow_tree. The function 

1618 `func` is applied to corresponding structures due to partial flattening of 

1619 each input, so the function must support arity of `len(inputs)`. 

1620 **kwargs: Arg valid for Modality.CORE only. kwargs to feed to func(). 

1621 Special kwarg `check_types` is not passed to func, but instead determines 

1622 whether the types of iterables within the structures have to be same (e.g. 

1623 `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow 

1624 this set this argument to `False`. 

1625 

1626 Raises: 

1627 TypeError: If `shallow_tree` is a nested structure but `input_tree` is not. 

1628 TypeError: If the structure types of `shallow_tree` are different from 

1629 `input_tree`. 

1630 ValueError: If the structure lengths of `shallow_tree` are different from 

1631 `input_tree`. 

1632 

1633 Returns: 

1634 result of repeatedly applying `func`, with the same structure layout as 

1635 `shallow_tree`. 

1636 """ 

1637 if modality == Modality.CORE: 

1638 return _tf_core_map_structure_with_tuple_paths_up_to( 

1639 shallow_tree, func, *inputs, **kwargs 

1640 ) 

1641 elif modality == Modality.DATA: 

1642 return _tf_data_map_structure_up_to(shallow_tree, func, *inputs) 

1643 else: 

1644 raise ValueError( 

1645 "Unknown modality used {} for nested structure".format(modality) 

1646 ) 

1647 

1648 

1649def _tf_core_map_structure_with_tuple_paths_up_to( 

1650 shallow_tree, func, *inputs, **kwargs 

1651): 

1652 """See comments for map_structure_with_tuple_paths_up_to() in tensorflow/python/util/nest.py.""" 

1653 if not inputs: 

1654 raise ValueError("Cannot map over no sequences") 

1655 

1656 check_types = kwargs.pop("check_types", True) 

1657 expand_composites = kwargs.pop("expand_composites", False) 

1658 is_nested_fn = ( 

1659 _is_nested_or_composite if expand_composites else _tf_core_is_nested 

1660 ) 

1661 

1662 for input_tree in inputs: 

1663 _tf_core_assert_shallow_structure( 

1664 shallow_tree, 

1665 input_tree, 

1666 check_types=check_types, 

1667 expand_composites=expand_composites, 

1668 ) 

1669 

1670 # Flatten each input separately, apply the function to corresponding items, 

1671 # then repack based on the structure of the first input. 

1672 flat_value_gen = ( 

1673 _tf_core_flatten_up_to( # pylint: disable=g-complex-comprehension 

1674 shallow_tree, 

1675 input_tree, 

1676 check_types, 

1677 expand_composites=expand_composites, 

1678 ) 

1679 for input_tree in inputs 

1680 ) 

1681 flat_path_gen = ( 

1682 path 

1683 for path, _ in _tf_core_yield_flat_up_to( 

1684 shallow_tree, inputs[0], is_nested_fn 

1685 ) 

1686 ) 

1687 results = [ 

1688 func(*args, **kwargs) for args in zip(flat_path_gen, *flat_value_gen) 

1689 ] 

1690 return _tf_core_pack_sequence_as( 

1691 structure=shallow_tree, 

1692 flat_sequence=results, 

1693 expand_composites=expand_composites, 

1694 ) 

1695 

1696 

1697# pylint: disable=missing-function-docstring 

1698def _tf_data_map_structure_up_to(shallow_tree, func, *inputs): 

1699 if not inputs: 

1700 raise ValueError( 

1701 "Argument `inputs` is empty. Cannot map over no sequences." 

1702 ) 

1703 for input_tree in inputs: 

1704 _tf_data_assert_shallow_structure(shallow_tree, input_tree) 

1705 

1706 # Flatten each input separately, apply the function to corresponding elements, 

1707 # then repack based on the structure of the first input. 

1708 all_flattened_up_to = ( 

1709 _tf_data_flatten_up_to(shallow_tree, input_tree) for input_tree in inputs 

1710 ) 

1711 

1712 results = [func(*tensors) for tensors in zip(*all_flattened_up_to)] 

1713 return _tf_data_pack_sequence_as( 

1714 structure=shallow_tree, flat_sequence=results 

1715 )