1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module for `PreemptionCheckpointHandler`.
16
17This is currently under development and the API is subject to change.
18
19PreemptionCheckpointHandler reduces loss of training progress caused by
20termination (preemption or maintenance) of workers in multi-worker synchronous
21training and avoid surfacing an error indistinguishable from application errors
22to the job scheduler or users.
23"""
24import os
25import signal
26import sys
27import threading
28import time
29
30from tensorflow.core.distributed_runtime.preemption import gen_check_preemption_op
31from tensorflow.python.checkpoint import checkpoint as checkpoint_lib
32from tensorflow.python.checkpoint import checkpoint_context
33from tensorflow.python.checkpoint import checkpoint_management
34from tensorflow.python.distribute import distribute_lib
35from tensorflow.python.distribute import multi_worker_util
36from tensorflow.python.distribute.failure_handling import failure_handling_util
37from tensorflow.python.eager import context
38from tensorflow.python.framework import constant_op
39from tensorflow.python.framework import dtypes
40from tensorflow.python.framework import errors
41from tensorflow.python.lib.io import file_io
42from tensorflow.python.ops import variables
43from tensorflow.python.platform import gfile
44from tensorflow.python.platform import tf_logging as logging
45from tensorflow.python.util import tf_contextlib
46from tensorflow.python.util.deprecation import deprecated
47from tensorflow.python.util.tf_export import tf_export
48from tensorflow.tools.docs import doc_controls
49
50
51_INITIAL_RUN_COUNT_KEY = 'RUN_TO_CHECKPOINT'
52_FINAL_RUN_COUNT_KEY = 'LAST_RUN_TO_CHECKPOINT'
53# This key is used to guarantee that only one worker (and it's the earliest
54# one that receives a preemption signal) sets _received_own_sigterm,
55# leads the step resolution, and controls the grace period timeline.
56_PREEMPTION_WORKER_KEY = 'TERMINATED_WORKER'
57_ACKNOWLEDGE_KEY = 'RECEIVED_SIGNAL'
58_ITERATION_VARIABLE = 'checkpointed_runs'
59_STOP_WATCHING_CLUSTER_VALUE = 'STOP_WATCHER'
60PREEMPTION_KEY = 'TF_DEFAULT_PREEMPTION_NOTICE_KEY'
61
62
63# TODO(wxinyi): add type annotations.
64def _non_chief_checkpoint_dir(checkpoint_dir, task_id):
65 """Returns a directory for non-chief worker to save checkpoint."""
66 dirpath = os.path.dirname(checkpoint_dir)
67 base = os.path.basename(checkpoint_dir)
68 base_dirpath = 'workertemp_' + str(task_id)
69 dirpath = os.path.join(dirpath, base_dirpath)
70 file_io.recursive_create_dir_v2(dirpath)
71 return os.path.join(dirpath, base)
72
73
74@tf_export('distribute.experimental.TerminationConfig', v1=[])
75class TerminationConfig(object):
76 """Customization of `PreemptionCheckpointHandler` for various platforms.
77
78 A `TerminationConfig` can be created and passed to a
79 `tf.distribute.experimental.PreemptionCheckpointHandler` to provide
80 customization based on the platform. It can deliver three pieces of
81 information:
82
83 * How to decide if there is a termination event soon
84
85 The form of termination notification and how to fetch it vary across
86 platforms. Thus `PreemptionCheckpointHandler` may take a user-defined
87 function, `termination_watcher_fn`, and execute it repeatedly to check for
88 termination notification. `termination_watcher_fn` should be a function
89 that returns `True` if a termination notification is available and
90 `False` otherwise. The function should be lightweight and non-blocking so that
91 resources can be cleaned up properly if no termination signal is ever raised
92 until training finishes.
93
94 * How to exit the program
95
96 A user can configure this through the `exit_fn`, which
97 `PreemptionCheckpointHandler` executes after saving the checkpoint to exit the
98 training program gracefully. For `tf.distribute.MultiWorkerMirroredStrategy`,
99 a restart is necessary to reset the program's state. However, having a
100 customized `exit_fn` may facilitate the restart and smoothen the training
101 experience. How so? Maybe the platform has an agreement to a `RESTART_CODE`
102 recognized as a program auto-restart signal, or maybe the user has a
103 coordinating script that starts up the training, in which they can configure
104 the program to auto-restart if it ever exits with this `RESTART_CODE`. In both
105 cases, configuring the `exit_fn` to be `sys.exit(RESTART_CODE)` makes the
106 training seamless.
107
108 * How long does `PreemptionCheckpointHandler` have from receiving a
109 termination event notice till the actual termination
110
111 Some platforms have a gap time as long as one hour or so. In these cases,
112 there is the option to utilize this gap time for training as much as possible
113 before saving a checkpoint and exiting. This can be achieved by passing the
114 `grace_period` argument a nonzero value. Note, for a user with a grace period
115 that is not multiple times longer than their checkpoint writing time (e.g.,
116 three times or more), we advise not to configure this argument, in which case
117 `PreemptionCheckpointHandler` will directly save a checkpoint and exit.
118
119
120 **The default behavior**:
121
122 * For Google Borg Platform:
123 * Automatically know how to detect preemption signal
124 * Exit with a platform-recognized restart code
125 * Save a checkpoint and exit immediately
126
127 * For Google Cloud Platform:
128 * Automatically know how to detect maintenance signal.
129 * Exit with a code (User may configure this)
130 * Automatically utilized the extended training period before save and exit
131
132 * For Other platform:
133 * If `termination_watcher_fn` is `None`, we will treat `signal.SIGTERM` as
134 a termination signal.
135 * If `exit_fn` is not configured, we exit the program with an arbitrary
136 code.
137 * If `grace_period` is not configured, we will wrap up the current
138 training step, save a checkpoint, and exit the program as soon as we
139 receive the termination signal.
140 """
141
142 def __init__(self,
143 termination_watcher_fn=None,
144 exit_fn=None,
145 grace_period=None,
146 save_fn=None):
147 """Creates a `TerminationConfig` object.
148
149 Args:
150 termination_watcher_fn: a function to execute repeatedly that returns
151 `True` if a preemption signal is available and False otherwise. The
152 function cannot block until a preemption signal is available, which
153 prevents proper cleanup of the program. A change is **NOT** recommended
154 for users on Google Borg or Google Cloud Platform.
155 exit_fn: a function to execute after a checkpoint is saved and before the
156 preemption happens. Usually, it should be in the form of
157 `lambda: sys.exit(RESTART_CODE)`, where `RESTART_CODE` varies by
158 platform. A change is **NOT** recommended for users on Google Borg.
159 Users on Google Cloud Platform may configure it to use a customized
160 `RESTART_CODE`.
161 grace_period: the length of time between receiving a preemption signal and
162 the actual preemption. A change is **NOT** recommended for users on
163 Google Borg, Google Cloud Platform, or users with a short grace period.
164 save_fn: an optional function letting you configure how to save a
165 checkpoint. This is useful if you'd like to pass extra argument to
166 `tf.train.CheckpointManager.save` or `tf.train.Checkpoint.save`. By
167 default, if not configured, the API will save checkpoint without extra
168 arguments.
169 """
170 self.termination_watcher_fn = termination_watcher_fn
171 self.exit_fn = exit_fn
172 self.grace_period = grace_period
173 self.save_fn = save_fn
174
175
176# TODO(wxinyi): add some tests for TerminationConfig.
177# TODO(wxinyi): configure the exit function based on device type (GPU or TPU).
178class GcpGpuTerminationConfig(TerminationConfig):
179 """Configurations for GCP GPU VM."""
180
181 def __init__( # pylint: disable=super-init-not-called
182 self,
183 termination_watcher_fn=None,
184 exit_fn=None,
185 grace_period=None,
186 save_fn=None,
187 ):
188 self.termination_watcher_fn = (
189 termination_watcher_fn
190 or failure_handling_util.termination_watcher_function_gce
191 )
192 self.exit_fn = exit_fn or failure_handling_util.gce_exit_fn
193 self.grace_period = (
194 grace_period if grace_period or grace_period == 0 else
195 failure_handling_util.GRACE_PERIOD_GCE)
196 self.save_fn = save_fn
197
198
199class GcpCpuTerminationConfig(TerminationConfig):
200 """Configurations for GCP CPU VM."""
201
202 def __init__( # pylint: disable=super-init-not-called
203 self,
204 termination_watcher_fn=None,
205 exit_fn=None,
206 grace_period=None,
207 save_fn=None):
208 self.termination_watcher_fn = termination_watcher_fn or failure_handling_util.termination_watcher_function_gce
209 self.exit_fn = exit_fn or failure_handling_util.gce_exit_fn
210 self.grace_period = grace_period or 0
211 self.save_fn = save_fn
212
213
214class BorgTerminationConfig(TerminationConfig):
215 """Configurations for Borg."""
216
217 def __init__( # pylint: disable=super-init-not-called
218 self,
219 termination_watcher_fn=None,
220 exit_fn=None,
221 grace_period=None,
222 save_fn=None):
223 self.termination_watcher_fn = termination_watcher_fn
224 default_exit_fn = lambda: sys.exit(42)
225 self.exit_fn = exit_fn or default_exit_fn
226 self.grace_period = grace_period or 0
227 self.save_fn = save_fn
228
229
230class BorgTPUTerminationConfig(TerminationConfig):
231 """Configurations for Borg."""
232
233 def __init__( # pylint: disable=super-init-not-called
234 self,
235 termination_watcher_fn=None,
236 exit_fn=None,
237 grace_period=None,
238 save_fn=None):
239 self.termination_watcher_fn = termination_watcher_fn
240 self.exit_fn = exit_fn or failure_handling_util.default_tpu_exit_fn
241 self.grace_period = grace_period or 0
242 self.save_fn = save_fn
243
244
245def _complete_config_for_environment(platform_device, termination_config):
246 """Complete un-filled fields of TerminationConfig based on platform."""
247 if not termination_config:
248 termination_config = TerminationConfig()
249
250 if platform_device is failure_handling_util.PlatformDevice.GCE_GPU:
251 return GcpGpuTerminationConfig(termination_config.termination_watcher_fn,
252 termination_config.exit_fn,
253 termination_config.grace_period,
254 termination_config.save_fn)
255
256 elif platform_device is failure_handling_util.PlatformDevice.GCE_CPU:
257 return GcpCpuTerminationConfig(termination_config.termination_watcher_fn,
258 termination_config.exit_fn,
259 termination_config.grace_period,
260 termination_config.save_fn)
261
262 elif platform_device is failure_handling_util.PlatformDevice.INTERNAL_TPU:
263 return BorgTPUTerminationConfig(termination_config.termination_watcher_fn,
264 termination_config.exit_fn,
265 termination_config.grace_period,
266 termination_config.save_fn)
267
268 else:
269 # The default we chose are the same as the ones used by Borg. So we just
270 # return this.
271 return BorgTerminationConfig(
272 termination_config.termination_watcher_fn,
273 termination_config.exit_fn, termination_config.grace_period,
274 termination_config.save_fn)
275
276
277# TODO(wxinyi): add release updates.
278# Implementation:
279# Each worker will create its own PreemptionCheckpointHandler instance, and the
280# instances communicate through coordination services. Each
281# PreemptionCheckpointHandler conduct three tasks in parallel:
282# - Watches out for its own preemption signal. (_poll_termination_signal_thread)
283# - Watches out for a step key from the coordination service made available
284# by any member in the cluster (_cluster_wise_termination_watcher_thread)
285# - The main thread for training.
286#
287# The life cycle of a PreemptionCheckpointHandler is as below:
288#
289# It starts two threads as two watcher as described above. And it starts
290# training. Each time before it starts a training step, it will check if any
291# information has been made available by the two watchers: The
292# _poll_termination_signal_thread will be in charge of the _received_own_sigterm
293# event, the _cluster_wise_termination_watcher_thread will be in charge of the
294# _received_checkpoint_step event.
295#
296# If at any point the local worker receives a preemption signal,
297# _poll_termination_signal_thread will set _received_own_sigterm.
298# Next time before it attempts to run a training step, it will deal with the
299# event, by setting its current finished step + 1 as the step after which a
300# checkpoint should be saved and make it available to all the workers through
301# the coordination service. It will then continue training.
302#
303# This step key will be picked up by the other watcher,
304# _cluster_wise_termination_watcher_thread, both on the worker to be preempted
305# and other workers. And it will set the _received_checkpoint_step event.
306# Now, if there is a long grace period before the training
307# has to terminate (e.g., an hour), we would like to keep training and save a
308# checkpoint again right before the termination. Thus this watcher thread will
309# move on to watch out for a final step-to-save key. Otherwise,
310# it has finished all the task to do.
311#
312# Back to the main training thread. Again, before the next training step, the
313# PreemptionCheckpointHandler found that _received_checkpoint_step is set. If
314# the local worker has not finished the required step after which to save a
315# checkpoint, it will not do anything. Continue training and it will revisit
316# after another step. If the step is met, then it will save a checkpoint,
317# which requires participation of all workers.
318#
319# After this checkpoint is saved, if there is NO long grace period, all workers
320# will just exit. If there is, all workers will enter a grace period countdown
321# phase (_final_checkpoint_countdown) and clear the _received_checkpoint_step
322# event. They will then continue training.
323#
324# For the worker to be preempted, during this countdown period, it will check
325# whether the grace period is almost ending before its every step. If not,
326# nothing needs to be done. If so, it will again set a step-to-save key and made
327# it available to all workers. This is still watched by
328# _cluster_wise_termination_watcher_thread and gestured by
329# _received_checkpoint_step. A similar process is repeated: all workers save
330# a checkpoint at an agreed step. And after they finish saving, they recognize
331# that they have finished a countdown period for an extended grace period, and
332# they all exit.
333#
334# When the program restarts and PreemptionCheckpointHandler object is created,
335# it will restore the checkpoint.
336@tf_export('distribute.experimental.PreemptionCheckpointHandler', v1=[])
337class PreemptionCheckpointHandler(object):
338 # pylint: disable=line-too-long
339 """Preemption and error handler for synchronous training.
340
341 Note: This API only supports use with
342 `tf.distribute.MultiWorkerMirroredStrategy` and `tf.distribute.TPUStrategy`.
343
344 A `PreemptionCheckpointHandler` coordinates all workers to save a checkpoint
345 upon receiving a preemption signal. It also helps disseminate application
346 error messages accurately among the cluster. When a
347 `PreemptionCheckpointHandler` object is created, it restores values from
348 the latest checkpoint file if any exists.
349
350 Right after the initialization, the object starts to watch out for termination
351 signal for any member in the cluster. If receiving a signal, the next time the
352 worker executes `PreemptionCheckpointHandler.run`, the
353 `PreemptionCheckpointHandler` will align all workers to save a checkpoint.
354 Then, if an `exit_fn` is configured via
355 `tf.distribute.experimental.TerminationConfig`, it will be invoked. Otherwise,
356 the process will simply exit and later the platform should restart it.
357
358 Note: We advise users of `tf.distribute.MultiWorkerMirroredStrategy` who
359 choose to configure their
360 own `exit_fn` in `tf.distribute.experimental.TerminationConfig` to include a
361 `sys.exit(CODE_OR_MESSAGE)` in the `exit_fn` so that after the restart, all
362 workers can initialize communication services correctly. For users of
363 `tf.distribute.TPUStrategy`, if they do not wish to do a cluster restart but
364 would like an in-process restart (i.e., keep the coordinator alive and re-do
365 the steps to connect to cluster, initialize TPU system, and make the
366 `TPUStrategy` object), they could configure the `exit_fn` to a no-op.
367
368 For users of `tf.distribute.MultiWorkerMirroredStrategy`, the core API is
369 `PreemptionCheckpointHandler.run`:
370
371 ```python
372 strategy = tf.distribute.MultiWorkerMirroredStrategy()
373
374 trained_epoch = tf.Variable(initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch')
375 step_in_epoch = tf.Variable(initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='step_in_epoch')
376
377 with strategy.scope():
378 dataset, model, optimizer = ...
379
380 checkpoint = tf.train.Checkpoint(optimizer=optimizer,
381 model=model,
382 trained_epoch=trained_epoch,
383 step_in_epoch=step_in_epoch)
384
385 preemption_checkpoint_handler = tf.distribute.experimental.PreemptionCheckpointHandler(cluster_resolver, checkpoint, checkpoint_dir)
386
387 while trained_epoch.numpy() < NUM_EPOCH:
388
389 while step_in_epoch.numpy() < STEPS_PER_EPOCH:
390
391 # distributed_train_function contains a call to strategy.run.
392 loss += preemption_checkpoint_handler.run(distributed_train_function, args=(next(iterator),))
393 # For users of MultiWorkerMirroredStrategy, usually
394 # STEPS_PER_TRAIN_FUNCTION = 1.
395 step_in_epoch.assign_add(STEPS_PER_TRAIN_FUNCTION)
396 ...
397
398 epoch.assign_add(1)
399 step_in_epoch.assign(0)
400 ```
401
402 For users of `tf.distribute.TPUStrategy`, the core APIs are
403 `PreemptionCheckpointHandler.run` and
404 `PreemptionCheckpointHandler.watch_preemption_scope`:
405
406 ```python
407
408 strategy = tf.distribute.TPUStrategy(tpu_cluster_resolver)
409
410 # Rest of TPU init omitted, see documentation for TPUSTrategy.
411
412 with preemption_checkpoint_handler.watch_preemption_scope():
413 while trained_epoch.numpy() < NUM_EPOCH:
414
415 while step_in_epoch.numpy() < STEPS_PER_EPOCH:
416
417 # distributed_train_function contains a call to strategy.run.
418 loss += preemption_checkpoint_handler.run(distributed_train_function, args=(next(iterator),))
419
420 # For users of TPUStrategy, usually STEPS_PER_TRAIN_FUNCTION >> 1 since
421 # clustering multiple steps within a tf.function amortizes the overhead
422 # of launching a multi-device function on TPU Pod.
423 step_in_epoch.assign_add(STEPS_PER_TRAIN_FUNCTION)
424 ...
425
426 epoch.assign_add(1)
427 step_in_epoch.assign(0)
428 ```
429
430 Not all interruptions come with advance notice so that the
431 `PreemptionCheckpointHandler` can handle them, e.g., those caused by hardware
432 failure. For a user who saves checkpoints for these cases themselves outside
433 the `PreemptionCheckpointHandler`, if they are using a
434 `tf.train.CheckpointManager`, pass it as the
435 `checkpoint_or_checkpoint_manager` argument to the
436 `PreemptionCheckpointHandler`. If they do not have a
437 `tf.train.CheckpointManager` but are directly working with
438 `tf.train.Checkpoint`, we advise saving the checkpoints in the directory
439 that's passed as the `checkpoint_dir` argument. In this way, at the program
440 beginning, `PreemptionCheckpointHandler` can restore the latest checkpoint
441 from the directory, no matter it's saved by the user themselves or saved by
442 the `PreemptionCheckpointHandler` before preemption happens.
443
444 **A note on the platform:**
445
446 `PreemptionCheckpointHandler` can only handle the kind of termination with
447 advance notice. For now, the API recognizes the termination signal for CPU,
448 GPU, and TPU on Google Borg and CPU and GPU on the Google Cloud Platform. In
449 these cases, `PreemptionCheckpointHandler` will automatically adopt the
450 correct preemption/maintenance notification detection mechanism. Users of
451 other platforms can configure a detection monitoring behavior through the
452 `tf.distribute.experimental.TerminationConfig`. Customization for the exit
453 behavior and grace period length could also be done here.
454 """
455 # pylint: enable=line-too-long
456
457 def __init__(self,
458 cluster_resolver,
459 checkpoint_or_checkpoint_manager,
460 checkpoint_dir=None,
461 termination_config=None):
462 """Creates the `PreemptionCheckpointHandler`.
463
464 Args:
465 cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`
466 object. You may also obtain it through the `cluster_resolver` attribute
467 of the distribution strategy in use.
468 checkpoint_or_checkpoint_manager: a `tf.train.CheckpointManager` or a
469 `tf.train.Checkpoint`. If you are using a `tf.train.CheckpointManager`
470 to manage checkpoints outside the `PreemptionCheckpointHandler` for
471 backup purpose as well, pass it as `checkpoint_or_checkpoint_manager`
472 argument. Otherwise, pass a `tf.train.Checkpoint` and the
473 `PreemptionCheckpointHandler` will create
474 a `tf.train.CheckpointManager` to manage it in the `checkpoint_dir`.
475 checkpoint_dir: a directory where the `PreemptionCheckpointHandler` saves
476 and restores checkpoints. When a `PreemptionCheckpointHandler` is
477 created, the latest checkpoint in the `checkpoint_dir` will be restored.
478 (This is not needed if a `tf.train.CheckpointManager` instead of a
479 `tf.train.Checkpoint` is passed as the
480 `checkpoint_or_checkpoint_manager` argument.)
481 termination_config: optional, a
482 `tf.distribute.experimental.TerminationConfig` object to configure for a
483 platform other than Google Borg or GCP.
484 """
485 # TODO(wxinyi): Maybe make checkpoint_or_checkpoint_manager optional if
486 # save_fn is passed. For now it's still useful for restore.
487 if isinstance(checkpoint_or_checkpoint_manager,
488 checkpoint_lib.Checkpoint) and not checkpoint_dir:
489 raise errors.InvalidArgumentError('When a checkpoint is passed, a '
490 'checkpoint_dir must be passed as well'
491 '.')
492
493 self._cluster_resolver = cluster_resolver
494 self._termination_config = termination_config
495 self._checkpoint_or_checkpoint_manager = checkpoint_or_checkpoint_manager
496 self._checkpoint_dir = checkpoint_dir
497
498 self._platform_device = failure_handling_util.detect_platform()
499
500 completed_termination_config = _complete_config_for_environment(
501 self._platform_device, self._termination_config
502 )
503 self._termination_watcher_fn = (
504 completed_termination_config.termination_watcher_fn
505 )
506 self._exit_fn = completed_termination_config.exit_fn
507 self._grace_period = completed_termination_config.grace_period
508 self._save_fn = completed_termination_config.save_fn
509
510 self._local_mode = True
511 if self._platform_device in (
512 failure_handling_util.PlatformDevice.GCE_TPU,
513 failure_handling_util.PlatformDevice.GCE_CPU,
514 ):
515 # While running MultiWorkerMirroredStrategy training with GPUs and CPUs
516 # are the same on Borg, GCE CPU VM and GPU VM are different in terms
517 # of live migration, grace period, etc. We can make it work upon request.
518 logging.warning(
519 'PreemptionCheckpointHandler does not support usage with '
520 'TPU or CPU device on GCP.'
521 )
522
523 elif (
524 self._platform_device
525 == failure_handling_util.PlatformDevice.INTERNAL_TPU
526 ):
527 self._initialize_for_tpu_strategy()
528
529 else:
530 if cluster_resolver and 'ps' in cluster_resolver.cluster_spec().as_dict():
531 raise NotImplementedError(
532 'PreemptionCheckpointHandler does not support'
533 'usage with tf.distribute.experimental.ParameterServerStrategy.'
534 )
535
536 self._initialize_for_mirrored_and_multi_worker_mirrored()
537
538 logging.info('PreemptionCheckpointHandler initialized or restored.')
539
540 def _initialize_for_tpu_strategy(self):
541 """Makes configurations for using the handler with TPUStrategy."""
542 self._is_chief = True
543 self._poll_termination_signal_thread = None
544 self._cluster_wise_termination_watcher_thread = None
545 self._maybe_create_checkpoint_manager()
546 self._read_checkpoint_manager.restore_or_initialize()
547 self._run_counter = 0
548
549 def _initialize_for_mirrored_and_multi_worker_mirrored(self):
550 """Makes configurations and start watchers for MS, MWMS, or OneDevice."""
551 if (
552 not self._cluster_resolver
553 or not self._cluster_resolver.cluster_spec().jobs
554 ):
555 # For MirroredStrategy, OneDeviceStrategy, and local-mode
556 # MultiWorkerMirroredStrategy, an empty cluster spec is passed, and
557 # coordination service is not enabled nor is it needed (since
558 # it's used for cross-worker communication). Thus we will directly name
559 # the worker id and is_chief properties and also skip the
560 # uploading/reading from coordination service logic.
561 self._local_mode = True
562 self._id_in_cluster = 'single_worker'
563 self._is_chief = True
564 else:
565 self._local_mode = False
566 self._id_in_cluster = str(
567 multi_worker_util.id_in_cluster(
568 self._cluster_resolver.cluster_spec(),
569 self._cluster_resolver.task_type,
570 self._cluster_resolver.task_id))
571 self._is_chief = multi_worker_util.is_chief(
572 cluster_spec=self._cluster_resolver.cluster_spec(),
573 task_type=self._cluster_resolver.task_type,
574 task_id=self._cluster_resolver.task_id)
575 # The number of calls to `PreemptionCheckpointHandler.run` when the latest
576 # checkpoint was saved.
577 self._checkpointed_runs = variables.Variable(
578 initial_value=constant_op.constant(0, dtype=dtypes.int64),
579 trainable=False,
580 name=_ITERATION_VARIABLE)
581
582 self._maybe_create_checkpoint_manager()
583
584 if not hasattr(self._write_checkpoint_manager._checkpoint, # pylint: disable=protected-access
585 _ITERATION_VARIABLE):
586 setattr(self._write_checkpoint_manager._checkpoint, _ITERATION_VARIABLE, # pylint: disable=protected-access
587 self._checkpointed_runs)
588
589 if not hasattr(self._read_checkpoint_manager._checkpoint, # pylint: disable=protected-access
590 _ITERATION_VARIABLE):
591 setattr(self._read_checkpoint_manager._checkpoint, _ITERATION_VARIABLE, # pylint: disable=protected-access
592 self._checkpointed_runs)
593
594 self._read_checkpoint_manager.restore_or_initialize()
595
596 # grace period countdown. Set to True for all workers once they finish
597 # timing saving a checkpoint. Once entering this phase, new
598 # preemption/maintenance notice will not be handled, since the whole cluster
599 # goes down as the worker who first initiates the grace period goes down.
600 self._final_checkpoint_countdown = False
601
602 self._estimated_run_time = 0
603
604 # An internal step counter that's restored to checkpointed_iterations when
605 # training is restored. It increments by one every time
606 # `PreemptionCheckpointHandler.run` is called. Note that in this case, the
607 # user must pass a single-step training function to
608 # `PreemptionCheckpointHandler.run` instead of a multiple-step one.
609 self._run_counter = self._checkpointed_runs.numpy()
610
611 # The worker itself has received preeption signal.
612 self._received_own_sigterm = threading.Event()
613
614 # Some member (could be oneself) has received preemption signal, and the
615 # step number to save a checkpoint has been aligned.
616 self._received_checkpoint_step = threading.Event()
617
618 distribute_lib.distribution_strategy_input_api_counter.get_cell(
619 self._platform_device.name,
620 'PreemptionCheckpointHandler').increase_by(1)
621
622 if not self._local_mode:
623 # When training is interrupted, we explicitly call the cleanup methods for
624 # the thread watching for local worker's termination signal and the thread
625 # watching for clusterwise information before we save a checkpoint and
626 # exit. In the final chapter of the training where no interruption is
627 # encountered, we rely on __del__ to clean up. However, there is no
628 # guarantee when or whether __del__ is executed, thus we make the threads
629 # daemon to avoid it preventing program from exit.
630 self._cluster_wise_termination_watcher_thread = threading.Thread(
631 target=self._watch_step_to_save_key,
632 name='PeerTerminationWatcher-%s' % self._id_in_cluster,
633 daemon=True)
634 logging.info('Start watcher for peer\'s signal.')
635 self._cluster_wise_termination_watcher_thread.start()
636
637 else:
638 self._cluster_wise_termination_watcher_thread = None
639
640 self._poll_termination_signal_thread = None
641
642 if self._termination_watcher_fn:
643 self._start_polling_for_termination_signal()
644 else:
645 self._start_watching_for_signal()
646
647 def _maybe_create_checkpoint_manager(self):
648 """Create CheckpointManager(s) if a checkpoint is passed else take it."""
649 if isinstance(self._checkpoint_or_checkpoint_manager,
650 checkpoint_management.CheckpointManager):
651 self._read_checkpoint_manager = self._checkpoint_or_checkpoint_manager
652 self._write_checkpoint_manager = self._checkpoint_or_checkpoint_manager
653 self._api_made_checkpoint_manager = False
654 else:
655 self._api_made_checkpoint_manager = True
656 # Make CheckpointManagers. MultiWorkerMirroredStrategy requires different
657 # setup on chief and on other workers.
658 self._read_checkpoint_manager = checkpoint_management.CheckpointManager(
659 self._checkpoint_or_checkpoint_manager,
660 directory=self._checkpoint_dir,
661 max_to_keep=1)
662
663 if self._is_chief:
664 self._write_checkpoint_manager = self._read_checkpoint_manager
665 else:
666 self._write_checkpoint_manager = (
667 checkpoint_management.CheckpointManager(
668 self._checkpoint_or_checkpoint_manager,
669 _non_chief_checkpoint_dir(self._checkpoint_dir,
670 self._cluster_resolver.task_id),
671 max_to_keep=1))
672
673 def _start_watching_for_signal(self):
674 logging.info('Start watcher for local signal.')
675 signal.signal(signal.SIGTERM, self._sigterm_handler_fn)
676
677 def _start_polling_for_termination_signal(self):
678 self._poll_termination_signal_thread_should_stop = threading.Event()
679 self._poll_termination_signal_thread = threading.Thread(
680 target=self._poll_termination_signal,
681 name='WorkerTerminationSignalWatcher-%s' % self._id_in_cluster,
682 daemon=True)
683 logging.info('Start polling for termination signal.')
684 self._poll_termination_signal_thread.start()
685
686 def _poll_termination_signal(self):
687 """Poll maintenance notice and notify peers if receiving one."""
688 while True:
689 if self._poll_termination_signal_thread_should_stop.is_set(
690 ) or self._final_checkpoint_countdown:
691 return
692 if self._termination_watcher_fn():
693 break
694 time.sleep(1)
695
696 self._maybe_set_received_own_sigterm()
697
698 def _maybe_set_received_own_sigterm(self):
699 """Claim earliest preemption if no one else has done it before."""
700 if self._local_mode:
701 logging.info('Received termination notice.',
702 self._id_in_cluster)
703 self._received_own_sigterm_time = time.time()
704 self._received_own_sigterm.set()
705 return
706
707 try:
708 context.context().set_config_key_value(_PREEMPTION_WORKER_KEY,
709 self._id_in_cluster)
710 logging.info('Member %s has received termination notice.',
711 self._id_in_cluster)
712 self._received_own_sigterm_time = time.time()
713 self._received_own_sigterm.set()
714
715 # This is to handle the case that a worker has received termination
716 # notice but hasn't come to the next step to set the step key. Other
717 # workers might receive a termination notice too, and attempt to set the
718 # config key again, which causes this error. This can be safely ignored
719 # since checkpoint should be saved as early as the earliest call is made.
720 except errors.AlreadyExistsError:
721 logging.info(
722 (
723 'Member %s has received termination notice. But some other '
724 'worker has received it as well! Leaving'
725 ' it to them to decide when to checkpoint. '
726 ),
727 self._id_in_cluster,
728 )
729 return
730
731 def _stop_poll_termination_signal_thread(self):
732 if getattr(self, '_poll_termination_signal_thread', None):
733 self._poll_termination_signal_thread_should_stop.set()
734 self._poll_termination_signal_thread.join()
735
736 self._poll_termination_signal_thread = None
737 logging.info("Shut down watcher for one's own termination signal")
738
739 def _stop_cluster_wise_termination_watcher_thread(self):
740 """Stop the thread that is _watch_step_to_save_key."""
741 if getattr(self, '_cluster_wise_termination_watcher_thread', None):
742 try:
743 context.context().set_config_key_value(
744 _INITIAL_RUN_COUNT_KEY, _STOP_WATCHING_CLUSTER_VALUE
745 )
746 except (errors.AlreadyExistsError, errors.UnavailableError):
747 # We'll ignore any error in the process of setting this key. There
748 # certainly will be a AlreadyExistError since all workers are trying to
749 # push this key. Or some worker might have exited already, leading to a
750 # errors.UnavailableError or errors.AbortedError.
751 pass
752 except Exception as e: # pylint: disable=broad-except
753 # We'll also ignore other errors since they are not important to the
754 # process.
755 logging.info('Ignoring error when shutting down '
756 '_stop_cluster_wise_termination_watcher_thread: ' + str(e))
757
758 try:
759 context.context().set_config_key_value(_FINAL_RUN_COUNT_KEY,
760 _STOP_WATCHING_CLUSTER_VALUE)
761 except (errors.AlreadyExistsError, errors.UnavailableError):
762 pass
763
764 except Exception as e: # pylint: disable=broad-except
765 logging.info('Ignoring error when shutting down '
766 '_stop_cluster_wise_termination_watcher_thread: ' + str(e))
767
768 finally:
769 self._cluster_wise_termination_watcher_thread.join()
770 self._cluster_wise_termination_watcher_thread = None
771 logging.info('Shut down watcher for peer\'s termination signal.')
772
773 def __del__(self):
774 self._stop_cluster_wise_termination_watcher_thread()
775 self._stop_poll_termination_signal_thread()
776
777 @property
778 @deprecated(None,
779 'Track steps using a tf.Variable saved in checkpoint instead.')
780 @doc_controls.do_not_generate_docs
781 def total_run_calls(self):
782 """Returns the number of times `PreemptionCheckpointHandler.run` is called.
783
784 DEPRECATED: user should track total steps themselves, as this API provides
785 little expressivity gain but could easily be misused and incurs extra
786 synchronization cost for TPUStrategy users.
787
788 This value tracks the number of all calls to
789 `PreemptionCheckpointHandler.run` including those before the program is
790 restarted and the training is restored, by saving and reading the value in
791 the checkpoint. A user can compute their total number of iterations
792 by `PreemptionCheckpointHandler.total_run_calls *
793 number_of_steps_in_train_function`,
794 while `number_of_steps_in_train_function` should be one for
795 `tf.distribute.MultiWorkerMirroredStrategy` users. They can also use this
796 value to infer the starting epoch and step after training restores, as shown
797 in the example above.
798 """
799 if (self._platform_device ==
800 failure_handling_util.PlatformDevice.INTERNAL_TPU):
801 raise NotImplementedError('Please create variables saved in checkpoint '
802 'to keep track of steps and epochs.')
803 return self._run_counter
804
805 def run(self,
806 distributed_train_function,
807 *args,
808 **kwargs):
809 """Runs a training function with error and preemption handling.
810
811 This function handles the preemption signal from any peer in the cluster by
812 saving the training progress and exiting gracefully. It will
813 also broadcase any program error encountered during the execution of
814 `distributed_train_function` to all workers so that they can raise the same
815 error.
816
817 The `distributed_train_function` argument should be a distributed train
818 function (i.e., containing a call to `tf.distribute.Strategy.run`). For
819 `tf.distribute.MultiWorkerMirroredStrategy` users, we recommend passing in a
820 single-step `distributed_train_function` to
821 `PreemptionCheckpointHandler.run` so that the checkpoint can be saved in
822 time in case a preemption signal or maintenance notice is sent.
823
824 Besides the preemption and error handling part,
825 `PreemptionCheckpointHandler.run(distributed_train_function, *args,
826 **kwargs)` has the same effect and output as
827 `distributed_train_function(*args, **kwargs)`. `distributed_train_function`
828 can return either some or no result. The following is a shortened example:
829
830 ```python
831
832 @tf.function
833 def distributed_train_step(iterator):
834 # A distributed single-step training function.
835
836 def step_fn(inputs):
837 # A per-replica single-step training function.
838 x, y = inputs
839 ...
840 return loss
841
842 per_replica_losses = strategy.run(step_fn, args=(next(iterator),))
843 return strategy.reduce(
844 tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
845
846 for epoch in range(preemption_handler.total_run_calls // STEPS_PER_EPOCH,
847 EPOCHS_TO_RUN):
848 iterator = iter(multi_worker_dataset)
849 total_loss = 0.0
850 num_batches = 0
851
852 for step in range(preemption_handler.total_run_calls % STEPS_PER_EPOCH,
853 STEPS_PER_EPOCH):
854 total_loss += preemption_handler.run(distributed_train_step)
855 num_batches += 1
856
857 train_loss = total_loss / num_batches
858 print('Epoch: %d, train_loss: %f.' %(epoch.numpy(), train_loss))
859
860 train_accuracy.reset_states()
861 ```
862
863 Args:
864 distributed_train_function: A (single-step) distributed training function.
865 *args: args for `distributed_train_function`.
866 **kwargs: kwargs for `distributed_train_function`.
867
868 Raises:
869 Program error encountered by any member in the cluster while executing the
870 `distributed_train_function`, or any error from the program error
871 propagation process.
872
873 Returns:
874 Result of running the `distributed_train_function`.
875 """
876 # TODO(wxinyi): after we support use with TPUStrategy, we should expand the
877 # API doc to state that `distributed_train_function` does not need to be a
878 # single-step training function, since a multi-step host-training loop is
879 # the dominant use case for TPU user. Besides, passing in a multi-step
880 # `distributed_train_function` will require the user to track their own
881 # training steps.
882 if (
883 self._platform_device
884 == failure_handling_util.PlatformDevice.INTERNAL_TPU
885 ):
886 return self._run_for_tpu(distributed_train_function, *args, **kwargs)
887 elif self._platform_device in (
888 failure_handling_util.PlatformDevice.GCE_TPU,
889 failure_handling_util.PlatformDevice.GCE_CPU,
890 ):
891 return distributed_train_function(*args, **kwargs)
892 else:
893 return self._run_for_multi_worker_mirrored(
894 distributed_train_function, *args, **kwargs
895 )
896
897 def _run_for_tpu(self, distributed_train_function, *args, **kwargs):
898 """PreemptionCheckpointHandler.run implementation for TPUStrategy."""
899 gen_check_preemption_op.check_preemption(preemption_key=PREEMPTION_KEY)
900 return distributed_train_function(*args, **kwargs)
901
902 def _run_for_multi_worker_mirrored(
903 self, distributed_train_function, *args, **kwargs
904 ):
905 """PreemptionCheckpointHandler.run implementation for MWMS."""
906 try:
907 self._check_preemption_and_maybe_checkpoint()
908 run_begin_time = time.time()
909 result = distributed_train_function(*args, **kwargs)
910 new_run_time = time.time() - run_begin_time
911 self._run_counter += 1
912 # Update the average run time with the new run.
913 self._estimated_run_time = self._estimated_run_time + (
914 new_run_time - self._estimated_run_time) / self._run_counter
915
916 except errors.OpError as e:
917 if not self._local_mode:
918 logging.info('Propagating error to cluster: %r: %s', e, e)
919 try:
920 context.context().report_error_to_cluster(e.error_code, e.message)
921 except Exception as ex: # pylint: disable=broad-except
922 logging.info('Ignoring error during error propagation: %r:%s', ex, ex)
923 raise
924
925 return result
926
927 # Disabling line-too-long check since we do not want to break the line when
928 # converted to public documentation.
929 # pylint: disable=line-too-long
930 def save_checkpoint_if_preempted(self, *args, **kwargs):
931 """Saves a checkpoint if a preemption signal has been made available.
932
933 This is an alternative API for `PreemptionCheckpointHandler.run` and
934 `PreemptionCheckpointHandler.watch_preemption_scope`. This method works for
935 both `tf.distribute.MultiWorkerMirroredStrategy` and
936 `tf.distribute.TPUStrategy`. However, **for TPUStrategy, this method will
937 add a synchronization point between workers and the coordinator** and thus
938 may have performance implication. If this is a concern, use the combination
939 of `PreemptionCheckpointHandler.watch_preemption_scope` and
940 `PreemptionCheckpointHandler.run` instead.
941
942 ```python
943 strategy = tf.distribute.TPUStrategy(tpu_cluster_resolver)
944 # initialization omitted
945
946 with strategy.scope():
947 # Save in the checkpoint.
948 trained_step = tf.Variable(initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='trained_step', aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
949
950 checkpoint_manager = tf.train.CheckpointManager(checkpoint, directory, max_to_keep=1)
951 preemption_handler = tf.distribute.experimental.PreemptionCheckpointHandler(cluster_resolver, checkpoint_manager)
952
953 while trained_step.numpy() < NUM_STEPS:
954 # Train STEPS_IN_FUNCTION steps at once.
955 train_multi_step_function()
956 trained_step.assign_add(STEPS_IN_FUNCTION)
957 preemption_handler.save_checkpoint_if_preempted()
958 ```
959
960 Args:
961 *args: args for `tf.train.CheckpointManager.save()` to save checkpoint.
962 **kwargs: kwargs for `tf.train.CheckpointManager.save()` to save.
963 """
964 # pylint: enable=line-too-long
965 if (self._platform_device ==
966 failure_handling_util.PlatformDevice.INTERNAL_TPU):
967
968 try:
969 with context.async_scope():
970 gen_check_preemption_op.check_preemption(
971 preemption_key=PREEMPTION_KEY)
972 except errors.AbortedError as abort_error:
973 if abort_error.experimental_payloads.get(
974 b'type.googleapis.com/tensorflow.distributed_runtime.WorkerPreemption'
975 ):
976 logging.info('Clearing preemption error to save checkpoint...')
977
978 context.async_clear_error()
979 self._save_checkpoint(*args, **kwargs)
980
981 # For TPU training, the default behavior is that it will block until
982 # workers are down and returns with error.
983 self._exit_fn()
984
985 else:
986 raise
987
988 elif self._platform_device in (
989 failure_handling_util.PlatformDevice.GCE_TPU,
990 failure_handling_util.PlatformDevice.GCE_CPU,
991 ):
992 return
993
994 else:
995 self._check_preemption_and_maybe_checkpoint(*args, **kwargs)
996 self._run_counter += 1
997 self._estimated_run_time = 0
998
999 @tf_contextlib.contextmanager
1000 def watch_preemption_scope(self):
1001 """Syncs error and maybe save checkpoint for usage with TPUStrategy.
1002
1003 Note: Usage with `tf.distribute.MultiWorkerMirroredStrategy` does not need
1004 this API.
1005
1006 Example usage:
1007
1008 ```python
1009 with preemption_checkpoint_handler.watch_preemption_scope():
1010 while trained_step.numpy() < NUM_STEPS:
1011
1012 # distributed_train_function contains a call to strategy.run.
1013 loss += preemption_checkpoint_handler.run(distributed_train_function, args=(next(iterator),))
1014 trained_step.assign_add(STEPS_PER_TRAIN_FUNCTION)
1015 ```
1016
1017 In this workflow, `PreemptionCheckpointHandler.run` will flag preemption
1018 signal received, and `watch_preemption_scope` will handle the preemption
1019 signal by saving a checkpoint and then either exit to restart or execute a
1020 user-passed `exit_fn` in `tf.distribute.experimental.TerminationConfig`. If
1021 no preemption signal is received during execution of ops and function inside
1022 the scope, `watch_preemption_scope` ensures the completion of all async op
1023 and function execution when exiting and will raises exceptions if async
1024 execution results in an error state.
1025
1026 Yields:
1027 None
1028 """
1029 if self._platform_device == failure_handling_util.PlatformDevice.INTERNAL_TPU:
1030 try:
1031 with context.async_scope():
1032 yield
1033 except errors.AbortedError as abort_error:
1034 if abort_error.experimental_payloads.get(
1035 b'type.googleapis.com/tensorflow.distributed_runtime.WorkerPreemption'
1036 ):
1037 logging.info('Clearing preemption error to save checkpoint...')
1038
1039 context.async_clear_error()
1040 self._save_checkpoint()
1041
1042 self._exit_fn()
1043
1044 else:
1045 raise
1046 else:
1047 try:
1048 yield
1049 except errors.OpError as e:
1050 if not self._local_mode:
1051 logging.info('Propagating error to cluster: %r: %s', e, e)
1052 try:
1053 context.context().report_error_to_cluster(e.error_code, e.message)
1054 except Exception as ex: # pylint: disable=broad-except
1055 logging.info('Ignoring error during error propagation: %r:%s', ex, ex)
1056 raise
1057
1058 def _save_checkpoint(self, *args, **kwargs):
1059 """Saves the checkpoint and exit program."""
1060 distribute_lib.distribution_strategy_input_api_counter.get_cell(
1061 self._platform_device.name,
1062 'PreemptionCheckpointHandler Saving Checkpoint').increase_by(1)
1063 logging.info('PreemptionCheckpointHandler: Starting saving a checkpoint.')
1064
1065 if self._platform_device != failure_handling_util.PlatformDevice.INTERNAL_TPU:
1066 self._checkpointed_runs.assign(self.total_run_calls)
1067
1068 start_time = time.monotonic()
1069
1070 with checkpoint_context.preemption_save_context():
1071 if self._save_fn:
1072 self._save_fn(*args, **kwargs)
1073 else:
1074 self._write_checkpoint_manager.save(*args, **kwargs)
1075
1076 end_time = time.monotonic()
1077
1078 logging.info('Checkpoint finished at path %s',
1079 self._write_checkpoint_manager.directory)
1080 self._checkpoint_time = end_time - start_time
1081
1082 def _check_preemption_and_maybe_checkpoint(self, *args, **kwargs):
1083 """Checkpoint if any worker has received a preemption signal.
1084
1085 This function handles preemption signal reported by any worker in the
1086 cluster. The current implementation relies on the fact that all workers in a
1087 MultiWorkerMirroredStrategy training cluster have a step number difference
1088 maximum of 1.
1089 - If the signal comes from the worker itself (i.e., where this failure
1090 handler sits), the worker will notify all peers to checkpoint after they
1091 finish CURRENT_STEP+1 steps, where CURRENT_STEP is the step this worker has
1092 just finished. And the worker will wait for all peers to acknowledge that
1093 they have received its preemption signal and the final-step number before
1094 the worker proceeds on training the final step.
1095 - If the signal comes from another member in the cluster but NO final-step
1096 info is available, proceed on training, because it will be available after
1097 finishing the next step.
1098 - If the signal comes from some other member in the cluster, and final-step
1099 info is available, if the worker has not finished these steps yet, keep
1100 training; otherwise, checkpoint and exit with a cluster-recognized restart
1101 code.
1102
1103 Args:
1104 *args: args for `tf.train.CheckpointManager.save()` to save checkpoint.
1105 **kwargs: kwargs for `tf.train.CheckpointManager.save()` to save.
1106 """
1107 if self._platform_device == failure_handling_util.PlatformDevice.INTERNAL_TPU:
1108 gen_check_preemption_op.check_preemption(preemption_key=PREEMPTION_KEY)
1109 return
1110
1111 if self._final_checkpoint_countdown:
1112 run_count_config_key = _FINAL_RUN_COUNT_KEY
1113
1114 else:
1115 run_count_config_key = _INITIAL_RUN_COUNT_KEY
1116
1117 if self._received_checkpoint_step.is_set():
1118
1119 if self._step_to_checkpoint == str(self._run_counter):
1120 self._save_checkpoint(*args, **kwargs)
1121
1122 if self._time_to_exit():
1123 self._stop_poll_termination_signal_thread()
1124 self._stop_cluster_wise_termination_watcher_thread()
1125 if self._api_made_checkpoint_manager and not self._is_chief:
1126 gfile.DeleteRecursively(
1127 os.path.dirname(self._write_checkpoint_manager.directory))
1128 logging.info(
1129 'PreemptionCheckpointHandler: checkpoint saved. Exiting.')
1130
1131 self._exit_fn()
1132
1133 else:
1134 logging.info('Continue training for the grace period.')
1135 self._final_checkpoint_countdown = True
1136 self._received_checkpoint_step.clear()
1137
1138 elif self._received_own_sigterm.is_set():
1139 # Only the worker who gets termination signal first among the cluster
1140 # will enter this branch. The following will happen in chronological
1141 # order:
1142 # 1. The worker just receives a preemption signal and enters this branch
1143 # for the first time. It will set a step-to-checkpoint and let the cluster
1144 # know.
1145 # 2. If there is a long grace period, it will also set
1146 # _final_checkpoint_countdown, so that during this grace period, it will
1147 # re-enter this branch to check if grace period is ending.
1148 # 3. If it is, set a step-to-checkpoint key again.
1149
1150 if self._final_checkpoint_countdown:
1151 if self._target_time_for_termination < time.time():
1152 logging.info(
1153 'Grace period almost ended. Final call to save a checkpoint!')
1154 else:
1155 return
1156
1157 step_to_save_at = str(self._run_counter + 1)
1158
1159 logging.info('Termination caught in main thread on preempted worker')
1160
1161 if self._local_mode:
1162 self._step_to_checkpoint = step_to_save_at
1163 self._received_checkpoint_step.set()
1164
1165 else:
1166 context.context().set_config_key_value(run_count_config_key,
1167 step_to_save_at)
1168 logging.info('%s set to %s', run_count_config_key, step_to_save_at)
1169
1170 if not self._local_mode:
1171 worker_count = multi_worker_util.worker_count(
1172 self._cluster_resolver.cluster_spec(),
1173 self._cluster_resolver.task_type)
1174 for i in range(worker_count):
1175 context.context().get_config_key_value(
1176 f'{_ACKNOWLEDGE_KEY}_{run_count_config_key}_{i}')
1177 logging.info('Sigterm acknowledgement from replica %d received', i)
1178
1179 self._setup_countdown_if_has_grace_period_and_not_already_counting_down()
1180
1181 def _time_to_exit(self):
1182 """Return whether to exit: exit if no grace period or grace period ends."""
1183 # we should directly exit in either of the two cases:
1184 # 1. if no grace period is provided;
1185 # 2. if there is a grace period, and we're in countdown period. This,
1186 # together with the fact that _received_checkpoint_step is set (again),
1187 # means it's time to exit: when there is a grace period, a worker
1188 # receives preemption signal and sets the step key. Then all workers
1189 # receive the step key and set their local _received_checkpoint_step
1190 # event, enters this branch in _check_preemption_and_maybe_checkpoint, make
1191 # a checkpoint. Then they set _final_checkpoint_countdown to True, clear
1192 # _received_checkpoint_step, and continue training. New preemption
1193 # signals anywhere in the cluster will not be handled, because
1194 # _PREEMPTION_WORKER_KEY is occupied. The only chance that
1195 # _received_checkpoint_step gets set again is when the worker who has
1196 # received the preemption signal earlier decide it's time to do a final
1197 # checkpoint (by checking if it already passes
1198 # _target_time_for_termination). It will upload a final step key. All
1199 # workers receive this key and again set _received_checkpoint_step. So,
1200 # if we found out that _received_checkpoint_step is set, and also
1201 # _final_checkpoint_countdown is true, it's checkpoint and exit time.
1202 return (self._grace_period <= 0) or self._final_checkpoint_countdown
1203
1204 def _setup_countdown_if_has_grace_period_and_not_already_counting_down(self):
1205 """Set up at the beginning of a countdown period for long grace period."""
1206 if self._grace_period > 0 and not self._final_checkpoint_countdown:
1207 # A factor to provide more buffer / inaccuracy.
1208 # TODO(wxinyi): update buffer_factor as needed. Maybe deduct a constant.
1209 buffer_factor = 3
1210 # Timing by 2 since while the preempted worker needs to do 1 extra step
1211 # when time_till_final_call <=0, other workers might need to do x step
1212 # where 0<x<2
1213 self._target_time_for_termination = (
1214 self._received_own_sigterm_time + self._grace_period -
1215 buffer_factor * self._estimated_run_time * 2)
1216
1217 def _sigterm_handler_fn(self, signum, frame):
1218 """Upload the to-be-preempted worker's id to coordination service."""
1219 del signum, frame
1220 self._maybe_set_received_own_sigterm()
1221
1222 def _watch_step_to_save_key(self):
1223 """Watch out for step-to-save config key and acknowledge.
1224
1225 All workers, including the one to be preempted, execute this function to get
1226 step-to-save.
1227 """
1228
1229 step_value = context.context().get_config_key_value(_INITIAL_RUN_COUNT_KEY)
1230
1231 # get_config_key_value does not return until it gets some result. Thus at
1232 # the time to clean up, we upload a _STOP_WATCHING_CLUSTER_VALUE as the
1233 # value so we can join the thread executing _watch_step_to_save_key.
1234 if step_value != _STOP_WATCHING_CLUSTER_VALUE:
1235 # This must be set before we set the ack key below, otherwise its value
1236 # in _check_preemption_and_maybe_checkpoint may be outdated.
1237 self._step_to_checkpoint = step_value
1238 self._received_checkpoint_step.set()
1239
1240 ack_key = f'{_ACKNOWLEDGE_KEY}_{_INITIAL_RUN_COUNT_KEY}_{self._id_in_cluster}'
1241 context.context().set_config_key_value(ack_key, '1')
1242 logging.info(
1243 'PreemptionCheckpointHandler: %s set, '
1244 'preemption awareness acknowledged', ack_key)
1245
1246 # If a positive grace_period is not configured, we get the
1247 # _INITIAL_RUN_COUNT_KEY and then we're done.
1248 # _check_preemption_and_maybe_checkpoint
1249 # will save a checkpoint and then exit. Otherwise, we need to move on to
1250 # wait for the _FINAL_RUN_COUNT_KEY, the one that the preempted worker
1251 # will set after we utilize the extended grace period to train, so that
1252 # a final checkpoint should be made right before the termination.
1253 if self._grace_period > 0:
1254 # Continue to wait until a final call is made.
1255 final_step_value = context.context().get_config_key_value(
1256 _FINAL_RUN_COUNT_KEY)
1257 if final_step_value != _STOP_WATCHING_CLUSTER_VALUE:
1258 ack_key = f'{_ACKNOWLEDGE_KEY}_{_FINAL_RUN_COUNT_KEY}_{self._id_in_cluster}'
1259 context.context().set_config_key_value(ack_key, '1')
1260 logging.info(
1261 'PreemptionCheckpointHandler: %s acknowledged, final '
1262 'checkpoint timing received.', ack_key)
1263 self._received_checkpoint_step.set()
1264 self._step_to_checkpoint = final_step_value
1265
1266# TODO(wxinyi): remove this line after we move the Keras callback prototype and
1267# change gce test usage.
1268WorkerPreemptionHandler = PreemptionCheckpointHandler