Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/saver.py: 20%

511 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16# pylint: disable=invalid-name 

17"""Save and restore variables. 

18 

19Symbols in this file are deprecated. See replacements in 

20tensorflow/python/training/trackable and tensorflow/python/training/saving. 

21""" 

22import collections 

23import glob 

24import os.path 

25import threading 

26import time 

27 

28import numpy as np 

29from tensorflow.core.protobuf import meta_graph_pb2 

30from tensorflow.core.protobuf import saver_pb2 

31from tensorflow.core.protobuf import trackable_object_graph_pb2 

32from tensorflow.python.checkpoint import checkpoint_management 

33from tensorflow.python.client import session 

34from tensorflow.python.eager import context 

35from tensorflow.python.framework import constant_op 

36from tensorflow.python.framework import device as pydev 

37from tensorflow.python.framework import errors 

38from tensorflow.python.framework import meta_graph 

39from tensorflow.python.framework import ops 

40from tensorflow.python.ops import array_ops 

41from tensorflow.python.ops import control_flow_ops 

42from tensorflow.python.ops import gen_io_ops 

43from tensorflow.python.ops import io_ops 

44from tensorflow.python.ops import string_ops 

45from tensorflow.python.ops import variables 

46from tensorflow.python.platform import gfile 

47from tensorflow.python.platform import tf_logging as logging 

48from tensorflow.python.saved_model.pywrap_saved_model import metrics 

49from tensorflow.python.trackable import base as trackable 

50from tensorflow.python.training import py_checkpoint_reader 

51from tensorflow.python.training import training_util 

52from tensorflow.python.training.saving import saveable_object 

53from tensorflow.python.training.saving import saveable_object_util 

54from tensorflow.python.util import compat 

55from tensorflow.python.util.tf_export import tf_export 

56 

57# TODO(allenl): Remove these aliases once all users are migrated off. 

58get_checkpoint_state = checkpoint_management.get_checkpoint_state 

59update_checkpoint_state = checkpoint_management.update_checkpoint_state 

60generate_checkpoint_state_proto = ( 

61 checkpoint_management.generate_checkpoint_state_proto) 

62latest_checkpoint = checkpoint_management.latest_checkpoint 

63checkpoint_exists = checkpoint_management.checkpoint_exists 

64get_checkpoint_mtimes = checkpoint_management.get_checkpoint_mtimes 

65remove_checkpoint = checkpoint_management.remove_checkpoint 

66 

67# Captures the timestamp of the first Saver object instantiation or end of a 

68# save operation. Can be accessed by multiple Saver instances. 

69_END_TIME_OF_LAST_WRITE = None 

70_END_TIME_OF_LAST_WRITE_LOCK = threading.Lock() 

71 

72# API label for cell name used in checkpoint metrics. 

73_SAVER_LABEL = "saver_v1" 

74 

75 

76def _get_duration_microseconds(start_time_seconds, end_time_seconds): 

77 if end_time_seconds < start_time_seconds: 

78 # Avoid returning negative value in case of clock skew. 

79 return 0 

80 return round((end_time_seconds - start_time_seconds) * 1000000) 

81 

82 

83def _get_checkpoint_size(prefix): 

84 """Calculates filesize of checkpoint based on prefix.""" 

85 size = 0 

86 # Gather all files beginning with prefix (.index plus sharded data files). 

87 files = glob.glob("{}*".format(prefix)) 

88 for file in files: 

89 # Use TensorFlow's C++ FileSystem API. 

90 size += metrics.CalculateFileSize(file) 

91 return size 

92 

93 

94class BaseSaverBuilder: 

95 """Base class for Savers. 

96 

97 Can be extended to create different Ops. 

98 """ 

99 

100 SaveSpec = saveable_object.SaveSpec 

101 SaveableObject = saveable_object.SaveableObject 

102 

103 # Aliases for code which was moved but still has lots of users. 

104 VariableSaveable = saveable_object_util.ReferenceVariableSaveable 

105 ResourceVariableSaveable = saveable_object_util.ResourceVariableSaveable 

106 

107 def __init__(self, write_version=saver_pb2.SaverDef.V2): 

108 self._write_version = write_version 

109 

110 def save_op(self, filename_tensor, saveables): 

111 """Create an Op to save 'saveables'. 

112 

113 This is intended to be overridden by subclasses that want to generate 

114 different Ops. 

115 

116 Args: 

117 filename_tensor: String Tensor. 

118 saveables: A list of BaseSaverBuilder.SaveableObject objects. 

119 

120 Returns: 

121 An Operation that save the variables. 

122 

123 Raises: 

124 RuntimeError: (implementation detail) if "self._write_version" is an 

125 unexpected value. 

126 """ 

127 # pylint: disable=protected-access 

128 tensor_names = [] 

129 tensors = [] 

130 tensor_slices = [] 

131 for saveable in saveables: 

132 for spec in saveable.specs: 

133 tensor_names.append(spec.name) 

134 tensors.append(spec.tensor) 

135 tensor_slices.append(spec.slice_spec) 

136 if self._write_version == saver_pb2.SaverDef.V1: 

137 return io_ops._save( 

138 filename=filename_tensor, 

139 tensor_names=tensor_names, 

140 tensors=tensors, 

141 tensor_slices=tensor_slices) 

142 elif self._write_version == saver_pb2.SaverDef.V2: 

143 # "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix 

144 # of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>". 

145 return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices, 

146 tensors) 

147 else: 

148 raise RuntimeError("Unexpected write_version: " + self._write_version) 

149 

150 def bulk_restore(self, filename_tensor, saveables, preferred_shard, 

151 restore_sequentially): 

152 """Restore all tensors contained in saveables. 

153 

154 By default, this issues separate calls to `restore_op` for each saveable. 

155 Subclasses may override to load multiple saveables in a single call. 

156 

157 Args: 

158 filename_tensor: String Tensor. 

159 saveables: List of BaseSaverBuilder.SaveableObject objects. 

160 preferred_shard: Int. Shard to open first when loading a sharded file. 

161 restore_sequentially: Unused. Bool. If true, each restore is sequential. 

162 

163 Returns: 

164 A list of Tensors resulting from reading 'saveable' from 

165 'filename'. 

166 

167 """ 

168 del restore_sequentially 

169 all_tensors = [] 

170 for saveable in saveables: 

171 if saveable.device: 

172 device = saveable_object_util.set_cpu0(saveable.device) 

173 else: 

174 device = None 

175 with ops.device(device): 

176 all_tensors.extend( 

177 self.restore_op(filename_tensor, saveable, preferred_shard)) 

178 return all_tensors 

179 

180 # pylint: disable=unused-argument 

181 def restore_op(self, filename_tensor, saveable, preferred_shard): 

182 """Create ops to restore 'saveable'. 

183 

184 This is intended to be overridden by subclasses that want to generate 

185 different Ops. 

186 

187 Args: 

188 filename_tensor: String Tensor. 

189 saveable: A BaseSaverBuilder.SaveableObject object. 

190 preferred_shard: Int. Shard to open first when loading a sharded file. 

191 

192 Returns: 

193 A list of Tensors resulting from reading 'saveable' from 

194 'filename'. 

195 """ 

196 # pylint: disable=protected-access 

197 tensors = [] 

198 for spec in saveable.specs: 

199 tensors.append( 

200 io_ops.restore_v2(filename_tensor, [spec.name], [spec.slice_spec], 

201 [spec.dtype])[0]) 

202 

203 return tensors 

204 

205 # pylint: enable=unused-argument 

206 

207 def sharded_filename(self, filename_tensor, shard, num_shards): 

208 """Append sharding information to a filename. 

209 

210 Args: 

211 filename_tensor: A string tensor. 

212 shard: Integer. The shard for the filename. 

213 num_shards: An int Tensor for the number of shards. 

214 

215 Returns: 

216 A string tensor. 

217 """ 

218 return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards) 

219 

220 def _AddSaveOps(self, filename_tensor, saveables): 

221 """Add ops to save variables that are on the same shard. 

222 

223 Args: 

224 filename_tensor: String Tensor. 

225 saveables: A list of SaveableObject objects. 

226 

227 Returns: 

228 A tensor with the filename used to save. 

229 """ 

230 save = self.save_op(filename_tensor, saveables) 

231 return control_flow_ops.with_dependencies([save], filename_tensor) 

232 

233 def _AddShardedSaveOpsForV2(self, checkpoint_prefix, per_device): 

234 """Add ops to save the params per shard, for the V2 format. 

235 

236 Note that the sharded save procedure for the V2 format is different from 

237 V1: there is a special "merge" step that merges the small metadata produced 

238 from each device. 

239 

240 Args: 

241 checkpoint_prefix: scalar String Tensor. Interpreted *NOT AS A FILENAME*, 

242 but as a prefix of a V2 checkpoint; 

243 per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as 

244 returned by _GroupByDevices(). 

245 

246 Returns: 

247 An op to save the variables, which, when evaluated, returns the prefix 

248 "<user-fed prefix>" only and does not include the sharded spec suffix. 

249 """ 

250 # IMPLEMENTATION DETAILS: most clients should skip. 

251 # 

252 # Suffix for any well-formed "checkpoint_prefix", when sharded. 

253 # Transformations: 

254 # * Users pass in "save_path" in save() and restore(). Say "myckpt". 

255 # * checkpoint_prefix gets fed <save_path><_SHARDED_SUFFIX>. 

256 # * If checkpoint_prefix is a S3 bucket path ".part" is appended to it 

257 # * Otherwise _temp/part is appended which is normalized relative to the OS 

258 # Example: 

259 # During runtime, a temporary directory is first created, which contains 

260 # files 

261 # 

262 # <train dir>/myckpt_temp/ 

263 # part-?????-of-?????{.index, .data-00000-of-00001} 

264 # 

265 # Before .save() finishes, they will be (hopefully, atomically) renamed to 

266 # 

267 # <train dir>/ 

268 # myckpt{.index, .data-?????-of-?????} 

269 # 

270 # Filesystems with eventual consistency (such as S3), don't need a 

271 # temporary location. Using a temporary directory in those cases might 

272 # cause situations where files are not available during copy. 

273 # 

274 # Users only need to interact with the user-specified prefix, which is 

275 # "<train dir>/myckpt" in this case. Save() and Restore() work with the 

276 # prefix directly, instead of any physical pathname. (On failure and 

277 # subsequent restore, an outdated and orphaned temporary directory can be 

278 # safely removed.) 

279 with ops.device("CPU"): 

280 _SHARDED_SUFFIX = array_ops.where( 

281 string_ops.regex_full_match(checkpoint_prefix, "^s3://.*"), 

282 constant_op.constant(".part"), 

283 constant_op.constant(os.path.normpath("_temp/part"))) 

284 tmp_checkpoint_prefix = string_ops.string_join( 

285 [checkpoint_prefix, _SHARDED_SUFFIX]) 

286 

287 num_shards = len(per_device) 

288 sharded_saves = [] 

289 sharded_prefixes = [] 

290 num_shards_tensor = constant_op.constant(num_shards, name="num_shards") 

291 last_device = None 

292 for shard, (device, saveables) in enumerate(per_device): 

293 last_device = device 

294 with ops.device(saveable_object_util.set_cpu0(device)): 

295 sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard, 

296 num_shards_tensor) 

297 sharded_prefixes.append(sharded_filename) 

298 sharded_saves.append(self._AddSaveOps(sharded_filename, saveables)) 

299 

300 with ops.control_dependencies([x.op for x in sharded_saves]): 

301 # Co-locates the merge step with the last device. 

302 with ops.device(saveable_object_util.set_cpu0(last_device)): 

303 # V2 format write path consists of a metadata merge step. Once merged, 

304 # attempts to delete the temporary directory, "<user-fed prefix>_temp". 

305 merge_step = gen_io_ops.merge_v2_checkpoints( 

306 sharded_prefixes, checkpoint_prefix, delete_old_dirs=True) 

307 with ops.control_dependencies([merge_step]): 

308 # Returns the prefix "<user-fed prefix>" only. DOES NOT include the 

309 # sharded spec suffix. 

310 return array_ops.identity(checkpoint_prefix) 

311 

312 def _AddShardedSaveOps(self, filename_tensor, per_device): 

313 """Add ops to save the params per shard. 

314 

315 Args: 

316 filename_tensor: a scalar String Tensor. 

317 per_device: A list of (device, BaseSaverBuilder.SaveableObject) pairs, as 

318 returned by _GroupByDevices(). 

319 

320 Returns: 

321 An op to save the variables. 

322 """ 

323 if self._write_version == saver_pb2.SaverDef.V2: 

324 return self._AddShardedSaveOpsForV2(filename_tensor, per_device) 

325 

326 num_shards = len(per_device) 

327 sharded_saves = [] 

328 num_shards_tensor = constant_op.constant(num_shards, name="num_shards") 

329 for shard, (device, saveables) in enumerate(per_device): 

330 with ops.device(device): 

331 sharded_filename = self.sharded_filename(filename_tensor, shard, 

332 num_shards_tensor) 

333 sharded_saves.append(self._AddSaveOps(sharded_filename, saveables)) 

334 # Return the sharded name for the save path. 

335 with ops.control_dependencies([x.op for x in sharded_saves]): 

336 return gen_io_ops.sharded_filespec(filename_tensor, num_shards_tensor) 

337 

338 def _AddRestoreOps(self, 

339 filename_tensor, 

340 saveables, 

341 restore_sequentially, 

342 reshape, 

343 preferred_shard=-1, 

344 name="restore_all"): 

345 """Add operations to restore saveables. 

346 

347 Args: 

348 filename_tensor: Tensor for the path of the file to load. 

349 saveables: A list of SaveableObject objects. 

350 restore_sequentially: True if we want to restore variables sequentially 

351 within a shard. 

352 reshape: True if we want to reshape loaded tensors to the shape of the 

353 corresponding variable. 

354 preferred_shard: Shard to open first when loading a sharded file. 

355 name: Name for the returned op. 

356 

357 Returns: 

358 An Operation that restores the variables. 

359 """ 

360 all_tensors = self.bulk_restore(filename_tensor, saveables, preferred_shard, 

361 restore_sequentially) 

362 

363 assign_ops = [] 

364 idx = 0 

365 # Load and optionally reshape on the CPU, as string tensors are not 

366 # available on the GPU. 

367 # TODO(touts): Re-enable restore on GPU when we can support annotating 

368 # string tensors as "HostMemory" inputs. 

369 for saveable in saveables: 

370 shapes = None 

371 if reshape: 

372 # Compute the shapes, let the restore op decide if and how to do 

373 # the reshape. 

374 shapes = [] 

375 for spec in saveable.specs: 

376 v = spec.tensor 

377 shape = v.get_shape() 

378 if not shape.is_fully_defined(): 

379 shape = array_ops.shape(v) 

380 shapes.append(shape) 

381 saveable_tensors = all_tensors[idx:idx + len(saveable.specs)] 

382 idx += len(saveable.specs) 

383 assign_ops.append(saveable.restore(saveable_tensors, shapes)) 

384 

385 # Create a Noop that has control dependencies from all the updates. 

386 return control_flow_ops.group(*assign_ops, name=name) 

387 

388 def _AddShardedRestoreOps(self, filename_tensor, per_device, 

389 restore_sequentially, reshape): 

390 """Add Ops to restore variables from multiple devices. 

391 

392 Args: 

393 filename_tensor: Tensor for the path of the file to load. 

394 per_device: A list of (device, SaveableObject) pairs, as returned by 

395 _GroupByDevices(). 

396 restore_sequentially: True if we want to restore variables sequentially 

397 within a shard. 

398 reshape: True if we want to reshape loaded tensors to the shape of the 

399 corresponding variable. 

400 

401 Returns: 

402 An Operation that restores the variables. 

403 """ 

404 sharded_restores = [] 

405 for shard, (device, saveables) in enumerate(per_device): 

406 with ops.device(device): 

407 sharded_restores.append( 

408 self._AddRestoreOps( 

409 filename_tensor, 

410 saveables, 

411 restore_sequentially, 

412 reshape, 

413 preferred_shard=shard, 

414 name="restore_shard")) 

415 return control_flow_ops.group(*sharded_restores, name="restore_all") 

416 

417 def _GroupByDevices(self, saveables): 

418 """Group Variable tensor slices per device. 

419 

420 TODO(touts): Make sure that all the devices found are on different 

421 job/replica/task/cpu|gpu. It would be bad if 2 were on the same device. 

422 It can happen if the devices are unspecified. 

423 

424 Args: 

425 saveables: A list of BaseSaverBuilder.SaveableObject objects. 

426 

427 Returns: 

428 A list of tuples: (device_name, BaseSaverBuilder.SaveableObject) tuples. 

429 The list is sorted by ascending device_name. 

430 

431 Raises: 

432 ValueError: If the tensors of a saveable are on different devices. 

433 """ 

434 per_device = collections.defaultdict(lambda: []) 

435 for saveable in saveables: 

436 canonical_device = set( 

437 pydev.canonical_name(spec.device) for spec in saveable.specs) 

438 if len(canonical_device) != 1: 

439 raise ValueError("All tensors of a saveable object must be " 

440 "on the same device: %s" % saveable.name) 

441 per_device[canonical_device.pop()].append(saveable) 

442 return sorted(per_device.items(), key=lambda t: t[0]) 

443 

444 def build(self, 

445 names_to_saveables, 

446 reshape=False, 

447 sharded=False, 

448 max_to_keep=5, 

449 keep_checkpoint_every_n_hours=10000.0, 

450 name=None, 

451 restore_sequentially=False, 

452 filename="model"): 

453 """Builds save/restore graph nodes or runs save/restore in eager mode. 

454 

455 Args: 

456 names_to_saveables: A dictionary mapping name to a Variable or 

457 SaveableObject. Each name will be associated with the corresponding 

458 variable in the checkpoint. 

459 reshape: If True, allow restoring parameters from a checkpoint that where 

460 the parameters have a different shape. This is only needed when you try 

461 to restore from a Dist-Belief checkpoint, and only some times. 

462 sharded: If True, shard the checkpoints, one per device that has Variable 

463 nodes. 

464 max_to_keep: Maximum number of checkpoints to keep. As new checkpoints 

465 are created, old ones are deleted. If None or 0, no checkpoints are 

466 deleted from the filesystem but only the last one is kept in the 

467 `checkpoint` file. Presently the number is only roughly enforced. For 

468 example in case of restarts more than max_to_keep checkpoints may be 

469 kept. 

470 keep_checkpoint_every_n_hours: How often checkpoints should be kept. 

471 Defaults to 10,000 hours. 

472 name: String. Optional name to use as a prefix when adding operations. 

473 restore_sequentially: A Bool, which if true, causes restore of different 

474 variables to happen sequentially within each device. 

475 filename: If known at graph construction time, filename used for variable 

476 loading/saving. If None, then the default name "model" will be used. 

477 

478 Returns: 

479 A SaverDef proto. 

480 

481 Raises: 

482 TypeError: If 'names_to_saveables' is not a dictionary mapping string 

483 keys to variable Tensors. 

484 ValueError: If any of the keys or values in 'names_to_saveables' is not 

485 unique. 

486 """ 

487 return self._build_internal( 

488 names_to_saveables=names_to_saveables, 

489 reshape=reshape, 

490 sharded=sharded, 

491 max_to_keep=max_to_keep, 

492 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, 

493 name=name, 

494 restore_sequentially=restore_sequentially, 

495 filename=filename) 

496 

497 def _build_internal(self, 

498 names_to_saveables, 

499 reshape=False, 

500 sharded=False, 

501 max_to_keep=5, 

502 keep_checkpoint_every_n_hours=10000.0, 

503 name=None, 

504 restore_sequentially=False, 

505 filename="model", 

506 build_save=True, 

507 build_restore=True): 

508 """build() with option to only perform save and restore.""" 

509 if not context.executing_eagerly() and (not build_save or 

510 not build_restore): 

511 raise ValueError("save and restore operations need to be built together " 

512 " when eager execution is not enabled.") 

513 

514 if not isinstance(names_to_saveables, dict): 

515 names_to_saveables = saveable_object_util.op_list_to_dict( 

516 names_to_saveables) 

517 saveables = saveable_object_util.validate_and_slice_inputs( 

518 names_to_saveables) 

519 if max_to_keep is None: 

520 max_to_keep = 0 

521 

522 with ops.name_scope(name, "save", 

523 [saveable.op for saveable in saveables]) as name: 

524 # Add a placeholder string tensor for the filename. 

525 filename_tensor = array_ops.placeholder_with_default( 

526 filename or "model", shape=(), name="filename") 

527 # Keep the name "Const" for backwards compatibility. 

528 filename_tensor = array_ops.placeholder_with_default( 

529 filename_tensor, shape=(), name="Const") 

530 

531 # Add the save ops. 

532 if sharded: 

533 per_device = self._GroupByDevices(saveables) 

534 if build_save: 

535 save_tensor = self._AddShardedSaveOps(filename_tensor, per_device) 

536 if build_restore: 

537 restore_op = self._AddShardedRestoreOps(filename_tensor, per_device, 

538 restore_sequentially, reshape) 

539 else: 

540 if build_save: 

541 save_tensor = self._AddSaveOps(filename_tensor, saveables) 

542 if build_restore: 

543 restore_op = self._AddRestoreOps(filename_tensor, saveables, 

544 restore_sequentially, reshape) 

545 

546 # In the following use case, it's possible to have restore_ops be called 

547 # something else: 

548 # - Build inference graph and export a meta_graph. 

549 # - Import the inference meta_graph 

550 # - Extend the inference graph to a train graph. 

551 # - Export a new meta_graph. 

552 # Now the second restore_op will be called "restore_all_1". 

553 # As such, comment out the assert for now until we know whether supporting 

554 # such usage model makes sense. 

555 # 

556 # assert restore_op.name.endswith("restore_all"), restore_op.name 

557 if context.executing_eagerly(): 

558 # Store the tensor values to the tensor_names. 

559 save_tensor_name = save_tensor.numpy() if build_save else "" 

560 return saver_pb2.SaverDef( 

561 filename_tensor_name=filename_tensor.numpy(), 

562 save_tensor_name=save_tensor_name, 

563 restore_op_name="", 

564 max_to_keep=max_to_keep, 

565 sharded=sharded, 

566 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, 

567 version=self._write_version) 

568 else: 

569 graph = ops.get_default_graph() 

570 # Do some sanity checking on collections containing 

571 # PartitionedVariables. If a saved collection has a PartitionedVariable, 

572 # the GraphDef needs to include concat ops to get the value (or there'll 

573 # be a lookup error on load). 

574 check_collection_list = graph.get_all_collection_keys() 

575 for collection_type in check_collection_list: 

576 for element in graph.get_collection(collection_type): 

577 if isinstance(element, variables.PartitionedVariable): 

578 try: 

579 graph.get_operation_by_name(element.name) 

580 except KeyError: 

581 # Create a concat op for this PartitionedVariable. The user may 

582 # not need it, but we'll try looking it up on MetaGraph restore 

583 # since it's in a collection. 

584 element.as_tensor() 

585 return saver_pb2.SaverDef( 

586 filename_tensor_name=filename_tensor.name, 

587 save_tensor_name=save_tensor.name, 

588 restore_op_name=restore_op.name, 

589 max_to_keep=max_to_keep, 

590 sharded=sharded, 

591 keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, 

592 version=self._write_version) 

593 

594 

595class BulkSaverBuilder(BaseSaverBuilder): 

596 """SaverBuilder with support for bulk restoring multiple saveables.""" 

597 

598 def bulk_restore(self, filename_tensor, saveables, preferred_shard, 

599 restore_sequentially): 

600 

601 # Ignored: bulk restore is internally sequential. 

602 del restore_sequentially 

603 restore_specs = [] 

604 for saveable in saveables: 

605 for spec in saveable.specs: 

606 restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) 

607 

608 names, slices, dtypes = zip(*restore_specs) 

609 # Load all tensors onto CPU 0 for compatibility with existing code. 

610 with ops.device("cpu:0"): 

611 return io_ops.restore_v2(filename_tensor, names, slices, dtypes) 

612 

613 

614def _get_saver_or_default(): 

615 """Returns the saver from SAVERS collection, or creates a default one. 

616 

617 This method is used by other members of the training module, such as 

618 `Scaffold`, or `CheckpointSaverHook`. 

619 

620 Returns: 

621 `Saver`. 

622 

623 Raises: 

624 RuntimeError: If the SAVERS collection already has more than one items. 

625 """ 

626 collection_key = ops.GraphKeys.SAVERS 

627 savers = ops.get_collection(collection_key) 

628 if savers: 

629 if len(savers) > 1: 

630 raise RuntimeError( 

631 "More than one item in collection {}. " 

632 "Please indicate which one to use by passing it to the constructor." 

633 .format(collection_key)) 

634 return savers[0] 

635 saver = Saver(sharded=True, allow_empty=True) 

636 if saver is not None: 

637 ops.add_to_collection(collection_key, saver) 

638 return saver 

639 

640 

641@tf_export(v1=["train.Saver"]) 

642class Saver: 

643 # pylint: disable=line-too-long 

644 """Saves and restores variables. 

645 

646 @compatibility(TF2) 

647 `tf.compat.v1.train.Saver` is not supported for saving and restoring 

648 checkpoints in TF2. Please switch to `tf.train.Checkpoint` or 

649 `tf.keras.Model.save_weights`, which perform a more robust [object-based 

650 saving](https://www.tensorflow.org/guide/checkpoint#loading_mechanics). 

651 

652 ### How to Rewrite Checkpoints 

653 

654 Please rewrite your checkpoints immediately using the object-based checkpoint 

655 APIs. 

656 

657 You can load a name-based checkpoint written by `tf.compat.v1.train.Saver` 

658 using `tf.train.Checkpoint.restore` or `tf.keras.Model.load_weights`. However, 

659 you may have to change the names of the variables in your model to match the 

660 variable names in the name-based checkpoint, which can be viewed with 

661 `tf.train.list_variables(path)`. 

662 

663 Another option is to create an `assignment_map` that maps the name of the 

664 variables in the name-based checkpoint to the variables in your model, eg: 

665 ``` 

666 { 

667 'sequential/dense/bias': model.variables[0], 

668 'sequential/dense/kernel': model.variables[1] 

669 } 

670 ``` 

671 and use `tf.compat.v1.train.init_from_checkpoint(path, assignment_map)` to 

672 restore the name-based checkpoint. 

673 

674 After restoring, re-encode your checkpoint 

675 using `tf.train.Checkpoint.save` or `tf.keras.Model.save_weights`. 

676 

677 See the [Checkpoint compatibility]( 

678 https://www.tensorflow.org/guide/migrate#checkpoint_compatibility) 

679 section of the migration guide for more details. 

680 

681 

682 ### Checkpoint Management in TF2 

683 

684 Use `tf.train.CheckpointManager` to manage checkpoints in TF2. 

685 `tf.train.CheckpointManager` offers equivalent `keep_checkpoint_every_n_hours` 

686 and `max_to_keep` parameters. 

687 

688 To recover the latest checkpoint, 

689 

690 ``` 

691 checkpoint = tf.train.Checkpoint(model) 

692 manager = tf.train.CheckpointManager(checkpoint) 

693 status = checkpoint.restore(manager.latest_checkpoint) 

694 ``` 

695 

696 `tf.train.CheckpointManager` also writes a [`CheckpointState` proto] 

697 (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/checkpoint_state.proto) 

698 which contains the timestamp when each checkpoint was created. 

699 

700 ### Writing `MetaGraphDef`s in TF2 

701 

702 To replace, `tf.compat.v1.train.Saver.save(write_meta_graph=True)`, use 

703 `tf.saved_model.save` to write the `MetaGraphDef` (which is contained in 

704 `saved_model.pb`). 

705 

706 @end_compatibility 

707 

708 See [Variables](https://tensorflow.org/guide/variables) 

709 for an overview of variables, saving and restoring. 

710 

711 The `Saver` class adds ops to save and restore variables to and from 

712 *checkpoints*. It also provides convenience methods to run these ops. 

713 

714 Checkpoints are binary files in a proprietary format which map variable names 

715 to tensor values. The best way to examine the contents of a checkpoint is to 

716 load it using a `Saver`. 

717 

718 Savers can automatically number checkpoint filenames with a provided counter. 

719 This lets you keep multiple checkpoints at different steps while training a 

720 model. For example you can number the checkpoint filenames with the training 

721 step number. To avoid filling up disks, savers manage checkpoint files 

722 automatically. For example, they can keep only the N most recent files, or 

723 one checkpoint for every N hours of training. 

724 

725 You number checkpoint filenames by passing a value to the optional 

726 `global_step` argument to `save()`: 

727 

728 ```python 

729 saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0' 

730 ... 

731 saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000' 

732 ``` 

733 

734 Additionally, optional arguments to the `Saver()` constructor let you control 

735 the proliferation of checkpoint files on disk: 

736 

737 * `max_to_keep` indicates the maximum number of recent checkpoint files to 

738 keep. As new files are created, older files are deleted. If None or 0, 

739 no checkpoints are deleted from the filesystem but only the last one is 

740 kept in the `checkpoint` file. Defaults to 5 (that is, the 5 most recent 

741 checkpoint files are kept.) 

742 

743 * `keep_checkpoint_every_n_hours`: In addition to keeping the most recent 

744 `max_to_keep` checkpoint files, you might want to keep one checkpoint file 

745 for every N hours of training. This can be useful if you want to later 

746 analyze how a model progressed during a long training session. For 

747 example, passing `keep_checkpoint_every_n_hours=2` ensures that you keep 

748 one checkpoint file for every 2 hours of training. The default value of 

749 10,000 hours effectively disables the feature. 

750 

751 Note that you still have to call the `save()` method to save the model. 

752 Passing these arguments to the constructor will not save variables 

753 automatically for you. 

754 

755 A training program that saves regularly looks like: 

756 

757 ```python 

758 ... 

759 # Create a saver. 

760 saver = tf.compat.v1.train.Saver(...variables...) 

761 # Launch the graph and train, saving the model every 1,000 steps. 

762 sess = tf.compat.v1.Session() 

763 for step in range(1000000): 

764 sess.run(..training_op..) 

765 if step % 1000 == 0: 

766 # Append the step number to the checkpoint name: 

767 saver.save(sess, 'my-model', global_step=step) 

768 ``` 

769 

770 In addition to checkpoint files, savers keep a protocol buffer on disk with 

771 the list of recent checkpoints. This is used to manage numbered checkpoint 

772 files and by `latest_checkpoint()`, which makes it easy to discover the path 

773 to the most recent checkpoint. That protocol buffer is stored in a file named 

774 'checkpoint' next to the checkpoint files. 

775 

776 If you create several savers, you can specify a different filename for the 

777 protocol buffer file in the call to `save()`. 

778 """ 

779 

780 # pylint: enable=line-too-long 

781 

782 def __init__(self, 

783 var_list=None, 

784 reshape=False, 

785 sharded=False, 

786 max_to_keep=5, 

787 keep_checkpoint_every_n_hours=10000.0, 

788 name=None, 

789 restore_sequentially=False, 

790 saver_def=None, 

791 builder=None, 

792 defer_build=False, 

793 allow_empty=False, 

794 write_version=saver_pb2.SaverDef.V2, 

795 pad_step_number=False, 

796 save_relative_paths=False, 

797 filename=None): 

798 """Creates a `Saver`. 

799 

800 The constructor adds ops to save and restore variables. 

801 

802 `var_list` specifies the variables that will be saved and restored. It can 

803 be passed as a `dict` or a list: 

804 

805 * A `dict` of names to variables: The keys are the names that will be 

806 used to save or restore the variables in the checkpoint files. 

807 * A list of variables: The variables will be keyed with their op name in 

808 the checkpoint files. 

809 

810 For example: 

811 

812 ```python 

813 v1 = tf.Variable(..., name='v1') 

814 v2 = tf.Variable(..., name='v2') 

815 

816 # Pass the variables as a dict: 

817 saver = tf.compat.v1.train.Saver({'v1': v1, 'v2': v2}) 

818 

819 # Or pass them as a list. 

820 saver = tf.compat.v1.train.Saver([v1, v2]) 

821 # Passing a list is equivalent to passing a dict with the variable op names 

822 # as keys: 

823 saver = tf.compat.v1.train.Saver({v.op.name: v for v in [v1, v2]}) 

824 ``` 

825 

826 Note: the newer `AutoTrackable` API is not supported by `Saver`. In this 

827 case, the `tf.train.Checkpoint` class should be used. 

828 

829 The optional `reshape` argument, if `True`, allows restoring a variable from 

830 a save file where the variable had a different shape, but the same number 

831 of elements and type. This is useful if you have reshaped a variable and 

832 want to reload it from an older checkpoint. 

833 

834 The optional `sharded` argument, if `True`, instructs the saver to shard 

835 checkpoints per device. 

836 

837 Args: 

838 var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping 

839 names to `SaveableObject`s. If `None`, defaults to the list of all 

840 saveable objects. 

841 reshape: If `True`, allows restoring parameters from a checkpoint where 

842 the variables have a different shape. 

843 sharded: If `True`, shard the checkpoints, one per device. 

844 max_to_keep: Maximum number of recent checkpoints to keep. Defaults to 5. 

845 keep_checkpoint_every_n_hours: How often to keep checkpoints. Defaults to 

846 10,000 hours. 

847 name: String. Optional name to use as a prefix when adding operations. 

848 restore_sequentially: A `Bool`, which if true, causes restore of different 

849 variables to happen sequentially within each device. This can lower 

850 memory usage when restoring very large models. 

851 saver_def: Optional `SaverDef` proto to use instead of running the 

852 builder. This is only useful for specialty code that wants to recreate a 

853 `Saver` object for a previously built `Graph` that had a `Saver`. The 

854 `saver_def` proto should be the one returned by the `as_saver_def()` 

855 call of the `Saver` that was created for that `Graph`. 

856 builder: Optional `SaverBuilder` to use if a `saver_def` was not provided. 

857 Defaults to `BulkSaverBuilder()`. 

858 defer_build: If `True`, defer adding the save and restore ops to the 

859 `build()` call. In that case `build()` should be called before 

860 finalizing the graph or using the saver. 

861 allow_empty: If `False` (default) raise an error if there are no variables 

862 in the graph. Otherwise, construct the saver anyway and make it a no-op. 

863 write_version: controls what format to use when saving checkpoints. It 

864 also affects certain filepath matching logic. The V2 format is the 

865 recommended choice: it is much more optimized than V1 in terms of memory 

866 required and latency incurred during restore. Regardless of this flag, 

867 the Saver is able to restore from both V2 and V1 checkpoints. 

868 pad_step_number: if True, pads the global step number in the checkpoint 

869 filepaths to some fixed width (8 by default). This is turned off by 

870 default. 

871 save_relative_paths: If `True`, will write relative paths to the 

872 checkpoint state file. This is needed if the user wants to copy the 

873 checkpoint directory and reload from the copied directory. 

874 filename: If known at graph construction time, filename used for variable 

875 loading/saving. 

876 

877 Raises: 

878 TypeError: If `var_list` is invalid. 

879 ValueError: If any of the keys or values in `var_list` are not unique. 

880 RuntimeError: If eager execution is enabled and`var_list` does not specify 

881 a list of variables to save. 

882 

883 @compatibility(eager) 

884 When eager execution is enabled, `var_list` must specify a `list` or `dict` 

885 of variables to save. Otherwise, a `RuntimeError` will be raised. 

886 

887 Although Saver works in some cases when executing eagerly, it is 

888 fragile. Please switch to `tf.train.Checkpoint` or 

889 `tf.keras.Model.save_weights`, which perform a more robust object-based 

890 saving. These APIs will load checkpoints written by `Saver`. 

891 @end_compatibility 

892 """ 

893 global _END_TIME_OF_LAST_WRITE 

894 with _END_TIME_OF_LAST_WRITE_LOCK: 

895 if _END_TIME_OF_LAST_WRITE is None: 

896 _END_TIME_OF_LAST_WRITE = time.time() 

897 

898 if defer_build and var_list: 

899 raise ValueError( 

900 "If `var_list` is provided then build cannot be deferred. " 

901 "Either set defer_build=False or var_list=None.") 

902 if context.executing_eagerly(): 

903 logging.warning( 

904 "Saver is deprecated, please switch to tf.train.Checkpoint or " 

905 "tf.keras.Model.save_weights for training checkpoints. When " 

906 "executing eagerly variables do not necessarily have unique names, " 

907 "and so the variable.name-based lookups Saver performs are " 

908 "error-prone.") 

909 if var_list is None: 

910 raise RuntimeError( 

911 "When eager execution is enabled, `var_list` must specify a list " 

912 "or dict of variables to save") 

913 self._var_list = var_list 

914 self._reshape = reshape 

915 self._sharded = sharded 

916 self._max_to_keep = max_to_keep 

917 self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours 

918 self._name = name 

919 self._restore_sequentially = restore_sequentially 

920 self.saver_def = saver_def 

921 self._builder = builder 

922 self._is_built = False 

923 self._allow_empty = allow_empty 

924 self._is_empty = None 

925 self._write_version = write_version 

926 self._pad_step_number = pad_step_number 

927 self._filename = filename 

928 self._last_checkpoints = [] 

929 self._checkpoints_to_be_deleted = [] 

930 if context.executing_eagerly(): 

931 self._next_checkpoint_time = ( 

932 time.time() + self._keep_checkpoint_every_n_hours * 3600) 

933 elif not defer_build: 

934 self.build() 

935 if self.saver_def: 

936 self._check_saver_def() 

937 self._write_version = self.saver_def.version 

938 self._save_relative_paths = save_relative_paths 

939 # For compatibility with object-based checkpoints, we may build a second 

940 # Saver to read the renamed keys. 

941 self._object_restore_saver = None 

942 

943 def build(self): 

944 if context.executing_eagerly(): 

945 raise RuntimeError("Use save/restore instead of build in eager mode.") 

946 self._build(self._filename, build_save=True, build_restore=True) 

947 

948 def _build_eager(self, checkpoint_path, build_save, build_restore): 

949 self._build( 

950 checkpoint_path, build_save=build_save, build_restore=build_restore) 

951 

952 def _build(self, checkpoint_path, build_save, build_restore): 

953 """Builds saver_def.""" 

954 if not context.executing_eagerly(): 

955 if self._is_built: 

956 return 

957 self._is_built = True 

958 

959 if not self.saver_def or context.executing_eagerly(): 

960 if self._builder is None: 

961 self._builder = BulkSaverBuilder(self._write_version) 

962 

963 if self._var_list is None: 

964 # pylint: disable=protected-access 

965 self._var_list = variables._all_saveable_objects() 

966 if not self._var_list: 

967 if self._allow_empty: 

968 self._is_empty = True 

969 return 

970 else: 

971 raise ValueError("No variables to save") 

972 self._is_empty = False 

973 

974 self.saver_def = self._builder._build_internal( # pylint: disable=protected-access 

975 self._var_list, 

976 reshape=self._reshape, 

977 sharded=self._sharded, 

978 max_to_keep=self._max_to_keep, 

979 keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours, 

980 name=self._name, 

981 restore_sequentially=self._restore_sequentially, 

982 filename=checkpoint_path, 

983 build_save=build_save, 

984 build_restore=build_restore) 

985 elif self.saver_def and self._name: 

986 # Since self._name is used as a name_scope by builder(), we are 

987 # overloading the use of this field to represent the "import_scope" as 

988 # well. 

989 self.saver_def.filename_tensor_name = ops.prepend_name_scope( 

990 self.saver_def.filename_tensor_name, self._name) 

991 self.saver_def.save_tensor_name = ops.prepend_name_scope( 

992 self.saver_def.save_tensor_name, self._name) 

993 self.saver_def.restore_op_name = ops.prepend_name_scope( 

994 self.saver_def.restore_op_name, self._name) 

995 

996 self._check_saver_def() 

997 if not context.executing_eagerly(): 

998 # Updates next checkpoint time. 

999 # Set in __init__ when executing eagerly. 

1000 self._next_checkpoint_time = ( 

1001 time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600) 

1002 

1003 def _check_saver_def(self): 

1004 if not isinstance(self.saver_def, saver_pb2.SaverDef): 

1005 raise ValueError("saver_def must be a saver_pb2.SaverDef: %s" % 

1006 self.saver_def) 

1007 if not context.executing_eagerly(): 

1008 if not self.saver_def.save_tensor_name: 

1009 raise ValueError("saver_def must specify the save_tensor_name: %s" % 

1010 str(self.saver_def)) 

1011 if not self.saver_def.restore_op_name: 

1012 raise ValueError("saver_def must specify the restore_op_name: %s" % 

1013 str(self.saver_def)) 

1014 

1015 def _CheckpointFilename(self, p): 

1016 """Returns the checkpoint filename given a `(filename, time)` pair. 

1017 

1018 Args: 

1019 p: (filename, time) pair. 

1020 

1021 Returns: 

1022 Checkpoint file name. 

1023 """ 

1024 name, _ = p 

1025 return name 

1026 

1027 def _RecordLastCheckpoint(self, latest_save_path): 

1028 """Manages the list of the latest checkpoints.""" 

1029 if not self.saver_def.max_to_keep: 

1030 return 

1031 # Remove first from list if the same name was used before. 

1032 for p in self._last_checkpoints: 

1033 if latest_save_path == self._CheckpointFilename(p): 

1034 self._last_checkpoints.remove(p) 

1035 # Append new path to list 

1036 self._last_checkpoints.append((latest_save_path, time.time())) 

1037 

1038 # If more than max_to_keep, remove oldest. 

1039 if len(self._last_checkpoints) > self.saver_def.max_to_keep: 

1040 self._checkpoints_to_be_deleted.append(self._last_checkpoints.pop(0)) 

1041 

1042 def _MaybeDeleteOldCheckpoints(self, meta_graph_suffix="meta"): 

1043 """Deletes old checkpoints if necessary. 

1044 

1045 `self._checkpoints_to_be_deleted` is going to contain checkpoints that are 

1046 over `max_to_keep`. They are going to be deleted. If 

1047 `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint 

1048 every `N` hours. For example, if `N` is 0.5, an additional checkpoint is 

1049 kept for every 0.5 hours of training; if `N` is 10, an additional 

1050 checkpoint is kept for every 10 hours of training. 

1051 

1052 Args: 

1053 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 

1054 """ 

1055 if self._checkpoints_to_be_deleted: 

1056 p = self._checkpoints_to_be_deleted.pop(0) 

1057 # Do not delete the file if we keep_checkpoint_every_n_hours is set and we 

1058 # have reached N hours of training. 

1059 should_keep = p[1] > self._next_checkpoint_time 

1060 if should_keep: 

1061 self._next_checkpoint_time += ( 

1062 self.saver_def.keep_checkpoint_every_n_hours * 3600) 

1063 return 

1064 

1065 # Otherwise delete the files. 

1066 try: 

1067 checkpoint_management.remove_checkpoint( 

1068 self._CheckpointFilename(p), self.saver_def.version, 

1069 meta_graph_suffix) 

1070 except Exception as e: # pylint: disable=broad-except 

1071 logging.warning("Ignoring: %s", str(e)) 

1072 

1073 def as_saver_def(self): 

1074 """Generates a `SaverDef` representation of this saver. 

1075 

1076 Returns: 

1077 A `SaverDef` proto. 

1078 """ 

1079 return self.saver_def 

1080 

1081 def to_proto(self, export_scope=None): 

1082 """Converts this `Saver` to a `SaverDef` protocol buffer. 

1083 

1084 Args: 

1085 export_scope: Optional `string`. Name scope to remove. 

1086 

1087 Returns: 

1088 A `SaverDef` protocol buffer. 

1089 """ 

1090 if export_scope is None: 

1091 return self.saver_def 

1092 

1093 if not (self.saver_def.filename_tensor_name.startswith(export_scope) and 

1094 self.saver_def.save_tensor_name.startswith(export_scope) and 

1095 self.saver_def.restore_op_name.startswith(export_scope)): 

1096 return None 

1097 

1098 saver_def = saver_pb2.SaverDef() 

1099 saver_def.CopyFrom(self.saver_def) 

1100 saver_def.filename_tensor_name = ops.strip_name_scope( 

1101 saver_def.filename_tensor_name, export_scope) 

1102 saver_def.save_tensor_name = ops.strip_name_scope( 

1103 saver_def.save_tensor_name, export_scope) 

1104 saver_def.restore_op_name = ops.strip_name_scope(saver_def.restore_op_name, 

1105 export_scope) 

1106 return saver_def 

1107 

1108 @staticmethod 

1109 def from_proto(saver_def, import_scope=None): 

1110 """Returns a `Saver` object created from `saver_def`. 

1111 

1112 Args: 

1113 saver_def: a `SaverDef` protocol buffer. 

1114 import_scope: Optional `string`. Name scope to use. 

1115 

1116 Returns: 

1117 A `Saver` built from saver_def. 

1118 """ 

1119 return Saver(saver_def=saver_def, name=import_scope) 

1120 

1121 @property 

1122 def last_checkpoints(self): 

1123 """List of not-yet-deleted checkpoint filenames. 

1124 

1125 You can pass any of the returned values to `restore()`. 

1126 

1127 Returns: 

1128 A list of checkpoint filenames, sorted from oldest to newest. 

1129 """ 

1130 return list(self._CheckpointFilename(p) for p in self._last_checkpoints) 

1131 

1132 def set_last_checkpoints(self, last_checkpoints): 

1133 """DEPRECATED: Use set_last_checkpoints_with_time. 

1134 

1135 Sets the list of old checkpoint filenames. 

1136 

1137 Args: 

1138 last_checkpoints: A list of checkpoint filenames. 

1139 

1140 Raises: 

1141 AssertionError: If last_checkpoints is not a list. 

1142 """ 

1143 assert isinstance(last_checkpoints, list) 

1144 # We use a timestamp of +inf so that this checkpoint will never be 

1145 # deleted. This is both safe and backwards compatible to a previous 

1146 # version of the code which used s[1] as the "timestamp". 

1147 self._last_checkpoints = [(s, np.inf) for s in last_checkpoints] 

1148 

1149 def set_last_checkpoints_with_time(self, last_checkpoints_with_time): 

1150 """Sets the list of old checkpoint filenames and timestamps. 

1151 

1152 Args: 

1153 last_checkpoints_with_time: A list of tuples of checkpoint filenames and 

1154 timestamps. 

1155 

1156 Raises: 

1157 AssertionError: If last_checkpoints_with_time is not a list. 

1158 """ 

1159 assert isinstance(last_checkpoints_with_time, list) 

1160 self._last_checkpoints = last_checkpoints_with_time 

1161 

1162 def recover_last_checkpoints(self, checkpoint_paths): 

1163 """Recovers the internal saver state after a crash. 

1164 

1165 This method is useful for recovering the "self._last_checkpoints" state. 

1166 

1167 Globs for the checkpoints pointed to by `checkpoint_paths`. If the files 

1168 exist, use their mtime as the checkpoint timestamp. 

1169 

1170 Args: 

1171 checkpoint_paths: a list of checkpoint paths. 

1172 """ 

1173 checkpoints_with_mtimes = [] 

1174 for checkpoint_path in checkpoint_paths: 

1175 try: 

1176 mtime = checkpoint_management.get_checkpoint_mtimes([checkpoint_path]) 

1177 except errors.NotFoundError: 

1178 # It's fine if some other thread/process is deleting some older 

1179 # checkpoint concurrently. 

1180 continue 

1181 if mtime: 

1182 checkpoints_with_mtimes.append((checkpoint_path, mtime[0])) 

1183 self.set_last_checkpoints_with_time(checkpoints_with_mtimes) 

1184 

1185 def save(self, 

1186 sess, 

1187 save_path, 

1188 global_step=None, 

1189 latest_filename=None, 

1190 meta_graph_suffix="meta", 

1191 write_meta_graph=True, 

1192 write_state=True, 

1193 strip_default_attrs=False, 

1194 save_debug_info=False): 

1195 # pylint: disable=line-too-long 

1196 """Saves variables. 

1197 

1198 This method runs the ops added by the constructor for saving variables. 

1199 It requires a session in which the graph was launched. The variables to 

1200 save must also have been initialized. 

1201 

1202 The method returns the path prefix of the newly created checkpoint files. 

1203 This string can be passed directly to a call to `restore()`. 

1204 

1205 Args: 

1206 sess: A Session to use to save the variables. 

1207 save_path: String. Prefix of filenames created for the checkpoint. 

1208 global_step: If provided the global step number is appended to `save_path` 

1209 to create the checkpoint filenames. The optional argument can be a 

1210 `Tensor`, a `Tensor` name or an integer. 

1211 latest_filename: Optional name for the protocol buffer file that will 

1212 contains the list of most recent checkpoints. That file, kept in the 

1213 same directory as the checkpoint files, is automatically managed by the 

1214 saver to keep track of recent checkpoints. Defaults to 'checkpoint'. 

1215 meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. 

1216 write_meta_graph: `Boolean` indicating whether or not to write the meta 

1217 graph file. 

1218 write_state: `Boolean` indicating whether or not to write the 

1219 `CheckpointStateProto`. 

1220 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 

1221 removed from the NodeDefs. For a detailed guide, see [Stripping 

1222 Default-Valued 

1223 Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 

1224 save_debug_info: If `True`, save the GraphDebugInfo to a separate file, 

1225 which in the same directory of save_path and with `_debug` added before 

1226 the file extension. This is only enabled when `write_meta_graph` is 

1227 `True` 

1228 

1229 Returns: 

1230 A string: path prefix used for the checkpoint files. If the saver is 

1231 sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn' 

1232 is the number of shards created. 

1233 If the saver is empty, returns None. 

1234 

1235 Raises: 

1236 TypeError: If `sess` is not a `Session`. 

1237 ValueError: If `latest_filename` contains path components, or if it 

1238 collides with `save_path`. 

1239 RuntimeError: If save and restore ops weren't built. 

1240 """ 

1241 # pylint: enable=line-too-long 

1242 start_time = time.time() 

1243 if not self._is_built and not context.executing_eagerly(): 

1244 raise RuntimeError( 

1245 "`build()` should be called before save if defer_build==True") 

1246 if latest_filename is None: 

1247 latest_filename = "checkpoint" 

1248 if self._write_version != saver_pb2.SaverDef.V2: 

1249 logging.warning("*******************************************************") 

1250 logging.warning("TensorFlow's V1 checkpoint format has been deprecated.") 

1251 logging.warning("Consider switching to the more efficient V2 format:") 

1252 logging.warning(" `tf.train.Saver(write_version=tf.train.SaverDef.V2)`") 

1253 logging.warning("now on by default.") 

1254 logging.warning("*******************************************************") 

1255 

1256 if os.path.split(latest_filename)[0]: 

1257 raise ValueError("'latest_filename' must not contain path components") 

1258 

1259 save_path = compat.as_str(save_path) 

1260 if global_step is not None: 

1261 if not isinstance(global_step, compat.integral_types): 

1262 global_step = training_util.global_step(sess, global_step) 

1263 checkpoint_file = "%s-%d" % (save_path, global_step) 

1264 if self._pad_step_number: 

1265 # Zero-pads the step numbers, so that they are sorted when listed. 

1266 checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step)) 

1267 else: 

1268 checkpoint_file = save_path 

1269 if os.path.basename(save_path) == latest_filename and not self._sharded: 

1270 # Guard against collision between data file and checkpoint state file. 

1271 raise ValueError( 

1272 "'latest_filename' collides with 'save_path': '%s' and '%s'" % 

1273 (latest_filename, save_path)) 

1274 

1275 if (not context.executing_eagerly() and 

1276 not isinstance(sess, session.SessionInterface)): 

1277 raise TypeError("'sess' must be a Session; %s" % sess) 

1278 

1279 save_path_parent = os.path.dirname(save_path) 

1280 if not self._is_empty: 

1281 try: 

1282 if context.executing_eagerly(): 

1283 self._build_eager( 

1284 checkpoint_file, build_save=True, build_restore=False) 

1285 model_checkpoint_path = self.saver_def.save_tensor_name 

1286 else: 

1287 model_checkpoint_path = sess.run( 

1288 self.saver_def.save_tensor_name, 

1289 {self.saver_def.filename_tensor_name: checkpoint_file}) 

1290 

1291 model_checkpoint_path = compat.as_str(model_checkpoint_path) 

1292 if write_state: 

1293 self._RecordLastCheckpoint(model_checkpoint_path) 

1294 checkpoint_management.update_checkpoint_state_internal( 

1295 save_dir=save_path_parent, 

1296 model_checkpoint_path=model_checkpoint_path, 

1297 all_model_checkpoint_paths=self.last_checkpoints, 

1298 latest_filename=latest_filename, 

1299 save_relative_paths=self._save_relative_paths) 

1300 self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix) 

1301 except (errors.FailedPreconditionError, errors.NotFoundError) as exc: 

1302 if not gfile.IsDirectory(save_path_parent): 

1303 exc = ValueError( 

1304 "Parent directory of {} doesn't exist, can't save.".format( 

1305 save_path)) 

1306 raise exc 

1307 

1308 end_time = time.time() 

1309 metrics.AddCheckpointWriteDuration( 

1310 api_label=_SAVER_LABEL, 

1311 microseconds=_get_duration_microseconds(start_time, end_time)) 

1312 global _END_TIME_OF_LAST_WRITE 

1313 with _END_TIME_OF_LAST_WRITE_LOCK: 

1314 metrics.AddTrainingTimeSaved( 

1315 api_label=_SAVER_LABEL, 

1316 microseconds=_get_duration_microseconds(_END_TIME_OF_LAST_WRITE, 

1317 end_time)) 

1318 _END_TIME_OF_LAST_WRITE = end_time 

1319 

1320 if write_meta_graph: 

1321 meta_graph_filename = checkpoint_management.meta_graph_filename( 

1322 checkpoint_file, meta_graph_suffix=meta_graph_suffix) 

1323 if not context.executing_eagerly(): 

1324 with sess.graph.as_default(): 

1325 self.export_meta_graph( 

1326 meta_graph_filename, 

1327 strip_default_attrs=strip_default_attrs, 

1328 save_debug_info=save_debug_info) 

1329 

1330 if self._is_empty: 

1331 return None 

1332 else: 

1333 metrics.RecordCheckpointSize( 

1334 api_label=_SAVER_LABEL, 

1335 filesize=_get_checkpoint_size(model_checkpoint_path)) 

1336 return model_checkpoint_path 

1337 

1338 def export_meta_graph(self, 

1339 filename=None, 

1340 collection_list=None, 

1341 as_text=False, 

1342 export_scope=None, 

1343 clear_devices=False, 

1344 clear_extraneous_savers=False, 

1345 strip_default_attrs=False, 

1346 save_debug_info=False): 

1347 # pylint: disable=line-too-long 

1348 """Writes `MetaGraphDef` to save_path/filename. 

1349 

1350 Args: 

1351 filename: Optional meta_graph filename including the path. 

1352 collection_list: List of string keys to collect. 

1353 as_text: If `True`, writes the meta_graph as an ASCII proto. 

1354 export_scope: Optional `string`. Name scope to remove. 

1355 clear_devices: Whether or not to clear the device field for an `Operation` 

1356 or `Tensor` during export. 

1357 clear_extraneous_savers: Remove any Saver-related information from the 

1358 graph (both Save/Restore ops and SaverDefs) that are not associated with 

1359 this Saver. 

1360 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 

1361 removed from the NodeDefs. For a detailed guide, see [Stripping 

1362 Default-Valued 

1363 Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 

1364 save_debug_info: If `True`, save the GraphDebugInfo to a separate file, 

1365 which in the same directory of filename and with `_debug` added before 

1366 the file extension. 

1367 

1368 Returns: 

1369 A `MetaGraphDef` proto. 

1370 """ 

1371 # pylint: enable=line-too-long 

1372 return export_meta_graph( 

1373 filename=filename, 

1374 graph_def=ops.get_default_graph().as_graph_def(add_shapes=True), 

1375 saver_def=self.saver_def, 

1376 collection_list=collection_list, 

1377 as_text=as_text, 

1378 export_scope=export_scope, 

1379 clear_devices=clear_devices, 

1380 clear_extraneous_savers=clear_extraneous_savers, 

1381 strip_default_attrs=strip_default_attrs, 

1382 save_debug_info=save_debug_info) 

1383 

1384 def restore(self, sess, save_path): 

1385 """Restores previously saved variables. 

1386 

1387 This method runs the ops added by the constructor for restoring variables. 

1388 It requires a session in which the graph was launched. The variables to 

1389 restore do not have to have been initialized, as restoring is itself a way 

1390 to initialize variables. 

1391 

1392 The `save_path` argument is typically a value previously returned from a 

1393 `save()` call, or a call to `latest_checkpoint()`. 

1394 

1395 Args: 

1396 sess: A `Session` to use to restore the parameters. None in eager mode. 

1397 save_path: Path where parameters were previously saved. 

1398 

1399 Raises: 

1400 ValueError: If save_path is None or not a valid checkpoint. 

1401 """ 

1402 start_time = time.time() 

1403 if self._is_empty: 

1404 return 

1405 if save_path is None: 

1406 raise ValueError("Can't load save_path when it is None.") 

1407 

1408 checkpoint_prefix = compat.as_text(save_path) 

1409 if not checkpoint_management.checkpoint_exists_internal(checkpoint_prefix): 

1410 raise ValueError("The passed save_path is not a valid checkpoint: " + 

1411 checkpoint_prefix) 

1412 

1413 logging.info("Restoring parameters from %s", checkpoint_prefix) 

1414 try: 

1415 if context.executing_eagerly(): 

1416 self._build_eager(save_path, build_save=False, build_restore=True) 

1417 else: 

1418 sess.run(self.saver_def.restore_op_name, 

1419 {self.saver_def.filename_tensor_name: save_path}) 

1420 except errors.NotFoundError as err: 

1421 # There are three common conditions that might cause this error: 

1422 # 0. The file is missing. We ignore here, as this is checked above. 

1423 # 1. This is an object-based checkpoint trying name-based loading. 

1424 # 2. The graph has been altered and a variable or other name is missing. 

1425 

1426 # 1. The checkpoint would not be loaded successfully as is. Try to parse 

1427 # it as an object-based checkpoint. 

1428 try: 

1429 names_to_keys = object_graph_key_mapping(save_path) 

1430 except errors.NotFoundError: 

1431 # 2. This is not an object-based checkpoint, which likely means there 

1432 # is a graph mismatch. Re-raise the original error with 

1433 # a helpful message (b/110263146) 

1434 raise _wrap_restore_error_with_msg( 

1435 err, "a Variable name or other graph key that is missing") 

1436 

1437 # This is an object-based checkpoint. We'll print a warning and then do 

1438 # the restore. 

1439 logging.warning( 

1440 "Restoring an object-based checkpoint using a name-based saver. This " 

1441 "may be somewhat fragile, and will re-build the Saver. Instead, " 

1442 "consider loading object-based checkpoints using " 

1443 "tf.train.Checkpoint().") 

1444 self._object_restore_saver = saver_from_object_based_checkpoint( 

1445 checkpoint_path=save_path, 

1446 var_list=self._var_list, 

1447 builder=self._builder, 

1448 names_to_keys=names_to_keys, 

1449 cached_saver=self._object_restore_saver) 

1450 self._object_restore_saver.restore(sess=sess, save_path=save_path) 

1451 except errors.InvalidArgumentError as err: 

1452 # There is a mismatch between the graph and the checkpoint being loaded. 

1453 # We add a more reasonable error message here to help users (b/110263146) 

1454 raise _wrap_restore_error_with_msg( 

1455 err, "a mismatch between the current graph and the graph") 

1456 metrics.AddCheckpointReadDuration( 

1457 api_label=_SAVER_LABEL, 

1458 microseconds=_get_duration_microseconds(start_time, time.time())) 

1459 

1460 @staticmethod 

1461 def _add_collection_def(meta_graph_def, key, export_scope=None): 

1462 """Adds a collection to MetaGraphDef protocol buffer. 

1463 

1464 Args: 

1465 meta_graph_def: MetaGraphDef protocol buffer. 

1466 key: One of the GraphKeys or user-defined string. 

1467 export_scope: Optional `string`. Name scope to remove. 

1468 """ 

1469 meta_graph.add_collection_def( 

1470 meta_graph_def, key, export_scope=export_scope) 

1471 

1472 

1473@tf_export(v1=["train.import_meta_graph"]) 

1474def import_meta_graph(meta_graph_or_file, 

1475 clear_devices=False, 

1476 import_scope=None, 

1477 **kwargs): 

1478 """Recreates a Graph saved in a `MetaGraphDef` proto. 

1479 

1480 This function takes a `MetaGraphDef` protocol buffer as input. If 

1481 the argument is a file containing a `MetaGraphDef` protocol buffer , 

1482 it constructs a protocol buffer from the file content. The function 

1483 then adds all the nodes from the `graph_def` field to the 

1484 current graph, recreates all the collections, and returns a saver 

1485 constructed from the `saver_def` field. 

1486 

1487 In combination with `export_meta_graph()`, this function can be used to 

1488 

1489 * Serialize a graph along with other Python objects such as `QueueRunner`, 

1490 `Variable` into a `MetaGraphDef`. 

1491 

1492 * Restart training from a saved graph and checkpoints. 

1493 

1494 * Run inference from a saved graph and checkpoints. 

1495 

1496 ```Python 

1497 ... 

1498 # Create a saver. 

1499 saver = tf.compat.v1.train.Saver(...variables...) 

1500 # Remember the training_op we want to run by adding it to a collection. 

1501 tf.compat.v1.add_to_collection('train_op', train_op) 

1502 sess = tf.compat.v1.Session() 

1503 for step in range(1000000): 

1504 sess.run(train_op) 

1505 if step % 1000 == 0: 

1506 # Saves checkpoint, which by default also exports a meta_graph 

1507 # named 'my-model-global_step.meta'. 

1508 saver.save(sess, 'my-model', global_step=step) 

1509 ``` 

1510 

1511 Later we can continue training from this saved `meta_graph` without building 

1512 the model from scratch. 

1513 

1514 ```Python 

1515 with tf.Session() as sess: 

1516 new_saver = 

1517 tf.train.import_meta_graph('my-save-dir/my-model-10000.meta') 

1518 new_saver.restore(sess, 'my-save-dir/my-model-10000') 

1519 # tf.get_collection() returns a list. In this example we only want 

1520 # the first one. 

1521 train_op = tf.get_collection('train_op')[0] 

1522 for step in range(1000000): 

1523 sess.run(train_op) 

1524 ``` 

1525 

1526 NOTE: Restarting training from saved `meta_graph` only works if the 

1527 device assignments have not changed. 

1528 

1529 Example: 

1530 Variables, placeholders, and independent operations can also be stored, as 

1531 shown in the following example. 

1532 

1533 ```Python 

1534 # Saving contents and operations. 

1535 v1 = tf.placeholder(tf.float32, name="v1") 

1536 v2 = tf.placeholder(tf.float32, name="v2") 

1537 v3 = tf.math.multiply(v1, v2) 

1538 vx = tf.Variable(10.0, name="vx") 

1539 v4 = tf.add(v3, vx, name="v4") 

1540 saver = tf.train.Saver([vx]) 

1541 sess = tf.Session() 

1542 sess.run(tf.global_variables_initializer()) 

1543 sess.run(vx.assign(tf.add(vx, vx))) 

1544 result = sess.run(v4, feed_dict={v1:12.0, v2:3.3}) 

1545 print(result) 

1546 saver.save(sess, "./model_ex1") 

1547 ``` 

1548 

1549 Later this model can be restored and contents loaded. 

1550 

1551 ```Python 

1552 # Restoring variables and running operations. 

1553 saver = tf.train.import_meta_graph("./model_ex1.meta") 

1554 sess = tf.Session() 

1555 saver.restore(sess, "./model_ex1") 

1556 result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3}) 

1557 print(result) 

1558 ``` 

1559 

1560 Args: 

1561 meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including 

1562 the path) containing a `MetaGraphDef`. 

1563 clear_devices: Whether or not to clear the device field for an `Operation` 

1564 or `Tensor` during import. 

1565 import_scope: Optional `string`. Name scope to add. Only used when 

1566 initializing from protocol buffer. 

1567 **kwargs: Optional keyed arguments. 

1568 

1569 Returns: 

1570 A saver constructed from `saver_def` in `MetaGraphDef` or None. 

1571 

1572 A None value is returned if no variables exist in the `MetaGraphDef` 

1573 (i.e., there are no variables to restore). 

1574 

1575 Raises: 

1576 RuntimeError: If called with eager execution enabled. 

1577 

1578 @compatibility(eager) 

1579 Exporting/importing meta graphs is not supported. No graph exists when eager 

1580 execution is enabled. 

1581 @end_compatibility 

1582 """ # pylint: disable=g-doc-exception 

1583 return _import_meta_graph_with_return_elements(meta_graph_or_file, 

1584 clear_devices, import_scope, 

1585 **kwargs)[0] 

1586 

1587 

1588def _import_meta_graph_with_return_elements(meta_graph_or_file, 

1589 clear_devices=False, 

1590 import_scope=None, 

1591 return_elements=None, 

1592 **kwargs): 

1593 """Import MetaGraph, and return both a saver and returned elements.""" 

1594 if context.executing_eagerly(): 

1595 raise RuntimeError("Exporting/importing meta graphs is not supported when " 

1596 "eager execution is enabled. No graph exists when eager " 

1597 "execution is enabled.") 

1598 if not isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): 

1599 meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file) 

1600 else: 

1601 meta_graph_def = meta_graph_or_file 

1602 

1603 imported_vars, imported_return_elements = ( 

1604 meta_graph.import_scoped_meta_graph_with_return_elements( 

1605 meta_graph_def, 

1606 clear_devices=clear_devices, 

1607 import_scope=import_scope, 

1608 return_elements=return_elements, 

1609 **kwargs)) 

1610 

1611 saver = _create_saver_from_imported_meta_graph(meta_graph_def, import_scope, 

1612 imported_vars) 

1613 return saver, imported_return_elements 

1614 

1615 

1616def _create_saver_from_imported_meta_graph(meta_graph_def, import_scope, 

1617 imported_vars): 

1618 """Return a saver for restoring variable values to an imported MetaGraph.""" 

1619 if meta_graph_def.HasField("saver_def"): 

1620 # Infer the scope that is prepended by `import_scoped_meta_graph`. 

1621 scope = import_scope 

1622 var_names = list(imported_vars.keys()) 

1623 if var_names: 

1624 sample_key = var_names[0] 

1625 sample_var = imported_vars[sample_key] 

1626 scope = sample_var.name[:-len(sample_key)] 

1627 

1628 return Saver(saver_def=meta_graph_def.saver_def, name=scope) 

1629 else: 

1630 if variables._all_saveable_objects(scope=import_scope): # pylint: disable=protected-access 

1631 # Return the default saver instance for all graph variables. 

1632 return Saver() 

1633 else: 

1634 # If no graph variables exist, then a Saver cannot be constructed. 

1635 logging.info("Saver not created because there are no variables in the" 

1636 " graph to restore") 

1637 return None 

1638 

1639 

1640@tf_export(v1=["train.export_meta_graph"]) 

1641def export_meta_graph(filename=None, 

1642 meta_info_def=None, 

1643 graph_def=None, 

1644 saver_def=None, 

1645 collection_list=None, 

1646 as_text=False, 

1647 graph=None, 

1648 export_scope=None, 

1649 clear_devices=False, 

1650 clear_extraneous_savers=False, 

1651 strip_default_attrs=False, 

1652 save_debug_info=False, 

1653 **kwargs): 

1654 # pylint: disable=line-too-long 

1655 """Returns `MetaGraphDef` proto. 

1656 

1657 Optionally writes it to filename. 

1658 

1659 This function exports the graph, saver, and collection objects into 

1660 `MetaGraphDef` protocol buffer with the intention of it being imported 

1661 at a later time or location to restart training, run inference, or be 

1662 a subgraph. 

1663 

1664 Args: 

1665 filename: Optional filename including the path for writing the generated 

1666 `MetaGraphDef` protocol buffer. 

1667 meta_info_def: `MetaInfoDef` protocol buffer. 

1668 graph_def: `GraphDef` protocol buffer. 

1669 saver_def: `SaverDef` protocol buffer. 

1670 collection_list: List of string keys to collect. 

1671 as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto. 

1672 graph: The `Graph` to export. If `None`, use the default graph. 

1673 export_scope: Optional `string`. Name scope under which to extract the 

1674 subgraph. The scope name will be striped from the node definitions for 

1675 easy import later into new name scopes. If `None`, the whole graph is 

1676 exported. graph_def and export_scope cannot both be specified. 

1677 clear_devices: Whether or not to clear the device field for an `Operation` 

1678 or `Tensor` during export. 

1679 clear_extraneous_savers: Remove any Saver-related information from the graph 

1680 (both Save/Restore ops and SaverDefs) that are not associated with the 

1681 provided SaverDef. 

1682 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 

1683 removed from the NodeDefs. For a detailed guide, see [Stripping 

1684 Default-Valued 

1685 Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 

1686 save_debug_info: If `True`, save the GraphDebugInfo to a separate file, 

1687 which in the same directory of filename and with `_debug` added before the 

1688 file extend. 

1689 **kwargs: Optional keyed arguments. 

1690 

1691 Returns: 

1692 A `MetaGraphDef` proto. 

1693 

1694 Raises: 

1695 ValueError: When the `GraphDef` is larger than 2GB. 

1696 RuntimeError: If called with eager execution enabled. 

1697 

1698 @compatibility(eager) 

1699 Exporting/importing meta graphs is not supported unless both `graph_def` and 

1700 `graph` are provided. No graph exists when eager execution is enabled. 

1701 @end_compatibility 

1702 """ 

1703 # pylint: enable=line-too-long 

1704 if context.executing_eagerly() and not (graph_def is not None and 

1705 graph is not None): 

1706 raise RuntimeError("Exporting/importing meta graphs is not supported when " 

1707 "eager execution is enabled. No graph exists when eager " 

1708 "execution is enabled.") 

1709 meta_graph_def, _ = meta_graph.export_scoped_meta_graph( 

1710 filename=filename, 

1711 meta_info_def=meta_info_def, 

1712 graph_def=graph_def, 

1713 saver_def=saver_def, 

1714 collection_list=collection_list, 

1715 as_text=as_text, 

1716 graph=graph, 

1717 export_scope=export_scope, 

1718 clear_devices=clear_devices, 

1719 clear_extraneous_savers=clear_extraneous_savers, 

1720 strip_default_attrs=strip_default_attrs, 

1721 save_debug_info=save_debug_info, 

1722 **kwargs) 

1723 return meta_graph_def 

1724 

1725 

1726def _wrap_restore_error_with_msg(err, extra_verbiage): 

1727 err_msg = ("Restoring from checkpoint failed. This is most likely " 

1728 "due to {} from the checkpoint. Please ensure that you " 

1729 "have not altered the graph expected based on the checkpoint. " 

1730 "Original error:\n\n{}").format(extra_verbiage, err.message) 

1731 return err.__class__(err.node_def, err.op, err_msg) 

1732 

1733 

1734ops.register_proto_function( 

1735 ops.GraphKeys.SAVERS, 

1736 proto_type=saver_pb2.SaverDef, 

1737 to_proto=Saver.to_proto, 

1738 from_proto=Saver.from_proto) 

1739 

1740 

1741def object_graph_key_mapping(checkpoint_path): 

1742 """Return name to key mappings from the checkpoint. 

1743 

1744 Args: 

1745 checkpoint_path: string, path to object-based checkpoint 

1746 

1747 Returns: 

1748 Dictionary mapping tensor names to checkpoint keys. 

1749 """ 

1750 reader = py_checkpoint_reader.NewCheckpointReader(checkpoint_path) 

1751 object_graph_string = reader.get_tensor(trackable.OBJECT_GRAPH_PROTO_KEY) 

1752 object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) 

1753 object_graph_proto.ParseFromString(object_graph_string) 

1754 names_to_keys = {} 

1755 for node in object_graph_proto.nodes: 

1756 for attribute in node.attributes: 

1757 names_to_keys[attribute.full_name] = attribute.checkpoint_key 

1758 return names_to_keys 

1759 

1760 

1761def saver_from_object_based_checkpoint(checkpoint_path, 

1762 var_list=None, 

1763 builder=None, 

1764 names_to_keys=None, 

1765 cached_saver=None): 

1766 """Return a `Saver` which reads from an object-based checkpoint. 

1767 

1768 This function validates that all variables in the variables list are remapped 

1769 in the object-based checkpoint (or `names_to_keys` dict if provided). A 

1770 saver will be created with the list of remapped variables. 

1771 

1772 The `cached_saver` argument allows the user to pass in a previously created 

1773 saver, so multiple `saver.restore()` calls don't pollute the graph when graph 

1774 building. This assumes that keys are consistent, meaning that the 

1775 1) `checkpoint_path` checkpoint, and 

1776 2) checkpoint used to create the `cached_saver` 

1777 are the same type of object-based checkpoint. If this argument is set, this 

1778 function will simply validate that all variables have been remapped by the 

1779 checkpoint at `checkpoint_path`. 

1780 

1781 Note that in general, `tf.train.Checkpoint` should be used to restore/save an 

1782 object-based checkpoint. 

1783 

1784 Args: 

1785 checkpoint_path: string, path to object-based checkpoint 

1786 var_list: list of `Variables` that appear in the checkpoint. If `None`, 

1787 `var_list` will be set to all saveable objects. 

1788 builder: a `BaseSaverBuilder` instance. If `None`, a new `BulkSaverBuilder` 

1789 will be created. 

1790 names_to_keys: dict mapping string tensor names to checkpoint keys. If 

1791 `None`, this dict will be generated from the checkpoint file. 

1792 cached_saver: Cached `Saver` object with remapped variables. 

1793 

1794 Returns: 

1795 `Saver` with remapped variables for reading from an object-based checkpoint. 

1796 

1797 Raises: 

1798 ValueError if the checkpoint provided is not an object-based checkpoint. 

1799 NotFoundError: If one of the variables in `var_list` can not be found in the 

1800 checkpoint. This could mean the checkpoint or `names_to_keys` mapping is 

1801 missing the variable. 

1802 """ 

1803 if names_to_keys is None: 

1804 try: 

1805 names_to_keys = object_graph_key_mapping(checkpoint_path) 

1806 except errors.NotFoundError: 

1807 raise ValueError("Checkpoint in %s not an object-based checkpoint." % 

1808 checkpoint_path) 

1809 if var_list is None: 

1810 var_list = variables._all_saveable_objects() # pylint: disable=protected-access 

1811 if builder is None: 

1812 builder = BulkSaverBuilder() 

1813 

1814 if not isinstance(var_list, dict): 

1815 var_list = saveable_object_util.op_list_to_dict(var_list) 

1816 saveables = saveable_object_util.validate_and_slice_inputs(var_list) 

1817 current_names = set() 

1818 for saveable in saveables: 

1819 for spec in saveable.specs: 

1820 current_names.add(spec.name) 

1821 previous_names = set(names_to_keys.keys()) 

1822 missing_names = current_names - previous_names 

1823 if missing_names: 

1824 extra_names = previous_names - current_names 

1825 intersecting_names = previous_names.intersection(current_names) 

1826 raise errors.NotFoundError( 

1827 None, 

1828 None, 

1829 message=( 

1830 "\n\nExisting variables not in the checkpoint: %s\n\n" 

1831 "Variables names when this checkpoint was written which don't " 

1832 "exist now: %s\n\n" 

1833 "(%d variable name(s) did match)\n\n" 

1834 "Could not find some variables in the checkpoint (see names " 

1835 "above). Saver was attempting to load an object-based checkpoint " 

1836 "(saved using tf.train.Checkpoint or tf.keras.Model.save_weights) " 

1837 "using variable names. If the checkpoint was written with eager " 

1838 "execution enabled, it's possible that variable names have " 

1839 "changed (for example missing a '_1' suffix). It's also " 

1840 "possible that there are new variables which did not exist " 

1841 "when the checkpoint was written. You can construct a " 

1842 "Saver(var_list=...) with only the variables which previously " 

1843 "existed, and if variable names have changed you may need to " 

1844 "make this a dictionary with the old names as keys. If you're " 

1845 "using an Estimator, you'll need to return a tf.train.Saver " 

1846 "inside a tf.train.Scaffold from your model_fn.") % 

1847 (", ".join(sorted(missing_names)), ", ".join( 

1848 sorted(extra_names)), len(intersecting_names))) 

1849 for saveable in saveables: 

1850 for spec in saveable.specs: 

1851 spec.name = names_to_keys[spec.name] 

1852 if cached_saver is None: 

1853 return Saver(saveables) 

1854 return cached_saver