Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/checkpoint/async_checkpoint_helper.py: 20%

236 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities for saving/loading Trackable objects asynchronously.""" 

16 

17import atexit 

18import collections 

19import copy 

20import threading 

21import time 

22import weakref 

23 

24from absl import logging 

25 

26from tensorflow.python.checkpoint import checkpoint_context 

27from tensorflow.python.distribute import device_util 

28from tensorflow.python.distribute.sharded_variable import ShardedVariable 

29from tensorflow.python.eager import context 

30from tensorflow.python.eager import def_function 

31from tensorflow.python.eager import executor 

32from tensorflow.python.framework import device as pydev 

33from tensorflow.python.framework import ops 

34from tensorflow.python.ops.resource_variable_ops import UninitializedVariable 

35from tensorflow.python.ops.variables import Variable 

36from tensorflow.python.saved_model.pywrap_saved_model import metrics 

37from tensorflow.python.util import object_identity 

38 

39# Captures the timestamp of the first Checkpoint instantiation or end of a write 

40# operation. Can be accessed by multiple Checkpoint instances. 

41_END_TIME_OF_LAST_ASYNC_WRITE = None 

42_END_TIME_OF_LAST_ASYNC_WRITE_LOCK = threading.Lock() 

43 

44# API label for cell names used in async checkpoint metrics. 

45_ASYNC_CHECKPOINT = "async_checkpoint" 

46 

47# Name of TPUEmbedding attribute. This is a temporary workaround 

48# to identify TPUEmbedding while avoiding import cycles. 

49_TPU_EMBEDDING_ATTR = "_create_copy_for_async_checkpoint" 

50 

51 

52def _get_duration_microseconds(start_time_seconds, end_time_seconds): 

53 """Calculate the duration between start and end time. 

54 

55 Args: 

56 start_time_seconds: The start time in seconds. 

57 end_time_seconds: The end time in seconds. 

58 

59 Returns: 

60 The duration between the start and the end time. Return 0 if 

61 end_time_seconds < start_time_seconds. 

62 """ 

63 if end_time_seconds < start_time_seconds: 

64 # Avoid returning negative value in case of clock skew. 

65 return 0 

66 return round((end_time_seconds - start_time_seconds) * 1000000) 

67 

68 

69class AsyncCheckpointHelper: 

70 """Helper class for async checkpoint.""" 

71 

72 def __init__(self, checkpointer_impl, root=None, **kwargs): 

73 """Initialize AsyncCheckpoint. 

74 

75 Args: 

76 checkpointer_impl: The Checkpoint class to power the AsyncCheckpoint. 

77 root: The root object to checkpoint. `root` may be a trackable object or 

78 `WeakRef` of a trackable object. 

79 **kwargs: The keyword arguments representing the checkpointed variables. 

80 """ 

81 # TODO(chienchunh): Make sure the processing for the root object is 

82 # consistent when integrating with the public API, e.g., adding all kwarg 

83 # items as the child of the root object. 

84 if root: 

85 trackable_root = root() if isinstance(root, weakref.ref) else root 

86 kwargs["root"] = trackable_root 

87 trackable_root._maybe_initialize_trackable() 

88 

89 self._checkpointer_impl = checkpointer_impl 

90 self._checkpoint_items = kwargs 

91 

92 # The underlying Checkpoint instance and its items. 

93 self._checkpoint = None 

94 self._checkpoint_options = None 

95 

96 # The callback function that needs to be executed after checkpoint write. 

97 # Currently this is only applied to the scenario where CheckpointManager is 

98 # used, which triggers the _write() method. 

99 self._async_write_done_callback = None 

100 

101 # The list of all nodes from the original checkpoint items. 

102 # TODO(chienchunh): Consider changing this to local variable. 

103 self._original_nodes = None 

104 # The mapping between the original and the copied resource variables. 

105 # The copied variables are used for the underlying checkpointing. 

106 self._object_map = None 

107 # A list of TPUEmbedding objects included in the checkpoint items. 

108 self._tpu_embedding_objects = None 

109 

110 self._default_device = device_util.current() or "CPU:0" 

111 self._default_device = device_util.canonicalize(self._default_device) 

112 

113 self._save_file_prefix = None 

114 self._use_checkpoint_save = False 

115 self._async_save_thread = None 

116 self._async_save_thread_shutdown = False 

117 # Semaphores for writing/reading the cpu-copied variables (self._var_pairs) 

118 # TODO(chienchunh): Consider Queue/Condition instead of Semaphore. 

119 self._writer_sem = threading.Semaphore(1) 

120 self._reader_sem = threading.Semaphore(0) 

121 

122 # Register to join the async save thread upon exit. 

123 atexit.register(self._join_async_save_thread) 

124 

125 self._async_error = None 

126 

127 global _END_TIME_OF_LAST_ASYNC_WRITE 

128 with _END_TIME_OF_LAST_ASYNC_WRITE_LOCK: 

129 if _END_TIME_OF_LAST_ASYNC_WRITE is None: 

130 _END_TIME_OF_LAST_ASYNC_WRITE = time.time() 

131 

132 @def_function.function 

133 def _copy_from_cpu(self): 

134 """Copy the checkpointed variables from the host CPU to the accelerator. 

135 

136 TODO(chienchunh): Get the concrete function before firstly called to avoid 

137 hangining the accelerators idle during function tracing. 

138 """ 

139 for accelerator_var, cpu_var in self._object_map.items(): 

140 if isinstance(accelerator_var, ShardedVariable) or hasattr( 

141 accelerator_var, _TPU_EMBEDDING_ATTR): 

142 # Skip for SharededVariable and TPUEmbedding as their sub-variables will 

143 # be copied over separately through other entries in the object map. 

144 continue 

145 with ops.device(accelerator_var.device): 

146 accelerator_var.assign(cpu_var.read_value()) 

147 

148 @def_function.function 

149 def _copy_to_cpu(self): 

150 """Copy the checkpointed variables from the accelerator to the host CPU. 

151 

152 TODO(chienchunh): Get the concrete function before firstly called to avoid 

153 hangining the accelerators idle during function tracing. 

154 """ 

155 for accelerator_var, cpu_var in self._object_map.items(): 

156 if isinstance(accelerator_var, ShardedVariable) or hasattr( 

157 accelerator_var, _TPU_EMBEDDING_ATTR): 

158 # Skip for SharededVariable and TPUEmbedding as their sub-variables will 

159 # be copied over separately through other entries in the object map. 

160 continue 

161 with ops.device(cpu_var.device): 

162 cpu_var.assign(accelerator_var.read_value()) 

163 for tpu_embedding in self._tpu_embedding_objects: 

164 tpu_embedding._retrieve_variables() # pylint: disable=protected-access 

165 

166 def _traverse_variables(self, to_traverse, visited): 

167 """Create the copied nodes and variables while traversing the nodes. 

168 

169 This method performs a BFS to traverse the nodes while avoiding duplicated 

170 visits. Throughout the process, self._mapping, self._original_nodes, and 

171 self._var_pairs are populated. 

172 

173 Args: 

174 to_traverse: A deque that stores the nodes to be traversed. 

175 visited: A list of nodes that have been visited. 

176 """ 

177 # pylint: disable=protected-access 

178 while to_traverse: 

179 current_trackable = to_traverse.popleft() 

180 self._original_nodes.append(current_trackable) 

181 

182 if isinstance(current_trackable, (Variable, ShardedVariable)): 

183 self._copy_trackable(current_trackable) 

184 if hasattr(current_trackable, _TPU_EMBEDDING_ATTR): 

185 self._handle_tpu_embedding(current_trackable) 

186 

187 for child in current_trackable._trackable_children( 

188 save_type="checkpoint").values(): 

189 if child in visited: 

190 continue 

191 visited.add(child) 

192 to_traverse.append(child) 

193 # pylint: enable=protected-access 

194 

195 def _ensure_initialized(self): 

196 """Initialize the async checkpoint internal state.""" 

197 if self._checkpoint is not None: 

198 return 

199 

200 self._original_nodes = [] 

201 self._object_map = object_identity.ObjectIdentityDictionary() 

202 self._tpu_embedding_objects = [] 

203 

204 # Add the top-level checkpoint items to be traversed, 

205 to_traverse = collections.deque([]) 

206 visited = object_identity.ObjectIdentitySet() 

207 for v in self._checkpoint_items.values(): 

208 if isinstance(v, (Variable, ShardedVariable)): 

209 self._copy_trackable(v) 

210 elif hasattr(v, _TPU_EMBEDDING_ATTR): 

211 self._handle_tpu_embedding(v) 

212 to_traverse.append(v) 

213 visited.add(v) 

214 self._traverse_variables(to_traverse, visited) 

215 

216 # Copy for the slot variables. 

217 for current_trackable in self._original_nodes: 

218 # Note: dir() is used rather than hasattr() here to avoid triggering 

219 # custom __getattr__ code, see b/152031870 for context. 

220 if "get_slot_names" in dir(current_trackable): 

221 slot_names = current_trackable.get_slot_names() 

222 for slot_name in slot_names: 

223 for original_variable in self._original_nodes: 

224 if not isinstance(original_variable, Variable): 

225 continue 

226 try: 

227 original_slot_variable = current_trackable.get_slot( 

228 original_variable, slot_name) 

229 except (AttributeError, KeyError): 

230 continue 

231 if isinstance(original_slot_variable, (Variable, ShardedVariable)): 

232 self._copy_trackable(original_slot_variable) 

233 

234 # Initiate the underlying Checkpoint instance with the copied items. 

235 self._checkpoint = self._checkpointer_impl(**self._checkpoint_items) 

236 # Initiate the underlying Checkpoint instance's save_counter. 

237 save_counter = self._checkpoint.save_counter 

238 logging.info("Initializing async checkpoint's save_counter: %d", 

239 save_counter) 

240 

241 # Pass the object map of the copied variables to the underlying Checkpoint. 

242 self._checkpoint._saver._object_map = self._object_map # pylint: disable=protected-access 

243 

244 # Initiate the async thread for checkpoint saving. 

245 self._async_save_thread = threading.Thread( 

246 target=self._async_save, daemon=True) 

247 self._async_save_thread.start() 

248 

249 def _check_async_thread_error(self): 

250 """Expose the most recent error from the async saving thread to the caller. 

251 """ 

252 if self._async_error: 

253 e = self._async_error 

254 self._async_error = None 

255 logging.error("Propagating the most recent error from the async thread " 

256 "before joining: %s", str(e)) 

257 # This allows the registered at-exit method '_join_async_save_thread' to 

258 # acquire the semaphore instead of timing out. 

259 self._writer_sem.release() 

260 raise e 

261 

262 def _join_async_save_thread(self): 

263 """Join the async save thread. 

264 

265 The steps for terminating the async save thread: 

266 1). Wait until the last async save event is done. 

267 2). Set _async_save_thread_shutdown flag to false to indicate termination. 

268 3). Trigger the async save thread to check and fail the while-predicate. 

269 4). Join the async save thread. (The thread may finish before joining.) 

270 """ 

271 # Expose the async thread error (if any) before joining the thread. 

272 self._check_async_thread_error() 

273 

274 if self._writer_sem.acquire(timeout=300): # Step-1. 

275 self._async_save_thread_shutdown = True # Step-2. 

276 self._reader_sem.release() # Step-3. 

277 logging.info("Joining the async save thread.") 

278 if self._async_save_thread is not None: 

279 self._async_save_thread.join() # Step-4. 

280 else: 

281 logging.error("Timeout waiting for the async save thread; terminating the" 

282 " thread instead. The last checkpoint may be incomeplete.") 

283 

284 def _async_save(self): 

285 """The thread function for the async checkpoint save.""" 

286 with context.executor_scope( 

287 executor.new_executor( 

288 enable_async=False, enable_streaming_enqueue=False)): 

289 while self._reader_sem.acquire() and not self._async_save_thread_shutdown: 

290 logging.info("Starting async checkpoint save on the device: %s", 

291 self._default_device) 

292 

293 async_save_start_time = time.time() 

294 

295 # Specify the ops placement on the worker if running with 

296 # coordinator-worker mode. This is required as launching a new thread 

297 # would clear the placement policy and make localhost the default 

298 # placement, while the main thread's default placement would be the 

299 # master worker's CPU:0. 

300 try: 

301 with ops.device(self._default_device): 

302 with checkpoint_context.async_metrics_context(): 

303 if self._use_checkpoint_save: 

304 self._checkpoint.save(self._save_file_prefix, 

305 self._checkpoint_options) 

306 else: 

307 self._checkpoint._write( # pylint: disable=protected-access 

308 self._save_file_prefix, 

309 options=self._checkpoint_options, 

310 write_done_callback=self._async_write_done_callback) 

311 except Exception as e: # # pylint: disable=broad-except 

312 self._async_error = e 

313 finally: 

314 self._writer_sem.release() 

315 

316 async_save_end_time = time.time() 

317 metrics.AddAsyncCheckpointWriteDuration( 

318 api_label=_ASYNC_CHECKPOINT, 

319 microseconds=_get_duration_microseconds(async_save_start_time, 

320 async_save_end_time)) 

321 

322 # Measure the elapsed time since the last checkpoint. 

323 # Due to the nature of async checkpoint, here it actually captures the 

324 # duration between the start_time of the previous checkpoint and the 

325 # start time of this checkpoint. As a result, the duration of the final 

326 # async checkpoint is excluded, which is fine since it does not take 

327 # much time. 

328 global _END_TIME_OF_LAST_ASYNC_WRITE 

329 with _END_TIME_OF_LAST_ASYNC_WRITE_LOCK: 

330 metrics.AddTrainingTimeSaved( 

331 api_label=_ASYNC_CHECKPOINT, 

332 microseconds=_get_duration_microseconds( 

333 _END_TIME_OF_LAST_ASYNC_WRITE, async_save_start_time)) 

334 _END_TIME_OF_LAST_ASYNC_WRITE = async_save_start_time 

335 logging.info("Async save thread reached the end of the execution.") 

336 

337 def _copy_for_variable(self, original_var): 

338 """Create a new instance for the input trackable. 

339 

340 Args: 

341 original_var: Input Variable object to be copied. 

342 """ 

343 op_device = pydev.DeviceSpec.from_string(original_var.device).replace( 

344 device_type="CPU", device_index=0).to_string() 

345 with ops.device(op_device): 

346 new_var = UninitializedVariable( 

347 trainable=original_var.trainable, 

348 shape=original_var.shape, 

349 dtype=original_var.dtype, 

350 name=original_var._shared_name) # pylint: disable=protected-access 

351 self._object_map[original_var] = new_var 

352 

353 def _copy_for_sharded_variable(self, original_var): 

354 """Create a new instance for the input ShardedVariable. 

355 

356 Args: 

357 original_var: Input ShardedVariable object to be copied. 

358 """ 

359 copied_vars = [] 

360 for v in original_var._variables: # pylint: disable=protected-access 

361 self._copy_for_variable(v) 

362 copied_vars.append(self._object_map[v]) 

363 self._object_map[original_var] = ShardedVariable( 

364 copied_vars, name=original_var.name) 

365 

366 def _copy_trackable(self, original_trackable): 

367 """Create a new instance for the input trackable. 

368 

369 Args: 

370 original_trackable: The trackable instance to be copied. 

371 

372 Raises: 

373 AttributeError: if the input trackable is not Variable or ShardedVariable. 

374 """ 

375 if isinstance(original_trackable, ShardedVariable): 

376 self._copy_for_sharded_variable(original_trackable) 

377 elif isinstance(original_trackable, Variable): 

378 self._copy_for_variable(original_trackable) 

379 else: 

380 raise AttributeError("Only Variable or ShardedVariable can be copied.") 

381 

382 def _handle_tpu_embedding(self, tpu_embedding): 

383 """Handle TPUEmbedding. 

384 

385 Args: 

386 tpu_embedding: TPUEmbedding object to be handled. 

387 

388 Raises: 

389 AttributeError: if the input trackable is not TPUEmbedding type. 

390 """ 

391 if not hasattr( 

392 tpu_embedding, _TPU_EMBEDDING_ATTR 

393 ) or not callable(tpu_embedding._create_copy_for_async_checkpoint): # pylint: disable=protected-access 

394 raise AttributeError( 

395 "Expecting TPUEmbedding type; got %s" % type(tpu_embedding) 

396 ) 

397 

398 # Create a dummy TPUEmbedding object and add it to the object_map. This is 

399 # to prevent the TPUEmbedding's save_callback from being triggered because 

400 # the embedding values have already being retrieved by AsyncCheckpoint. 

401 # pylint: disable=protected-access 

402 new_embedding = tpu_embedding._create_copy_for_async_checkpoint( 

403 feature_config=tpu_embedding._feature_config, 

404 optimizer=tpu_embedding._table_config[0] 

405 if tpu_embedding._table_config 

406 else None, 

407 pipeline_execution_with_tensor_core=tpu_embedding._pipeline_execution_with_tensor_core, 

408 ) 

409 self._object_map[tpu_embedding] = new_embedding 

410 # pylint: enable=protected-access 

411 

412 if tpu_embedding not in self._tpu_embedding_objects: 

413 self._tpu_embedding_objects.append(tpu_embedding) 

414 

415 @property 

416 def save_counter(self): 

417 """An integer variable numbering the checkpoint events. 

418 

419 This is maintained by the underlying tf.train.Checkpoing object employed by 

420 AsyncCheckpoint class. The number starts at 0 and gets incremented for each 

421 checkpoint event. 

422 

423 Returns: 

424 The save counter variable. 

425 """ 

426 # TODO(sagunb): Improve the solution for initializing save_counter. 

427 # If save_counter() is called before all the variables are created, 

428 # self._ensure_initialized() would construct the object_map without some 

429 # variables that need to be checkpointed, e.g., slot variables. 

430 self._ensure_initialized() 

431 return self._checkpoint.save_counter 

432 

433 def write(self, save_path, options=None): 

434 """Save the checkpointed variables. 

435 

436 Args: 

437 save_path: The file prefix of the checkpoint file. 

438 options: Optional CheckpointOption instance. 

439 

440 Returns: 

441 The full path of the checkpoint file. 

442 """ 

443 self._write(save_path, options) 

444 

445 def _write(self, save_path, options=None, write_done_callback=None): 

446 """Save the checkpointed variables. 

447 

448 This method has exactly the same logic as save(), except it does not 

449 increment the underlying save_counter, which is done by the caller, e.g., 

450 CheckpointManager. 

451 

452 Args: 

453 save_path: The file prefix of the checkpoint file. 

454 options: Optional CheckpointOption instance. 

455 write_done_callback: Optional callback function executed after the async 

456 write is done. 

457 

458 Returns: 

459 The full path of the checkpoint file. 

460 """ 

461 self._ensure_initialized() 

462 

463 write_start_time = time.time() 

464 

465 # Copy the variable values to the host CPU. 

466 if self._writer_sem.acquire(): 

467 self._copy_to_cpu() 

468 

469 # Surface the error from the async thread, if any. 

470 # This step should come after the sem acquision step in the above, so that 

471 # it makes sure it waits until the previous async save finishes storing the 

472 # error. 

473 self._check_async_thread_error() 

474 

475 # Trigger the async thread to checkpoint the cpu-copied variables. 

476 # Need to wait until the weight copying finishes before checkpoint save. 

477 context.async_wait() 

478 self._save_file_prefix = save_path 

479 self._use_checkpoint_save = False 

480 

481 # Ensure that we do not request async checkpointing to the underlying 

482 # checkpointer as this could lead to an infinite loop. 

483 self._checkpoint_options = copy.copy(options) if options else None 

484 if self._checkpoint_options: 

485 self._checkpoint_options.experimental_enable_async_checkpoint = False 

486 

487 self._async_write_done_callback = write_done_callback 

488 self._reader_sem.release() 

489 

490 write_end_time = time.time() 

491 metrics.AddCheckpointWriteDuration( 

492 api_label=_ASYNC_CHECKPOINT, 

493 microseconds=_get_duration_microseconds(write_start_time, 

494 write_end_time)) 

495 

496 return save_path 

497 

498 def save(self, save_path, options=None): 

499 """Save the checkpointed variables. 

500 

501 Args: 

502 save_path: The file prefix of the checkpoint file. 

503 options: Optional CheckpointOption instance. 

504 

505 Returns: 

506 The full path of the checkpoint file. 

507 """ 

508 # If this is the first time that AsyncCheckpoint.save() is called, 

509 # initialize the cpu-copied variables and create the pair-wise mapping 

510 # between the original model variables and the cpu-copied variables. 

511 # 

512 # This is not performed in the initializer because some variables, e.g., 

513 # slot variables of the optimizer, were not created until actually running 

514 # the train function, so we could only get the complete list of the 

515 # variables after some train steps were run. 

516 self._ensure_initialized() 

517 

518 save_start_time = time.time() 

519 

520 # Copy the variable values to the host CPU. 

521 if self._writer_sem.acquire(): 

522 self._copy_to_cpu() 

523 

524 # Surface the error from the async thread, if any. 

525 # This step should come after the sem acquision step in the above, so that 

526 # it makes sure it waits until the previous async save finishes storing the 

527 # error. 

528 self._check_async_thread_error() 

529 

530 # Retrieve the save counter from the underlying checkpoint object to 

531 # re-construct the full path of the checkpoint file. 

532 # This step has to happen before triggerting the underlying checkpoint; 

533 # otherwise, the save_counter value may or may not have been updated. 

534 save_counter = self._checkpoint.save_counter.numpy() + 1 

535 full_path = "{}-{}".format(save_path, save_counter) 

536 

537 # Trigger the async thread to checkpoint the cpu-copied variables. 

538 # Need to wait until the weight copying finishes before checkpoint save. 

539 context.async_wait() 

540 self._save_file_prefix = save_path 

541 self._use_checkpoint_save = True 

542 

543 # Ensure that we do not request async checkpointing to the underlying 

544 # checkpointer as this could lead to an infinite loop. 

545 self._checkpoint_options = copy.copy(options) if options else None 

546 if self._checkpoint_options: 

547 self._checkpoint_options.experimental_enable_async_checkpoint = False 

548 

549 self._reader_sem.release() 

550 

551 save_end_time = time.time() 

552 metrics.AddCheckpointWriteDuration( 

553 api_label=_ASYNC_CHECKPOINT, 

554 microseconds=_get_duration_microseconds(save_start_time, save_end_time)) 

555 

556 return full_path 

557 

558 def read(self, save_path, options=None): 

559 """Restore the checkpointed variables. 

560 

561 This method has exactly the same logic as restore(). This method is 

562 implemented only to fulfill the duty of subclassing tf.train.Checkpoint. 

563 

564 Args: 

565 save_path: The full name of the checkpoint file to be restored. 

566 options: CheckpointOption instance. 

567 

568 Returns: 

569 A load status object, which can be used to make assertions about the 

570 status of a checkpoint restoration. See tf.train.Checkpoint.restore() 

571 for more details. 

572 """ 

573 return self.restore(save_path, options) 

574 

575 def restore(self, save_path, options=None): 

576 """Restore the checkpointed variables. 

577 

578 Args: 

579 save_path: The full name of the checkpoint file to be restored. 

580 options: CheckpointOption instance. 

581 

582 Returns: 

583 A load status object, which can be used to make assertions about the 

584 status of a checkpoint restoration. See tf.train.Checkpoint.restore() 

585 for more details. 

586 """ 

587 # Ensure that we do not request async checkpointing to the underlying 

588 # checkpointer as this could lead to an infinite loop. 

589 self._checkpoint_options = ( 

590 copy.copy(options) if options else self._checkpoint_options) 

591 if self._checkpoint_options: 

592 self._checkpoint_options.experimental_enable_async_checkpoint = False 

593 

594 # Wait for any ongoing checkpoint event to finish. 

595 with self._writer_sem: 

596 # If _checkpoint has not been initialized yet, it means the restore() is 

597 # called right after the coordinator is restarted. We directly restore 

598 # the checkpointed items through tf.train.Checkpoint.restore(). 

599 if self._checkpoint is None: 

600 tmp_checkpoint = self._checkpointer_impl(**self._checkpoint_items) 

601 return tmp_checkpoint.restore(save_path, self._checkpoint_options) 

602 

603 # Restore the values of the cpu-copied variables. 

604 status = self._checkpoint.restore(save_path, self._checkpoint_options) 

605 

606 # Restore the values of the original model. 

607 self._copy_from_cpu() 

608 return status 

609 

610 def sync(self): 

611 """Sync on any ongoing save or restore events.""" 

612 with self._writer_sem: 

613 logging.info("Sync on ongoing save/restore.")