Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/composite_tensor.py: 59%

29 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Tensor-like objects that are composed from tf.Tensors.""" 

16 

17import abc 

18 

19from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import 

20from tensorflow.python.util import _pywrap_utils 

21from tensorflow.python.util import nest 

22from tensorflow.python.util.tf_export import tf_export 

23 

24 

25@tf_export("__internal__.CompositeTensor", v1=[]) 

26class CompositeTensor(metaclass=abc.ABCMeta): 

27 """Abstract base class for Tensor-like objects that are composed from Tensors. 

28 

29 Each `CompositeTensor` can be decomposed into a structured collection of 

30 component `tf.Tensor`s, and reconstructed from those components. 

31 

32 The `tensorflow.python.util.nest` module has support for treating composite 

33 tensors as structure, which makes it easy to flatten and reconstruct 

34 composite tensors (or larger structures that contain composite tensors). 

35 E.g.: 

36 

37 ```python 

38 ct = ... # Create a composite tensor. 

39 flat_list_of_tensors = nest.flatten(ct, expand_composites=True) 

40 transformed_list_of_tensors = ... # do something with the flat tensors. 

41 result = nest.pack_sequence_as(ct, transformed_list_of_tensors, 

42 expand_composites=True) 

43 ``` 

44 """ 

45 

46 @abc.abstractproperty 

47 def _type_spec(self): 

48 """A `TypeSpec` describing the type of this value.""" 

49 raise NotImplementedError(f"{type(self).__name__}._type_spec()") 

50 

51 def _shape_invariant_to_type_spec(self, shape): 

52 """Returns a TypeSpec given a shape invariant (used by `tf.while_loop`). 

53 

54 Args: 

55 shape: A `tf.TensorShape` object. The shape invariant for this 

56 `CompositeTensor`, or `None` if a default shape invariant should be used 

57 (based on the value of this `CompositeTensor`). 

58 

59 Returns: 

60 A nested structure whose values are `tf.TensorShape` objects, specifying 

61 the shape invariants for the tensors that comprise this `CompositeTensor`. 

62 """ 

63 # New TypeSpec subclasses generally do not need to implement this -- 

64 # this method is used for backwards compatibility. Users of tf.while_loop 

65 # can specify a type by passing in TypeSpec instead. 

66 raise NotImplementedError( 

67 f"{type(self).__name__}._shape_invariant_to_type_spec") 

68 

69 def _consumers(self): 

70 """Returns a list of `Operation`s that consume this `CompositeTensor`. 

71 

72 Returns: 

73 A list of `Operation`s. 

74 

75 Raises: 

76 RuntimeError: If this method is called while executing eagerly. 

77 """ 

78 consumers = nest.flatten([ 

79 component.consumers() 

80 for component in nest.flatten(self, expand_composites=True) 

81 if getattr(component, "graph", None) is not None 

82 ]) 

83 return list(set(consumers)) 

84 

85 def __tf_tracing_type__(self, context): 

86 return self._type_spec.__tf_tracing_type__(context) 

87 

88 def _convert_variables_to_tensors(self): 

89 """Converts ResourceVariable components to Tensors. 

90 

91 Override this method to explicitly convert ResourceVariables embedded in the 

92 CompositeTensor to Tensors. By default, it returns the CompositeTensor 

93 unchanged. 

94 

95 Returns: 

96 A CompositeTensor with all its ResourceVariable components converted to 

97 Tensors. 

98 """ 

99 return self 

100 

101 

102_pywrap_utils.RegisterType("CompositeTensor", CompositeTensor) 

103 

104 

105def replace_composites_with_components(structure): 

106 """Recursively replaces CompositeTensors with their components. 

107 

108 Args: 

109 structure: A `nest`-compatible structure, possibly containing composite 

110 tensors. 

111 

112 Returns: 

113 A copy of `structure`, where each composite tensor has been replaced by 

114 its components. The result will contain no composite tensors. 

115 Note that `nest.flatten(replace_composites_with_components(structure))` 

116 returns the same value as `nest.flatten(structure)`. 

117 """ 

118 if isinstance(structure, CompositeTensor): 

119 return replace_composites_with_components( 

120 structure._type_spec._to_components(structure)) # pylint: disable=protected-access 

121 elif not nest.is_nested(structure): 

122 return structure 

123 else: 

124 return nest.map_structure( 

125 replace_composites_with_components, structure, expand_composites=False) 

126 

127 

128def convert_variables_to_tensors(composite_tensor): 

129 return composite_tensor._convert_variables_to_tensors() # pylint: disable=protected-access 

130 

131 

132# @TODO(edloper): Can we replace convert_to_tensor_or_xyz with just 

133# convert_to_tensor_or_composite? Alternatively, should composite tensors 

134# register a dispatch override for tf.convert_to_tensor? 

135 

136# Note about the internal encoding of composite tensors when they are "lowered" 

137# from Python objects to tensors. The usual encoding is "component encoding" 

138# which uses the dense tensors that represent a composite tensor. 

139# A second encoding, "batchable tensor list encoding", is used by datasets 

140# and map_fn which in addition to supporting batching also can use ops 

141# for encoding and decoding, e.g. for encoding/decoding to/from a 

142# single variant that represents a composite tensor. Some internal properties 

143# for type specs for composite tensors use `flat` as a nickname for 

144# "batchable tensor list encoding". (e.g. `flat_tensor_specs`).