Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/distribute/worker_training_state.py: 28%
60 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training state management."""
17import os
19import tensorflow.compat.v2 as tf
21from keras.src import backend
22from keras.src.distribute import distributed_file_utils
23from keras.src.utils import mode_keys
25# isort: off
26from keras.src.distribute.distributed_file_utils import (
27 support_on_demand_checkpoint_callback,
28) # noqa: E501
31MAX_CHECKPOINT_TO_KEEP = 1
34class WorkerTrainingState:
35 """Training state management class.
37 This class provides apis for backing up and restoring the training state.
38 This allows model and epoch and batch information to be saved periodically
39 and restore for fault-tolerance, also known as preemption-recovery purpose.
40 """
42 # Constant for `tf.keras.Model` attribute to store the epoch and batch
43 # at which the most recently saved checkpoint was saved.
44 CKPT_SAVED_EPOCH_UNUSED_VALUE = -1
46 CKPT_SAVED_BATCH_UNUSED_VALUE = -1
48 def __init__(
49 self,
50 model,
51 checkpoint_dir,
52 save_freq="epoch",
53 save_before_preemption_arg=None,
54 ):
55 self._enable_save_before_preemption = save_before_preemption_arg and (
56 support_on_demand_checkpoint_callback(model.distribute_strategy)
57 )
58 self._model = model
60 self._save_freq = save_freq
61 # The batch and epoch at which the checkpoint is saved. Used for
62 # fault-tolerance. GPU device only has int64 dtype registered
63 # VarHandleOp.
64 self._ckpt_saved_epoch = tf.Variable(
65 initial_value=tf.constant(
66 self.CKPT_SAVED_EPOCH_UNUSED_VALUE, dtype=tf.int64
67 ),
68 name="ckpt_saved_epoch",
69 )
70 self._ckpt_saved_batch = tf.Variable(
71 initial_value=tf.constant(
72 self.CKPT_SAVED_BATCH_UNUSED_VALUE, dtype=tf.int64
73 ),
74 name="ckpt_saved_batch",
75 )
76 # Variable initialization.
77 backend.set_value(
78 self._ckpt_saved_epoch, self.CKPT_SAVED_EPOCH_UNUSED_VALUE
79 )
80 backend.set_value(
81 self._ckpt_saved_batch, self.CKPT_SAVED_BATCH_UNUSED_VALUE
82 )
83 # _ckpt_saved_epoch and _ckpt_saved_batch gets tracked and is included
84 # in the checkpoint file when backing up.
85 checkpoint = tf.train.Checkpoint(
86 model=self._model,
87 ckpt_saved_epoch=self._ckpt_saved_epoch,
88 ckpt_saved_batch=self._ckpt_saved_batch,
89 train_counter=self._model._train_counter,
90 )
92 # If this is single-worker training, checkpoint_dir are the same for
93 # write_checkpoint_manager and read_checkpoint_manager.
94 #
95 # If this is multi-worker training, and this worker should not save
96 # checkpoint, we replace the write_checkpoint_manager's checkpoint_dir
97 # with a temp filepath, so it writes to a file that will be removed at
98 # the end of back_up() call. This is necessary because the
99 # SyncOnReadVariable needs to be synced across all the workers in order
100 # to be read, and all workers need to perform `save()`. But all workers
101 # should restore from the same checkpoint_dir as passed in
102 # read_checkpoint_manager.
103 self.read_checkpoint_manager = tf.train.CheckpointManager(
104 checkpoint,
105 directory=os.path.join(checkpoint_dir, "chief"),
106 max_to_keep=MAX_CHECKPOINT_TO_KEEP,
107 )
108 write_checkpoint_dir = distributed_file_utils.write_dirpath(
109 checkpoint_dir, self._model.distribute_strategy
110 )
111 if self._model.distribute_strategy.extended.should_checkpoint:
112 self.write_checkpoint_manager = self.read_checkpoint_manager
113 else:
114 self.write_checkpoint_manager = tf.train.CheckpointManager(
115 checkpoint,
116 directory=write_checkpoint_dir,
117 max_to_keep=MAX_CHECKPOINT_TO_KEEP,
118 )
120 if self._enable_save_before_preemption:
121 self.preemption_handler = (
122 tf.distribute.experimental.PreemptionCheckpointHandler(
123 self._model.distribute_strategy.cluster_resolver,
124 self.write_checkpoint_manager,
125 )
126 )
127 self.preemption_handler._read_checkpoint_manager = (
128 self.read_checkpoint_manager
129 )
130 self._model._preemption_handler = self.preemption_handler
132 def back_up(self, epoch, batch=0):
133 """Back up the current state of training into a checkpoint file.
135 Args:
136 epoch: The current epoch information to be saved.
137 batch: The current batch(step) information to be saved.
138 """
139 # Save the model plus CKPT_SAVED_EPOCH and CKPT_SAVED_BATCH variable.
140 if self.write_checkpoint_manager.save():
141 distributed_file_utils.remove_temp_dirpath(
142 self.write_checkpoint_manager.directory,
143 self._model.distribute_strategy,
144 )
146 def backup_if_preempted(self):
147 if self._enable_save_before_preemption:
148 self.preemption_handler._run_counter += 1
149 self.preemption_handler._check_preemption_and_maybe_checkpoint()
151 def restore(self):
152 """Restore the training state from the backed up checkpoint file.
154 Returns:
155 True if the training state is successfully restored. False if the
156 training state doesn't need to be restored, or error occurred so it
157 can't.
158 """
159 # When creating the PreemptionCheckpointHandler object, we have already
160 # restored the checkpoint.
161 if not self._enable_save_before_preemption:
162 self.read_checkpoint_manager.restore_or_initialize()
164 def delete_backup(self):
165 """Delete the backup directories.
167 Delete the backup directories which should not exist after `fit()`
168 successfully finishes.
169 """
170 if self.write_checkpoint_manager is self.read_checkpoint_manager:
171 try:
172 tf.io.gfile.rmtree(self.write_checkpoint_manager.directory)
173 except tf.errors.NotFoundError:
174 pass
176 def maybe_load_initial_counters_from_ckpt(
177 self, steps_per_epoch, initial_epoch, mode
178 ):
179 """Maybe load 1st epoch from checkpoint, considering worker recovery.
181 When `_ckpt_saved_epoch` attribute exists and is not
182 `CKPT_SAVED_EPOCH_UNUSED_VALUE`, this is under multi-worker training
183 setting and indicates the worker is recovering from previous failure. In
184 this case, infer `initial_epoch` from `self._ckpt_saved_epoch` to
185 continue previous unfinished training from certain epoch.
187 Args:
188 steps_per_epoch: The number of steps per epoch value.
189 initial_epoch: The original initial_epoch user passes in in `fit()`.
190 mode: The mode for running `model.fit()`.
192 Returns:
193 If the training is recovering from previous failure under multi-worker
194 training setting, return the (epoch, step) the training is supposed to
195 continue at. Otherwise, return the `initial_epoch, initial_step` the
196 user passes in.
197 """
199 initial_step = 0
200 epoch = backend.eval(self._ckpt_saved_epoch)
201 batch = backend.eval(self._ckpt_saved_batch)
202 if mode == mode_keys.ModeKeys.TRAIN:
203 # For batch-level saving
204 if self._enable_save_before_preemption or isinstance(
205 self._save_freq, int
206 ):
207 if batch >= 0:
208 # If the checkpoint was last saved at last batch of the
209 # epoch, return the next epoch number and batch=0
210 if batch == steps_per_epoch - 1:
211 initial_epoch = epoch + 1
212 initial_step = 0
213 else:
214 # If the checkpoint was not last saved at last batch of
215 # the epoch, return the same epoch and next batch number
216 initial_epoch = epoch
217 initial_step = batch + 1
218 else:
219 if epoch >= 0:
220 # The most recently saved epoch is one epoch prior to the
221 # epoch it failed at, so return the value of
222 # 'self._ckpt_saved_epoch' plus one.
223 initial_epoch = epoch + 1
225 return (initial_epoch, initial_step)