Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/checkpoint/restore.py: 17%

310 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Logic for restoring checkpointed values for Trackables.""" 

16 

17import collections 

18 

19from tensorflow.python.checkpoint import checkpoint_view 

20from tensorflow.python.checkpoint import functional_saver 

21from tensorflow.python.checkpoint import save_util_v1 

22from tensorflow.python.checkpoint import saveable_compat 

23from tensorflow.python.eager import context 

24from tensorflow.python.framework import ops 

25from tensorflow.python.ops import array_ops 

26from tensorflow.python.ops import gen_io_ops as io_ops 

27from tensorflow.python.ops import io_ops 

28from tensorflow.python.platform import tf_logging as logging 

29from tensorflow.python.saved_model import registration 

30from tensorflow.python.trackable import base 

31from tensorflow.python.trackable import constants 

32from tensorflow.python.trackable import python_state 

33from tensorflow.python.trackable import trackable_utils 

34from tensorflow.python.training.saving import saveable_object_util 

35from tensorflow.python.util import object_identity 

36 

37 

38class CheckpointPosition(object): 

39 """Indicates a position within a `_CheckpointRestoreCoordinator`.""" 

40 

41 __slots__ = ["_checkpoint", "_proto_id", "skip_restore"] 

42 

43 def __init__(self, checkpoint, proto_id): 

44 """Specify an object within a checkpoint. 

45 

46 Args: 

47 checkpoint: A _CheckpointRestoreCoordinator object. 

48 proto_id: The index of this object in TrackableObjectGraph.nodes. 

49 """ 

50 self._checkpoint = checkpoint 

51 self._proto_id = proto_id 

52 # This may be set to True if the registered saver cannot be used with this 

53 # object. 

54 self.skip_restore = False 

55 

56 def restore(self, trackable, reader=None): 

57 """Restore this value into `trackable`.""" 

58 with ops.init_scope(): 

59 if self.bind_object(trackable): 

60 # This object's correspondence with a checkpointed object is new, so 

61 # process deferred restorations for it and its dependencies. 

62 restore_ops = self._restore_descendants(reader) 

63 if restore_ops: 

64 self._checkpoint.new_restore_ops(restore_ops) 

65 

66 def bind_object(self, trackable): 

67 """Set a checkpoint<->object correspondence. 

68 

69 Args: 

70 trackable: The object to record a correspondence for. 

71 

72 Returns: 

73 True if this is a new assignment, False if this object has already been 

74 mapped to a checkpointed `Object` proto. 

75 Raises: 

76 AssertionError: If another object is already bound to the `Object` proto. 

77 """ 

78 checkpoint = self.checkpoint 

79 checkpoint.all_python_objects.add(trackable) 

80 current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None) 

81 checkpoint.matched_proto_ids.add(self._proto_id) 

82 if current_assignment is None: 

83 checkpoint.object_by_proto_id[self._proto_id] = trackable 

84 return True # New assignment 

85 else: 

86 # The object was already mapped for this checkpoint load, which means 

87 # we don't need to do anything besides check that the mapping is 

88 # consistent (if the dependency DAG is not a tree then there are 

89 # multiple paths to the same object). 

90 if current_assignment is not trackable: 

91 logging.warning( 

92 "Inconsistent references when loading the checkpoint into this " 

93 "object graph. For example, in the saved checkpoint object, " 

94 "`model.layer.weight` and `model.layer_copy.weight` reference the " 

95 "same variable, while in the current object these are two different" 

96 " variables. The referenced variables are:" 

97 f"({current_assignment} and {trackable}).") 

98 return False # Not a new assignment 

99 

100 def is_simple_variable(self): 

101 """Determine whether this value is restorable with a Tensor initializer.""" 

102 attributes = self.object_proto.attributes 

103 return (len(attributes) == 1 and 

104 attributes[0].name == constants.VARIABLE_VALUE_KEY and 

105 not self.object_proto.children) 

106 

107 def value_tensors(self, shape_and_slices=None): 

108 """Create value `Tensor`s for this object's attributes. 

109 

110 Does not require that the Python object has been created. Used for 

111 restore-on-create when executing eagerly. 

112 

113 Args: 

114 shape_and_slices: A dict mapping from object attribute names to a shape 

115 and slice string that will be passed to a RestoreV2 op. If the dict is 

116 None or if an object attribute is not in the dict, the full tensor will 

117 be restored. 

118 

119 Returns: 

120 A dictionary mapping from object attribute names to `Tensor`s. 

121 """ 

122 value_tensors = {} 

123 for serialized_tensor in self.object_proto.attributes: 

124 checkpoint_key = serialized_tensor.checkpoint_key 

125 dtype = self._checkpoint.dtype_map[checkpoint_key] 

126 base_type = dtype.base_dtype 

127 io_device = self._checkpoint.options.experimental_io_device or "cpu:0" 

128 with ops.init_scope(): 

129 with ops.device(io_device): 

130 # Run the restore itself on the io_device(CPU or specified). 

131 if (shape_and_slices is not None and 

132 serialized_tensor.name in shape_and_slices): 

133 shape_and_slice = shape_and_slices[serialized_tensor.name] 

134 else: 

135 shape_and_slice = "" 

136 value, = io_ops.restore_v2( 

137 prefix=self._checkpoint.save_path_tensor, 

138 tensor_names=[checkpoint_key], 

139 shape_and_slices=[shape_and_slice], 

140 dtypes=[base_type], 

141 name="%s_checkpoint_read" % (serialized_tensor.name,)) 

142 # Copy the value to the current device if necessary. 

143 value_tensors[serialized_tensor.name] = array_ops.identity(value) 

144 return value_tensors 

145 

146 def gather_ops_or_named_saveables(self): 

147 """Looks up or creates SaveableObjects which don't have cached ops. 

148 

149 Returns: 

150 A tuple of ( 

151 existing_restore_ops: list, 

152 named_saveables: dict, 

153 python_positions: list, 

154 registered_savers: dict) 

155 """ 

156 

157 recorded_registered_saver = self.get_registered_saver_name() 

158 if not (self.object_proto.attributes or recorded_registered_saver): 

159 return [], {}, [], {} 

160 

161 existing_restore_ops = [] 

162 named_saveables = {} 

163 python_positions = [] 

164 registered_savers = collections.defaultdict(dict) 

165 

166 saveable_factories = saveable_object_util.saveable_objects_from_trackable( 

167 self.trackable) 

168 saver_name = registration.get_registered_saver_name(self.trackable) 

169 

170 if recorded_registered_saver: 

171 if not self.skip_restore: 

172 name = self.object_proto.registered_saver.object_name 

173 registered_savers[recorded_registered_saver][name] = self.trackable 

174 # Else: Skip restoration of this Trackable. This skip only happens if the 

175 # registered saver has enabled `option_restore`. Otherwise, an error would 

176 # have been raised at `self.get_registered_saver_name()`. 

177 elif saver_name: 

178 # In this case, the checkpoint has a recorded serialized tensor but no 

179 # registered saver, while the Trackable loading the checkpoint has 

180 # migrated to the registered checkpoint functionality (TPUEmbedding is an 

181 # example of this). 

182 

183 # Set the Trackable's object name to the first checkpoint key that is 

184 # stored in checkpoint. If there is a use case that requires the other 

185 # keys, then we can take another look at this. 

186 registered_savers[saver_name] = { 

187 self.object_proto.attributes[0].checkpoint_key: self.trackable 

188 } 

189 elif isinstance(self.trackable, python_state.PythonState): 

190 python_positions.append(self) 

191 elif saveable_factories.keys() == { 

192 trackable_utils.SERIALIZE_TO_TENSORS_NAME 

193 }: 

194 existing_restore_ops, named_saveables = ( 

195 self._create_serialize_to_tensor_saveable(saveable_factories)) 

196 elif saveable_factories: 

197 existing_restore_ops, named_saveables = ( 

198 self._create_saveables_by_attribute_name(saveable_factories)) 

199 else: 

200 # If no registered savers were found, then it means that one or more 

201 # serialized tensors were never used. 

202 for serialized_tensor in self.object_proto.attributes: 

203 self._checkpoint.unused_attributes.setdefault( 

204 self._proto_id, []).append(serialized_tensor.name) 

205 return (existing_restore_ops, named_saveables, python_positions, 

206 registered_savers) 

207 

208 def _create_serialize_to_tensor_saveable(self, saveable_factories): 

209 """Creates a saveable using the _serialize_to_tensor method.""" 

210 # Extract the saveable name from the checkpoint key. This will be used as 

211 # the cache key or the name to pass to the saveable factory. 

212 suffix = saveable_compat.get_saveable_name(self.trackable) or "" 

213 saveable_name = _extract_saveable_name( 

214 self.object_proto.attributes[0].checkpoint_key) + suffix 

215 

216 # Try to find the cached saveable (only in graph mode). 

217 if not context.executing_eagerly(): 

218 existing_op = self._checkpoint.restore_ops_by_name.get( 

219 saveable_name, None) 

220 if existing_op is not None: 

221 return [existing_op], {} 

222 

223 saveables_cache = self._checkpoint.saveables_cache.setdefault( 

224 self.trackable, {}) 

225 if saveable_name in saveables_cache: 

226 return [], {saveable_name: saveables_cache[saveable_name]} 

227 

228 saveable = saveable_factories[trackable_utils.SERIALIZE_TO_TENSORS_NAME]( 

229 name=saveable_name) 

230 if not context.executing_eagerly(): 

231 saveables_cache[saveable_name] = saveable 

232 return [], {saveable_name: saveable} 

233 

234 def _create_saveables_by_attribute_name(self, saveable_factories): 

235 """Creates or caches SaveableObjects by matching the attribute names. 

236 

237 The attribute name keys in the `saveable_factories` is used to find the 

238 corresponding attribute in the object proto. Attributes contain checkpoint 

239 keys which are passed to the factory function to generate the 

240 SaveableObject. 

241 

242 Args: 

243 saveable_factories: a dict mapping attribute name to a callable factory 

244 function that produces a SaveableObject. 

245 

246 Returns: 

247 A tuple of ( 

248 existing_restore_ops: list, 

249 named_saveables: dict) 

250 """ 

251 # Name saveables based on the name this object had when it was checkpointed. 

252 named_saveables = {} 

253 existing_restore_ops = [] 

254 

255 # Forward compatibility code: when loading a future checkpoint, there may 

256 # be multiple SerializedTensors mapped to a single saveable. 

257 created_compat_names = set() 

258 

259 for serialized_tensor in self.object_proto.attributes: 

260 if context.executing_eagerly(): 

261 existing_op = None 

262 else: 

263 existing_op = self._checkpoint.restore_ops_by_name.get( 

264 serialized_tensor.checkpoint_key, None) 

265 if existing_op is not None: 

266 existing_restore_ops.append(existing_op) 

267 continue 

268 

269 if any(serialized_tensor.name.startswith(name) 

270 for name in created_compat_names): 

271 continue # Saveable has already been created for this tensor. 

272 

273 # Only if we don't have cached ops for this SaveableObject, we'll see if 

274 # the SaveableObject itself has been cached. If not, we'll make it, and 

275 # either way we'll extract new ops from it (or if it has Python state to 

276 # restore, we'll run that). 

277 saveables_cache = self._checkpoint.saveables_cache 

278 if saveables_cache is None: 

279 # No SaveableObject caching when executing eagerly. 

280 saveable = None 

281 else: 

282 # If we've already created and cached a SaveableObject for this 

283 # attribute, we can re-use it to avoid re-creating some ops when graph 

284 # building. 

285 saveable_list = saveables_cache.get(self.trackable, 

286 {}).get(serialized_tensor.name, 

287 (None,)) 

288 if len(saveable_list) == 1: 

289 # Almost every attribute will have exactly one SaveableObject. 

290 saveable, = saveable_list 

291 else: 

292 # Don't use cached SaveableObjects for partitioned variables, which is 

293 # the only case where we'd have a list of SaveableObjects. Op caching 

294 # will catch them. 

295 saveable = None 

296 if saveable is not None: 

297 # The name of this attribute has changed, so we need to re-generate 

298 # the SaveableObject. 

299 if serialized_tensor.checkpoint_key not in saveable.name: 

300 saveable = None 

301 del saveables_cache[self.trackable] 

302 if saveable is None: 

303 # If there was no cached SaveableObject, create one. 

304 # Use the name to check if the Python object has the same attribute. 

305 saveable = _get_saveable_from_factory(saveable_factories, 

306 serialized_tensor, 

307 created_compat_names) 

308 if saveable is None: 

309 # Purposefully does not throw an exception if attributes have been 

310 # added or deleted. Stores unused attributes so an exception can be 

311 # raised if the user decides to check that everything in the 

312 # checkpoint was loaded. 

313 self._checkpoint.unused_attributes.setdefault( 

314 self._proto_id, []).append(serialized_tensor.name) 

315 continue 

316 if saveables_cache is not None: 

317 saveables_cache.setdefault(self.trackable, 

318 {})[serialized_tensor.name] = [saveable] 

319 named_saveables[serialized_tensor.checkpoint_key] = saveable 

320 

321 return existing_restore_ops, named_saveables 

322 

323 def restore_ops(self, reader=None): 

324 """Create or fetch restore ops for this object's attributes. 

325 

326 Requires that the `Trackable` Python object has been bound to an object 

327 ID in the checkpoint. 

328 

329 Args: 

330 reader: A `CheckpointReader`. If None, a new instance will be created. 

331 

332 Returns: 

333 A list of operations when graph building, or an empty list when executing 

334 eagerly. 

335 """ 

336 if self._has_registered_saver(): 

337 raise ValueError("Unable to run individual checkpoint restore for objects" 

338 " with registered savers.") 

339 (restore_ops, tensor_saveables, python_positions, 

340 _) = self.gather_ops_or_named_saveables() 

341 restore_ops.extend( 

342 self._checkpoint.restore_saveables( 

343 tensor_saveables, python_positions, reader=reader)) 

344 return restore_ops 

345 

346 @property 

347 def checkpoint(self): 

348 return self._checkpoint 

349 

350 @property 

351 def trackable(self): 

352 return self._checkpoint.object_by_proto_id[self._proto_id] 

353 

354 @property 

355 def object_proto(self): 

356 return self._checkpoint.object_graph_proto.nodes[self._proto_id] 

357 

358 @property 

359 def proto_id(self): 

360 return self._proto_id 

361 

362 @property 

363 def restore_uid(self): 

364 return self._checkpoint.restore_uid 

365 

366 def __repr__(self): 

367 return repr(self.object_proto) 

368 

369 def value_shape(self): 

370 """The shape of the VARIABLE_VALUE tensor. 

371 

372 Returns: 

373 If found a TensorShape object, otherwise None. 

374 """ 

375 for serialized_tensor in self.object_proto.attributes: 

376 if serialized_tensor.name == constants.VARIABLE_VALUE_KEY: 

377 return self._checkpoint.shape_map[serialized_tensor.checkpoint_key] 

378 return None 

379 

380 def _has_registered_saver(self): 

381 return bool(self.object_proto.registered_saver.name) 

382 

383 def get_registered_saver_name(self): 

384 """Returns the registered saver name defined in the Checkpoint.""" 

385 if self._has_registered_saver(): 

386 saver_name = self.object_proto.registered_saver.name 

387 try: 

388 registration.validate_restore_function(self.trackable, saver_name) 

389 except ValueError as e: 

390 if registration.get_strict_predicate_restore(saver_name): 

391 raise e 

392 self.skip_restore = True 

393 return saver_name 

394 return None 

395 

396 def create_slot_variable_position(self, optimizer_object, variable, 

397 slot_variable_id, slot_name): 

398 """Generates CheckpointPosition for a slot variable. 

399 

400 Args: 

401 optimizer_object: Optimizer that owns the slot variable. 

402 variable: Variable associated with the slot variable. 

403 slot_variable_id: ID of the slot variable. 

404 slot_name: Name of the slot variable. 

405 

406 Returns: 

407 If there is a slot variable in the `optimizer_object` that has not been 

408 bound to the checkpoint, this function returns a tuple of ( 

409 new `CheckpointPosition` for the slot variable, 

410 the slot variable itself). 

411 """ 

412 slot_variable_position = CheckpointPosition( 

413 checkpoint=self.checkpoint, proto_id=slot_variable_id) 

414 # pylint: disable=protected-access 

415 slot_variable = optimizer_object._create_or_restore_slot_variable( 

416 slot_variable_position=slot_variable_position, 

417 variable=variable, 

418 slot_name=slot_name) 

419 # pylint: enable=protected-access 

420 if (slot_variable is not None and 

421 slot_variable_position.bind_object(slot_variable)): 

422 return slot_variable_position, slot_variable 

423 else: 

424 return None, None 

425 

426 def create_child_position(self, node_id): 

427 return CheckpointPosition(checkpoint=self.checkpoint, proto_id=node_id) 

428 

429 def _restore_descendants(self, reader=None): 

430 """Restore the bound Trackable and dependencies (may be deferred).""" 

431 # Attempt a breadth-first traversal, since presumably the user has more 

432 # control over shorter paths. If we don't have all of the dependencies at 

433 # this point, the end result is not breadth-first (since other deferred 

434 # traversals will happen later). 

435 

436 # You may be wondering why elements in the `visit_queue` are tuples that 

437 # contains both CheckpointPositions and their Trackable. The reason is that 

438 # Optimizers will not keep a strong reference to slot vars for 

439 # ShardedVariables. The slot variable must be kept in memory until the 

440 # restore saveables have been created. 

441 visit_queue = collections.deque([(self, self.trackable)]) 

442 restore_ops = [] 

443 tensor_saveables = {} 

444 python_positions = [] 

445 registered_savers = collections.defaultdict(dict) 

446 while visit_queue: 

447 current_position, _ = visit_queue.popleft() 

448 

449 # Restore using the ops defined in a Saveable or registered function. 

450 (new_restore_ops, new_tensor_saveables, new_python_positions, 

451 new_registered_savers) = current_position._single_restore() # pylint: disable=protected-access 

452 restore_ops.extend(new_restore_ops) 

453 tensor_saveables.update(new_tensor_saveables) 

454 python_positions.extend(new_python_positions) 

455 for saver_name, trackable_map in new_registered_savers.items(): 

456 registered_savers[saver_name].update(trackable_map) 

457 

458 # Pass the restoration to the dependencies. 

459 _queue_children_for_restoration(current_position, visit_queue) 

460 _queue_slot_variables(current_position, visit_queue) 

461 

462 restore_ops.extend( 

463 current_position.checkpoint.restore_saveables( 

464 tensor_saveables, 

465 python_positions, 

466 registered_savers, 

467 reader=reader)) 

468 return restore_ops 

469 

470 def _single_restore(self): 

471 """Restores the trackable.""" 

472 trackable = self.trackable 

473 trackable._maybe_initialize_trackable() # pylint: disable=protected-access 

474 checkpoint = self.checkpoint 

475 # If the UID of this restore is lower than our current update UID, we don't 

476 # need to actually restore the object. 

477 if checkpoint.restore_uid > trackable._update_uid: # pylint: disable=protected-access 

478 restore_ops, tensor_saveables, python_positions, registered_savers = ( 

479 self.gather_ops_or_named_saveables()) 

480 trackable._update_uid = checkpoint.restore_uid # pylint: disable=protected-access 

481 else: 

482 restore_ops = () 

483 tensor_saveables = {} 

484 python_positions = () 

485 registered_savers = {} 

486 return restore_ops, tensor_saveables, python_positions, registered_savers 

487 

488 

489def restore_nodes(save_path, nodes_to_restore): 

490 """Restores nodes from a dict. 

491 

492 Requires that the `Trackable` Python object has been bound to an object 

493 ID in the checkpoint. 

494 

495 Args: 

496 save_path: a string represents path to the checkpoint. 

497 nodes_to_restore: a dict maps `node_id` to `trackable` to be restored. 

498 """ 

499 if save_path is None: 

500 raise ValueError("save_path cannot be empty.") 

501 if not isinstance(nodes_to_restore, dict): 

502 raise ValueError( 

503 "Expecting a dictionary of node_id to Trackable for nodes_to_restore.") 

504 

505 ckpt_view = checkpoint_view.CheckpointView(save_path) 

506 ckpt_view_descendants = ckpt_view.descendants() 

507 for node_id, trackable in nodes_to_restore.items(): 

508 # node_id does not have a corresponding Checkpoint value. 

509 if (node_id not in ckpt_view_descendants or 

510 ckpt_view._object_graph_proto.nodes[ # pylint: disable=protected-access 

511 node_id] is None): 

512 raise ValueError( 

513 f"The expected node_id: {node_id} to Trackable {trackable} to " 

514 "restore does not exist in the checkpoint.") 

515 # Trackable mapped to node_id to restore is empty. 

516 if trackable is None or not isinstance(trackable, base.Trackable): 

517 raise ValueError( 

518 f"Expecting a valid Trackable to node_id: {node_id} but got " 

519 f"trackable: {trackable}." 

520 ) 

521 

522 serialized_tensors = object_identity.ObjectIdentityDictionary() 

523 for node_id, current_trackable in nodes_to_restore.items(): 

524 ckpt_contains_serialized_tensors = ckpt_view._object_graph_proto.nodes[ # pylint: disable=protected-access 

525 node_id].attributes 

526 node = ckpt_view._object_graph_proto.nodes[node_id] # pylint: disable=protected-access 

527 trackable_has_serialize_to_tensor = saveable_object_util.trackable_has_serialize_to_tensor( 

528 current_trackable) 

529 if not trackable_has_serialize_to_tensor: 

530 if not node.attributes: 

531 if saveable_object_util.saveable_objects_from_trackable( 

532 current_trackable): 

533 raise ValueError( 

534 f"Trackable {current_trackable} expects checkpointed values but " 

535 "checkpoint does not contain serialized tensors for node_id: " 

536 f"{node_id}.") 

537 else: 

538 continue 

539 object_names = object_identity.ObjectIdentityDictionary() 

540 object_names[current_trackable] = trackable_utils.extract_object_name( 

541 node.attributes[0].checkpoint_key) 

542 checkpoint_factory_map, _ = save_util_v1.get_checkpoint_factories_and_keys( 

543 object_names, None) 

544 saveable_objects = save_util_v1.generate_saveable_objects( 

545 checkpoint_factory_map)[0] 

546 if len(node.attributes) != len(saveable_objects): 

547 raise ValueError("Size for saveable_objects for Trackable: " 

548 f"{len(saveable_objects)} did not match the size for " 

549 "serialized_tensors for checkpoint: " 

550 f"{len(node.attributes)}.") 

551 current_trackable = saveable_object_util.SaveableCompatibilityConverter( 

552 current_trackable, saveable_objects) 

553 

554 serialized_tensors[ 

555 current_trackable] = current_trackable._serialize_to_tensors() # pylint: disable=protected-access 

556 trackable_expects_ckpted_value = bool(serialized_tensors[current_trackable]) 

557 

558 if trackable_expects_ckpted_value and not ckpt_contains_serialized_tensors: 

559 raise ValueError( 

560 f"Trackable {current_trackable} expects checkpointed values but " 

561 "checkpoint does not contain serialized tensors for node_id: " 

562 f"{node_id}.") 

563 

564 if not trackable_expects_ckpted_value and ckpt_contains_serialized_tensors: 

565 raise ValueError( 

566 f"Trackable {current_trackable} does not expect checkpointed " 

567 "values but checkpoint contains serialized tensors: " 

568 f"{ckpt_contains_serialized_tensors} for node_id: {node_id}.") 

569 

570 if len(node.attributes) != len(serialized_tensors[current_trackable]): 

571 raise ValueError("Size for serialized_tensors for Trackable: " 

572 f"{len(serialized_tensors[current_trackable])} did not " 

573 "match size for serialized_tensors for checkpoint: " 

574 f"{len(node.attributes)}.") 

575 

576 if not trackable_has_serialize_to_tensor: 

577 functional_saver.MultiDeviceSaver(serialized_tensors).restore(save_path) 

578 else: 

579 # Converts attribute.name to attribute.checkpoint_key since that's what 

580 # restore method is expecting. i.e., converts "a" to "/.ATTRIBUTES/a". 

581 serialized_tensors_renamed = object_identity.ObjectIdentityDictionary() 

582 serialized_tensors_renamed[current_trackable] = {} 

583 for attribute in node.attributes: 

584 name = attribute.name 

585 checkpoint_key = attribute.checkpoint_key 

586 serialized_tensors_renamed[current_trackable][ 

587 checkpoint_key] = serialized_tensors[current_trackable][name] 

588 functional_saver.MultiDeviceSaver(serialized_tensors_renamed).restore( 

589 save_path) 

590 

591 

592def _queue_children_for_restoration(checkpoint_position, visit_queue): 

593 """Queues the restoration of trackable's children or defers them.""" 

594 # pylint: disable=protected-access 

595 trackable = checkpoint_position.trackable 

596 for child in checkpoint_position.object_proto.children: 

597 # trackable._lookup_dependency can be expensive so first check if this node 

598 # already has an object correspondence. If so we skip this node. 

599 correspondence = checkpoint_position.checkpoint.object_by_proto_id.get( 

600 child.node_id, None 

601 ) 

602 if correspondence is not None: 

603 continue 

604 child_position = checkpoint_position.create_child_position(child.node_id) 

605 local_object = trackable._lookup_dependency(child.local_name) 

606 child_proto = child_position.object_proto 

607 if local_object is None: 

608 # We don't yet have a dependency registered with this name. Save it 

609 # in case we do. 

610 if child_proto.HasField("has_checkpoint_values"): 

611 has_value = child_proto.has_checkpoint_values.value 

612 else: 

613 # If the field is not set, do a simple check to see if the dependency 

614 # has children and/or checkpointed values. 

615 has_value = bool( 

616 child_proto.children or child_proto.attributes or 

617 child_proto.slot_variables or 

618 child_proto.HasField("registered_saver")) 

619 if has_value: 

620 trackable._deferred_dependencies.setdefault(child.local_name, 

621 []).append(child_position) 

622 else: 

623 if child_position.bind_object(trackable=local_object): 

624 # This object's correspondence is new, so dependencies need to be 

625 # visited. Delay doing it so that we get a breadth-first dependency 

626 # resolution order (shallowest paths first). The caller is responsible 

627 # for emptying visit_queue. 

628 visit_queue.append((child_position, local_object)) 

629 

630 

631_DeferredSlotVariableRestoration = collections.namedtuple( 

632 "_DeferredSlotVariableRestoration", [ 

633 "original_variable", 

634 "slot_variable_id", 

635 "slot_name", 

636 ]) 

637 

638 

639def _queue_slot_variables(checkpoint_position, visit_queue): 

640 """Queues slot variables for restoration.""" 

641 trackable = checkpoint_position.trackable 

642 checkpoint = checkpoint_position.checkpoint 

643 for deferred_slot_restoration in (checkpoint.deferred_slot_restorations.pop( 

644 checkpoint_position.proto_id, ())): 

645 slot_variable_position, slot_variable = ( 

646 checkpoint_position.create_slot_variable_position( 

647 trackable, deferred_slot_restoration.original_variable, 

648 deferred_slot_restoration.slot_variable_id, 

649 deferred_slot_restoration.slot_name)) 

650 if slot_variable_position is not None: 

651 visit_queue.append((slot_variable_position, slot_variable)) 

652 for slot_restoration in checkpoint.slot_restorations.pop( 

653 checkpoint_position.proto_id, ()): 

654 optimizer_object = checkpoint.object_by_proto_id.get( 

655 slot_restoration.optimizer_id, None) 

656 if optimizer_object is None: 

657 # The optimizer has not yet been created or tracked. Record in the 

658 # checkpoint that the slot variables need to be restored when it is. 

659 checkpoint.deferred_slot_restorations.setdefault( 

660 slot_restoration.optimizer_id, []).append( 

661 _DeferredSlotVariableRestoration( 

662 original_variable=trackable, 

663 slot_variable_id=slot_restoration.slot_variable_id, 

664 slot_name=slot_restoration.slot_name)) 

665 

666 # `optimizer_object` can be a `Checkpoint` when user only needs the 

667 # attributes the optimizer holds, such as `iterations`. In those cases, 

668 # it would not have the optimizer's `_create_or_restore_slot_variable` 

669 # method. 

670 elif hasattr(optimizer_object, "_create_or_restore_slot_variable"): 

671 slot_variable_position, slot_variable = ( 

672 checkpoint_position.create_slot_variable_position( 

673 optimizer_object, trackable, slot_restoration.slot_variable_id, 

674 slot_restoration.slot_name)) 

675 if slot_variable_position is not None: 

676 visit_queue.append((slot_variable_position, slot_variable)) 

677 

678 

679def _extract_saveable_name(checkpoint_key): 

680 # Substring the checkpoint key to the end of the "{...}.ATTRIBUTES/" 

681 search_key = trackable_utils.OBJECT_ATTRIBUTES_NAME + "/" 

682 return checkpoint_key[:checkpoint_key.index(search_key) + len(search_key)] 

683 

684 

685def _get_saveable_from_factory(saveable_factories, serialized_tensor, 

686 created_compat_names): 

687 """Returns the saveable generated from the factory method.""" 

688 matched_factory = None 

689 

690 # The `expected_factory_name` is used to find the right saveable factory, 

691 # while the `factory_input_name` is the value that is passed to the factory 

692 # method to instantiate the SaveableObject. 

693 expected_factory_name = serialized_tensor.name 

694 factory_input_name = serialized_tensor.checkpoint_key 

695 

696 # Case 1: the name already exactly matches a key in saveable_factories. 

697 if expected_factory_name in saveable_factories: 

698 matched_factory = saveable_factories[expected_factory_name] 

699 

700 # Case 2: (Forward compat) The serialized name is composed of 

701 # "factory_name" + "SUFFIX". Get the matching factory name. 

702 if matched_factory is None: 

703 

704 for factory_name, factory in saveable_factories.items(): 

705 if expected_factory_name.startswith(factory_name): 

706 if matched_factory is not None: 

707 # This condition is met in the extreme edge case where the object 

708 # returns two saveable factories with similar names. This is very 

709 # unlikely because there zero objects inside TensorFlow that use 

710 # more than one saveable factory. 

711 raise ValueError("Forward compatibility load error: Unable to load " 

712 "checkpoint saved in future version of TensorFlow. " 

713 "Please update your version of TensorFlow to the " 

714 "version in which the checkpoint was saved.") 

715 

716 matched_factory = factory 

717 factory_input_name = _extract_saveable_name( 

718 serialized_tensor.checkpoint_key) + factory_name 

719 created_compat_names.add(factory_name) 

720 

721 if callable(matched_factory): 

722 return matched_factory(name=factory_input_name) 

723 return matched_factory