Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/failure_handling/failure_handling.py: 21%

357 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Module for `PreemptionCheckpointHandler`. 

16 

17This is currently under development and the API is subject to change. 

18 

19PreemptionCheckpointHandler reduces loss of training progress caused by 

20termination (preemption or maintenance) of workers in multi-worker synchronous 

21training and avoid surfacing an error indistinguishable from application errors 

22to the job scheduler or users. 

23""" 

24import os 

25import signal 

26import sys 

27import threading 

28import time 

29 

30from tensorflow.core.distributed_runtime.preemption import gen_check_preemption_op 

31from tensorflow.python.checkpoint import checkpoint as checkpoint_lib 

32from tensorflow.python.checkpoint import checkpoint_context 

33from tensorflow.python.checkpoint import checkpoint_management 

34from tensorflow.python.distribute import distribute_lib 

35from tensorflow.python.distribute import multi_worker_util 

36from tensorflow.python.distribute.failure_handling import failure_handling_util 

37from tensorflow.python.eager import context 

38from tensorflow.python.framework import constant_op 

39from tensorflow.python.framework import dtypes 

40from tensorflow.python.framework import errors 

41from tensorflow.python.lib.io import file_io 

42from tensorflow.python.ops import variables 

43from tensorflow.python.platform import gfile 

44from tensorflow.python.platform import tf_logging as logging 

45from tensorflow.python.util import tf_contextlib 

46from tensorflow.python.util.deprecation import deprecated 

47from tensorflow.python.util.tf_export import tf_export 

48from tensorflow.tools.docs import doc_controls 

49 

50 

51_INITIAL_RUN_COUNT_KEY = 'RUN_TO_CHECKPOINT' 

52_FINAL_RUN_COUNT_KEY = 'LAST_RUN_TO_CHECKPOINT' 

53# This key is used to guarantee that only one worker (and it's the earliest 

54# one that receives a preemption signal) sets _received_own_sigterm, 

55# leads the step resolution, and controls the grace period timeline. 

56_PREEMPTION_WORKER_KEY = 'TERMINATED_WORKER' 

57_ACKNOWLEDGE_KEY = 'RECEIVED_SIGNAL' 

58_ITERATION_VARIABLE = 'checkpointed_runs' 

59_STOP_WATCHING_CLUSTER_VALUE = 'STOP_WATCHER' 

60PREEMPTION_KEY = 'TF_DEFAULT_PREEMPTION_NOTICE_KEY' 

61 

62 

63# TODO(wxinyi): add type annotations. 

64def _non_chief_checkpoint_dir(checkpoint_dir, task_id): 

65 """Returns a directory for non-chief worker to save checkpoint.""" 

66 dirpath = os.path.dirname(checkpoint_dir) 

67 base = os.path.basename(checkpoint_dir) 

68 base_dirpath = 'workertemp_' + str(task_id) 

69 dirpath = os.path.join(dirpath, base_dirpath) 

70 file_io.recursive_create_dir_v2(dirpath) 

71 return os.path.join(dirpath, base) 

72 

73 

74@tf_export('distribute.experimental.TerminationConfig', v1=[]) 

75class TerminationConfig(object): 

76 """Customization of `PreemptionCheckpointHandler` for various platforms. 

77 

78 A `TerminationConfig` can be created and passed to a 

79 `tf.distribute.experimental.PreemptionCheckpointHandler` to provide 

80 customization based on the platform. It can deliver three pieces of 

81 information: 

82 

83 * How to decide if there is a termination event soon 

84 

85 The form of termination notification and how to fetch it vary across 

86 platforms. Thus `PreemptionCheckpointHandler` may take a user-defined 

87 function, `termination_watcher_fn`, and execute it repeatedly to check for 

88 termination notification. `termination_watcher_fn` should be a function 

89 that returns `True` if a termination notification is available and 

90 `False` otherwise. The function should be lightweight and non-blocking so that 

91 resources can be cleaned up properly if no termination signal is ever raised 

92 until training finishes. 

93 

94 * How to exit the program 

95 

96 A user can configure this through the `exit_fn`, which 

97 `PreemptionCheckpointHandler` executes after saving the checkpoint to exit the 

98 training program gracefully. For `tf.distribute.MultiWorkerMirroredStrategy`, 

99 a restart is necessary to reset the program's state. However, having a 

100 customized `exit_fn` may facilitate the restart and smoothen the training 

101 experience. How so? Maybe the platform has an agreement to a `RESTART_CODE` 

102 recognized as a program auto-restart signal, or maybe the user has a 

103 coordinating script that starts up the training, in which they can configure 

104 the program to auto-restart if it ever exits with this `RESTART_CODE`. In both 

105 cases, configuring the `exit_fn` to be `sys.exit(RESTART_CODE)` makes the 

106 training seamless. 

107 

108 * How long does `PreemptionCheckpointHandler` have from receiving a 

109 termination event notice till the actual termination 

110 

111 Some platforms have a gap time as long as one hour or so. In these cases, 

112 there is the option to utilize this gap time for training as much as possible 

113 before saving a checkpoint and exiting. This can be achieved by passing the 

114 `grace_period` argument a nonzero value. Note, for a user with a grace period 

115 that is not multiple times longer than their checkpoint writing time (e.g., 

116 three times or more), we advise not to configure this argument, in which case 

117 `PreemptionCheckpointHandler` will directly save a checkpoint and exit. 

118 

119 

120 **The default behavior**: 

121 

122 * For Google Borg Platform: 

123 * Automatically know how to detect preemption signal 

124 * Exit with a platform-recognized restart code 

125 * Save a checkpoint and exit immediately 

126 

127 * For Google Cloud Platform: 

128 * Automatically know how to detect maintenance signal. 

129 * Exit with a code (User may configure this) 

130 * Automatically utilized the extended training period before save and exit 

131 

132 * For Other platform: 

133 * If `termination_watcher_fn` is `None`, we will treat `signal.SIGTERM` as 

134 a termination signal. 

135 * If `exit_fn` is not configured, we exit the program with an arbitrary 

136 code. 

137 * If `grace_period` is not configured, we will wrap up the current 

138 training step, save a checkpoint, and exit the program as soon as we 

139 receive the termination signal. 

140 """ 

141 

142 def __init__(self, 

143 termination_watcher_fn=None, 

144 exit_fn=None, 

145 grace_period=None, 

146 save_fn=None): 

147 """Creates a `TerminationConfig` object. 

148 

149 Args: 

150 termination_watcher_fn: a function to execute repeatedly that returns 

151 `True` if a preemption signal is available and False otherwise. The 

152 function cannot block until a preemption signal is available, which 

153 prevents proper cleanup of the program. A change is **NOT** recommended 

154 for users on Google Borg or Google Cloud Platform. 

155 exit_fn: a function to execute after a checkpoint is saved and before the 

156 preemption happens. Usually, it should be in the form of 

157 `lambda: sys.exit(RESTART_CODE)`, where `RESTART_CODE` varies by 

158 platform. A change is **NOT** recommended for users on Google Borg. 

159 Users on Google Cloud Platform may configure it to use a customized 

160 `RESTART_CODE`. 

161 grace_period: the length of time between receiving a preemption signal and 

162 the actual preemption. A change is **NOT** recommended for users on 

163 Google Borg, Google Cloud Platform, or users with a short grace period. 

164 save_fn: an optional function letting you configure how to save a 

165 checkpoint. This is useful if you'd like to pass extra argument to 

166 `tf.train.CheckpointManager.save` or `tf.train.Checkpoint.save`. By 

167 default, if not configured, the API will save checkpoint without extra 

168 arguments. 

169 """ 

170 self.termination_watcher_fn = termination_watcher_fn 

171 self.exit_fn = exit_fn 

172 self.grace_period = grace_period 

173 self.save_fn = save_fn 

174 

175 

176# TODO(wxinyi): add some tests for TerminationConfig. 

177# TODO(wxinyi): configure the exit function based on device type (GPU or TPU). 

178class GcpGpuTerminationConfig(TerminationConfig): 

179 """Configurations for GCP GPU VM.""" 

180 

181 def __init__( # pylint: disable=super-init-not-called 

182 self, 

183 termination_watcher_fn=None, 

184 exit_fn=None, 

185 grace_period=None, 

186 save_fn=None, 

187 ): 

188 self.termination_watcher_fn = ( 

189 termination_watcher_fn 

190 or failure_handling_util.termination_watcher_function_gce 

191 ) 

192 self.exit_fn = exit_fn or failure_handling_util.gce_exit_fn 

193 self.grace_period = ( 

194 grace_period if grace_period or grace_period == 0 else 

195 failure_handling_util.GRACE_PERIOD_GCE) 

196 self.save_fn = save_fn 

197 

198 

199class GcpCpuTerminationConfig(TerminationConfig): 

200 """Configurations for GCP CPU VM.""" 

201 

202 def __init__( # pylint: disable=super-init-not-called 

203 self, 

204 termination_watcher_fn=None, 

205 exit_fn=None, 

206 grace_period=None, 

207 save_fn=None): 

208 self.termination_watcher_fn = termination_watcher_fn or failure_handling_util.termination_watcher_function_gce 

209 self.exit_fn = exit_fn or failure_handling_util.gce_exit_fn 

210 self.grace_period = grace_period or 0 

211 self.save_fn = save_fn 

212 

213 

214class BorgTerminationConfig(TerminationConfig): 

215 """Configurations for Borg.""" 

216 

217 def __init__( # pylint: disable=super-init-not-called 

218 self, 

219 termination_watcher_fn=None, 

220 exit_fn=None, 

221 grace_period=None, 

222 save_fn=None): 

223 self.termination_watcher_fn = termination_watcher_fn 

224 default_exit_fn = lambda: sys.exit(42) 

225 self.exit_fn = exit_fn or default_exit_fn 

226 self.grace_period = grace_period or 0 

227 self.save_fn = save_fn 

228 

229 

230class BorgTPUTerminationConfig(TerminationConfig): 

231 """Configurations for Borg.""" 

232 

233 def __init__( # pylint: disable=super-init-not-called 

234 self, 

235 termination_watcher_fn=None, 

236 exit_fn=None, 

237 grace_period=None, 

238 save_fn=None): 

239 self.termination_watcher_fn = termination_watcher_fn 

240 self.exit_fn = exit_fn or failure_handling_util.default_tpu_exit_fn 

241 self.grace_period = grace_period or 0 

242 self.save_fn = save_fn 

243 

244 

245def _complete_config_for_environment(platform_device, termination_config): 

246 """Complete un-filled fields of TerminationConfig based on platform.""" 

247 if not termination_config: 

248 termination_config = TerminationConfig() 

249 

250 if platform_device is failure_handling_util.PlatformDevice.GCE_GPU: 

251 return GcpGpuTerminationConfig(termination_config.termination_watcher_fn, 

252 termination_config.exit_fn, 

253 termination_config.grace_period, 

254 termination_config.save_fn) 

255 

256 elif platform_device is failure_handling_util.PlatformDevice.GCE_CPU: 

257 return GcpCpuTerminationConfig(termination_config.termination_watcher_fn, 

258 termination_config.exit_fn, 

259 termination_config.grace_period, 

260 termination_config.save_fn) 

261 

262 elif platform_device is failure_handling_util.PlatformDevice.INTERNAL_TPU: 

263 return BorgTPUTerminationConfig(termination_config.termination_watcher_fn, 

264 termination_config.exit_fn, 

265 termination_config.grace_period, 

266 termination_config.save_fn) 

267 

268 else: 

269 # The default we chose are the same as the ones used by Borg. So we just 

270 # return this. 

271 return BorgTerminationConfig( 

272 termination_config.termination_watcher_fn, 

273 termination_config.exit_fn, termination_config.grace_period, 

274 termination_config.save_fn) 

275 

276 

277# TODO(wxinyi): add release updates. 

278# Implementation: 

279# Each worker will create its own PreemptionCheckpointHandler instance, and the 

280# instances communicate through coordination services. Each 

281# PreemptionCheckpointHandler conduct three tasks in parallel: 

282# - Watches out for its own preemption signal. (_poll_termination_signal_thread) 

283# - Watches out for a step key from the coordination service made available 

284# by any member in the cluster (_cluster_wise_termination_watcher_thread) 

285# - The main thread for training. 

286# 

287# The life cycle of a PreemptionCheckpointHandler is as below: 

288# 

289# It starts two threads as two watcher as described above. And it starts 

290# training. Each time before it starts a training step, it will check if any 

291# information has been made available by the two watchers: The 

292# _poll_termination_signal_thread will be in charge of the _received_own_sigterm 

293# event, the _cluster_wise_termination_watcher_thread will be in charge of the 

294# _received_checkpoint_step event. 

295# 

296# If at any point the local worker receives a preemption signal, 

297# _poll_termination_signal_thread will set _received_own_sigterm. 

298# Next time before it attempts to run a training step, it will deal with the 

299# event, by setting its current finished step + 1 as the step after which a 

300# checkpoint should be saved and make it available to all the workers through 

301# the coordination service. It will then continue training. 

302# 

303# This step key will be picked up by the other watcher, 

304# _cluster_wise_termination_watcher_thread, both on the worker to be preempted 

305# and other workers. And it will set the _received_checkpoint_step event. 

306# Now, if there is a long grace period before the training 

307# has to terminate (e.g., an hour), we would like to keep training and save a 

308# checkpoint again right before the termination. Thus this watcher thread will 

309# move on to watch out for a final step-to-save key. Otherwise, 

310# it has finished all the task to do. 

311# 

312# Back to the main training thread. Again, before the next training step, the 

313# PreemptionCheckpointHandler found that _received_checkpoint_step is set. If 

314# the local worker has not finished the required step after which to save a 

315# checkpoint, it will not do anything. Continue training and it will revisit 

316# after another step. If the step is met, then it will save a checkpoint, 

317# which requires participation of all workers. 

318# 

319# After this checkpoint is saved, if there is NO long grace period, all workers 

320# will just exit. If there is, all workers will enter a grace period countdown 

321# phase (_final_checkpoint_countdown) and clear the _received_checkpoint_step 

322# event. They will then continue training. 

323# 

324# For the worker to be preempted, during this countdown period, it will check 

325# whether the grace period is almost ending before its every step. If not, 

326# nothing needs to be done. If so, it will again set a step-to-save key and made 

327# it available to all workers. This is still watched by 

328# _cluster_wise_termination_watcher_thread and gestured by 

329# _received_checkpoint_step. A similar process is repeated: all workers save 

330# a checkpoint at an agreed step. And after they finish saving, they recognize 

331# that they have finished a countdown period for an extended grace period, and 

332# they all exit. 

333# 

334# When the program restarts and PreemptionCheckpointHandler object is created, 

335# it will restore the checkpoint. 

336@tf_export('distribute.experimental.PreemptionCheckpointHandler', v1=[]) 

337class PreemptionCheckpointHandler(object): 

338 # pylint: disable=line-too-long 

339 """Preemption and error handler for synchronous training. 

340 

341 Note: This API only supports use with 

342 `tf.distribute.MultiWorkerMirroredStrategy` and `tf.distribute.TPUStrategy`. 

343 

344 A `PreemptionCheckpointHandler` coordinates all workers to save a checkpoint 

345 upon receiving a preemption signal. It also helps disseminate application 

346 error messages accurately among the cluster. When a 

347 `PreemptionCheckpointHandler` object is created, it restores values from 

348 the latest checkpoint file if any exists. 

349 

350 Right after the initialization, the object starts to watch out for termination 

351 signal for any member in the cluster. If receiving a signal, the next time the 

352 worker executes `PreemptionCheckpointHandler.run`, the 

353 `PreemptionCheckpointHandler` will align all workers to save a checkpoint. 

354 Then, if an `exit_fn` is configured via 

355 `tf.distribute.experimental.TerminationConfig`, it will be invoked. Otherwise, 

356 the process will simply exit and later the platform should restart it. 

357 

358 Note: We advise users of `tf.distribute.MultiWorkerMirroredStrategy` who 

359 choose to configure their 

360 own `exit_fn` in `tf.distribute.experimental.TerminationConfig` to include a 

361 `sys.exit(CODE_OR_MESSAGE)` in the `exit_fn` so that after the restart, all 

362 workers can initialize communication services correctly. For users of 

363 `tf.distribute.TPUStrategy`, if they do not wish to do a cluster restart but 

364 would like an in-process restart (i.e., keep the coordinator alive and re-do 

365 the steps to connect to cluster, initialize TPU system, and make the 

366 `TPUStrategy` object), they could configure the `exit_fn` to a no-op. 

367 

368 For users of `tf.distribute.MultiWorkerMirroredStrategy`, the core API is 

369 `PreemptionCheckpointHandler.run`: 

370 

371 ```python 

372 strategy = tf.distribute.MultiWorkerMirroredStrategy() 

373 

374 trained_epoch = tf.Variable(initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch') 

375 step_in_epoch = tf.Variable(initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='step_in_epoch') 

376 

377 with strategy.scope(): 

378 dataset, model, optimizer = ... 

379 

380 checkpoint = tf.train.Checkpoint(optimizer=optimizer, 

381 model=model, 

382 trained_epoch=trained_epoch, 

383 step_in_epoch=step_in_epoch) 

384 

385 preemption_checkpoint_handler = tf.distribute.experimental.PreemptionCheckpointHandler(cluster_resolver, checkpoint, checkpoint_dir) 

386 

387 while trained_epoch.numpy() < NUM_EPOCH: 

388 

389 while step_in_epoch.numpy() < STEPS_PER_EPOCH: 

390 

391 # distributed_train_function contains a call to strategy.run. 

392 loss += preemption_checkpoint_handler.run(distributed_train_function, args=(next(iterator),)) 

393 # For users of MultiWorkerMirroredStrategy, usually 

394 # STEPS_PER_TRAIN_FUNCTION = 1. 

395 step_in_epoch.assign_add(STEPS_PER_TRAIN_FUNCTION) 

396 ... 

397 

398 epoch.assign_add(1) 

399 step_in_epoch.assign(0) 

400 ``` 

401 

402 For users of `tf.distribute.TPUStrategy`, the core APIs are 

403 `PreemptionCheckpointHandler.run` and 

404 `PreemptionCheckpointHandler.watch_preemption_scope`: 

405 

406 ```python 

407 

408 strategy = tf.distribute.TPUStrategy(tpu_cluster_resolver) 

409 

410 # Rest of TPU init omitted, see documentation for TPUSTrategy. 

411 

412 with preemption_checkpoint_handler.watch_preemption_scope(): 

413 while trained_epoch.numpy() < NUM_EPOCH: 

414 

415 while step_in_epoch.numpy() < STEPS_PER_EPOCH: 

416 

417 # distributed_train_function contains a call to strategy.run. 

418 loss += preemption_checkpoint_handler.run(distributed_train_function, args=(next(iterator),)) 

419 

420 # For users of TPUStrategy, usually STEPS_PER_TRAIN_FUNCTION >> 1 since 

421 # clustering multiple steps within a tf.function amortizes the overhead 

422 # of launching a multi-device function on TPU Pod. 

423 step_in_epoch.assign_add(STEPS_PER_TRAIN_FUNCTION) 

424 ... 

425 

426 epoch.assign_add(1) 

427 step_in_epoch.assign(0) 

428 ``` 

429 

430 Not all interruptions come with advance notice so that the 

431 `PreemptionCheckpointHandler` can handle them, e.g., those caused by hardware 

432 failure. For a user who saves checkpoints for these cases themselves outside 

433 the `PreemptionCheckpointHandler`, if they are using a 

434 `tf.train.CheckpointManager`, pass it as the 

435 `checkpoint_or_checkpoint_manager` argument to the 

436 `PreemptionCheckpointHandler`. If they do not have a 

437 `tf.train.CheckpointManager` but are directly working with 

438 `tf.train.Checkpoint`, we advise saving the checkpoints in the directory 

439 that's passed as the `checkpoint_dir` argument. In this way, at the program 

440 beginning, `PreemptionCheckpointHandler` can restore the latest checkpoint 

441 from the directory, no matter it's saved by the user themselves or saved by 

442 the `PreemptionCheckpointHandler` before preemption happens. 

443 

444 **A note on the platform:** 

445 

446 `PreemptionCheckpointHandler` can only handle the kind of termination with 

447 advance notice. For now, the API recognizes the termination signal for CPU, 

448 GPU, and TPU on Google Borg and CPU and GPU on the Google Cloud Platform. In 

449 these cases, `PreemptionCheckpointHandler` will automatically adopt the 

450 correct preemption/maintenance notification detection mechanism. Users of 

451 other platforms can configure a detection monitoring behavior through the 

452 `tf.distribute.experimental.TerminationConfig`. Customization for the exit 

453 behavior and grace period length could also be done here. 

454 """ 

455 # pylint: enable=line-too-long 

456 

457 def __init__(self, 

458 cluster_resolver, 

459 checkpoint_or_checkpoint_manager, 

460 checkpoint_dir=None, 

461 termination_config=None): 

462 """Creates the `PreemptionCheckpointHandler`. 

463 

464 Args: 

465 cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver` 

466 object. You may also obtain it through the `cluster_resolver` attribute 

467 of the distribution strategy in use. 

468 checkpoint_or_checkpoint_manager: a `tf.train.CheckpointManager` or a 

469 `tf.train.Checkpoint`. If you are using a `tf.train.CheckpointManager` 

470 to manage checkpoints outside the `PreemptionCheckpointHandler` for 

471 backup purpose as well, pass it as `checkpoint_or_checkpoint_manager` 

472 argument. Otherwise, pass a `tf.train.Checkpoint` and the 

473 `PreemptionCheckpointHandler` will create 

474 a `tf.train.CheckpointManager` to manage it in the `checkpoint_dir`. 

475 checkpoint_dir: a directory where the `PreemptionCheckpointHandler` saves 

476 and restores checkpoints. When a `PreemptionCheckpointHandler` is 

477 created, the latest checkpoint in the `checkpoint_dir` will be restored. 

478 (This is not needed if a `tf.train.CheckpointManager` instead of a 

479 `tf.train.Checkpoint` is passed as the 

480 `checkpoint_or_checkpoint_manager` argument.) 

481 termination_config: optional, a 

482 `tf.distribute.experimental.TerminationConfig` object to configure for a 

483 platform other than Google Borg or GCP. 

484 """ 

485 # TODO(wxinyi): Maybe make checkpoint_or_checkpoint_manager optional if 

486 # save_fn is passed. For now it's still useful for restore. 

487 if isinstance(checkpoint_or_checkpoint_manager, 

488 checkpoint_lib.Checkpoint) and not checkpoint_dir: 

489 raise errors.InvalidArgumentError('When a checkpoint is passed, a ' 

490 'checkpoint_dir must be passed as well' 

491 '.') 

492 

493 self._cluster_resolver = cluster_resolver 

494 self._termination_config = termination_config 

495 self._checkpoint_or_checkpoint_manager = checkpoint_or_checkpoint_manager 

496 self._checkpoint_dir = checkpoint_dir 

497 

498 self._platform_device = failure_handling_util.detect_platform() 

499 

500 completed_termination_config = _complete_config_for_environment( 

501 self._platform_device, self._termination_config 

502 ) 

503 self._termination_watcher_fn = ( 

504 completed_termination_config.termination_watcher_fn 

505 ) 

506 self._exit_fn = completed_termination_config.exit_fn 

507 self._grace_period = completed_termination_config.grace_period 

508 self._save_fn = completed_termination_config.save_fn 

509 

510 self._local_mode = True 

511 if self._platform_device in ( 

512 failure_handling_util.PlatformDevice.GCE_TPU, 

513 failure_handling_util.PlatformDevice.GCE_CPU, 

514 ): 

515 # While running MultiWorkerMirroredStrategy training with GPUs and CPUs 

516 # are the same on Borg, GCE CPU VM and GPU VM are different in terms 

517 # of live migration, grace period, etc. We can make it work upon request. 

518 logging.warning( 

519 'PreemptionCheckpointHandler does not support usage with ' 

520 'TPU or CPU device on GCP.' 

521 ) 

522 

523 elif ( 

524 self._platform_device 

525 == failure_handling_util.PlatformDevice.INTERNAL_TPU 

526 ): 

527 self._initialize_for_tpu_strategy() 

528 

529 else: 

530 if cluster_resolver and 'ps' in cluster_resolver.cluster_spec().as_dict(): 

531 raise NotImplementedError( 

532 'PreemptionCheckpointHandler does not support' 

533 'usage with tf.distribute.experimental.ParameterServerStrategy.' 

534 ) 

535 

536 self._initialize_for_mirrored_and_multi_worker_mirrored() 

537 

538 logging.info('PreemptionCheckpointHandler initialized or restored.') 

539 

540 def _initialize_for_tpu_strategy(self): 

541 """Makes configurations for using the handler with TPUStrategy.""" 

542 self._is_chief = True 

543 self._poll_termination_signal_thread = None 

544 self._cluster_wise_termination_watcher_thread = None 

545 self._maybe_create_checkpoint_manager() 

546 self._read_checkpoint_manager.restore_or_initialize() 

547 self._run_counter = 0 

548 

549 def _initialize_for_mirrored_and_multi_worker_mirrored(self): 

550 """Makes configurations and start watchers for MS, MWMS, or OneDevice.""" 

551 if ( 

552 not self._cluster_resolver 

553 or not self._cluster_resolver.cluster_spec().jobs 

554 ): 

555 # For MirroredStrategy, OneDeviceStrategy, and local-mode 

556 # MultiWorkerMirroredStrategy, an empty cluster spec is passed, and 

557 # coordination service is not enabled nor is it needed (since 

558 # it's used for cross-worker communication). Thus we will directly name 

559 # the worker id and is_chief properties and also skip the 

560 # uploading/reading from coordination service logic. 

561 self._local_mode = True 

562 self._id_in_cluster = 'single_worker' 

563 self._is_chief = True 

564 else: 

565 self._local_mode = False 

566 self._id_in_cluster = str( 

567 multi_worker_util.id_in_cluster( 

568 self._cluster_resolver.cluster_spec(), 

569 self._cluster_resolver.task_type, 

570 self._cluster_resolver.task_id)) 

571 self._is_chief = multi_worker_util.is_chief( 

572 cluster_spec=self._cluster_resolver.cluster_spec(), 

573 task_type=self._cluster_resolver.task_type, 

574 task_id=self._cluster_resolver.task_id) 

575 # The number of calls to `PreemptionCheckpointHandler.run` when the latest 

576 # checkpoint was saved. 

577 self._checkpointed_runs = variables.Variable( 

578 initial_value=constant_op.constant(0, dtype=dtypes.int64), 

579 trainable=False, 

580 name=_ITERATION_VARIABLE) 

581 

582 self._maybe_create_checkpoint_manager() 

583 

584 if not hasattr(self._write_checkpoint_manager._checkpoint, # pylint: disable=protected-access 

585 _ITERATION_VARIABLE): 

586 setattr(self._write_checkpoint_manager._checkpoint, _ITERATION_VARIABLE, # pylint: disable=protected-access 

587 self._checkpointed_runs) 

588 

589 if not hasattr(self._read_checkpoint_manager._checkpoint, # pylint: disable=protected-access 

590 _ITERATION_VARIABLE): 

591 setattr(self._read_checkpoint_manager._checkpoint, _ITERATION_VARIABLE, # pylint: disable=protected-access 

592 self._checkpointed_runs) 

593 

594 self._read_checkpoint_manager.restore_or_initialize() 

595 

596 # grace period countdown. Set to True for all workers once they finish 

597 # timing saving a checkpoint. Once entering this phase, new 

598 # preemption/maintenance notice will not be handled, since the whole cluster 

599 # goes down as the worker who first initiates the grace period goes down. 

600 self._final_checkpoint_countdown = False 

601 

602 self._estimated_run_time = 0 

603 

604 # An internal step counter that's restored to checkpointed_iterations when 

605 # training is restored. It increments by one every time 

606 # `PreemptionCheckpointHandler.run` is called. Note that in this case, the 

607 # user must pass a single-step training function to 

608 # `PreemptionCheckpointHandler.run` instead of a multiple-step one. 

609 self._run_counter = self._checkpointed_runs.numpy() 

610 

611 # The worker itself has received preeption signal. 

612 self._received_own_sigterm = threading.Event() 

613 

614 # Some member (could be oneself) has received preemption signal, and the 

615 # step number to save a checkpoint has been aligned. 

616 self._received_checkpoint_step = threading.Event() 

617 

618 distribute_lib.distribution_strategy_input_api_counter.get_cell( 

619 self._platform_device.name, 

620 'PreemptionCheckpointHandler').increase_by(1) 

621 

622 if not self._local_mode: 

623 # When training is interrupted, we explicitly call the cleanup methods for 

624 # the thread watching for local worker's termination signal and the thread 

625 # watching for clusterwise information before we save a checkpoint and 

626 # exit. In the final chapter of the training where no interruption is 

627 # encountered, we rely on __del__ to clean up. However, there is no 

628 # guarantee when or whether __del__ is executed, thus we make the threads 

629 # daemon to avoid it preventing program from exit. 

630 self._cluster_wise_termination_watcher_thread = threading.Thread( 

631 target=self._watch_step_to_save_key, 

632 name='PeerTerminationWatcher-%s' % self._id_in_cluster, 

633 daemon=True) 

634 logging.info('Start watcher for peer\'s signal.') 

635 self._cluster_wise_termination_watcher_thread.start() 

636 

637 else: 

638 self._cluster_wise_termination_watcher_thread = None 

639 

640 self._poll_termination_signal_thread = None 

641 

642 if self._termination_watcher_fn: 

643 self._start_polling_for_termination_signal() 

644 else: 

645 self._start_watching_for_signal() 

646 

647 def _maybe_create_checkpoint_manager(self): 

648 """Create CheckpointManager(s) if a checkpoint is passed else take it.""" 

649 if isinstance(self._checkpoint_or_checkpoint_manager, 

650 checkpoint_management.CheckpointManager): 

651 self._read_checkpoint_manager = self._checkpoint_or_checkpoint_manager 

652 self._write_checkpoint_manager = self._checkpoint_or_checkpoint_manager 

653 self._api_made_checkpoint_manager = False 

654 else: 

655 self._api_made_checkpoint_manager = True 

656 # Make CheckpointManagers. MultiWorkerMirroredStrategy requires different 

657 # setup on chief and on other workers. 

658 self._read_checkpoint_manager = checkpoint_management.CheckpointManager( 

659 self._checkpoint_or_checkpoint_manager, 

660 directory=self._checkpoint_dir, 

661 max_to_keep=1) 

662 

663 if self._is_chief: 

664 self._write_checkpoint_manager = self._read_checkpoint_manager 

665 else: 

666 self._write_checkpoint_manager = ( 

667 checkpoint_management.CheckpointManager( 

668 self._checkpoint_or_checkpoint_manager, 

669 _non_chief_checkpoint_dir(self._checkpoint_dir, 

670 self._cluster_resolver.task_id), 

671 max_to_keep=1)) 

672 

673 def _start_watching_for_signal(self): 

674 logging.info('Start watcher for local signal.') 

675 signal.signal(signal.SIGTERM, self._sigterm_handler_fn) 

676 

677 def _start_polling_for_termination_signal(self): 

678 self._poll_termination_signal_thread_should_stop = threading.Event() 

679 self._poll_termination_signal_thread = threading.Thread( 

680 target=self._poll_termination_signal, 

681 name='WorkerTerminationSignalWatcher-%s' % self._id_in_cluster, 

682 daemon=True) 

683 logging.info('Start polling for termination signal.') 

684 self._poll_termination_signal_thread.start() 

685 

686 def _poll_termination_signal(self): 

687 """Poll maintenance notice and notify peers if receiving one.""" 

688 while True: 

689 if self._poll_termination_signal_thread_should_stop.is_set( 

690 ) or self._final_checkpoint_countdown: 

691 return 

692 if self._termination_watcher_fn(): 

693 break 

694 time.sleep(1) 

695 

696 self._maybe_set_received_own_sigterm() 

697 

698 def _maybe_set_received_own_sigterm(self): 

699 """Claim earliest preemption if no one else has done it before.""" 

700 if self._local_mode: 

701 logging.info('Received termination notice.', 

702 self._id_in_cluster) 

703 self._received_own_sigterm_time = time.time() 

704 self._received_own_sigterm.set() 

705 return 

706 

707 try: 

708 context.context().set_config_key_value(_PREEMPTION_WORKER_KEY, 

709 self._id_in_cluster) 

710 logging.info('Member %s has received termination notice.', 

711 self._id_in_cluster) 

712 self._received_own_sigterm_time = time.time() 

713 self._received_own_sigterm.set() 

714 

715 # This is to handle the case that a worker has received termination 

716 # notice but hasn't come to the next step to set the step key. Other 

717 # workers might receive a termination notice too, and attempt to set the 

718 # config key again, which causes this error. This can be safely ignored 

719 # since checkpoint should be saved as early as the earliest call is made. 

720 except errors.AlreadyExistsError: 

721 logging.info( 

722 ( 

723 'Member %s has received termination notice. But some other ' 

724 'worker has received it as well! Leaving' 

725 ' it to them to decide when to checkpoint. ' 

726 ), 

727 self._id_in_cluster, 

728 ) 

729 return 

730 

731 def _stop_poll_termination_signal_thread(self): 

732 if getattr(self, '_poll_termination_signal_thread', None): 

733 self._poll_termination_signal_thread_should_stop.set() 

734 self._poll_termination_signal_thread.join() 

735 

736 self._poll_termination_signal_thread = None 

737 logging.info("Shut down watcher for one's own termination signal") 

738 

739 def _stop_cluster_wise_termination_watcher_thread(self): 

740 """Stop the thread that is _watch_step_to_save_key.""" 

741 if getattr(self, '_cluster_wise_termination_watcher_thread', None): 

742 try: 

743 context.context().set_config_key_value( 

744 _INITIAL_RUN_COUNT_KEY, _STOP_WATCHING_CLUSTER_VALUE 

745 ) 

746 except (errors.AlreadyExistsError, errors.UnavailableError): 

747 # We'll ignore any error in the process of setting this key. There 

748 # certainly will be a AlreadyExistError since all workers are trying to 

749 # push this key. Or some worker might have exited already, leading to a 

750 # errors.UnavailableError or errors.AbortedError. 

751 pass 

752 except Exception as e: # pylint: disable=broad-except 

753 # We'll also ignore other errors since they are not important to the 

754 # process. 

755 logging.info('Ignoring error when shutting down ' 

756 '_stop_cluster_wise_termination_watcher_thread: ' + str(e)) 

757 

758 try: 

759 context.context().set_config_key_value(_FINAL_RUN_COUNT_KEY, 

760 _STOP_WATCHING_CLUSTER_VALUE) 

761 except (errors.AlreadyExistsError, errors.UnavailableError): 

762 pass 

763 

764 except Exception as e: # pylint: disable=broad-except 

765 logging.info('Ignoring error when shutting down ' 

766 '_stop_cluster_wise_termination_watcher_thread: ' + str(e)) 

767 

768 finally: 

769 self._cluster_wise_termination_watcher_thread.join() 

770 self._cluster_wise_termination_watcher_thread = None 

771 logging.info('Shut down watcher for peer\'s termination signal.') 

772 

773 def __del__(self): 

774 self._stop_cluster_wise_termination_watcher_thread() 

775 self._stop_poll_termination_signal_thread() 

776 

777 @property 

778 @deprecated(None, 

779 'Track steps using a tf.Variable saved in checkpoint instead.') 

780 @doc_controls.do_not_generate_docs 

781 def total_run_calls(self): 

782 """Returns the number of times `PreemptionCheckpointHandler.run` is called. 

783 

784 DEPRECATED: user should track total steps themselves, as this API provides 

785 little expressivity gain but could easily be misused and incurs extra 

786 synchronization cost for TPUStrategy users. 

787 

788 This value tracks the number of all calls to 

789 `PreemptionCheckpointHandler.run` including those before the program is 

790 restarted and the training is restored, by saving and reading the value in 

791 the checkpoint. A user can compute their total number of iterations 

792 by `PreemptionCheckpointHandler.total_run_calls * 

793 number_of_steps_in_train_function`, 

794 while `number_of_steps_in_train_function` should be one for 

795 `tf.distribute.MultiWorkerMirroredStrategy` users. They can also use this 

796 value to infer the starting epoch and step after training restores, as shown 

797 in the example above. 

798 """ 

799 if (self._platform_device == 

800 failure_handling_util.PlatformDevice.INTERNAL_TPU): 

801 raise NotImplementedError('Please create variables saved in checkpoint ' 

802 'to keep track of steps and epochs.') 

803 return self._run_counter 

804 

805 def run(self, 

806 distributed_train_function, 

807 *args, 

808 **kwargs): 

809 """Runs a training function with error and preemption handling. 

810 

811 This function handles the preemption signal from any peer in the cluster by 

812 saving the training progress and exiting gracefully. It will 

813 also broadcase any program error encountered during the execution of 

814 `distributed_train_function` to all workers so that they can raise the same 

815 error. 

816 

817 The `distributed_train_function` argument should be a distributed train 

818 function (i.e., containing a call to `tf.distribute.Strategy.run`). For 

819 `tf.distribute.MultiWorkerMirroredStrategy` users, we recommend passing in a 

820 single-step `distributed_train_function` to 

821 `PreemptionCheckpointHandler.run` so that the checkpoint can be saved in 

822 time in case a preemption signal or maintenance notice is sent. 

823 

824 Besides the preemption and error handling part, 

825 `PreemptionCheckpointHandler.run(distributed_train_function, *args, 

826 **kwargs)` has the same effect and output as 

827 `distributed_train_function(*args, **kwargs)`. `distributed_train_function` 

828 can return either some or no result. The following is a shortened example: 

829 

830 ```python 

831 

832 @tf.function 

833 def distributed_train_step(iterator): 

834 # A distributed single-step training function. 

835 

836 def step_fn(inputs): 

837 # A per-replica single-step training function. 

838 x, y = inputs 

839 ... 

840 return loss 

841 

842 per_replica_losses = strategy.run(step_fn, args=(next(iterator),)) 

843 return strategy.reduce( 

844 tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) 

845 

846 for epoch in range(preemption_handler.total_run_calls // STEPS_PER_EPOCH, 

847 EPOCHS_TO_RUN): 

848 iterator = iter(multi_worker_dataset) 

849 total_loss = 0.0 

850 num_batches = 0 

851 

852 for step in range(preemption_handler.total_run_calls % STEPS_PER_EPOCH, 

853 STEPS_PER_EPOCH): 

854 total_loss += preemption_handler.run(distributed_train_step) 

855 num_batches += 1 

856 

857 train_loss = total_loss / num_batches 

858 print('Epoch: %d, train_loss: %f.' %(epoch.numpy(), train_loss)) 

859 

860 train_accuracy.reset_states() 

861 ``` 

862 

863 Args: 

864 distributed_train_function: A (single-step) distributed training function. 

865 *args: args for `distributed_train_function`. 

866 **kwargs: kwargs for `distributed_train_function`. 

867 

868 Raises: 

869 Program error encountered by any member in the cluster while executing the 

870 `distributed_train_function`, or any error from the program error 

871 propagation process. 

872 

873 Returns: 

874 Result of running the `distributed_train_function`. 

875 """ 

876 # TODO(wxinyi): after we support use with TPUStrategy, we should expand the 

877 # API doc to state that `distributed_train_function` does not need to be a 

878 # single-step training function, since a multi-step host-training loop is 

879 # the dominant use case for TPU user. Besides, passing in a multi-step 

880 # `distributed_train_function` will require the user to track their own 

881 # training steps. 

882 if ( 

883 self._platform_device 

884 == failure_handling_util.PlatformDevice.INTERNAL_TPU 

885 ): 

886 return self._run_for_tpu(distributed_train_function, *args, **kwargs) 

887 elif self._platform_device in ( 

888 failure_handling_util.PlatformDevice.GCE_TPU, 

889 failure_handling_util.PlatformDevice.GCE_CPU, 

890 ): 

891 return distributed_train_function(*args, **kwargs) 

892 else: 

893 return self._run_for_multi_worker_mirrored( 

894 distributed_train_function, *args, **kwargs 

895 ) 

896 

897 def _run_for_tpu(self, distributed_train_function, *args, **kwargs): 

898 """PreemptionCheckpointHandler.run implementation for TPUStrategy.""" 

899 gen_check_preemption_op.check_preemption(preemption_key=PREEMPTION_KEY) 

900 return distributed_train_function(*args, **kwargs) 

901 

902 def _run_for_multi_worker_mirrored( 

903 self, distributed_train_function, *args, **kwargs 

904 ): 

905 """PreemptionCheckpointHandler.run implementation for MWMS.""" 

906 try: 

907 self._check_preemption_and_maybe_checkpoint() 

908 run_begin_time = time.time() 

909 result = distributed_train_function(*args, **kwargs) 

910 new_run_time = time.time() - run_begin_time 

911 self._run_counter += 1 

912 # Update the average run time with the new run. 

913 self._estimated_run_time = self._estimated_run_time + ( 

914 new_run_time - self._estimated_run_time) / self._run_counter 

915 

916 except errors.OpError as e: 

917 if not self._local_mode: 

918 logging.info('Propagating error to cluster: %r: %s', e, e) 

919 try: 

920 context.context().report_error_to_cluster(e.error_code, e.message) 

921 except Exception as ex: # pylint: disable=broad-except 

922 logging.info('Ignoring error during error propagation: %r:%s', ex, ex) 

923 raise 

924 

925 return result 

926 

927 # Disabling line-too-long check since we do not want to break the line when 

928 # converted to public documentation. 

929 # pylint: disable=line-too-long 

930 def save_checkpoint_if_preempted(self, *args, **kwargs): 

931 """Saves a checkpoint if a preemption signal has been made available. 

932 

933 This is an alternative API for `PreemptionCheckpointHandler.run` and 

934 `PreemptionCheckpointHandler.watch_preemption_scope`. This method works for 

935 both `tf.distribute.MultiWorkerMirroredStrategy` and 

936 `tf.distribute.TPUStrategy`. However, **for TPUStrategy, this method will 

937 add a synchronization point between workers and the coordinator** and thus 

938 may have performance implication. If this is a concern, use the combination 

939 of `PreemptionCheckpointHandler.watch_preemption_scope` and 

940 `PreemptionCheckpointHandler.run` instead. 

941 

942 ```python 

943 strategy = tf.distribute.TPUStrategy(tpu_cluster_resolver) 

944 # initialization omitted 

945 

946 with strategy.scope(): 

947 # Save in the checkpoint. 

948 trained_step = tf.Variable(initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='trained_step', aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA) 

949 

950 checkpoint_manager = tf.train.CheckpointManager(checkpoint, directory, max_to_keep=1) 

951 preemption_handler = tf.distribute.experimental.PreemptionCheckpointHandler(cluster_resolver, checkpoint_manager) 

952 

953 while trained_step.numpy() < NUM_STEPS: 

954 # Train STEPS_IN_FUNCTION steps at once. 

955 train_multi_step_function() 

956 trained_step.assign_add(STEPS_IN_FUNCTION) 

957 preemption_handler.save_checkpoint_if_preempted() 

958 ``` 

959 

960 Args: 

961 *args: args for `tf.train.CheckpointManager.save()` to save checkpoint. 

962 **kwargs: kwargs for `tf.train.CheckpointManager.save()` to save. 

963 """ 

964 # pylint: enable=line-too-long 

965 if (self._platform_device == 

966 failure_handling_util.PlatformDevice.INTERNAL_TPU): 

967 

968 try: 

969 with context.async_scope(): 

970 gen_check_preemption_op.check_preemption( 

971 preemption_key=PREEMPTION_KEY) 

972 except errors.AbortedError as abort_error: 

973 if abort_error.experimental_payloads.get( 

974 b'type.googleapis.com/tensorflow.distributed_runtime.WorkerPreemption' 

975 ): 

976 logging.info('Clearing preemption error to save checkpoint...') 

977 

978 context.async_clear_error() 

979 self._save_checkpoint(*args, **kwargs) 

980 

981 # For TPU training, the default behavior is that it will block until 

982 # workers are down and returns with error. 

983 self._exit_fn() 

984 

985 else: 

986 raise 

987 

988 elif self._platform_device in ( 

989 failure_handling_util.PlatformDevice.GCE_TPU, 

990 failure_handling_util.PlatformDevice.GCE_CPU, 

991 ): 

992 return 

993 

994 else: 

995 self._check_preemption_and_maybe_checkpoint(*args, **kwargs) 

996 self._run_counter += 1 

997 self._estimated_run_time = 0 

998 

999 @tf_contextlib.contextmanager 

1000 def watch_preemption_scope(self): 

1001 """Syncs error and maybe save checkpoint for usage with TPUStrategy. 

1002 

1003 Note: Usage with `tf.distribute.MultiWorkerMirroredStrategy` does not need 

1004 this API. 

1005 

1006 Example usage: 

1007 

1008 ```python 

1009 with preemption_checkpoint_handler.watch_preemption_scope(): 

1010 while trained_step.numpy() < NUM_STEPS: 

1011 

1012 # distributed_train_function contains a call to strategy.run. 

1013 loss += preemption_checkpoint_handler.run(distributed_train_function, args=(next(iterator),)) 

1014 trained_step.assign_add(STEPS_PER_TRAIN_FUNCTION) 

1015 ``` 

1016 

1017 In this workflow, `PreemptionCheckpointHandler.run` will flag preemption 

1018 signal received, and `watch_preemption_scope` will handle the preemption 

1019 signal by saving a checkpoint and then either exit to restart or execute a 

1020 user-passed `exit_fn` in `tf.distribute.experimental.TerminationConfig`. If 

1021 no preemption signal is received during execution of ops and function inside 

1022 the scope, `watch_preemption_scope` ensures the completion of all async op 

1023 and function execution when exiting and will raises exceptions if async 

1024 execution results in an error state. 

1025 

1026 Yields: 

1027 None 

1028 """ 

1029 if self._platform_device == failure_handling_util.PlatformDevice.INTERNAL_TPU: 

1030 try: 

1031 with context.async_scope(): 

1032 yield 

1033 except errors.AbortedError as abort_error: 

1034 if abort_error.experimental_payloads.get( 

1035 b'type.googleapis.com/tensorflow.distributed_runtime.WorkerPreemption' 

1036 ): 

1037 logging.info('Clearing preemption error to save checkpoint...') 

1038 

1039 context.async_clear_error() 

1040 self._save_checkpoint() 

1041 

1042 self._exit_fn() 

1043 

1044 else: 

1045 raise 

1046 else: 

1047 try: 

1048 yield 

1049 except errors.OpError as e: 

1050 if not self._local_mode: 

1051 logging.info('Propagating error to cluster: %r: %s', e, e) 

1052 try: 

1053 context.context().report_error_to_cluster(e.error_code, e.message) 

1054 except Exception as ex: # pylint: disable=broad-except 

1055 logging.info('Ignoring error during error propagation: %r:%s', ex, ex) 

1056 raise 

1057 

1058 def _save_checkpoint(self, *args, **kwargs): 

1059 """Saves the checkpoint and exit program.""" 

1060 distribute_lib.distribution_strategy_input_api_counter.get_cell( 

1061 self._platform_device.name, 

1062 'PreemptionCheckpointHandler Saving Checkpoint').increase_by(1) 

1063 logging.info('PreemptionCheckpointHandler: Starting saving a checkpoint.') 

1064 

1065 if self._platform_device != failure_handling_util.PlatformDevice.INTERNAL_TPU: 

1066 self._checkpointed_runs.assign(self.total_run_calls) 

1067 

1068 start_time = time.monotonic() 

1069 

1070 with checkpoint_context.preemption_save_context(): 

1071 if self._save_fn: 

1072 self._save_fn(*args, **kwargs) 

1073 else: 

1074 self._write_checkpoint_manager.save(*args, **kwargs) 

1075 

1076 end_time = time.monotonic() 

1077 

1078 logging.info('Checkpoint finished at path %s', 

1079 self._write_checkpoint_manager.directory) 

1080 self._checkpoint_time = end_time - start_time 

1081 

1082 def _check_preemption_and_maybe_checkpoint(self, *args, **kwargs): 

1083 """Checkpoint if any worker has received a preemption signal. 

1084 

1085 This function handles preemption signal reported by any worker in the 

1086 cluster. The current implementation relies on the fact that all workers in a 

1087 MultiWorkerMirroredStrategy training cluster have a step number difference 

1088 maximum of 1. 

1089 - If the signal comes from the worker itself (i.e., where this failure 

1090 handler sits), the worker will notify all peers to checkpoint after they 

1091 finish CURRENT_STEP+1 steps, where CURRENT_STEP is the step this worker has 

1092 just finished. And the worker will wait for all peers to acknowledge that 

1093 they have received its preemption signal and the final-step number before 

1094 the worker proceeds on training the final step. 

1095 - If the signal comes from another member in the cluster but NO final-step 

1096 info is available, proceed on training, because it will be available after 

1097 finishing the next step. 

1098 - If the signal comes from some other member in the cluster, and final-step 

1099 info is available, if the worker has not finished these steps yet, keep 

1100 training; otherwise, checkpoint and exit with a cluster-recognized restart 

1101 code. 

1102 

1103 Args: 

1104 *args: args for `tf.train.CheckpointManager.save()` to save checkpoint. 

1105 **kwargs: kwargs for `tf.train.CheckpointManager.save()` to save. 

1106 """ 

1107 if self._platform_device == failure_handling_util.PlatformDevice.INTERNAL_TPU: 

1108 gen_check_preemption_op.check_preemption(preemption_key=PREEMPTION_KEY) 

1109 return 

1110 

1111 if self._final_checkpoint_countdown: 

1112 run_count_config_key = _FINAL_RUN_COUNT_KEY 

1113 

1114 else: 

1115 run_count_config_key = _INITIAL_RUN_COUNT_KEY 

1116 

1117 if self._received_checkpoint_step.is_set(): 

1118 

1119 if self._step_to_checkpoint == str(self._run_counter): 

1120 self._save_checkpoint(*args, **kwargs) 

1121 

1122 if self._time_to_exit(): 

1123 self._stop_poll_termination_signal_thread() 

1124 self._stop_cluster_wise_termination_watcher_thread() 

1125 if self._api_made_checkpoint_manager and not self._is_chief: 

1126 gfile.DeleteRecursively( 

1127 os.path.dirname(self._write_checkpoint_manager.directory)) 

1128 logging.info( 

1129 'PreemptionCheckpointHandler: checkpoint saved. Exiting.') 

1130 

1131 self._exit_fn() 

1132 

1133 else: 

1134 logging.info('Continue training for the grace period.') 

1135 self._final_checkpoint_countdown = True 

1136 self._received_checkpoint_step.clear() 

1137 

1138 elif self._received_own_sigterm.is_set(): 

1139 # Only the worker who gets termination signal first among the cluster 

1140 # will enter this branch. The following will happen in chronological 

1141 # order: 

1142 # 1. The worker just receives a preemption signal and enters this branch 

1143 # for the first time. It will set a step-to-checkpoint and let the cluster 

1144 # know. 

1145 # 2. If there is a long grace period, it will also set 

1146 # _final_checkpoint_countdown, so that during this grace period, it will 

1147 # re-enter this branch to check if grace period is ending. 

1148 # 3. If it is, set a step-to-checkpoint key again. 

1149 

1150 if self._final_checkpoint_countdown: 

1151 if self._target_time_for_termination < time.time(): 

1152 logging.info( 

1153 'Grace period almost ended. Final call to save a checkpoint!') 

1154 else: 

1155 return 

1156 

1157 step_to_save_at = str(self._run_counter + 1) 

1158 

1159 logging.info('Termination caught in main thread on preempted worker') 

1160 

1161 if self._local_mode: 

1162 self._step_to_checkpoint = step_to_save_at 

1163 self._received_checkpoint_step.set() 

1164 

1165 else: 

1166 context.context().set_config_key_value(run_count_config_key, 

1167 step_to_save_at) 

1168 logging.info('%s set to %s', run_count_config_key, step_to_save_at) 

1169 

1170 if not self._local_mode: 

1171 worker_count = multi_worker_util.worker_count( 

1172 self._cluster_resolver.cluster_spec(), 

1173 self._cluster_resolver.task_type) 

1174 for i in range(worker_count): 

1175 context.context().get_config_key_value( 

1176 f'{_ACKNOWLEDGE_KEY}_{run_count_config_key}_{i}') 

1177 logging.info('Sigterm acknowledgement from replica %d received', i) 

1178 

1179 self._setup_countdown_if_has_grace_period_and_not_already_counting_down() 

1180 

1181 def _time_to_exit(self): 

1182 """Return whether to exit: exit if no grace period or grace period ends.""" 

1183 # we should directly exit in either of the two cases: 

1184 # 1. if no grace period is provided; 

1185 # 2. if there is a grace period, and we're in countdown period. This, 

1186 # together with the fact that _received_checkpoint_step is set (again), 

1187 # means it's time to exit: when there is a grace period, a worker 

1188 # receives preemption signal and sets the step key. Then all workers 

1189 # receive the step key and set their local _received_checkpoint_step 

1190 # event, enters this branch in _check_preemption_and_maybe_checkpoint, make 

1191 # a checkpoint. Then they set _final_checkpoint_countdown to True, clear 

1192 # _received_checkpoint_step, and continue training. New preemption 

1193 # signals anywhere in the cluster will not be handled, because 

1194 # _PREEMPTION_WORKER_KEY is occupied. The only chance that 

1195 # _received_checkpoint_step gets set again is when the worker who has 

1196 # received the preemption signal earlier decide it's time to do a final 

1197 # checkpoint (by checking if it already passes 

1198 # _target_time_for_termination). It will upload a final step key. All 

1199 # workers receive this key and again set _received_checkpoint_step. So, 

1200 # if we found out that _received_checkpoint_step is set, and also 

1201 # _final_checkpoint_countdown is true, it's checkpoint and exit time. 

1202 return (self._grace_period <= 0) or self._final_checkpoint_countdown 

1203 

1204 def _setup_countdown_if_has_grace_period_and_not_already_counting_down(self): 

1205 """Set up at the beginning of a countdown period for long grace period.""" 

1206 if self._grace_period > 0 and not self._final_checkpoint_countdown: 

1207 # A factor to provide more buffer / inaccuracy. 

1208 # TODO(wxinyi): update buffer_factor as needed. Maybe deduct a constant. 

1209 buffer_factor = 3 

1210 # Timing by 2 since while the preempted worker needs to do 1 extra step 

1211 # when time_till_final_call <=0, other workers might need to do x step 

1212 # where 0<x<2 

1213 self._target_time_for_termination = ( 

1214 self._received_own_sigterm_time + self._grace_period - 

1215 buffer_factor * self._estimated_run_time * 2) 

1216 

1217 def _sigterm_handler_fn(self, signum, frame): 

1218 """Upload the to-be-preempted worker's id to coordination service.""" 

1219 del signum, frame 

1220 self._maybe_set_received_own_sigterm() 

1221 

1222 def _watch_step_to_save_key(self): 

1223 """Watch out for step-to-save config key and acknowledge. 

1224 

1225 All workers, including the one to be preempted, execute this function to get 

1226 step-to-save. 

1227 """ 

1228 

1229 step_value = context.context().get_config_key_value(_INITIAL_RUN_COUNT_KEY) 

1230 

1231 # get_config_key_value does not return until it gets some result. Thus at 

1232 # the time to clean up, we upload a _STOP_WATCHING_CLUSTER_VALUE as the 

1233 # value so we can join the thread executing _watch_step_to_save_key. 

1234 if step_value != _STOP_WATCHING_CLUSTER_VALUE: 

1235 # This must be set before we set the ack key below, otherwise its value 

1236 # in _check_preemption_and_maybe_checkpoint may be outdated. 

1237 self._step_to_checkpoint = step_value 

1238 self._received_checkpoint_step.set() 

1239 

1240 ack_key = f'{_ACKNOWLEDGE_KEY}_{_INITIAL_RUN_COUNT_KEY}_{self._id_in_cluster}' 

1241 context.context().set_config_key_value(ack_key, '1') 

1242 logging.info( 

1243 'PreemptionCheckpointHandler: %s set, ' 

1244 'preemption awareness acknowledged', ack_key) 

1245 

1246 # If a positive grace_period is not configured, we get the 

1247 # _INITIAL_RUN_COUNT_KEY and then we're done. 

1248 # _check_preemption_and_maybe_checkpoint 

1249 # will save a checkpoint and then exit. Otherwise, we need to move on to 

1250 # wait for the _FINAL_RUN_COUNT_KEY, the one that the preempted worker 

1251 # will set after we utilize the extended grace period to train, so that 

1252 # a final checkpoint should be made right before the termination. 

1253 if self._grace_period > 0: 

1254 # Continue to wait until a final call is made. 

1255 final_step_value = context.context().get_config_key_value( 

1256 _FINAL_RUN_COUNT_KEY) 

1257 if final_step_value != _STOP_WATCHING_CLUSTER_VALUE: 

1258 ack_key = f'{_ACKNOWLEDGE_KEY}_{_FINAL_RUN_COUNT_KEY}_{self._id_in_cluster}' 

1259 context.context().set_config_key_value(ack_key, '1') 

1260 logging.info( 

1261 'PreemptionCheckpointHandler: %s acknowledged, final ' 

1262 'checkpoint timing received.', ack_key) 

1263 self._received_checkpoint_step.set() 

1264 self._step_to_checkpoint = final_step_value 

1265 

1266# TODO(wxinyi): remove this line after we move the Keras callback prototype and 

1267# change gce test usage. 

1268WorkerPreemptionHandler = PreemptionCheckpointHandler