Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/tape.py: 44%
41 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gradient tape utilities."""
17from tensorflow.python import pywrap_tfe
20class Tape(object):
21 """Represents a gradient propagation trace."""
23 __slots__ = ["_tape"]
25 def __init__(self, tape):
26 self._tape = tape
28 def watched_variables(self):
29 return pywrap_tfe.TFE_Py_TapeWatchedVariables(self._tape)
32def push_new_tape(persistent=False, watch_accessed_variables=True):
33 """Pushes a new tape onto the tape stack."""
34 tape = pywrap_tfe.TFE_Py_TapeSetNew(persistent, watch_accessed_variables)
35 return Tape(tape)
38def push_tape(tape):
39 """Pushes an existing tape onto the tape stack."""
40 pywrap_tfe.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access
43def watch(tape, tensor):
44 """Marks this tensor to be watched by the given tape."""
45 pywrap_tfe.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access
48def default_get_variables(variable):
49 return [variable]
51# Gets a list of changed variables. Can be overriden using
52# register_variables_override. An example of overriding is for getting the
53# varibles within a distributed context.
54_variables_override = default_get_variables
57def register_watched_variable_resolver(resolver):
58 """Registers the resolver to be used to get the list of variables to watch.
60 Args:
61 resolver: callable, takes a Variable and returns a list of Variables that
62 shall be watched.
63 """
64 global _variables_override
65 assert _variables_override is default_get_variables
66 _variables_override = resolver
69def watch_variable(tape, variable):
70 """Marks this variable to be watched by the given tape."""
71 variables = _variables_override(variable)
72 for var in variables:
73 pywrap_tfe.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access
74 pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
77def variable_accessed(variable):
78 """Notifies all tapes in the stack that a variable has been accessed.
80 Args:
81 variable: variable to be watched.
82 """
83 variables = _variables_override(variable)
84 for var in variables:
85 pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
86 pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
89def variables_accessed(variables):
90 """Notifies all tapes in the stack that variables have been accessed.
92 Only trainable variables are marked as accessed.
94 Args:
95 variables: iterable of variables to mark as accessed.
96 """
97 accessed = []
98 for variable in variables:
99 if variable.trainable:
100 accessed.extend(_variables_override(variable))
102 for var in accessed:
103 pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
104 pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
107def pop_tape(tape):
108 """Pops the given tape in the stack."""
109 pywrap_tfe.TFE_Py_TapeSetRemove(tape._tape) # pylint: disable=protected-access