Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/tape.py: 44%

41 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Gradient tape utilities.""" 

16 

17from tensorflow.python import pywrap_tfe 

18 

19 

20class Tape(object): 

21 """Represents a gradient propagation trace.""" 

22 

23 __slots__ = ["_tape"] 

24 

25 def __init__(self, tape): 

26 self._tape = tape 

27 

28 def watched_variables(self): 

29 return pywrap_tfe.TFE_Py_TapeWatchedVariables(self._tape) 

30 

31 

32def push_new_tape(persistent=False, watch_accessed_variables=True): 

33 """Pushes a new tape onto the tape stack.""" 

34 tape = pywrap_tfe.TFE_Py_TapeSetNew(persistent, watch_accessed_variables) 

35 return Tape(tape) 

36 

37 

38def push_tape(tape): 

39 """Pushes an existing tape onto the tape stack.""" 

40 pywrap_tfe.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access 

41 

42 

43def watch(tape, tensor): 

44 """Marks this tensor to be watched by the given tape.""" 

45 pywrap_tfe.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access 

46 

47 

48def default_get_variables(variable): 

49 return [variable] 

50 

51# Gets a list of changed variables. Can be overriden using 

52# register_variables_override. An example of overriding is for getting the 

53# varibles within a distributed context. 

54_variables_override = default_get_variables 

55 

56 

57def register_watched_variable_resolver(resolver): 

58 """Registers the resolver to be used to get the list of variables to watch. 

59 

60 Args: 

61 resolver: callable, takes a Variable and returns a list of Variables that 

62 shall be watched. 

63 """ 

64 global _variables_override 

65 assert _variables_override is default_get_variables 

66 _variables_override = resolver 

67 

68 

69def watch_variable(tape, variable): 

70 """Marks this variable to be watched by the given tape.""" 

71 variables = _variables_override(variable) 

72 for var in variables: 

73 pywrap_tfe.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access 

74 pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var) 

75 

76 

77def variable_accessed(variable): 

78 """Notifies all tapes in the stack that a variable has been accessed. 

79 

80 Args: 

81 variable: variable to be watched. 

82 """ 

83 variables = _variables_override(variable) 

84 for var in variables: 

85 pywrap_tfe.TFE_Py_TapeVariableAccessed(var) 

86 pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var) 

87 

88 

89def variables_accessed(variables): 

90 """Notifies all tapes in the stack that variables have been accessed. 

91 

92 Only trainable variables are marked as accessed. 

93 

94 Args: 

95 variables: iterable of variables to mark as accessed. 

96 """ 

97 accessed = [] 

98 for variable in variables: 

99 if variable.trainable: 

100 accessed.extend(_variables_override(variable)) 

101 

102 for var in accessed: 

103 pywrap_tfe.TFE_Py_TapeVariableAccessed(var) 

104 pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var) 

105 

106 

107def pop_tape(tape): 

108 """Pops the given tape in the stack.""" 

109 pywrap_tfe.TFE_Py_TapeSetRemove(tape._tape) # pylint: disable=protected-access