Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/forwardprop.py: 30%
141 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for forward-mode automatic differentiation."""
17import functools
18import threading
20from tensorflow.python import pywrap_tfe
21from tensorflow.python.eager import backprop
22from tensorflow.python.eager import backprop_util
23from tensorflow.python.eager import execute
24from tensorflow.python.eager import forwardprop_util
25from tensorflow.python.eager.polymorphic_function import tracing_compiler
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_shape
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops.parallel_for import control_flow_ops
32from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
33from tensorflow.python.platform import tf_logging as logging
34from tensorflow.python.util import nest
35from tensorflow.python.util.tf_export import tf_export
38# Dictionary mapping from op names to special-cased jvp functions. Otherwise
39# backward functions are transposed on the tape.
40_SPECIAL_CASES = {}
43def _identity_jvp(attr_tuple, inputs, outputs, tangents):
44 # Special-cased mostly for resource handles, where creating ones Tensors from
45 # handle data for transposing the backward function on the tape is error-prone
46 # (even if we get good handle data, partially defined shapes are an issue).
47 del attr_tuple, inputs, outputs
48 return [array_ops.identity(t) for t in tangents]
51_SPECIAL_CASES["Identity"] = _identity_jvp
54def _read_variable_jvp(attr_tuple, inputs, outputs, tangents):
55 # Like for Identity, this special case means we don't need to create
56 # variable-shaped Tensors from resource handles.
57 del attr_tuple, inputs, outputs
58 return [array_ops.identity(t) for t in tangents]
61_SPECIAL_CASES["ReadVariableOp"] = _read_variable_jvp
64_TRACE_COUNT_CONSISTENCY_LOCK = threading.Lock()
65# Map from op names to number of traces of _jvp_helper. Used to cap the number
66# of traces due to shape differences while still specializing where possible.
67_TRACE_COUNT = {}
70def _jvp_helper(op_name, attr_tuple, inputs, outputs, tangents):
71 """Computes a Jacobian-vector product for an op.
73 Note that this function would be wasteful if executed eagerly. It runs the
74 backward gradient function and throws away the result just to record its
75 operations on a GradientTape. These unused ops are pruned away when this
76 function is traced.
78 Args:
79 op_name: A string, the type of operation being executed.
80 attr_tuple: Attributes of the operation.
81 inputs: A flat list of input Tensors to the operation.
82 outputs: A flat list of output Tensors from the operation.
83 tangents: A flat list of Tensors, same shape as `inputs`.
85 Returns:
86 A flat list of tangents corresponding to `outputs`.
87 """
88 with _TRACE_COUNT_CONSISTENCY_LOCK:
89 # Just make sure writes don't clobber each other's increments; reads in
90 # _jvp_dispatch do not lock.
91 _TRACE_COUNT[op_name] = _TRACE_COUNT.get(op_name, 0) + 1
93 special_case = _SPECIAL_CASES.get(op_name, None)
94 if special_case is not None:
95 return special_case(attr_tuple, inputs, outputs, tangents)
96 if not outputs:
97 # tape.gradients([], inputs) doesn't make much sense
98 return []
99 # Generally inner GradientTapes won't function while outer accumulators are
100 # recording. We temporarily reset forwardprop state to allow GradientTapes to
101 # function here.
102 with forwardprop_util.push_forwardprop_state():
103 trainable_inputs = []
104 trainable_indices = []
105 nontrivial_tangents = []
106 for input_index, tensor in enumerate(inputs):
107 if backprop_util.IsTrainable(tensor):
108 trainable_inputs.append(tensor)
109 trainable_indices.append(input_index)
110 nontrivial_tangents.append(tangents[input_index])
112 with backprop.GradientTape() as transpose_tape:
113 with backprop.GradientTape() as backfunc_tape:
114 backfunc_tape.watch(trainable_inputs)
115 execute.record_gradient(op_name, inputs, attr_tuple, outputs)
117 forwardprop_aids = []
118 trainable_outputs = []
119 nontrivial_output_indices = []
120 for output_index, output in enumerate(outputs):
121 if backprop_util.IsTrainable(output):
122 forwardprop_aids.append(
123 array_ops.ones_like(output, name="unused_forwardprop_aid"))
124 trainable_outputs.append(output)
125 nontrivial_output_indices.append(output_index)
127 transpose_tape.watch(forwardprop_aids)
128 grads = backfunc_tape.gradient(
129 trainable_outputs,
130 trainable_inputs,
131 forwardprop_aids,
132 unconnected_gradients=UnconnectedGradients.ZERO)
133 nontrivial_output_tangents = transpose_tape.gradient(
134 grads, forwardprop_aids, output_gradients=nontrivial_tangents)
135 output_tangents = [None] * len(outputs)
136 for index, tangent in zip(nontrivial_output_indices,
137 nontrivial_output_tangents):
138 output_tangents[index] = tangent
139 return output_tangents
142def _jvp_helper_wrapper(op_name, attr_tuple, inputs, outputs, tangents,
143 use_batch):
144 """Computes a batch of Jacobian-vector product for an op.
146 Args:
147 op_name: A string, the type of operation being executed.
148 attr_tuple: Attributes of the operation.
149 inputs: A flat list of input Tensors to the operation.
150 outputs: A flat list of output Tensors from the operation.
151 tangents: A flat list of Tensors, compatible with shape `[None] +
152 input_shape`.
153 use_batch: A bool, True to vetorize over batch of tangents of shape `[None]
154 + input_shape`.
156 Returns:
157 A flat list of tangents compatible with `outputs`
158 or `[None] + output_shape`.
160 Raises:
161 ValueError: if tangent shapes are not compatible with input shapes.
162 """
163 if use_batch:
164 for primal, tangent in zip(inputs, tangents):
165 if not tangent.shape.is_compatible_with([None] + primal.shape):
166 raise ValueError("Tangent {} was expected to be of shape "
167 "{} but is instead of shape {}".format(
168 tangent, [None] + primal.shape, tangent.shape))
170 return control_flow_ops.vectorized_map(
171 functools.partial(_jvp_helper, op_name, attr_tuple, inputs, outputs),
172 tangents,
173 )
174 return _jvp_helper(op_name, attr_tuple, inputs, outputs, tangents)
177# TODO(allenl): reduce_retracing for gradients which rely on static
178# shape information are underspecialized. We may want hand-written forward
179# implementations, or a more satisfying story about how we re-specialize
180# gradients which were traced with relaxed shapes (e.g. use conds instead of
181# trace-time Python logic).
182#
183# Using function.defun rather than def_function.function avoids
184# tf.config.run_functions_eagerly(True). `_jvp_helper` doesn't successfully run
185# eagerly (infinite recursion), and even if it did it would use extra memory and
186# run unnecessary computation. The function does not create variables, so the
187# two symbols are otherwise equivalent.
188_jvp_relaxed_shapes = tracing_compiler.TracingCompiler(
189 _jvp_helper_wrapper, name="_jvp_relaxed_shapes", reduce_retracing=True)
190_jvp_exact_shapes = tracing_compiler.TracingCompiler(
191 _jvp_helper_wrapper, name="_jvp_exact_shapes", reduce_retracing=False)
193# The maximum number of exact-shape traces to perform for a single op before
194# switching to shape relaxation.
195_TRACE_COUNT_LIMIT = 32
198def _jvp_dispatch(op_name,
199 attr_tuple,
200 inputs,
201 outputs,
202 tangents,
203 use_batch=False):
204 """Determine which forwardprop function to call."""
205 # Note that this _TRACE_COUNT read races with writes. That's fine, it just
206 # means we may trace a few more exact shapes before moving on to relaxation.
207 if _TRACE_COUNT.get(op_name, 0) < _TRACE_COUNT_LIMIT:
208 return _jvp_exact_shapes(op_name, attr_tuple, inputs, outputs, tangents,
209 use_batch)
210 return _jvp_relaxed_shapes(op_name, attr_tuple, inputs, outputs, tangents,
211 use_batch)
214pywrap_tfe.TFE_Py_RegisterJVPFunction(_jvp_dispatch)
217@tf_export("autodiff.ForwardAccumulator", v1=[])
218class ForwardAccumulator():
219 """Computes Jacobian-vector products ("JVP"s) using forward-mode autodiff.
221 Compare to `tf.GradientTape` which computes vector-Jacobian products ("VJP"s)
222 using reverse-mode autodiff (backprop). Reverse mode is more attractive when
223 computing gradients of a scalar-valued function with respect to many inputs
224 (e.g. a neural network with many parameters and a scalar loss). Forward mode
225 works best on functions with many outputs and few inputs. Since it does not
226 hold on to intermediate activations, it is much more memory efficient than
227 backprop where it is applicable.
229 Consider a simple linear regression:
231 >>> x = tf.constant([[2.0, 3.0], [1.0, 4.0]])
232 >>> targets = tf.constant([[1.], [-1.]])
233 >>> dense = tf.keras.layers.Dense(1)
234 >>> dense.build([None, 2])
235 >>> with tf.autodiff.ForwardAccumulator(
236 ... primals=dense.kernel,
237 ... tangents=tf.constant([[1.], [0.]])) as acc:
238 ... loss = tf.reduce_sum((dense(x) - targets) ** 2.)
239 >>> acc.jvp(loss)
240 <tf.Tensor: shape=(), dtype=float32, numpy=...>
242 The example has two variables containing parameters, `dense.kernel` (2
243 parameters) and `dense.bias` (1 parameter). Considering the training data `x`
244 as a constant, this means the Jacobian matrix for the function mapping from
245 parameters to loss has one row and three columns.
247 With forwardprop, we specify a length-three vector in advance which multiplies
248 the Jacobian. The `primals` constructor argument is the parameter (a
249 `tf.Tensor` or `tf.Variable`) we're specifying a vector for, and the
250 `tangents` argument is the "vector" in Jacobian-vector product. If our goal is
251 to compute the entire Jacobian matrix, forwardprop computes one column at a
252 time while backprop computes one row at a time. Since the Jacobian in the
253 linear regression example has only one row, backprop requires fewer
254 invocations:
256 >>> x = tf.constant([[2.0, 3.0], [1.0, 4.0]])
257 >>> targets = tf.constant([[1.], [-1.]])
258 >>> dense = tf.keras.layers.Dense(1)
259 >>> dense.build([None, 2])
260 >>> loss_fn = lambda: tf.reduce_sum((dense(x) - targets) ** 2.)
261 >>> kernel_fprop = []
262 >>> with tf.autodiff.ForwardAccumulator(
263 ... dense.kernel, tf.constant([[1.], [0.]])) as acc:
264 ... kernel_fprop.append(acc.jvp(loss_fn()))
265 >>> with tf.autodiff.ForwardAccumulator(
266 ... dense.kernel, tf.constant([[0.], [1.]])) as acc:
267 ... kernel_fprop.append(acc.jvp(loss_fn()))
268 >>> with tf.autodiff.ForwardAccumulator(dense.bias, tf.constant([1.])) as acc:
269 ... bias_fprop = acc.jvp(loss_fn())
270 >>> with tf.GradientTape() as tape:
271 ... loss = loss_fn()
272 >>> kernel_grad, bias_grad = tape.gradient(loss, (dense.kernel, dense.bias))
273 >>> np.testing.assert_allclose(
274 ... kernel_grad, tf.stack(kernel_fprop)[:, tf.newaxis])
275 >>> np.testing.assert_allclose(bias_grad, bias_fprop[tf.newaxis])
277 Implicit in the `tape.gradient` call is a length-one vector which
278 left-multiplies the Jacobian, a vector-Jacobian product.
280 `ForwardAccumulator` maintains JVPs corresponding primal tensors it is
281 watching, derived from the original `primals` specified in the constructor. As
282 soon as a primal tensor is deleted, `ForwardAccumulator` deletes the
283 corresponding JVP.
285 `acc.jvp(x)` retrieves `acc`'s JVP corresponding to the primal tensor `x`. It
286 does not perform any computation. `acc.jvp` calls can be repeated as long as
287 `acc` is accessible, whether the context manager is active or not. New JVPs
288 are only computed while the context manager is active.
290 Note that `ForwardAccumulator`s are always applied in the order their context
291 managers were entered, so inner accumulators will not see JVP computation from
292 outer accumulators. Take higher-order JVPs from outer accumulators:
294 >>> primal = tf.constant(1.1)
295 >>> with tf.autodiff.ForwardAccumulator(primal, tf.constant(1.)) as outer:
296 ... with tf.autodiff.ForwardAccumulator(primal, tf.constant(1.)) as inner:
297 ... primal_out = primal ** tf.constant(3.5)
298 >>> inner_jvp = inner.jvp(primal_out)
299 >>> inner_jvp # 3.5 * 1.1 ** 2.5
300 <tf.Tensor: shape=(), dtype=float32, numpy=4.4417057>
301 >>> outer.jvp(inner_jvp) # 3.5 * 2.5 * 1.1 ** 1.5
302 <tf.Tensor: shape=(), dtype=float32, numpy=10.094786>
304 Reversing the collection in the last line to instead retrieve
305 `inner.jvp(outer.jvp(primal_out))` will not work.
307 Strict nesting also applies to combinations of `ForwardAccumulator` and
308 `tf.GradientTape`. More deeply nested `GradientTape` objects will ignore the
309 products of outer `ForwardAccumulator` objects. This allows (for example)
310 memory-efficient forward-over-backward computation of Hessian-vector products,
311 where the inner `GradientTape` would otherwise hold on to all intermediate
312 JVPs:
314 >>> v = tf.Variable([1., 2.])
315 >>> with tf.autodiff.ForwardAccumulator(
316 ... v,
317 ... # The "vector" in Hessian-vector product.
318 ... tf.constant([1., 0.])) as acc:
319 ... with tf.GradientTape() as tape:
320 ... y = tf.reduce_sum(v ** 3.)
321 ... backward = tape.gradient(y, v)
322 >>> backward # gradient from backprop
323 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([ 3., 12.], dtype=float32)>
324 >>> acc.jvp(backward) # forward-over-backward Hessian-vector product
325 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 0.], dtype=float32)>
326 """
328 def __init__(self, primals, tangents):
329 """Specify tensors to watch and their Jacobian-vector products.
331 Mathematically, `tangents` is a vector right-multiplying the Jacobian matrix
332 (a Jacobian-vector product) for the function computed while this accumulator
333 is active. Since JVPs are computed in forward mode as the computation
334 happens, this vector must be supplied in advance.
336 Listing a single tensor multiple times in `primals` raises an
337 exception. Excluding a tensor from `primals` is equivalent to watching it
338 with a tangent tensor of zeros.
340 Args:
341 primals: A tensor or nested structure of tensors to watch.
342 tangents: A tensor or nested structure of tensors, with the same nesting
343 structure as `primals`, with each element being a vector with the same
344 size as the corresponding primal element.
346 Raises:
347 ValueError: If the same tensor or variable is specified multiple times in
348 `primals`.
349 """
350 self._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew(False)
351 self._recording = False
352 primal_ids = set()
353 for primal in nest.flatten(primals):
354 if id(primal) in primal_ids:
355 raise ValueError(
356 "Tensor {} was specified as a primal multiple times. This may "
357 "indicate an error. If it was intended, please sum the "
358 "corresponding tangents.")
359 primal_ids.add(id(primal))
360 self._watch(primals, tangents)
362 def __enter__(self):
363 self._push_accumulator()
364 return self
366 def __exit__(self, typ, value, traceback):
367 if self._recording:
368 self._pop_accumulator()
370 def _push_accumulator(self):
371 if self._recording:
372 raise ValueError("Accumulator is already recording.")
373 pywrap_tfe.TFE_Py_ForwardAccumulatorSetAdd(self._accumulator)
374 self._recording = True
376 def _pop_accumulator(self):
377 if not self._recording:
378 raise ValueError("Accumulator is not recording.")
379 pywrap_tfe.TFE_Py_ForwardAccumulatorSetRemove(self._accumulator)
380 self._recording = False
382 def _watch(self, primals, tangents):
383 """Ensures that `primals` are being traced by this accumulator.
385 Mathematically, `tangents` is a vector right-multiplying the Jacobian matrix
386 (a Jacobian-vector product) for the function computed while this accumulator
387 is active. Since JVPs are computed in forward mode as the computation
388 happens, this vector must be supplied in advance.
390 Watching a single tensor multiple times sums each of its `tangents`. Any
391 un-watched tensor has zeros for its tangent vector.
393 Args:
394 primals: A Tensor or list of Tensors.
395 tangents: A Tensor or list of Tensors matching `primals`.
396 """
398 def _watch(primal, tangent):
399 if not primal.dtype.is_floating:
400 logging.log_first_n(
401 logging.WARN, "The dtype of the watched primal must be "
402 "floating (e.g. tf.float32), got %r", 5, primal.dtype)
403 tangent = ops.convert_to_tensor(tangent, dtype=primal.dtype)
404 if hasattr(primal, "handle"):
405 # Run convert_to_tensor to get the captured handle from whichever
406 # function we're running if necessary.
407 primal = ops.convert_to_tensor(primal.handle)
408 pywrap_tfe.TFE_Py_ForwardAccumulatorWatch(self._accumulator, primal,
409 tangent)
411 nest.map_structure(_watch, primals, tangents)
413 def jvp(self, primals, unconnected_gradients=UnconnectedGradients.NONE):
414 """Fetches the Jacobian-vector product computed for `primals`.
416 Note that this method performs no computation, and simply looks up a JVP
417 that was already computed (unlike backprop using a `tf.GradientTape`, where
418 the computation happens on the call to `tape.gradient`).
420 Args:
421 primals: A watched Tensor or structure of Tensors to fetch the JVPs for.
422 unconnected_gradients: A value which can either hold 'none' or 'zero' and
423 alters the value which will be returned if no JVP was computed for
424 `primals`. The possible values and effects are detailed in
425 'tf.UnconnectedGradients' and it defaults to 'none'.
427 Returns:
428 Tensors with the same shapes and dtypes as `primals`, or None if no JVP
429 is available.
430 """
431 unconnected_gradients = UnconnectedGradients(unconnected_gradients)
432 if self._accumulator is None:
433 raise ValueError("Called jvp() without first tracing anything.")
435 def _fetch_jvp(tensor):
436 if hasattr(tensor, "handle"):
437 unwrapped_tensor = ops.convert_to_tensor(tensor.handle)
438 else:
439 unwrapped_tensor = tensor
440 result = pywrap_tfe.TFE_Py_ForwardAccumulatorJVP(self._accumulator,
441 unwrapped_tensor)
442 if result is None and unconnected_gradients == UnconnectedGradients.ZERO:
443 result = array_ops.zeros_like(tensor)
444 return result
446 return nest.map_structure(_fetch_jvp, primals)
448 @classmethod
449 def _batch_accumulator(cls, primals, tangents):
450 """Factory constructor to test accumulator on batches of tangents.
452 Args:
453 primals: A tensor or nested structure of tensors to watch.
454 tangents: A tensor or nested structure of tensors, with the same nesting
455 structure as `primals`, with each element being a vector with compatible
456 shape `[None] + primal.shape` of the corresponding primal element.
458 Returns:
459 A batch accumulator object.
460 """
461 acc = super(ForwardAccumulator, cls).__new__(cls, primals, tangents)
462 acc._recording = False
463 acc._accumulator = pywrap_tfe.TFE_Py_ForwardAccumulatorNew(True)
464 primal_ids = set()
465 for primal, tangent in zip(nest.flatten(primals), nest.flatten(tangents)):
466 tangent.shape.assert_is_compatible_with(
467 tensor_shape.TensorShape([None]) + primal.shape)
468 if id(primal) in primal_ids:
469 raise ValueError(
470 "Tensor {} was specified as a primal multiple times. This may "
471 "indicate an error. If it was intended, please sum the "
472 "corresponding tangents.")
473 primal_ids.add(id(primal))
474 acc._watch(primals, tangents)
475 return acc