Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/util/nest.py: 56%
18 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""## Functions for working with arbitrarily nested sequences of elements.
18NOTE(mrry): This fork of the `tensorflow.python.util.nest` module
19makes two changes:
211. It removes support for lists as a level of nesting in nested structures.
222. It adds support for `SparseTensorValue` as an atomic element.
24The motivation for this change is twofold:
261. It seems more natural for lists to be treated (e.g. in Dataset constructors)
27 as tensors, rather than lists of (lists of...) tensors.
282. This is needed because `SparseTensorValue` is implemented as a `namedtuple`
29 that would normally be flattened and we want to be able to create sparse
30 tensor from `SparseTensorValue's similarly to creating tensors from numpy
31 arrays.
32"""
34from tensorflow.python.util import nest_util
37def is_nested(structure):
38 return nest_util.is_nested(nest_util.Modality.DATA, structure)
41def flatten(structure):
42 return nest_util.flatten(nest_util.Modality.DATA, structure)
45def assert_same_structure(nest1, nest2, check_types=True):
46 """Asserts that two structures are nested in the same way.
48 Args:
49 nest1: an arbitrarily nested structure.
50 nest2: an arbitrarily nested structure.
51 check_types: if `True` (default) types of sequences should be same as
52 well. For dictionary, "type" of dictionary is considered to include its
53 keys. In other words, two dictionaries with different keys are considered
54 to have a different "type". If set to `False`, two iterables are
55 considered same as long as they yield the elements that have same
56 structures.
58 Raises:
59 ValueError: If the two structures do not have the same number of elements or
60 if the two structures are not nested in the same way.
61 TypeError: If the two structures differ in the type of sequence in any of
62 their substructures. Only possible if `check_types` is `True`.
63 """
64 nest_util.assert_same_structure(
65 nest_util.Modality.DATA, nest1, nest2, check_types
66 )
69def pack_sequence_as(structure, flat_sequence):
70 """Returns a given flattened sequence packed into a nest.
72 If `structure` is a scalar, `flat_sequence` must be a single-element list;
73 in this case the return value is `flat_sequence[0]`.
75 Args:
76 structure: tuple or list constructed of scalars and/or other tuples/lists,
77 or a scalar. Note: numpy arrays are considered scalars.
78 flat_sequence: flat sequence to pack.
80 Returns:
81 packed: `flat_sequence` converted to have the same recursive structure as
82 `structure`.
84 Raises:
85 ValueError: If nest and structure have different element counts.
86 """
87 return nest_util.pack_sequence_as(
88 nest_util.Modality.DATA, structure, flat_sequence, expand_composites=False
89 )
92def map_structure(func, *structure, **check_types_dict):
93 """Applies `func` to each entry in `structure` and returns a new structure.
95 Applies `func(x[0], x[1], ...)` where x[i] is an entry in
96 `structure[i]`. All structures in `structure` must have the same arity,
97 and the return value will contain the results in the same structure.
99 Args:
100 func: A callable that accepts as many arguments are there are structures.
101 *structure: scalar, or tuple or list of constructed scalars and/or other
102 tuples/lists, or scalars. Note: numpy arrays are considered scalars.
103 **check_types_dict: only valid keyword argument is `check_types`. If set to
104 `True` (default) the types of iterables within the structures have to be
105 same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError`
106 exception). To allow this set this argument to `False`.
108 Returns:
109 A new structure with the same arity as `structure`, whose values correspond
110 to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding
111 location in `structure[i]`. If there are different sequence types and
112 `check_types` is `False` the sequence types of the first structure will be
113 used.
115 Raises:
116 TypeError: If `func` is not callable or if the structures do not match
117 each other by depth tree.
118 ValueError: If no structure is provided or if the structures do not match
119 each other by type.
120 ValueError: If wrong keyword arguments are provided.
121 """
122 return nest_util.map_structure(
123 nest_util.Modality.DATA, func, *structure, **check_types_dict
124 )
127def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
128 """Asserts that `shallow_tree` is a shallow structure of `input_tree`.
130 That is, this function tests if the `input_tree` structure can be created from
131 the `shallow_tree` structure by replacing its leaf nodes with deeper
132 tree structures.
134 Examples:
136 The following code will raise an exception:
137 ```python
138 shallow_tree = ["a", "b"]
139 input_tree = ["c", ["d", "e"], "f"]
140 assert_shallow_structure(shallow_tree, input_tree)
141 ```
143 The following code will not raise an exception:
144 ```python
145 shallow_tree = ["a", "b"]
146 input_tree = ["c", ["d", "e"]]
147 assert_shallow_structure(shallow_tree, input_tree)
148 ```
150 Args:
151 shallow_tree: an arbitrarily nested structure.
152 input_tree: an arbitrarily nested structure.
153 check_types: if `True` (default) the sequence types of `shallow_tree` and
154 `input_tree` have to be the same.
156 Raises:
157 TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
158 TypeError: If the sequence types of `shallow_tree` are different from
159 `input_tree`. Only raised if `check_types` is `True`.
160 ValueError: If the sequence lengths of `shallow_tree` are different from
161 `input_tree`.
162 """
163 nest_util.assert_shallow_structure(
164 nest_util.Modality.DATA, shallow_tree, input_tree, check_types
165 )
168def flatten_up_to(shallow_tree, input_tree):
169 """Flattens `input_tree` up to `shallow_tree`.
171 Any further depth in structure in `input_tree` is retained as elements in the
172 partially flatten output.
174 If `shallow_tree` and `input_tree` are not sequences, this returns a
175 single-element list: `[input_tree]`.
177 Use Case:
179 Sometimes we may wish to partially flatten a nested sequence, retaining some
180 of the nested structure. We achieve this by specifying a shallow structure,
181 `shallow_tree`, we wish to flatten up to.
183 The input, `input_tree`, can be thought of as having the same structure as
184 `shallow_tree`, but with leaf nodes that are themselves tree structures.
186 Examples:
188 ```python
189 input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
190 shallow_tree = [[True, True], [False, True]]
192 flattened_input_tree = flatten_up_to(shallow_tree, input_tree)
193 flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree)
195 # Output is:
196 # [[2, 2], [3, 3], [4, 9], [5, 5]]
197 # [True, True, False, True]
198 ```
200 ```python
201 input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]]
202 shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]]
204 input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree)
205 input_tree_flattened = flatten(input_tree)
207 # Output is:
208 # [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
209 # ['a', 1, 'b', 2, 'c', 3, 'd', 4]
210 ```
212 Non-Sequence Edge Cases:
214 ```python
215 flatten_up_to(0, 0) # Output: [0]
216 flatten_up_to(0, [0, 1, 2]) # Output: [[0, 1, 2]]
217 flatten_up_to([0, 1, 2], 0) # Output: TypeError
218 flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2]
219 ```
221 Args:
222 shallow_tree: a possibly pruned structure of input_tree.
223 input_tree: an arbitrarily nested structure or a scalar object.
224 Note, numpy arrays are considered scalars.
226 Returns:
227 A Python list, the partially flattened version of `input_tree` according to
228 the structure of `shallow_tree`.
230 Raises:
231 TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
232 TypeError: If the sequence types of `shallow_tree` are different from
233 `input_tree`.
234 ValueError: If the sequence lengths of `shallow_tree` are different from
235 `input_tree`.
236 """
237 return nest_util.flatten_up_to(
238 nest_util.Modality.DATA, shallow_tree, input_tree
239 )
242def map_structure_up_to(shallow_tree, func, *inputs):
243 """Applies a function or op to a number of partially flattened inputs.
245 The `inputs` are flattened up to `shallow_tree` before being mapped.
247 Use Case:
249 Sometimes we wish to apply a function to a partially flattened
250 sequence (for example when the function itself takes sequence inputs). We
251 achieve this by specifying a shallow structure, `shallow_tree` we wish to
252 flatten up to.
254 The `inputs`, can be thought of as having the same structure as
255 `shallow_tree`, but with leaf nodes that are themselves tree structures.
257 This function, therefore, will return something with the same base structure
258 as `shallow_tree`.
260 Examples:
262 ```python
263 ab_tuple = collections.namedtuple("ab_tuple", "a, b")
264 op_tuple = collections.namedtuple("op_tuple", "add, mul")
265 inp_val = ab_tuple(a=2, b=3)
266 inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
267 out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul,
268 inp_val, inp_ops)
270 # Output is: ab_tuple(a=6, b=15)
271 ```
273 ```python
274 data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
275 name_list = ['evens', ['odds', 'primes']]
276 out = map_structure_up_to(
277 name_list,
278 lambda name, sec: "first_{}_{}".format(len(sec), name),
279 name_list, data_list)
281 # Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']]
282 ```
284 Args:
285 shallow_tree: a shallow tree, common to all the inputs.
286 func: callable which will be applied to each input individually.
287 *inputs: arbitrarily nested combination of objects that are compatible with
288 shallow_tree. The function `func` is applied to corresponding
289 partially flattened elements of each input, so the function must support
290 arity of `len(inputs)`.
292 Raises:
293 TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
294 TypeError: If the sequence types of `shallow_tree` are different from
295 `input_tree`.
296 ValueError: If the sequence lengths of `shallow_tree` are different from
297 `input_tree`.
299 Returns:
300 result of repeatedly applying `func`, with same structure as
301 `shallow_tree`.
302 """
303 return nest_util.map_structure_up_to(
304 nest_util.Modality.DATA, shallow_tree, func, *inputs
305 )