1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Provides a utility class for preemption detection and recovery."""
16
17import threading
18
19from absl import logging
20
21from tensorflow.python.distribute.failure_handling.failure_handling_util import detect_platform
22from tensorflow.python.distribute.failure_handling.failure_handling_util import PlatformDevice
23from tensorflow.python.eager import context
24from tensorflow.python.eager import monitoring
25from tensorflow.python.framework.errors import AbortedError
26from tensorflow.python.framework.errors import CancelledError
27from tensorflow.python.framework.errors import UnavailableError
28from tensorflow.python.util.tf_export import tf_export
29
30
31_preemption_watcher_initialization_counter = monitoring.Counter(
32 "/tensorflow/api/distribution_strategy/preemption_watcher_initialized",
33 "Counter for usages of PreemptionWatcher",
34)
35_preemption_handling_counter = monitoring.Counter(
36 "/tensorflow/api/distribution_strategy/preemption_watcher_handled",
37 "Counter for number of preempions catched and handled by PreemptionWatcher",
38)
39
40_PREEMPTION_KEY = "TF_DEFAULT_PREEMPTION_NOTICE_KEY"
41
42
43@tf_export("distribute.experimental.PreemptionWatcher", v1=[])
44class PreemptionWatcher:
45 """Watch preemption signal and store it.
46
47 Notice: Currently only support Borg TPU environment with TPUClusterResolver.
48
49 This class provides a way to monitor the preemption signal during training on
50 TPU. It will start a background thread to watch the training process, trying
51 to fetch preemption message from the coordination service. When preemption
52 happens, the preempted worker will write the preemption message to the
53 coordination service. Thus getting a non-empty preemption message means there
54 is a preemption happened.
55
56 User can use the preemption message as a reliable preemption indicator, and
57 then set the coordinator to reconnect to the TPU worker instead of a fully
58 restart triggered by Borg. For example, a training process with
59 preemption recovery will be like:
60
61 ```python
62 keep_running = True
63 preemption_watcher = None
64 while keep_running:
65 try:
66 # Initialize TPU cluster and stratygy.
67 resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
68 tf.config.experimental_connect_to_cluster(resolver)
69 tf.tpu.experimental.initialize_tpu_system(resolver)
70 strategy = tf.distribute.TPUStrategy(resolver)
71
72 # PreemptionWatcher must be created after connected to cluster.
73 preemption_watcher = tf.distribute.experimental.PreemptionWatcher()
74 train_model(strategy)
75 keep_running = False
76 except Exception as e:
77 if preemption_watcher and preemption_watcher.preemption_message:
78 preemption_watcher.block_until_worker_exit()
79 keep_running = True
80 else:
81 raise e
82 ```
83
84 Attributes:
85 preemption_message: A variable to store the preemption message fetched from
86 the coordination service. If it is not None, then there is a preemption
87 happened.
88 platform: A PlatformDevice to indicate the current job's platform. Refer to
89 failure_handling_util.py for the definition of enum class PlatformDevice.
90 """
91
92 def __init__(self):
93 # TODO(b/254321514): Integrate with GPU and cloud enviornmenmt.
94 self._preemption_message = None
95 self._platform = detect_platform()
96 if self._platform != PlatformDevice.INTERNAL_TPU:
97 logging.warning(
98 "Preemption watcher does not support environment: %s", self._platform
99 )
100 else:
101 _preemption_watcher_initialization_counter.get_cell().increase_by(1)
102 threading.Thread(target=self._watch_preemption_key, daemon=True).start()
103
104 @property
105 def preemption_message(self):
106 """Returns the preemption message."""
107 return self._preemption_message
108
109 def _watch_preemption_key(self):
110 logging.info("Watching preemption signal.")
111 message = context.context().get_config_key_value(_PREEMPTION_KEY)
112 _preemption_handling_counter.get_cell().increase_by(1)
113 logging.info("Preemption signal received.")
114 self._preemption_message = message
115
116 def block_until_worker_exit(self):
117 """Block coordinator until workers exit.
118
119 In some rare cases, another error could be raised during the
120 preemption grace period. This will cause the coordinator to reconnect to the
121 same TPU workers, which will be killed later. It prevents the coordinator to
122 reconnect to new TPU workers, and falls back to a hard restart. To avoid
123 this situation, this method will block the coordinator to reconnect until
124 workers exit. This method will be a no-op for non-TPU platform.
125 """
126 if self._platform != PlatformDevice.INTERNAL_TPU:
127 return
128 try:
129 context.context().get_config_key_value("BLOCK_TILL_EXIT")
130 except (AbortedError, CancelledError, UnavailableError):
131 logging.info("Workers exited.")