Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/util/nest.py: 56%

18 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""## Functions for working with arbitrarily nested sequences of elements. 

17 

18NOTE(mrry): This fork of the `tensorflow.python.util.nest` module 

19makes two changes: 

20 

211. It removes support for lists as a level of nesting in nested structures. 

222. It adds support for `SparseTensorValue` as an atomic element. 

23 

24The motivation for this change is twofold: 

25 

261. It seems more natural for lists to be treated (e.g. in Dataset constructors) 

27 as tensors, rather than lists of (lists of...) tensors. 

282. This is needed because `SparseTensorValue` is implemented as a `namedtuple` 

29 that would normally be flattened and we want to be able to create sparse 

30 tensor from `SparseTensorValue's similarly to creating tensors from numpy 

31 arrays. 

32""" 

33 

34from tensorflow.python.util import nest_util 

35 

36 

37def is_nested(structure): 

38 return nest_util.is_nested(nest_util.Modality.DATA, structure) 

39 

40 

41def flatten(structure): 

42 return nest_util.flatten(nest_util.Modality.DATA, structure) 

43 

44 

45def assert_same_structure(nest1, nest2, check_types=True): 

46 """Asserts that two structures are nested in the same way. 

47 

48 Args: 

49 nest1: an arbitrarily nested structure. 

50 nest2: an arbitrarily nested structure. 

51 check_types: if `True` (default) types of sequences should be same as 

52 well. For dictionary, "type" of dictionary is considered to include its 

53 keys. In other words, two dictionaries with different keys are considered 

54 to have a different "type". If set to `False`, two iterables are 

55 considered same as long as they yield the elements that have same 

56 structures. 

57 

58 Raises: 

59 ValueError: If the two structures do not have the same number of elements or 

60 if the two structures are not nested in the same way. 

61 TypeError: If the two structures differ in the type of sequence in any of 

62 their substructures. Only possible if `check_types` is `True`. 

63 """ 

64 nest_util.assert_same_structure( 

65 nest_util.Modality.DATA, nest1, nest2, check_types 

66 ) 

67 

68 

69def pack_sequence_as(structure, flat_sequence): 

70 """Returns a given flattened sequence packed into a nest. 

71 

72 If `structure` is a scalar, `flat_sequence` must be a single-element list; 

73 in this case the return value is `flat_sequence[0]`. 

74 

75 Args: 

76 structure: tuple or list constructed of scalars and/or other tuples/lists, 

77 or a scalar. Note: numpy arrays are considered scalars. 

78 flat_sequence: flat sequence to pack. 

79 

80 Returns: 

81 packed: `flat_sequence` converted to have the same recursive structure as 

82 `structure`. 

83 

84 Raises: 

85 ValueError: If nest and structure have different element counts. 

86 """ 

87 return nest_util.pack_sequence_as( 

88 nest_util.Modality.DATA, structure, flat_sequence, expand_composites=False 

89 ) 

90 

91 

92def map_structure(func, *structure, **check_types_dict): 

93 """Applies `func` to each entry in `structure` and returns a new structure. 

94 

95 Applies `func(x[0], x[1], ...)` where x[i] is an entry in 

96 `structure[i]`. All structures in `structure` must have the same arity, 

97 and the return value will contain the results in the same structure. 

98 

99 Args: 

100 func: A callable that accepts as many arguments are there are structures. 

101 *structure: scalar, or tuple or list of constructed scalars and/or other 

102 tuples/lists, or scalars. Note: numpy arrays are considered scalars. 

103 **check_types_dict: only valid keyword argument is `check_types`. If set to 

104 `True` (default) the types of iterables within the structures have to be 

105 same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError` 

106 exception). To allow this set this argument to `False`. 

107 

108 Returns: 

109 A new structure with the same arity as `structure`, whose values correspond 

110 to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding 

111 location in `structure[i]`. If there are different sequence types and 

112 `check_types` is `False` the sequence types of the first structure will be 

113 used. 

114 

115 Raises: 

116 TypeError: If `func` is not callable or if the structures do not match 

117 each other by depth tree. 

118 ValueError: If no structure is provided or if the structures do not match 

119 each other by type. 

120 ValueError: If wrong keyword arguments are provided. 

121 """ 

122 return nest_util.map_structure( 

123 nest_util.Modality.DATA, func, *structure, **check_types_dict 

124 ) 

125 

126 

127def assert_shallow_structure(shallow_tree, input_tree, check_types=True): 

128 """Asserts that `shallow_tree` is a shallow structure of `input_tree`. 

129 

130 That is, this function tests if the `input_tree` structure can be created from 

131 the `shallow_tree` structure by replacing its leaf nodes with deeper 

132 tree structures. 

133 

134 Examples: 

135 

136 The following code will raise an exception: 

137 ```python 

138 shallow_tree = ["a", "b"] 

139 input_tree = ["c", ["d", "e"], "f"] 

140 assert_shallow_structure(shallow_tree, input_tree) 

141 ``` 

142 

143 The following code will not raise an exception: 

144 ```python 

145 shallow_tree = ["a", "b"] 

146 input_tree = ["c", ["d", "e"]] 

147 assert_shallow_structure(shallow_tree, input_tree) 

148 ``` 

149 

150 Args: 

151 shallow_tree: an arbitrarily nested structure. 

152 input_tree: an arbitrarily nested structure. 

153 check_types: if `True` (default) the sequence types of `shallow_tree` and 

154 `input_tree` have to be the same. 

155 

156 Raises: 

157 TypeError: If `shallow_tree` is a sequence but `input_tree` is not. 

158 TypeError: If the sequence types of `shallow_tree` are different from 

159 `input_tree`. Only raised if `check_types` is `True`. 

160 ValueError: If the sequence lengths of `shallow_tree` are different from 

161 `input_tree`. 

162 """ 

163 nest_util.assert_shallow_structure( 

164 nest_util.Modality.DATA, shallow_tree, input_tree, check_types 

165 ) 

166 

167 

168def flatten_up_to(shallow_tree, input_tree): 

169 """Flattens `input_tree` up to `shallow_tree`. 

170 

171 Any further depth in structure in `input_tree` is retained as elements in the 

172 partially flatten output. 

173 

174 If `shallow_tree` and `input_tree` are not sequences, this returns a 

175 single-element list: `[input_tree]`. 

176 

177 Use Case: 

178 

179 Sometimes we may wish to partially flatten a nested sequence, retaining some 

180 of the nested structure. We achieve this by specifying a shallow structure, 

181 `shallow_tree`, we wish to flatten up to. 

182 

183 The input, `input_tree`, can be thought of as having the same structure as 

184 `shallow_tree`, but with leaf nodes that are themselves tree structures. 

185 

186 Examples: 

187 

188 ```python 

189 input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] 

190 shallow_tree = [[True, True], [False, True]] 

191 

192 flattened_input_tree = flatten_up_to(shallow_tree, input_tree) 

193 flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree) 

194 

195 # Output is: 

196 # [[2, 2], [3, 3], [4, 9], [5, 5]] 

197 # [True, True, False, True] 

198 ``` 

199 

200 ```python 

201 input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]] 

202 shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]] 

203 

204 input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree) 

205 input_tree_flattened = flatten(input_tree) 

206 

207 # Output is: 

208 # [('a', 1), ('b', 2), ('c', 3), ('d', 4)] 

209 # ['a', 1, 'b', 2, 'c', 3, 'd', 4] 

210 ``` 

211 

212 Non-Sequence Edge Cases: 

213 

214 ```python 

215 flatten_up_to(0, 0) # Output: [0] 

216 flatten_up_to(0, [0, 1, 2]) # Output: [[0, 1, 2]] 

217 flatten_up_to([0, 1, 2], 0) # Output: TypeError 

218 flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2] 

219 ``` 

220 

221 Args: 

222 shallow_tree: a possibly pruned structure of input_tree. 

223 input_tree: an arbitrarily nested structure or a scalar object. 

224 Note, numpy arrays are considered scalars. 

225 

226 Returns: 

227 A Python list, the partially flattened version of `input_tree` according to 

228 the structure of `shallow_tree`. 

229 

230 Raises: 

231 TypeError: If `shallow_tree` is a sequence but `input_tree` is not. 

232 TypeError: If the sequence types of `shallow_tree` are different from 

233 `input_tree`. 

234 ValueError: If the sequence lengths of `shallow_tree` are different from 

235 `input_tree`. 

236 """ 

237 return nest_util.flatten_up_to( 

238 nest_util.Modality.DATA, shallow_tree, input_tree 

239 ) 

240 

241 

242def map_structure_up_to(shallow_tree, func, *inputs): 

243 """Applies a function or op to a number of partially flattened inputs. 

244 

245 The `inputs` are flattened up to `shallow_tree` before being mapped. 

246 

247 Use Case: 

248 

249 Sometimes we wish to apply a function to a partially flattened 

250 sequence (for example when the function itself takes sequence inputs). We 

251 achieve this by specifying a shallow structure, `shallow_tree` we wish to 

252 flatten up to. 

253 

254 The `inputs`, can be thought of as having the same structure as 

255 `shallow_tree`, but with leaf nodes that are themselves tree structures. 

256 

257 This function, therefore, will return something with the same base structure 

258 as `shallow_tree`. 

259 

260 Examples: 

261 

262 ```python 

263 ab_tuple = collections.namedtuple("ab_tuple", "a, b") 

264 op_tuple = collections.namedtuple("op_tuple", "add, mul") 

265 inp_val = ab_tuple(a=2, b=3) 

266 inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) 

267 out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul, 

268 inp_val, inp_ops) 

269 

270 # Output is: ab_tuple(a=6, b=15) 

271 ``` 

272 

273 ```python 

274 data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] 

275 name_list = ['evens', ['odds', 'primes']] 

276 out = map_structure_up_to( 

277 name_list, 

278 lambda name, sec: "first_{}_{}".format(len(sec), name), 

279 name_list, data_list) 

280 

281 # Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']] 

282 ``` 

283 

284 Args: 

285 shallow_tree: a shallow tree, common to all the inputs. 

286 func: callable which will be applied to each input individually. 

287 *inputs: arbitrarily nested combination of objects that are compatible with 

288 shallow_tree. The function `func` is applied to corresponding 

289 partially flattened elements of each input, so the function must support 

290 arity of `len(inputs)`. 

291 

292 Raises: 

293 TypeError: If `shallow_tree` is a sequence but `input_tree` is not. 

294 TypeError: If the sequence types of `shallow_tree` are different from 

295 `input_tree`. 

296 ValueError: If the sequence lengths of `shallow_tree` are different from 

297 `input_tree`. 

298 

299 Returns: 

300 result of repeatedly applying `func`, with same structure as 

301 `shallow_tree`. 

302 """ 

303 return nest_util.map_structure_up_to( 

304 nest_util.Modality.DATA, shallow_tree, func, *inputs 

305 )