Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/trackable/autotrackable.py: 27%

64 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Dependency tracking for trackable objects.""" 

16 

17import warnings 

18 

19from absl import logging 

20 

21from tensorflow.python.eager import def_function 

22from tensorflow.python.eager import function as defun 

23from tensorflow.python.trackable import base 

24from tensorflow.python.trackable import data_structures 

25from tensorflow.python.types import core as core_types 

26from tensorflow.python.util.tf_export import tf_export 

27 

28 

29@tf_export("__internal__.tracking.AutoTrackable", v1=[]) 

30class AutoTrackable(base.Trackable): 

31 """Manages dependencies on other objects. 

32 

33 `Trackable` objects may have dependencies: other `Trackable` objects 

34 which should be saved if the object declaring the dependency is saved. A 

35 correctly saveable program has a dependency graph such that if changing a 

36 global variable affects an object (e.g. changes the behavior of any of its 

37 methods) then there is a chain of dependencies from the influenced object to 

38 the variable. 

39 

40 Dependency edges have names, and are created implicitly when a 

41 `Trackable` object is assigned to an attribute of another 

42 `Trackable` object. For example: 

43 

44 ``` 

45 obj = Trackable() 

46 obj.v = ResourceVariable(0.) 

47 ``` 

48 

49 The `Trackable` object `obj` now has a dependency named "v" on a 

50 variable. 

51 

52 `Trackable` objects may specify `Tensor`s to be saved and restored 

53 directly (e.g. a `Variable` indicating how to save itself) rather than through 

54 dependencies on other objects. See 

55 `Trackable._gather_saveables_for_checkpoint` for details. 

56 """ 

57 

58 def __setattr__(self, name, value): 

59 """Support self.foo = trackable syntax.""" 

60 try: 

61 if getattr(self, name) is value: 

62 # Short circuit for `self.$x = self.$x`. 

63 return 

64 except AttributeError: 

65 pass 

66 

67 if getattr(self, "_self_setattr_tracking", True): 

68 value = data_structures.sticky_attribute_assignment( 

69 trackable=self, value=value, name=name) 

70 super(AutoTrackable, self).__setattr__(name, value) 

71 

72 def __delattr__(self, name): 

73 self._delete_tracking(name) 

74 super(AutoTrackable, self).__delattr__(name) 

75 

76 def _no_dependency(self, value): 

77 """Override to allow TrackableBase to disable dependency tracking.""" 

78 return data_structures.NoDependency(value) 

79 

80 def _trackable_children(self, save_type=base.SaveType.CHECKPOINT, **kwargs): 

81 """Returns all children of a trackable, including functions.""" 

82 if save_type != base.SaveType.SAVEDMODEL: 

83 return super(AutoTrackable, self)._trackable_children( 

84 save_type, **kwargs) 

85 

86 functions = {} 

87 try: 

88 # We get the attributes, suppressing warnings and exceptions. 

89 logging_verbosity = logging.get_verbosity() 

90 logging.set_verbosity(logging.FATAL) 

91 for attribute_name in dir(self): 

92 try: 

93 with warnings.catch_warnings(): 

94 warnings.simplefilter("ignore") 

95 attribute_value = getattr(self, attribute_name, None) 

96 except Exception: # pylint: disable=broad-except 

97 # NOTE: If we make the exception catching here less broad, we might 

98 # need to revisit `finally` block below. 

99 # We really don't want to throw an exception just because some 

100 # object's attribute accessor is broken. 

101 attribute_value = None 

102 if isinstance(attribute_value, (def_function.Function, 

103 defun.ConcreteFunction)): 

104 functions[attribute_name] = attribute_value 

105 finally: 

106 logging.set_verbosity(logging_verbosity) 

107 

108 # Trace concrete functions to force side-effects: 

109 # 1. populate the cache for functions that have an input_signature 

110 # and have not been called 

111 # 2. force side effects of creation of concrete functions, e.g. create 

112 # variables on first run. 

113 for fn in functions.values(): 

114 if isinstance(fn, core_types.GenericFunction): 

115 fn._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access 

116 

117 # Additional dependencies may have been generated during function tracing 

118 # (e.g. captured variables). Make sure we return those too. 

119 children = {} 

120 for name, child in self._checkpoint_dependencies: 

121 if isinstance(child, (core_types.GenericFunction, 

122 core_types.ConcreteFunction)): 

123 # Skip "tracked" functions for now since there may be objects that 

124 # automatically track functions that should not be saved. 

125 # TODO(kathywu): remove once `_list_functions_for_serialization` has 

126 # been fully deprecated. 

127 continue 

128 

129 if name in functions and child is not functions[name]: 

130 raise ValueError( 

131 "Can't save object because it has multiple children with the same " 

132 f"name. Object: {self}, attribute name: {name}, child 1: " 

133 f"{child}, child 2: {functions[name]}") 

134 

135 children[name] = child 

136 

137 children.update(functions) 

138 return children 

139 

140 def _delete_tracking(self, name): 

141 """Removes the tracking of name.""" 

142 self._maybe_initialize_trackable() 

143 if name in self._unconditional_dependency_names: 

144 del self._unconditional_dependency_names[name] 

145 for index, (dep_name, _) in enumerate( 

146 self._unconditional_checkpoint_dependencies): 

147 if dep_name == name: 

148 del self._unconditional_checkpoint_dependencies[index] 

149 break 

150 

151 def _add_trackable_child(self, name, value): 

152 self.__setattr__(name, value)