Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/composite_tensor_gradient.py: 43%

54 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Gradient support for Composite Tensors.""" 

16 

17import abc 

18import sys 

19 

20from tensorflow.python.framework import composite_tensor 

21from tensorflow.python.util import nest 

22 

23 

24# pylint:disable=g-import-not-at-top 

25if sys.version_info >= (3, 8): 

26 from typing import Protocol 

27 from typing import runtime_checkable 

28else: 

29 from typing_extensions import Protocol 

30 from typing_extensions import runtime_checkable 

31# pylint:enable=g-import-not-at-top 

32 

33 

34# TODO(xjun): Add CompositeTensorGradient support for SparseTensor, 

35# StructuredTensor, and MaskedTensor. 

36class CompositeTensorGradient(object, metaclass=abc.ABCMeta): 

37 """Class used to help compute gradients for CompositeTensors. 

38 

39 This abstract base class defines two methods: `get_gradient_components`, which 

40 returns the components of a value that should be included in gradients; and 

41 `replace_gradient_components`, which replaces the gradient components in a 

42 value. These methods can be used to compute the gradient of a `y` with 

43 respect to `x` (`grad(y, x)`) as follows: 

44 

45 * If `y` is a `CompositeTensor` with `CompositeTensorGradient` `cg` = 

46 `y.__composite_gradient__`, then `grad(y, x)` = 

47 `grad(cg.get_gradient_components(y), x)`. 

48 

49 * If `x` is a `CompositeTensor` with `CompositeTensorGradient` `cg` = 

50 'x.__composite_gradient__', then `grad(y, x)` = 

51 `cg.replace_gradient_components(x, grad(y, cg.get_gradient_components(x))`. 

52 """ 

53 

54 @abc.abstractmethod 

55 def get_gradient_components(self, value): 

56 """Returns the components of `value` that should be included in gradients. 

57 

58 This method may not call TensorFlow ops, since any new ops added to the 

59 graph would not be propertly tracked by the gradient mechanisms. 

60 

61 Args: 

62 value: A `CompositeTensor` value. 

63 

64 Returns: 

65 A nested structure of `Tensor` or `IndexedSlices`. 

66 """ 

67 raise NotImplementedError( 

68 f"{type(self).__name__}.get_gradient_components()") 

69 

70 @abc.abstractmethod 

71 def replace_gradient_components(self, value, component_grads): 

72 """Replaces the gradient components in `value` with `component_grads`. 

73 

74 Args: 

75 value: A value with its gradient components compatible with 

76 `component_grads`. 

77 component_grads: A nested structure of `Tensor` or `IndexedSlices` or 

78 `None` (for unconnected gradients). 

79 

80 Returns: 

81 A copy of `value`, where the components that should be included in 

82 gradients have been replaced by `component_grads`; or `None` (if 

83 `component_grads` includes `None`). 

84 """ 

85 raise NotImplementedError( 

86 f"{type(self).__name__}.replace_gradient_components()") 

87 

88 

89@runtime_checkable 

90class CompositeTensorGradientProtocol(Protocol): 

91 """Protocol for adding gradient support to CompositeTensors.""" 

92 __composite_gradient__: CompositeTensorGradient 

93 

94 

95class WithValuesCompositeTensorGradient(CompositeTensorGradient): 

96 """CompositeTensorGradient based on `T.values` and `T.with_values`.""" 

97 

98 def get_gradient_components(self, value): 

99 return value.values 

100 

101 def replace_gradient_components(self, value, component_grads): 

102 return value.with_values(component_grads) 

103 

104 

105def _get_tensors_for_gradient(x): 

106 """Returns the Tensors in `x` that should be differentiated. 

107 

108 Args: 

109 x: A `Tensor` or `CompositeTensor`. 

110 

111 Returns: 

112 A `Tensor` or a nested structure of `Tensor`. 

113 """ 

114 if not isinstance(x, composite_tensor.CompositeTensor): 

115 return x 

116 

117 if not isinstance(x, CompositeTensorGradientProtocol): 

118 raise ValueError( 

119 f"Type {type(x).__name__} is not supported as a gradient source or " 

120 "gradient target.") 

121 composite_gradient = x.__composite_gradient__ 

122 gradient_components = composite_gradient.get_gradient_components(x) 

123 if gradient_components is x: 

124 return x 

125 return nest.map_structure(_get_tensors_for_gradient, gradient_components) 

126 

127 

128def _replace_tensors_for_gradient(x, grad): 

129 """Replaces the tensors in `x` that should be differentiated with `grad`. 

130 

131 Args: 

132 x: A `Tensor` or `CompositeTensor`. 

133 grad: A nested structure of `Tensor`, with the same structure as the value 

134 returned by `_get_tensors_for_gradient(x)`. 

135 

136 Returns: 

137 A `Tensor` or `CompositeTensor`. 

138 """ 

139 if not isinstance(x, composite_tensor.CompositeTensor): 

140 return grad 

141 

142 if not isinstance(x, CompositeTensorGradientProtocol): 

143 raise ValueError( 

144 f"Type {type(x).__name__} is not supported as a gradient source.") 

145 

146 composite_gradient = x.__composite_gradient__ 

147 x_components = composite_gradient.get_gradient_components(x) 

148 if x_components is x: 

149 grad_components = grad 

150 else: 

151 grad_components = nest.map_structure_up_to(x_components, 

152 _replace_tensors_for_gradient, 

153 x_components, grad) 

154 if grad_components is None: 

155 return None 

156 return composite_gradient.replace_gradient_components(x, grad_components) 

157 

158 

159def get_flat_tensors_for_gradients(xs): 

160 """Returns a flat list of Tensors that should be differentiated for `xs`. 

161 

162 Args: 

163 xs: A list of `Tensor`s or `CompositeTensor`s. 

164 

165 Returns: 

166 A flat list of `Tensor`s constructed from `xs`, where `Tensor` values are 

167 left as-is, and `CompositeTensor`s are replaced with 

168 `_get_tensors_for_gradient(x)`. 

169 """ 

170 return nest.flatten([_get_tensors_for_gradient(x) for x in xs]) 

171 

172 

173def replace_flat_tensors_for_gradients(xs, flat_grads): 

174 """Replaces Tensors that should be differentiated in `xs` with `flat_grads`. 

175 

176 Args: 

177 xs: A list of `Tensor`s or `CompositeTensor`s. 

178 flat_grads: A list of `Tensor`. 

179 

180 Returns: 

181 A list of `Tensor` or `CompositeTensor`. 

182 """ 

183 xs_structure = [_get_tensors_for_gradient(x) for x in xs] 

184 grads = nest.pack_sequence_as(xs_structure, flat_grads) 

185 return [_replace_tensors_for_gradient(x, grad) for x, grad in zip(xs, grads)]