Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/composite_tensor.py: 59%
29 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tensor-like objects that are composed from tf.Tensors."""
17import abc
19from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
20from tensorflow.python.util import _pywrap_utils
21from tensorflow.python.util import nest
22from tensorflow.python.util.tf_export import tf_export
25@tf_export("__internal__.CompositeTensor", v1=[])
26class CompositeTensor(metaclass=abc.ABCMeta):
27 """Abstract base class for Tensor-like objects that are composed from Tensors.
29 Each `CompositeTensor` can be decomposed into a structured collection of
30 component `tf.Tensor`s, and reconstructed from those components.
32 The `tensorflow.python.util.nest` module has support for treating composite
33 tensors as structure, which makes it easy to flatten and reconstruct
34 composite tensors (or larger structures that contain composite tensors).
35 E.g.:
37 ```python
38 ct = ... # Create a composite tensor.
39 flat_list_of_tensors = nest.flatten(ct, expand_composites=True)
40 transformed_list_of_tensors = ... # do something with the flat tensors.
41 result = nest.pack_sequence_as(ct, transformed_list_of_tensors,
42 expand_composites=True)
43 ```
44 """
46 @abc.abstractproperty
47 def _type_spec(self):
48 """A `TypeSpec` describing the type of this value."""
49 raise NotImplementedError(f"{type(self).__name__}._type_spec()")
51 def _shape_invariant_to_type_spec(self, shape):
52 """Returns a TypeSpec given a shape invariant (used by `tf.while_loop`).
54 Args:
55 shape: A `tf.TensorShape` object. The shape invariant for this
56 `CompositeTensor`, or `None` if a default shape invariant should be used
57 (based on the value of this `CompositeTensor`).
59 Returns:
60 A nested structure whose values are `tf.TensorShape` objects, specifying
61 the shape invariants for the tensors that comprise this `CompositeTensor`.
62 """
63 # New TypeSpec subclasses generally do not need to implement this --
64 # this method is used for backwards compatibility. Users of tf.while_loop
65 # can specify a type by passing in TypeSpec instead.
66 raise NotImplementedError(
67 f"{type(self).__name__}._shape_invariant_to_type_spec")
69 def _consumers(self):
70 """Returns a list of `Operation`s that consume this `CompositeTensor`.
72 Returns:
73 A list of `Operation`s.
75 Raises:
76 RuntimeError: If this method is called while executing eagerly.
77 """
78 consumers = nest.flatten([
79 component.consumers()
80 for component in nest.flatten(self, expand_composites=True)
81 if getattr(component, "graph", None) is not None
82 ])
83 return list(set(consumers))
85 def __tf_tracing_type__(self, context):
86 return self._type_spec.__tf_tracing_type__(context)
88 def _convert_variables_to_tensors(self):
89 """Converts ResourceVariable components to Tensors.
91 Override this method to explicitly convert ResourceVariables embedded in the
92 CompositeTensor to Tensors. By default, it returns the CompositeTensor
93 unchanged.
95 Returns:
96 A CompositeTensor with all its ResourceVariable components converted to
97 Tensors.
98 """
99 return self
102_pywrap_utils.RegisterType("CompositeTensor", CompositeTensor)
105def replace_composites_with_components(structure):
106 """Recursively replaces CompositeTensors with their components.
108 Args:
109 structure: A `nest`-compatible structure, possibly containing composite
110 tensors.
112 Returns:
113 A copy of `structure`, where each composite tensor has been replaced by
114 its components. The result will contain no composite tensors.
115 Note that `nest.flatten(replace_composites_with_components(structure))`
116 returns the same value as `nest.flatten(structure)`.
117 """
118 if isinstance(structure, CompositeTensor):
119 return replace_composites_with_components(
120 structure._type_spec._to_components(structure)) # pylint: disable=protected-access
121 elif not nest.is_nested(structure):
122 return structure
123 else:
124 return nest.map_structure(
125 replace_composites_with_components, structure, expand_composites=False)
128def convert_variables_to_tensors(composite_tensor):
129 return composite_tensor._convert_variables_to_tensors() # pylint: disable=protected-access
132# @TODO(edloper): Can we replace convert_to_tensor_or_xyz with just
133# convert_to_tensor_or_composite? Alternatively, should composite tensors
134# register a dispatch override for tf.convert_to_tensor?
136# Note about the internal encoding of composite tensors when they are "lowered"
137# from Python objects to tensors. The usual encoding is "component encoding"
138# which uses the dense tensors that represent a composite tensor.
139# A second encoding, "batchable tensor list encoding", is used by datasets
140# and map_fn which in addition to supporting batching also can use ops
141# for encoding and decoding, e.g. for encoding/decoding to/from a
142# single variant that represents a composite tensor. Some internal properties
143# for type specs for composite tensors use `flat` as a nickname for
144# "batchable tensor list encoding". (e.g. `flat_tensor_specs`).