Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/distribute/worker_training_state.py: 48%

42 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Training state management.""" 

16 

17import os 

18from tensorflow.python.checkpoint import checkpoint as trackable_util 

19from tensorflow.python.checkpoint import checkpoint_management 

20from tensorflow.python.framework import constant_op 

21from tensorflow.python.framework import dtypes 

22from tensorflow.python.framework import errors 

23from tensorflow.python.keras import backend 

24from tensorflow.python.keras.distribute import distributed_file_utils 

25from tensorflow.python.keras.utils import mode_keys 

26from tensorflow.python.lib.io import file_io 

27from tensorflow.python.ops import variables 

28 

29# Constant for `tf.keras.Model` attribute to store the epoch at which the most 

30# recently saved checkpoint was saved. 

31CKPT_SAVED_EPOCH = '_ckpt_saved_epoch' 

32 

33CKPT_SAVED_EPOCH_UNUSED_VALUE = -1 

34 

35 

36class WorkerTrainingState(object): 

37 """Training state management class. 

38 

39 This class provides apis for backing up and restoring the training state. 

40 This allows model and epoch information to be saved periodically and restore 

41 for fault-tolerance, also known as preemption-recovery purpose. 

42 """ 

43 

44 def __init__(self, model, checkpoint_dir): 

45 self._model = model 

46 

47 # The epoch at which the checkpoint is saved. Used for fault-tolerance. 

48 # GPU device only has int64 dtype registered VarHandleOp. 

49 self._ckpt_saved_epoch = variables.Variable( 

50 initial_value=constant_op.constant( 

51 CKPT_SAVED_EPOCH_UNUSED_VALUE, dtype=dtypes.int64), 

52 name='ckpt_saved_epoch') 

53 

54 # Variable initialization. 

55 backend.set_value(self._ckpt_saved_epoch, CKPT_SAVED_EPOCH_UNUSED_VALUE) 

56 

57 # _ckpt_saved_epoch gets tracked and is included in the checkpoint file 

58 # when backing up. 

59 checkpoint = trackable_util.Checkpoint( 

60 model=self._model, ckpt_saved_epoch=self._ckpt_saved_epoch) 

61 

62 # If this is single-worker training, checkpoint_dir are the same for 

63 # write_checkpoint_manager and read_checkpoint_manager. 

64 # 

65 # If this is multi-worker training, and this worker should not 

66 # save checkpoint, we replace the write_checkpoint_manager's checkpoint_dir 

67 # with a temp filepath, so it writes to a file that will be removed at the 

68 # end of back_up() call. This is necessary because the SyncOnReadVariable 

69 # needs to be synced across all the workers in order to be read, and all 

70 # workers need to perform `save()`. 

71 # But all workers should restore from the same checkpoint_dir as passed in 

72 # read_checkpoint_manager. 

73 self.read_checkpoint_manager = checkpoint_management.CheckpointManager( 

74 checkpoint, 

75 directory=os.path.join(checkpoint_dir, 'chief'), 

76 max_to_keep=1) 

77 write_checkpoint_dir = distributed_file_utils.write_dirpath( 

78 checkpoint_dir, self._model.distribute_strategy) 

79 if self._model.distribute_strategy.extended.should_checkpoint: 

80 self.write_checkpoint_manager = self.read_checkpoint_manager 

81 else: 

82 self.write_checkpoint_manager = checkpoint_management.CheckpointManager( 

83 checkpoint, directory=write_checkpoint_dir, max_to_keep=1) 

84 

85 def back_up(self, epoch): 

86 """Back up the current state of training into a checkpoint file. 

87 

88 Args: 

89 epoch: The current epoch information to be saved. 

90 """ 

91 backend.set_value(self._ckpt_saved_epoch, epoch) 

92 # Save the model plus CKPT_SAVED_EPOCH variable. 

93 if self.write_checkpoint_manager.save(): 

94 distributed_file_utils.remove_temp_dirpath( 

95 self.write_checkpoint_manager.directory, 

96 self._model.distribute_strategy) 

97 

98 def restore(self): 

99 """Restore the training state from the backed up checkpoint file. 

100 

101 Returns: 

102 True if the training state is successfully restored. False if the training 

103 state doesn't need to be restored, or error occurred so it can't. 

104 """ 

105 self.read_checkpoint_manager.restore_or_initialize() 

106 

107 def delete_backup(self): 

108 """Delete the backup directories. 

109 

110 Delete the backup directories which should not exist after `fit()` 

111 successfully finishes. 

112 """ 

113 if self.write_checkpoint_manager is self.read_checkpoint_manager: 

114 try: 

115 file_io.delete_recursively_v2(self.write_checkpoint_manager.directory) 

116 except errors.NotFoundError: 

117 pass 

118 

119 def maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode): 

120 """Maybe load initial epoch from ckpt considering possible worker recovery. 

121 

122 When `_ckpt_saved_epoch` attribute exists and is not 

123 `CKPT_SAVED_EPOCH_UNUSED_VALUE`, this is under multi-worker training setting 

124 and indicates the worker is recovering from previous failure. In this case, 

125 infer `initial_epoch` from `self._ckpt_saved_epoch` to continue previous 

126 unfinished training from certain epoch. 

127 

128 Args: 

129 initial_epoch: The original initial_epoch user passes in in `fit()`. 

130 mode: The mode for running `model.fit()`. 

131 

132 Returns: 

133 If the training is recovering from previous failure under multi-worker 

134 training setting, return the epoch the training is supposed to continue 

135 at. Otherwise, return the `initial_epoch` the user passes in. 

136 """ 

137 

138 epoch = backend.eval(self._ckpt_saved_epoch) 

139 if mode == mode_keys.ModeKeys.TRAIN and epoch >= 0: 

140 # The most recently saved epoch is one epoch prior to the epoch it 

141 # failed at, so return the value of 'self._ckpt_saved_epoch' plus one. 

142 return epoch + 1 

143 return initial_epoch