Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/failure_handling/preemption_watcher.py: 54%

39 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Provides a utility class for preemption detection and recovery.""" 

16 

17import threading 

18 

19from absl import logging 

20 

21from tensorflow.python.distribute.failure_handling.failure_handling_util import detect_platform 

22from tensorflow.python.distribute.failure_handling.failure_handling_util import PlatformDevice 

23from tensorflow.python.eager import context 

24from tensorflow.python.eager import monitoring 

25from tensorflow.python.framework.errors import AbortedError 

26from tensorflow.python.framework.errors import CancelledError 

27from tensorflow.python.framework.errors import UnavailableError 

28from tensorflow.python.util.tf_export import tf_export 

29 

30 

31_preemption_watcher_initialization_counter = monitoring.Counter( 

32 "/tensorflow/api/distribution_strategy/preemption_watcher_initialized", 

33 "Counter for usages of PreemptionWatcher", 

34) 

35_preemption_handling_counter = monitoring.Counter( 

36 "/tensorflow/api/distribution_strategy/preemption_watcher_handled", 

37 "Counter for number of preempions catched and handled by PreemptionWatcher", 

38) 

39 

40_PREEMPTION_KEY = "TF_DEFAULT_PREEMPTION_NOTICE_KEY" 

41 

42 

43@tf_export("distribute.experimental.PreemptionWatcher", v1=[]) 

44class PreemptionWatcher: 

45 """Watch preemption signal and store it. 

46 

47 Notice: Currently only support Borg TPU environment with TPUClusterResolver. 

48 

49 This class provides a way to monitor the preemption signal during training on 

50 TPU. It will start a background thread to watch the training process, trying 

51 to fetch preemption message from the coordination service. When preemption 

52 happens, the preempted worker will write the preemption message to the 

53 coordination service. Thus getting a non-empty preemption message means there 

54 is a preemption happened. 

55 

56 User can use the preemption message as a reliable preemption indicator, and 

57 then set the coordinator to reconnect to the TPU worker instead of a fully 

58 restart triggered by Borg. For example, a training process with 

59 preemption recovery will be like: 

60 

61 ```python 

62 keep_running = True 

63 preemption_watcher = None 

64 while keep_running: 

65 try: 

66 # Initialize TPU cluster and stratygy. 

67 resolver = tf.distribute.cluster_resolver.TPUClusterResolver() 

68 tf.config.experimental_connect_to_cluster(resolver) 

69 tf.tpu.experimental.initialize_tpu_system(resolver) 

70 strategy = tf.distribute.TPUStrategy(resolver) 

71 

72 # PreemptionWatcher must be created after connected to cluster. 

73 preemption_watcher = tf.distribute.experimental.PreemptionWatcher() 

74 train_model(strategy) 

75 keep_running = False 

76 except Exception as e: 

77 if preemption_watcher and preemption_watcher.preemption_message: 

78 preemption_watcher.block_until_worker_exit() 

79 keep_running = True 

80 else: 

81 raise e 

82 ``` 

83 

84 Attributes: 

85 preemption_message: A variable to store the preemption message fetched from 

86 the coordination service. If it is not None, then there is a preemption 

87 happened. 

88 platform: A PlatformDevice to indicate the current job's platform. Refer to 

89 failure_handling_util.py for the definition of enum class PlatformDevice. 

90 """ 

91 

92 def __init__(self): 

93 # TODO(b/254321514): Integrate with GPU and cloud enviornmenmt. 

94 self._preemption_message = None 

95 self._platform = detect_platform() 

96 if self._platform != PlatformDevice.INTERNAL_TPU: 

97 logging.warning( 

98 "Preemption watcher does not support environment: %s", self._platform 

99 ) 

100 else: 

101 _preemption_watcher_initialization_counter.get_cell().increase_by(1) 

102 threading.Thread(target=self._watch_preemption_key, daemon=True).start() 

103 

104 @property 

105 def preemption_message(self): 

106 """Returns the preemption message.""" 

107 return self._preemption_message 

108 

109 def _watch_preemption_key(self): 

110 logging.info("Watching preemption signal.") 

111 message = context.context().get_config_key_value(_PREEMPTION_KEY) 

112 _preemption_handling_counter.get_cell().increase_by(1) 

113 logging.info("Preemption signal received.") 

114 self._preemption_message = message 

115 

116 def block_until_worker_exit(self): 

117 """Block coordinator until workers exit. 

118 

119 In some rare cases, another error could be raised during the 

120 preemption grace period. This will cause the coordinator to reconnect to the 

121 same TPU workers, which will be killed later. It prevents the coordinator to 

122 reconnect to new TPU workers, and falls back to a hard restart. To avoid 

123 this situation, this method will block the coordinator to reconnect until 

124 workers exit. This method will be a no-op for non-TPU platform. 

125 """ 

126 if self._platform != PlatformDevice.INTERNAL_TPU: 

127 return 

128 try: 

129 context.context().get_config_key_value("BLOCK_TILL_EXIT") 

130 except (AbortedError, CancelledError, UnavailableError): 

131 logging.info("Workers exited.")