Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/checkpoint/checkpoint_management.py: 23%

278 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16# pylint: disable=invalid-name 

17"""Checkpoint Manager and other utilities for managing checkpoints.""" 

18import collections 

19import os.path 

20import re 

21import time 

22 

23from google.protobuf import text_format 

24 

25from tensorflow.core.protobuf import saver_pb2 

26from tensorflow.python.eager import context 

27from tensorflow.python.framework import errors 

28from tensorflow.python.framework import ops 

29from tensorflow.python.lib.io import file_io 

30from tensorflow.python.ops import variable_scope 

31from tensorflow.python.platform import tf_logging as logging 

32from tensorflow.python.training import training_util 

33from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState 

34from tensorflow.python.util import compat 

35from tensorflow.python.util import deprecation 

36from tensorflow.python.util.tf_export import tf_export 

37 

38 

39def _evaluate(tensor): 

40 """Returns the numpy value of a tensor.""" 

41 if context.executing_eagerly(): 

42 return tensor.numpy() 

43 return ops.get_default_session().run(tensor) 

44 

45 

46def _GetCheckpointFilename(save_dir, latest_filename): 

47 """Returns a filename for storing the CheckpointState. 

48 

49 Args: 

50 save_dir: The directory for saving and restoring checkpoints. 

51 latest_filename: Name of the file in 'save_dir' that is used 

52 to store the CheckpointState. 

53 

54 Returns: 

55 The path of the file that contains the CheckpointState proto. 

56 """ 

57 if latest_filename is None: 

58 latest_filename = "checkpoint" 

59 return os.path.join(save_dir, latest_filename) 

60 

61 

62@tf_export(v1=["train.generate_checkpoint_state_proto"]) 

63def generate_checkpoint_state_proto(save_dir, 

64 model_checkpoint_path, 

65 all_model_checkpoint_paths=None, 

66 all_model_checkpoint_timestamps=None, 

67 last_preserved_timestamp=None): 

68 """Generates a checkpoint state proto. 

69 

70 Args: 

71 save_dir: Directory where the model was saved. 

72 model_checkpoint_path: The checkpoint file. 

73 all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted 

74 checkpoints, sorted from oldest to newest. If this is a non-empty list, 

75 the last element must be equal to model_checkpoint_path. These paths 

76 are also saved in the CheckpointState proto. 

77 all_model_checkpoint_timestamps: A list of floats, indicating the number of 

78 seconds since the Epoch when each checkpoint was generated. 

79 last_preserved_timestamp: A float, indicating the number of seconds since 

80 the Epoch when the last preserved checkpoint was written, e.g. due to a 

81 `keep_checkpoint_every_n_hours` parameter (see 

82 `tf.train.CheckpointManager` for an implementation). 

83 Returns: 

84 CheckpointState proto with model_checkpoint_path and 

85 all_model_checkpoint_paths updated to either absolute paths or 

86 relative paths to the current save_dir. 

87 

88 Raises: 

89 ValueError: If `all_model_checkpoint_timestamps` was provided but its length 

90 does not match `all_model_checkpoint_paths`. 

91 """ 

92 if all_model_checkpoint_paths is None: 

93 all_model_checkpoint_paths = [] 

94 

95 if (not all_model_checkpoint_paths or 

96 all_model_checkpoint_paths[-1] != model_checkpoint_path): 

97 logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.", 

98 model_checkpoint_path) 

99 all_model_checkpoint_paths.append(model_checkpoint_path) 

100 

101 if (all_model_checkpoint_timestamps 

102 and (len(all_model_checkpoint_timestamps) 

103 != len(all_model_checkpoint_paths))): 

104 raise ValueError( 

105 ("Checkpoint timestamps, if provided, must match checkpoint paths (got " 

106 "paths %s and timestamps %s)") 

107 % (all_model_checkpoint_paths, all_model_checkpoint_timestamps)) 

108 

109 # Relative paths need to be rewritten to be relative to the "save_dir" 

110 # if model_checkpoint_path already contains "save_dir". 

111 if not os.path.isabs(save_dir): 

112 if not os.path.isabs(model_checkpoint_path): 

113 model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir) 

114 for i, p in enumerate(all_model_checkpoint_paths): 

115 if not os.path.isabs(p): 

116 all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir) 

117 

118 coord_checkpoint_proto = CheckpointState( 

119 model_checkpoint_path=model_checkpoint_path, 

120 all_model_checkpoint_paths=all_model_checkpoint_paths, 

121 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 

122 last_preserved_timestamp=last_preserved_timestamp) 

123 

124 return coord_checkpoint_proto 

125 

126 

127@deprecation.deprecated( 

128 date=None, 

129 instructions=("Use `tf.train.CheckpointManager` to manage checkpoints " 

130 "rather than manually editing the Checkpoint proto.")) 

131@tf_export(v1=["train.update_checkpoint_state"]) 

132def update_checkpoint_state(save_dir, 

133 model_checkpoint_path, 

134 all_model_checkpoint_paths=None, 

135 latest_filename=None, 

136 all_model_checkpoint_timestamps=None, 

137 last_preserved_timestamp=None): 

138 """Updates the content of the 'checkpoint' file. 

139 

140 This updates the checkpoint file containing a CheckpointState 

141 proto. 

142 

143 Args: 

144 save_dir: Directory where the model was saved. 

145 model_checkpoint_path: The checkpoint file. 

146 all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted 

147 checkpoints, sorted from oldest to newest. If this is a non-empty list, 

148 the last element must be equal to model_checkpoint_path. These paths 

149 are also saved in the CheckpointState proto. 

150 latest_filename: Optional name of the checkpoint file. Default to 

151 'checkpoint'. 

152 all_model_checkpoint_timestamps: Optional list of timestamps (floats, 

153 seconds since the Epoch) indicating when the checkpoints in 

154 `all_model_checkpoint_paths` were created. 

155 last_preserved_timestamp: A float, indicating the number of seconds since 

156 the Epoch when the last preserved checkpoint was written, e.g. due to a 

157 `keep_checkpoint_every_n_hours` parameter (see 

158 `tf.train.CheckpointManager` for an implementation). 

159 Raises: 

160 RuntimeError: If any of the model checkpoint paths conflict with the file 

161 containing CheckpointSate. 

162 """ 

163 update_checkpoint_state_internal( 

164 save_dir=save_dir, 

165 model_checkpoint_path=model_checkpoint_path, 

166 all_model_checkpoint_paths=all_model_checkpoint_paths, 

167 latest_filename=latest_filename, 

168 save_relative_paths=False, 

169 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 

170 last_preserved_timestamp=last_preserved_timestamp) 

171 

172 

173@tf_export("__internal__.train.update_checkpoint_state", v1=[]) 

174def update_checkpoint_state_internal(save_dir, 

175 model_checkpoint_path, 

176 all_model_checkpoint_paths=None, 

177 latest_filename=None, 

178 save_relative_paths=False, 

179 all_model_checkpoint_timestamps=None, 

180 last_preserved_timestamp=None): 

181 """Updates the content of the 'checkpoint' file. 

182 

183 This updates the checkpoint file containing a CheckpointState 

184 proto. 

185 

186 Args: 

187 save_dir: Directory where the model was saved. 

188 model_checkpoint_path: The checkpoint file. 

189 all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted 

190 checkpoints, sorted from oldest to newest. If this is a non-empty list, 

191 the last element must be equal to model_checkpoint_path. These paths 

192 are also saved in the CheckpointState proto. 

193 latest_filename: Optional name of the checkpoint file. Default to 

194 'checkpoint'. 

195 save_relative_paths: If `True`, will write relative paths to the checkpoint 

196 state file. 

197 all_model_checkpoint_timestamps: Optional list of timestamps (floats, 

198 seconds since the Epoch) indicating when the checkpoints in 

199 `all_model_checkpoint_paths` were created. 

200 last_preserved_timestamp: A float, indicating the number of seconds since 

201 the Epoch when the last preserved checkpoint was written, e.g. due to a 

202 `keep_checkpoint_every_n_hours` parameter (see 

203 `tf.train.CheckpointManager` for an implementation). 

204 

205 Raises: 

206 RuntimeError: If any of the model checkpoint paths conflict with the file 

207 containing CheckpointSate. 

208 """ 

209 # Writes the "checkpoint" file for the coordinator for later restoration. 

210 coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) 

211 if save_relative_paths: 

212 if os.path.isabs(model_checkpoint_path): 

213 rel_model_checkpoint_path = os.path.relpath( 

214 model_checkpoint_path, save_dir) 

215 else: 

216 rel_model_checkpoint_path = model_checkpoint_path 

217 rel_all_model_checkpoint_paths = [] 

218 for p in all_model_checkpoint_paths: 

219 if os.path.isabs(p): 

220 rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir)) 

221 else: 

222 rel_all_model_checkpoint_paths.append(p) 

223 ckpt = generate_checkpoint_state_proto( 

224 save_dir, 

225 rel_model_checkpoint_path, 

226 all_model_checkpoint_paths=rel_all_model_checkpoint_paths, 

227 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 

228 last_preserved_timestamp=last_preserved_timestamp) 

229 else: 

230 ckpt = generate_checkpoint_state_proto( 

231 save_dir, 

232 model_checkpoint_path, 

233 all_model_checkpoint_paths=all_model_checkpoint_paths, 

234 all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, 

235 last_preserved_timestamp=last_preserved_timestamp) 

236 

237 if coord_checkpoint_filename == ckpt.model_checkpoint_path: 

238 raise RuntimeError("Save path '%s' conflicts with path used for " 

239 "checkpoint state. Please use a different save path." % 

240 model_checkpoint_path) 

241 

242 # Preventing potential read/write race condition by *atomically* writing to a 

243 # file. 

244 file_io.atomic_write_string_to_file(coord_checkpoint_filename, 

245 text_format.MessageToString(ckpt)) 

246 

247 

248@tf_export("train.get_checkpoint_state") 

249def get_checkpoint_state(checkpoint_dir, latest_filename=None): 

250 """Returns CheckpointState proto from the "checkpoint" file. 

251 

252 If the "checkpoint" file contains a valid CheckpointState 

253 proto, returns it. 

254 

255 Args: 

256 checkpoint_dir: The directory of checkpoints. 

257 latest_filename: Optional name of the checkpoint file. Default to 

258 'checkpoint'. 

259 

260 Returns: 

261 A CheckpointState if the state was available, None 

262 otherwise. 

263 

264 Raises: 

265 ValueError: if the checkpoint read doesn't have model_checkpoint_path set. 

266 """ 

267 if isinstance(checkpoint_dir, os.PathLike): 

268 checkpoint_dir = os.fspath(checkpoint_dir) 

269 ckpt = None 

270 coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, 

271 latest_filename) 

272 f = None 

273 try: 

274 # Check that the file exists before opening it to avoid 

275 # many lines of errors from colossus in the logs. 

276 if file_io.file_exists(coord_checkpoint_filename): 

277 file_content = file_io.read_file_to_string( 

278 coord_checkpoint_filename) 

279 ckpt = CheckpointState() 

280 text_format.Merge(file_content, ckpt) 

281 if not ckpt.model_checkpoint_path: 

282 raise ValueError("Invalid checkpoint state loaded from " 

283 + checkpoint_dir) 

284 # For relative model_checkpoint_path and all_model_checkpoint_paths, 

285 # prepend checkpoint_dir. 

286 if not os.path.isabs(ckpt.model_checkpoint_path): 

287 ckpt.model_checkpoint_path = os.path.join(checkpoint_dir, 

288 ckpt.model_checkpoint_path) 

289 for i, p in enumerate(ckpt.all_model_checkpoint_paths): 

290 if not os.path.isabs(p): 

291 ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p) 

292 except errors.OpError as e: 

293 # It's ok if the file cannot be read 

294 logging.warning("%s: %s", type(e).__name__, e) 

295 logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) 

296 return None 

297 except text_format.ParseError as e: 

298 logging.warning("%s: %s", type(e).__name__, e) 

299 logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) 

300 return None 

301 finally: 

302 if f: 

303 f.close() 

304 return ckpt 

305 

306 

307def _prefix_to_checkpoint_path(prefix, format_version): 

308 """Returns the pathname of a checkpoint file, given the checkpoint prefix. 

309 

310 For V1 checkpoint, simply returns the prefix itself (the data file). For V2, 

311 returns the pathname to the index file. 

312 

313 Args: 

314 prefix: a string, the prefix of a checkpoint. 

315 format_version: the checkpoint format version that corresponds to the 

316 prefix. 

317 Returns: 

318 The pathname of a checkpoint file, taking into account the checkpoint 

319 format version. 

320 """ 

321 if format_version == saver_pb2.SaverDef.V2: 

322 return prefix + ".index" # The index file identifies a checkpoint. 

323 return prefix # Just the data file. 

324 

325 

326@tf_export("train.latest_checkpoint") 

327def latest_checkpoint(checkpoint_dir, latest_filename=None): 

328 """Finds the filename of latest saved checkpoint file. 

329 

330 Gets the checkpoint state given the provided checkpoint_dir and looks for a 

331 corresponding TensorFlow 2 (preferred) or TensorFlow 1.x checkpoint path. 

332 The latest_filename argument is only applicable if you are saving checkpoint 

333 using `v1.train.Saver.save` 

334 

335 

336 See the [Training Checkpoints 

337 Guide](https://www.tensorflow.org/guide/checkpoint) for more details and 

338 examples.` 

339 

340 Args: 

341 checkpoint_dir: Directory where the variables were saved. 

342 latest_filename: Optional name for the protocol buffer file that 

343 contains the list of most recent checkpoint filenames. 

344 See the corresponding argument to `v1.train.Saver.save`. 

345 

346 Returns: 

347 The full path to the latest checkpoint or `None` if no checkpoint was found. 

348 """ 

349 # Pick the latest checkpoint based on checkpoint state. 

350 ckpt = get_checkpoint_state(checkpoint_dir, latest_filename) 

351 if ckpt and ckpt.model_checkpoint_path: 

352 # Look for either a V2 path or a V1 path, with priority for V2. 

353 v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, 

354 saver_pb2.SaverDef.V2) 

355 v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, 

356 saver_pb2.SaverDef.V1) 

357 if file_io.get_matching_files(v2_path) or file_io.get_matching_files( 

358 v1_path): 

359 return ckpt.model_checkpoint_path 

360 else: 

361 logging.error("Couldn't match files for checkpoint %s", 

362 ckpt.model_checkpoint_path) 

363 return None 

364 

365 

366def checkpoint_exists_internal(checkpoint_prefix): 

367 """Checks whether a V1 or V2 checkpoint exists with the specified prefix. 

368 

369 This is an internal function to check if a checkpoint exists, 

370 since it takes into account the naming difference between V1 and V2 formats. 

371 

372 Args: 

373 checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking 

374 priority. Typically the result of `Saver.save()` or that of 

375 `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or 

376 V1/V2. 

377 Returns: 

378 A bool, true if a checkpoint referred to by `checkpoint_prefix` exists. 

379 """ 

380 pathname = _prefix_to_checkpoint_path(checkpoint_prefix, 

381 saver_pb2.SaverDef.V2) 

382 if file_io.get_matching_files(pathname): 

383 return True 

384 elif file_io.get_matching_files(checkpoint_prefix): 

385 return True 

386 else: 

387 return False 

388 

389 

390@deprecation.deprecated( 

391 date=None, 

392 instructions="Use standard file APIs to check for files with this prefix.") 

393@tf_export(v1=["train.checkpoint_exists"]) 

394def checkpoint_exists(checkpoint_prefix): 

395 """Checks whether a V1 or V2 checkpoint exists with the specified prefix. 

396 

397 This is the recommended way to check if a checkpoint exists, since it takes 

398 into account the naming difference between V1 and V2 formats. 

399 

400 Args: 

401 checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking 

402 priority. Typically the result of `Saver.save()` or that of 

403 `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or 

404 V1/V2. 

405 

406 Returns: 

407 A bool, true if a checkpoint referred to by `checkpoint_prefix` exists. 

408 """ 

409 return checkpoint_exists_internal(checkpoint_prefix) 

410 

411 

412@deprecation.deprecated( 

413 date=None, 

414 instructions="Use standard file utilities to get mtimes.") 

415@tf_export(v1=["train.get_checkpoint_mtimes"]) 

416def get_checkpoint_mtimes(checkpoint_prefixes): 

417 """Returns the mtimes (modification timestamps) of the checkpoints. 

418 

419 Globs for the checkpoints pointed to by `checkpoint_prefixes`. If the files 

420 exist, collect their mtime. Both V2 and V1 checkpoints are considered, in 

421 that priority. 

422 

423 This is the recommended way to get the mtimes, since it takes into account 

424 the naming difference between V1 and V2 formats. 

425 

426 Note: If not all checkpoints exist, the length of the returned mtimes list 

427 will be smaller than the length of `checkpoint_prefixes` list, so mapping 

428 checkpoints to corresponding mtimes will not be possible. 

429 

430 Args: 

431 checkpoint_prefixes: a list of checkpoint paths, typically the results of 

432 `Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of 

433 sharded/non-sharded or V1/V2. 

434 Returns: 

435 A list of mtimes (in microseconds) of the found checkpoints. 

436 """ 

437 mtimes = [] 

438 

439 def match_maybe_append(pathname): 

440 fnames = file_io.get_matching_files(pathname) 

441 if fnames: 

442 mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9) 

443 return True 

444 return False 

445 

446 for checkpoint_prefix in checkpoint_prefixes: 

447 # Tries V2's metadata file first. 

448 pathname = _prefix_to_checkpoint_path(checkpoint_prefix, 

449 saver_pb2.SaverDef.V2) 

450 if match_maybe_append(pathname): 

451 continue 

452 # Otherwise, tries V1, where the prefix is the complete pathname. 

453 match_maybe_append(checkpoint_prefix) 

454 

455 return mtimes 

456 

457 

458@deprecation.deprecated( 

459 date=None, 

460 instructions="Use standard file APIs to delete files with this prefix.") 

461@tf_export(v1=["train.remove_checkpoint"]) 

462def remove_checkpoint(checkpoint_prefix, 

463 checkpoint_format_version=saver_pb2.SaverDef.V2, 

464 meta_graph_suffix="meta"): 

465 """Removes a checkpoint given by `checkpoint_prefix`. 

466 

467 Args: 

468 checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result 

469 of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of 

470 sharded/non-sharded or V1/V2. 

471 checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to 

472 `SaverDef.V2`. 

473 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 

474 """ 

475 _delete_file_if_exists( 

476 meta_graph_filename(checkpoint_prefix, meta_graph_suffix)) 

477 if checkpoint_format_version == saver_pb2.SaverDef.V2: 

478 # V2 has a metadata file and some data files. 

479 _delete_file_if_exists(checkpoint_prefix + ".index") 

480 _delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????") 

481 else: 

482 # V1, Legacy. Exact match on the data file. 

483 _delete_file_if_exists(checkpoint_prefix) 

484 

485 

486def _delete_file_if_exists(filespec): 

487 """Deletes files matching `filespec`.""" 

488 for pathname in file_io.get_matching_files(filespec): 

489 try: 

490 file_io.delete_file(pathname) 

491 except errors.NotFoundError: 

492 logging.warning( 

493 "Hit NotFoundError when deleting '%s', possibly because another " 

494 "process/thread is also deleting/moving the same file", pathname) 

495 

496 

497def meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"): 

498 """Returns the meta graph filename. 

499 

500 Args: 

501 checkpoint_filename: Name of the checkpoint file. 

502 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 

503 

504 Returns: 

505 MetaGraph file name. 

506 """ 

507 # If the checkpoint_filename is sharded, the checkpoint_filename could 

508 # be of format model.ckpt-step#-?????-of-shard#. For example, 

509 # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002. 

510 basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename) 

511 suffixed_filename = ".".join([basename, meta_graph_suffix]) 

512 return suffixed_filename 

513 

514 

515# TODO(allenl): Allow tf.keras.Model instances in the constructor directly? 

516@tf_export("train.CheckpointManager") 

517class CheckpointManager(object): 

518 """Manages multiple checkpoints by keeping some and deleting unneeded ones. 

519 

520 Example usage: 

521 

522 ```python 

523 import tensorflow as tf 

524 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 

525 manager = tf.train.CheckpointManager( 

526 checkpoint, directory="/tmp/model", max_to_keep=5) 

527 status = checkpoint.restore(manager.latest_checkpoint) 

528 while True: 

529 # train 

530 manager.save() 

531 ``` 

532 

533 `CheckpointManager` preserves its own state across instantiations (see the 

534 `__init__` documentation for details). Only one should be active in a 

535 particular directory at a time. 

536 """ 

537 

538 def __init__(self, 

539 checkpoint, 

540 directory, 

541 max_to_keep, 

542 keep_checkpoint_every_n_hours=None, 

543 checkpoint_name="ckpt", 

544 step_counter=None, 

545 checkpoint_interval=None, 

546 init_fn=None): 

547 """Configure a `CheckpointManager` for use in `directory`. 

548 

549 If a `CheckpointManager` was previously used in `directory`, its 

550 state will be restored. This includes the list of managed checkpoints and 

551 the timestamp bookkeeping necessary to support 

552 `keep_checkpoint_every_n_hours`. The behavior of the new `CheckpointManager` 

553 will be the same as the previous `CheckpointManager`, including cleaning up 

554 existing checkpoints if appropriate. 

555 

556 Checkpoints are only considered for deletion just after a new checkpoint has 

557 been added. At that point, `max_to_keep` checkpoints will remain in an 

558 "active set". Once a checkpoint is preserved by 

559 `keep_checkpoint_every_n_hours` it will not be deleted by this 

560 `CheckpointManager` or any future `CheckpointManager` instantiated in 

561 `directory` (regardless of the new setting of 

562 `keep_checkpoint_every_n_hours`). The `max_to_keep` checkpoints in the 

563 active set may be deleted by this `CheckpointManager` or a future 

564 `CheckpointManager` instantiated in `directory` (subject to its 

565 `max_to_keep` and `keep_checkpoint_every_n_hours` settings). 

566 

567 `CheckpointManager` can be also used for initializing the model if 

568 there is no checkpoints for restoring in `directory`. An example usage is: 

569 

570 >>> import tempfile 

571 

572 >>> tmp_dir = tempfile.mkdtemp() 

573 >>> checkpoint = tf.train.Checkpoint() 

574 >>> init_path = checkpoint.save(os.path.join(tmp_dir, 'init')) 

575 

576 >>> def init_fn(): 

577 ... # Partially restore the checkpoint from `init_path`. 

578 ... checkpoint.restore(init_path) 

579 

580 >>> manager = tf.train.CheckpointManager( 

581 ... checkpoint, 

582 ... directory=os.path.join(tmp_dir, 'ckpt'), 

583 ... max_to_keep=None, 

584 ... init_fn=init_fn) 

585 >>> # `restore_or_initialize` will call `init_fn` if there is no existing 

586 >>> # checkpoint in `directory`. 

587 >>> manager.restore_or_initialize() 

588 

589 Args: 

590 checkpoint: The `tf.train.Checkpoint` instance to save and manage 

591 checkpoints for. 

592 directory: The path to a directory in which to write checkpoints. A 

593 special file named "checkpoint" is also written to this directory (in a 

594 human-readable text format) which contains the state of the 

595 `CheckpointManager`. 

596 max_to_keep: An integer, the number of checkpoints to keep. Unless 

597 preserved by `keep_checkpoint_every_n_hours`, checkpoints will be 

598 deleted from the active set, oldest first, until only `max_to_keep` 

599 checkpoints remain. If `None`, no checkpoints are deleted and everything 

600 stays in the active set. Note that `max_to_keep=None` will keep all 

601 checkpoint paths in memory and in the checkpoint state protocol buffer 

602 on disk. 

603 keep_checkpoint_every_n_hours: Upon removal from the active set, a 

604 checkpoint will be preserved if it has been at least 

605 `keep_checkpoint_every_n_hours` since the last preserved checkpoint. The 

606 default setting of `None` does not preserve any checkpoints in this way. 

607 checkpoint_name: Custom name for the checkpoint file. 

608 step_counter: A `tf.Variable` instance for checking the current step 

609 counter value, in case users want to save checkpoints every N steps. 

610 checkpoint_interval: An integer, indicates the minimum step interval 

611 between two checkpoints. 

612 init_fn: Callable. A function to do customized intialization if no 

613 checkpoints are in the directory. 

614 

615 Raises: 

616 ValueError: If `max_to_keep` is not a positive integer. 

617 """ 

618 self._checkpoint = checkpoint 

619 self._save_counter_assign = None 

620 if max_to_keep is not None and max_to_keep <= 0: 

621 raise ValueError( 

622 ("Expected a positive integer or `None` for `max_to_keep`, " 

623 "got %d.") 

624 % (max_to_keep,)) 

625 self._max_to_keep = max_to_keep 

626 self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours 

627 if isinstance(directory, os.PathLike): 

628 directory = os.fspath(directory) 

629 self._directory = directory 

630 self._checkpoint_prefix = os.path.join(directory, checkpoint_name) 

631 self._init_fn = init_fn 

632 

633 if checkpoint_interval is not None: 

634 if step_counter is None: 

635 raise ValueError("`step_counter` should be passed if " 

636 "`checkpoint_interval` is not None.") 

637 self._last_checkpoint_step = None 

638 self._step_counter = step_counter 

639 self._checkpoint_interval = checkpoint_interval 

640 

641 recovered_state = get_checkpoint_state(directory) 

642 current_clock = time.time() 

643 self._maybe_delete = collections.OrderedDict() 

644 if recovered_state is None: 

645 self._latest_checkpoint = None 

646 # Set the clock back slightly to avoid race conditions when quickly 

647 # re-creating a CheckpointManager. 

648 self._last_preserved_timestamp = current_clock - 1. 

649 else: 

650 self._latest_checkpoint = recovered_state.model_checkpoint_path 

651 self._last_preserved_timestamp = recovered_state.last_preserved_timestamp 

652 if current_clock < self._last_preserved_timestamp: 

653 # Time seems to have reversed itself. In addition to this warning, we'll 

654 # min() saved checkpoint timestamps with the current time to ensure that 

655 # old checkpoints don't get deleted accidentally. 

656 logging.warning( 

657 ("time.time() returned a value %f seconds behind the last " 

658 "preserved checkpoint timestamp.") 

659 % (self._last_preserved_timestamp - current_clock,)) 

660 self._last_preserved_timestamp = current_clock 

661 all_timestamps = recovered_state.all_model_checkpoint_timestamps 

662 all_paths = recovered_state.all_model_checkpoint_paths 

663 del recovered_state # Uses modified values from now on 

664 if not all_timestamps: 

665 all_timestamps = [self._last_preserved_timestamp] * len(all_paths) 

666 

667 for filename, timestamp in zip(all_paths, all_timestamps): 

668 timestamp = min(timestamp, current_clock) 

669 if timestamp > self._last_preserved_timestamp: 

670 self._maybe_delete[filename] = timestamp 

671 

672 @property 

673 def directory(self): 

674 return self._directory 

675 

676 @property 

677 def checkpoint_interval(self): 

678 return self._checkpoint_interval 

679 

680 @property 

681 def latest_checkpoint(self): 

682 """The prefix of the most recent checkpoint in `directory`. 

683 

684 Equivalent to `tf.train.latest_checkpoint(directory)` where `directory` is 

685 the constructor argument to `CheckpointManager`. 

686 

687 Suitable for passing to `tf.train.Checkpoint.restore` to resume training. 

688 

689 Returns: 

690 The checkpoint prefix. If there are no checkpoints, returns `None`. 

691 """ 

692 return self._latest_checkpoint 

693 

694 @property 

695 def checkpoints(self): 

696 """A list of managed checkpoints. 

697 

698 Note that checkpoints saved due to `keep_checkpoint_every_n_hours` will not 

699 show up in this list (to avoid ever-growing filename lists). 

700 

701 Returns: 

702 A list of filenames, sorted from oldest to newest. 

703 """ 

704 return list(self._maybe_delete.keys()) 

705 

706 def _sweep(self): 

707 """Deletes or preserves managed checkpoints.""" 

708 if not self._max_to_keep: 

709 # Does not update self._last_preserved_timestamp, since everything is kept 

710 # in the active set. 

711 return 

712 while len(self._maybe_delete) > self._max_to_keep: 

713 filename, timestamp = self._maybe_delete.popitem(last=False) 

714 # Even if we're keeping this checkpoint due to 

715 # keep_checkpoint_every_n_hours, we won't reference it to avoid 

716 # infinitely-growing CheckpointState protos. 

717 if (self._keep_checkpoint_every_n_hours 

718 and (timestamp - self._keep_checkpoint_every_n_hours * 3600. 

719 >= self._last_preserved_timestamp)): 

720 self._last_preserved_timestamp = timestamp 

721 continue 

722 _delete_file_if_exists(filename + ".index") 

723 _delete_file_if_exists(filename + ".data-?????-of-?????") 

724 

725 def _record_state(self): 

726 """Saves the `CheckpointManager`'s state in `directory`.""" 

727 filenames, timestamps = zip(*self._maybe_delete.items()) 

728 update_checkpoint_state_internal( 

729 self._directory, 

730 model_checkpoint_path=self.latest_checkpoint, 

731 all_model_checkpoint_paths=filenames, 

732 all_model_checkpoint_timestamps=timestamps, 

733 last_preserved_timestamp=self._last_preserved_timestamp, 

734 save_relative_paths=True) 

735 

736 @property 

737 def _prefix(self): 

738 """A common prefix for all checkpoints saved with this manager. 

739 

740 For example, if `directory` (a constructor argument) were `"/tmp/tf-model"`, 

741 `prefix` would be `"/tmp/tf-model/ckpt"` and checkpoints would generally be 

742 numbered `"/tmp/tf-model/ckpt-1"`, `"/tmp/tf-model/ckpt-2"`, and so on. Each 

743 checkpoint has several associated files 

744 (e.g. `"/tmp/tf-model/ckpt-2.index"`). 

745 

746 Returns: 

747 A string prefix. 

748 """ 

749 return self._checkpoint_prefix 

750 

751 @property 

752 def checkpoint(self): 

753 """Returns the `tf.train.Checkpoint` object.""" 

754 return self._checkpoint 

755 

756 def save(self, checkpoint_number=None, check_interval=True, options=None): 

757 """Creates a new checkpoint and manages it. 

758 

759 Args: 

760 checkpoint_number: An optional integer, or an integer-dtype `Variable` or 

761 `Tensor`, used to number the checkpoint. If `None` (default), 

762 checkpoints are numbered using `checkpoint.save_counter`. Even if 

763 `checkpoint_number` is provided, `save_counter` is still incremented. A 

764 user-provided `checkpoint_number` is not incremented even if it is a 

765 `Variable`. 

766 check_interval: An optional boolean. The argument is only effective when 

767 `checkpoint_interval` is passed into the manager. If `True`, the manager 

768 will only save the checkpoint if the interval between checkpoints is 

769 larger than `checkpoint_interval`. Otherwise it will always save the 

770 checkpoint unless a checkpoint has already been saved for the current 

771 step. 

772 options: Optional `tf.train.CheckpointOptions` object. This argument only 

773 works with TF2 checkpoint objects. For example, options = 

774 tf.saved_model.SaveOptions(experimental_io_device='/job:localhost') 

775 

776 Returns: 

777 The path to the new checkpoint. It is also recorded in the `checkpoints` 

778 and `latest_checkpoint` properties. `None` if no checkpoint is saved. 

779 """ 

780 if self._checkpoint_interval is not None: 

781 current_step = _evaluate(self._step_counter) 

782 if self._last_checkpoint_step is not None: 

783 if current_step == self._last_checkpoint_step: 

784 return None 

785 if check_interval and current_step < ( 

786 self._last_checkpoint_step + self._checkpoint_interval): 

787 return None 

788 self._last_checkpoint_step = current_step 

789 

790 # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge 

791 # slightly with a custom numbering option. 

792 if context.executing_eagerly(): 

793 save_counter = self._checkpoint.save_counter 

794 save_counter.assign_add(1) 

795 session = None 

796 else: 

797 session = ops.get_default_session() 

798 

799 def _initializing_creator(next_creator, **kwargs): 

800 """Initialize the save counter if it has been newly created.""" 

801 v = next_creator(**kwargs) 

802 session.run(v.initializer) 

803 return v 

804 

805 with variable_scope.variable_creator_scope(_initializing_creator): 

806 save_counter = self._checkpoint.save_counter 

807 if self._save_counter_assign is None: 

808 self._save_counter_assign = save_counter.assign_add(1, read_value=False) 

809 session.run(self._save_counter_assign) 

810 if checkpoint_number is None: 

811 checkpoint_number = save_counter 

812 if not isinstance(checkpoint_number, compat.integral_types): 

813 checkpoint_number = training_util.global_step( 

814 sess=session, global_step_tensor=checkpoint_number) 

815 prefix = "%s-%d" % (self._prefix, checkpoint_number) 

816 

817 def _record_and_sweep_state(save_path): 

818 timestamp = time.time() 

819 # If this is an overwritten checkpoint we were previously tracking, delete 

820 # and reinsert it to make sure it goes to the end of the queue. 

821 if save_path in self._maybe_delete: 

822 del self._maybe_delete[save_path] 

823 self._maybe_delete[save_path] = timestamp 

824 self._latest_checkpoint = save_path 

825 # Before deleting anything we update the Checkpoint proto with the new 

826 # checkpoint. We'll go back and correct it after cleaning up old files, 

827 # but a preemption while deleting will be more likely to see the new 

828 # checkpoint this way. 

829 self._record_state() 

830 self._sweep() 

831 # Write out the Checkpoint proto a second time, now without the deleted 

832 # checkpoints. 

833 self._record_state() 

834 

835 if options is None: 

836 save_path = self._checkpoint._write( # pylint: disable=protected-access 

837 prefix, write_done_callback=_record_and_sweep_state) 

838 else: 

839 save_path = self._checkpoint._write( # pylint: disable=protected-access 

840 prefix, options=options, write_done_callback=_record_and_sweep_state) 

841 

842 return save_path 

843 

844 def restore_or_initialize(self): 

845 """Restore items in `checkpoint` from the latest checkpoint file. 

846 

847 This method will first try to restore from the most recent checkpoint in 

848 `directory`. If no checkpoints exist in `directory`, and `init_fn` is 

849 specified, this method will call `init_fn` to do customized 

850 initialization. This can be used to support initialization from pretrained 

851 models. 

852 

853 Note that unlike `tf.train.Checkpoint.restore()`, this method doesn't return 

854 a load status object that users can run assertions on 

855 (e.g. assert_consumed()). Thus to run assertions, users should directly use 

856 `tf.train.Checkpoint.restore()` method. 

857 

858 Returns: 

859 The restored checkpoint path if the lastest checkpoint is found and 

860 restored. Otherwise None. 

861 """ 

862 # TODO(chienchunh): When AsyncCheckpoint is used, we may need to force to 

863 # sync until any ongoing async save is done. Otherwise, if this is the first 

864 # checkpoint and _latest_checkpoint has not been updated due to async write, 

865 # this would resort to init_fn instead of restoring from the checkpoin file. 

866 # This should be fixed once AsyncCheckpoint is integrated with the public 

867 # API so that we can rely on CheckpointOptions to tell whether we should 

868 # sync for AsyncCheckpoint. 

869 if self._latest_checkpoint is not None: 

870 self._checkpoint.restore(self._latest_checkpoint) 

871 if self._checkpoint_interval is not None: 

872 self._last_checkpoint_step = _evaluate(self._step_counter) 

873 return self._latest_checkpoint 

874 

875 if self._init_fn is not None: 

876 self._init_fn() 

877 logging.info( 

878 "Customized initialization is done through the passed `init_fn`.") 

879 return None 

880 

881 def sync(self): 

882 """Wait for any outstanding save or restore operations.""" 

883 if self._checkpoint: 

884 self._checkpoint.sync()