Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/list_ops.py: 33%

165 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Ops to manipulate lists of tensors.""" 

16 

17# pylint: disable=g-bad-name 

18import numpy as np 

19 

20from tensorflow.core.framework import full_type_pb2 

21from tensorflow.python.framework import cpp_shape_inference_pb2 

22from tensorflow.python.framework import dtypes 

23from tensorflow.python.framework import ops 

24from tensorflow.python.framework import tensor_shape 

25from tensorflow.python.framework import tensor_util 

26from tensorflow.python.ops import array_ops 

27from tensorflow.python.ops import gen_list_ops 

28from tensorflow.python.ops import handle_data_util 

29# go/tf-wildcard-import 

30# pylint: disable=wildcard-import 

31from tensorflow.python.ops.gen_list_ops import * 

32# pylint: enable=wildcard-import 

33 

34 

35ops.NotDifferentiable("TensorListConcatLists") 

36ops.NotDifferentiable("TensorListElementShape") 

37ops.NotDifferentiable("TensorListLength") 

38ops.NotDifferentiable("TensorListPushBackBatch") 

39 

40 

41def empty_tensor_list(element_shape, 

42 element_dtype, 

43 max_num_elements=None, 

44 name=None): 

45 if max_num_elements is None: 

46 max_num_elements = -1 

47 

48 return gen_list_ops.empty_tensor_list( 

49 element_shape=_build_element_shape(element_shape), 

50 element_dtype=element_dtype, 

51 max_num_elements=max_num_elements, 

52 name=name) 

53 

54 

55def _set_handle_data(list_handle, element_shape, element_dtype): 

56 """Sets type information on `list_handle` for consistency with graphs.""" 

57 # TODO(b/169968286): It would be better if we had a consistent story for 

58 # creating handle data from eager operations (shared with VarHandleOp). 

59 if isinstance(list_handle, ops.EagerTensor): 

60 if tensor_util.is_tf_type(element_shape): 

61 element_shape = tensor_shape.TensorShape(None) 

62 elif not isinstance(element_shape, tensor_shape.TensorShape): 

63 element_shape = tensor_shape.TensorShape(element_shape) 

64 handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData() 

65 handle_data.is_set = True 

66 # TODO(b/191472076): This duplicates type inference. Clean up. 

67 handle_data.shape_and_type.append( 

68 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType( 

69 shape=element_shape.as_proto(), 

70 dtype=element_dtype.as_datatype_enum, 

71 type=full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_ARRAY))) 

72 list_handle._handle_data = handle_data # pylint: disable=protected-access 

73 

74 

75def tensor_list_reserve(element_shape, num_elements, element_dtype, name=None): 

76 result = gen_list_ops.tensor_list_reserve( 

77 element_shape=_build_element_shape(element_shape), 

78 num_elements=num_elements, 

79 element_dtype=element_dtype, 

80 name=name) 

81 # TODO(b/169968286): gen_ops needs to ensure the metadata is properly 

82 # populated for eager operations. 

83 _set_handle_data(result, element_shape, element_dtype) 

84 return result 

85 

86 

87def tensor_list_from_tensor(tensor, element_shape, name=None): 

88 tensor = ops.convert_to_tensor(tensor) 

89 result = gen_list_ops.tensor_list_from_tensor( 

90 tensor=tensor, 

91 element_shape=_build_element_shape(element_shape), 

92 name=name) 

93 _set_handle_data(result, tensor.shape, tensor.dtype) 

94 return result 

95 

96 

97def tensor_list_get_item(input_handle, index, element_dtype, element_shape=None, 

98 name=None): 

99 return gen_list_ops.tensor_list_get_item( 

100 input_handle=input_handle, 

101 index=index, 

102 element_shape=_build_element_shape(element_shape), 

103 element_dtype=element_dtype, 

104 name=name) 

105 

106 

107def tensor_list_pop_back(input_handle, element_dtype, name=None): 

108 return gen_list_ops.tensor_list_pop_back( 

109 input_handle=input_handle, 

110 element_shape=-1, 

111 element_dtype=element_dtype, 

112 name=name) 

113 

114 

115def tensor_list_gather(input_handle, 

116 indices, 

117 element_dtype, 

118 element_shape=None, 

119 name=None): 

120 return gen_list_ops.tensor_list_gather( 

121 input_handle=input_handle, 

122 indices=indices, 

123 element_shape=_build_element_shape(element_shape), 

124 element_dtype=element_dtype, 

125 name=name) 

126 

127 

128def tensor_list_scatter(tensor, 

129 indices, 

130 element_shape=None, 

131 input_handle=None, 

132 name=None): 

133 """Returns a TensorList created or updated by scattering `tensor`.""" 

134 tensor = ops.convert_to_tensor(tensor) 

135 if input_handle is not None: 

136 output_handle = gen_list_ops.tensor_list_scatter_into_existing_list( 

137 input_handle=input_handle, tensor=tensor, indices=indices, name=name) 

138 handle_data_util.copy_handle_data(input_handle, output_handle) 

139 return output_handle 

140 else: 

141 output_handle = gen_list_ops.tensor_list_scatter_v2( 

142 tensor=tensor, 

143 indices=indices, 

144 element_shape=_build_element_shape(element_shape), 

145 num_elements=-1, 

146 name=name) 

147 _set_handle_data(output_handle, element_shape, tensor.dtype) 

148 return output_handle 

149 

150 

151def tensor_list_stack(input_handle, 

152 element_dtype, 

153 num_elements=-1, 

154 element_shape=None, 

155 name=None): 

156 return gen_list_ops.tensor_list_stack( 

157 input_handle=input_handle, 

158 element_shape=_build_element_shape(element_shape), 

159 element_dtype=element_dtype, 

160 num_elements=num_elements, 

161 name=name) 

162 

163 

164def tensor_list_concat(input_handle, element_dtype, element_shape=None, 

165 name=None): 

166 # Ignore the lengths output of TensorListConcat. It is only used during 

167 # gradient computation. 

168 return gen_list_ops.tensor_list_concat_v2( 

169 input_handle=input_handle, 

170 element_dtype=element_dtype, 

171 element_shape=_build_element_shape(element_shape), 

172 leading_dims=ops.convert_to_tensor([], dtype=dtypes.int64), 

173 name=name)[0] 

174 

175 

176def tensor_list_split(tensor, element_shape, lengths, name=None): 

177 return gen_list_ops.tensor_list_split( 

178 tensor=tensor, 

179 element_shape=_build_element_shape(element_shape), 

180 lengths=lengths, 

181 name=name) 

182 

183 

184def tensor_list_set_item(input_handle, 

185 index, 

186 item, 

187 resize_if_index_out_of_bounds=False, 

188 name=None): 

189 """Sets `item` at `index` in input list.""" 

190 output_handle = gen_list_ops.tensor_list_set_item( 

191 input_handle=input_handle, 

192 index=index, 

193 item=item, 

194 name=name, 

195 resize_if_index_out_of_bounds=resize_if_index_out_of_bounds, 

196 ) 

197 handle_data_util.copy_handle_data(input_handle, output_handle) 

198 return output_handle 

199 

200 

201@ops.RegisterGradient("TensorListPushBack") 

202def _PushBackGrad(op, dresult): 

203 return gen_list_ops.tensor_list_pop_back( 

204 dresult, 

205 element_shape=array_ops.shape(op.inputs[1]), 

206 element_dtype=op.get_attr("element_dtype")) 

207 

208 

209@ops.RegisterGradient("TensorListPopBack") 

210def _PopBackGrad(op, dlist, delement): 

211 if dlist is None: 

212 dlist = empty_tensor_list( 

213 element_dtype=delement.dtype, 

214 element_shape=gen_list_ops.tensor_list_element_shape( 

215 op.outputs[0], shape_type=dtypes.int32)) 

216 if delement is None: 

217 delement = array_ops.zeros_like(op.outputs[1]) 

218 return gen_list_ops.tensor_list_push_back(dlist, delement), None 

219 

220 

221@ops.RegisterGradient("TensorListStack") 

222def _TensorListStackGrad(unused_op, dtensor): 

223 return tensor_list_from_tensor(dtensor, element_shape=dtensor.shape[1:]), None 

224 

225 

226@ops.RegisterGradient("TensorListConcat") 

227@ops.RegisterGradient("TensorListConcatV2") 

228def _TensorListConcatGrad(op, dtensor, unused_dlengths): 

229 """Gradient function for TensorListConcat.""" 

230 dlist = tensor_list_split( 

231 dtensor, 

232 element_shape=gen_list_ops.tensor_list_element_shape( 

233 op.inputs[0], shape_type=dtypes.int32), 

234 lengths=op.outputs[1]) 

235 if op.type == "TensorListConcatV2": 

236 return dlist, None, None 

237 else: 

238 return dlist 

239 

240 

241@ops.RegisterGradient("TensorListSplit") 

242def _TensorListSplitGrad(op, dlist): 

243 tensor, _, lengths = op.inputs 

244 element_shape = array_ops.slice(array_ops.shape(tensor), [1], [-1]) 

245 element_shape = array_ops.concat([[-1], element_shape], axis=0) 

246 return gen_list_ops.tensor_list_concat_v2( 

247 dlist, 

248 element_shape=element_shape, 

249 leading_dims=lengths, 

250 element_dtype=op.inputs[0].dtype)[0], None, None 

251 

252 

253@ops.RegisterGradient("TensorListFromTensor") 

254def _TensorListFromTensorGrad(op, dlist): 

255 """Gradient for TensorListFromTensor.""" 

256 t = op.inputs[0] 

257 if t.shape.dims and t.shape.dims[0].value is not None: 

258 num_elements = t.shape.dims[0].value 

259 else: 

260 num_elements = None 

261 if dlist is None: 

262 dlist = empty_tensor_list( 

263 element_dtype=t.dtype, 

264 element_shape=gen_list_ops.tensor_list_element_shape( 

265 op.outputs[0], shape_type=dtypes.int32)) 

266 tensor_grad = gen_list_ops.tensor_list_stack( 

267 dlist, 

268 element_shape=array_ops.slice(array_ops.shape(t), [1], [-1]), 

269 element_dtype=t.dtype, 

270 num_elements=num_elements) 

271 shape_grad = None 

272 return tensor_grad, shape_grad 

273 

274 

275@ops.RegisterGradient("TensorListGetItem") 

276def _TensorListGetItemGrad(op, ditem): 

277 """Gradient for TensorListGetItem.""" 

278 list_size = gen_list_ops.tensor_list_length(op.inputs[0]) 

279 list_grad = gen_list_ops.tensor_list_set_item( 

280 gen_list_ops.tensor_list_reserve( 

281 gen_list_ops.tensor_list_element_shape(op.inputs[0], 

282 shape_type=dtypes.int32), 

283 list_size, element_dtype=ditem.dtype), 

284 index=op.inputs[1], 

285 item=ditem) 

286 index_grad = None 

287 element_shape_grad = None 

288 return list_grad, index_grad, element_shape_grad 

289 

290 

291@ops.RegisterGradient("TensorListSetItem") 

292def _TensorListSetItemGrad(op, dlist): 

293 """Gradient function for TensorListSetItem.""" 

294 input_list, index, item = op.inputs 

295 list_grad = gen_list_ops.tensor_list_set_item( 

296 dlist, index=index, item=array_ops.zeros_like(item) 

297 ) 

298 index_grad = None 

299 element_grad = tensor_list_get_item( 

300 dlist, 

301 index, 

302 element_shape=array_ops.shape(item), 

303 element_dtype=item.dtype, 

304 ) 

305 if op.get_attr( 

306 "resize_if_index_out_of_bounds" 

307 ): 

308 input_list_size = gen_list_ops.tensor_list_length(input_list) 

309 list_grad = gen_list_ops.tensor_list_resize(list_grad, input_list_size) 

310 return list_grad, index_grad, element_grad 

311 

312 

313@ops.RegisterGradient("TensorListResize") 

314def _TensorListResizeGrad(op, dlist): 

315 input_list, _ = op.inputs 

316 input_list_size = gen_list_ops.tensor_list_length(input_list) 

317 return gen_list_ops.tensor_list_resize(dlist, input_list_size), None 

318 

319 

320@ops.RegisterGradient("TensorListGather") 

321def _TensorListGatherGrad(op, dtensor): 

322 """Gradient function for TensorListGather.""" 

323 input_list, indices, _ = op.inputs 

324 element_shape = gen_list_ops.tensor_list_element_shape( 

325 input_list, shape_type=dtypes.int32) 

326 num_elements = gen_list_ops.tensor_list_length(input_list) 

327 dlist = tensor_list_reserve(element_shape, num_elements, dtensor.dtype) 

328 dlist = tensor_list_scatter( 

329 tensor=dtensor, indices=indices, input_handle=dlist) 

330 return dlist, None, None 

331 

332 

333@ops.RegisterGradient("TensorListScatter") 

334@ops.RegisterGradient("TensorListScatterV2") 

335def _TensorListScatterGrad(op, dlist): 

336 """Gradient function for TensorListScatter.""" 

337 tensor = op.inputs[0] 

338 indices = op.inputs[1] 

339 dtensor = gen_list_ops.tensor_list_gather( 

340 dlist, 

341 indices, 

342 element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]), 

343 element_dtype=tensor.dtype) 

344 if op.type == "TensorListScatterV2": 

345 return dtensor, None, None, None 

346 else: 

347 return dtensor, None, None 

348 

349 

350@ops.RegisterGradient("TensorListScatterIntoExistingList") 

351def _TensorListScatterIntoExistingListGrad(op, dlist): 

352 """Gradient function for TensorListScatterIntoExistingList.""" 

353 _, tensor, indices = op.inputs 

354 dtensor = gen_list_ops.tensor_list_gather( 

355 dlist, 

356 indices, 

357 element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]), 

358 element_dtype=tensor.dtype) 

359 zeros = array_ops.zeros_like(tensor) 

360 dlist = tensor_list_scatter(zeros, indices, indices, input_handle=dlist) 

361 return dlist, dtensor, None 

362 

363 

364def _build_element_shape(shape): 

365 """Converts shape to a format understood by list_ops for element_shape. 

366 

367 If `shape` is already a `Tensor` it is returned as-is. We do not perform a 

368 type check here. 

369 

370 If shape is None or a TensorShape with unknown rank, -1 is returned. 

371 

372 If shape is a scalar, an int32 tensor with empty list is returned. Note we 

373 do directly return an empty list since ops.convert_to_tensor would conver it 

374 to a float32 which is not a valid type for element_shape. 

375 

376 If shape is a sequence of dims, None's in the list are replaced with -1. We 

377 do not check the dtype of the other dims. 

378 

379 Args: 

380 shape: Could be None, Tensor, TensorShape or a list of dims (each dim could 

381 be a None, scalar or Tensor). 

382 

383 Returns: 

384 A None-free shape that can be converted to a tensor. 

385 """ 

386 if isinstance(shape, ops.Tensor): 

387 return shape 

388 if isinstance(shape, tensor_shape.TensorShape): 

389 # `TensorShape.as_list` requires rank to be known. 

390 shape = shape.as_list() if shape else None 

391 # Shape is unknown. 

392 if shape is None: 

393 return -1 

394 # Shape is numpy array or a scalar. 

395 if isinstance(shape, (np.ndarray, np.generic)) or not shape: 

396 return ops.convert_to_tensor(shape, dtype=dtypes.int32) 

397 # Shape is a sequence of dimensions. Convert None dims to -1. 

398 def convert(val): 

399 if val is None: 

400 return -1 

401 if isinstance(val, ops.Tensor): 

402 return val 

403 if isinstance(val, tensor_shape.Dimension): 

404 return val.value if val.value is not None else -1 

405 return val 

406 

407 return [convert(d) for d in shape]