Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/checkpoint/save_util.py: 21%

136 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Extracts tensors for checkpointing while updating a TrackableObjectGraph. 

16 

17The tensors are extracted from `Trackable._serialize_to_tensors`. 

18""" 

19import collections 

20 

21from typing import Any, Callable, List, Optional, Tuple, Mapping, Union, Dict 

22 

23from tensorflow.core.protobuf import trackable_object_graph_pb2 

24from tensorflow.python.checkpoint import graph_view as graph_view_lib 

25from tensorflow.python.checkpoint import save_util_v1 

26from tensorflow.python.checkpoint import saveable_compat 

27from tensorflow.python.checkpoint import util 

28from tensorflow.python.framework import constant_op 

29from tensorflow.python.framework import dtypes 

30from tensorflow.python.framework import ops 

31from tensorflow.python.saved_model import registration 

32from tensorflow.python.trackable import base 

33from tensorflow.python.trackable import python_state 

34from tensorflow.python.trackable import trackable_utils 

35from tensorflow.python.training.saving import saveable_object as saveable_object_lib 

36from tensorflow.python.training.saving import saveable_object_util 

37from tensorflow.python.types import core 

38from tensorflow.python.util import object_identity 

39 

40# Attributes for each Trackable in the checkpointed object graph. 

41_TrackableData = collections.namedtuple("_TrackableData", [ 

42 # A trackable in the root Trackable object graph. 

43 "trackable", 

44 # The index at which the Trackable appears in TrackableObjectGraph.nodes. 

45 "node_id", 

46 # The BFS-generated path from the root object / used to generate readable 

47 # checkpoint keys. 

48 "object_name", 

49 # A list of ObjectReference for each child connected to this Trackable. 

50 "children_proto", 

51 # A list of SlotVariableReference to save to the object (only valid for 

52 # Optimizer objects). 

53 "slot_variable_proto", 

54 # The object to save to checkpoint. Usually this is the same as `trackable`, 

55 # but can differ when the the caller wants to specify a different object to 

56 # save. For example, when saving checkpoints asynchronously, variables are 

57 # copied to the CPU. `object_to_save` is set as the copied variable. 

58 "object_to_save", 

59 ]) 

60 

61 

62def _split_trackables( 

63 trackable_data: List[_TrackableData] 

64) -> Tuple[List[_TrackableData], List[_TrackableData], 

65 Dict[str, List[_TrackableData]]]: 

66 """Splits Trackables into 3 categories (tensor/pystate/registered).""" 

67 tensor_trackables = [] 

68 pystate_trackables = [] 

69 registered_trackables = collections.defaultdict(list) 

70 

71 for td in trackable_data: 

72 saver_name = registration.get_registered_saver_name(td.object_to_save) 

73 if isinstance(td.object_to_save, python_state.PythonState): 

74 pystate_trackables.append(td) 

75 elif saver_name: 

76 registered_trackables[saver_name].append(td) 

77 else: 

78 tensor_trackables.append(td) 

79 

80 return tensor_trackables, pystate_trackables, registered_trackables 

81 

82 

83def _gather_trackable_data( 

84 graph_view: graph_view_lib.ObjectGraphView, 

85 object_map: Mapping[base.Trackable, base.Trackable] 

86) -> Tuple[List[_TrackableData], Dict[base.Trackable, int]]: 

87 """Returns a list of generated TrackableData based on the ObjectGraphView.""" 

88 trackable_objects, node_paths = graph_view.breadth_first_traversal() 

89 object_names = object_identity.ObjectIdentityDictionary() 

90 for obj, path in node_paths.items(): 

91 object_names[obj] = trackable_utils.object_path_to_string(path) 

92 node_ids = object_identity.ObjectIdentityDictionary() 

93 for node_id, node in enumerate(trackable_objects): 

94 node_ids[node] = node_id 

95 slot_variables = util.serialize_slot_variables( 

96 trackable_objects=trackable_objects, 

97 node_ids=node_ids, 

98 object_names=object_names) 

99 trackable_data = [] 

100 for trackable in trackable_objects: 

101 children_proto = [] 

102 for child in graph_view.list_children(trackable): 

103 children_proto.append( 

104 trackable_object_graph_pb2.TrackableObjectGraph.TrackableObject 

105 .ObjectReference(node_id=node_ids[child.ref], 

106 local_name=child.name)) 

107 

108 trackable_data.append(_TrackableData( 

109 trackable, 

110 node_id=node_ids[trackable], 

111 object_name=object_names[trackable], 

112 children_proto=children_proto, 

113 slot_variable_proto=slot_variables.get(trackable, []), 

114 object_to_save=util.get_mapped_trackable(trackable, object_map))) 

115 return trackable_data, node_ids 

116 

117 

118def _fill_object_graph_proto( 

119 trackable_data: List[_TrackableData] 

120) -> trackable_object_graph_pb2.TrackableObjectGraph: 

121 """Name non-slot `Trackable`s and add them to `object_graph_proto`.""" 

122 object_graph_proto = trackable_object_graph_pb2.TrackableObjectGraph() 

123 for checkpoint_id, td in enumerate(trackable_data): 

124 assert td.node_id == checkpoint_id 

125 object_graph_proto.nodes.add( 

126 slot_variables=td.slot_variable_proto, 

127 children=td.children_proto) 

128 return object_graph_proto 

129 

130 

131def _get_and_write_tensors_to_serialize( 

132 tensor_trackables: List[_TrackableData], 

133 node_ids: Dict[base.Trackable, int], 

134 call_with_mapped_captures: Union[Callable[..., Any], None], 

135 cache: Union[Dict[base.Trackable, any], None], 

136 object_graph_proto: trackable_object_graph_pb2.TrackableObjectGraph 

137) -> Dict[base.Trackable, Any]: 

138 """Creates dictionary of tensors to checkpoint, and updates the proto.""" 

139 # Maps trackable to the a dictionary of tensors, which maps 

140 # checkpoint key (-> slice_spec) -> tensor. 

141 serialized_tensors = object_identity.ObjectIdentityDictionary() 

142 

143 for td in tensor_trackables: 

144 if cache is not None and td.object_to_save in cache: 

145 trackable, tensor_dict, object_proto = cache[td.object_to_save] 

146 serialized_tensors[trackable] = tensor_dict 

147 object_graph_proto.nodes[td.node_id].attributes.MergeFrom(object_proto) 

148 continue 

149 

150 legacy_name = saveable_compat.get_saveable_name(td.object_to_save) or "" 

151 

152 if (not saveable_object_util.trackable_has_serialize_to_tensor( 

153 td.object_to_save) or 

154 legacy_name): 

155 # Use the legacy code path for objects that are using SaveableObjects 

156 # or the compat saveable name decorator. 

157 trackable, tensor_dict = _get_tensors_from_legacy_saveable( 

158 td, node_ids, call_with_mapped_captures, object_graph_proto) 

159 else: 

160 tensor_dict = _get_tensors_from_trackable( 

161 td, call_with_mapped_captures, object_graph_proto) 

162 trackable = td.object_to_save 

163 serialized_tensors[trackable] = tensor_dict 

164 

165 if cache is not None and td.object_to_save not in cache: 

166 cache[td.object_to_save] = ( 

167 trackable, tensor_dict, 

168 object_graph_proto.nodes[td.node_id].attributes) 

169 

170 return serialized_tensors 

171 

172 

173def _get_tensors_from_legacy_saveable( 

174 trackable_data: _TrackableData, 

175 node_ids: Dict[base.Trackable, int], 

176 call_with_mapped_captures: Callable[..., Any], 

177 object_graph_proto: trackable_object_graph_pb2.TrackableObjectGraph 

178) -> Tuple[base.Trackable, Dict[str, Any]]: 

179 """Gets tensors to serialize from a Trackable with legacy SaveableObjects.""" 

180 # Call `save_util_v1` methods to create legacy SaveableObjects and update the 

181 # proto. 

182 object_names = object_identity.ObjectIdentityDictionary() 

183 object_names[trackable_data.trackable] = trackable_data.object_name 

184 object_map = object_identity.ObjectIdentityDictionary() 

185 object_map[trackable_data.trackable] = trackable_data.object_to_save 

186 

187 checkpoint_factory_map, _ = save_util_v1.get_checkpoint_factories_and_keys( 

188 object_names, object_map) 

189 named_saveable_objects, _ = ( 

190 save_util_v1.generate_saveable_objects( 

191 checkpoint_factory_map, 

192 object_graph_proto, 

193 node_ids, 

194 object_map, 

195 call_with_mapped_captures, 

196 saveables_cache=None)) 

197 trackable = ( 

198 saveable_object_util.SaveableCompatibilityConverter( 

199 trackable_data.object_to_save, named_saveable_objects)) 

200 return trackable, trackable._serialize_to_tensors() # pylint: disable=protected-access 

201 

202 

203def _get_tensors_from_trackable( 

204 trackable_data: _TrackableData, 

205 call_with_mapped_captures: Union[Callable[..., Any], None], 

206 object_graph_proto: trackable_object_graph_pb2.TrackableObjectGraph 

207) -> Dict[str, Any]: 

208 """Gets tensors to serialize from a Trackable.""" 

209 trackable = trackable_data.object_to_save 

210 save_fn = trackable._serialize_to_tensors # pylint: disable=protected-access 

211 

212 if (call_with_mapped_captures and 

213 isinstance(save_fn, core.ConcreteFunction)): 

214 ret_tensor_dict = call_with_mapped_captures(save_fn, []) 

215 else: 

216 ret_tensor_dict = save_fn() 

217 

218 # Create checkpoint keys for each entry in the returned tensor dict, and 

219 # write each entry to the object proto. 

220 tensor_dict = {} 

221 for tensor_name, maybe_tensor in ret_tensor_dict.items(): 

222 local_name = trackable_utils.escape_local_name(tensor_name) 

223 checkpoint_key = trackable_utils.checkpoint_key(trackable_data.object_name, 

224 local_name) 

225 tensor_dict[checkpoint_key] = maybe_tensor 

226 

227 # TODO(b/261786493): Delete this when DCheckpoint is removed. 

228 if isinstance(maybe_tensor, saveable_object_lib.SaveSpec): 

229 maybe_tensor.name = checkpoint_key 

230 maybe_tensor.slice_spec = "" 

231 

232 if object_graph_proto is not None: 

233 object_graph_proto.nodes[trackable_data.node_id].attributes.add( 

234 name=local_name, 

235 checkpoint_key=checkpoint_key, 

236 full_name=util.get_full_name(trackable)) 

237 

238 return tensor_dict 

239 

240 

241def _get_and_write_pystate_feed_additions( 

242 pystate_trackables: List[_TrackableData], 

243 cache: Union[Dict[base.Trackable, Any], None], 

244 object_graph_proto=None 

245) -> Tuple[Dict[base.Trackable, Any], Dict[base.Trackable, Any]]: 

246 """Gets feed additions needed for checkpointing Python State.""" 

247 serialized_tensors = object_identity.ObjectIdentityDictionary() 

248 # Maps tensor placeholders to python values. 

249 feed_additions = {} 

250 

251 for td in pystate_trackables: 

252 trackable = td.object_to_save 

253 checkpoint_key = trackable_utils.checkpoint_key(td.object_name, 

254 python_state.PYTHON_STATE) 

255 if trackable in cache: 

256 save_string = cache[td.object_to_save][python_state.PYTHON_STATE] 

257 else: 

258 with ops.device("/cpu:0"): 

259 save_string = constant_op.constant("", dtype=dtypes.string) 

260 cache[trackable] = {python_state.PYTHON_STATE: save_string} 

261 

262 with ops.init_scope(): 

263 value = trackable.serialize() 

264 feed_additions[save_string] = value 

265 serialized_tensors[trackable] = {checkpoint_key: save_string} 

266 

267 object_graph_proto.nodes[td.node_id].attributes.add( 

268 name=python_state.PYTHON_STATE, 

269 checkpoint_key=checkpoint_key, 

270 full_name=util.get_full_name(trackable)) 

271 

272 return serialized_tensors, feed_additions 

273 

274 

275def _get_and_write_registered_savers( 

276 registered_trackables: Dict[str, List[_TrackableData]], 

277 object_graph_proto: trackable_object_graph_pb2.TrackableObjectGraph 

278) -> Dict[str, Dict[str, base.Trackable]]: 

279 """Generates dictionary of registered savers and updates the proto.""" 

280 registered_savers = collections.defaultdict(dict) 

281 for saver_name, trackables in registered_trackables.items(): 

282 for td in trackables: 

283 registered_savers[saver_name][td.object_name] = td.object_to_save 

284 

285 object_proto = object_graph_proto.nodes[td.node_id] 

286 object_proto.registered_saver.name = saver_name 

287 object_proto.registered_saver.object_name = td.object_name 

288 

289 return registered_savers 

290 

291 

292def serialize_graph_view( 

293 graph_view: graph_view_lib.ObjectGraphView, 

294 object_map: Optional[Mapping[base.Trackable, base.Trackable]] = None, 

295 call_with_mapped_captures: Optional[Callable[..., Any]] = None, 

296 cache: Optional[Dict[base.Trackable, Any]] = None) -> ...: 

297 """Gathers serialization objects, and creates a TrackableObjectGraph proto.""" 

298 # There are 3 types of checkpoint serialization types supported: 

299 # 1. Trackables that override `Trackable._serialize_to_tensor()`. 

300 # 2. PythonState: A special type of Trackable that serializes a Python string. 

301 # 3. Registered Trackable Savers: For objects that need to define advanced 

302 # checkpointing operations not supported by (1) or (2). 

303 trackable_data, node_ids = _gather_trackable_data(graph_view, object_map) 

304 tensor_trackables, pystate_trackables, registered_trackables = ( 

305 _split_trackables(trackable_data)) 

306 

307 object_graph_proto = _fill_object_graph_proto(trackable_data) 

308 

309 serialized_tensors = _get_and_write_tensors_to_serialize( 

310 tensor_trackables, 

311 node_ids, 

312 call_with_mapped_captures, 

313 cache, 

314 object_graph_proto) 

315 registered_savers = _get_and_write_registered_savers( 

316 registered_trackables, object_graph_proto) 

317 

318 # PythonState trackables must be treated differently depending on if the 

319 # checkpoint is being saved in TF1 graph mode (`cache` exists) or 

320 # eager mode (`cache` is None). 

321 if cache is None: 

322 # When the tensor cache is None, get the serialized tensors directly. 

323 feed_additions = None 

324 serialized_tensors.update(_get_and_write_tensors_to_serialize( 

325 pystate_trackables, 

326 node_ids, 

327 call_with_mapped_captures, 

328 cache, 

329 object_graph_proto)) 

330 else: 

331 # Python state is not automatically updated within a TF session so these 

332 # values must be passed to sess.run(feed_additions=...). 

333 new_serialized_tensors, feed_additions = ( 

334 _get_and_write_pystate_feed_additions(pystate_trackables, 

335 cache, 

336 object_graph_proto)) 

337 serialized_tensors.update(new_serialized_tensors) 

338 

339 # Gather all trackables that have checkpoint values or descendants with 

340 # checkpoint values, and add that info to the proto. 

341 util.add_checkpoint_values_check(object_graph_proto) 

342 return (serialized_tensors, feed_additions, registered_savers, 

343 object_graph_proto) 

344