Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/checkpoint/checkpoint.py: 21%

814 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1"""Utilities for saving/loading Trackable objects.""" 

2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

3# 

4# Licensed under the Apache License, Version 2.0 (the "License"); 

5# you may not use this file except in compliance with the License. 

6# You may obtain a copy of the License at 

7# 

8# http://www.apache.org/licenses/LICENSE-2.0 

9# 

10# Unless required by applicable law or agreed to in writing, software 

11# distributed under the License is distributed on an "AS IS" BASIS, 

12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

13# See the License for the specific language governing permissions and 

14# limitations under the License. 

15# ============================================================================== 

16 

17import abc 

18import collections 

19import functools 

20import glob 

21import os 

22import threading 

23import time 

24import weakref 

25 

26from tensorflow.core.protobuf import trackable_object_graph_pb2 

27from tensorflow.python.checkpoint import async_checkpoint_helper 

28from tensorflow.python.checkpoint import checkpoint_context 

29from tensorflow.python.checkpoint import checkpoint_management 

30from tensorflow.python.checkpoint import checkpoint_options 

31from tensorflow.python.checkpoint import functional_saver 

32from tensorflow.python.checkpoint import graph_view as graph_view_lib 

33from tensorflow.python.checkpoint import restore as restore_lib 

34from tensorflow.python.checkpoint import save_util 

35from tensorflow.python.checkpoint import save_util_v1 

36from tensorflow.python.checkpoint import util 

37from tensorflow.python.client import session as session_lib 

38from tensorflow.python.eager import context 

39from tensorflow.python.eager import def_function 

40from tensorflow.python.eager import monitoring 

41from tensorflow.python.framework import constant_op 

42from tensorflow.python.framework import dtypes 

43from tensorflow.python.framework import errors_impl 

44from tensorflow.python.framework import ops 

45from tensorflow.python.framework import tensor_shape 

46from tensorflow.python.framework import tensor_util 

47from tensorflow.python.lib.io import file_io 

48from tensorflow.python.ops import array_ops 

49from tensorflow.python.ops import gen_io_ops as io_ops 

50from tensorflow.python.ops import init_ops 

51from tensorflow.python.ops import variable_scope 

52from tensorflow.python.ops import variable_v1 

53from tensorflow.python.platform import gfile 

54from tensorflow.python.platform import tf_logging as logging 

55from tensorflow.python.saved_model import path_helpers 

56from tensorflow.python.saved_model.pywrap_saved_model import metrics 

57from tensorflow.python.trackable import autotrackable 

58from tensorflow.python.trackable import base 

59from tensorflow.python.trackable import data_structures 

60from tensorflow.python.training import py_checkpoint_reader 

61from tensorflow.python.training import saver as v1_saver_lib 

62from tensorflow.python.training.saving import saveable_object as saveable_object_lib 

63from tensorflow.python.training.saving import saveable_object_util 

64from tensorflow.python.util import compat 

65from tensorflow.python.util import deprecation 

66from tensorflow.python.util import object_identity 

67from tensorflow.python.util import tf_contextlib 

68from tensorflow.python.util import tf_inspect 

69from tensorflow.python.util.tf_export import tf_export 

70 

71 

72# The callable that provide Keras default session that is needed for saving. 

73_SESSION_PROVIDER = None 

74 

75# Captures the timestamp of the first Checkpoint instantiation or end of a write 

76# operation. Can be accessed by multiple Checkpoint instances. 

77_END_TIME_OF_LAST_WRITE = None 

78_END_TIME_OF_LAST_WRITE_LOCK = threading.Lock() 

79 

80# API labels for cell names used in checkpoint metrics. 

81_CHECKPOINT_V1 = "checkpoint_v1" 

82_CHECKPOINT_V2 = "checkpoint_v2" 

83 

84# Async thread used for asynchronous checkpoint. 

85_ASYNC_CHECKPOINT_THREAD = None 

86 

87 

88def _get_duration_microseconds(start_time_seconds, end_time_seconds): 

89 if end_time_seconds < start_time_seconds: 

90 # Avoid returning negative value in case of clock skew. 

91 return 0 

92 return round((end_time_seconds - start_time_seconds) * 1000000) 

93 

94 

95@tf_export("__internal__.tracking.register_session_provider", v1=[]) 

96def register_session_provider(session_provider): 

97 global _SESSION_PROVIDER 

98 # TODO(scottzhu): Change it back to only allow one time setting for session 

99 # provider once we finished the keras repo split. 

100 # if _SESSION_PROVIDER is None: 

101 _SESSION_PROVIDER = session_provider 

102 

103 

104def get_session(): 

105 # Prefer TF's default session since get_session from Keras has side-effects. 

106 session = ops.get_default_session() 

107 if session is None: 

108 global _SESSION_PROVIDER 

109 if _SESSION_PROVIDER is not None: 

110 session = _SESSION_PROVIDER() # pylint: disable=not-callable 

111 return session 

112 

113 

114def _get_checkpoint_size(prefix): 

115 """Calculates filesize of checkpoint based on prefix.""" 

116 size = 0 

117 # Gather all files beginning with prefix (.index plus sharded data files). 

118 files = glob.glob("{}*".format(prefix)) 

119 for file in files: 

120 # Use TensorFlow's C++ FileSystem API. 

121 size += metrics.CalculateFileSize(file) 

122 return size 

123 

124 

125class ObjectGraphProtoPrettyPrinter: 

126 """Lazily traverses an object graph proto to pretty print names. 

127 

128 If no calls to `node_names` are made this object has no performance 

129 overhead. On the other hand, it will only traverse the object graph once, so 

130 repeated naming is cheap after the first. 

131 """ 

132 

133 __slots__ = ["_object_graph_proto", "_node_name_cache"] 

134 

135 def __init__(self, object_graph_proto): 

136 self._object_graph_proto = object_graph_proto 

137 self._node_name_cache = None 

138 

139 @property 

140 def node_names(self): 

141 """Lazily creates a mapping from node id to ("path", "to", "root").""" 

142 if self._node_name_cache is not None: 

143 return self._node_name_cache 

144 path_to_root = {} 

145 path_to_root[0] = ("(root)",) 

146 to_visit = collections.deque([0]) 

147 while to_visit: 

148 node_id = to_visit.popleft() 

149 obj = self._object_graph_proto.nodes[node_id] 

150 for child in obj.children: 

151 if child.node_id not in path_to_root: 

152 path_to_root[child.node_id] = ( 

153 path_to_root[node_id] + (child.local_name,)) 

154 to_visit.append(child.node_id) 

155 

156 node_names = {} 

157 for node_id, path_to_root in path_to_root.items(): 

158 node_names[node_id] = ".".join(path_to_root) 

159 

160 for node_id, node in enumerate(self._object_graph_proto.nodes): 

161 for slot_reference in node.slot_variables: 

162 node_names[slot_reference.slot_variable_node_id] = ( 

163 f"{node_names[node_id]}'s state '{slot_reference.slot_name}' for " 

164 f"{node_names[slot_reference.original_variable_node_id]}") 

165 self._node_name_cache = node_names 

166 return node_names 

167 

168 

169class _CheckpointRestoreCoordinatorDeleter: 

170 """Deleter to avoid overriding _CheckpointRestoreCoordinator.__del__().""" 

171 

172 __slots__ = [ 

173 "expect_partial", "object_graph_proto", "matched_proto_ids", 

174 "unused_attributes" 

175 ] 

176 

177 def __init__(self, expect_partial, object_graph_proto, matched_proto_ids, 

178 unused_attributes): 

179 self.expect_partial = expect_partial 

180 self.object_graph_proto = object_graph_proto 

181 self.matched_proto_ids = matched_proto_ids 

182 self.unused_attributes = unused_attributes 

183 

184 def set_expect_partial(self, expect_partial): 

185 self.expect_partial = expect_partial 

186 

187 def __del__(self): 

188 if self.expect_partial: 

189 return 

190 if logging is None: 

191 # The logging module may have been unloaded when __del__ is called. 

192 log_fn = print 

193 else: 

194 log_fn = logging.warning 

195 unused_nodes_in_checkpoint = [] 

196 unrestored_attributes_in_object = [] 

197 pretty_printer = ObjectGraphProtoPrettyPrinter(self.object_graph_proto) 

198 for node_id, node in enumerate(self.object_graph_proto.nodes): 

199 if not node.attributes: 

200 continue 

201 if node_id not in self.matched_proto_ids: 

202 unused_nodes_in_checkpoint.append(pretty_printer.node_names[node_id]) 

203 for node_id, attribute_name in self.unused_attributes.items(): 

204 unrestored_attributes_in_object.append(( 

205 pretty_printer.node_names[node_id], attribute_name)) 

206 if unused_nodes_in_checkpoint or unrestored_attributes_in_object: 

207 # pylint:disable=line-too-long 

208 log_fn("Detecting that an object or model or tf.train.Checkpoint is being" 

209 " deleted with unrestored values. See the following logs for the " 

210 "specific values in question. To silence these warnings, use " 

211 "`status.expect_partial()`. See " 

212 "https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restore" 

213 "for details about the status object returned by the restore " 

214 "function.") 

215 # pylint:enable=line-too-long 

216 for node_path in unused_nodes_in_checkpoint: 

217 log_fn("Value in checkpoint could not be found in the restored object: " 

218 f"{node_path}") 

219 for node_path, attr in unrestored_attributes_in_object: 

220 log_fn("An attribute in the restored object could not be found in the " 

221 f"checkpoint. Object: {node_path}, attribute: {attr}") 

222 

223 

224class _CheckpointRestoreCoordinator: 

225 """Holds the status of an object-based checkpoint load.""" 

226 

227 def __init__(self, object_graph_proto, save_path, save_path_tensor, reader, 

228 restore_op_cache, graph_view, options, saveables_cache): 

229 """Specify the checkpoint being loaded. 

230 

231 Args: 

232 object_graph_proto: The TrackableObjectGraph protocol buffer associated 

233 with this checkpoint. 

234 save_path: A string, the path to the checkpoint, as returned by 

235 `tf.train.latest_checkpoint`. 

236 save_path_tensor: A string `Tensor` which contains or will be fed the save 

237 path. 

238 reader: A `CheckpointReader` for `save_path`. If None, 

239 `_CheckpointRestoreCoordinator` will initialize one itself. 

240 restore_op_cache: A dictionary shared between 

241 `_CheckpointRestoreCoordinator`s for the same Python objects, used to 

242 look up restore ops by name to avoid re-creating them across multiple 

243 `restore()` calls. 

244 graph_view: A graph_view_lib.ObjectGraphView object for the restored 

245 objects. 

246 options: A CheckpointOptions object. 

247 saveables_cache: An optional cache storing previously created 

248 SaveableObjects created for each Trackable. Maps Trackables to a 

249 dictionary of attribute names to Trackable. 

250 """ 

251 self.options = options 

252 self.object_graph_proto = object_graph_proto 

253 self.restore_uid = ops.uid() 

254 # Maps from proto ids to lists of attributes which were in the checkpoint 

255 # but not loaded into any object, for error checking. 

256 self.unused_attributes = {} 

257 # Dictionary mapping from an id in the protocol buffer flat array to 

258 # Trackable Python objects. This mapping may be deferred if a 

259 # checkpoint is restored before all dependencies have been tracked. Uses 

260 # weak references so that partial restorations don't create reference cycles 

261 # (as objects with deferred dependencies will generally have references to 

262 # this object). 

263 self.object_by_proto_id = weakref.WeakValueDictionary() 

264 self.matched_proto_ids = set() 

265 # A set of all Python objects we've seen as dependencies, even if we didn't 

266 # use them (for example because of inconsistent references when 

267 # loading). Used to make status assertions fail when loading checkpoints 

268 # that don't quite match. 

269 self.all_python_objects = object_identity.ObjectIdentityWeakSet() 

270 self.save_path_tensor = save_path_tensor 

271 self.save_path_string = save_path 

272 self.dtype_map = reader.get_variable_to_dtype_map() 

273 self.shape_map = reader.get_variable_to_shape_map() 

274 # A NewCheckpointReader for the most recent checkpoint, for streaming Python 

275 # state restoration. 

276 # When graph building, contains a list of ops to run to restore objects from 

277 # this checkpoint. 

278 self.restore_ops = [] 

279 self.restore_ops_by_name = restore_op_cache 

280 self.graph_view = graph_view 

281 self.new_restore_ops_callback = None 

282 # A mapping from optimizer proto ids to lists of slot variables to be 

283 # restored when the optimizer is tracked. Only includes slot variables whose 

284 # regular variables have already been created, and only for optimizer 

285 # objects which have not yet been created/tracked. 

286 self.deferred_slot_restorations = {} 

287 # A mapping from variable proto ids to lists of slot variables to be 

288 # restored when the variable is created/tracked. These get shifted over to 

289 # deferred_slot_restorations if the optimizer hasn't been created when that 

290 # happens. 

291 self.slot_restorations = {} 

292 # Controls whether errors are printed in __del__ if some objects did not 

293 # match. 

294 self.expect_partial_attr = False 

295 for node_index, node in enumerate(self.object_graph_proto.nodes): 

296 for slot_reference in node.slot_variables: 

297 # `node` refers to an `Optimizer`, since only these have slot variables. 

298 self.slot_restorations.setdefault( 

299 slot_reference.original_variable_node_id, []).append( 

300 base._SlotVariableRestoration( # pylint: disable=protected-access 

301 optimizer_id=node_index, 

302 slot_variable_id=slot_reference.slot_variable_node_id, 

303 slot_name=slot_reference.slot_name)) 

304 

305 self._deleter = _CheckpointRestoreCoordinatorDeleter( 

306 self.expect_partial_attr, 

307 self.object_graph_proto, 

308 self.matched_proto_ids, 

309 self.unused_attributes) 

310 

311 self.saveables_cache = saveables_cache 

312 

313 @property 

314 def expect_partial(self): 

315 return self.expect_partial_attr 

316 

317 @expect_partial.setter 

318 def expect_partial(self, expect_partial): 

319 self.expect_partial_attr = expect_partial 

320 self._deleter.set_expect_partial(expect_partial) 

321 

322 def new_restore_ops(self, new_ops): 

323 self.restore_ops.extend(new_ops) 

324 if self.new_restore_ops_callback: 

325 self.new_restore_ops_callback(new_ops) # pylint: disable=not-callable 

326 

327 def restore_saveables( 

328 self, 

329 tensor_saveables, 

330 python_positions, 

331 registered_savers=None, 

332 reader=None, 

333 ): 

334 """Run or build restore operations for SaveableObjects. 

335 

336 Args: 

337 tensor_saveables: `SaveableObject`s which correspond to Tensors. 

338 python_positions: List of CheckpointPositions bound to `PythonState` 

339 objects which must be restored eagerly. 

340 registered_savers: a dict mapping saver names-> object name -> Trackable. 

341 reader: A `CheckpointReader`. If None, a new instance will be created. 

342 

343 Returns: 

344 When graph building, a list of restore operations, either cached or newly 

345 created, to restore `tensor_saveables`. 

346 """ 

347 if reader is None: 

348 reader = py_checkpoint_reader.NewCheckpointReader(self.save_path_string) 

349 

350 restore_ops = [] 

351 # Eagerly run restorations for Python state. 

352 for position in python_positions: 

353 key = position.object_proto.attributes[0].checkpoint_key 

354 position.trackable.deserialize(reader.get_tensor(key)) 

355 

356 # If we have new SaveableObjects, extract and cache restore ops. 

357 if tensor_saveables or registered_savers: 

358 flat_saveables = saveable_object_util.validate_and_slice_inputs( 

359 tensor_saveables) 

360 new_restore_ops = functional_saver.MultiDeviceSaver.from_saveables( 

361 flat_saveables, 

362 registered_savers).restore(self.save_path_tensor, self.options) 

363 if not context.executing_eagerly(): 

364 for name, restore_op in sorted(new_restore_ops.items()): 

365 restore_ops.append(restore_op) 

366 assert name not in self.restore_ops_by_name 

367 self.restore_ops_by_name[name] = restore_op 

368 return restore_ops 

369 

370 

371class _NameBasedRestoreCoordinator: 

372 """Keeps the status of a name-based checkpoint restore.""" 

373 

374 def __init__(self, save_path, dtype_map=None): 

375 self.save_path = save_path 

376 self.dtype_map = dtype_map 

377 # A map from trackable objects to unused attribute names. We don't have 

378 # proto IDs when doing a name-based restore, so the map keys differ from 

379 # those in _CheckpointRestoreCoordinator. 

380 self.unused_attributes = object_identity.ObjectIdentityWeakKeyDictionary() 

381 self.restore_uid = ops.uid() 

382 

383 def globally_named_object_attributes(self, trackable): 

384 """Create globally named SaveableObjects from attributes. 

385 

386 If an object's attribute has no global name specified (default construction 

387 for the SaveableObject factory), records the failure in 

388 `self.unused_attributes` (which can then be used to make status assertions 

389 fail; see `NameBasedSaverStatus`). 

390 

391 Args: 

392 trackable: An object to save. 

393 

394 Yields: 

395 SaveableObjects for `trackable`'s attributes. 

396 """ 

397 for ( 

398 attribute_name, 

399 saveable_factory, 

400 ) in saveable_object_util.saveable_objects_from_trackable( 

401 trackable, tf1_saver=True, 

402 ).items(): 

403 if callable(saveable_factory): 

404 try: 

405 # This saveable object factory does not have a default name= argument, 

406 # which means there's no way to save/restore it using a name-based 

407 # checkpoint. Ignore the error now and make sure assert_consumed() 

408 # fails. 

409 saveable = saveable_factory() 

410 except TypeError: 

411 self.unused_attributes.setdefault(trackable, 

412 []).append(attribute_name) 

413 continue 

414 else: 

415 saveable = saveable_factory 

416 names_to_saveables = saveable_object_util.op_list_to_dict( 

417 [saveable], convert_variable_to_tensor=False) 

418 for name, op in names_to_saveables.items(): 

419 for saveable_object in saveable_object_util.saveable_objects_for_op( 

420 op=op, name=name): 

421 yield saveable_object 

422 

423 def eager_restore(self, trackable): 

424 """Runs restore ops for `trackable`'s attributes.""" 

425 # When graph building, we don't add any restore ops to the graph until 

426 # run_restore_ops/initialize_or_restore on the status object for name-based 

427 # checkpoints. 

428 assert context.executing_eagerly() 

429 for saveable in self.globally_named_object_attributes(trackable): 

430 restored_tensors = [] 

431 tensor_missing = False 

432 for spec in saveable.specs: 

433 if spec.name in self.dtype_map: 

434 with ops.device("cpu:0"): 

435 restored, = io_ops.restore_v2( 

436 prefix=self.save_path, 

437 tensor_names=[spec.name], 

438 shape_and_slices=[""], 

439 dtypes=[self.dtype_map[spec.name]], 

440 name="%s_checkpoint_read" % (spec.name,)) 

441 restored_tensors.append(array_ops.identity(restored)) 

442 else: 

443 tensor_missing = True 

444 

445 if tensor_missing: 

446 # Record that this variable didn't match so assertions will fail. 

447 self.unused_attributes.setdefault(trackable, []).append(saveable.name) 

448 else: 

449 # Ignores values missing from the checkpoint, as with object-based 

450 # restore. Status assertions can be used to check exact matches, 

451 # although it's unlikely to ever happen for name-based checkpoints. 

452 saveable.restore( 

453 restored_tensors=restored_tensors, restored_shapes=None) 

454 

455 

456# TODO(allenl): If this ends up in a public API, consider adding LINT.If Change 

457# or consolidating the implementation with get_variable. 

458def _default_getter(name, 

459 shape, 

460 dtype, 

461 initializer=None, 

462 partition_info=None, 

463 **kwargs): 

464 """A pared-down version of get_variable which does not reuse variables.""" 

465 dtype = dtypes.as_dtype(dtype) 

466 shape_object = tensor_shape.as_shape(shape) 

467 with ops.init_scope(): 

468 if initializer is None: 

469 initializer, initializing_from_value = ( 

470 variable_scope._get_default_variable_store()._get_default_initializer( # pylint: disable=protected-access 

471 name=name, 

472 shape=shape_object, 

473 dtype=dtype)) 

474 else: 

475 initializing_from_value = not callable(initializer) 

476 # Same logic as get_variable 

477 variable_dtype = dtype.base_dtype 

478 if initializing_from_value: 

479 if shape is not None: 

480 raise ValueError("If initializer is a constant, do not specify shape.") 

481 initial_value = initializer 

482 else: 

483 # Instantiate initializer if provided initializer is a type object. 

484 if isinstance(initializer, type(init_ops.Initializer)): 

485 initializer = initializer(dtype=dtype) 

486 shape_list = None if shape is None else shape_object.as_list() 

487 if "partition_info" in tf_inspect.getargspec(initializer).args: 

488 initial_value = functools.partial(initializer, 

489 shape_list, 

490 dtype=dtype, 

491 partition_info=partition_info) 

492 else: 

493 initial_value = functools.partial(initializer, 

494 shape_list, 

495 dtype=dtype) 

496 

497 return variable_v1.VariableV1( 

498 initial_value=initial_value, 

499 name=name, 

500 dtype=variable_dtype, 

501 use_resource=True, 

502 **kwargs) 

503 

504 

505def add_variable(trackable, 

506 name, 

507 shape=None, 

508 dtype=dtypes.float32, 

509 initializer=None, 

510 trainable=True): 

511 """Add a variable to a Trackable with no scope influence.""" 

512 return trackable._add_variable_with_custom_getter( # pylint: disable=protected-access 

513 name=name, 

514 shape=shape, 

515 dtype=dtype, 

516 initializer=initializer, 

517 getter=_default_getter, 

518 trainable=trainable) 

519 

520 

521def object_metadata(save_path): 

522 """Retrieves information about the objects in a checkpoint. 

523 

524 Example usage: 

525 

526 ```python 

527 object_graph = tf.contrib.checkpoint.object_metadata( 

528 tf.train.latest_checkpoint(checkpoint_directory)) 

529 ckpt_variable_names = set() 

530 for node in object_graph.nodes: 

531 for attribute in node.attributes: 

532 ckpt_variable_names.add(attribute.full_name) 

533 ``` 

534 

535 Args: 

536 save_path: The path to the checkpoint, as returned by `save` or 

537 `tf.train.latest_checkpoint`. 

538 

539 Returns: 

540 A parsed `tf.contrib.checkpoint.TrackableObjectGraph` protocol buffer. 

541 Raises: 

542 ValueError: If an object graph was not found in the checkpoint. 

543 """ 

544 reader = py_checkpoint_reader.NewCheckpointReader(save_path) 

545 try: 

546 object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY) 

547 except errors_impl.NotFoundError: 

548 raise ValueError( 

549 f"The specified checkpoint \"{save_path}\" does not appear to be " 

550 "object-based (saved with TF2) since it is missing the key " 

551 f"\"{base.OBJECT_GRAPH_PROTO_KEY}\". Likely it was created with the " 

552 "TF1 name-based saver and does not contain an object dependency graph.") 

553 object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) 

554 object_graph_proto.ParseFromString(object_graph_string) 

555 return object_graph_proto 

556 

557 

558def list_objects(root_trackable): 

559 """Traverse the object graph and list all accessible objects. 

560 

561 Looks for `Trackable` objects which are dependencies of 

562 `root_trackable`. Includes slot variables only if the variable they are 

563 slotting for and the optimizer are dependencies of `root_trackable` 

564 (i.e. if they would be saved with a checkpoint). 

565 

566 Args: 

567 root_trackable: A `Trackable` object whose dependencies should be flattened. 

568 

569 Returns: 

570 A flat list of objects. 

571 """ 

572 return util.list_objects(graph_view_lib.ObjectGraphView(root_trackable)) 

573 

574 

575def gather_initializers(root_trackable): 

576 """Traverse the object graph and find initialization ops. 

577 

578 Looks for `Trackable` objects which are dependencies of 

579 `root_trackable` and which have an `initializer` property. Includes 

580 initializers for slot variables only if the variable they are slotting for and 

581 the optimizer are dependencies of `root_trackable` (i.e. if they would be 

582 saved with a checkpoint). 

583 

584 Args: 

585 root_trackable: A `Trackable` object to gather initializers for. 

586 

587 Returns: 

588 A list of initialization ops. 

589 """ 

590 trackable_objects = list_objects(root_trackable) 

591 return [ 

592 c.initializer 

593 for c in trackable_objects 

594 if hasattr(c, "initializer") and c.initializer is not None 

595 ] 

596 

597 

598@tf_contextlib.contextmanager 

599def capture_dependencies(template): 

600 """Capture variables created within this scope as `Template` dependencies. 

601 

602 Requires that `template.variable_scope` is active. 

603 

604 This scope is intended as a compatibility measure, allowing a trackable 

605 object to add dependencies on variables created in a block of code which is 

606 not aware of object-based saving (and instead uses variable names 

607 heavily). This is how `Template` objects add dependencies on variables and 

608 sub-`Template`s. Where possible, use `tf.compat.v1.make_template` directly. 

609 

610 Args: 

611 template: The `Template` object to register dependencies with. 

612 

613 Yields: 

614 None (when used as a context manager). 

615 """ 

616 name_prefix = template.variable_scope.name 

617 

618 def _trackable_custom_creator(next_creator, 

619 name, 

620 initial_value, 

621 trackable_parent=None, 

622 **kwargs): 

623 """A variable creation hook which adds Trackable dependencies. 

624 

625 Set for example during a `Template`'s first wrapped function 

626 execution. Ensures that (a) `template` depends on any trackable 

627 objects using their own `capture_dependencies` scope inside this scope which 

628 create variables, and (b) that any variables not in a more deeply nested 

629 scope are added as dependencies directly. 

630 

631 The `trackable_parent` argument is passed between custom creators but 

632 ignored when the variable object itself is created. This argument indicates 

633 (if not `None`) that a more deeply nested scope has already added the 

634 variable as a dependency, and that parent scopes should add a dependency on 

635 that object rather than on the variable directly. 

636 

637 Args: 

638 next_creator: See `variable_scope.variable_creator_scope`; the next 

639 creator in the chain. 

640 name: The (full, scope-influenced) name of the variable. The `name_prefix` 

641 itself is stripped for the purposes of object-based dependency tracking, 

642 but scopes opened within this scope are respected. 

643 initial_value: See `variable_scope.variable_creator_scope`. Taken 

644 explicitly so the argument can be re-named and used with 

645 `Trackable._add_variable_with_custom_getter`. 

646 trackable_parent: If not None, a more deeply nested trackable object and 

647 its name prefix which were passed to `capture_dependencies` to add a 

648 dependency on (rather than depending on the variable directly). 

649 **kwargs: Passed through to the next creator. 

650 

651 Returns: 

652 The output of `next_creator`: the fetched/created variable object. 

653 """ 

654 

655 def _call_next_creator_renaming_initializer(initializer, **inner_kwargs): 

656 inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which 

657 # we don't want to propagate. 

658 return next_creator(initial_value=initializer, name=name, **inner_kwargs) 

659 

660 if name is not None and name.startswith(name_prefix): 

661 scope_stripped_name = name[len(name_prefix) + 1:] 

662 if not trackable_parent: 

663 return template._add_variable_with_custom_getter( # pylint: disable=protected-access 

664 initializer=initial_value, 

665 name=scope_stripped_name, 

666 getter=_call_next_creator_renaming_initializer, 

667 # Disable error checking for Trackable. Exceptions are instead 

668 # raised if necessary when the object-based saver tries to 

669 # save/restore the object. 

670 overwrite=True, 

671 trackable_parent=(template, name_prefix), 

672 **kwargs) 

673 else: 

674 parent_object, parent_name_prefix = trackable_parent 

675 template._track_trackable( # pylint: disable=protected-access 

676 parent_object, 

677 name=parent_name_prefix[len(name_prefix) + 1:], 

678 overwrite=True) 

679 return next_creator( 

680 name=name, 

681 initial_value=initial_value, 

682 trackable_parent=(template, name_prefix), 

683 **kwargs) 

684 

685 with variable_scope.variable_creator_scope(_trackable_custom_creator): 

686 yield 

687 

688 

689class _LoadStatus: 

690 """Abstract base for load status callbacks.""" 

691 

692 @abc.abstractmethod 

693 def assert_consumed(self): 

694 """Raises an exception unless a non-trivial restoration has completed.""" 

695 pass 

696 

697 @abc.abstractmethod 

698 def assert_existing_objects_matched(self): 

699 """Raises an exception unless existing Python objects have been matched.""" 

700 pass 

701 

702 @abc.abstractmethod 

703 def assert_nontrivial_match(self): 

704 """Raises an exception if only the root object matched.""" 

705 pass 

706 

707 @abc.abstractmethod 

708 def run_restore_ops(self, session=None): 

709 """Runs restore ops from the checkpoint. Requires a valid checkpoint.""" 

710 pass 

711 

712 @abc.abstractmethod 

713 def initialize_or_restore(self, session=None): 

714 """Runs restore ops from the checkpoint, or initializes variables.""" 

715 pass 

716 

717 def expect_partial(self): 

718 """Silence warnings about incomplete checkpoint restores.""" 

719 return self 

720 

721 

722@tf_export("__internal__.tracking.streaming_restore", v1=[]) 

723def streaming_restore(status, session=None): 

724 """When graph building, runs restore ops as soon as they come in. 

725 

726 Args: 

727 status: A _LoadStatus objects from an object-based saver's restore(). 

728 Streaming restore from name-based checkpoints is not currently supported. 

729 session: A session to run new restore ops in. 

730 """ 

731 if context.executing_eagerly(): 

732 # Streaming restore is the default/only behavior when executing eagerly. 

733 return 

734 if session is None: 

735 session = get_session() 

736 if isinstance(status, NameBasedSaverStatus): 

737 raise NotImplementedError( 

738 "Streaming restore not supported from name-based checkpoints when " 

739 "graph building. File a feature request if this limitation bothers " 

740 "you. As a workaround, consider either using tf.train.Checkpoint to " 

741 "load name-based checkpoints or enabling eager execution.") 

742 status.run_restore_ops(session=session) 

743 # pylint: disable=protected-access 

744 status._checkpoint.new_restore_ops_callback = ( 

745 lambda ops: session.run(ops, feed_dict=status._feed_dict)) 

746 # pylint: enable=protected-access 

747 

748 

749def _objects_with_attributes(full_list): 

750 """Filters out objects with no direct variable dependencies for assertions.""" 

751 return [ 

752 o for o in full_list 

753 if saveable_object_util.saveable_objects_from_trackable(o) 

754 ] 

755 

756 

757class CheckpointLoadStatus(_LoadStatus): 

758 """Checks the status of checkpoint loading and manages restore ops. 

759 

760 Returned from `Saver.restore`. Since `restore` may defer the loading of values 

761 in the checkpoint which don't yet have corresponding Python objects, 

762 `CheckpointLoadStatus` provides a callback to verify that checkpoint loading 

763 is complete (`assert_consumed`). 

764 

765 When graph building, `restore` does not run restore ops itself since their 

766 creation may be deferred. The `run_restore_ops` method must be called once all 

767 Python objects with values to restore have been created and added to the 

768 dependency graph (this does not necessarily have to be the whole checkpoint; 

769 calling `run_restore_ops` while `assert_consumed` fails is supported and will 

770 partially restore the checkpoint). 

771 

772 See `Saver.restore` for usage examples. 

773 """ 

774 

775 def __init__(self, checkpoint, feed_dict, graph_view): 

776 self._checkpoint = checkpoint 

777 self._feed_dict = feed_dict 

778 self._object_graph_view = graph_view 

779 # Keep a reference to the root, since object_graph_view might only have a 

780 # weakref. 

781 self._root = graph_view.root 

782 

783 def assert_consumed(self): 

784 """Asserts that all objects in the checkpoint have been created/matched. 

785 

786 Returns: 

787 `self` for chaining. 

788 Raises: 

789 AssertionError: If there are any Python objects in the dependency graph 

790 which have not been restored from this checkpoint or a later `restore`, 

791 or if there are any checkpointed values which have not been matched to 

792 Python objects. 

793 """ 

794 pretty_printer = ObjectGraphProtoPrettyPrinter( 

795 self._checkpoint.object_graph_proto) 

796 self.assert_existing_objects_matched() 

797 for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes): 

798 if not node.attributes: 

799 # Only raise exceptions for the nodes with attributes themselves. Either 

800 # they're ultimately not important, or they have a child with an 

801 # attribute. 

802 continue 

803 trackable = self._checkpoint.object_by_proto_id.get(node_id, None) 

804 if trackable is None: 

805 raise AssertionError( 

806 "Unresolved object in checkpoint " 

807 f"{pretty_printer.node_names[node_id]}: {node}") 

808 if self._checkpoint.slot_restorations: 

809 # Sanity check; this collection should be clear if everything has been 

810 # restored. 

811 raise AssertionError( 

812 f"Unresolved slot restorations: {self._checkpoint.slot_restorations}") 

813 if self._checkpoint.unused_attributes: 

814 unused_attribute_messages = [] 

815 for node_id, attribute in self._checkpoint.unused_attributes.items(): 

816 obj = self._checkpoint.object_by_proto_id[node_id] 

817 unused_attribute_messages.append( 

818 f"{pretty_printer.node_names[node_id]} ({obj}): {attribute}") 

819 joined_attribute_messages = "\n".join(unused_attribute_messages) 

820 raise AssertionError( 

821 "Unused attributes in these objects (the attributes exist in the " 

822 f"checkpoint but were not restored):\n{joined_attribute_messages}") 

823 return self 

824 

825 def assert_existing_objects_matched(self): 

826 """Asserts that trackable Python objects have been matched. 

827 

828 Note that this is a weaker assertion than `assert_consumed`. It will only 

829 fail for existing Python objects which are (transitive) dependencies of the 

830 root object and which do not have an entry in the checkpoint. 

831 

832 It will not fail, for example, if a `tf.keras.Layer` object has not yet been 

833 built and so has not created any `tf.Variable` objects. 

834 

835 Returns: 

836 `self` for chaining. 

837 

838 Raises: 

839 AssertionError: If a Python object exists in the transitive dependencies 

840 of the root object but does not have a value in the checkpoint. 

841 """ 

842 for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes): 

843 trackable = self._checkpoint.object_by_proto_id.get(node_id, None) 

844 if (trackable is not None and 

845 trackable._update_uid < self._checkpoint.restore_uid): # pylint: disable=protected-access 

846 raise AssertionError( 

847 f"Object {node} not assigned a value from checkpoint.") 

848 for trackable_object in util.list_objects(self._object_graph_view): 

849 # Remove data structures that do not contain any variables from 

850 # restoration checks. 

851 if (isinstance(trackable_object, 

852 data_structures.TrackableDataStructure) and 

853 not trackable_object._trackable_children( # pylint: disable=protected-access 

854 save_type=base.SaveType.CHECKPOINT)): 

855 continue 

856 self._checkpoint.all_python_objects.add(trackable_object) 

857 unused_python_objects = ( 

858 object_identity.ObjectIdentitySet( 

859 _objects_with_attributes( 

860 self._checkpoint.all_python_objects)) - 

861 object_identity.ObjectIdentitySet( 

862 self._checkpoint.object_by_proto_id.values())) 

863 if unused_python_objects: 

864 num_unused_python_objects = len(list(unused_python_objects)) 

865 # Display max number of 10 variables in error message. 

866 num_variables_to_show = min(10, num_unused_python_objects) 

867 raise AssertionError( 

868 f"Found {num_unused_python_objects} Python objects that were " 

869 "not bound to checkpointed values, likely due to changes in the " 

870 f"Python program. Showing {num_variables_to_show} of " 

871 f"{num_unused_python_objects} unmatched objects: " 

872 f"{list(unused_python_objects)[:num_variables_to_show]}") 

873 return self 

874 

875 def assert_nontrivial_match(self): 

876 """Raises an exception if only the root object matched.""" 

877 for trackable_object in util.list_objects(self._object_graph_view): 

878 self._checkpoint.all_python_objects.add(trackable_object) 

879 if len(self._checkpoint.object_by_proto_id) <= 1: 

880 unused_python_objects = ( 

881 object_identity.ObjectIdentitySet( 

882 _objects_with_attributes(self._checkpoint.all_python_objects)) - 

883 object_identity.ObjectIdentitySet( 

884 self._checkpoint.object_by_proto_id.values())) 

885 if unused_python_objects: 

886 raise AssertionError( 

887 "Nothing except the root object matched a checkpointed value. " 

888 "Typically this means that the checkpoint does not match the " 

889 "Python program. The following objects have no matching " 

890 f"checkpointed value: {list(unused_python_objects)}") 

891 else: 

892 raise AssertionError( 

893 "Nothing to load. No dependencies have been added to " 

894 f"{self._object_graph_view.root} yet.") 

895 return self 

896 

897 def run_restore_ops(self, session=None): 

898 """Run operations to restore objects in the dependency graph.""" 

899 if context.executing_eagerly(): 

900 return # Run eagerly 

901 if session is None: 

902 session = get_session() 

903 session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict) 

904 

905 def initialize_or_restore(self, session=None): 

906 """Run operations to initialize or restore objects in the dependency graph. 

907 

908 Any objects in the dependency graph which have initializers but are not in 

909 the checkpoint will have those initializers run, unless those variables are 

910 being restored by a later call to `tf.train.Checkpoint.restore()`. 

911 

912 This method has a sibling in `InitializationOnlyStatus` which instead 

913 initializes variables. That type is returned if no checkpoint is specified 

914 in `Saver.restore`. 

915 

916 Args: 

917 session: The session to run init/restore ops in. If `None`, uses the 

918 default session. 

919 """ 

920 if context.executing_eagerly(): 

921 return # Initialization and restoration ops are run eagerly 

922 if session is None: 

923 session = get_session() 

924 all_objects = util.list_objects(self._object_graph_view) 

925 already_initialized_objects = object_identity.ObjectIdentitySet( 

926 self._checkpoint.object_by_proto_id.values()) 

927 initializers_for_non_restored_variables = [ 

928 c.initializer for c in all_objects 

929 if hasattr(c, "initializer") 

930 and c not in already_initialized_objects 

931 and (getattr(c, "_update_uid", self._checkpoint.restore_uid - 1) 

932 < self._checkpoint.restore_uid) 

933 ] 

934 self.run_restore_ops(session=session) 

935 session.run(initializers_for_non_restored_variables) 

936 

937 def expect_partial(self): 

938 """Silence warnings about incomplete checkpoint restores.""" 

939 self._checkpoint.expect_partial = True 

940 return self 

941 

942 

943class InitializationOnlyStatus(_LoadStatus): 

944 """Returned from `Saver.restore` when no checkpoint has been specified. 

945 

946 Objects of this type have the same `assert_consumed` method as 

947 `CheckpointLoadStatus`, but it always fails. However, 

948 `initialize_or_restore` works on objects of both types, and will 

949 initialize variables in `InitializationOnlyStatus` objects or restore them 

950 otherwise. 

951 """ 

952 

953 def __init__(self, object_graph_view, restore_uid): 

954 self._restore_uid = restore_uid 

955 self._object_graph_view = object_graph_view 

956 # Keep a reference to the root, since graph_view might only have a weakref. 

957 self._root = object_graph_view.root 

958 

959 def assert_consumed(self): 

960 """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" 

961 raise AssertionError( 

962 "No checkpoint specified (save_path=None); nothing is being restored.") 

963 

964 def assert_existing_objects_matched(self): 

965 """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" 

966 raise AssertionError( 

967 "No checkpoint specified (save_path=None); nothing is being restored.") 

968 

969 def assert_nontrivial_match(self): 

970 """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" 

971 raise AssertionError( 

972 "No checkpoint specified (save_path=None); nothing is being restored.") 

973 

974 def run_restore_ops(self, session=None): 

975 """For consistency with `CheckpointLoadStatus`. 

976 

977 Use `initialize_or_restore` for initializing if no checkpoint was passed 

978 to `Saver.restore` and restoring otherwise. 

979 

980 Args: 

981 session: Not used. 

982 """ 

983 raise AssertionError( 

984 "No checkpoint specified, so no restore ops are available " 

985 "(save_path=None to Saver.restore).") 

986 

987 def initialize_or_restore(self, session=None): 

988 """Runs initialization ops for variables. 

989 

990 Objects which would be saved by `Saver.save` will be initialized, unless 

991 those variables are being restored by a later call to 

992 `tf.train.Checkpoint.restore()`. 

993 

994 This method does nothing when executing eagerly (initializers get run 

995 eagerly). 

996 

997 Args: 

998 session: The session to run initialization ops in. If `None`, uses the 

999 default session. 

1000 """ 

1001 if context.executing_eagerly(): 

1002 return # run eagerly 

1003 if session is None: 

1004 session = get_session() 

1005 trackable_objects = util.list_objects(self._object_graph_view) 

1006 initializers = [ 

1007 c.initializer for c in trackable_objects 

1008 if hasattr(c, "initializer") and c.initializer is not None 

1009 and (getattr(c, "_update_uid", self._restore_uid - 1) 

1010 < self._restore_uid) 

1011 ] 

1012 session.run(initializers) 

1013 

1014 

1015_DEPRECATED_RESTORE_INSTRUCTIONS = ( 

1016 "Restoring a name-based tf.train.Saver checkpoint using the object-based " 

1017 "restore API. This mode uses global names to match variables, and so is " 

1018 "somewhat fragile. It also adds new restore ops to the graph each time it " 

1019 "is called when graph building. Prefer re-encoding training checkpoints in " 

1020 "the object-based format: run save() on the object-based saver (the same " 

1021 "one this message is coming from) and use that checkpoint in the future.") 

1022 

1023 

1024class NameBasedSaverStatus(_LoadStatus): 

1025 """Status for loading a name-based training checkpoint.""" 

1026 

1027 # Ideally this deprecation decorator would be on the class, but that 

1028 # interferes with isinstance checks. 

1029 @deprecation.deprecated( 

1030 date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS) 

1031 def __init__(self, checkpoint, object_graph_view): 

1032 self._checkpoint = checkpoint 

1033 self._object_graph_view = object_graph_view 

1034 self._optionally_restored = [] 

1035 # Keep a reference to the root, since graph_view might only have a weakref. 

1036 self._root = object_graph_view.root 

1037 

1038 def add_to_optionally_restored(self, var): 

1039 """Add a variable to the list of optionally restored variables. 

1040 

1041 There are situations where certain variables should be ignored in assertions 

1042 such as assert_existing_objects_matched(). One example is that of a 

1043 checkpoint saved with train.Saver(), and restored with train.Checkpoint(): 

1044 it is possible for the train.Saver() checkpoint to be missing the internal 

1045 `save_counter` variable, which we want to ignore on restore. 

1046 

1047 Args: 

1048 var: The variable to treat as optionally restored. 

1049 """ 

1050 self._optionally_restored.append(var) 

1051 

1052 def assert_consumed(self): 

1053 """Raises an exception if any variables are unmatched.""" 

1054 unused_attributes = list(self._checkpoint.unused_attributes.items()) 

1055 unused_attributes = [ 

1056 a for a in unused_attributes 

1057 if all(a[0] is not x for x in self._optionally_restored) 

1058 ] 

1059 if unused_attributes: 

1060 unused_attribute_string = "".join( 

1061 f"\n {obj}: {attributes}" for obj, attributes in unused_attributes) 

1062 raise AssertionError( 

1063 "Some objects had attributes which were not restored: " 

1064 f"{unused_attribute_string}") 

1065 for trackable in util.list_objects(self._object_graph_view): 

1066 # pylint: disable=protected-access 

1067 trackable._maybe_initialize_trackable() 

1068 if trackable._update_uid < self._checkpoint.restore_uid: 

1069 raise AssertionError(f"Object not restored: {trackable}") 

1070 # pylint: enable=protected-access 

1071 return self 

1072 

1073 def assert_existing_objects_matched(self): 

1074 """Raises an exception if currently created objects are unmatched.""" 

1075 # For name-based checkpoints there's no object information in the 

1076 # checkpoint, so there's no distinction between 

1077 # assert_existing_objects_matched and assert_consumed (and both are less 

1078 # useful since we don't touch Python objects or Python state). 

1079 return self.assert_consumed() 

1080 

1081 def assert_nontrivial_match(self): 

1082 """Raises an exception if currently created objects are unmatched.""" 

1083 # For name-based checkpoints there's no object information in the 

1084 # checkpoint, so there's no distinction between 

1085 # assert_nontrivial_match and assert_consumed (and both are less 

1086 # useful since we don't touch Python objects or Python state). 

1087 return self.assert_consumed() 

1088 

1089 def _gather_saveable_objects(self): 

1090 """Walk the object graph, using global names for SaveableObjects.""" 

1091 objects = util.list_objects(self._object_graph_view) 

1092 saveable_objects = [] 

1093 for trackable in objects: 

1094 # pylint: disable=protected-access 

1095 trackable._maybe_initialize_trackable() 

1096 if trackable._update_uid < self._checkpoint.restore_uid: 

1097 trackable._update_uid = self._checkpoint.restore_uid 

1098 else: 

1099 continue 

1100 # pylint: enable=protected-access 

1101 saveable_objects.extend( 

1102 self._checkpoint.globally_named_object_attributes(trackable)) 

1103 return saveable_objects 

1104 

1105 def run_restore_ops(self, session=None): 

1106 """Load the name-based checkpoint using a new `tf.compat.v1.train.Saver`.""" 

1107 if context.executing_eagerly(): 

1108 return # Nothing to do, variables are restored on creation. 

1109 if session is None: 

1110 session = get_session() 

1111 with ops.device("/cpu:0"): 

1112 saveables = self._gather_saveable_objects() 

1113 v1_saver_lib.Saver(saveables).restore( 

1114 sess=session, save_path=self._checkpoint.save_path) 

1115 

1116 def initialize_or_restore(self, session=None): 

1117 """Alias for `run_restore_ops`.""" 

1118 self.run_restore_ops(session=session) 

1119 

1120 

1121class _SessionWithFeedDictAdditions(session_lib.SessionInterface): 

1122 """Pretends to be a session, inserts extra feeds on run().""" 

1123 

1124 def __init__(self, session, feed_additions): 

1125 self._wrapped_session = session 

1126 self._feed_additions = feed_additions 

1127 

1128 def run(self, fetches, feed_dict=None, **kwargs): 

1129 if feed_dict is None: 

1130 feed_dict = {} 

1131 else: 

1132 feed_dict = feed_dict.copy() 

1133 feed_dict.update(self._feed_additions) 

1134 return self._wrapped_session.run( 

1135 fetches=fetches, feed_dict=feed_dict, **kwargs) 

1136 

1137 

1138class TrackableSaver: 

1139 """Saves and restores a `Trackable` object and its dependencies. 

1140 

1141 See `Trackable` for details of dependency management. `Saver` wraps 

1142 `tf.compat.v1.train.Saver` for saving, including extra information about the 

1143 graph of 

1144 dependencies between Python objects. When restoring, it uses this information 

1145 about the save-time dependency graph to more robustly match objects with their 

1146 checkpointed values. When executing eagerly, it supports restoring variables 

1147 on object creation (see `Saver.restore`). 

1148 

1149 Values in a checkpoint are mapped to `Trackable` Python objects 

1150 (`Variable`s, `Optimizer`s, `Layer`s) based on the names provided when the 

1151 checkpoint was written. To avoid breaking existing checkpoints when modifying 

1152 a class, dependency names (the names of attributes to which `Trackable` 

1153 objects are assigned) may not change. These names are local to objects, in 

1154 contrast to the `Variable.name`-based save/restore from 

1155 `tf.compat.v1.train.Saver`, and 

1156 so allow additional program transformations. 

1157 """ 

1158 

1159 def __init__(self, graph_view): 

1160 """Configure saving. 

1161 

1162 Args: 

1163 graph_view: An `ObjectGraphView` object containing a description of the 

1164 object graph to save. 

1165 """ 

1166 self._graph_view = graph_view 

1167 

1168 # The following attributes are used when graph building. 

1169 

1170 # self._cache: A more generic cache used to cache the serialized tensors and 

1171 # TrackableObjectGraph proto attributes. 

1172 # self._saveables_cache: A dictionary mapping `Trackable` objects -> 

1173 # attribute names -> SaveableObjects, used to avoid re-creating 

1174 # SaveableObjects when graph building. 

1175 if context.executing_eagerly(): 

1176 self._cache = None 

1177 self._saveables_cache = None 

1178 else: 

1179 self._cache = object_identity.ObjectIdentityWeakKeyDictionary() 

1180 self._saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary() 

1181 

1182 # The file prefix placeholder is created lazily when graph building (and not 

1183 # at all when executing eagerly) to avoid creating ops in the constructor 

1184 # (when they may never be necessary). 

1185 self._file_prefix_placeholder = None 

1186 

1187 # Op caching for save 

1188 self._object_graph_feed_tensor = None 

1189 self._last_save_object_graph = None 

1190 self._file_prefix_feed_tensor = None 

1191 self._cached_save_operation = None 

1192 

1193 # Op caching for restore, shared between _CheckpointRestoreCoordinators 

1194 self._restore_op_cache = {} 

1195 

1196 # Object map used for checkpoint. This attribute is to be overridden by a 

1197 # Checkpoint subclass, e.g., AsyncCheckpoint, to replace the trackable 

1198 # objects for checkpoint saving. 

1199 self._object_map = None 

1200 

1201 def _gather_serialized_tensors(self, object_graph_tensor=None): 

1202 """Gathers tensors to save to ckpt and includes the object graph proto.""" 

1203 serialized_tensors, feed_additions, registered_savers, graph_proto = ( 

1204 save_util.serialize_graph_view(self._graph_view, 

1205 self._object_map, 

1206 cache=self._cache)) 

1207 

1208 if self._saveables_cache is not None: 

1209 # Store saveables cache for restoration purposes. 

1210 self._saveables_cache = ( 

1211 saveable_object_util.serialized_tensors_to_saveable_cache( 

1212 serialized_tensors)) 

1213 

1214 if object_graph_tensor is None: 

1215 with ops.device("/cpu:0"): 

1216 object_graph_tensor = constant_op.constant( 

1217 graph_proto.SerializeToString(), dtype=dtypes.string) 

1218 else: 

1219 feed_additions.update( 

1220 {object_graph_tensor: graph_proto.SerializeToString()}) 

1221 assert base.OBJECT_GRAPH_PROTO_KEY not in serialized_tensors.get(None, {}) 

1222 serialized_tensors.setdefault(None, {})[base.OBJECT_GRAPH_PROTO_KEY] = ( 

1223 object_graph_tensor) 

1224 return serialized_tensors, feed_additions, registered_savers, graph_proto 

1225 

1226 def _save_cached_when_graph_building(self, file_prefix, object_graph_tensor, 

1227 options): 

1228 """Create or retrieve save ops. 

1229 

1230 Args: 

1231 file_prefix: The prefix for saved checkpoint files. 

1232 object_graph_tensor: A `Tensor` to which the current object graph will be 

1233 fed. 

1234 options: `CheckpointOptions` object. 

1235 

1236 Returns: 

1237 A two-element tuple with a filename tensor and a feed_dict of tensors to 

1238 feed when running it (if graph building). The feed dict contains the 

1239 current object graph and any Python state to be saved in the 

1240 checkpoint. When executing eagerly only the first argument is meaningful. 

1241 """ 

1242 serialized_tensors, feed_additions, registered_savers, graph_proto = ( 

1243 self._gather_serialized_tensors(object_graph_tensor)) 

1244 

1245 if (self._last_save_object_graph != graph_proto 

1246 # When executing eagerly, we need to re-create SaveableObjects each 

1247 # time save() is called so they pick up new Tensors passed to their 

1248 # constructors. That means the Saver needs to be copied with a new 

1249 # var_list. 

1250 or context.executing_eagerly() or ops.inside_function()): 

1251 saver = functional_saver.MultiDeviceSaver(serialized_tensors, 

1252 registered_savers) 

1253 save_op = saver.save(file_prefix, options=options) 

1254 with ops.device("/cpu:0"): 

1255 with ops.control_dependencies([save_op]): 

1256 self._cached_save_operation = array_ops.identity(file_prefix) 

1257 self._last_save_object_graph = graph_proto 

1258 return self._cached_save_operation, feed_additions 

1259 

1260 def save(self, 

1261 file_prefix, 

1262 checkpoint_number=None, 

1263 session=None, 

1264 options=None): 

1265 """Save a training checkpoint. 

1266 

1267 The saved checkpoint includes variables created by this object and any 

1268 Trackable objects it depends on at the time `Saver.save()` is called. 

1269 

1270 Args: 

1271 file_prefix: A prefix to use for the checkpoint filenames 

1272 (/path/to/directory/and_a_prefix). Names are generated based on this 

1273 prefix and `checkpoint_number`, if provided. 

1274 checkpoint_number: An integer variable or Tensor, used to number 

1275 checkpoints. Typically this value is saved along with other variables in 

1276 training checkpoints, which will happen automatically if it was created 

1277 by `root_trackable` or one of its dependencies (via 

1278 `Trackable._add_variable`). 

1279 session: The session to evaluate variables in. Ignored when executing 

1280 eagerly. If not provided when graph building, the default session is 

1281 used. 

1282 options: Optional `tf.train.CheckpointOptions` object. 

1283 

1284 Returns: 

1285 The full path to the checkpoint. 

1286 

1287 Raises: 

1288 RuntimeError: if called in V1 Graph mode without a default session. 

1289 """ 

1290 options = options or checkpoint_options.CheckpointOptions() 

1291 feed_dict = {} 

1292 use_session = (not context.executing_eagerly() and 

1293 not ops.inside_function()) 

1294 if checkpoint_number: 

1295 file_prefix = "%s-%d" % (file_prefix, checkpoint_number) 

1296 if use_session: 

1297 if self._object_graph_feed_tensor is None: 

1298 with ops.device("/cpu:0"): 

1299 self._object_graph_feed_tensor = constant_op.constant( 

1300 "", dtype=dtypes.string) 

1301 self._file_prefix_feed_tensor = constant_op.constant( 

1302 "", dtype=dtypes.string) 

1303 object_graph_tensor = self._object_graph_feed_tensor 

1304 file_prefix_tensor = self._file_prefix_feed_tensor 

1305 feed_dict[file_prefix_tensor] = file_prefix 

1306 else: 

1307 with ops.device("/cpu:0"): 

1308 file_prefix_tensor = ops.convert_to_tensor( 

1309 file_prefix, dtype=dtypes.string) 

1310 object_graph_tensor = None 

1311 

1312 if not tensor_util.is_tensor(file_prefix): 

1313 file_io.recursive_create_dir(os.path.dirname(file_prefix)) 

1314 

1315 save_path, new_feed_additions = self._save_cached_when_graph_building( 

1316 file_prefix_tensor, object_graph_tensor, options) 

1317 

1318 if new_feed_additions: 

1319 feed_dict.update(new_feed_additions) 

1320 if not use_session: 

1321 session = None 

1322 elif session is None: 

1323 session = get_session() 

1324 

1325 if session: 

1326 return session.run(save_path, feed_dict=feed_dict) 

1327 elif use_session: 

1328 raise RuntimeError(f"Unable to save checkpoint to \"{file_prefix}\" " 

1329 "in graph mode without a default session. Please use " 

1330 "`with tf.Session():` to create a session.") 

1331 else: 

1332 return save_path 

1333 

1334 def restore(self, save_path, options=None): 

1335 """Restore a training checkpoint. 

1336 

1337 Restores `root_trackable` and any objects that it tracks 

1338 (transitive). Either assigns values immediately if variables to restore have 

1339 been created already, or defers restoration until the variables are 

1340 created. Dependencies added to the `root_trackable` passed to the 

1341 constructor after this call will be matched if they have a corresponding 

1342 object in the checkpoint. 

1343 

1344 When building a graph, restorations are added to the graph but not run. 

1345 

1346 ```python 

1347 saver = Saver(root) 

1348 saver.restore(path) 

1349 ``` 

1350 

1351 To ensure that loading is complete and no more deferred restorations will 

1352 take place, you can use the `assert_consumed()` method of the status object 

1353 returned by the `restore` call. 

1354 

1355 The assert will raise an exception unless every object was matched and all 

1356 checkpointed values have a matching variable object. 

1357 

1358 ```python 

1359 saver = Saver(root) 

1360 saver.restore(path).assert_consumed() 

1361 ``` 

1362 

1363 When graph building, `assert_consumed()` indicates that all of the restore 

1364 ops which will be created for this checkpoint have been created. They can be 

1365 run via the `run_restore_ops()` function of the status object: 

1366 

1367 ```python 

1368 saver.restore(path).assert_consumed().run_restore_ops() 

1369 ``` 

1370 

1371 If the checkpoint has not been consumed completely, then the list of restore 

1372 ops will grow as more objects are added to the dependency graph. 

1373 

1374 Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this 

1375 method. There is no deferred loading, and names are used to match 

1376 variables. No restore ops are created/run until `run_restore_ops()` or 

1377 `initialize_or_restore()` are called on the returned status object, even 

1378 when executing eagerly. Re-encode name-based checkpoints using this 

1379 object-based `Saver.save` as soon as possible. 

1380 

1381 Args: 

1382 save_path: The path to the checkpoint, as returned by `save` or 

1383 `tf.train.latest_checkpoint`. If None (as when there is no latest 

1384 checkpoint for `tf.train.latest_checkpoint` to return), returns an 

1385 object which may run initializers for objects in the dependency graph. 

1386 If the checkpoint was written by the name-based 

1387 `tf.compat.v1.train.Saver`, names are used to match variables. 

1388 options: Optional `tf.train.CheckpointOptions` object. 

1389 

1390 Returns: 

1391 A load status object, which can be used to make assertions about the 

1392 status of checkpoint restoration and run initialization/restore ops 

1393 (of type `CheckpointLoadStatus`, or `InitializationOnlyStatus` if 

1394 `save_path` is `None`). 

1395 

1396 If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus` 

1397 object is returned which runs restore ops from a name-based saver. 

1398 

1399 Raises: 

1400 RuntimeError: When a checkpoint file saved by async checkpoint is not 

1401 available upon restore(). 

1402 """ 

1403 options = options or checkpoint_options.CheckpointOptions() 

1404 if save_path is None: 

1405 return InitializationOnlyStatus(self._graph_view, ops.uid()) 

1406 

1407 # Wait until the ongoing checkpoint to finish. 

1408 # TODO(chienchunh): Allow to load the file while other checkpoint events 

1409 # are still ongiing. Need to add timeout mechanism along 

1410 # with conditional variables to notify when the checkpoint 

1411 # file is ready. 

1412 global _ASYNC_CHECKPOINT_THREAD 

1413 if _ASYNC_CHECKPOINT_THREAD is not None: 

1414 _ASYNC_CHECKPOINT_THREAD.join() 

1415 reader = py_checkpoint_reader.NewCheckpointReader(save_path) 

1416 graph_building = not context.executing_eagerly() 

1417 if graph_building: 

1418 dtype_map = None 

1419 else: 

1420 dtype_map = reader.get_variable_to_dtype_map() 

1421 try: 

1422 object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY) 

1423 except errors_impl.NotFoundError: 

1424 # The object graph proto does not exist in this checkpoint. Try the 

1425 # name-based compatibility mode. 

1426 restore_coordinator = _NameBasedRestoreCoordinator( 

1427 save_path=save_path, 

1428 dtype_map=dtype_map) 

1429 if not graph_building: 

1430 for existing_trackable in util.list_objects(self._graph_view): 

1431 # pylint: disable=protected-access 

1432 existing_trackable._maybe_initialize_trackable() 

1433 existing_trackable._name_based_restores.add(restore_coordinator) 

1434 existing_trackable._name_based_attribute_restore(restore_coordinator) 

1435 # pylint: enable=protected-access 

1436 return NameBasedSaverStatus( 

1437 restore_coordinator, 

1438 object_graph_view=self._graph_view) 

1439 

1440 if graph_building: 

1441 if self._file_prefix_placeholder is None: 

1442 with ops.device("/cpu:0"): 

1443 self._file_prefix_placeholder = constant_op.constant("model") 

1444 file_prefix_tensor = self._file_prefix_placeholder 

1445 file_prefix_feed_dict = {self._file_prefix_placeholder: save_path} 

1446 else: 

1447 with ops.device("/cpu:0"): 

1448 file_prefix_tensor = constant_op.constant(save_path) 

1449 file_prefix_feed_dict = None 

1450 object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) 

1451 object_graph_proto.ParseFromString(object_graph_string) 

1452 checkpoint = _CheckpointRestoreCoordinator( 

1453 object_graph_proto=object_graph_proto, 

1454 save_path=save_path, 

1455 save_path_tensor=file_prefix_tensor, 

1456 reader=reader, 

1457 restore_op_cache=self._restore_op_cache, 

1458 graph_view=self._graph_view, 

1459 options=options, 

1460 saveables_cache=self._saveables_cache) 

1461 restore_lib.CheckpointPosition( 

1462 checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root, 

1463 reader) 

1464 

1465 # Attached dependencies are not attached to the root, so should be restored 

1466 # separately. 

1467 if self._graph_view.attached_dependencies: 

1468 for ref in self._graph_view.attached_dependencies: 

1469 if ref.name == "root": 

1470 # Root dependency is automatically added to attached dependencies -- 

1471 # this can be ignored since it maps back to the root object. 

1472 continue 

1473 proto_id = None 

1474 # Find proto ID of attached dependency (if it is in the proto). 

1475 for proto_ref in object_graph_proto.nodes[0].children: 

1476 if proto_ref.local_name == ref.name: 

1477 proto_id = proto_ref.node_id 

1478 break 

1479 

1480 if proto_id in checkpoint.object_by_proto_id: 

1481 # Object has already been restored. This can happen when there's an 

1482 # indirect connection from the attached object to the root. 

1483 continue 

1484 

1485 if proto_id is None: 

1486 # Could not find attached dependency in proto. 

1487 continue 

1488 

1489 restore_lib.CheckpointPosition( 

1490 checkpoint=checkpoint, 

1491 proto_id=proto_id).restore(ref.ref, reader) 

1492 

1493 load_status = CheckpointLoadStatus( 

1494 checkpoint, 

1495 graph_view=self._graph_view, 

1496 feed_dict=file_prefix_feed_dict) 

1497 return load_status 

1498 

1499 

1500def frozen_saver(root_trackable): 

1501 """Creates a static `tf.compat.v1.train.Saver` from a trackable object. 

1502 

1503 The returned `Saver` saves object-based checkpoints, but these checkpoints 

1504 will no longer reflect structural changes to the object graph, only changes to 

1505 the values of `Variable`s added as dependencies of the root object before 

1506 `freeze` was called. 

1507 

1508 `restore` works on the returned `Saver`, but requires that the object graph of 

1509 the checkpoint being loaded exactly matches the object graph when `freeze` was 

1510 called. This is in contrast the object-based restore performed by 

1511 `tf.train.Checkpoint` which attempts a fuzzy matching between a checkpoint's 

1512 object graph and the current Python object graph. 

1513 

1514 Args: 

1515 root_trackable: A trackable object to save. 

1516 

1517 Returns: 

1518 A saver which saves object-based checkpoints for the object graph frozen at 

1519 the time `frozen_saver` was called. 

1520 """ 

1521 named_saveable_objects, registered_savers = ( 

1522 save_util_v1.frozen_saveables_and_savers( 

1523 graph_view_lib.ObjectGraphView(root_trackable))) 

1524 return functional_saver.MultiDeviceSaver.from_saveables( 

1525 named_saveable_objects, registered_savers) 

1526 

1527 

1528def _assert_trackable(obj, name): 

1529 if not isinstance( 

1530 obj, (base.Trackable, def_function.Function)): 

1531 raise ValueError( 

1532 f"`Checkpoint` was expecting {name} to be a trackable object (an " 

1533 f"object derived from `Trackable`), got {obj}. If you believe this " 

1534 "object should be trackable (i.e. it is part of the " 

1535 "TensorFlow Python API and manages state), please open an issue.") 

1536 

1537 

1538def _update_checkpoint_state_internal(file_path): 

1539 """Update internal checkpoint state.""" 

1540 checkpoint_management.update_checkpoint_state_internal( 

1541 save_dir=os.path.dirname(file_path), 

1542 model_checkpoint_path=file_path, 

1543 all_model_checkpoint_paths=[file_path], 

1544 save_relative_paths=True) 

1545 

1546 

1547def _convert_file_name_tensor_to_string(tensor): 

1548 """Convert file name tensor to string.""" 

1549 output = tensor 

1550 if tensor_util.is_tf_type(output): 

1551 # Convert to numpy if not `tf.function` building. 

1552 if context.executing_eagerly(): 

1553 output = compat.as_str(output.numpy()) 

1554 else: 

1555 # Graph + Session, so we already session.ran it. 

1556 output = compat.as_str(output) 

1557 return output 

1558 

1559 

1560def _copy_single_tensor(tensor): 

1561 """Copies a single Tensor / SaveSpec onto the CPU device.""" 

1562 device = tensor.device 

1563 if isinstance(tensor, saveable_object_lib.SaveSpec): 

1564 # Pin the device according to the tensor's device location to 

1565 # avoid unnecessary data copies when reading the variables. This is 

1566 # aligned with the behavior in MultiDeviceSaver.save(). 

1567 with ops.device(device): 

1568 tensor = tensor.tensor 

1569 

1570 if tensor is not None: 

1571 with ops.device(saveable_object_util.set_cpu0(device)): 

1572 tensor = array_ops.identity(tensor) # pylint: disable=protected-access 

1573 return tensor 

1574 

1575 

1576# Mentions graph building / Sessions. The v2 version is below. 

1577@tf_export(v1=["train.Checkpoint"]) 

1578class CheckpointV1(autotrackable.AutoTrackable): 

1579 """Groups trackable objects, saving and restoring them. 

1580 

1581 `Checkpoint`'s constructor accepts keyword arguments whose values are types 

1582 that contain trackable state, such as `tf.compat.v1.train.Optimizer` 

1583 implementations, `tf.Variable`, `tf.keras.Layer` implementations, or 

1584 `tf.keras.Model` implementations. It saves these values with a checkpoint, and 

1585 maintains a `save_counter` for numbering checkpoints. 

1586 

1587 Example usage when graph building: 

1588 

1589 ```python 

1590 import tensorflow as tf 

1591 import os 

1592 

1593 checkpoint_directory = "/tmp/training_checkpoints" 

1594 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 

1595 

1596 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 

1597 status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) 

1598 train_op = optimizer.minimize( ... ) 

1599 status.assert_consumed() # Optional sanity checks. 

1600 with tf.compat.v1.Session() as session: 

1601 # Use the Session to restore variables, or initialize them if 

1602 # tf.train.latest_checkpoint returned None. 

1603 status.initialize_or_restore(session) 

1604 for _ in range(num_training_steps): 

1605 session.run(train_op) 

1606 checkpoint.save(file_prefix=checkpoint_prefix) 

1607 ``` 

1608 

1609 Example usage with eager execution enabled: 

1610 

1611 ```python 

1612 import tensorflow as tf 

1613 import os 

1614 

1615 tf.compat.v1.enable_eager_execution() 

1616 

1617 checkpoint_directory = "/tmp/training_checkpoints" 

1618 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 

1619 

1620 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 

1621 status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) 

1622 for _ in range(num_training_steps): 

1623 optimizer.minimize( ... ) # Variables will be restored on creation. 

1624 status.assert_consumed() # Optional sanity checks. 

1625 checkpoint.save(file_prefix=checkpoint_prefix) 

1626 ``` 

1627 

1628 `Checkpoint.save` and `Checkpoint.restore` write and read object-based 

1629 checkpoints, in contrast to `tf.compat.v1.train.Saver` which writes and reads 

1630 `variable.name` based checkpoints. Object-based checkpointing saves a graph of 

1631 dependencies between Python objects (`Layer`s, `Optimizer`s, `Variable`s, 

1632 etc.) with named edges, and this graph is used to match variables when 

1633 restoring a checkpoint. It can be more robust to changes in the Python 

1634 program, and helps to support restore-on-create for variables when executing 

1635 eagerly. Prefer `tf.train.Checkpoint` over `tf.compat.v1.train.Saver` for new 

1636 code. 

1637 

1638 `Checkpoint` objects have dependencies on the objects passed as keyword 

1639 arguments to their constructors, and each dependency is given a name that is 

1640 identical to the name of the keyword argument for which it was created. 

1641 TensorFlow classes like `Layer`s and `Optimizer`s will automatically add 

1642 dependencies on their variables (e.g. "kernel" and "bias" for 

1643 `tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing 

1644 dependencies easy in user-defined classes, since `Model` hooks into attribute 

1645 assignment. For example: 

1646 

1647 ```python 

1648 class Regress(tf.keras.Model): 

1649 

1650 def __init__(self): 

1651 super().__init__() 

1652 self.input_transform = tf.keras.layers.Dense(10) 

1653 # ... 

1654 

1655 def call(self, inputs): 

1656 x = self.input_transform(inputs) 

1657 # ... 

1658 ``` 

1659 

1660 This `Model` has a dependency named "input_transform" on its `Dense` layer, 

1661 which in turn depends on its variables. As a result, saving an instance of 

1662 `Regress` using `tf.train.Checkpoint` will also save all the variables created 

1663 by the `Dense` layer. 

1664 

1665 When variables are assigned to multiple workers, each worker writes its own 

1666 section of the checkpoint. These sections are then merged/re-indexed to behave 

1667 as a single checkpoint. This avoids copying all variables to one worker, but 

1668 does require that all workers see a common filesystem. 

1669 

1670 While `tf.keras.Model.save_weights` and `tf.train.Checkpoint.save` save in the 

1671 same format, note that the root of the resulting checkpoint is the object the 

1672 save method is attached to. This means saving a `tf.keras.Model` using 

1673 `save_weights` and loading into a `tf.train.Checkpoint` with a `Model` 

1674 attached (or vice versa) will not match the `Model`'s variables. See the 

1675 [guide to training 

1676 checkpoints](https://www.tensorflow.org/guide/checkpoint) for 

1677 details. Prefer `tf.train.Checkpoint` over `tf.keras.Model.save_weights` for 

1678 training checkpoints. 

1679 

1680 Attributes: 

1681 save_counter: Incremented when `save()` is called. Used to number 

1682 checkpoints. 

1683 """ 

1684 

1685 def __init__(self, **kwargs): 

1686 """Group objects into a training checkpoint. 

1687 

1688 Args: 

1689 **kwargs: Keyword arguments are set as attributes of this object, and are 

1690 saved with the checkpoint. Values must be trackable objects. 

1691 

1692 Raises: 

1693 ValueError: If objects in `kwargs` are not trackable. 

1694 """ 

1695 super().__init__() 

1696 global _END_TIME_OF_LAST_WRITE 

1697 with _END_TIME_OF_LAST_WRITE_LOCK: 

1698 if _END_TIME_OF_LAST_WRITE is None: 

1699 _END_TIME_OF_LAST_WRITE = time.time() 

1700 

1701 for k, v in sorted(kwargs.items(), key=lambda item: item[0]): 

1702 setattr(self, k, v) 

1703 if not isinstance( 

1704 getattr(self, k), (base.Trackable, def_function.Function)): 

1705 raise ValueError( 

1706 "`Checkpoint` was expecting a trackable object (an object " 

1707 f"derived from `Trackable`), got {v}. If you believe this " 

1708 "object should be trackable (i.e. it is part of the " 

1709 "TensorFlow Python API and manages state), please open an issue.") 

1710 self._save_counter = None # Created lazily for restore-on-create. 

1711 self._save_assign_op = None 

1712 self._saver = TrackableSaver(graph_view_lib.ObjectGraphView(self)) 

1713 

1714 def _maybe_create_save_counter(self): 

1715 """Create a save counter if it does not yet exist.""" 

1716 if self._save_counter is None: 

1717 # Initialized to 0 and incremented before saving. 

1718 with ops.device("/cpu:0"): 

1719 # add_variable creates a dependency named "save_counter"; NoDependency 

1720 # prevents creating a second dependency named "_save_counter". 

1721 self._save_counter = data_structures.NoDependency( 

1722 add_variable( 

1723 self, 

1724 name="save_counter", 

1725 initializer=0, 

1726 dtype=dtypes.int64, 

1727 trainable=False)) 

1728 

1729 def write(self, file_prefix, session=None): 

1730 """Writes a training checkpoint. 

1731 

1732 The checkpoint includes variables created by this object and any 

1733 trackable objects it depends on at the time `Checkpoint.write()` is 

1734 called. 

1735 

1736 `write` does not number checkpoints, increment `save_counter`, or update the 

1737 metadata used by `tf.train.latest_checkpoint`. It is primarily intended for 

1738 use by higher level checkpoint management utilities. `save` provides a very 

1739 basic implementation of these features. 

1740 

1741 Args: 

1742 file_prefix: A prefix to use for the checkpoint filenames 

1743 (/path/to/directory/and_a_prefix). 

1744 session: The session to evaluate variables in. Ignored when executing 

1745 eagerly. If not provided when graph building, the default session is 

1746 used. 

1747 

1748 Returns: 

1749 The full path to the checkpoint (i.e. `file_prefix`). 

1750 """ 

1751 return self._write(file_prefix, session) 

1752 

1753 def _write(self, file_prefix, session=None, write_done_callback=None): 

1754 """Writes a training checkpoint. 

1755 

1756 The checkpoint includes variables created by this object and any 

1757 trackable objects it depends on at the time `Checkpoint.write()` is 

1758 called. 

1759 

1760 `write` does not number checkpoints, increment `save_counter`, or update the 

1761 metadata used by `tf.train.latest_checkpoint`. It is primarily intended for 

1762 use by higher level checkpoint management utilities. `save` provides a very 

1763 basic implementation of these features. 

1764 

1765 Args: 

1766 file_prefix: A prefix to use for the checkpoint filenames 

1767 (/path/to/directory/and_a_prefix). 

1768 session: The session to evaluate variables in. Ignored when executing 

1769 eagerly. If not provided when graph building, the default session is 

1770 used. 

1771 write_done_callback: Optional callback function to be executed once 

1772 the underlying checkpoint saving is finished. Example usage includes 

1773 updating the checkpoint internal state. 

1774 

1775 Returns: 

1776 The full path to the checkpoint (i.e. `file_prefix`). 

1777 """ 

1778 start_time = time.time() 

1779 output = self._saver.save(file_prefix=file_prefix, session=session) 

1780 end_time = time.time() 

1781 

1782 metrics.AddCheckpointWriteDuration( 

1783 api_label=_CHECKPOINT_V1, 

1784 microseconds=_get_duration_microseconds(start_time, end_time)) 

1785 

1786 global _END_TIME_OF_LAST_WRITE 

1787 with _END_TIME_OF_LAST_WRITE_LOCK: 

1788 metrics.AddTrainingTimeSaved( 

1789 api_label=_CHECKPOINT_V1, 

1790 microseconds=_get_duration_microseconds(_END_TIME_OF_LAST_WRITE, 

1791 end_time)) 

1792 

1793 if checkpoint_context.in_preemption_save_context(): 

1794 _preemption_checkpoint_saved_time_usecs.get_cell().increase_by( 

1795 _get_duration_microseconds(_END_TIME_OF_LAST_WRITE, end_time) 

1796 ) 

1797 

1798 _END_TIME_OF_LAST_WRITE = end_time 

1799 

1800 if tensor_util.is_tf_type(output): 

1801 # Convert to numpy if not `tf.function` building. 

1802 if context.executing_eagerly(): 

1803 output = compat.as_str(output.numpy()) 

1804 else: 

1805 # Graph + Session, so we already session.ran it. 

1806 output = compat.as_str(output) 

1807 

1808 if write_done_callback: 

1809 write_done_callback(output) 

1810 

1811 metrics.RecordCheckpointSize( 

1812 api_label=_CHECKPOINT_V1, filesize=_get_checkpoint_size(output)) 

1813 return output 

1814 

1815 @property 

1816 def save_counter(self): 

1817 """An integer variable which starts at zero and is incremented on save. 

1818 

1819 Used to number checkpoints. 

1820 

1821 Returns: 

1822 The save counter variable. 

1823 """ 

1824 self._maybe_create_save_counter() 

1825 return self._save_counter 

1826 

1827 def save(self, file_prefix, session=None): 

1828 """Saves a training checkpoint and provides basic checkpoint management. 

1829 

1830 The saved checkpoint includes variables created by this object and any 

1831 trackable objects it depends on at the time `Checkpoint.save()` is 

1832 called. 

1833 

1834 `save` is a basic convenience wrapper around the `write` method, 

1835 sequentially numbering checkpoints using `save_counter` and updating the 

1836 metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint 

1837 management, for example garbage collection and custom numbering, may be 

1838 provided by other utilities which also wrap `write` 

1839 (`tf.train.CheckpointManager` for example). 

1840 

1841 Args: 

1842 file_prefix: A prefix to use for the checkpoint filenames 

1843 (/path/to/directory/and_a_prefix). Names are generated based on this 

1844 prefix and `Checkpoint.save_counter`. 

1845 session: The session to evaluate variables in. Ignored when executing 

1846 eagerly. If not provided when graph building, the default session is 

1847 used. 

1848 

1849 Returns: 

1850 The full path to the checkpoint. 

1851 """ 

1852 graph_building = not context.executing_eagerly() 

1853 if graph_building: 

1854 if ops.inside_function(): 

1855 raise NotImplementedError( 

1856 "Calling tf.train.Checkpoint.save() from a function is not " 

1857 "supported, as save() modifies saving metadata in ways not " 

1858 "supported by TensorFlow Operations. Consider using " 

1859 "tf.train.Checkpoint.write(), a lower-level API which does not " 

1860 "update metadata. tf.train.latest_checkpoint and related APIs will " 

1861 "not see this checkpoint.") 

1862 if session is None: 

1863 session = get_session() 

1864 if self._save_counter is None: 

1865 # When graph building, if this is a new save counter variable then it 

1866 # needs to be initialized before assign_add. This is only an issue if 

1867 # restore() has not been called first. 

1868 session.run(self.save_counter.initializer) 

1869 if not graph_building or self._save_assign_op is None: 

1870 with ops.colocate_with(self.save_counter): 

1871 assign_op = self.save_counter.assign_add(1, read_value=True) 

1872 if graph_building: 

1873 self._save_assign_op = data_structures.NoDependency(assign_op) 

1874 if graph_building: 

1875 checkpoint_number = session.run(self._save_assign_op) 

1876 else: 

1877 checkpoint_number = assign_op.numpy() 

1878 file_path = self.write( 

1879 "%s-%d" % (file_prefix, checkpoint_number), session=session) 

1880 checkpoint_management.update_checkpoint_state_internal( 

1881 save_dir=os.path.dirname(file_prefix), 

1882 model_checkpoint_path=file_path, 

1883 all_model_checkpoint_paths=[file_path], 

1884 save_relative_paths=True) 

1885 return file_path 

1886 

1887 def restore(self, save_path): 

1888 """Restore a training checkpoint. 

1889 

1890 Restores this `Checkpoint` and any objects it depends on. 

1891 

1892 When executing eagerly, either assigns values immediately if variables to 

1893 restore have been created already, or defers restoration until the variables 

1894 are created. Dependencies added after this call will be matched if they have 

1895 a corresponding object in the checkpoint (the restore request will queue in 

1896 any trackable object waiting for the expected dependency to be added). 

1897 

1898 When graph building, restoration ops are added to the graph but not run 

1899 immediately. 

1900 

1901 ```python 

1902 checkpoint = tf.train.Checkpoint( ... ) 

1903 checkpoint.restore(path) 

1904 ``` 

1905 

1906 To ensure that loading is complete and no more deferred restorations will 

1907 take place, you can use the `assert_consumed()` method of the status object 

1908 returned by `restore`. 

1909 The assert will raise an exception if any Python objects in the dependency 

1910 graph were not found in the checkpoint, or if any checkpointed values do not 

1911 have a matching Python object: 

1912 

1913 ```python 

1914 checkpoint = tf.train.Checkpoint( ... ) 

1915 checkpoint.restore(path).assert_consumed() 

1916 ``` 

1917 

1918 When graph building, `assert_consumed()` indicates that all of the restore 

1919 ops that will be created for this checkpoint have been created. They can be 

1920 run via the `run_restore_ops()` method of the status object: 

1921 

1922 ```python 

1923 checkpoint.restore(path).assert_consumed().run_restore_ops() 

1924 ``` 

1925 

1926 If the checkpoint has not been consumed completely, then the list of restore 

1927 ops will grow as more objects are added to the dependency graph. 

1928 

1929 To check that all variables in the Python object have restored values from 

1930 checkpoint, use `assert_existing_objects_matched()`. This assertion is 

1931 useful when called after the variables in your graph have been created. 

1932 

1933 Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this 

1934 method. Names are used to match variables. No restore ops are created/run 

1935 until `run_restore_ops()` or `initialize_or_restore()` are called on the 

1936 returned status object when graph building, but there is restore-on-creation 

1937 when executing eagerly. Re-encode name-based checkpoints using 

1938 `tf.train.Checkpoint.save` as soon as possible. 

1939 

1940 Args: 

1941 save_path: The path to the checkpoint, as returned by `save` or 

1942 `tf.train.latest_checkpoint`. If None (as when there is no latest 

1943 checkpoint for `tf.train.latest_checkpoint` to return), returns an 

1944 object which may run initializers for objects in the dependency graph. 

1945 If the checkpoint was written by the name-based 

1946 `tf.compat.v1.train.Saver`, names are used to match variables. 

1947 

1948 Returns: 

1949 A load status object, which can be used to make assertions about the 

1950 status of a checkpoint restoration and run initialization/restore ops. 

1951 

1952 The returned status object has the following methods: 

1953 

1954 * `assert_consumed()`: 

1955 Raises an exception if any variables are unmatched: either 

1956 checkpointed values which don't have a matching Python object or 

1957 Python objects in the dependency graph with no values in the 

1958 checkpoint. This method returns the status object, and so may be 

1959 chained with `initialize_or_restore` or `run_restore_ops`. 

1960 

1961 * `assert_existing_objects_matched()`: 

1962 Raises an exception if any existing Python objects in the dependency 

1963 graph are unmatched. Unlike `assert_consumed`, this assertion will 

1964 pass if values in the checkpoint have no corresponding Python 

1965 objects. For example a `tf.keras.Layer` object which has not yet been 

1966 built, and so has not created any variables, will pass this assertion 

1967 but will fail `assert_consumed`. Useful when loading part of a larger 

1968 checkpoint into a new Python program, e.g. a training checkpoint with 

1969 a `tf.compat.v1.train.Optimizer` was saved but only the state required 

1970 for inference is being loaded. This method returns the status object, 

1971 and so may be chained with `initialize_or_restore` or 

1972 `run_restore_ops`. 

1973 

1974 * `assert_nontrivial_match()`: Asserts that something aside from the root 

1975 object was matched. This is a very weak assertion, but is useful for 

1976 sanity checking in library code where objects may exist in the 

1977 checkpoint which haven't been created in Python and some Python 

1978 objects may not have a checkpointed value. 

1979 

1980 * `expect_partial()`: Silence warnings about incomplete checkpoint 

1981 restores. Warnings are otherwise printed for unused parts of the 

1982 checkpoint file or object when the `Checkpoint` object is deleted 

1983 (often at program shutdown). 

1984 

1985 * `initialize_or_restore(session=None)`: 

1986 When graph building, runs variable initializers if `save_path` is 

1987 `None`, but otherwise runs restore operations. If no `session` is 

1988 explicitly specified, the default session is used. No effect when 

1989 executing eagerly (variables are initialized or restored eagerly). 

1990 

1991 * `run_restore_ops(session=None)`: 

1992 When graph building, runs restore operations. If no `session` is 

1993 explicitly specified, the default session is used. No effect when 

1994 executing eagerly (restore operations are run eagerly). May only be 

1995 called when `save_path` is not `None`. 

1996 """ 

1997 start_time = time.time() 

1998 status = self._saver.restore(save_path=save_path) 

1999 # Create the save counter now so it gets initialized with other variables 

2000 # when graph building. Creating it earlier would lead to errors when using, 

2001 # say, train.Saver() to save the model before initializing it. 

2002 self._maybe_create_save_counter() 

2003 if isinstance(status, NameBasedSaverStatus): 

2004 status.add_to_optionally_restored(self.save_counter) 

2005 

2006 metrics.AddCheckpointReadDuration( 

2007 api_label=_CHECKPOINT_V1, 

2008 microseconds=_get_duration_microseconds(start_time, time.time())) 

2009 return status 

2010 

2011 

2012@tf_export("train.Checkpoint", v1=[]) 

2013class Checkpoint(autotrackable.AutoTrackable): 

2014 """Manages saving/restoring trackable values to disk. 

2015 

2016 TensorFlow objects may contain trackable state, such as `tf.Variable`s, 

2017 `tf.keras.optimizers.Optimizer` implementations, `tf.data.Dataset` iterators, 

2018 `tf.keras.Layer` implementations, or `tf.keras.Model` implementations. 

2019 These are called **trackable objects**. 

2020 

2021 A `Checkpoint` object can be constructed to save either a single or group of 

2022 trackable objects to a checkpoint file. It maintains a `save_counter` for 

2023 numbering checkpoints. 

2024 

2025 Example: 

2026 

2027 ```python 

2028 model = tf.keras.Model(...) 

2029 checkpoint = tf.train.Checkpoint(model) 

2030 

2031 # Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time 

2032 # checkpoint.save is called, the save counter is increased. 

2033 save_path = checkpoint.save('/tmp/training_checkpoints') 

2034 

2035 # Restore the checkpointed values to the `model` object. 

2036 checkpoint.restore(save_path) 

2037 ``` 

2038 

2039 Example 2: 

2040 

2041 ```python 

2042 import tensorflow as tf 

2043 import os 

2044 

2045 checkpoint_directory = "/tmp/training_checkpoints" 

2046 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 

2047 

2048 # Create a Checkpoint that will manage two objects with trackable state, 

2049 # one we name "optimizer" and the other we name "model". 

2050 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 

2051 status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) 

2052 for _ in range(num_training_steps): 

2053 optimizer.minimize( ... ) # Variables will be restored on creation. 

2054 status.assert_consumed() # Optional sanity checks. 

2055 checkpoint.save(file_prefix=checkpoint_prefix) 

2056 ``` 

2057 

2058 `Checkpoint.save()` and `Checkpoint.restore()` write and read object-based 

2059 checkpoints, in contrast to TensorFlow 1.x's `tf.compat.v1.train.Saver` which 

2060 writes and 

2061 reads `variable.name` based checkpoints. Object-based checkpointing saves a 

2062 graph of dependencies between Python objects (`Layer`s, `Optimizer`s, 

2063 `Variable`s, etc.) with named edges, and this graph is used to match variables 

2064 when restoring a checkpoint. It can be more robust to changes in the Python 

2065 program, and helps to support restore-on-create for variables. 

2066 

2067 `Checkpoint` objects have dependencies on the objects passed as keyword 

2068 arguments to their constructors, and each dependency is given a name that is 

2069 identical to the name of the keyword argument for which it was created. 

2070 TensorFlow classes like `Layer`s and `Optimizer`s will automatically add 

2071 dependencies on their own variables (e.g. "kernel" and "bias" for 

2072 `tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing 

2073 dependencies easy in user-defined classes, since `Model` hooks into attribute 

2074 assignment. For example: 

2075 

2076 ```python 

2077 class Regress(tf.keras.Model): 

2078 

2079 def __init__(self): 

2080 super().__init__() 

2081 self.input_transform = tf.keras.layers.Dense(10) 

2082 # ... 

2083 

2084 def call(self, inputs): 

2085 x = self.input_transform(inputs) 

2086 # ... 

2087 ``` 

2088 

2089 This `Model` has a dependency named "input_transform" on its `Dense` layer, 

2090 which in turn depends on its variables. As a result, saving an instance of 

2091 `Regress` using `tf.train.Checkpoint` will also save all the variables created 

2092 by the `Dense` layer. 

2093 

2094 When variables are assigned to multiple workers, each worker writes its own 

2095 section of the checkpoint. These sections are then merged/re-indexed to behave 

2096 as a single checkpoint. This avoids copying all variables to one worker, but 

2097 does require that all workers see a common filesystem. 

2098 

2099 This function differs slightly from the Keras Model `save_weights` function. 

2100 `tf.keras.Model.save_weights` creates a checkpoint file with the name 

2101 specified in `filepath`, while `tf.train.Checkpoint` numbers the checkpoints, 

2102 using `filepath` as the prefix for the checkpoint file names. Aside from this, 

2103 `model.save_weights()` and `tf.train.Checkpoint(model).save()` are equivalent. 

2104 

2105 See the [guide to training 

2106 checkpoints](https://www.tensorflow.org/guide/checkpoint) for 

2107 details. 

2108 

2109 Attributes: 

2110 save_counter: Incremented when `save()` is called. Used to number 

2111 checkpoints. 

2112 """ 

2113 

2114 def __init__(self, root=None, **kwargs): 

2115 """Creates a training checkpoint for a single or group of objects. 

2116 

2117 Args: 

2118 root: The root object to checkpoint. `root` may be a trackable object or 

2119 `WeakRef` of a trackable object. 

2120 **kwargs: Keyword arguments are set as attributes of this object, and are 

2121 saved with the checkpoint. All `kwargs` must be trackable objects, or a 

2122 nested structure of trackable objects (`list`, `dict`, or `tuple`). 

2123 

2124 Raises: 

2125 ValueError: If `root` or the objects in `kwargs` are not trackable. A 

2126 `ValueError` is also raised if the `root` object tracks different 

2127 objects from the ones listed in attributes in kwargs (e.g. 

2128 `root.child = A` and `tf.train.Checkpoint(root, child=B)` are 

2129 incompatible). 

2130 

2131 """ 

2132 super().__init__() 

2133 global _END_TIME_OF_LAST_WRITE 

2134 with _END_TIME_OF_LAST_WRITE_LOCK: 

2135 if _END_TIME_OF_LAST_WRITE is None: 

2136 _END_TIME_OF_LAST_WRITE = time.time() 

2137 

2138 # Store a reference to root and kwargs if we need to instantiate an 

2139 # AsyncCheckpointer later. 

2140 self._root = root 

2141 self._kwargs = kwargs 

2142 self._delete_tracking("_kwargs") 

2143 

2144 # Don't instantiate the AsyncCheckpointer unless required. 

2145 self._async_checkpointer_impl = None 

2146 

2147 # Store checkpoint options during the save/write calls so that subsequent 

2148 # read/restore calls are done properly. This is only populated when 

2149 # async read/write is enabled. 

2150 self._checkpoint_options = None 

2151 

2152 attached_dependencies = None 

2153 self._save_counter = None # Created lazily for restore-on-create. 

2154 self._save_assign_op = None 

2155 

2156 if root: 

2157 trackable_root = root() if isinstance(root, weakref.ref) else root 

2158 _assert_trackable(trackable_root, "root") 

2159 attached_dependencies = [] 

2160 

2161 # All keyword arguments (including root itself) are set as children 

2162 # of root. 

2163 kwargs["root"] = root 

2164 trackable_root._maybe_initialize_trackable() 

2165 

2166 self._save_counter = data_structures.NoDependency( 

2167 trackable_root._lookup_dependency("save_counter")) 

2168 

2169 for k, v in sorted(kwargs.items(), key=lambda item: item[0]): 

2170 setattr(self, k, v) 

2171 

2172 # Call getattr instead of directly using v because setattr converts 

2173 # v to a Trackable data structure when v is a list/dict/tuple. 

2174 converted_v = getattr(self, k) 

2175 if isinstance(converted_v, weakref.ref): 

2176 converted_v = converted_v() 

2177 _assert_trackable(converted_v, k) 

2178 

2179 if root: 

2180 # Make sure that root doesn't already have dependencies with these names 

2181 child = trackable_root._lookup_dependency(k) 

2182 if child is None: 

2183 attached_dependencies.append( 

2184 base.WeakTrackableReference(k, converted_v)) 

2185 elif child != converted_v: 

2186 raise ValueError( 

2187 f"Cannot create a Checkpoint with keyword argument {k} if " 

2188 f"root.{k} already exists.") 

2189 

2190 self._saver = TrackableSaver( 

2191 graph_view_lib.ObjectGraphView( 

2192 root if root else self, 

2193 attached_dependencies=attached_dependencies)) 

2194 self._attached_dependencies = data_structures.NoDependency( 

2195 attached_dependencies) 

2196 

2197 def _maybe_create_save_counter(self): 

2198 """Create a save counter if it does not yet exist.""" 

2199 if self._save_counter is None: 

2200 # Initialized to 0 and incremented before saving. 

2201 with ops.device("/cpu:0"): 

2202 # add_variable creates a dependency named "save_counter"; NoDependency 

2203 # prevents creating a second dependency named "_save_counter". 

2204 self._save_counter = data_structures.NoDependency( 

2205 add_variable( 

2206 self, 

2207 name="save_counter", 

2208 initializer=0, 

2209 dtype=dtypes.int64, 

2210 trainable=False)) 

2211 if self._attached_dependencies is not None: 

2212 self._attached_dependencies.append( 

2213 # Store a stronge reference to the `save_counter`, so that if the 

2214 # `Checkpoint` object is deleted, the `save_counter` does not get 

2215 # deleted immediately. (The LoadStatus object needs to indirectly 

2216 # reference the counter through the ObjectGraphView). 

2217 base.TrackableReference("save_counter", self._save_counter)) 

2218 # When loading a checkpoint, the save counter is created after 

2219 # the checkpoint has been loaded, so it must be handled in a deferred 

2220 # manner. 

2221 if isinstance(self.root, weakref.ref): 

2222 root = self.root() 

2223 else: 

2224 root = self.root 

2225 restore = root._deferred_dependencies.pop("save_counter", ()) # pylint: disable=protected-access 

2226 if restore: 

2227 restore[0].restore(self._save_counter) 

2228 

2229 def write(self, file_prefix, options=None): 

2230 """Writes a training checkpoint. 

2231 

2232 The checkpoint includes variables created by this object and any 

2233 trackable objects it depends on at the time `Checkpoint.write()` is 

2234 called. 

2235 

2236 `write` does not number checkpoints, increment `save_counter`, or update the 

2237 metadata used by `tf.train.latest_checkpoint`. It is primarily intended for 

2238 use by higher level checkpoint management utilities. `save` provides a very 

2239 basic implementation of these features. 

2240 

2241 Checkpoints written with `write` must be read with `read`. 

2242 

2243 Example usage: 

2244 

2245 ``` 

2246 step = tf.Variable(0, name="step") 

2247 checkpoint = tf.Checkpoint(step=step) 

2248 checkpoint.write("/tmp/ckpt") 

2249 

2250 # Later, read the checkpoint with read() 

2251 checkpoint.read("/tmp/ckpt") 

2252 

2253 # You can also pass options to write() and read(). For example this 

2254 # runs the IO ops on the localhost: 

2255 options = tf.CheckpointOptions(experimental_io_device="/job:localhost") 

2256 checkpoint.write("/tmp/ckpt", options=options) 

2257 

2258 # Later, read the checkpoint with read() 

2259 checkpoint.read("/tmp/ckpt", options=options) 

2260 ``` 

2261 

2262 Args: 

2263 file_prefix: A prefix to use for the checkpoint filenames 

2264 (/path/to/directory/and_a_prefix). 

2265 options: Optional `tf.train.CheckpointOptions` object. 

2266 

2267 Returns: 

2268 The full path to the checkpoint (i.e. `file_prefix`). 

2269 """ 

2270 if isinstance(file_prefix, os.PathLike): 

2271 file_prefix = os.fspath(file_prefix) 

2272 return self._write(file_prefix, options) 

2273 

2274 def _async_checkpointer(self): 

2275 """Returns an instantiated AsyncCheckpointHelper.""" 

2276 if self._async_checkpointer_impl is None: 

2277 self._async_checkpointer_impl = ( 

2278 async_checkpoint_helper.AsyncCheckpointHelper( 

2279 Checkpoint, 

2280 **self._kwargs)) 

2281 

2282 return self._async_checkpointer_impl 

2283 

2284 def _write(self, file_prefix, options=None, write_done_callback=None): 

2285 """Internal method that implements Checkpoint.write(). 

2286 

2287 Args: 

2288 file_prefix: A prefix to use for the checkpoint filenames 

2289 (/path/to/directory/and_a_prefix). 

2290 options: Optional `tf.train.CheckpointOptions` object. 

2291 write_done_callback: Optional callback function to be executed once 

2292 the underlying checkpoint saving is finished. Example usage includes 

2293 updating the checkpoint internal state. 

2294 

2295 Returns: 

2296 The full path to the checkpoint (i.e. `file_prefix`). 

2297 """ 

2298 # Triggers TF2 async checkpoint handling if: 

2299 # 1. async checkpoint is enabled in CheckpointOptions 

2300 # 2. running in eager mode 

2301 if options and options.experimental_enable_async_checkpoint: 

2302 self._checkpoint_options = options 

2303 if checkpoint_context.in_preemption_save_context(): 

2304 # Make sure all in-progress writes have completed before saving the 

2305 # final preemption checkpoint. 

2306 if self._async_checkpointer_impl is not None: 

2307 self._async_checkpointer_impl.sync() 

2308 # Additional work done will not be saved in a future checkpoint, so 

2309 # we use regular sync checkpoint to avoid overhead of dispatching 

2310 # checkpoint write to a new thread. 

2311 logging.warning( 

2312 "Switching to regular sync checkpoint for preemption checkpoint." 

2313 ) 

2314 elif context.executing_eagerly(): 

2315 return self._async_checkpointer()._write( # pylint: disable=protected-access 

2316 file_prefix, options, write_done_callback) 

2317 else: 

2318 logging.warning( 

2319 "Saving async checkpoint in graph mode is currently not supported;" 

2320 " switching to regular sync checkpoint instead.") 

2321 

2322 start_time = time.time() 

2323 options = options or checkpoint_options.CheckpointOptions() 

2324 output = self._saver.save(file_prefix=file_prefix, options=options) 

2325 output = _convert_file_name_tensor_to_string(output) 

2326 

2327 if write_done_callback: 

2328 write_done_callback(output) 

2329 

2330 # Ensure save operations have completed when running in eager runtime. 

2331 if context.executing_eagerly(): 

2332 context.async_wait() 

2333 

2334 end_time = time.time() 

2335 

2336 if not checkpoint_context.in_async_metrics_context(): 

2337 # This records the time checkpoint._write() blocks on the main thread. 

2338 metrics.AddCheckpointWriteDuration( 

2339 api_label=_CHECKPOINT_V2, 

2340 microseconds=_get_duration_microseconds(start_time, end_time), 

2341 ) 

2342 

2343 global _END_TIME_OF_LAST_WRITE 

2344 with _END_TIME_OF_LAST_WRITE_LOCK: 

2345 if not checkpoint_context.in_async_metrics_context(): 

2346 metrics.AddTrainingTimeSaved( 

2347 api_label=_CHECKPOINT_V2, 

2348 microseconds=_get_duration_microseconds( 

2349 _END_TIME_OF_LAST_WRITE, end_time) 

2350 ) 

2351 if checkpoint_context.in_preemption_save_context(): 

2352 _preemption_checkpoint_saved_time_usecs.get_cell().increase_by( 

2353 _get_duration_microseconds(_END_TIME_OF_LAST_WRITE, end_time) 

2354 ) 

2355 _END_TIME_OF_LAST_WRITE = end_time 

2356 

2357 metrics.RecordCheckpointSize( 

2358 api_label=_CHECKPOINT_V2, filesize=_get_checkpoint_size(output) 

2359 ) 

2360 return output 

2361 

2362 @property 

2363 def save_counter(self): 

2364 """An integer variable which starts at zero and is incremented on save. 

2365 

2366 Used to number checkpoints. 

2367 

2368 Returns: 

2369 The save counter variable. 

2370 """ 

2371 self._maybe_create_save_counter() 

2372 return self._save_counter 

2373 

2374 def sync(self): 

2375 """Wait for any outstanding save or restore operations.""" 

2376 # Subclasses of Checkpoint may not have `_async_checkpointer_impl` so use 

2377 # `getattr` for safer check. 

2378 if getattr(self, "_async_checkpointer_impl", None) is not None: 

2379 self._async_checkpointer_impl.sync() 

2380 

2381 def save(self, file_prefix, options=None): 

2382 # pylint:disable=line-too-long 

2383 """Saves a training checkpoint and provides basic checkpoint management. 

2384 

2385 The saved checkpoint includes variables created by this object and any 

2386 trackable objects it depends on at the time `Checkpoint.save()` is 

2387 called. 

2388 

2389 `save` is a basic convenience wrapper around the `write` method, 

2390 sequentially numbering checkpoints using `save_counter` and updating the 

2391 metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint 

2392 management, for example garbage collection and custom numbering, may be 

2393 provided by other utilities which also wrap `write` and `read`. 

2394 (`tf.train.CheckpointManager` for example). 

2395 

2396 ``` 

2397 step = tf.Variable(0, name="step") 

2398 checkpoint = tf.train.Checkpoint(step=step) 

2399 checkpoint.save("/tmp/ckpt") 

2400 

2401 # Later, read the checkpoint with restore() 

2402 checkpoint.restore("/tmp/ckpt-1") 

2403 

2404 # You can also pass options to save() and restore(). For example this 

2405 # runs the IO ops on the localhost: 

2406 options = tf.train.CheckpointOptions(experimental_io_device="/job:localhost") 

2407 checkpoint.save("/tmp/ckpt", options=options) 

2408 

2409 # Later, read the checkpoint with restore() 

2410 checkpoint.restore("/tmp/ckpt-1", options=options) 

2411 ``` 

2412 

2413 Args: 

2414 file_prefix: A prefix to use for the checkpoint filenames 

2415 (/path/to/directory/and_a_prefix). Names are generated based on this 

2416 prefix and `Checkpoint.save_counter`. 

2417 options: Optional `tf.train.CheckpointOptions` object. 

2418 

2419 Returns: 

2420 The full path to the checkpoint. 

2421 """ 

2422 # Triggers TF2 async checkpoint handling if: 

2423 # 1. async checkpoint is enabled in CheckpointOptions 

2424 # 2. running in eager mode 

2425 if options and options.experimental_enable_async_checkpoint: 

2426 self._checkpoint_options = options 

2427 if checkpoint_context.in_preemption_save_context(): 

2428 # Make sure all in-progress writes have completed before saving the 

2429 # final preemption checkpoint. 

2430 if self._async_checkpointer_impl is not None: 

2431 self._async_checkpointer_impl.sync() 

2432 # Additional work done will not be saved in a future checkpoint, so 

2433 # we use regular sync checkpoint to avoid overhead of dispatching 

2434 # checkpoint write to a new thread. 

2435 logging.warning( 

2436 "Switching to regular sync checkpoint for preemption checkpoint." 

2437 ) 

2438 elif context.executing_eagerly(): 

2439 return self._async_checkpointer().save(file_prefix, options) 

2440 else: 

2441 logging.warning( 

2442 "Saving async checkpoint in graph mode is currently not supported;" 

2443 " switching to regular sync checkpoint instead.") 

2444 

2445 if isinstance(file_prefix, os.PathLike): 

2446 file_prefix = os.fspath(file_prefix) 

2447 # pylint:enable=line-too-long 

2448 options = options or checkpoint_options.CheckpointOptions() 

2449 graph_building = not context.executing_eagerly() 

2450 if graph_building: 

2451 if ops.inside_function(): 

2452 raise NotImplementedError( 

2453 "Calling tf.train.Checkpoint.save() from a function is not " 

2454 "supported, as save() modifies saving metadata in ways not " 

2455 "supported by TensorFlow Operations. Consider using " 

2456 "tf.train.Checkpoint.write(), a lower-level API which does not " 

2457 "update metadata. tf.train.latest_checkpoint and related APIs will " 

2458 "not see this checkpoint.") 

2459 session = get_session() 

2460 if self._save_counter is None: 

2461 # When graph building, if this is a new save counter variable then it 

2462 # needs to be initialized before assign_add. This is only an issue if 

2463 # restore() has not been called first. 

2464 session.run(self.save_counter.initializer) 

2465 

2466 if not graph_building or self._save_assign_op is None: 

2467 with ops.colocate_with(self.save_counter): 

2468 assign_op = self.save_counter.assign_add(1, read_value=True) 

2469 if graph_building: 

2470 self._save_assign_op = data_structures.NoDependency(assign_op) 

2471 

2472 if graph_building: 

2473 checkpoint_number = session.run(self._save_assign_op) 

2474 else: 

2475 checkpoint_number = assign_op.numpy() 

2476 

2477 return self._write( 

2478 "%s-%d" % (file_prefix, checkpoint_number), 

2479 options=options, 

2480 write_done_callback=_update_checkpoint_state_internal) 

2481 

2482 def read(self, save_path, options=None): 

2483 """Reads a training checkpoint written with `write`. 

2484 

2485 Reads this `Checkpoint` and any objects it depends on. 

2486 

2487 This method is just like `restore()` but does not expect the `save_counter` 

2488 variable in the checkpoint. It only restores the objects that the checkpoint 

2489 already depends on. 

2490 

2491 The method is primarily intended for use by higher level checkpoint 

2492 management utilities that use `write()` instead of `save()` and have their 

2493 own mechanisms to number and track checkpoints. 

2494 

2495 Example usage: 

2496 

2497 ```python 

2498 # Create a checkpoint with write() 

2499 ckpt = tf.train.Checkpoint(v=tf.Variable(1.)) 

2500 path = ckpt.write('/tmp/my_checkpoint') 

2501 

2502 # Later, load the checkpoint with read() 

2503 # With restore() assert_consumed() would have failed. 

2504 checkpoint.read(path).assert_consumed() 

2505 

2506 # You can also pass options to read(). For example this 

2507 # runs the IO ops on the localhost: 

2508 options = tf.train.CheckpointOptions( 

2509 experimental_io_device="/job:localhost") 

2510 checkpoint.read(path, options=options) 

2511 ``` 

2512 

2513 Args: 

2514 save_path: The path to the checkpoint as returned by `write`. 

2515 options: Optional `tf.train.CheckpointOptions` object. 

2516 

2517 Returns: 

2518 A load status object, which can be used to make assertions about the 

2519 status of a checkpoint restoration. See `restore` for details. 

2520 """ 

2521 if options and options.experimental_enable_async_checkpoint: 

2522 self._checkpoint_options = options 

2523 # Triggers TF2 async checkpoint handling if: 

2524 # 1. async checkpoint is enabled in CheckpointOptions 

2525 # 2. there's a preceeding async save/write 

2526 # 3. running in eager mode 

2527 if (self._checkpoint_options and 

2528 self._checkpoint_options.experimental_enable_async_checkpoint): 

2529 if context.executing_eagerly(): 

2530 return self._async_checkpointer().read(save_path, options) 

2531 else: 

2532 logging.warning( 

2533 "Saving async checkpoint in graph mode is currently not supported;" 

2534 " switching to regular sync checkpoint instead.") 

2535 

2536 start_time = time.time() 

2537 if isinstance(save_path, os.PathLike): 

2538 save_path = os.fspath(save_path) 

2539 options = options or checkpoint_options.CheckpointOptions() 

2540 result = self._saver.restore(save_path=save_path, options=options) 

2541 metrics.AddCheckpointReadDuration( 

2542 api_label=_CHECKPOINT_V2, 

2543 microseconds=_get_duration_microseconds(start_time, time.time())) 

2544 return result 

2545 

2546 def restore(self, save_path, options=None): 

2547 """Restores a training checkpoint. 

2548 

2549 Restores this `Checkpoint` and any objects it depends on. 

2550 

2551 This method is intended to be used to load checkpoints created by `save()`. 

2552 For checkpoints created by `write()` use the `read()` method which does not 

2553 expect the `save_counter` variable added by `save()`. 

2554 

2555 `restore()` either assigns values immediately if variables to restore have 

2556 been created already, or defers restoration until the variables are 

2557 created. Dependencies added after this call will be matched if they have a 

2558 corresponding object in the checkpoint (the restore request will queue in 

2559 any trackable object waiting for the expected dependency to be added). 

2560 

2561 ```python 

2562 checkpoint = tf.train.Checkpoint( ... ) 

2563 checkpoint.restore(path) 

2564 

2565 # You can additionally pass options to restore(): 

2566 options = tf.CheckpointOptions(experimental_io_device="/job:localhost") 

2567 checkpoint.restore(path, options=options) 

2568 ``` 

2569 

2570 To ensure that loading is complete and no more deferred restorations will 

2571 take place, use the `assert_consumed()` method of the status object returned 

2572 by `restore()`: 

2573 

2574 ```python 

2575 checkpoint.restore(path, options=options).assert_consumed() 

2576 ``` 

2577 

2578 The assert will raise an error if any Python objects in the dependency graph 

2579 were not found in the checkpoint, or if any checkpointed values do not have 

2580 a matching Python object. 

2581 

2582 Name-based `tf.compat.v1.train.Saver` checkpoints from TensorFlow 1.x can be 

2583 loaded using this method. Names are used to match variables. Re-encode 

2584 name-based checkpoints using `tf.train.Checkpoint.save` as soon as possible. 

2585 

2586 **Loading from SavedModel checkpoints** 

2587 

2588 To load values from a SavedModel, just pass the SavedModel directory 

2589 to checkpoint.restore: 

2590 

2591 ```python 

2592 model = tf.keras.Model(...) 

2593 tf.saved_model.save(model, path) # or model.save(path, save_format='tf') 

2594 

2595 checkpoint = tf.train.Checkpoint(model) 

2596 checkpoint.restore(path).expect_partial() 

2597 ``` 

2598 

2599 This example calls `expect_partial()` on the loaded status, since 

2600 SavedModels saved from Keras often generates extra keys in the checkpoint. 

2601 Otherwise, the program prints a lot of warnings about unused keys at exit 

2602 time. 

2603 

2604 Args: 

2605 save_path: The path to the checkpoint, as returned by `save` or 

2606 `tf.train.latest_checkpoint`. If the checkpoint was written by the 

2607 name-based `tf.compat.v1.train.Saver`, names are used to match 

2608 variables. This path may also be a SavedModel directory. 

2609 options: Optional `tf.train.CheckpointOptions` object. 

2610 

2611 Returns: 

2612 A load status object, which can be used to make assertions about the 

2613 status of a checkpoint restoration. 

2614 

2615 The returned status object has the following methods: 

2616 

2617 * `assert_consumed()`: 

2618 Raises an exception if any variables are unmatched: either 

2619 checkpointed values which don't have a matching Python object or 

2620 Python objects in the dependency graph with no values in the 

2621 checkpoint. This method returns the status object, and so may be 

2622 chained with other assertions. 

2623 

2624 * `assert_existing_objects_matched()`: 

2625 Raises an exception if any existing Python objects in the dependency 

2626 graph are unmatched. Unlike `assert_consumed`, this assertion will 

2627 pass if values in the checkpoint have no corresponding Python 

2628 objects. For example a `tf.keras.Layer` object which has not yet been 

2629 built, and so has not created any variables, will pass this assertion 

2630 but fail `assert_consumed`. Useful when loading part of a larger 

2631 checkpoint into a new Python program, e.g. a training checkpoint with 

2632 a `tf.compat.v1.train.Optimizer` was saved but only the state required 

2633 for 

2634 inference is being loaded. This method returns the status object, and 

2635 so may be chained with other assertions. 

2636 

2637 * `assert_nontrivial_match()`: Asserts that something aside from the root 

2638 object was matched. This is a very weak assertion, but is useful for 

2639 sanity checking in library code where objects may exist in the 

2640 checkpoint which haven't been created in Python and some Python 

2641 objects may not have a checkpointed value. 

2642 

2643 * `expect_partial()`: Silence warnings about incomplete checkpoint 

2644 restores. Warnings are otherwise printed for unused parts of the 

2645 checkpoint file or object when the `Checkpoint` object is deleted 

2646 (often at program shutdown). 

2647 

2648 Raises: 

2649 NotFoundError: if the a checkpoint or SavedModel cannot be found at 

2650 `save_path`. 

2651 """ 

2652 if options and options.experimental_enable_async_checkpoint: 

2653 self._checkpoint_options = options 

2654 # Triggers TF2 async checkpoint handling if: 

2655 # 1. async checkpoint is enabled in CheckpointOptions 

2656 # 2. there's a preceeding async save/write 

2657 # 3. running in eager mode 

2658 if (self._checkpoint_options and 

2659 self._checkpoint_options.experimental_enable_async_checkpoint): 

2660 if context.executing_eagerly(): 

2661 return self._async_checkpointer().restore(save_path, options) 

2662 else: 

2663 logging.warning( 

2664 "Saving async checkpoint in graph mode is currently not supported;" 

2665 " switching to regular sync checkpoint instead.") 

2666 

2667 orig_save_path = save_path 

2668 if isinstance(save_path, os.PathLike): 

2669 save_path = os.fspath(save_path) 

2670 

2671 if save_path is not None and gfile.IsDirectory(save_path) and ( 

2672 (gfile.Exists(path_helpers.get_saved_model_pb_path(save_path)) or 

2673 gfile.Exists(path_helpers.get_saved_model_pbtxt_path(save_path)))): 

2674 save_path = path_helpers.get_variables_path(save_path) 

2675 

2676 try: 

2677 status = self.read(save_path, options=options) 

2678 if context.executing_eagerly(): 

2679 context.async_wait() # Ensure restore operations have completed. 

2680 except errors_impl.NotFoundError as e: 

2681 raise errors_impl.NotFoundError( 

2682 None, None, 

2683 f"Error when restoring from checkpoint or SavedModel at " 

2684 f"{orig_save_path}: {e.message}" 

2685 f"\nPlease double-check that the path is correct. You may be missing " 

2686 "the checkpoint suffix (e.g. the '-1' in 'path/to/ckpt-1').") 

2687 # Create the save counter now so it gets initialized with other variables 

2688 # when graph building. Creating it earlier would lead to errors when using, 

2689 # say, train.Saver() to save the model before initializing it. 

2690 self._maybe_create_save_counter() 

2691 if isinstance(status, NameBasedSaverStatus): 

2692 status.add_to_optionally_restored(self.save_counter) 

2693 return status 

2694 

2695 

2696_preemption_checkpoint_saved_time_usecs = monitoring.Counter( 

2697 "/tensorflow/api/distribution_strategy/preemption_checkpoint_saved_time_usecs", 

2698 "Training time saved by PreemptionCheckpointHandler (us).")