Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/operators/data_structures.py: 20%

147 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Operators specific to data structures: list append, subscripts, etc.""" 

16 

17import collections 

18 

19from tensorflow.python.framework import constant_op 

20from tensorflow.python.framework import dtypes 

21from tensorflow.python.framework import ops 

22from tensorflow.python.framework import tensor_util 

23from tensorflow.python.ops import array_ops 

24from tensorflow.python.ops import cond 

25from tensorflow.python.ops import list_ops 

26from tensorflow.python.ops import tensor_array_ops 

27 

28 

29# TODO(mdan): Once control flow supports objects, repackage as a class. 

30 

31 

32def new_list(iterable=None): 

33 """The list constructor. 

34 

35 Args: 

36 iterable: Optional elements to fill the list with. 

37 

38 Returns: 

39 A list-like object. The exact return value depends on the initial elements. 

40 """ 

41 if iterable: 

42 elements = tuple(iterable) 

43 else: 

44 elements = () 

45 

46 if elements: 

47 # When the list contains elements, it is assumed to be a "Python" lvalue 

48 # list. 

49 return _py_list_new(elements) 

50 return tf_tensor_list_new(elements) 

51 

52 

53def tf_tensor_array_new(elements, element_dtype=None, element_shape=None): 

54 """Overload of new_list that stages a Tensor list creation.""" 

55 elements = tuple(ops.convert_to_tensor(el) for el in elements) 

56 

57 all_dtypes = set(el.dtype for el in elements) 

58 if len(all_dtypes) == 1: 

59 inferred_dtype, = tuple(all_dtypes) 

60 if element_dtype is not None and element_dtype != inferred_dtype: 

61 raise ValueError( 

62 'incompatible dtype; specified: {}, inferred from {}: {}'.format( 

63 element_dtype, elements, inferred_dtype)) 

64 elif len(all_dtypes) > 1: 

65 raise ValueError( 

66 'TensorArray requires all elements to have the same dtype:' 

67 ' {}'.format(elements)) 

68 else: 

69 if element_dtype is None: 

70 raise ValueError('dtype is required to create an empty TensorArray') 

71 

72 all_shapes = set(tuple(el.shape.as_list()) for el in elements) 

73 if len(all_shapes) == 1: 

74 inferred_shape, = tuple(all_shapes) 

75 if element_shape is not None and element_shape != inferred_shape: 

76 raise ValueError( 

77 'incompatible shape; specified: {}, inferred from {}: {}'.format( 

78 element_shape, elements, inferred_shape)) 

79 elif len(all_shapes) > 1: 

80 raise ValueError( 

81 'TensorArray requires all elements to have the same shape:' 

82 ' {}'.format(elements)) 

83 # TODO(mdan): We may want to allow different shapes with infer_shape=False. 

84 else: 

85 inferred_shape = None 

86 

87 if element_dtype is None: 

88 element_dtype = inferred_dtype 

89 if element_shape is None: 

90 element_shape = inferred_shape 

91 

92 l = tensor_array_ops.TensorArray( 

93 dtype=element_dtype, 

94 size=len(elements), 

95 dynamic_size=True, 

96 infer_shape=(element_shape is None), 

97 element_shape=element_shape) 

98 for i, el in enumerate(elements): 

99 l = l.write(i, el) 

100 return l 

101 

102 

103def tf_tensor_list_new(elements, element_dtype=None, element_shape=None): 

104 """Overload of new_list that stages a Tensor list creation.""" 

105 if tensor_util.is_tf_type(elements): 

106 if element_shape is not None: 

107 raise ValueError( 

108 'element shape may not be specified when creating list from tensor') 

109 element_shape = array_ops.shape(elements)[1:] 

110 l = list_ops.tensor_list_from_tensor(elements, element_shape=element_shape) 

111 return l 

112 

113 elements = tuple(ops.convert_to_tensor(el) for el in elements) 

114 

115 all_dtypes = set(el.dtype for el in elements) 

116 if len(all_dtypes) == 1: 

117 inferred_dtype = tuple(all_dtypes)[0] 

118 if element_dtype is not None and element_dtype != inferred_dtype: 

119 raise ValueError( 

120 'incompatible dtype; specified: {}, inferred from {}: {}'.format( 

121 element_dtype, elements, inferred_dtype)) 

122 elif all_dtypes: 

123 # Heterogeneous lists are ok. 

124 if element_dtype is not None: 

125 raise ValueError( 

126 'specified dtype {} is inconsistent with that of elements {}'.format( 

127 element_dtype, elements)) 

128 inferred_dtype = dtypes.variant 

129 else: 

130 inferred_dtype = dtypes.variant 

131 

132 all_shapes = set(tuple(el.shape.as_list()) for el in elements) 

133 if len(all_shapes) == 1: 

134 inferred_shape = array_ops.shape(elements[0]) 

135 if element_shape is not None and element_shape != inferred_shape: 

136 raise ValueError( 

137 'incompatible shape; specified: {}, inferred from {}: {}'.format( 

138 element_shape, elements, inferred_shape)) 

139 elif all_shapes: 

140 # Heterogeneous lists are ok. 

141 if element_shape is not None: 

142 raise ValueError( 

143 'specified shape {} is inconsistent with that of elements {}'.format( 

144 element_shape, elements)) 

145 inferred_shape = constant_op.constant(-1) # unknown shape, by convention 

146 else: 

147 inferred_shape = constant_op.constant(-1) # unknown shape, by convention 

148 

149 if element_dtype is None: 

150 element_dtype = inferred_dtype 

151 if element_shape is None: 

152 element_shape = inferred_shape 

153 

154 element_shape = ops.convert_to_tensor(element_shape, dtype=dtypes.int32) 

155 l = list_ops.empty_tensor_list( 

156 element_shape=element_shape, element_dtype=element_dtype) 

157 for el in elements: 

158 l = list_ops.tensor_list_push_back(l, el) 

159 return l 

160 

161 

162def _py_list_new(elements): 

163 """Overload of new_list that creates a Python list.""" 

164 return list(elements) 

165 

166 

167def list_append(list_, x): 

168 """The list append function. 

169 

170 Note: it is unspecified where list_ will be mutated or not. If list_ is 

171 a TensorFlow entity, it will not be typically mutated. If list_ is a plain 

172 list, it will be. In general, if the list is mutated then the return value 

173 should point to the original entity. 

174 

175 Args: 

176 list_: An entity that supports append semantics. 

177 x: The element to append. 

178 

179 Returns: 

180 Same as list_, after the append was performed. 

181 

182 Raises: 

183 ValueError: if list_ is not of a known list-like type. 

184 """ 

185 if isinstance(list_, tensor_array_ops.TensorArray): 

186 return _tf_tensorarray_append(list_, x) 

187 elif tensor_util.is_tf_type(list_): 

188 if list_.dtype == dtypes.variant: 

189 return _tf_tensor_list_append(list_, x) 

190 else: 

191 raise ValueError( 

192 'tensor lists are expected to be Tensors with dtype=tf.variant,' 

193 ' instead found %s' % list_) 

194 else: 

195 return _py_list_append(list_, x) 

196 

197 

198def _tf_tensor_list_append(list_, x): 

199 """Overload of list_append that stages a Tensor list write.""" 

200 def empty_list_of_elements_like_x(): 

201 tensor_x = ops.convert_to_tensor(x) 

202 return list_ops.empty_tensor_list( 

203 element_shape=array_ops.shape(tensor_x), 

204 element_dtype=tensor_x.dtype) 

205 

206 list_ = cond.cond( 

207 list_ops.tensor_list_length(list_) > 0, 

208 lambda: list_, 

209 empty_list_of_elements_like_x, 

210 ) 

211 return list_ops.tensor_list_push_back(list_, x) 

212 

213 

214def _tf_tensorarray_append(list_, x): 

215 """Overload of list_append that stages a TensorArray write.""" 

216 return list_.write(list_.size(), x) 

217 

218 

219def _py_list_append(list_, x): 

220 """Overload of list_append that executes a Python list append.""" 

221 # Revert to the original call. 

222 list_.append(x) 

223 return list_ 

224 

225 

226class ListPopOpts( 

227 collections.namedtuple('ListPopOpts', ('element_dtype', 'element_shape'))): 

228 pass 

229 

230 

231def list_pop(list_, i, opts): 

232 """The list pop function. 

233 

234 Note: it is unspecified where list_ will be mutated or not. If list_ is 

235 a TensorFlow entity, it will not be typically mutated. If list_ is a plain 

236 list, it will be. In general, if the list is mutated then the return value 

237 should point to the original entity. 

238 

239 Args: 

240 list_: An entity that supports pop semantics. 

241 i: Optional index to pop from. May be None. 

242 opts: A ListPopOpts. 

243 

244 Returns: 

245 Tuple (x, out_list_): 

246 out_list_: same as list_, after the removal was performed. 

247 x: the removed element value. 

248 

249 Raises: 

250 ValueError: if list_ is not of a known list-like type or the operation is 

251 not supported for that type. 

252 """ 

253 assert isinstance(opts, ListPopOpts) 

254 

255 if isinstance(list_, tensor_array_ops.TensorArray): 

256 raise ValueError('TensorArray does not support item removal') 

257 elif tensor_util.is_tf_type(list_): 

258 if list_.dtype == dtypes.variant: 

259 return _tf_tensor_list_pop(list_, i, opts) 

260 else: 

261 raise ValueError( 

262 'tensor lists are expected to be Tensors with dtype=tf.variant,' 

263 ' instead found %s' % list_) 

264 else: 

265 return _py_list_pop(list_, i) 

266 

267 

268def _tf_tensor_list_pop(list_, i, opts): 

269 """Overload of list_pop that stages a Tensor list pop.""" 

270 if i is not None: 

271 raise NotImplementedError('tensor lists only support removing from the end') 

272 

273 if opts.element_dtype is None: 

274 raise ValueError('cannot pop from a list without knowing its element ' 

275 'type; use set_element_type to annotate it') 

276 if opts.element_shape is None: 

277 raise ValueError('cannot pop from a list without knowing its element ' 

278 'shape; use set_element_type to annotate it') 

279 list_out, x = list_ops.tensor_list_pop_back( 

280 list_, element_dtype=opts.element_dtype) 

281 x.set_shape(opts.element_shape) 

282 return list_out, x 

283 

284 

285def _py_list_pop(list_, i): 

286 """Overload of list_pop that executes a Python list append.""" 

287 if i is None: 

288 x = list_.pop() 

289 else: 

290 x = list_.pop(i) 

291 return list_, x 

292 

293 

294# TODO(mdan): Look into reducing duplication between all these containers. 

295class ListStackOpts( 

296 collections.namedtuple('ListStackOpts', 

297 ('element_dtype', 'original_call'))): 

298 pass 

299 

300 

301def list_stack(list_, opts): 

302 """The list stack function. 

303 

304 This does not have a direct correspondent in Python. The closest idiom to 

305 this is tf.append or np.stack. It's different from those in the sense that it 

306 accepts a Tensor list, rather than a list of tensors. It can also accept 

307 TensorArray. When the target is anything else, the dispatcher will rely on 

308 ctx.original_call for fallback. 

309 

310 Args: 

311 list_: An entity that supports append semantics. 

312 opts: A ListStackOpts object. 

313 

314 Returns: 

315 The output of the stack operation, typically a Tensor. 

316 """ 

317 assert isinstance(opts, ListStackOpts) 

318 

319 if isinstance(list_, tensor_array_ops.TensorArray): 

320 return _tf_tensorarray_stack(list_) 

321 elif tensor_util.is_tf_type(list_): 

322 if list_.dtype == dtypes.variant: 

323 return _tf_tensor_list_stack(list_, opts) 

324 else: 

325 # No-op for primitive Tensor arguments. 

326 return list_ 

327 else: 

328 return _py_list_stack(list_, opts) 

329 

330 

331def _tf_tensorarray_stack(list_): 

332 """Overload of list_stack that stages a TensorArray stack.""" 

333 return list_.stack() 

334 

335 

336def _tf_tensor_list_stack(list_, opts): 

337 """Overload of list_stack that stages a Tensor list write.""" 

338 if opts.element_dtype is None: 

339 raise ValueError('cannot stack a list without knowing its element type;' 

340 ' use set_element_type to annotate it') 

341 return list_ops.tensor_list_stack(list_, element_dtype=opts.element_dtype) 

342 

343 

344def _py_list_stack(list_, opts): 

345 """Overload of list_stack that executes a Python list append.""" 

346 # Revert to the original call. 

347 return opts.original_call(list_)