Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/checkpoint/util.py: 15%

81 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities for extracting and writing checkpoint info`.""" 

16 

17from tensorflow.core.protobuf import trackable_object_graph_pb2 

18from tensorflow.python.ops import resource_variable_ops 

19from tensorflow.python.ops import variables 

20from tensorflow.python.trackable import trackable_utils 

21from tensorflow.python.util import object_identity 

22 

23 

24def serialize_slot_variables(trackable_objects, node_ids, object_names): 

25 """Gather and name slot variables.""" 

26 non_slot_objects = list(trackable_objects) 

27 slot_variables = object_identity.ObjectIdentityDictionary() 

28 for trackable in non_slot_objects: 

29 # TODO(b/110718070): Fix Keras imports. 

30 # Note: dir() is used rather than hasattr() here to avoid triggering 

31 # custom __getattr__ code, see b/152031870 for context. 

32 if "get_slot_names" in dir(trackable): 

33 slot_names = trackable.get_slot_names() 

34 for slot_name in slot_names: 

35 for original_variable_node_id, original_variable in enumerate( 

36 non_slot_objects): 

37 try: 

38 slot_variable = trackable.get_slot(original_variable, slot_name) 

39 except (AttributeError, KeyError): 

40 slot_variable = None 

41 if slot_variable is None: 

42 continue 

43 slot_variable._maybe_initialize_trackable() # pylint: disable=protected-access 

44 if slot_variable._trackable_children(): # pylint: disable=protected-access 

45 # TODO(allenl): Gather dependencies of slot variables. 

46 raise NotImplementedError( 

47 "Currently only variables with no dependencies can be saved as " 

48 "slot variables. File a feature request if this limitation " 

49 "bothers you.") 

50 if slot_variable in node_ids: 

51 raise NotImplementedError( 

52 "A slot variable was re-used as a dependency of a Trackable " 

53 f"object: {slot_variable}. This is not currently allowed. " 

54 "File a feature request if this limitation bothers you.") 

55 checkpoint_name = trackable_utils.slot_variable_key( 

56 variable_path=object_names[original_variable], 

57 optimizer_path=object_names[trackable], 

58 slot_name=slot_name) 

59 object_names[slot_variable] = checkpoint_name 

60 slot_variable_node_id = len(trackable_objects) 

61 node_ids[slot_variable] = slot_variable_node_id 

62 trackable_objects.append(slot_variable) 

63 slot_variable_proto = ( 

64 trackable_object_graph_pb2.TrackableObjectGraph.TrackableObject 

65 .SlotVariableReference( 

66 slot_name=slot_name, 

67 original_variable_node_id=original_variable_node_id, 

68 slot_variable_node_id=slot_variable_node_id)) 

69 slot_variables.setdefault(trackable, []).append(slot_variable_proto) 

70 return slot_variables 

71 

72 

73def get_mapped_trackable(trackable, object_map): 

74 """Returns the mapped trackable if possible, otherwise returns trackable.""" 

75 if object_map is None: 

76 return trackable 

77 else: 

78 return object_map.get(trackable, trackable) 

79 

80 

81def get_full_name(var): 

82 """Gets the full name of variable for name-based checkpoint compatiblity.""" 

83 # pylint: disable=protected-access 

84 if (not (isinstance(var, variables.Variable) or 

85 # Some objects do not subclass Variable but still act as one. 

86 resource_variable_ops.is_resource_variable(var))): 

87 return "" 

88 

89 if getattr(var, "_save_slice_info", None) is not None: 

90 # Use getattr because `var._save_slice_info` may be set as `None`. 

91 return var._save_slice_info.full_name 

92 else: 

93 return var._shared_name 

94 # pylint: enable=protected-access 

95 

96 

97def add_checkpoint_values_check(object_graph_proto): 

98 """Determines which objects have checkpoint values and save this to the proto. 

99 

100 Args: 

101 object_graph_proto: A `TrackableObjectGraph` proto. 

102 """ 

103 # Trackable -> set of all trackables that depend on it (the "parents"). 

104 # If a trackable has checkpoint values, then all of the parents can be 

105 # marked as having checkpoint values. 

106 parents = {} 

107 checkpointed_trackables = object_identity.ObjectIdentitySet() 

108 

109 # First pass: build dictionary of parent objects and initial set of 

110 # checkpointed trackables. 

111 checkpointed_trackables = set() 

112 for node_id, object_proto in enumerate(object_graph_proto.nodes): 

113 if (object_proto.attributes or object_proto.slot_variables or 

114 object_proto.HasField("registered_saver")): 

115 checkpointed_trackables.add(node_id) 

116 for child_proto in object_proto.children: 

117 child = child_proto.node_id 

118 if child not in parents: 

119 parents[child] = set() 

120 parents[child].add(node_id) 

121 

122 # Second pass: add all connected parents to set of checkpointed trackables. 

123 to_visit = set() 

124 to_visit.update(checkpointed_trackables) 

125 

126 while to_visit: 

127 trackable = to_visit.pop() 

128 if trackable not in parents: 

129 # Some trackables may not have parents (e.g. slot variables). 

130 continue 

131 current_parents = parents.pop(trackable) 

132 checkpointed_trackables.update(current_parents) 

133 for parent in current_parents: 

134 if parent in parents: 

135 to_visit.add(parent) 

136 

137 for node_id, object_proto in enumerate(object_graph_proto.nodes): 

138 object_proto.has_checkpoint_values.value = bool( 

139 node_id in checkpointed_trackables) 

140 

141 

142def objects_ids_and_slot_variables_and_paths(graph_view): 

143 """Traverse the object graph and list all accessible objects. 

144 

145 Looks for `Trackable` objects which are dependencies of 

146 `root_trackable`. Includes slot variables only if the variable they are 

147 slotting for and the optimizer are dependencies of `root_trackable` 

148 (i.e. if they would be saved with a checkpoint). 

149 

150 Args: 

151 graph_view: A GraphView object. 

152 

153 Returns: 

154 A tuple of (trackable objects, paths from root for each object, 

155 object -> node id, slot variables, object_names) 

156 """ 

157 trackable_objects, node_paths = graph_view.breadth_first_traversal() 

158 object_names = object_identity.ObjectIdentityDictionary() 

159 for obj, path in node_paths.items(): 

160 object_names[obj] = trackable_utils.object_path_to_string(path) 

161 node_ids = object_identity.ObjectIdentityDictionary() 

162 for node_id, node in enumerate(trackable_objects): 

163 node_ids[node] = node_id 

164 slot_variables = serialize_slot_variables( 

165 trackable_objects=trackable_objects, 

166 node_ids=node_ids, 

167 object_names=object_names) 

168 return (trackable_objects, node_paths, node_ids, slot_variables, object_names) 

169 

170 

171def list_objects(graph_view): 

172 """Traverse the object graph and list all accessible objects.""" 

173 trackable_objects = objects_ids_and_slot_variables_and_paths(graph_view)[0] 

174 return trackable_objects