Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/util/structure.py: 34%
204 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for describing the structure of a `tf.data` type."""
16import collections
17import functools
18import itertools
20import wrapt
22from tensorflow.python.data.util import nest
23from tensorflow.python.framework import composite_tensor
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import sparse_tensor
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import tensor_spec
28from tensorflow.python.framework import type_spec
29from tensorflow.python.framework import type_spec_registry
30from tensorflow.python.ops import resource_variable_ops
31from tensorflow.python.ops import tensor_array_ops
32from tensorflow.python.ops.ragged import ragged_tensor
33from tensorflow.python.platform import tf_logging as logging
34from tensorflow.python.types import internal
35from tensorflow.python.util import deprecation
36from tensorflow.python.util.compat import collections_abc
37from tensorflow.python.util.tf_export import tf_export
40# pylint: disable=invalid-name
41@tf_export(v1=["data.experimental.TensorStructure"])
42@deprecation.deprecated(None, "Use `tf.TensorSpec` instead.")
43def _TensorStructure(dtype, shape):
44 return tensor_spec.TensorSpec(shape, dtype)
47@tf_export(v1=["data.experimental.SparseTensorStructure"])
48@deprecation.deprecated(None, "Use `tf.SparseTensorSpec` instead.")
49def _SparseTensorStructure(dtype, shape):
50 return sparse_tensor.SparseTensorSpec(shape, dtype)
53@tf_export(v1=["data.experimental.TensorArrayStructure"])
54@deprecation.deprecated(None, "Use `tf.TensorArraySpec` instead.")
55def _TensorArrayStructure(dtype, element_shape, dynamic_size, infer_shape):
56 return tensor_array_ops.TensorArraySpec(element_shape, dtype,
57 dynamic_size, infer_shape)
60@tf_export(v1=["data.experimental.RaggedTensorStructure"])
61@deprecation.deprecated(None, "Use `tf.RaggedTensorSpec` instead.")
62def _RaggedTensorStructure(dtype, shape, ragged_rank):
63 return ragged_tensor.RaggedTensorSpec(shape, dtype, ragged_rank)
64# pylint: enable=invalid-name
67# TODO(jsimsa): Remove the special-case for `TensorArray` pass-through once
68# it is a subclass of `CompositeTensor`.
69def normalize_element(element, element_signature=None):
70 """Normalizes a nested structure of element components.
72 * Components matching `SparseTensorSpec` are converted to `SparseTensor`.
73 * Components matching `RaggedTensorSpec` are converted to `RaggedTensor`.
74 * Components matching `VariableSpec` are converted to `Tensor`.
75 * Components matching `DatasetSpec` or `TensorArraySpec` are passed through.
76 * `CompositeTensor` components are passed through.
77 * All other components are converted to `Tensor`.
79 Args:
80 element: A nested structure of individual components.
81 element_signature: (Optional.) A nested structure of `tf.DType` objects
82 corresponding to each component of `element`. If specified, it will be
83 used to set the exact type of output tensor when converting input
84 components which are not tensors themselves (e.g. numpy arrays, native
85 python types, etc.)
87 Returns:
88 A nested structure of `Tensor`, `Variable`, `Dataset`, `SparseTensor`,
89 `RaggedTensor`, or `TensorArray` objects.
90 """
91 normalized_components = []
92 if element_signature is None:
93 components = nest.flatten(element)
94 flattened_signature = [None] * len(components)
95 pack_as = element
96 else:
97 flattened_signature = nest.flatten(element_signature)
98 components = nest.flatten_up_to(element_signature, element)
99 pack_as = element_signature
100 with ops.name_scope("normalize_element"):
101 for i, (t, spec) in enumerate(zip(components, flattened_signature)):
102 try:
103 if spec is None:
104 spec = type_spec_from_value(t, use_fallback=False)
105 except TypeError:
106 # TypeError indicates it was not possible to compute a `TypeSpec` for
107 # the value. As a fallback try converting the value to a tensor.
108 normalized_components.append(
109 ops.convert_to_tensor(t, name="component_%d" % i))
110 else:
111 # To avoid a circular dependency between dataset_ops and structure,
112 # we check the class name instead of using `isinstance`.
113 if spec.__class__.__name__ == "DatasetSpec":
114 normalized_components.append(t)
115 elif isinstance(spec, sparse_tensor.SparseTensorSpec):
116 normalized_components.append(sparse_tensor.SparseTensor.from_value(t))
117 elif isinstance(spec, ragged_tensor.RaggedTensorSpec):
118 normalized_components.append(
119 ragged_tensor.convert_to_tensor_or_ragged_tensor(
120 t, name="component_%d" % i))
121 elif isinstance(spec, (tensor_array_ops.TensorArraySpec)):
122 normalized_components.append(t)
123 elif isinstance(spec, NoneTensorSpec):
124 normalized_components.append(NoneTensor())
125 elif isinstance(spec, resource_variable_ops.VariableSpec):
126 normalized_components.append(
127 ops.convert_to_tensor(t, name=f"component_{i}", dtype=spec.dtype))
128 elif isinstance(t, composite_tensor.CompositeTensor):
129 normalized_components.append(t)
130 else:
131 dtype = getattr(spec, "dtype", None)
132 normalized_components.append(
133 ops.convert_to_tensor(t, name="component_%d" % i, dtype=dtype))
134 return nest.pack_sequence_as(pack_as, normalized_components)
137def convert_legacy_structure(output_types, output_shapes, output_classes):
138 """Returns a `Structure` that represents the given legacy structure.
140 This method provides a way to convert from the existing `Dataset` and
141 `Iterator` structure-related properties to a `Structure` object. A "legacy"
142 structure is represented by the `tf.data.Dataset.output_types`,
143 `tf.data.Dataset.output_shapes`, and `tf.data.Dataset.output_classes`
144 properties.
146 TODO(b/110122868): Remove this function once `Structure` is used throughout
147 `tf.data`.
149 Args:
150 output_types: A nested structure of `tf.DType` objects corresponding to
151 each component of a structured value.
152 output_shapes: A nested structure of `tf.TensorShape` objects
153 corresponding to each component a structured value.
154 output_classes: A nested structure of Python `type` objects corresponding
155 to each component of a structured value.
157 Returns:
158 A `Structure`.
160 Raises:
161 TypeError: If a structure cannot be built from the arguments, because one of
162 the component classes in `output_classes` is not supported.
163 """
164 flat_types = nest.flatten(output_types)
165 flat_shapes = nest.flatten(output_shapes)
166 flat_classes = nest.flatten(output_classes)
167 flat_ret = []
168 for flat_type, flat_shape, flat_class in zip(flat_types, flat_shapes,
169 flat_classes):
170 if isinstance(flat_class, type_spec.TypeSpec):
171 flat_ret.append(flat_class)
172 elif issubclass(flat_class, sparse_tensor.SparseTensor):
173 flat_ret.append(sparse_tensor.SparseTensorSpec(flat_shape, flat_type))
174 elif issubclass(flat_class, ops.Tensor):
175 flat_ret.append(tensor_spec.TensorSpec(flat_shape, flat_type))
176 elif issubclass(flat_class, tensor_array_ops.TensorArray):
177 # We sneaked the dynamic_size and infer_shape into the legacy shape.
178 flat_ret.append(
179 tensor_array_ops.TensorArraySpec(
180 flat_shape[2:], flat_type,
181 dynamic_size=tensor_shape.dimension_value(flat_shape[0]),
182 infer_shape=tensor_shape.dimension_value(flat_shape[1])))
183 else:
184 # NOTE(mrry): Since legacy structures produced by iterators only
185 # comprise Tensors, SparseTensors, and nests, we do not need to
186 # support all structure types here.
187 raise TypeError(
188 "Could not build a structure for output class {}. Make sure any "
189 "component class in `output_classes` inherits from one of the "
190 "following classes: `tf.TypeSpec`, `tf.sparse.SparseTensor`, "
191 "`tf.Tensor`, `tf.TensorArray`.".format(flat_class.__name__))
193 return nest.pack_sequence_as(output_classes, flat_ret)
196def _from_tensor_list_helper(decode_fn, element_spec, tensor_list):
197 """Returns an element constructed from the given spec and tensor list.
199 Args:
200 decode_fn: Method that constructs an element component from the element spec
201 component and a tensor list.
202 element_spec: A nested structure of `tf.TypeSpec` objects representing to
203 element type specification.
204 tensor_list: A list of tensors to use for constructing the value.
206 Returns:
207 An element constructed from the given spec and tensor list.
209 Raises:
210 ValueError: If the number of tensors needed to construct an element for
211 the given spec does not match the given number of tensors.
212 """
214 # pylint: disable=protected-access
216 flat_specs = nest.flatten(element_spec)
217 flat_spec_lengths = [len(spec._flat_tensor_specs) for spec in flat_specs]
218 if sum(flat_spec_lengths) != len(tensor_list):
219 raise ValueError("Expected {} tensors but got {}.".format(
220 sum(flat_spec_lengths), len(tensor_list)))
222 i = 0
223 flat_ret = []
224 for (component_spec, num_flat_values) in zip(flat_specs, flat_spec_lengths):
225 value = tensor_list[i:i + num_flat_values]
226 flat_ret.append(decode_fn(component_spec, value))
227 i += num_flat_values
228 return nest.pack_sequence_as(element_spec, flat_ret)
231def from_compatible_tensor_list(element_spec, tensor_list):
232 """Returns an element constructed from the given spec and tensor list.
234 Args:
235 element_spec: A nested structure of `tf.TypeSpec` objects representing to
236 element type specification.
237 tensor_list: A list of tensors to use for constructing the value.
239 Returns:
240 An element constructed from the given spec and tensor list.
242 Raises:
243 ValueError: If the number of tensors needed to construct an element for
244 the given spec does not match the given number of tensors.
245 """
247 # pylint: disable=protected-access
248 # pylint: disable=g-long-lambda
249 return _from_tensor_list_helper(
250 lambda spec, value: spec._from_compatible_tensor_list(value),
251 element_spec, tensor_list)
254def from_tensor_list(element_spec, tensor_list):
255 """Returns an element constructed from the given spec and tensor list.
257 Args:
258 element_spec: A nested structure of `tf.TypeSpec` objects representing to
259 element type specification.
260 tensor_list: A list of tensors to use for constructing the value.
262 Returns:
263 An element constructed from the given spec and tensor list.
265 Raises:
266 ValueError: If the number of tensors needed to construct an element for
267 the given spec does not match the given number of tensors or the given
268 spec is not compatible with the tensor list.
269 """
271 # pylint: disable=protected-access
272 # pylint: disable=g-long-lambda
273 return _from_tensor_list_helper(
274 lambda spec, value: spec._from_tensor_list(value), element_spec,
275 tensor_list)
278def get_flat_tensor_specs(element_spec):
279 """Returns a list `tf.TypeSpec`s for the element tensor representation.
281 Args:
282 element_spec: A nested structure of `tf.TypeSpec` objects representing to
283 element type specification.
285 Returns:
286 A list `tf.TypeSpec`s for the element tensor representation.
287 """
289 # pylint: disable=protected-access
290 return list(
291 itertools.chain.from_iterable(
292 spec._flat_tensor_specs for spec in nest.flatten(element_spec)))
295def get_flat_tensor_shapes(element_spec):
296 """Returns a list `tf.TensorShapes`s for the element tensor representation.
298 Args:
299 element_spec: A nested structure of `tf.TypeSpec` objects representing to
300 element type specification.
302 Returns:
303 A list `tf.TensorShapes`s for the element tensor representation.
304 """
305 return [spec.shape for spec in get_flat_tensor_specs(element_spec)]
308def get_flat_tensor_types(element_spec):
309 """Returns a list `tf.DType`s for the element tensor representation.
311 Args:
312 element_spec: A nested structure of `tf.TypeSpec` objects representing to
313 element type specification.
315 Returns:
316 A list `tf.DType`s for the element tensor representation.
317 """
318 return [spec.dtype for spec in get_flat_tensor_specs(element_spec)]
321def _to_tensor_list_helper(encode_fn, element_spec, element):
322 """Returns a tensor list representation of the element.
324 Args:
325 encode_fn: Method that constructs a tensor list representation from the
326 given element spec and element.
327 element_spec: A nested structure of `tf.TypeSpec` objects representing to
328 element type specification.
329 element: The element to convert to tensor list representation.
331 Returns:
332 A tensor list representation of `element`.
334 Raises:
335 ValueError: If `element_spec` and `element` do not have the same number of
336 elements or if the two structures are not nested in the same way.
337 TypeError: If `element_spec` and `element` differ in the type of sequence
338 in any of their substructures.
339 """
341 nest.assert_same_structure(element_spec, element)
343 def reduce_fn(state, value):
344 spec, component = value
345 if isinstance(spec, internal.TensorSpec):
346 try:
347 component = ops.convert_to_tensor(component, spec.dtype)
348 except (TypeError, ValueError):
349 raise ValueError(
350 f"Value {component} is not convertible to a tensor with "
351 f"dtype {spec.dtype} and shape {spec.shape}."
352 )
353 if not component.shape.is_compatible_with(spec.shape):
354 raise ValueError(
355 f"Value {component} is not convertible to a tensor with "
356 f"dtype {spec.dtype} and shape {spec.shape}."
357 )
358 return encode_fn(state, spec, component)
360 return functools.reduce(
361 reduce_fn, zip(nest.flatten(element_spec), nest.flatten(element)), [])
364def to_batched_tensor_list(element_spec, element):
365 """Returns a tensor list representation of the element.
367 Args:
368 element_spec: A nested structure of `tf.TypeSpec` objects representing to
369 element type specification.
370 element: The element to convert to tensor list representation.
372 Returns:
373 A tensor list representation of `element`.
375 Raises:
376 ValueError: If `element_spec` and `element` do not have the same number of
377 elements or if the two structures are not nested in the same way or the
378 rank of any of the tensors in the tensor list representation is 0.
379 TypeError: If `element_spec` and `element` differ in the type of sequence
380 in any of their substructures.
381 """
383 # pylint: disable=protected-access
384 # pylint: disable=g-long-lambda
385 return _to_tensor_list_helper(
386 lambda state, spec, component: state + spec._to_batched_tensor_list(
387 component), element_spec, element)
390def to_tensor_list(element_spec, element):
391 """Returns a tensor list representation of the element.
393 Args:
394 element_spec: A nested structure of `tf.TypeSpec` objects representing to
395 element type specification.
396 element: The element to convert to tensor list representation.
398 Returns:
399 A tensor list representation of `element`.
401 Raises:
402 ValueError: If `element_spec` and `element` do not have the same number of
403 elements or if the two structures are not nested in the same way.
404 TypeError: If `element_spec` and `element` differ in the type of sequence
405 in any of their substructures.
406 """
408 # pylint: disable=protected-access
409 # pylint: disable=g-long-lambda
410 return _to_tensor_list_helper(
411 lambda state, spec, component: state + spec._to_tensor_list(component),
412 element_spec, element)
415def are_compatible(spec1, spec2):
416 """Indicates whether two type specifications are compatible.
418 Two type specifications are compatible if they have the same nested structure
419 and the their individual components are pair-wise compatible.
421 Args:
422 spec1: A `tf.TypeSpec` object to compare.
423 spec2: A `tf.TypeSpec` object to compare.
425 Returns:
426 `True` if the two type specifications are compatible and `False` otherwise.
427 """
429 try:
430 nest.assert_same_structure(spec1, spec2)
431 except TypeError:
432 return False
433 except ValueError:
434 return False
436 for s1, s2 in zip(nest.flatten(spec1), nest.flatten(spec2)):
437 if not s1.is_compatible_with(s2) or not s2.is_compatible_with(s1):
438 return False
439 return True
442def type_spec_from_value(element, use_fallback=True):
443 """Creates a type specification for the given value.
445 Args:
446 element: The element to create the type specification for.
447 use_fallback: Whether to fall back to converting the element to a tensor
448 in order to compute its `TypeSpec`.
450 Returns:
451 A nested structure of `TypeSpec`s that represents the type specification
452 of `element`.
454 Raises:
455 TypeError: If a `TypeSpec` cannot be built for `element`, because its type
456 is not supported.
457 """
458 spec = type_spec._type_spec_from_value(element) # pylint: disable=protected-access
459 if spec is not None:
460 return spec
462 if isinstance(element, collections_abc.Mapping):
463 # We create a shallow copy in an attempt to preserve the key order.
464 #
465 # Note that we do not guarantee that the key order is preserved, which is
466 # a limitation inherited from `copy()`. As a consequence, callers of
467 # `type_spec_from_value` should not assume that the key order of a `dict`
468 # in the returned nested structure matches the key order of the
469 # corresponding `dict` in the input value.
470 if isinstance(element, collections.defaultdict):
471 ctor = lambda items: type(element)(element.default_factory, items)
472 else:
473 ctor = type(element)
474 return ctor([(k, type_spec_from_value(v)) for k, v in element.items()])
476 if isinstance(element, tuple):
477 if hasattr(element, "_fields") and isinstance(
478 element._fields, collections_abc.Sequence) and all(
479 isinstance(f, str) for f in element._fields):
480 if isinstance(element, wrapt.ObjectProxy):
481 element_type = type(element.__wrapped__)
482 else:
483 element_type = type(element)
484 # `element` is a namedtuple
485 return element_type(*[type_spec_from_value(v) for v in element])
486 # `element` is not a namedtuple
487 return tuple([type_spec_from_value(v) for v in element])
489 if hasattr(element.__class__, "__attrs_attrs__"):
490 # `element` is an `attr.s` decorated class
491 attrs = getattr(element.__class__, "__attrs_attrs__")
492 return type(element)(*[
493 type_spec_from_value(getattr(element, a.name)) for a in attrs
494 ])
496 if use_fallback:
497 # As a fallback try converting the element to a tensor.
498 try:
499 tensor = ops.convert_to_tensor(element)
500 spec = type_spec_from_value(tensor)
501 if spec is not None:
502 return spec
503 except (ValueError, TypeError) as e:
504 logging.vlog(
505 3, "Failed to convert %r to tensor: %s" % (type(element).__name__, e))
507 raise TypeError("Could not build a `TypeSpec` for {} with type {}".format(
508 element,
509 type(element).__name__))
512# TODO(b/149584798): Move this to framework and add tests for non-tf.data
513# functionality.
514class NoneTensor(composite_tensor.CompositeTensor):
515 """Composite tensor representation for `None` value."""
517 @property
518 def _type_spec(self):
519 return NoneTensorSpec()
522# TODO(b/149584798): Move this to framework and add tests for non-tf.data
523# functionality.
524@type_spec_registry.register("tf.NoneTensorSpec")
525class NoneTensorSpec(type_spec.BatchableTypeSpec):
526 """Type specification for `None` value."""
528 @property
529 def value_type(self):
530 return NoneTensor
532 def _serialize(self):
533 return ()
535 @property
536 def _component_specs(self):
537 return []
539 def _to_components(self, value):
540 return []
542 def _from_components(self, components):
543 return
545 def _to_tensor_list(self, value):
546 return []
548 @staticmethod
549 def from_value(value):
550 return NoneTensorSpec()
552 def _batch(self, batch_size):
553 return NoneTensorSpec()
555 def _unbatch(self):
556 return NoneTensorSpec()
558 def _to_batched_tensor_list(self, value):
559 return []
561 def _to_legacy_output_types(self):
562 return self
564 def _to_legacy_output_shapes(self):
565 return self
567 def _to_legacy_output_classes(self):
568 return self
570 def most_specific_compatible_shape(self, other):
571 if type(self) is not type(other):
572 raise ValueError("No `TypeSpec` is compatible with both {} and {}".format(
573 self, other))
574 return self
577type_spec.register_type_spec_from_value_converter(type(None),
578 NoneTensorSpec.from_value)