Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/composite_tensor_gradient.py: 43%
54 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gradient support for Composite Tensors."""
17import abc
18import sys
20from tensorflow.python.framework import composite_tensor
21from tensorflow.python.util import nest
24# pylint:disable=g-import-not-at-top
25if sys.version_info >= (3, 8):
26 from typing import Protocol
27 from typing import runtime_checkable
28else:
29 from typing_extensions import Protocol
30 from typing_extensions import runtime_checkable
31# pylint:enable=g-import-not-at-top
34# TODO(xjun): Add CompositeTensorGradient support for SparseTensor,
35# StructuredTensor, and MaskedTensor.
36class CompositeTensorGradient(object, metaclass=abc.ABCMeta):
37 """Class used to help compute gradients for CompositeTensors.
39 This abstract base class defines two methods: `get_gradient_components`, which
40 returns the components of a value that should be included in gradients; and
41 `replace_gradient_components`, which replaces the gradient components in a
42 value. These methods can be used to compute the gradient of a `y` with
43 respect to `x` (`grad(y, x)`) as follows:
45 * If `y` is a `CompositeTensor` with `CompositeTensorGradient` `cg` =
46 `y.__composite_gradient__`, then `grad(y, x)` =
47 `grad(cg.get_gradient_components(y), x)`.
49 * If `x` is a `CompositeTensor` with `CompositeTensorGradient` `cg` =
50 'x.__composite_gradient__', then `grad(y, x)` =
51 `cg.replace_gradient_components(x, grad(y, cg.get_gradient_components(x))`.
52 """
54 @abc.abstractmethod
55 def get_gradient_components(self, value):
56 """Returns the components of `value` that should be included in gradients.
58 This method may not call TensorFlow ops, since any new ops added to the
59 graph would not be propertly tracked by the gradient mechanisms.
61 Args:
62 value: A `CompositeTensor` value.
64 Returns:
65 A nested structure of `Tensor` or `IndexedSlices`.
66 """
67 raise NotImplementedError(
68 f"{type(self).__name__}.get_gradient_components()")
70 @abc.abstractmethod
71 def replace_gradient_components(self, value, component_grads):
72 """Replaces the gradient components in `value` with `component_grads`.
74 Args:
75 value: A value with its gradient components compatible with
76 `component_grads`.
77 component_grads: A nested structure of `Tensor` or `IndexedSlices` or
78 `None` (for unconnected gradients).
80 Returns:
81 A copy of `value`, where the components that should be included in
82 gradients have been replaced by `component_grads`; or `None` (if
83 `component_grads` includes `None`).
84 """
85 raise NotImplementedError(
86 f"{type(self).__name__}.replace_gradient_components()")
89@runtime_checkable
90class CompositeTensorGradientProtocol(Protocol):
91 """Protocol for adding gradient support to CompositeTensors."""
92 __composite_gradient__: CompositeTensorGradient
95class WithValuesCompositeTensorGradient(CompositeTensorGradient):
96 """CompositeTensorGradient based on `T.values` and `T.with_values`."""
98 def get_gradient_components(self, value):
99 return value.values
101 def replace_gradient_components(self, value, component_grads):
102 return value.with_values(component_grads)
105def _get_tensors_for_gradient(x):
106 """Returns the Tensors in `x` that should be differentiated.
108 Args:
109 x: A `Tensor` or `CompositeTensor`.
111 Returns:
112 A `Tensor` or a nested structure of `Tensor`.
113 """
114 if not isinstance(x, composite_tensor.CompositeTensor):
115 return x
117 if not isinstance(x, CompositeTensorGradientProtocol):
118 raise ValueError(
119 f"Type {type(x).__name__} is not supported as a gradient source or "
120 "gradient target.")
121 composite_gradient = x.__composite_gradient__
122 gradient_components = composite_gradient.get_gradient_components(x)
123 if gradient_components is x:
124 return x
125 return nest.map_structure(_get_tensors_for_gradient, gradient_components)
128def _replace_tensors_for_gradient(x, grad):
129 """Replaces the tensors in `x` that should be differentiated with `grad`.
131 Args:
132 x: A `Tensor` or `CompositeTensor`.
133 grad: A nested structure of `Tensor`, with the same structure as the value
134 returned by `_get_tensors_for_gradient(x)`.
136 Returns:
137 A `Tensor` or `CompositeTensor`.
138 """
139 if not isinstance(x, composite_tensor.CompositeTensor):
140 return grad
142 if not isinstance(x, CompositeTensorGradientProtocol):
143 raise ValueError(
144 f"Type {type(x).__name__} is not supported as a gradient source.")
146 composite_gradient = x.__composite_gradient__
147 x_components = composite_gradient.get_gradient_components(x)
148 if x_components is x:
149 grad_components = grad
150 else:
151 grad_components = nest.map_structure_up_to(x_components,
152 _replace_tensors_for_gradient,
153 x_components, grad)
154 if grad_components is None:
155 return None
156 return composite_gradient.replace_gradient_components(x, grad_components)
159def get_flat_tensors_for_gradients(xs):
160 """Returns a flat list of Tensors that should be differentiated for `xs`.
162 Args:
163 xs: A list of `Tensor`s or `CompositeTensor`s.
165 Returns:
166 A flat list of `Tensor`s constructed from `xs`, where `Tensor` values are
167 left as-is, and `CompositeTensor`s are replaced with
168 `_get_tensors_for_gradient(x)`.
169 """
170 return nest.flatten([_get_tensors_for_gradient(x) for x in xs])
173def replace_flat_tensors_for_gradients(xs, flat_grads):
174 """Replaces Tensors that should be differentiated in `xs` with `flat_grads`.
176 Args:
177 xs: A list of `Tensor`s or `CompositeTensor`s.
178 flat_grads: A list of `Tensor`.
180 Returns:
181 A list of `Tensor` or `CompositeTensor`.
182 """
183 xs_structure = [_get_tensors_for_gradient(x) for x in xs]
184 grads = nest.pack_sequence_as(xs_structure, flat_grads)
185 return [_replace_tensors_for_gradient(x, grad) for x, grad in zip(xs, grads)]