Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/distribute/worker_training_state.py: 28%

60 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Training state management.""" 

16 

17import os 

18 

19import tensorflow.compat.v2 as tf 

20 

21from keras.src import backend 

22from keras.src.distribute import distributed_file_utils 

23from keras.src.utils import mode_keys 

24 

25# isort: off 

26from keras.src.distribute.distributed_file_utils import ( 

27 support_on_demand_checkpoint_callback, 

28) # noqa: E501 

29 

30 

31MAX_CHECKPOINT_TO_KEEP = 1 

32 

33 

34class WorkerTrainingState: 

35 """Training state management class. 

36 

37 This class provides apis for backing up and restoring the training state. 

38 This allows model and epoch and batch information to be saved periodically 

39 and restore for fault-tolerance, also known as preemption-recovery purpose. 

40 """ 

41 

42 # Constant for `tf.keras.Model` attribute to store the epoch and batch 

43 # at which the most recently saved checkpoint was saved. 

44 CKPT_SAVED_EPOCH_UNUSED_VALUE = -1 

45 

46 CKPT_SAVED_BATCH_UNUSED_VALUE = -1 

47 

48 def __init__( 

49 self, 

50 model, 

51 checkpoint_dir, 

52 save_freq="epoch", 

53 save_before_preemption_arg=None, 

54 ): 

55 self._enable_save_before_preemption = save_before_preemption_arg and ( 

56 support_on_demand_checkpoint_callback(model.distribute_strategy) 

57 ) 

58 self._model = model 

59 

60 self._save_freq = save_freq 

61 # The batch and epoch at which the checkpoint is saved. Used for 

62 # fault-tolerance. GPU device only has int64 dtype registered 

63 # VarHandleOp. 

64 self._ckpt_saved_epoch = tf.Variable( 

65 initial_value=tf.constant( 

66 self.CKPT_SAVED_EPOCH_UNUSED_VALUE, dtype=tf.int64 

67 ), 

68 name="ckpt_saved_epoch", 

69 ) 

70 self._ckpt_saved_batch = tf.Variable( 

71 initial_value=tf.constant( 

72 self.CKPT_SAVED_BATCH_UNUSED_VALUE, dtype=tf.int64 

73 ), 

74 name="ckpt_saved_batch", 

75 ) 

76 # Variable initialization. 

77 backend.set_value( 

78 self._ckpt_saved_epoch, self.CKPT_SAVED_EPOCH_UNUSED_VALUE 

79 ) 

80 backend.set_value( 

81 self._ckpt_saved_batch, self.CKPT_SAVED_BATCH_UNUSED_VALUE 

82 ) 

83 # _ckpt_saved_epoch and _ckpt_saved_batch gets tracked and is included 

84 # in the checkpoint file when backing up. 

85 checkpoint = tf.train.Checkpoint( 

86 model=self._model, 

87 ckpt_saved_epoch=self._ckpt_saved_epoch, 

88 ckpt_saved_batch=self._ckpt_saved_batch, 

89 train_counter=self._model._train_counter, 

90 ) 

91 

92 # If this is single-worker training, checkpoint_dir are the same for 

93 # write_checkpoint_manager and read_checkpoint_manager. 

94 # 

95 # If this is multi-worker training, and this worker should not save 

96 # checkpoint, we replace the write_checkpoint_manager's checkpoint_dir 

97 # with a temp filepath, so it writes to a file that will be removed at 

98 # the end of back_up() call. This is necessary because the 

99 # SyncOnReadVariable needs to be synced across all the workers in order 

100 # to be read, and all workers need to perform `save()`. But all workers 

101 # should restore from the same checkpoint_dir as passed in 

102 # read_checkpoint_manager. 

103 self.read_checkpoint_manager = tf.train.CheckpointManager( 

104 checkpoint, 

105 directory=os.path.join(checkpoint_dir, "chief"), 

106 max_to_keep=MAX_CHECKPOINT_TO_KEEP, 

107 ) 

108 write_checkpoint_dir = distributed_file_utils.write_dirpath( 

109 checkpoint_dir, self._model.distribute_strategy 

110 ) 

111 if self._model.distribute_strategy.extended.should_checkpoint: 

112 self.write_checkpoint_manager = self.read_checkpoint_manager 

113 else: 

114 self.write_checkpoint_manager = tf.train.CheckpointManager( 

115 checkpoint, 

116 directory=write_checkpoint_dir, 

117 max_to_keep=MAX_CHECKPOINT_TO_KEEP, 

118 ) 

119 

120 if self._enable_save_before_preemption: 

121 self.preemption_handler = ( 

122 tf.distribute.experimental.PreemptionCheckpointHandler( 

123 self._model.distribute_strategy.cluster_resolver, 

124 self.write_checkpoint_manager, 

125 ) 

126 ) 

127 self.preemption_handler._read_checkpoint_manager = ( 

128 self.read_checkpoint_manager 

129 ) 

130 self._model._preemption_handler = self.preemption_handler 

131 

132 def back_up(self, epoch, batch=0): 

133 """Back up the current state of training into a checkpoint file. 

134 

135 Args: 

136 epoch: The current epoch information to be saved. 

137 batch: The current batch(step) information to be saved. 

138 """ 

139 # Save the model plus CKPT_SAVED_EPOCH and CKPT_SAVED_BATCH variable. 

140 if self.write_checkpoint_manager.save(): 

141 distributed_file_utils.remove_temp_dirpath( 

142 self.write_checkpoint_manager.directory, 

143 self._model.distribute_strategy, 

144 ) 

145 

146 def backup_if_preempted(self): 

147 if self._enable_save_before_preemption: 

148 self.preemption_handler._run_counter += 1 

149 self.preemption_handler._check_preemption_and_maybe_checkpoint() 

150 

151 def restore(self): 

152 """Restore the training state from the backed up checkpoint file. 

153 

154 Returns: 

155 True if the training state is successfully restored. False if the 

156 training state doesn't need to be restored, or error occurred so it 

157 can't. 

158 """ 

159 # When creating the PreemptionCheckpointHandler object, we have already 

160 # restored the checkpoint. 

161 if not self._enable_save_before_preemption: 

162 self.read_checkpoint_manager.restore_or_initialize() 

163 

164 def delete_backup(self): 

165 """Delete the backup directories. 

166 

167 Delete the backup directories which should not exist after `fit()` 

168 successfully finishes. 

169 """ 

170 if self.write_checkpoint_manager is self.read_checkpoint_manager: 

171 try: 

172 tf.io.gfile.rmtree(self.write_checkpoint_manager.directory) 

173 except tf.errors.NotFoundError: 

174 pass 

175 

176 def maybe_load_initial_counters_from_ckpt( 

177 self, steps_per_epoch, initial_epoch, mode 

178 ): 

179 """Maybe load 1st epoch from checkpoint, considering worker recovery. 

180 

181 When `_ckpt_saved_epoch` attribute exists and is not 

182 `CKPT_SAVED_EPOCH_UNUSED_VALUE`, this is under multi-worker training 

183 setting and indicates the worker is recovering from previous failure. In 

184 this case, infer `initial_epoch` from `self._ckpt_saved_epoch` to 

185 continue previous unfinished training from certain epoch. 

186 

187 Args: 

188 steps_per_epoch: The number of steps per epoch value. 

189 initial_epoch: The original initial_epoch user passes in in `fit()`. 

190 mode: The mode for running `model.fit()`. 

191 

192 Returns: 

193 If the training is recovering from previous failure under multi-worker 

194 training setting, return the (epoch, step) the training is supposed to 

195 continue at. Otherwise, return the `initial_epoch, initial_step` the 

196 user passes in. 

197 """ 

198 

199 initial_step = 0 

200 epoch = backend.eval(self._ckpt_saved_epoch) 

201 batch = backend.eval(self._ckpt_saved_batch) 

202 if mode == mode_keys.ModeKeys.TRAIN: 

203 # For batch-level saving 

204 if self._enable_save_before_preemption or isinstance( 

205 self._save_freq, int 

206 ): 

207 if batch >= 0: 

208 # If the checkpoint was last saved at last batch of the 

209 # epoch, return the next epoch number and batch=0 

210 if batch == steps_per_epoch - 1: 

211 initial_epoch = epoch + 1 

212 initial_step = 0 

213 else: 

214 # If the checkpoint was not last saved at last batch of 

215 # the epoch, return the same epoch and next batch number 

216 initial_epoch = epoch 

217 initial_step = batch + 1 

218 else: 

219 if epoch >= 0: 

220 # The most recently saved epoch is one epoch prior to the 

221 # epoch it failed at, so return the value of 

222 # 'self._ckpt_saved_epoch' plus one. 

223 initial_epoch = epoch + 1 

224 

225 return (initial_epoch, initial_step) 

226