Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/checkpoint/graph_view.py: 38%

53 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1"""Manages a graph of Trackable objects.""" 

2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

3# 

4# Licensed under the Apache License, Version 2.0 (the "License"); 

5# you may not use this file except in compliance with the License. 

6# You may obtain a copy of the License at 

7# 

8# http://www.apache.org/licenses/LICENSE-2.0 

9# 

10# Unless required by applicable law or agreed to in writing, software 

11# distributed under the License is distributed on an "AS IS" BASIS, 

12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

13# See the License for the specific language governing permissions and 

14# limitations under the License. 

15# ============================================================================== 

16import copy 

17import weakref 

18 

19from tensorflow.python.checkpoint import save_util_v1 

20from tensorflow.python.checkpoint import trackable_view 

21from tensorflow.python.trackable import base 

22from tensorflow.python.util.tf_export import tf_export 

23 

24 

25@tf_export("__internal__.tracking.ObjectGraphView", v1=[]) 

26class ObjectGraphView(trackable_view.TrackableView): 

27 """Gathers and serializes an object graph.""" 

28 

29 def __init__(self, root, attached_dependencies=None): 

30 """Configure the graph view. 

31 

32 Args: 

33 root: A `Trackable` object whose variables (including the variables of 

34 dependencies, recursively) should be saved. May be a weak reference. 

35 attached_dependencies: List of dependencies to attach to the root object. 

36 Used when saving a Checkpoint with a defined root object. To avoid 

37 reference cycles, this should use the WeakTrackableReference class. 

38 """ 

39 trackable_view.TrackableView.__init__(self, root) 

40 # ObjectGraphView should never contain a strong reference to root, since it 

41 # may result in a cycle: 

42 # root -> deferred dependencies -> CheckpointPosition 

43 # -> CheckpointRestoreCoordinator -> ObjectGraphView -> root 

44 self._root_ref = (root if isinstance(root, weakref.ref) 

45 else weakref.ref(root)) 

46 self._attached_dependencies = attached_dependencies 

47 

48 def __deepcopy__(self, memo): 

49 # By default, weak references are not copied, which leads to surprising 

50 # deepcopy behavior. To fix, we first we copy the object itself, then we 

51 # make a weak reference to the copy. 

52 strong_root = self._root_ref() 

53 if strong_root is not None: 

54 strong_copy = copy.deepcopy(strong_root, memo) 

55 memo[id(self._root_ref)] = weakref.ref(strong_copy) 

56 # super() does not have a __deepcopy__, so we need to re-implement it 

57 copied = super().__new__(type(self)) 

58 memo[id(self)] = copied 

59 for key, value in vars(self).items(): 

60 setattr(copied, key, copy.deepcopy(value, memo)) 

61 return copied 

62 

63 def list_children(self, obj, save_type=base.SaveType.CHECKPOINT, **kwargs): 

64 """Returns list of all child trackables attached to obj. 

65 

66 Args: 

67 obj: A `Trackable` object. 

68 save_type: A string, can be 'savedmodel' or 'checkpoint'. 

69 **kwargs: kwargs to use when retrieving the object's children. 

70 

71 Returns: 

72 List of all children attached to the object. 

73 """ 

74 children = [] 

75 for name, ref in super(ObjectGraphView, 

76 self).children(obj, save_type, **kwargs).items(): 

77 children.append(base.TrackableReference(name, ref)) 

78 

79 # GraphView objects may define children of the root object that are not 

80 # actually attached, e.g. a Checkpoint object's save_counter. 

81 if obj is self.root and self._attached_dependencies: 

82 children.extend(self._attached_dependencies) 

83 return children 

84 

85 def children(self, obj, save_type=base.SaveType.CHECKPOINT, **kwargs): 

86 """Returns all child trackables attached to obj. 

87 

88 Args: 

89 obj: A `Trackable` object. 

90 save_type: A string, can be 'savedmodel' or 'checkpoint'. 

91 **kwargs: kwargs to use when retrieving the object's children. 

92 

93 Returns: 

94 Dictionary of all children attached to the object with name to trackable. 

95 """ 

96 children = {} 

97 for name, ref in self.list_children(obj, **kwargs): 

98 children[name] = ref 

99 return children 

100 

101 @property 

102 def attached_dependencies(self): 

103 """Returns list of dependencies that should be saved in the checkpoint. 

104 

105 These dependencies are not tracked by root, but are in the checkpoint. 

106 This is defined when the user creates a Checkpoint with both root and kwargs 

107 set. 

108 

109 Returns: 

110 A list of TrackableReferences. 

111 """ 

112 return self._attached_dependencies 

113 

114 @property 

115 def root(self): 

116 if isinstance(self._root_ref, weakref.ref): 

117 derefed = self._root_ref() 

118 assert derefed is not None 

119 return derefed 

120 else: 

121 return self._root_ref 

122 

123 def breadth_first_traversal(self): 

124 return self._breadth_first_traversal() 

125 

126 def _breadth_first_traversal(self): 

127 """Find shortest paths to all dependencies of self.root.""" 

128 return super(ObjectGraphView, self)._descendants_with_paths() 

129 

130 def serialize_object_graph(self, saveables_cache=None): 

131 """Determine checkpoint keys for variables and build a serialized graph. 

132 

133 Non-slot variables are keyed based on a shortest path from the root saveable 

134 to the object which owns the variable (i.e. the one which called 

135 `Trackable._add_variable` to create it). 

136 

137 Slot variables are keyed based on a shortest path to the variable being 

138 slotted for, a shortest path to their optimizer, and the slot name. 

139 

140 Args: 

141 saveables_cache: An optional cache storing previously created 

142 SaveableObjects created for each Trackable. Maps Trackables to a 

143 dictionary of attribute names to Trackable. 

144 

145 Returns: 

146 A tuple of (named_variables, object_graph_proto, feed_additions): 

147 named_variables: A dictionary mapping names to variable objects. 

148 object_graph_proto: A TrackableObjectGraph protocol buffer 

149 containing the serialized object graph and variable references. 

150 feed_additions: A dictionary mapping from Tensors to values which should 

151 be fed when saving. 

152 

153 Raises: 

154 ValueError: If there are invalid characters in an optimizer's slot names. 

155 """ 

156 named_saveable_objects, object_graph_proto, feed_additions, _ = ( 

157 save_util_v1.serialize_object_graph_with_registered_savers( 

158 self, saveables_cache)) 

159 return named_saveable_objects, object_graph_proto, feed_additions 

160 

161 def frozen_saveable_objects(self, 

162 object_map=None, 

163 to_graph=None, 

164 call_with_mapped_captures=None): 

165 """Creates SaveableObjects with the current object graph frozen.""" 

166 return save_util_v1.frozen_saveables_and_savers( 

167 self, object_map, to_graph, call_with_mapped_captures)[0]