Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/util/structure.py: 34%

204 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities for describing the structure of a `tf.data` type.""" 

16import collections 

17import functools 

18import itertools 

19 

20import wrapt 

21 

22from tensorflow.python.data.util import nest 

23from tensorflow.python.framework import composite_tensor 

24from tensorflow.python.framework import ops 

25from tensorflow.python.framework import sparse_tensor 

26from tensorflow.python.framework import tensor_shape 

27from tensorflow.python.framework import tensor_spec 

28from tensorflow.python.framework import type_spec 

29from tensorflow.python.framework import type_spec_registry 

30from tensorflow.python.ops import resource_variable_ops 

31from tensorflow.python.ops import tensor_array_ops 

32from tensorflow.python.ops.ragged import ragged_tensor 

33from tensorflow.python.platform import tf_logging as logging 

34from tensorflow.python.types import internal 

35from tensorflow.python.util import deprecation 

36from tensorflow.python.util.compat import collections_abc 

37from tensorflow.python.util.tf_export import tf_export 

38 

39 

40# pylint: disable=invalid-name 

41@tf_export(v1=["data.experimental.TensorStructure"]) 

42@deprecation.deprecated(None, "Use `tf.TensorSpec` instead.") 

43def _TensorStructure(dtype, shape): 

44 return tensor_spec.TensorSpec(shape, dtype) 

45 

46 

47@tf_export(v1=["data.experimental.SparseTensorStructure"]) 

48@deprecation.deprecated(None, "Use `tf.SparseTensorSpec` instead.") 

49def _SparseTensorStructure(dtype, shape): 

50 return sparse_tensor.SparseTensorSpec(shape, dtype) 

51 

52 

53@tf_export(v1=["data.experimental.TensorArrayStructure"]) 

54@deprecation.deprecated(None, "Use `tf.TensorArraySpec` instead.") 

55def _TensorArrayStructure(dtype, element_shape, dynamic_size, infer_shape): 

56 return tensor_array_ops.TensorArraySpec(element_shape, dtype, 

57 dynamic_size, infer_shape) 

58 

59 

60@tf_export(v1=["data.experimental.RaggedTensorStructure"]) 

61@deprecation.deprecated(None, "Use `tf.RaggedTensorSpec` instead.") 

62def _RaggedTensorStructure(dtype, shape, ragged_rank): 

63 return ragged_tensor.RaggedTensorSpec(shape, dtype, ragged_rank) 

64# pylint: enable=invalid-name 

65 

66 

67# TODO(jsimsa): Remove the special-case for `TensorArray` pass-through once 

68# it is a subclass of `CompositeTensor`. 

69def normalize_element(element, element_signature=None): 

70 """Normalizes a nested structure of element components. 

71 

72 * Components matching `SparseTensorSpec` are converted to `SparseTensor`. 

73 * Components matching `RaggedTensorSpec` are converted to `RaggedTensor`. 

74 * Components matching `VariableSpec` are converted to `Tensor`. 

75 * Components matching `DatasetSpec` or `TensorArraySpec` are passed through. 

76 * `CompositeTensor` components are passed through. 

77 * All other components are converted to `Tensor`. 

78 

79 Args: 

80 element: A nested structure of individual components. 

81 element_signature: (Optional.) A nested structure of `tf.DType` objects 

82 corresponding to each component of `element`. If specified, it will be 

83 used to set the exact type of output tensor when converting input 

84 components which are not tensors themselves (e.g. numpy arrays, native 

85 python types, etc.) 

86 

87 Returns: 

88 A nested structure of `Tensor`, `Variable`, `Dataset`, `SparseTensor`, 

89 `RaggedTensor`, or `TensorArray` objects. 

90 """ 

91 normalized_components = [] 

92 if element_signature is None: 

93 components = nest.flatten(element) 

94 flattened_signature = [None] * len(components) 

95 pack_as = element 

96 else: 

97 flattened_signature = nest.flatten(element_signature) 

98 components = nest.flatten_up_to(element_signature, element) 

99 pack_as = element_signature 

100 with ops.name_scope("normalize_element"): 

101 for i, (t, spec) in enumerate(zip(components, flattened_signature)): 

102 try: 

103 if spec is None: 

104 spec = type_spec_from_value(t, use_fallback=False) 

105 except TypeError: 

106 # TypeError indicates it was not possible to compute a `TypeSpec` for 

107 # the value. As a fallback try converting the value to a tensor. 

108 normalized_components.append( 

109 ops.convert_to_tensor(t, name="component_%d" % i)) 

110 else: 

111 # To avoid a circular dependency between dataset_ops and structure, 

112 # we check the class name instead of using `isinstance`. 

113 if spec.__class__.__name__ == "DatasetSpec": 

114 normalized_components.append(t) 

115 elif isinstance(spec, sparse_tensor.SparseTensorSpec): 

116 normalized_components.append(sparse_tensor.SparseTensor.from_value(t)) 

117 elif isinstance(spec, ragged_tensor.RaggedTensorSpec): 

118 normalized_components.append( 

119 ragged_tensor.convert_to_tensor_or_ragged_tensor( 

120 t, name="component_%d" % i)) 

121 elif isinstance(spec, (tensor_array_ops.TensorArraySpec)): 

122 normalized_components.append(t) 

123 elif isinstance(spec, NoneTensorSpec): 

124 normalized_components.append(NoneTensor()) 

125 elif isinstance(spec, resource_variable_ops.VariableSpec): 

126 normalized_components.append( 

127 ops.convert_to_tensor(t, name=f"component_{i}", dtype=spec.dtype)) 

128 elif isinstance(t, composite_tensor.CompositeTensor): 

129 normalized_components.append(t) 

130 else: 

131 dtype = getattr(spec, "dtype", None) 

132 normalized_components.append( 

133 ops.convert_to_tensor(t, name="component_%d" % i, dtype=dtype)) 

134 return nest.pack_sequence_as(pack_as, normalized_components) 

135 

136 

137def convert_legacy_structure(output_types, output_shapes, output_classes): 

138 """Returns a `Structure` that represents the given legacy structure. 

139 

140 This method provides a way to convert from the existing `Dataset` and 

141 `Iterator` structure-related properties to a `Structure` object. A "legacy" 

142 structure is represented by the `tf.data.Dataset.output_types`, 

143 `tf.data.Dataset.output_shapes`, and `tf.data.Dataset.output_classes` 

144 properties. 

145 

146 TODO(b/110122868): Remove this function once `Structure` is used throughout 

147 `tf.data`. 

148 

149 Args: 

150 output_types: A nested structure of `tf.DType` objects corresponding to 

151 each component of a structured value. 

152 output_shapes: A nested structure of `tf.TensorShape` objects 

153 corresponding to each component a structured value. 

154 output_classes: A nested structure of Python `type` objects corresponding 

155 to each component of a structured value. 

156 

157 Returns: 

158 A `Structure`. 

159 

160 Raises: 

161 TypeError: If a structure cannot be built from the arguments, because one of 

162 the component classes in `output_classes` is not supported. 

163 """ 

164 flat_types = nest.flatten(output_types) 

165 flat_shapes = nest.flatten(output_shapes) 

166 flat_classes = nest.flatten(output_classes) 

167 flat_ret = [] 

168 for flat_type, flat_shape, flat_class in zip(flat_types, flat_shapes, 

169 flat_classes): 

170 if isinstance(flat_class, type_spec.TypeSpec): 

171 flat_ret.append(flat_class) 

172 elif issubclass(flat_class, sparse_tensor.SparseTensor): 

173 flat_ret.append(sparse_tensor.SparseTensorSpec(flat_shape, flat_type)) 

174 elif issubclass(flat_class, ops.Tensor): 

175 flat_ret.append(tensor_spec.TensorSpec(flat_shape, flat_type)) 

176 elif issubclass(flat_class, tensor_array_ops.TensorArray): 

177 # We sneaked the dynamic_size and infer_shape into the legacy shape. 

178 flat_ret.append( 

179 tensor_array_ops.TensorArraySpec( 

180 flat_shape[2:], flat_type, 

181 dynamic_size=tensor_shape.dimension_value(flat_shape[0]), 

182 infer_shape=tensor_shape.dimension_value(flat_shape[1]))) 

183 else: 

184 # NOTE(mrry): Since legacy structures produced by iterators only 

185 # comprise Tensors, SparseTensors, and nests, we do not need to 

186 # support all structure types here. 

187 raise TypeError( 

188 "Could not build a structure for output class {}. Make sure any " 

189 "component class in `output_classes` inherits from one of the " 

190 "following classes: `tf.TypeSpec`, `tf.sparse.SparseTensor`, " 

191 "`tf.Tensor`, `tf.TensorArray`.".format(flat_class.__name__)) 

192 

193 return nest.pack_sequence_as(output_classes, flat_ret) 

194 

195 

196def _from_tensor_list_helper(decode_fn, element_spec, tensor_list): 

197 """Returns an element constructed from the given spec and tensor list. 

198 

199 Args: 

200 decode_fn: Method that constructs an element component from the element spec 

201 component and a tensor list. 

202 element_spec: A nested structure of `tf.TypeSpec` objects representing to 

203 element type specification. 

204 tensor_list: A list of tensors to use for constructing the value. 

205 

206 Returns: 

207 An element constructed from the given spec and tensor list. 

208 

209 Raises: 

210 ValueError: If the number of tensors needed to construct an element for 

211 the given spec does not match the given number of tensors. 

212 """ 

213 

214 # pylint: disable=protected-access 

215 

216 flat_specs = nest.flatten(element_spec) 

217 flat_spec_lengths = [len(spec._flat_tensor_specs) for spec in flat_specs] 

218 if sum(flat_spec_lengths) != len(tensor_list): 

219 raise ValueError("Expected {} tensors but got {}.".format( 

220 sum(flat_spec_lengths), len(tensor_list))) 

221 

222 i = 0 

223 flat_ret = [] 

224 for (component_spec, num_flat_values) in zip(flat_specs, flat_spec_lengths): 

225 value = tensor_list[i:i + num_flat_values] 

226 flat_ret.append(decode_fn(component_spec, value)) 

227 i += num_flat_values 

228 return nest.pack_sequence_as(element_spec, flat_ret) 

229 

230 

231def from_compatible_tensor_list(element_spec, tensor_list): 

232 """Returns an element constructed from the given spec and tensor list. 

233 

234 Args: 

235 element_spec: A nested structure of `tf.TypeSpec` objects representing to 

236 element type specification. 

237 tensor_list: A list of tensors to use for constructing the value. 

238 

239 Returns: 

240 An element constructed from the given spec and tensor list. 

241 

242 Raises: 

243 ValueError: If the number of tensors needed to construct an element for 

244 the given spec does not match the given number of tensors. 

245 """ 

246 

247 # pylint: disable=protected-access 

248 # pylint: disable=g-long-lambda 

249 return _from_tensor_list_helper( 

250 lambda spec, value: spec._from_compatible_tensor_list(value), 

251 element_spec, tensor_list) 

252 

253 

254def from_tensor_list(element_spec, tensor_list): 

255 """Returns an element constructed from the given spec and tensor list. 

256 

257 Args: 

258 element_spec: A nested structure of `tf.TypeSpec` objects representing to 

259 element type specification. 

260 tensor_list: A list of tensors to use for constructing the value. 

261 

262 Returns: 

263 An element constructed from the given spec and tensor list. 

264 

265 Raises: 

266 ValueError: If the number of tensors needed to construct an element for 

267 the given spec does not match the given number of tensors or the given 

268 spec is not compatible with the tensor list. 

269 """ 

270 

271 # pylint: disable=protected-access 

272 # pylint: disable=g-long-lambda 

273 return _from_tensor_list_helper( 

274 lambda spec, value: spec._from_tensor_list(value), element_spec, 

275 tensor_list) 

276 

277 

278def get_flat_tensor_specs(element_spec): 

279 """Returns a list `tf.TypeSpec`s for the element tensor representation. 

280 

281 Args: 

282 element_spec: A nested structure of `tf.TypeSpec` objects representing to 

283 element type specification. 

284 

285 Returns: 

286 A list `tf.TypeSpec`s for the element tensor representation. 

287 """ 

288 

289 # pylint: disable=protected-access 

290 return list( 

291 itertools.chain.from_iterable( 

292 spec._flat_tensor_specs for spec in nest.flatten(element_spec))) 

293 

294 

295def get_flat_tensor_shapes(element_spec): 

296 """Returns a list `tf.TensorShapes`s for the element tensor representation. 

297 

298 Args: 

299 element_spec: A nested structure of `tf.TypeSpec` objects representing to 

300 element type specification. 

301 

302 Returns: 

303 A list `tf.TensorShapes`s for the element tensor representation. 

304 """ 

305 return [spec.shape for spec in get_flat_tensor_specs(element_spec)] 

306 

307 

308def get_flat_tensor_types(element_spec): 

309 """Returns a list `tf.DType`s for the element tensor representation. 

310 

311 Args: 

312 element_spec: A nested structure of `tf.TypeSpec` objects representing to 

313 element type specification. 

314 

315 Returns: 

316 A list `tf.DType`s for the element tensor representation. 

317 """ 

318 return [spec.dtype for spec in get_flat_tensor_specs(element_spec)] 

319 

320 

321def _to_tensor_list_helper(encode_fn, element_spec, element): 

322 """Returns a tensor list representation of the element. 

323 

324 Args: 

325 encode_fn: Method that constructs a tensor list representation from the 

326 given element spec and element. 

327 element_spec: A nested structure of `tf.TypeSpec` objects representing to 

328 element type specification. 

329 element: The element to convert to tensor list representation. 

330 

331 Returns: 

332 A tensor list representation of `element`. 

333 

334 Raises: 

335 ValueError: If `element_spec` and `element` do not have the same number of 

336 elements or if the two structures are not nested in the same way. 

337 TypeError: If `element_spec` and `element` differ in the type of sequence 

338 in any of their substructures. 

339 """ 

340 

341 nest.assert_same_structure(element_spec, element) 

342 

343 def reduce_fn(state, value): 

344 spec, component = value 

345 if isinstance(spec, internal.TensorSpec): 

346 try: 

347 component = ops.convert_to_tensor(component, spec.dtype) 

348 except (TypeError, ValueError): 

349 raise ValueError( 

350 f"Value {component} is not convertible to a tensor with " 

351 f"dtype {spec.dtype} and shape {spec.shape}." 

352 ) 

353 if not component.shape.is_compatible_with(spec.shape): 

354 raise ValueError( 

355 f"Value {component} is not convertible to a tensor with " 

356 f"dtype {spec.dtype} and shape {spec.shape}." 

357 ) 

358 return encode_fn(state, spec, component) 

359 

360 return functools.reduce( 

361 reduce_fn, zip(nest.flatten(element_spec), nest.flatten(element)), []) 

362 

363 

364def to_batched_tensor_list(element_spec, element): 

365 """Returns a tensor list representation of the element. 

366 

367 Args: 

368 element_spec: A nested structure of `tf.TypeSpec` objects representing to 

369 element type specification. 

370 element: The element to convert to tensor list representation. 

371 

372 Returns: 

373 A tensor list representation of `element`. 

374 

375 Raises: 

376 ValueError: If `element_spec` and `element` do not have the same number of 

377 elements or if the two structures are not nested in the same way or the 

378 rank of any of the tensors in the tensor list representation is 0. 

379 TypeError: If `element_spec` and `element` differ in the type of sequence 

380 in any of their substructures. 

381 """ 

382 

383 # pylint: disable=protected-access 

384 # pylint: disable=g-long-lambda 

385 return _to_tensor_list_helper( 

386 lambda state, spec, component: state + spec._to_batched_tensor_list( 

387 component), element_spec, element) 

388 

389 

390def to_tensor_list(element_spec, element): 

391 """Returns a tensor list representation of the element. 

392 

393 Args: 

394 element_spec: A nested structure of `tf.TypeSpec` objects representing to 

395 element type specification. 

396 element: The element to convert to tensor list representation. 

397 

398 Returns: 

399 A tensor list representation of `element`. 

400 

401 Raises: 

402 ValueError: If `element_spec` and `element` do not have the same number of 

403 elements or if the two structures are not nested in the same way. 

404 TypeError: If `element_spec` and `element` differ in the type of sequence 

405 in any of their substructures. 

406 """ 

407 

408 # pylint: disable=protected-access 

409 # pylint: disable=g-long-lambda 

410 return _to_tensor_list_helper( 

411 lambda state, spec, component: state + spec._to_tensor_list(component), 

412 element_spec, element) 

413 

414 

415def are_compatible(spec1, spec2): 

416 """Indicates whether two type specifications are compatible. 

417 

418 Two type specifications are compatible if they have the same nested structure 

419 and the their individual components are pair-wise compatible. 

420 

421 Args: 

422 spec1: A `tf.TypeSpec` object to compare. 

423 spec2: A `tf.TypeSpec` object to compare. 

424 

425 Returns: 

426 `True` if the two type specifications are compatible and `False` otherwise. 

427 """ 

428 

429 try: 

430 nest.assert_same_structure(spec1, spec2) 

431 except TypeError: 

432 return False 

433 except ValueError: 

434 return False 

435 

436 for s1, s2 in zip(nest.flatten(spec1), nest.flatten(spec2)): 

437 if not s1.is_compatible_with(s2) or not s2.is_compatible_with(s1): 

438 return False 

439 return True 

440 

441 

442def type_spec_from_value(element, use_fallback=True): 

443 """Creates a type specification for the given value. 

444 

445 Args: 

446 element: The element to create the type specification for. 

447 use_fallback: Whether to fall back to converting the element to a tensor 

448 in order to compute its `TypeSpec`. 

449 

450 Returns: 

451 A nested structure of `TypeSpec`s that represents the type specification 

452 of `element`. 

453 

454 Raises: 

455 TypeError: If a `TypeSpec` cannot be built for `element`, because its type 

456 is not supported. 

457 """ 

458 spec = type_spec._type_spec_from_value(element) # pylint: disable=protected-access 

459 if spec is not None: 

460 return spec 

461 

462 if isinstance(element, collections_abc.Mapping): 

463 # We create a shallow copy in an attempt to preserve the key order. 

464 # 

465 # Note that we do not guarantee that the key order is preserved, which is 

466 # a limitation inherited from `copy()`. As a consequence, callers of 

467 # `type_spec_from_value` should not assume that the key order of a `dict` 

468 # in the returned nested structure matches the key order of the 

469 # corresponding `dict` in the input value. 

470 if isinstance(element, collections.defaultdict): 

471 ctor = lambda items: type(element)(element.default_factory, items) 

472 else: 

473 ctor = type(element) 

474 return ctor([(k, type_spec_from_value(v)) for k, v in element.items()]) 

475 

476 if isinstance(element, tuple): 

477 if hasattr(element, "_fields") and isinstance( 

478 element._fields, collections_abc.Sequence) and all( 

479 isinstance(f, str) for f in element._fields): 

480 if isinstance(element, wrapt.ObjectProxy): 

481 element_type = type(element.__wrapped__) 

482 else: 

483 element_type = type(element) 

484 # `element` is a namedtuple 

485 return element_type(*[type_spec_from_value(v) for v in element]) 

486 # `element` is not a namedtuple 

487 return tuple([type_spec_from_value(v) for v in element]) 

488 

489 if hasattr(element.__class__, "__attrs_attrs__"): 

490 # `element` is an `attr.s` decorated class 

491 attrs = getattr(element.__class__, "__attrs_attrs__") 

492 return type(element)(*[ 

493 type_spec_from_value(getattr(element, a.name)) for a in attrs 

494 ]) 

495 

496 if use_fallback: 

497 # As a fallback try converting the element to a tensor. 

498 try: 

499 tensor = ops.convert_to_tensor(element) 

500 spec = type_spec_from_value(tensor) 

501 if spec is not None: 

502 return spec 

503 except (ValueError, TypeError) as e: 

504 logging.vlog( 

505 3, "Failed to convert %r to tensor: %s" % (type(element).__name__, e)) 

506 

507 raise TypeError("Could not build a `TypeSpec` for {} with type {}".format( 

508 element, 

509 type(element).__name__)) 

510 

511 

512# TODO(b/149584798): Move this to framework and add tests for non-tf.data 

513# functionality. 

514class NoneTensor(composite_tensor.CompositeTensor): 

515 """Composite tensor representation for `None` value.""" 

516 

517 @property 

518 def _type_spec(self): 

519 return NoneTensorSpec() 

520 

521 

522# TODO(b/149584798): Move this to framework and add tests for non-tf.data 

523# functionality. 

524@type_spec_registry.register("tf.NoneTensorSpec") 

525class NoneTensorSpec(type_spec.BatchableTypeSpec): 

526 """Type specification for `None` value.""" 

527 

528 @property 

529 def value_type(self): 

530 return NoneTensor 

531 

532 def _serialize(self): 

533 return () 

534 

535 @property 

536 def _component_specs(self): 

537 return [] 

538 

539 def _to_components(self, value): 

540 return [] 

541 

542 def _from_components(self, components): 

543 return 

544 

545 def _to_tensor_list(self, value): 

546 return [] 

547 

548 @staticmethod 

549 def from_value(value): 

550 return NoneTensorSpec() 

551 

552 def _batch(self, batch_size): 

553 return NoneTensorSpec() 

554 

555 def _unbatch(self): 

556 return NoneTensorSpec() 

557 

558 def _to_batched_tensor_list(self, value): 

559 return [] 

560 

561 def _to_legacy_output_types(self): 

562 return self 

563 

564 def _to_legacy_output_shapes(self): 

565 return self 

566 

567 def _to_legacy_output_classes(self): 

568 return self 

569 

570 def most_specific_compatible_shape(self, other): 

571 if type(self) is not type(other): 

572 raise ValueError("No `TypeSpec` is compatible with both {} and {}".format( 

573 self, other)) 

574 return self 

575 

576 

577type_spec.register_type_spec_from_value_converter(type(None), 

578 NoneTensorSpec.from_value)