Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/callbacks.py: 20%

1227 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16 

17"""Callbacks: utilities called at certain points during model training.""" 

18 

19import collections 

20import copy 

21import csv 

22import json 

23import os 

24import re 

25import sys 

26import time 

27 

28import numpy as np 

29import tensorflow.compat.v2 as tf 

30 

31from keras.src import backend 

32from keras.src.distribute import distributed_file_utils 

33from keras.src.distribute import worker_training_state 

34from keras.src.optimizers import optimizer 

35from keras.src.optimizers.schedules import learning_rate_schedule 

36from keras.src.utils import generic_utils 

37from keras.src.utils import io_utils 

38from keras.src.utils import tf_utils 

39from keras.src.utils import version_utils 

40from keras.src.utils.data_utils import Sequence 

41from keras.src.utils.generic_utils import Progbar 

42from keras.src.utils.mode_keys import ModeKeys 

43 

44# isort: off 

45from tensorflow.python.platform import tf_logging as logging 

46from tensorflow.python.util import deprecation 

47from tensorflow.python.util.tf_export import keras_export 

48from tensorflow.tools.docs import doc_controls 

49 

50try: 

51 import requests 

52except ImportError: 

53 requests = None 

54 

55 

56# Note: `configure_callbacks` is only used in TF1. 

57def configure_callbacks( 

58 callbacks, 

59 model, 

60 do_validation=False, 

61 batch_size=None, 

62 epochs=None, 

63 steps_per_epoch=None, 

64 samples=None, 

65 verbose=1, 

66 count_mode="steps", 

67 mode=ModeKeys.TRAIN, 

68): 

69 """Configures callbacks for use in various training loops. 

70 

71 Args: 

72 callbacks: List of Callbacks. 

73 model: Model being trained. 

74 do_validation: Whether or not validation loop will be run. 

75 batch_size: Number of samples per batch. 

76 epochs: Number of epoch to train. 

77 steps_per_epoch: Number of batches to run per training epoch. 

78 samples: Number of training samples. 

79 verbose: int, 0 or 1. Keras logging verbosity to pass to ProgbarLogger. 

80 count_mode: One of 'steps' or 'samples'. Per-batch or per-sample count. 

81 mode: String. One of ModeKeys.TRAIN, ModeKeys.TEST, or ModeKeys.PREDICT. 

82 Which loop mode to configure callbacks for. 

83 

84 Returns: 

85 Instance of CallbackList used to control all Callbacks. 

86 """ 

87 # Check if callbacks have already been configured. 

88 if isinstance(callbacks, CallbackList): 

89 return callbacks 

90 

91 if not callbacks: 

92 callbacks = [] 

93 

94 # Add additional callbacks during training. 

95 if mode == ModeKeys.TRAIN: 

96 model.history = History() 

97 callbacks = [BaseLogger()] + (callbacks or []) + [model.history] 

98 if verbose: 

99 callbacks.append(ProgbarLogger(count_mode)) 

100 callback_list = CallbackList(callbacks) 

101 

102 # Set callback model 

103 callback_model = model._get_callback_model() 

104 callback_list.set_model(callback_model) 

105 

106 set_callback_parameters( 

107 callback_list, 

108 model, 

109 do_validation=do_validation, 

110 batch_size=batch_size, 

111 epochs=epochs, 

112 steps_per_epoch=steps_per_epoch, 

113 samples=samples, 

114 verbose=verbose, 

115 mode=mode, 

116 ) 

117 

118 callback_list.model.stop_training = False 

119 return callback_list 

120 

121 

122def set_callback_parameters( 

123 callback_list, 

124 model, 

125 do_validation=False, 

126 batch_size=None, 

127 epochs=None, 

128 steps_per_epoch=None, 

129 samples=None, 

130 verbose=1, 

131 mode=ModeKeys.TRAIN, 

132): 

133 """Sets callback parameters. 

134 

135 Args: 

136 callback_list: CallbackList instance. 

137 model: Model being trained. 

138 do_validation: Whether or not validation loop will be run. 

139 batch_size: Number of samples per batch. 

140 epochs: Number of epoch to train. 

141 steps_per_epoch: Number of batches to run per training epoch. 

142 samples: Number of training samples. 

143 verbose: int, 0 or 1. Keras logging verbosity to pass to ProgbarLogger. 

144 mode: String. One of ModeKeys.TRAIN, ModeKeys.TEST, or ModeKeys.PREDICT. 

145 Which loop mode to configure callbacks for. 

146 """ 

147 metric_names = None 

148 for cbk in callback_list: 

149 if isinstance(cbk, (BaseLogger, ProgbarLogger)): 

150 if not metric_names: 

151 metric_names = model.metrics_names 

152 cbk.stateful_metrics = metric_names[1:] # Exclude `loss` 

153 

154 # Set callback parameters 

155 callback_metrics = [] 

156 # When we have deferred build scenario with iterator input, we will compile 

157 # when we standardize first batch of data. 

158 if mode != ModeKeys.PREDICT: 

159 if not metric_names: 

160 metric_names = model.metrics_names 

161 callback_metrics = copy.copy(metric_names) 

162 if do_validation: 

163 callback_metrics += ["val_" + n for n in metric_names] 

164 callback_params = { 

165 "batch_size": batch_size, 

166 "epochs": epochs, 

167 "steps": steps_per_epoch, 

168 "samples": samples, 

169 "verbose": verbose, 

170 "do_validation": do_validation, 

171 "metrics": callback_metrics, 

172 } 

173 callback_list.set_params(callback_params) 

174 

175 

176def _is_generator_like(data): 

177 """Checks if data is a generator, Sequence, or Iterator.""" 

178 return ( 

179 hasattr(data, "__next__") 

180 or hasattr(data, "next") 

181 or isinstance( 

182 data, (Sequence, tf.compat.v1.data.Iterator, tf.data.Iterator) 

183 ) 

184 ) 

185 

186 

187def make_logs(model, logs, outputs, mode, prefix=""): 

188 """Computes logs for sending to `on_batch_end` methods.""" 

189 metric_names = model.metrics_names 

190 if mode in {ModeKeys.TRAIN, ModeKeys.TEST} and metric_names: 

191 for label, output in zip(metric_names, outputs): 

192 logs[prefix + label] = output 

193 else: 

194 logs["outputs"] = outputs 

195 return logs 

196 

197 

198@keras_export("keras.callbacks.CallbackList") 

199class CallbackList: 

200 """Container abstracting a list of callbacks.""" 

201 

202 def __init__( 

203 self, 

204 callbacks=None, 

205 add_history=False, 

206 add_progbar=False, 

207 model=None, 

208 **params, 

209 ): 

210 """Container for `Callback` instances. 

211 

212 This object wraps a list of `Callback` instances, making it possible 

213 to call them all at once via a single endpoint 

214 (e.g. `callback_list.on_epoch_end(...)`). 

215 

216 Args: 

217 callbacks: List of `Callback` instances. 

218 add_history: Whether a `History` callback should be added, if one does 

219 not already exist in the `callbacks` list. 

220 add_progbar: Whether a `ProgbarLogger` callback should be added, if 

221 one does not already exist in the `callbacks` list. 

222 model: The `Model` these callbacks are used with. 

223 **params: If provided, parameters will be passed to each `Callback` 

224 via `Callback.set_params`. 

225 """ 

226 self.callbacks = tf.nest.flatten(callbacks) if callbacks else [] 

227 self._add_default_callbacks(add_history, add_progbar) 

228 

229 if model: 

230 self.set_model(model) 

231 if params: 

232 self.set_params(params) 

233 

234 # Performance optimization: determines if batch hooks need to be called. 

235 

236 self._supports_tf_logs = all( 

237 getattr(cb, "_supports_tf_logs", False) for cb in self.callbacks 

238 ) 

239 self._batch_hooks_support_tf_logs = all( 

240 getattr(cb, "_supports_tf_logs", False) 

241 for cb in self.callbacks 

242 if cb._implements_train_batch_hooks() 

243 or cb._implements_test_batch_hooks() 

244 or cb._implements_predict_batch_hooks() 

245 ) 

246 

247 self._should_call_train_batch_hooks = any( 

248 cb._implements_train_batch_hooks() for cb in self.callbacks 

249 ) 

250 self._should_call_test_batch_hooks = any( 

251 cb._implements_test_batch_hooks() for cb in self.callbacks 

252 ) 

253 self._should_call_predict_batch_hooks = any( 

254 cb._implements_predict_batch_hooks() for cb in self.callbacks 

255 ) 

256 

257 self._disallow_batch_hooks_in_ps_strategy() 

258 

259 # Performance check: Check batch hooks for slowness compared to batch 

260 # time. Only run check for custom callbacks (i.e. not present in this 

261 # file). 

262 self._check_timing = any( 

263 cbk.__class__.__name__ not in globals() for cbk in self.callbacks 

264 ) 

265 self._num_batches_for_timing_check = 5 

266 self._hook_times = {} 

267 self._batch_start_time = None 

268 self._batch_times = [] 

269 

270 def _add_default_callbacks(self, add_history, add_progbar): 

271 """Adds `Callback`s that are always present.""" 

272 self._progbar = None 

273 self._history = None 

274 

275 for cb in self.callbacks: 

276 if isinstance(cb, ProgbarLogger): 

277 self._progbar = cb 

278 elif isinstance(cb, History): 

279 self._history = cb 

280 

281 if self._history is None and add_history: 

282 self._history = History() 

283 self.callbacks.append(self._history) 

284 

285 if self._progbar is None and add_progbar: 

286 self._progbar = ProgbarLogger(count_mode="steps") 

287 self.callbacks.append(self._progbar) 

288 

289 def _process_logs(self, logs, is_batch_hook=False): 

290 """Turns tensors into numpy arrays or Python scalars if necessary.""" 

291 if logs is None: 

292 return {} 

293 if self._supports_tf_logs: 

294 return logs 

295 if is_batch_hook and self._batch_hooks_support_tf_logs: 

296 return logs 

297 return tf_utils.sync_to_numpy_or_python_type(logs) 

298 

299 def append(self, callback): 

300 self.callbacks.append(callback) 

301 

302 def set_params(self, params): 

303 self.params = params 

304 for callback in self.callbacks: 

305 callback.set_params(params) 

306 

307 def set_model(self, model): 

308 self.model = model 

309 if self._history: 

310 model.history = self._history 

311 for callback in self.callbacks: 

312 callback.set_model(model) 

313 

314 def _call_batch_hook(self, mode, hook, batch, logs=None): 

315 """Helper function for all batch_{begin | end} methods.""" 

316 if not self.callbacks: 

317 return 

318 

319 if hook == "begin": 

320 self._call_batch_begin_hook(mode, batch, logs) 

321 elif hook == "end": 

322 self._call_batch_end_hook(mode, batch, logs) 

323 else: 

324 raise ValueError( 

325 f"Unrecognized hook: {hook}. " 

326 'Expected values are ["begin", "end"]' 

327 ) 

328 

329 def _call_batch_begin_hook(self, mode, batch, logs): 

330 """Helper function for `on_*_batch_begin` methods.""" 

331 hook_name = f"on_{mode}_batch_begin" 

332 self._call_batch_hook_helper(hook_name, batch, logs) 

333 

334 if self._check_timing: 

335 self._batch_start_time = time.time() 

336 

337 def _call_batch_end_hook(self, mode, batch, logs): 

338 """Helper function for `on_*_batch_end` methods.""" 

339 hook_name = f"on_{mode}_batch_end" 

340 

341 if self._check_timing and batch >= 1: 

342 batch_time = time.time() - self._batch_start_time 

343 self._batch_times.append(batch_time) 

344 

345 self._call_batch_hook_helper(hook_name, batch, logs) 

346 

347 if len(self._batch_times) >= self._num_batches_for_timing_check: 

348 end_hook_name = hook_name 

349 begin_hook_name = f"on_{mode}_batch_begin" 

350 avg_batch_time = sum(self._batch_times) / len(self._batch_times) 

351 avg_end_hook_time = sum(self._hook_times[end_hook_name]) / len( 

352 self._hook_times[end_hook_name] 

353 ) 

354 avg_begin_hook_time = sum(self._hook_times[begin_hook_name]) / len( 

355 self._hook_times[begin_hook_name] 

356 ) 

357 

358 threshold_time = 1.0 * avg_batch_time 

359 warning_msg = ( 

360 "Callback method `{hook}` is slow compared to " 

361 "the batch time (batch time: {batch_time:.4f}s vs " 

362 "`{hook}` time: {hook_time:.4f}s). Check your callbacks." 

363 ) 

364 if avg_begin_hook_time > threshold_time: 

365 logging.warning( 

366 warning_msg.format( 

367 hook=begin_hook_name, 

368 batch_time=avg_batch_time, 

369 hook_time=avg_begin_hook_time, 

370 ) 

371 ) 

372 if avg_end_hook_time > threshold_time: 

373 logging.warning( 

374 warning_msg.format( 

375 hook=end_hook_name, 

376 batch_time=avg_batch_time, 

377 hook_time=avg_end_hook_time, 

378 ) 

379 ) 

380 self._check_timing = False 

381 self._batch_start_time = None 

382 self._batch_times = [] 

383 self._hook_times = {} 

384 

385 def _call_batch_hook_helper(self, hook_name, batch, logs): 

386 """Helper function for `on_*_batch_*` methods.""" 

387 if self._check_timing: 

388 start_time = time.time() 

389 

390 logs = self._process_logs(logs, is_batch_hook=True) 

391 for callback in self.callbacks: 

392 hook = getattr(callback, hook_name) 

393 hook(batch, logs) 

394 

395 if self._check_timing: 

396 if hook_name not in self._hook_times: 

397 self._hook_times[hook_name] = [] 

398 self._hook_times[hook_name].append(time.time() - start_time) 

399 

400 def _call_begin_hook(self, mode): 

401 """Helper function for on_{train|test|predict}_begin methods.""" 

402 if mode == ModeKeys.TRAIN: 

403 self.on_train_begin() 

404 elif mode == ModeKeys.TEST: 

405 self.on_test_begin() 

406 else: 

407 self.on_predict_begin() 

408 

409 def _call_end_hook(self, mode): 

410 """Helper function for on_{train|test|predict}_end methods.""" 

411 if mode == ModeKeys.TRAIN: 

412 self.on_train_end() 

413 elif mode == ModeKeys.TEST: 

414 self.on_test_end() 

415 else: 

416 self.on_predict_end() 

417 

418 def on_batch_begin(self, batch, logs=None): 

419 if self._should_call_train_batch_hooks: 

420 self._call_batch_hook(ModeKeys.TRAIN, "begin", batch, logs=logs) 

421 

422 def on_batch_end(self, batch, logs=None): 

423 if self._should_call_train_batch_hooks: 

424 self._call_batch_hook(ModeKeys.TRAIN, "end", batch, logs=logs) 

425 

426 def on_epoch_begin(self, epoch, logs=None): 

427 """Calls the `on_epoch_begin` methods of its callbacks. 

428 

429 This function should only be called during TRAIN mode. 

430 

431 Args: 

432 epoch: Integer, index of epoch. 

433 logs: Dict. Currently no data is passed to this argument for this 

434 method but that may change in the future. 

435 """ 

436 logs = self._process_logs(logs) 

437 for callback in self.callbacks: 

438 callback.on_epoch_begin(epoch, logs) 

439 

440 def on_epoch_end(self, epoch, logs=None): 

441 """Calls the `on_epoch_end` methods of its callbacks. 

442 

443 This function should only be called during TRAIN mode. 

444 

445 Args: 

446 epoch: Integer, index of epoch. 

447 logs: Dict, metric results for this training epoch, and for the 

448 validation epoch if validation is performed. Validation result 

449 keys are prefixed with `val_`. 

450 """ 

451 logs = self._process_logs(logs) 

452 for callback in self.callbacks: 

453 callback.on_epoch_end(epoch, logs) 

454 

455 def on_train_batch_begin(self, batch, logs=None): 

456 """Calls the `on_train_batch_begin` methods of its callbacks. 

457 

458 Args: 

459 batch: Integer, index of batch within the current epoch. 

460 logs: Dict, contains the return value of `model.train_step`. 

461 Typically, the values of the `Model`'s metrics are returned. 

462 Example: `{'loss': 0.2, 'accuracy': 0.7}`. 

463 """ 

464 if self._should_call_train_batch_hooks: 

465 self._call_batch_hook(ModeKeys.TRAIN, "begin", batch, logs=logs) 

466 

467 def on_train_batch_end(self, batch, logs=None): 

468 """Calls the `on_train_batch_end` methods of its callbacks. 

469 

470 Args: 

471 batch: Integer, index of batch within the current epoch. 

472 logs: Dict. Aggregated metric results up until this batch. 

473 """ 

474 if self._should_call_train_batch_hooks: 

475 self._call_batch_hook(ModeKeys.TRAIN, "end", batch, logs=logs) 

476 

477 def on_test_batch_begin(self, batch, logs=None): 

478 """Calls the `on_test_batch_begin` methods of its callbacks. 

479 

480 Args: 

481 batch: Integer, index of batch within the current epoch. 

482 logs: Dict, contains the return value of `model.test_step`. 

483 Typically, the values of the `Model`'s metrics are returned. 

484 Example: `{'loss': 0.2, 'accuracy': 0.7}`. 

485 """ 

486 if self._should_call_test_batch_hooks: 

487 self._call_batch_hook(ModeKeys.TEST, "begin", batch, logs=logs) 

488 

489 def on_test_batch_end(self, batch, logs=None): 

490 """Calls the `on_test_batch_end` methods of its callbacks. 

491 

492 Args: 

493 batch: Integer, index of batch within the current epoch. 

494 logs: Dict. Aggregated metric results up until this batch. 

495 """ 

496 if self._should_call_test_batch_hooks: 

497 self._call_batch_hook(ModeKeys.TEST, "end", batch, logs=logs) 

498 

499 def on_predict_batch_begin(self, batch, logs=None): 

500 """Calls the `on_predict_batch_begin` methods of its callbacks. 

501 

502 Args: 

503 batch: Integer, index of batch within the current epoch. 

504 logs: Dict, contains the return value of `model.predict_step`, 

505 it typically returns a dict with a key 'outputs' containing 

506 the model's outputs. 

507 """ 

508 if self._should_call_predict_batch_hooks: 

509 self._call_batch_hook(ModeKeys.PREDICT, "begin", batch, logs=logs) 

510 

511 def on_predict_batch_end(self, batch, logs=None): 

512 """Calls the `on_predict_batch_end` methods of its callbacks. 

513 

514 Args: 

515 batch: Integer, index of batch within the current epoch. 

516 logs: Dict. Aggregated metric results up until this batch. 

517 """ 

518 if self._should_call_predict_batch_hooks: 

519 self._call_batch_hook(ModeKeys.PREDICT, "end", batch, logs=logs) 

520 

521 def on_train_begin(self, logs=None): 

522 """Calls the `on_train_begin` methods of its callbacks. 

523 

524 Args: 

525 logs: Dict. Currently, no data is passed via this argument 

526 for this method, but that may change in the future. 

527 """ 

528 logs = self._process_logs(logs) 

529 for callback in self.callbacks: 

530 callback.on_train_begin(logs) 

531 

532 def on_train_end(self, logs=None): 

533 """Calls the `on_train_end` methods of its callbacks. 

534 

535 Args: 

536 logs: Dict. Currently, no data is passed via this argument 

537 for this method, but that may change in the future. 

538 """ 

539 logs = self._process_logs(logs) 

540 for callback in self.callbacks: 

541 callback.on_train_end(logs) 

542 

543 def on_test_begin(self, logs=None): 

544 """Calls the `on_test_begin` methods of its callbacks. 

545 

546 Args: 

547 logs: Dict. Currently no data is passed to this argument for this 

548 method but that may change in the future. 

549 """ 

550 logs = self._process_logs(logs) 

551 for callback in self.callbacks: 

552 callback.on_test_begin(logs) 

553 

554 def on_test_end(self, logs=None): 

555 """Calls the `on_test_end` methods of its callbacks. 

556 

557 Args: 

558 logs: Dict. Currently, no data is passed via this argument 

559 for this method, but that may change in the future. 

560 """ 

561 logs = self._process_logs(logs) 

562 for callback in self.callbacks: 

563 callback.on_test_end(logs) 

564 

565 def on_predict_begin(self, logs=None): 

566 """Calls the 'on_predict_begin` methods of its callbacks. 

567 

568 Args: 

569 logs: Dict. Currently no data is passed to this argument for this 

570 method but that may change in the future. 

571 """ 

572 logs = self._process_logs(logs) 

573 for callback in self.callbacks: 

574 callback.on_predict_begin(logs) 

575 

576 def on_predict_end(self, logs=None): 

577 """Calls the `on_predict_end` methods of its callbacks. 

578 

579 Args: 

580 logs: Dict. Currently, no data is passed via this argument 

581 for this method, but that may change in the future. 

582 """ 

583 logs = self._process_logs(logs) 

584 for callback in self.callbacks: 

585 callback.on_predict_end(logs) 

586 

587 def __iter__(self): 

588 return iter(self.callbacks) 

589 

590 def _disallow_batch_hooks_in_ps_strategy(self): 

591 """Error out if batch-level callbacks are passed with PSStrategy.""" 

592 

593 strategy = tf.distribute.get_strategy() 

594 if strategy._should_use_with_coordinator: 

595 unsupported_callbacks = [] 

596 for cb in self.callbacks: 

597 # These Callbacks can accept RemoteValues directly. 

598 if getattr(cb, "_supports_tf_logs", False): 

599 continue 

600 if ( 

601 cb._implements_train_batch_hooks() 

602 or cb._implements_test_batch_hooks() 

603 or cb._implements_predict_batch_hooks() 

604 ): 

605 unsupported_callbacks.append(cb) 

606 if unsupported_callbacks: 

607 raise ValueError( 

608 "Batch-level `Callback`s are not supported with " 

609 "`ParameterServerStrategy`. Found unsupported " 

610 f"callbacks: {unsupported_callbacks}" 

611 ) 

612 

613 def make_logs(self, model, logs, outputs, mode, prefix=""): 

614 """Computes logs for sending to `on_batch_end` methods.""" 

615 if not self.callbacks: 

616 return logs 

617 

618 return make_logs(model, logs, outputs, mode, prefix=prefix) 

619 

620 

621@keras_export("keras.callbacks.Callback") 

622class Callback: 

623 """Abstract base class used to build new callbacks. 

624 

625 Callbacks can be passed to keras methods such as `fit`, `evaluate`, and 

626 `predict` in order to hook into the various stages of the model training and 

627 inference lifecycle. 

628 

629 To create a custom callback, subclass `keras.callbacks.Callback` and 

630 override the method associated with the stage of interest. See 

631 https://www.tensorflow.org/guide/keras/custom_callback for more information. 

632 

633 Example: 

634 

635 >>> training_finished = False 

636 >>> class MyCallback(tf.keras.callbacks.Callback): 

637 ... def on_train_end(self, logs=None): 

638 ... global training_finished 

639 ... training_finished = True 

640 >>> model = tf.keras.Sequential([ 

641 ... tf.keras.layers.Dense(1, input_shape=(1,))]) 

642 >>> model.compile(loss='mean_squared_error') 

643 >>> model.fit(tf.constant([[1.0]]), tf.constant([[1.0]]), 

644 ... callbacks=[MyCallback()]) 

645 >>> assert training_finished == True 

646 

647 If you want to use `Callback` objects in a custom training loop: 

648 

649 1. You should pack all your callbacks into a single `callbacks.CallbackList` 

650 so they can all be called together. 

651 2. You will need to manually call all the `on_*` methods at the appropriate 

652 locations in your loop. Like this: 

653 

654 Example: 

655 ```python 

656 callbacks = tf.keras.callbacks.CallbackList([...]) 

657 callbacks.append(...) 

658 callbacks.on_train_begin(...) 

659 for epoch in range(EPOCHS): 

660 callbacks.on_epoch_begin(epoch) 

661 for i, data in dataset.enumerate(): 

662 callbacks.on_train_batch_begin(i) 

663 batch_logs = model.train_step(data) 

664 callbacks.on_train_batch_end(i, batch_logs) 

665 epoch_logs = ... 

666 callbacks.on_epoch_end(epoch, epoch_logs) 

667 final_logs=... 

668 callbacks.on_train_end(final_logs) 

669 ``` 

670 

671 Attributes: 

672 params: Dict. Training parameters 

673 (eg. verbosity, batch size, number of epochs...). 

674 model: Instance of `keras.models.Model`. 

675 Reference of the model being trained. 

676 

677 The `logs` dictionary that callback methods 

678 take as argument will contain keys for quantities relevant to 

679 the current batch or epoch (see method-specific docstrings). 

680 """ 

681 

682 def __init__(self): 

683 self.validation_data = None 

684 self.model = None 

685 # Whether this Callback should only run on the chief worker in a 

686 # Multi-Worker setting. 

687 # TODO(omalleyt): Make this attr public once solution is stable. 

688 self._chief_worker_only = None 

689 self._supports_tf_logs = False 

690 

691 def set_params(self, params): 

692 self.params = params 

693 

694 def set_model(self, model): 

695 self.model = model 

696 

697 @doc_controls.for_subclass_implementers 

698 @generic_utils.default 

699 def on_batch_begin(self, batch, logs=None): 

700 """A backwards compatibility alias for `on_train_batch_begin`.""" 

701 

702 @doc_controls.for_subclass_implementers 

703 @generic_utils.default 

704 def on_batch_end(self, batch, logs=None): 

705 """A backwards compatibility alias for `on_train_batch_end`.""" 

706 

707 @doc_controls.for_subclass_implementers 

708 def on_epoch_begin(self, epoch, logs=None): 

709 """Called at the start of an epoch. 

710 

711 Subclasses should override for any actions to run. This function should 

712 only be called during TRAIN mode. 

713 

714 Args: 

715 epoch: Integer, index of epoch. 

716 logs: Dict. Currently no data is passed to this argument for this 

717 method but that may change in the future. 

718 """ 

719 

720 @doc_controls.for_subclass_implementers 

721 def on_epoch_end(self, epoch, logs=None): 

722 """Called at the end of an epoch. 

723 

724 Subclasses should override for any actions to run. This function should 

725 only be called during TRAIN mode. 

726 

727 Args: 

728 epoch: Integer, index of epoch. 

729 logs: Dict, metric results for this training epoch, and for the 

730 validation epoch if validation is performed. Validation result 

731 keys are prefixed with `val_`. For training epoch, the values of 

732 the `Model`'s metrics are returned. Example: 

733 `{'loss': 0.2, 'accuracy': 0.7}`. 

734 """ 

735 

736 @doc_controls.for_subclass_implementers 

737 @generic_utils.default 

738 def on_train_batch_begin(self, batch, logs=None): 

739 """Called at the beginning of a training batch in `fit` methods. 

740 

741 Subclasses should override for any actions to run. 

742 

743 Note that if the `steps_per_execution` argument to `compile` in 

744 `tf.keras.Model` is set to `N`, this method will only be called every 

745 `N` batches. 

746 

747 Args: 

748 batch: Integer, index of batch within the current epoch. 

749 logs: Dict. Currently no data is passed to this argument for this 

750 method but that may change in the future. 

751 """ 

752 # For backwards compatibility. 

753 self.on_batch_begin(batch, logs=logs) 

754 

755 @doc_controls.for_subclass_implementers 

756 @generic_utils.default 

757 def on_train_batch_end(self, batch, logs=None): 

758 """Called at the end of a training batch in `fit` methods. 

759 

760 Subclasses should override for any actions to run. 

761 

762 Note that if the `steps_per_execution` argument to `compile` in 

763 `tf.keras.Model` is set to `N`, this method will only be called every 

764 `N` batches. 

765 

766 Args: 

767 batch: Integer, index of batch within the current epoch. 

768 logs: Dict. Aggregated metric results up until this batch. 

769 """ 

770 # For backwards compatibility. 

771 self.on_batch_end(batch, logs=logs) 

772 

773 @doc_controls.for_subclass_implementers 

774 @generic_utils.default 

775 def on_test_batch_begin(self, batch, logs=None): 

776 """Called at the beginning of a batch in `evaluate` methods. 

777 

778 Also called at the beginning of a validation batch in the `fit` 

779 methods, if validation data is provided. 

780 

781 Subclasses should override for any actions to run. 

782 

783 Note that if the `steps_per_execution` argument to `compile` in 

784 `tf.keras.Model` is set to `N`, this method will only be called every 

785 `N` batches. 

786 

787 Args: 

788 batch: Integer, index of batch within the current epoch. 

789 logs: Dict. Currently no data is passed to this argument for this 

790 method but that may change in the future. 

791 """ 

792 

793 @doc_controls.for_subclass_implementers 

794 @generic_utils.default 

795 def on_test_batch_end(self, batch, logs=None): 

796 """Called at the end of a batch in `evaluate` methods. 

797 

798 Also called at the end of a validation batch in the `fit` 

799 methods, if validation data is provided. 

800 

801 Subclasses should override for any actions to run. 

802 

803 Note that if the `steps_per_execution` argument to `compile` in 

804 `tf.keras.Model` is set to `N`, this method will only be called every 

805 `N` batches. 

806 

807 Args: 

808 batch: Integer, index of batch within the current epoch. 

809 logs: Dict. Aggregated metric results up until this batch. 

810 """ 

811 

812 @doc_controls.for_subclass_implementers 

813 @generic_utils.default 

814 def on_predict_batch_begin(self, batch, logs=None): 

815 """Called at the beginning of a batch in `predict` methods. 

816 

817 Subclasses should override for any actions to run. 

818 

819 Note that if the `steps_per_execution` argument to `compile` in 

820 `tf.keras.Model` is set to `N`, this method will only be called every 

821 `N` batches. 

822 

823 Args: 

824 batch: Integer, index of batch within the current epoch. 

825 logs: Dict. Currently no data is passed to this argument for this 

826 method but that may change in the future. 

827 """ 

828 

829 @doc_controls.for_subclass_implementers 

830 @generic_utils.default 

831 def on_predict_batch_end(self, batch, logs=None): 

832 """Called at the end of a batch in `predict` methods. 

833 

834 Subclasses should override for any actions to run. 

835 

836 Note that if the `steps_per_execution` argument to `compile` in 

837 `tf.keras.Model` is set to `N`, this method will only be called every 

838 `N` batches. 

839 

840 Args: 

841 batch: Integer, index of batch within the current epoch. 

842 logs: Dict. Aggregated metric results up until this batch. 

843 """ 

844 

845 @doc_controls.for_subclass_implementers 

846 def on_train_begin(self, logs=None): 

847 """Called at the beginning of training. 

848 

849 Subclasses should override for any actions to run. 

850 

851 Args: 

852 logs: Dict. Currently no data is passed to this argument for this 

853 method but that may change in the future. 

854 """ 

855 

856 @doc_controls.for_subclass_implementers 

857 def on_train_end(self, logs=None): 

858 """Called at the end of training. 

859 

860 Subclasses should override for any actions to run. 

861 

862 Args: 

863 logs: Dict. Currently the output of the last call to 

864 `on_epoch_end()` is passed to this argument for this method but 

865 that may change in the future. 

866 """ 

867 

868 @doc_controls.for_subclass_implementers 

869 def on_test_begin(self, logs=None): 

870 """Called at the beginning of evaluation or validation. 

871 

872 Subclasses should override for any actions to run. 

873 

874 Args: 

875 logs: Dict. Currently no data is passed to this argument for this 

876 method but that may change in the future. 

877 """ 

878 

879 @doc_controls.for_subclass_implementers 

880 def on_test_end(self, logs=None): 

881 """Called at the end of evaluation or validation. 

882 

883 Subclasses should override for any actions to run. 

884 

885 Args: 

886 logs: Dict. Currently the output of the last call to 

887 `on_test_batch_end()` is passed to this argument for this method 

888 but that may change in the future. 

889 """ 

890 

891 @doc_controls.for_subclass_implementers 

892 def on_predict_begin(self, logs=None): 

893 """Called at the beginning of prediction. 

894 

895 Subclasses should override for any actions to run. 

896 

897 Args: 

898 logs: Dict. Currently no data is passed to this argument for this 

899 method but that may change in the future. 

900 """ 

901 

902 @doc_controls.for_subclass_implementers 

903 def on_predict_end(self, logs=None): 

904 """Called at the end of prediction. 

905 

906 Subclasses should override for any actions to run. 

907 

908 Args: 

909 logs: Dict. Currently no data is passed to this argument for this 

910 method but that may change in the future. 

911 """ 

912 

913 def _implements_train_batch_hooks(self): 

914 """Determines if this Callback should be called for each train batch.""" 

915 return ( 

916 not generic_utils.is_default(self.on_batch_begin) 

917 or not generic_utils.is_default(self.on_batch_end) 

918 or not generic_utils.is_default(self.on_train_batch_begin) 

919 or not generic_utils.is_default(self.on_train_batch_end) 

920 ) 

921 

922 def _implements_test_batch_hooks(self): 

923 """Determines if this Callback should be called for each test batch.""" 

924 return not generic_utils.is_default( 

925 self.on_test_batch_begin 

926 ) or not generic_utils.is_default(self.on_test_batch_end) 

927 

928 def _implements_predict_batch_hooks(self): 

929 """Determines if this Callback should be called for each predict 

930 batch.""" 

931 return not generic_utils.is_default( 

932 self.on_predict_batch_begin 

933 ) or not generic_utils.is_default(self.on_predict_batch_end) 

934 

935 

936@keras_export("keras.callbacks.BaseLogger") 

937class BaseLogger(Callback): 

938 """Callback that accumulates epoch averages of metrics. 

939 

940 This callback is automatically applied to every Keras model. 

941 

942 Args: 

943 stateful_metrics: Iterable of string names of metrics that 

944 should *not* be averaged over an epoch. 

945 Metrics in this list will be logged as-is in `on_epoch_end`. 

946 All others will be averaged in `on_epoch_end`. 

947 """ 

948 

949 def __init__(self, stateful_metrics=None): 

950 super().__init__() 

951 self.stateful_metrics = set(stateful_metrics or []) 

952 

953 def on_epoch_begin(self, epoch, logs=None): 

954 self.seen = 0 

955 self.totals = {} 

956 

957 def on_batch_end(self, batch, logs=None): 

958 logs = logs or {} 

959 batch_size = logs.get("size", 0) 

960 # In case of distribution strategy we can potentially run multiple steps 

961 # at the same time, we should account for that in the `seen` 

962 # calculation. 

963 num_steps = logs.get("num_steps", 1) 

964 self.seen += batch_size * num_steps 

965 

966 for k, v in logs.items(): 

967 if k in self.stateful_metrics: 

968 self.totals[k] = v 

969 else: 

970 if k in self.totals: 

971 self.totals[k] += v * batch_size 

972 else: 

973 self.totals[k] = v * batch_size 

974 

975 def on_epoch_end(self, epoch, logs=None): 

976 if logs is not None: 

977 for k in self.params["metrics"]: 

978 if k in self.totals: 

979 # Make value available to next callbacks. 

980 if k in self.stateful_metrics: 

981 logs[k] = self.totals[k] 

982 else: 

983 logs[k] = self.totals[k] / self.seen 

984 

985 

986@keras_export("keras.callbacks.TerminateOnNaN") 

987class TerminateOnNaN(Callback): 

988 """Callback that terminates training when a NaN loss is encountered.""" 

989 

990 def __init__(self): 

991 super().__init__() 

992 self._supports_tf_logs = True 

993 

994 def on_batch_end(self, batch, logs=None): 

995 logs = logs or {} 

996 loss = logs.get("loss") 

997 if loss is not None: 

998 loss = tf_utils.sync_to_numpy_or_python_type(loss) 

999 if np.isnan(loss) or np.isinf(loss): 

1000 io_utils.print_msg( 

1001 f"Batch {batch}: Invalid loss, terminating training" 

1002 ) 

1003 self.model.stop_training = True 

1004 

1005 

1006@keras_export("keras.callbacks.ProgbarLogger") 

1007class ProgbarLogger(Callback): 

1008 """Callback that prints metrics to stdout. 

1009 

1010 Args: 

1011 count_mode: One of `"steps"` or `"samples"`. 

1012 Whether the progress bar should 

1013 count samples seen or steps (batches) seen. 

1014 stateful_metrics: Iterable of string names of metrics that 

1015 should *not* be averaged over an epoch. 

1016 Metrics in this list will be logged as-is. 

1017 All others will be averaged over time (e.g. loss, etc). 

1018 If not provided, defaults to the `Model`'s metrics. 

1019 

1020 Raises: 

1021 ValueError: In case of invalid `count_mode`. 

1022 """ 

1023 

1024 def __init__(self, count_mode: str = "samples", stateful_metrics=None): 

1025 super().__init__() 

1026 self._supports_tf_logs = True 

1027 if count_mode == "samples": 

1028 self.use_steps = False 

1029 elif count_mode == "steps": 

1030 self.use_steps = True 

1031 else: 

1032 raise ValueError( 

1033 f"Unknown `count_mode`: {count_mode}. " 

1034 'Expected values are ["samples", "steps"]' 

1035 ) 

1036 # Defaults to all Model's metrics except for loss. 

1037 self.stateful_metrics = ( 

1038 set(stateful_metrics) if stateful_metrics else set() 

1039 ) 

1040 

1041 self.seen = 0 

1042 self.progbar = None 

1043 self.target = None 

1044 self.verbose = 1 

1045 self.epochs = 1 

1046 

1047 self._train_step, self._test_step, self._predict_step = None, None, None 

1048 self._call_batch_hooks = True 

1049 

1050 self._called_in_fit = False 

1051 

1052 def set_params(self, params): 

1053 self.verbose = params["verbose"] 

1054 self.epochs = params["epochs"] 

1055 if self.use_steps and "steps" in params: 

1056 self.target = params["steps"] 

1057 elif not self.use_steps and "samples" in params: 

1058 self.target = params["samples"] 

1059 else: 

1060 self.target = ( 

1061 None # Will be inferred at the end of the first epoch. 

1062 ) 

1063 

1064 self._call_batch_hooks = self.verbose == 1 

1065 if self.target is None: 

1066 try: 

1067 self._train_step = self.model._train_counter 

1068 self._test_step = self.model._test_counter 

1069 self._predict_step = self.model._predict_counter 

1070 except AttributeError: 

1071 self._call_batch_hooks = True 

1072 

1073 def on_train_begin(self, logs=None): 

1074 # When this logger is called inside `fit`, validation is silent. 

1075 self._called_in_fit = True 

1076 

1077 def on_test_begin(self, logs=None): 

1078 if not self._called_in_fit: 

1079 self._reset_progbar() 

1080 self._maybe_init_progbar() 

1081 

1082 def on_predict_begin(self, logs=None): 

1083 self._reset_progbar() 

1084 self._maybe_init_progbar() 

1085 

1086 def on_epoch_begin(self, epoch, logs=None): 

1087 self._reset_progbar() 

1088 self._maybe_init_progbar() 

1089 if self.verbose and self.epochs > 1: 

1090 io_utils.print_msg(f"Epoch {epoch + 1}/{self.epochs}") 

1091 

1092 def on_train_batch_end(self, batch, logs=None): 

1093 self._batch_update_progbar(batch, logs) 

1094 

1095 def on_test_batch_end(self, batch, logs=None): 

1096 if not self._called_in_fit: 

1097 self._batch_update_progbar(batch, logs) 

1098 

1099 def on_predict_batch_end(self, batch, logs=None): 

1100 # Don't pass prediction results. 

1101 self._batch_update_progbar(batch, None) 

1102 

1103 def on_epoch_end(self, epoch, logs=None): 

1104 self._finalize_progbar(logs, self._train_step) 

1105 

1106 def on_test_end(self, logs=None): 

1107 if not self._called_in_fit: 

1108 self._finalize_progbar(logs, self._test_step) 

1109 

1110 def on_predict_end(self, logs=None): 

1111 self._finalize_progbar(logs, self._predict_step) 

1112 

1113 def _reset_progbar(self): 

1114 self.seen = 0 

1115 self.progbar = None 

1116 

1117 def _maybe_init_progbar(self): 

1118 """Instantiate a `Progbar` if not yet, and update the stateful 

1119 metrics.""" 

1120 # TODO(rchao): Legacy TF1 code path may use list for 

1121 # `self.stateful_metrics`. Remove "cast to set" when TF1 support is 

1122 # dropped. 

1123 self.stateful_metrics = set(self.stateful_metrics) 

1124 

1125 if self.model: 

1126 # Update the existing stateful metrics as `self.model.metrics` may 

1127 # contain updated metrics after `MetricsContainer` is built in the 

1128 # first train step. 

1129 self.stateful_metrics = self.stateful_metrics.union( 

1130 set(m.name for m in self.model.metrics) 

1131 ) 

1132 

1133 if self.progbar is None: 

1134 self.progbar = Progbar( 

1135 target=self.target, 

1136 verbose=self.verbose, 

1137 stateful_metrics=self.stateful_metrics, 

1138 unit_name="step" if self.use_steps else "sample", 

1139 ) 

1140 

1141 self.progbar._update_stateful_metrics(self.stateful_metrics) 

1142 

1143 def _implements_train_batch_hooks(self): 

1144 return self._call_batch_hooks 

1145 

1146 def _implements_test_batch_hooks(self): 

1147 return self._call_batch_hooks 

1148 

1149 def _implements_predict_batch_hooks(self): 

1150 return self._call_batch_hooks 

1151 

1152 def _batch_update_progbar(self, batch, logs=None): 

1153 """Updates the progbar.""" 

1154 logs = logs or {} 

1155 self._maybe_init_progbar() 

1156 if self.use_steps: 

1157 self.seen = batch + 1 # One-indexed. 

1158 else: 

1159 # v1 path only. 

1160 logs = copy.copy(logs) 

1161 batch_size = logs.pop("size", 0) 

1162 num_steps = logs.pop("num_steps", 1) 

1163 logs.pop("batch", None) 

1164 add_seen = num_steps * batch_size 

1165 self.seen += add_seen 

1166 

1167 if self.verbose == 1: 

1168 # Only block async when verbose = 1. 

1169 logs = tf_utils.sync_to_numpy_or_python_type(logs) 

1170 self.progbar.update(self.seen, list(logs.items()), finalize=False) 

1171 

1172 def _finalize_progbar(self, logs, counter): 

1173 logs = tf_utils.sync_to_numpy_or_python_type(logs or {}) 

1174 if self.target is None: 

1175 if counter is not None: 

1176 counter = counter.numpy() 

1177 if not self.use_steps: 

1178 counter *= logs.get("size", 1) 

1179 self.target = counter or self.seen 

1180 self.progbar.target = self.target 

1181 self.progbar.update(self.target, list(logs.items()), finalize=True) 

1182 

1183 

1184@keras_export("keras.callbacks.History") 

1185class History(Callback): 

1186 """Callback that records events into a `History` object. 

1187 

1188 This callback is automatically applied to 

1189 every Keras model. The `History` object 

1190 gets returned by the `fit` method of models. 

1191 

1192 Example: 

1193 

1194 >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 

1195 >>> model.compile(tf.keras.optimizers.SGD(), loss='mse') 

1196 >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), 

1197 ... epochs=10, verbose=1) 

1198 >>> print(history.params) 

1199 {'verbose': 1, 'epochs': 10, 'steps': 1} 

1200 >>> # check the keys of history object 

1201 >>> print(history.history.keys()) 

1202 dict_keys(['loss']) 

1203 

1204 """ 

1205 

1206 def __init__(self): 

1207 super().__init__() 

1208 self.history = {} 

1209 

1210 def on_train_begin(self, logs=None): 

1211 self.epoch = [] 

1212 

1213 def on_epoch_end(self, epoch, logs=None): 

1214 logs = logs or {} 

1215 self.epoch.append(epoch) 

1216 for k, v in logs.items(): 

1217 self.history.setdefault(k, []).append(v) 

1218 

1219 # Set the history attribute on the model after the epoch ends. This will 

1220 # make sure that the state which is set is the latest one. 

1221 self.model.history = self 

1222 

1223 

1224@keras_export("keras.callbacks.ModelCheckpoint") 

1225class ModelCheckpoint(Callback): 

1226 """Callback to save the Keras model or model weights at some frequency. 

1227 

1228 `ModelCheckpoint` callback is used in conjunction with training using 

1229 `model.fit()` to save a model or weights (in a checkpoint file) at some 

1230 interval, so the model or weights can be loaded later to continue the 

1231 training from the state saved. 

1232 

1233 A few options this callback provides include: 

1234 

1235 - Whether to only keep the model that has achieved the "best performance" so 

1236 far, or whether to save the model at the end of every epoch regardless of 

1237 performance. 

1238 - Definition of 'best'; which quantity to monitor and whether it should be 

1239 maximized or minimized. 

1240 - The frequency it should save at. Currently, the callback supports saving 

1241 at the end of every epoch, or after a fixed number of training batches. 

1242 - Whether only weights are saved, or the whole model is saved. 

1243 

1244 Note: If you get `WARNING:tensorflow:Can save best model only with <name> 

1245 available, skipping` see the description of the `monitor` argument for 

1246 details on how to get this right. 

1247 

1248 Example: 

1249 

1250 ```python 

1251 model.compile(loss=..., optimizer=..., 

1252 metrics=['accuracy']) 

1253 

1254 EPOCHS = 10 

1255 checkpoint_filepath = '/tmp/checkpoint' 

1256 model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( 

1257 filepath=checkpoint_filepath, 

1258 save_weights_only=True, 

1259 monitor='val_accuracy', 

1260 mode='max', 

1261 save_best_only=True) 

1262 

1263 # Model weights are saved at the end of every epoch, if it's the best seen 

1264 # so far. 

1265 model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback]) 

1266 

1267 # The model weights (that are considered the best) are loaded into the 

1268 # model. 

1269 model.load_weights(checkpoint_filepath) 

1270 ``` 

1271 

1272 Args: 

1273 filepath: string or `PathLike`, path to save the model file. e.g. 

1274 filepath = os.path.join(working_dir, 'ckpt', file_name). `filepath` 

1275 can contain named formatting options, which will be filled the value 

1276 of `epoch` and keys in `logs` (passed in `on_epoch_end`). For example: 

1277 if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then the 

1278 model checkpoints will be saved with the epoch number and the 

1279 validation loss in the filename. The directory of the filepath should 

1280 not be reused by any other callbacks to avoid conflicts. 

1281 monitor: The metric name to monitor. Typically the metrics are set by 

1282 the `Model.compile` method. Note: 

1283 

1284 * Prefix the name with `"val_`" to monitor validation metrics. 

1285 * Use `"loss"` or "`val_loss`" to monitor the model's total loss. 

1286 * If you specify metrics as strings, like `"accuracy"`, pass the same 

1287 string (with or without the `"val_"` prefix). 

1288 * If you pass `metrics.Metric` objects, `monitor` should be set to 

1289 `metric.name` 

1290 * If you're not sure about the metric names you can check the contents 

1291 of the `history.history` dictionary returned by 

1292 `history = model.fit()` 

1293 * Multi-output models set additional prefixes on the metric names. 

1294 

1295 verbose: Verbosity mode, 0 or 1. Mode 0 is silent, and mode 1 

1296 displays messages when the callback takes an action. 

1297 save_best_only: if `save_best_only=True`, it only saves when the model 

1298 is considered the "best" and the latest best model according to the 

1299 quantity monitored will not be overwritten. If `filepath` doesn't 

1300 contain formatting options like `{epoch}` then `filepath` will be 

1301 overwritten by each new better model. 

1302 mode: one of {'auto', 'min', 'max'}. If `save_best_only=True`, the 

1303 decision to overwrite the current save file is made based on either 

1304 the maximization or the minimization of the monitored quantity. 

1305 For `val_acc`, this should be `max`, for `val_loss` this should be 

1306 `min`, etc. In `auto` mode, the mode is set to `max` if the quantities 

1307 monitored are 'acc' or start with 'fmeasure' and are set to `min` for 

1308 the rest of the quantities. 

1309 save_weights_only: if True, then only the model's weights will be saved 

1310 (`model.save_weights(filepath)`), else the full model is saved 

1311 (`model.save(filepath)`). 

1312 save_freq: `'epoch'` or integer. When using `'epoch'`, the callback 

1313 saves the model after each epoch. When using integer, the callback 

1314 saves the model at end of this many batches. If the `Model` is 

1315 compiled with `steps_per_execution=N`, then the saving criteria will 

1316 be checked every Nth batch. Note that if the saving isn't aligned to 

1317 epochs, the monitored metric may potentially be less reliable (it 

1318 could reflect as little as 1 batch, since the metrics get reset every 

1319 epoch). Defaults to `'epoch'`. 

1320 options: Optional `tf.train.CheckpointOptions` object if 

1321 `save_weights_only` is true or optional `tf.saved_model.SaveOptions` 

1322 object if `save_weights_only` is false. 

1323 initial_value_threshold: Floating point initial "best" value of the 

1324 metric to be monitored. Only applies if `save_best_value=True`. Only 

1325 overwrites the model weights already saved if the performance of 

1326 current model is better than this value. 

1327 **kwargs: Additional arguments for backwards compatibility. Possible key 

1328 is `period`. 

1329 """ 

1330 

1331 def __init__( 

1332 self, 

1333 filepath, 

1334 monitor: str = "val_loss", 

1335 verbose: int = 0, 

1336 save_best_only: bool = False, 

1337 save_weights_only: bool = False, 

1338 mode: str = "auto", 

1339 save_freq="epoch", 

1340 options=None, 

1341 initial_value_threshold=None, 

1342 **kwargs, 

1343 ): 

1344 super().__init__() 

1345 self._supports_tf_logs = True 

1346 self.monitor = monitor 

1347 self.verbose = verbose 

1348 self.filepath = io_utils.path_to_string(filepath) 

1349 self.save_best_only = save_best_only 

1350 self.save_weights_only = save_weights_only 

1351 self.save_freq = save_freq 

1352 self.epochs_since_last_save = 0 

1353 self._batches_seen_since_last_saving = 0 

1354 self._last_batch_seen = 0 

1355 self.best = initial_value_threshold 

1356 

1357 if save_weights_only: 

1358 if options is None or isinstance( 

1359 options, tf.train.CheckpointOptions 

1360 ): 

1361 self._options = options or tf.train.CheckpointOptions() 

1362 else: 

1363 raise TypeError( 

1364 "If save_weights_only is True, then `options` must be " 

1365 "either None or a tf.train.CheckpointOptions. " 

1366 f"Got {options}." 

1367 ) 

1368 else: 

1369 if filepath and filepath.endswith(".keras") and options is not None: 

1370 raise ValueError( 

1371 "The native Keras format does not support " 

1372 "the `options` argument. Please remove " 

1373 "the `options` argument, or use the SavedModel " 

1374 "format by removing the `.keras` extension from " 

1375 "the model filepath." 

1376 ) 

1377 if options is None or isinstance( 

1378 options, tf.saved_model.SaveOptions 

1379 ): 

1380 self._options = options or tf.saved_model.SaveOptions() 

1381 else: 

1382 raise TypeError( 

1383 "If save_weights_only is False, then `options` must be " 

1384 "either None or a tf.saved_model.SaveOptions. " 

1385 f"Got {options}." 

1386 ) 

1387 

1388 # Deprecated field `load_weights_on_restart` is for loading the 

1389 # checkpoint file from `filepath` at the start of `model.fit()` 

1390 # TODO(rchao): Remove the arg during next breaking release. 

1391 if "load_weights_on_restart" in kwargs: 

1392 self.load_weights_on_restart = kwargs["load_weights_on_restart"] 

1393 logging.warning( 

1394 "`load_weights_on_restart` argument is deprecated. " 

1395 "Please use `model.load_weights()` for loading weights " 

1396 "before the start of `model.fit()`." 

1397 ) 

1398 else: 

1399 self.load_weights_on_restart = False 

1400 

1401 # Deprecated field `period` is for the number of epochs between which 

1402 # the model is saved. 

1403 if "period" in kwargs: 

1404 self.period = kwargs["period"] 

1405 logging.warning( 

1406 "`period` argument is deprecated. Please use `save_freq` " 

1407 "to specify the frequency in number of batches seen." 

1408 ) 

1409 else: 

1410 self.period = 1 

1411 

1412 if mode not in ["auto", "min", "max"]: 

1413 logging.warning( 

1414 "ModelCheckpoint mode %s is unknown, fallback to auto mode.", 

1415 mode, 

1416 ) 

1417 mode = "auto" 

1418 

1419 if mode == "min": 

1420 self.monitor_op = np.less 

1421 if self.best is None: 

1422 self.best = np.Inf 

1423 elif mode == "max": 

1424 self.monitor_op = np.greater 

1425 if self.best is None: 

1426 self.best = -np.Inf 

1427 else: 

1428 if "acc" in self.monitor or self.monitor.startswith("fmeasure"): 

1429 self.monitor_op = np.greater 

1430 if self.best is None: 

1431 self.best = -np.Inf 

1432 else: 

1433 self.monitor_op = np.less 

1434 if self.best is None: 

1435 self.best = np.Inf 

1436 

1437 if self.save_freq != "epoch" and not isinstance(self.save_freq, int): 

1438 raise ValueError( 

1439 f"Unrecognized save_freq: {self.save_freq}. " 

1440 'Expected save_freq are "epoch" or integer' 

1441 ) 

1442 

1443 # Only the chief worker writes model checkpoints, but all workers 

1444 # restore checkpoint at on_train_begin(). 

1445 self._chief_worker_only = False 

1446 

1447 def on_train_begin(self, logs=None): 

1448 if self.load_weights_on_restart: 

1449 filepath_to_load = ( 

1450 self._get_most_recently_modified_file_matching_pattern( 

1451 self.filepath 

1452 ) 

1453 ) 

1454 if filepath_to_load is not None and self._checkpoint_exists( 

1455 filepath_to_load 

1456 ): 

1457 try: 

1458 # `filepath` may contain placeholders such as `{epoch:02d}`, 

1459 # and thus it attempts to load the most recently modified 

1460 # file with file name matching the pattern. 

1461 self.model.load_weights(filepath_to_load) 

1462 except (IOError, ValueError) as e: 

1463 raise ValueError( 

1464 f"Error loading file from {filepath_to_load}. " 

1465 f"Reason: {e}" 

1466 ) 

1467 

1468 def _implements_train_batch_hooks(self): 

1469 # Only call batch hooks when saving on batch 

1470 return self.save_freq != "epoch" 

1471 

1472 def on_train_batch_end(self, batch, logs=None): 

1473 if self._should_save_on_batch(batch): 

1474 self._save_model(epoch=self._current_epoch, batch=batch, logs=logs) 

1475 

1476 def on_epoch_begin(self, epoch, logs=None): 

1477 self._current_epoch = epoch 

1478 

1479 def on_epoch_end(self, epoch, logs=None): 

1480 self.epochs_since_last_save += 1 

1481 

1482 if self.save_freq == "epoch": 

1483 self._save_model(epoch=epoch, batch=None, logs=logs) 

1484 

1485 def _should_save_on_batch(self, batch): 

1486 """Handles batch-level saving logic, supports steps_per_execution.""" 

1487 if self.save_freq == "epoch": 

1488 return False 

1489 

1490 if batch <= self._last_batch_seen: # New epoch. 

1491 add_batches = batch + 1 # batches are zero-indexed. 

1492 else: 

1493 add_batches = batch - self._last_batch_seen 

1494 self._batches_seen_since_last_saving += add_batches 

1495 self._last_batch_seen = batch 

1496 

1497 if self._batches_seen_since_last_saving >= self.save_freq: 

1498 self._batches_seen_since_last_saving = 0 

1499 return True 

1500 return False 

1501 

1502 def _save_model(self, epoch, batch, logs): 

1503 """Saves the model. 

1504 

1505 Args: 

1506 epoch: the epoch this iteration is in. 

1507 batch: the batch this iteration is in. `None` if the `save_freq` 

1508 is set to `epoch`. 

1509 logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`. 

1510 """ 

1511 logs = logs or {} 

1512 

1513 if ( 

1514 isinstance(self.save_freq, int) 

1515 or self.epochs_since_last_save >= self.period 

1516 ): 

1517 # Block only when saving interval is reached. 

1518 logs = tf_utils.sync_to_numpy_or_python_type(logs) 

1519 self.epochs_since_last_save = 0 

1520 filepath = self._get_file_path(epoch, batch, logs) 

1521 

1522 # Create host directory if it doesn't exist. 

1523 dirname = os.path.dirname(filepath) 

1524 if dirname and not tf.io.gfile.exists(dirname): 

1525 tf.io.gfile.makedirs(dirname) 

1526 

1527 try: 

1528 if self.save_best_only: 

1529 current = logs.get(self.monitor) 

1530 if current is None: 

1531 logging.warning( 

1532 "Can save best model only with %s available, " 

1533 "skipping.", 

1534 self.monitor, 

1535 ) 

1536 else: 

1537 if self.monitor_op(current, self.best): 

1538 if self.verbose > 0: 

1539 io_utils.print_msg( 

1540 f"\nEpoch {epoch + 1}: {self.monitor} " 

1541 "improved " 

1542 f"from {self.best:.5f} to {current:.5f}, " 

1543 f"saving model to {filepath}" 

1544 ) 

1545 self.best = current 

1546 if self.save_weights_only: 

1547 self.model.save_weights( 

1548 filepath, 

1549 overwrite=True, 

1550 options=self._options, 

1551 ) 

1552 else: 

1553 self.model.save( 

1554 filepath, 

1555 overwrite=True, 

1556 options=self._options, 

1557 ) 

1558 else: 

1559 if self.verbose > 0: 

1560 io_utils.print_msg( 

1561 f"\nEpoch {epoch + 1}: " 

1562 f"{self.monitor} did not improve " 

1563 f"from {self.best:.5f}" 

1564 ) 

1565 else: 

1566 if self.verbose > 0: 

1567 io_utils.print_msg( 

1568 f"\nEpoch {epoch + 1}: saving model to {filepath}" 

1569 ) 

1570 if self.save_weights_only: 

1571 self.model.save_weights( 

1572 filepath, overwrite=True, options=self._options 

1573 ) 

1574 elif filepath.endswith(".keras"): 

1575 self.model.save(filepath, overwrite=True) 

1576 else: 

1577 self.model.save( 

1578 filepath, overwrite=True, options=self._options 

1579 ) 

1580 

1581 self._maybe_remove_file() 

1582 except IsADirectoryError: # h5py 3.x 

1583 raise IOError( 

1584 "Please specify a non-directory filepath for " 

1585 "ModelCheckpoint. Filepath used is an existing " 

1586 f"directory: {filepath}" 

1587 ) 

1588 except IOError as e: # h5py 2.x 

1589 # `e.errno` appears to be `None` so checking the content of 

1590 # `e.args[0]`. 

1591 if "is a directory" in str(e.args[0]).lower(): 

1592 raise IOError( 

1593 "Please specify a non-directory filepath for " 

1594 "ModelCheckpoint. Filepath used is an existing " 

1595 f"directory: f{filepath}" 

1596 ) 

1597 # Re-throw the error for any other causes. 

1598 raise e 

1599 

1600 def _get_file_path(self, epoch, batch, logs): 

1601 """Returns the file path for checkpoint.""" 

1602 

1603 try: 

1604 # `filepath` may contain placeholders such as 

1605 # `{epoch:02d}`,`{batch:02d}` and `{mape:.2f}`. A mismatch between 

1606 # logged metrics and the path's placeholders can cause formatting to 

1607 # fail. 

1608 if batch is None or "batch" in logs: 

1609 file_path = self.filepath.format(epoch=epoch + 1, **logs) 

1610 else: 

1611 file_path = self.filepath.format( 

1612 epoch=epoch + 1, batch=batch + 1, **logs 

1613 ) 

1614 except KeyError as e: 

1615 raise KeyError( 

1616 f'Failed to format this callback filepath: "{self.filepath}". ' 

1617 f"Reason: {e}" 

1618 ) 

1619 self._write_filepath = distributed_file_utils.write_filepath( 

1620 file_path, self.model.distribute_strategy 

1621 ) 

1622 return self._write_filepath 

1623 

1624 def _maybe_remove_file(self): 

1625 # Remove the checkpoint directory in multi-worker training where this 

1626 # worker should not checkpoint. It is a dummy directory previously saved 

1627 # for sync distributed training. 

1628 distributed_file_utils.remove_temp_dir_with_filepath( 

1629 self._write_filepath, self.model.distribute_strategy 

1630 ) 

1631 

1632 def _checkpoint_exists(self, filepath): 

1633 """Returns whether the checkpoint `filepath` refers to exists.""" 

1634 if filepath.endswith(".h5"): 

1635 return tf.io.gfile.exists(filepath) 

1636 tf_saved_model_exists = tf.io.gfile.exists(filepath) 

1637 tf_weights_only_checkpoint_exists = tf.io.gfile.exists( 

1638 filepath + ".index" 

1639 ) 

1640 return tf_saved_model_exists or tf_weights_only_checkpoint_exists 

1641 

1642 def _get_most_recently_modified_file_matching_pattern(self, pattern): 

1643 """Returns the most recently modified filepath matching pattern. 

1644 

1645 Pattern may contain python formatting placeholder. If 

1646 `tf.train.latest_checkpoint()` does not return None, use that; 

1647 otherwise, check for most recently modified one that matches the 

1648 pattern. 

1649 

1650 In the rare case where there are more than one pattern-matching file 

1651 having the same modified time that is most recent among all, return the 

1652 filepath that is largest (by `>` operator, lexicographically using the 

1653 numeric equivalents). This provides a tie-breaker when multiple files 

1654 are most recent. Note that a larger `filepath` can sometimes indicate a 

1655 later time of modification (for instance, when epoch/batch is used as 

1656 formatting option), but not necessarily (when accuracy or loss is used). 

1657 The tie-breaker is put in the logic as best effort to return the most 

1658 recent, and to avoid undeterministic result. 

1659 

1660 Modified time of a file is obtained with `os.path.getmtime()`. 

1661 

1662 This utility function is best demonstrated via an example: 

1663 

1664 ```python 

1665 file_pattern = 'f.batch{batch:02d}epoch{epoch:02d}.h5' 

1666 test_dir = self.get_temp_dir() 

1667 path_pattern = os.path.join(test_dir, file_pattern) 

1668 file_paths = [ 

1669 os.path.join(test_dir, file_name) for file_name in 

1670 ['f.batch03epoch02.h5', 

1671 'f.batch02epoch02.h5', 'f.batch01epoch01.h5'] 

1672 ] 

1673 for file_path in file_paths: 

1674 # Write something to each of the files 

1675 self.assertEqual( 

1676 _get_most_recently_modified_file_matching_pattern(path_pattern), 

1677 file_paths[-1]) 

1678 ``` 

1679 

1680 Args: 

1681 pattern: The file pattern that may optionally contain python 

1682 placeholder such as `{epoch:02d}`. 

1683 

1684 Returns: 

1685 The most recently modified file's full filepath matching `pattern`. 

1686 If `pattern` does not contain any placeholder, this returns the 

1687 filepath that exactly matches `pattern`. Returns `None` if no match 

1688 is found. 

1689 """ 

1690 dir_name = os.path.dirname(pattern) 

1691 base_name = os.path.basename(pattern) 

1692 base_name_regex = "^" + re.sub(r"{.*}", r".*", base_name) + "$" 

1693 

1694 # If tf.train.latest_checkpoint tells us there exists a latest 

1695 # checkpoint, use that as it is more robust than `os.path.getmtime()`. 

1696 latest_tf_checkpoint = tf.train.latest_checkpoint(dir_name) 

1697 if latest_tf_checkpoint is not None and re.match( 

1698 base_name_regex, os.path.basename(latest_tf_checkpoint) 

1699 ): 

1700 return latest_tf_checkpoint 

1701 

1702 latest_mod_time = 0 

1703 file_path_with_latest_mod_time = None 

1704 n_file_with_latest_mod_time = 0 

1705 file_path_with_largest_file_name = None 

1706 

1707 if tf.io.gfile.exists(dir_name): 

1708 for file_name in os.listdir(dir_name): 

1709 # Only consider if `file_name` matches the pattern. 

1710 if re.match(base_name_regex, file_name): 

1711 file_path = os.path.join(dir_name, file_name) 

1712 mod_time = os.path.getmtime(file_path) 

1713 if ( 

1714 file_path_with_largest_file_name is None 

1715 or file_path > file_path_with_largest_file_name 

1716 ): 

1717 file_path_with_largest_file_name = file_path 

1718 if mod_time > latest_mod_time: 

1719 latest_mod_time = mod_time 

1720 file_path_with_latest_mod_time = file_path 

1721 # In the case a file with later modified time is found, 

1722 # reset the counter for the number of files with latest 

1723 # modified time. 

1724 n_file_with_latest_mod_time = 1 

1725 elif mod_time == latest_mod_time: 

1726 # In the case a file has modified time tied with the 

1727 # most recent, increment the counter for the number of 

1728 # files with latest modified time by 1. 

1729 n_file_with_latest_mod_time += 1 

1730 

1731 if n_file_with_latest_mod_time == 1: 

1732 # Return the sole file that has most recent modified time. 

1733 return file_path_with_latest_mod_time 

1734 else: 

1735 # If there are more than one file having latest modified time, 

1736 # return the file path with the largest file name. 

1737 return file_path_with_largest_file_name 

1738 

1739 

1740@keras_export("keras.callbacks.BackupAndRestore", v1=[]) 

1741class BackupAndRestore(Callback): 

1742 """Callback to back up and restore the training state. 

1743 

1744 `BackupAndRestore` callback is intended to recover training from an 

1745 interruption that has happened in the middle of a `Model.fit` execution, by 

1746 backing up the training states in a temporary checkpoint file (with the help 

1747 of a `tf.train.CheckpointManager`), at the end of each epoch. Each backup 

1748 overwrites the previously written checkpoint file, so at any given time 

1749 there is at most one such checkpoint file for backup/restoring purpose. 

1750 

1751 If training restarts before completion, the training state (which includes 

1752 the `Model` weights and epoch number) is restored to the most recently saved 

1753 state at the beginning of a new `Model.fit` run. At the completion of a 

1754 `Model.fit` run, the temporary checkpoint file is deleted. 

1755 

1756 Note that the user is responsible to bring jobs back after the interruption. 

1757 This callback is important for the backup and restore mechanism for fault 

1758 tolerance purpose, and the model to be restored from a previous checkpoint 

1759 is expected to be the same as the one used to back up. If user changes 

1760 arguments passed to compile or fit, the checkpoint saved for fault tolerance 

1761 can become invalid. 

1762 

1763 Note: 

1764 

1765 1. This callback is not compatible with eager execution disabled. 

1766 2. A checkpoint is saved at the end of each epoch. After restoring, 

1767 `Model.fit` redoes any partial work during the unfinished epoch in which the 

1768 training got restarted (so the work done before the interruption doesn't 

1769 affect the final model state). 

1770 3. This works for both single worker and multi-worker modes. When 

1771 `Model.fit` is used with `tf.distribute`, it supports 

1772 `tf.distribute.MirroredStrategy`, 

1773 `tf.distribute.MultiWorkerMirroredStrategy`, `tf.distribute.TPUStrategy`, 

1774 and `tf.distribute.experimental.ParameterServerStrategy`. 

1775 

1776 Example: 

1777 

1778 >>> class InterruptingCallback(tf.keras.callbacks.Callback): 

1779 ... def on_epoch_begin(self, epoch, logs=None): 

1780 ... if epoch == 4: 

1781 ... raise RuntimeError('Interrupting!') 

1782 >>> callback = tf.keras.callbacks.BackupAndRestore(backup_dir="/tmp/backup") 

1783 >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 

1784 >>> model.compile(tf.keras.optimizers.SGD(), loss='mse') 

1785 >>> try: 

1786 ... model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10, 

1787 ... batch_size=1, callbacks=[callback, InterruptingCallback()], 

1788 ... verbose=0) 

1789 ... except: 

1790 ... pass 

1791 >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), 

1792 ... epochs=10, batch_size=1, callbacks=[callback], 

1793 ... verbose=0) 

1794 >>> # Only 6 more epochs are run, since first training got interrupted at 

1795 >>> # zero-indexed epoch 4, second training will continue from 4 to 9. 

1796 >>> len(history.history['loss']) 

1797 6 

1798 

1799 Besides the option to save at the end of every epoch or every N steps, if 

1800 you are doing distributed training with 

1801 `tf.distribute.MultiWorkerMirroredStrategy` on Google Cloud Platform or 

1802 Google Borg, you can also use the `save_before_preemption` argument 

1803 to enable saving a checkpoint right before a worker gets preempted 

1804 by other jobs and training gets interrupted. See 

1805 `tf.distribute.experimental.PreemptionCheckpointHandler` for more details. 

1806 

1807 Args: 

1808 backup_dir: String, path to store the checkpoint. 

1809 e.g. `backup_dir = os.path.join(working_dir, 'backup')`. 

1810 This is the directory in which the system stores temporary files to 

1811 recover the model from jobs terminated unexpectedly. The directory 

1812 cannot be reused elsewhere to store other files, e.g. by the 

1813 `BackupAndRestore` callback of another training run, 

1814 or by another callback 

1815 (e.g. `ModelCheckpoint`) of the same training. 

1816 save_freq: `'epoch'`, integer, or `False`. When set to `'epoch'` 

1817 the callback saves the checkpoint at the end of each epoch. 

1818 When set to an integer, the callback saves the checkpoint every 

1819 `save_freq` batches. Set `save_freq` to `False` if only using 

1820 preemption checkpointing (with `save_before_preemption=True`). 

1821 delete_checkpoint: Boolean, default to True. This `BackupAndRestore` 

1822 callback works by saving a checkpoint to back up the training state. 

1823 If `delete_checkpoint=True`, the checkpoint will be deleted after 

1824 training is finished. Use `False` if you'd like to keep the checkpoint 

1825 for future usage. 

1826 save_before_preemption: A boolean value instructing whether to turn on 

1827 the automatic checkpoint saving for preemption/maintenance events. 

1828 This only supports 

1829 `tf.distribute.MultiWorkerMirroredStrategy` on Google Cloud Platform 

1830 or Google Borg for now. 

1831 """ 

1832 

1833 def __init__( 

1834 self, 

1835 backup_dir, 

1836 save_freq="epoch", 

1837 delete_checkpoint=True, 

1838 save_before_preemption=False, 

1839 ): 

1840 super().__init__() 

1841 self.backup_dir = backup_dir 

1842 self._supports_tf_logs = True 

1843 self._supported_strategies = ( 

1844 tf.distribute.MirroredStrategy, 

1845 tf.distribute.MultiWorkerMirroredStrategy, 

1846 tf.distribute.experimental.TPUStrategy, 

1847 tf.distribute.TPUStrategy, 

1848 tf.distribute.experimental.ParameterServerStrategy, 

1849 ) 

1850 self.save_freq = save_freq 

1851 self.delete_checkpoint = delete_checkpoint 

1852 self.save_before_preemption = save_before_preemption 

1853 self._batches_count = 0 

1854 self._current_epoch = 0 

1855 

1856 if not tf.executing_eagerly(): 

1857 if tf.inside_function(): 

1858 raise ValueError( 

1859 "This Callback's method contains Python state and " 

1860 "should be called outside of `tf.function`s." 

1861 ) 

1862 else: # Legacy graph mode: 

1863 raise ValueError( 

1864 "BackupAndRestore only supports eager mode. In graph " 

1865 "mode, consider using ModelCheckpoint to manually save " 

1866 "and restore weights with `model.load_weights()` and by " 

1867 "providing `initial_epoch` in `model.fit()` for fault " 

1868 "tolerance." 

1869 ) 

1870 if (not save_freq) and (not save_before_preemption): 

1871 raise ValueError( 

1872 "Either `save_freq` or `save_before_preemption` " "must be set." 

1873 ) 

1874 

1875 # Only the chief worker writes model checkpoints, but all workers 

1876 # restore checkpoint at on_train_begin(). 

1877 self._chief_worker_only = False 

1878 

1879 def on_train_begin(self, logs=None): 

1880 # TrainingState is used to manage the training state needed for 

1881 # failure-recovery of a worker in training. 

1882 

1883 if self.model._distribution_strategy and not isinstance( 

1884 self.model.distribute_strategy, self._supported_strategies 

1885 ): 

1886 raise NotImplementedError( 

1887 f"{type(self.model.distribute_strategy)} is not supported yet. " 

1888 "Currently BackupAndRestore callback " 

1889 "only supports empty strategy, " 

1890 "MirroredStrategy, MultiWorkerMirroredStrategy and TPUStrategy." 

1891 ) 

1892 self.model._training_state = worker_training_state.WorkerTrainingState( 

1893 self.model, 

1894 self.backup_dir, 

1895 self.save_freq, 

1896 self.save_before_preemption, 

1897 ) 

1898 self._training_state = self.model._training_state 

1899 self._training_state.restore() 

1900 

1901 def on_train_batch_begin(self, batch, logs=None): 

1902 # Skip batch update for PSS Strategy 

1903 if isinstance( 

1904 self.model.distribute_strategy, 

1905 tf.distribute.ParameterServerStrategy, 

1906 ): 

1907 return 

1908 self._training_state._ckpt_saved_batch.assign(batch) 

1909 

1910 def on_train_batch_end(self, batch, logs=None): 

1911 # Skip batch update for PSS Strategy 

1912 if isinstance( 

1913 self.model.distribute_strategy, 

1914 tf.distribute.ParameterServerStrategy, 

1915 ): 

1916 return 

1917 self._training_state.backup_if_preempted() 

1918 if self.save_freq and self.save_freq != "epoch": 

1919 self._batches_count += 1 

1920 if self._batches_count >= self.save_freq: 

1921 self._batches_count = 0 

1922 self._backup(epoch=self._current_epoch, batch=batch) 

1923 

1924 def _implements_train_batch_hooks(self): 

1925 return self.save_freq != "epoch" 

1926 

1927 def on_train_end(self, logs=None): 

1928 if self.delete_checkpoint: 

1929 # On exit of training, delete the training state backup file saved 

1930 # for the purpose of worker recovery unless the user opts out. 

1931 self._training_state.delete_backup() 

1932 # Clean up the training state. 

1933 del self._training_state 

1934 del self.model._training_state 

1935 

1936 def on_epoch_begin(self, epoch, logs=None): 

1937 self._training_state._ckpt_saved_epoch.assign(epoch) 

1938 self._current_epoch = epoch 

1939 

1940 def on_epoch_end(self, epoch, logs=None): 

1941 # Back up the model and current epoch for possible future recovery. 

1942 if self.save_freq == "epoch": 

1943 self._backup(epoch=epoch) 

1944 

1945 def _backup(self, epoch, batch=0): 

1946 self._training_state.back_up(epoch=epoch, batch=batch) 

1947 

1948 

1949@keras_export("keras.callbacks.experimental.BackupAndRestore", v1=[]) 

1950@deprecation.deprecated_endpoints( 

1951 "keras.callbacks.experimental.BackupAndRestore" 

1952) 

1953class BackupAndRestoreExperimental(BackupAndRestore): 

1954 """Deprecated. Please use `tf.keras.callbacks.BackupAndRestore` instead. 

1955 

1956 Caution: `tf.keras.callbacks.experimental.BackupAndRestore` endpoint is 

1957 deprecated and will be removed in a future release. Please use 

1958 `tf.keras.callbacks.BackupAndRestore`. 

1959 """ 

1960 

1961 def __init__(self, *args, **kwargs): 

1962 logging.warning( 

1963 "`tf.keras.callbacks.experimental.BackupAndRestore` endpoint is " 

1964 "deprecated and will be removed in a future release. Please use " 

1965 "`tf.keras.callbacks.BackupAndRestore`." 

1966 ) 

1967 super().__init__(*args, **kwargs) 

1968 

1969 

1970@keras_export("keras.callbacks.EarlyStopping") 

1971class EarlyStopping(Callback): 

1972 """Stop training when a monitored metric has stopped improving. 

1973 

1974 Assuming the goal of a training is to minimize the loss. With this, the 

1975 metric to be monitored would be `'loss'`, and mode would be `'min'`. A 

1976 `model.fit()` training loop will check at end of every epoch whether 

1977 the loss is no longer decreasing, considering the `min_delta` and 

1978 `patience` if applicable. Once it's found no longer decreasing, 

1979 `model.stop_training` is marked True and the training terminates. 

1980 

1981 The quantity to be monitored needs to be available in `logs` dict. 

1982 To make it so, pass the loss or metrics at `model.compile()`. 

1983 

1984 Args: 

1985 monitor: Quantity to be monitored. 

1986 min_delta: Minimum change in the monitored quantity 

1987 to qualify as an improvement, i.e. an absolute 

1988 change of less than min_delta, will count as no 

1989 improvement. 

1990 patience: Number of epochs with no improvement 

1991 after which training will be stopped. 

1992 verbose: Verbosity mode, 0 or 1. Mode 0 is silent, and mode 1 

1993 displays messages when the callback takes an action. 

1994 mode: One of `{"auto", "min", "max"}`. In `min` mode, 

1995 training will stop when the quantity 

1996 monitored has stopped decreasing; in `"max"` 

1997 mode it will stop when the quantity 

1998 monitored has stopped increasing; in `"auto"` 

1999 mode, the direction is automatically inferred 

2000 from the name of the monitored quantity. 

2001 baseline: Baseline value for the monitored quantity. 

2002 Training will stop if the model doesn't show improvement over the 

2003 baseline. 

2004 restore_best_weights: Whether to restore model weights from 

2005 the epoch with the best value of the monitored quantity. 

2006 If False, the model weights obtained at the last step of 

2007 training are used. An epoch will be restored regardless 

2008 of the performance relative to the `baseline`. If no epoch 

2009 improves on `baseline`, training will run for `patience` 

2010 epochs and restore weights from the best epoch in that set. 

2011 start_from_epoch: Number of epochs to wait before starting 

2012 to monitor improvement. This allows for a warm-up period in which 

2013 no improvement is expected and thus training will not be stopped. 

2014 

2015 

2016 Example: 

2017 

2018 >>> callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3) 

2019 >>> # This callback will stop the training when there is no improvement in 

2020 >>> # the loss for three consecutive epochs. 

2021 >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 

2022 >>> model.compile(tf.keras.optimizers.SGD(), loss='mse') 

2023 >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), 

2024 ... epochs=10, batch_size=1, callbacks=[callback], 

2025 ... verbose=0) 

2026 >>> len(history.history['loss']) # Only 4 epochs are run. 

2027 4 

2028 """ 

2029 

2030 def __init__( 

2031 self, 

2032 monitor="val_loss", 

2033 min_delta=0, 

2034 patience=0, 

2035 verbose=0, 

2036 mode="auto", 

2037 baseline=None, 

2038 restore_best_weights=False, 

2039 start_from_epoch=0, 

2040 ): 

2041 super().__init__() 

2042 

2043 self.monitor = monitor 

2044 self.patience = patience 

2045 self.verbose = verbose 

2046 self.baseline = baseline 

2047 self.min_delta = abs(min_delta) 

2048 self.wait = 0 

2049 self.stopped_epoch = 0 

2050 self.restore_best_weights = restore_best_weights 

2051 self.best_weights = None 

2052 self.start_from_epoch = start_from_epoch 

2053 

2054 if mode not in ["auto", "min", "max"]: 

2055 logging.warning( 

2056 "EarlyStopping mode %s is unknown, fallback to auto mode.", 

2057 mode, 

2058 ) 

2059 mode = "auto" 

2060 

2061 if mode == "min": 

2062 self.monitor_op = np.less 

2063 elif mode == "max": 

2064 self.monitor_op = np.greater 

2065 else: 

2066 if ( 

2067 self.monitor.endswith("acc") 

2068 or self.monitor.endswith("accuracy") 

2069 or self.monitor.endswith("auc") 

2070 ): 

2071 self.monitor_op = np.greater 

2072 else: 

2073 self.monitor_op = np.less 

2074 

2075 if self.monitor_op == np.greater: 

2076 self.min_delta *= 1 

2077 else: 

2078 self.min_delta *= -1 

2079 

2080 def on_train_begin(self, logs=None): 

2081 # Allow instances to be re-used 

2082 self.wait = 0 

2083 self.stopped_epoch = 0 

2084 self.best = np.Inf if self.monitor_op == np.less else -np.Inf 

2085 self.best_weights = None 

2086 self.best_epoch = 0 

2087 

2088 def on_epoch_end(self, epoch, logs=None): 

2089 current = self.get_monitor_value(logs) 

2090 if current is None or epoch < self.start_from_epoch: 

2091 # If no monitor value exists or still in initial warm-up stage. 

2092 return 

2093 if self.restore_best_weights and self.best_weights is None: 

2094 # Restore the weights after first epoch if no progress is ever made. 

2095 self.best_weights = self.model.get_weights() 

2096 

2097 self.wait += 1 

2098 if self._is_improvement(current, self.best): 

2099 self.best = current 

2100 self.best_epoch = epoch 

2101 if self.restore_best_weights: 

2102 self.best_weights = self.model.get_weights() 

2103 # Only restart wait if we beat both the baseline and our previous 

2104 # best. 

2105 if self.baseline is None or self._is_improvement( 

2106 current, self.baseline 

2107 ): 

2108 self.wait = 0 

2109 return 

2110 

2111 # Only check after the first epoch. 

2112 if self.wait >= self.patience and epoch > 0: 

2113 self.stopped_epoch = epoch 

2114 self.model.stop_training = True 

2115 if self.restore_best_weights and self.best_weights is not None: 

2116 if self.verbose > 0: 

2117 io_utils.print_msg( 

2118 "Restoring model weights from " 

2119 "the end of the best epoch: " 

2120 f"{self.best_epoch + 1}." 

2121 ) 

2122 self.model.set_weights(self.best_weights) 

2123 

2124 def on_train_end(self, logs=None): 

2125 if self.stopped_epoch > 0 and self.verbose > 0: 

2126 io_utils.print_msg( 

2127 f"Epoch {self.stopped_epoch + 1}: early stopping" 

2128 ) 

2129 

2130 def get_monitor_value(self, logs): 

2131 logs = logs or {} 

2132 monitor_value = logs.get(self.monitor) 

2133 if monitor_value is None: 

2134 logging.warning( 

2135 "Early stopping conditioned on metric `%s` " 

2136 "which is not available. Available metrics are: %s", 

2137 self.monitor, 

2138 ",".join(list(logs.keys())), 

2139 ) 

2140 return monitor_value 

2141 

2142 def _is_improvement(self, monitor_value, reference_value): 

2143 return self.monitor_op(monitor_value - self.min_delta, reference_value) 

2144 

2145 

2146@keras_export("keras.callbacks.RemoteMonitor") 

2147class RemoteMonitor(Callback): 

2148 """Callback used to stream events to a server. 

2149 

2150 Requires the `requests` library. 

2151 Events are sent to `root + '/publish/epoch/end/'` by default. Calls are 

2152 HTTP POST, with a `data` argument which is a 

2153 JSON-encoded dictionary of event data. 

2154 If `send_as_json=True`, the content type of the request will be 

2155 `"application/json"`. 

2156 Otherwise the serialized JSON will be sent within a form. 

2157 

2158 Args: 

2159 root: String; root url of the target server. 

2160 path: String; path relative to `root` to which the events will be sent. 

2161 field: String; JSON field under which the data will be stored. 

2162 The field is used only if the payload is sent within a form 

2163 (i.e. send_as_json is set to False). 

2164 headers: Dictionary; optional custom HTTP headers. 

2165 send_as_json: Boolean; whether the request should be 

2166 sent as `"application/json"`. 

2167 """ 

2168 

2169 def __init__( 

2170 self, 

2171 root="http://localhost:9000", 

2172 path="/publish/epoch/end/", 

2173 field="data", 

2174 headers=None, 

2175 send_as_json=False, 

2176 ): 

2177 super().__init__() 

2178 

2179 self.root = root 

2180 self.path = path 

2181 self.field = field 

2182 self.headers = headers 

2183 self.send_as_json = send_as_json 

2184 

2185 def on_epoch_end(self, epoch, logs=None): 

2186 if requests is None: 

2187 raise ImportError("RemoteMonitor requires the `requests` library.") 

2188 logs = logs or {} 

2189 send = {} 

2190 send["epoch"] = epoch 

2191 for k, v in logs.items(): 

2192 # np.ndarray and np.generic are not scalar types 

2193 # therefore we must unwrap their scalar values and 

2194 # pass to the json-serializable dict 'send' 

2195 if isinstance(v, (np.ndarray, np.generic)): 

2196 send[k] = v.item() 

2197 else: 

2198 send[k] = v 

2199 try: 

2200 if self.send_as_json: 

2201 requests.post( 

2202 self.root + self.path, json=send, headers=self.headers 

2203 ) 

2204 else: 

2205 requests.post( 

2206 self.root + self.path, 

2207 {self.field: json.dumps(send)}, 

2208 headers=self.headers, 

2209 ) 

2210 except requests.exceptions.RequestException: 

2211 logging.warning( 

2212 "Warning: could not reach RemoteMonitor root server at " 

2213 + str(self.root) 

2214 ) 

2215 

2216 

2217@keras_export("keras.callbacks.LearningRateScheduler") 

2218class LearningRateScheduler(Callback): 

2219 """Learning rate scheduler. 

2220 

2221 At the beginning of every epoch, this callback gets the updated learning 

2222 rate value from `schedule` function provided at `__init__`, with the current 

2223 epoch and current learning rate, and applies the updated learning rate on 

2224 the optimizer. 

2225 

2226 Args: 

2227 schedule: a function that takes an epoch index (integer, indexed from 0) 

2228 and current learning rate (float) as inputs and returns a new 

2229 learning rate as output (float). 

2230 verbose: int. 0: quiet, 1: update messages. 

2231 

2232 Example: 

2233 

2234 >>> # This function keeps the initial learning rate for the first ten epochs 

2235 >>> # and decreases it exponentially after that. 

2236 >>> def scheduler(epoch, lr): 

2237 ... if epoch < 10: 

2238 ... return lr 

2239 ... else: 

2240 ... return lr * tf.math.exp(-0.1) 

2241 >>> 

2242 >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 

2243 >>> model.compile(tf.keras.optimizers.SGD(), loss='mse') 

2244 >>> round(model.optimizer.lr.numpy(), 5) 

2245 0.01 

2246 

2247 >>> callback = tf.keras.callbacks.LearningRateScheduler(scheduler) 

2248 >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), 

2249 ... epochs=15, callbacks=[callback], verbose=0) 

2250 >>> round(model.optimizer.lr.numpy(), 5) 

2251 0.00607 

2252 

2253 """ 

2254 

2255 def __init__(self, schedule, verbose=0): 

2256 super().__init__() 

2257 self.schedule = schedule 

2258 self.verbose = verbose 

2259 

2260 def on_epoch_begin(self, epoch, logs=None): 

2261 if not hasattr(self.model.optimizer, "lr"): 

2262 raise ValueError('Optimizer must have a "lr" attribute.') 

2263 try: # new API 

2264 lr = float(backend.get_value(self.model.optimizer.lr)) 

2265 lr = self.schedule(epoch, lr) 

2266 except TypeError: # Support for old API for backward compatibility 

2267 lr = self.schedule(epoch) 

2268 if not isinstance(lr, (tf.Tensor, float, np.float32, np.float64)): 

2269 raise ValueError( 

2270 'The output of the "schedule" function ' 

2271 f"should be float. Got: {lr}" 

2272 ) 

2273 if isinstance(lr, tf.Tensor) and not lr.dtype.is_floating: 

2274 raise ValueError( 

2275 f"The dtype of `lr` Tensor should be float. Got: {lr.dtype}" 

2276 ) 

2277 backend.set_value(self.model.optimizer.lr, backend.get_value(lr)) 

2278 if self.verbose > 0: 

2279 io_utils.print_msg( 

2280 f"\nEpoch {epoch + 1}: LearningRateScheduler setting learning " 

2281 f"rate to {lr}." 

2282 ) 

2283 

2284 def on_epoch_end(self, epoch, logs=None): 

2285 logs = logs or {} 

2286 logs["lr"] = backend.get_value(self.model.optimizer.lr) 

2287 

2288 

2289def keras_model_summary(name, data, step=None): 

2290 """Writes a Keras model as JSON to as a Summary. 

2291 

2292 Writing the Keras model configuration allows the TensorBoard graph plugin to 

2293 render a conceptual graph, as opposed to graph of ops. In case the model 

2294 fails to serialize as JSON, it ignores and returns False. 

2295 

2296 Args: 

2297 name: A name for this summary. The summary tag used for TensorBoard will 

2298 be this name prefixed by any active name scopes. 

2299 data: A Keras Model to write. 

2300 step: Explicit `int64`-castable monotonic step value for this summary. If 

2301 omitted, this defaults to `tf.summary.experimental.get_step()`, which 

2302 must not be None. 

2303 

2304 Returns: 

2305 True on success, or False if no summary was written because no default 

2306 summary writer was available. 

2307 

2308 Raises: 

2309 ValueError: if a default writer exists, but no step was provided and 

2310 `tf.summary.experimental.get_step()` is None. 

2311 """ 

2312 summary_metadata = tf.compat.v1.SummaryMetadata() 

2313 # Hard coding a plugin name. Please refer to go/tb-plugin-name-hardcode for 

2314 # the rationale. 

2315 summary_metadata.plugin_data.plugin_name = "graph_keras_model" 

2316 # version number = 1 

2317 summary_metadata.plugin_data.content = b"1" 

2318 

2319 try: 

2320 json_string = data.to_json() 

2321 except Exception as exc: 

2322 # An exception should not break a model code. 

2323 logging.warning( 

2324 "Model failed to serialize as JSON. Ignoring... %s", exc 

2325 ) 

2326 return False 

2327 

2328 with tf.summary.experimental.summary_scope( 

2329 name, "graph_keras_model", [data, step] 

2330 ) as (tag, _): 

2331 with tf.device("cpu:0"): 

2332 tensor = tf.constant(json_string, dtype=tf.string) 

2333 return tf.summary.write( 

2334 tag=tag, tensor=tensor, step=step, metadata=summary_metadata 

2335 ) 

2336 

2337 

2338@keras_export("keras.callbacks.TensorBoard", v1=[]) 

2339class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): 

2340 

2341 """Enable visualizations for TensorBoard. 

2342 

2343 TensorBoard is a visualization tool provided with TensorFlow. 

2344 

2345 This callback logs events for TensorBoard, including: 

2346 

2347 * Metrics summary plots 

2348 * Training graph visualization 

2349 * Weight histograms 

2350 * Sampled profiling 

2351 

2352 When used in `Model.evaluate` or regular validation 

2353 ([on_test_end](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback#on_test_end)), 

2354 in addition to epoch summaries, there will be a summary that records 

2355 evaluation metrics vs `Model.optimizer.iterations` written. The metric names 

2356 will be prepended with `evaluation`, with `Model.optimizer.iterations` being 

2357 the step in the visualized TensorBoard. 

2358 

2359 If you have installed TensorFlow with pip, you should be able 

2360 to launch TensorBoard from the command line: 

2361 

2362 ``` 

2363 tensorboard --logdir=path_to_your_logs 

2364 ``` 

2365 

2366 You can find more information about TensorBoard 

2367 [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard). 

2368 

2369 Args: 

2370 log_dir: the path of the directory where to save the log files to be 

2371 parsed by TensorBoard. e.g. log_dir = os.path.join(working_dir, 

2372 'logs') This directory should not be reused by any other callbacks. 

2373 histogram_freq: frequency (in epochs) at which to compute 

2374 weight histograms for the layers of the model. If set to 0, histograms 

2375 won't be computed. Validation data (or split) must be specified for 

2376 histogram visualizations. 

2377 write_graph: whether to visualize the graph in TensorBoard. The log file 

2378 can become quite large when write_graph is set to True. 

2379 write_images: whether to write model weights to visualize as image in 

2380 TensorBoard. 

2381 write_steps_per_second: whether to log the training steps per second 

2382 into TensorBoard. This supports both epoch and batch frequency 

2383 logging. 

2384 update_freq: `'batch'` or `'epoch'` or integer. When using `'epoch'`, 

2385 writes the losses and metrics to TensorBoard after every epoch. 

2386 If using an integer, let's say `1000`, all metrics and losses 

2387 (including custom ones added by `Model.compile`) will be logged to 

2388 TensorBoard every 1000 batches. `'batch'` is a synonym for `1`, 

2389 meaning that they will be written every batch. 

2390 Note however that writing too frequently to TensorBoard can slow down 

2391 your training, especially when used with `tf.distribute.Strategy` as 

2392 it will incur additional synchronization overhead. 

2393 Use with `ParameterServerStrategy` is not supported. 

2394 Batch-level summary writing is also available via `train_step` 

2395 override. Please see 

2396 [TensorBoard Scalars tutorial](https://www.tensorflow.org/tensorboard/scalars_and_keras#batch-level_logging) # noqa: E501 

2397 for more details. 

2398 profile_batch: Profile the batch(es) to sample compute characteristics. 

2399 profile_batch must be a non-negative integer or a tuple of integers. 

2400 A pair of positive integers signify a range of batches to profile. 

2401 By default, profiling is disabled. 

2402 embeddings_freq: frequency (in epochs) at which embedding layers will be 

2403 visualized. If set to 0, embeddings won't be visualized. 

2404 embeddings_metadata: Dictionary which maps embedding layer names to the 

2405 filename of a file in which to save metadata for the embedding layer. 

2406 In case the same metadata file is to be 

2407 used for all embedding layers, a single filename can be passed. 

2408 

2409 Examples: 

2410 

2411 Basic usage: 

2412 

2413 ```python 

2414 tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs") 

2415 model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) 

2416 # Then run the tensorboard command to view the visualizations. 

2417 ``` 

2418 

2419 Custom batch-level summaries in a subclassed Model: 

2420 

2421 ```python 

2422 class MyModel(tf.keras.Model): 

2423 

2424 def build(self, _): 

2425 self.dense = tf.keras.layers.Dense(10) 

2426 

2427 def call(self, x): 

2428 outputs = self.dense(x) 

2429 tf.summary.histogram('outputs', outputs) 

2430 return outputs 

2431 

2432 model = MyModel() 

2433 model.compile('sgd', 'mse') 

2434 

2435 # Make sure to set `update_freq=N` to log a batch-level summary every N 

2436 # batches. In addition to any `tf.summary` contained in `Model.call`, 

2437 # metrics added in `Model.compile` will be logged every N batches. 

2438 tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1) 

2439 model.fit(x_train, y_train, callbacks=[tb_callback]) 

2440 ``` 

2441 

2442 Custom batch-level summaries in a Functional API Model: 

2443 

2444 ```python 

2445 def my_summary(x): 

2446 tf.summary.histogram('x', x) 

2447 return x 

2448 

2449 inputs = tf.keras.Input(10) 

2450 x = tf.keras.layers.Dense(10)(inputs) 

2451 outputs = tf.keras.layers.Lambda(my_summary)(x) 

2452 model = tf.keras.Model(inputs, outputs) 

2453 model.compile('sgd', 'mse') 

2454 

2455 # Make sure to set `update_freq=N` to log a batch-level summary every N 

2456 # batches. In addition to any `tf.summary` contained in `Model.call`, 

2457 # metrics added in `Model.compile` will be logged every N batches. 

2458 tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1) 

2459 model.fit(x_train, y_train, callbacks=[tb_callback]) 

2460 ``` 

2461 

2462 Profiling: 

2463 

2464 ```python 

2465 # Profile a single batch, e.g. the 5th batch. 

2466 tensorboard_callback = tf.keras.callbacks.TensorBoard( 

2467 log_dir='./logs', profile_batch=5) 

2468 model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) 

2469 

2470 # Profile a range of batches, e.g. from 10 to 20. 

2471 tensorboard_callback = tf.keras.callbacks.TensorBoard( 

2472 log_dir='./logs', profile_batch=(10,20)) 

2473 model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) 

2474 ``` 

2475 """ 

2476 

2477 def __init__( 

2478 self, 

2479 log_dir="logs", 

2480 histogram_freq=0, 

2481 write_graph=True, 

2482 write_images=False, 

2483 write_steps_per_second=False, 

2484 update_freq="epoch", 

2485 profile_batch=0, 

2486 embeddings_freq=0, 

2487 embeddings_metadata=None, 

2488 **kwargs, 

2489 ): 

2490 super().__init__() 

2491 self._supports_tf_logs = True 

2492 self._validate_kwargs(kwargs) 

2493 

2494 self.log_dir = io_utils.path_to_string(log_dir) 

2495 self.histogram_freq = histogram_freq 

2496 self.write_graph = write_graph 

2497 self.write_images = write_images 

2498 self.write_steps_per_second = write_steps_per_second 

2499 self.update_freq = 1 if update_freq == "batch" else update_freq 

2500 self.embeddings_freq = embeddings_freq 

2501 self.embeddings_metadata = embeddings_metadata 

2502 self._init_profile_batch(profile_batch) 

2503 self._global_train_batch = 0 

2504 self._previous_epoch_iterations = 0 

2505 self._train_accumulated_time = 0 

2506 self._batch_start_time = 0 

2507 

2508 # Lazily initialized in order to avoid creating event files when 

2509 # not needed. 

2510 self._writers = {} 

2511 

2512 # Used to restore any existing `SummaryWriter` after training ends. 

2513 self._prev_summary_state = [] 

2514 

2515 def _validate_kwargs(self, kwargs): 

2516 """Handle arguments were supported in V1.""" 

2517 if kwargs.get("write_grads", False): 

2518 logging.warning( 

2519 "`write_grads` will be ignored in TensorFlow 2.0 " 

2520 "for the `TensorBoard` Callback." 

2521 ) 

2522 if kwargs.get("batch_size", False): 

2523 logging.warning( 

2524 "`batch_size` is no longer needed in the " 

2525 "`TensorBoard` Callback and will be ignored " 

2526 "in TensorFlow 2.0." 

2527 ) 

2528 if kwargs.get("embeddings_layer_names", False): 

2529 logging.warning( 

2530 "`embeddings_layer_names` is not supported in " 

2531 "TensorFlow 2.0. Instead, all `Embedding` layers " 

2532 "will be visualized." 

2533 ) 

2534 if kwargs.get("embeddings_data", False): 

2535 logging.warning( 

2536 "`embeddings_data` is not supported in TensorFlow " 

2537 "2.0. Instead, all `Embedding` variables will be " 

2538 "visualized." 

2539 ) 

2540 

2541 supported_kwargs = { 

2542 "write_grads", 

2543 "embeddings_layer_names", 

2544 "embeddings_data", 

2545 "batch_size", 

2546 } 

2547 unrecognized_kwargs = set(kwargs.keys()) - supported_kwargs 

2548 

2549 # Only allow kwargs that were supported in V1. 

2550 if unrecognized_kwargs: 

2551 raise ValueError( 

2552 "Unrecognized arguments in `TensorBoard` Callback: " 

2553 f"{unrecognized_kwargs}. " 

2554 f"Supported kwargs are: {supported_kwargs}" 

2555 ) 

2556 

2557 def set_model(self, model): 

2558 """Sets Keras model and writes graph if specified.""" 

2559 self.model = model 

2560 self._log_write_dir = self._get_log_write_dir() 

2561 

2562 self._train_dir = os.path.join(self._log_write_dir, "train") 

2563 self._train_step = self.model._train_counter 

2564 

2565 self._val_dir = os.path.join(self._log_write_dir, "validation") 

2566 self._val_step = self.model._test_counter 

2567 

2568 self._writers = {} # Resets writers. 

2569 

2570 self._should_write_train_graph = False 

2571 if self.write_graph: 

2572 self._write_keras_model_summary() 

2573 self._should_write_train_graph = True 

2574 if self.embeddings_freq: 

2575 self._configure_embeddings() 

2576 

2577 @property 

2578 def _train_writer(self): 

2579 if "train" not in self._writers: 

2580 self._writers["train"] = tf.summary.create_file_writer( 

2581 self._train_dir 

2582 ) 

2583 return self._writers["train"] 

2584 

2585 @property 

2586 def _val_writer(self): 

2587 if "val" not in self._writers: 

2588 self._writers["val"] = tf.summary.create_file_writer(self._val_dir) 

2589 return self._writers["val"] 

2590 

2591 def _get_log_write_dir(self): 

2592 """For multi-worker, only chief should write, others write to '/tmp'.""" 

2593 return distributed_file_utils.write_dirpath( 

2594 self.log_dir, self.model.distribute_strategy 

2595 ) 

2596 

2597 def _delete_tmp_write_dir(self): 

2598 """Deletes tmp write directories for multi-worker.""" 

2599 distributed_file_utils.remove_temp_dirpath( 

2600 self.log_dir, self.model.distribute_strategy 

2601 ) 

2602 

2603 def _write_keras_model_train_graph(self): 

2604 """Writes Keras model train_function graph to TensorBoard.""" 

2605 with self._train_writer.as_default(): 

2606 with tf.summary.record_if(True): 

2607 train_fn = self.model.train_tf_function 

2608 # If the train_function is a `tf.function`, we can write out a 

2609 # graph 

2610 if hasattr(train_fn, "function_spec"): 

2611 # TODO(b/243822285): Use _variable_creation_fn directly. 

2612 if hasattr(train_fn, "_concrete_stateful_fn"): 

2613 tf.summary.graph(train_fn._concrete_stateful_fn.graph) 

2614 else: 

2615 tf.summary.graph( 

2616 train_fn._concrete_variable_creation_fn.graph 

2617 ) 

2618 

2619 def _write_keras_model_summary(self): 

2620 """Writes Keras graph network summary to TensorBoard.""" 

2621 with self._train_writer.as_default(): 

2622 with tf.summary.record_if(True): 

2623 summary_writable = ( 

2624 self.model._is_graph_network 

2625 or self.model.__class__.__name__ == "Sequential" 

2626 ) 

2627 if summary_writable: 

2628 keras_model_summary("keras", self.model, step=0) 

2629 

2630 def _configure_embeddings(self): 

2631 """Configure the Projector for embeddings.""" 

2632 # TODO(omalleyt): Add integration tests. 

2633 from keras.src.layers import core 

2634 from keras.protobuf import projector_config_pb2 

2635 

2636 # isort: off 

2637 from google.protobuf import text_format 

2638 

2639 config = projector_config_pb2.ProjectorConfig() 

2640 for layer in self.model.layers: 

2641 if isinstance(layer, core.Embedding): 

2642 embedding = config.embeddings.add() 

2643 # Embeddings are always the first layer, so this naming should 

2644 # be consistent in any keras models checkpoints. 

2645 name = ( 

2646 "layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE" 

2647 ) 

2648 embedding.tensor_name = name 

2649 

2650 if self.embeddings_metadata is not None: 

2651 if isinstance(self.embeddings_metadata, str): 

2652 embedding.metadata_path = self.embeddings_metadata 

2653 else: 

2654 if layer.name in self.embeddings_metadata.keys(): 

2655 embedding.metadata_path = ( 

2656 self.embeddings_metadata.pop(layer.name) 

2657 ) 

2658 

2659 if self.embeddings_metadata and not isinstance( 

2660 self.embeddings_metadata, str 

2661 ): 

2662 raise ValueError( 

2663 "Unrecognized `Embedding` layer names passed to " 

2664 "`keras.callbacks.TensorBoard` `embeddings_metadata` " 

2665 f"argument: {self.embeddings_metadata.keys()}" 

2666 ) 

2667 

2668 config_pbtxt = text_format.MessageToString(config) 

2669 path = os.path.join(self._log_write_dir, "projector_config.pbtxt") 

2670 with tf.io.gfile.GFile(path, "w") as f: 

2671 f.write(config_pbtxt) 

2672 

2673 def _push_writer(self, writer, step): 

2674 """Sets the default writer for custom batch-level summaries.""" 

2675 if self.update_freq == "epoch": 

2676 return 

2677 

2678 should_record = lambda: tf.equal(step % self.update_freq, 0) 

2679 # TODO(b/151339474): Fix deadlock when not using .value() here. 

2680 summary_context = ( 

2681 writer.as_default(step.value()), 

2682 tf.summary.record_if(should_record), 

2683 ) 

2684 self._prev_summary_state.append(summary_context) 

2685 summary_context[0].__enter__() 

2686 summary_context[1].__enter__() 

2687 

2688 def _pop_writer(self): 

2689 """Pops the current writer.""" 

2690 if self.update_freq == "epoch": 

2691 return 

2692 

2693 # See _push_writer for the content of the previous_context, which is 

2694 # pair of context. 

2695 previous_context = self._prev_summary_state.pop() 

2696 previous_context[1].__exit__(*sys.exc_info()) 

2697 previous_context[0].__exit__(*sys.exc_info()) 

2698 

2699 def _close_writers(self): 

2700 for writer in self._writers.values(): 

2701 writer.close() 

2702 

2703 def _init_profile_batch(self, profile_batch): 

2704 """Validate profile_batch value and set the range of batches to profile. 

2705 

2706 Sets values of _start_batch and _stop_batch attributes, 

2707 specifying the start and stop batch to profile. 

2708 Setting `profile_batch=0` disables profiling. 

2709 

2710 Args: 

2711 profile_batch: The range of batches to profile. Should be a 

2712 non-negative integer or a comma separated string of pair of positive 

2713 integers. A pair of positive integers signify a range of batches to 

2714 profile. 

2715 

2716 Raises: 

2717 ValueError: If profile_batch is not an integer or a comma separated 

2718 pair of positive integers. 

2719 

2720 """ 

2721 profile_batch_error_message = ( 

2722 "profile_batch must be a non-negative integer or " 

2723 "2-tuple of positive " 

2724 "integers. A pair of positive integers " 

2725 "signifies a range of batches " 

2726 f"to profile. Found: {profile_batch}" 

2727 ) 

2728 

2729 # Support legacy way of specifying "start,stop" or "start" as str. 

2730 if isinstance(profile_batch, str): 

2731 profile_batch = str(profile_batch).split(",") 

2732 profile_batch = tf.nest.map_structure(int, profile_batch) 

2733 

2734 if isinstance(profile_batch, int): 

2735 self._start_batch = profile_batch 

2736 self._stop_batch = profile_batch 

2737 elif ( 

2738 isinstance(profile_batch, (tuple, list)) and len(profile_batch) == 2 

2739 ): 

2740 self._start_batch, self._stop_batch = profile_batch 

2741 else: 

2742 raise ValueError(profile_batch_error_message) 

2743 

2744 if self._start_batch < 0 or self._stop_batch < self._start_batch: 

2745 raise ValueError(profile_batch_error_message) 

2746 

2747 # True when the profiler was successfully started by this callback. 

2748 # We track the status here to make sure callbacks do not interfere with 

2749 # each other. The callback will only stop the profiler it started. 

2750 self._profiler_started = False 

2751 if self._start_batch > 0: 

2752 # Warm up and improve the profiling accuracy. 

2753 self._start_profiler(logdir="") 

2754 self._stop_profiler(save=False) 

2755 # True when a trace is running. 

2756 self._is_tracing = False 

2757 

2758 # Setting `profile_batch=0` disables profiling. 

2759 self._should_trace = not ( 

2760 self._start_batch == 0 and self._stop_batch == 0 

2761 ) 

2762 

2763 def on_train_begin(self, logs=None): 

2764 self._global_train_batch = 0 

2765 self._previous_epoch_iterations = 0 

2766 self._push_writer(self._train_writer, self._train_step) 

2767 

2768 def on_train_end(self, logs=None): 

2769 self._pop_writer() 

2770 

2771 if self._is_tracing: 

2772 self._stop_trace() 

2773 

2774 self._close_writers() 

2775 self._delete_tmp_write_dir() 

2776 

2777 def on_test_begin(self, logs=None): 

2778 self._push_writer(self._val_writer, self._val_step) 

2779 

2780 def on_test_end(self, logs=None): 

2781 if self.model.optimizer and hasattr(self.model.optimizer, "iterations"): 

2782 with tf.summary.record_if(True), self._val_writer.as_default(): 

2783 for name, value in logs.items(): 

2784 tf.summary.scalar( 

2785 "evaluation_" + name + "_vs_iterations", 

2786 value, 

2787 step=self.model.optimizer.iterations.read_value(), 

2788 ) 

2789 self._pop_writer() 

2790 

2791 def _implements_train_batch_hooks(self): 

2792 # Only call batch hooks when tracing or write_steps_per_second are 

2793 # enabled 

2794 return self._should_trace or self.write_steps_per_second 

2795 

2796 def on_train_batch_begin(self, batch, logs=None): 

2797 self._global_train_batch += 1 

2798 if self.write_steps_per_second: 

2799 self._batch_start_time = time.time() 

2800 if not self._should_trace: 

2801 return 

2802 

2803 if self._global_train_batch == self._start_batch: 

2804 self._start_trace() 

2805 

2806 def on_train_batch_end(self, batch, logs=None): 

2807 if self._should_write_train_graph: 

2808 self._write_keras_model_train_graph() 

2809 self._should_write_train_graph = False 

2810 if self.write_steps_per_second: 

2811 batch_run_time = time.time() - self._batch_start_time 

2812 tf.summary.scalar( 

2813 "batch_steps_per_second", 

2814 1.0 / batch_run_time, 

2815 step=self._train_step, 

2816 ) 

2817 

2818 # `logs` isn't necessarily always a dict. For example, when using 

2819 # `tf.distribute.experimental.ParameterServerStrategy`, a 

2820 # `tf.distribute.experimental.coordinator.RemoteValue` will be passed. 

2821 # For now, we just disable `update_freq` in those cases. 

2822 if isinstance(logs, dict): 

2823 for name, value in logs.items(): 

2824 tf.summary.scalar("batch_" + name, value, step=self._train_step) 

2825 

2826 if not self._should_trace: 

2827 return 

2828 

2829 if self._is_tracing and self._global_train_batch >= self._stop_batch: 

2830 self._stop_trace() 

2831 

2832 def on_epoch_begin(self, epoch, logs=None): 

2833 # Keeps track of epoch for profiling. 

2834 if self.write_steps_per_second: 

2835 self._previous_epoch_iterations = ( 

2836 self.model.optimizer.iterations.numpy() 

2837 ) 

2838 self._epoch_start_time = time.time() 

2839 

2840 def on_epoch_end(self, epoch, logs=None): 

2841 """Runs metrics and histogram summaries at epoch end.""" 

2842 self._log_epoch_metrics(epoch, logs) 

2843 

2844 if self.histogram_freq and epoch % self.histogram_freq == 0: 

2845 self._log_weights(epoch) 

2846 

2847 if self.embeddings_freq and epoch % self.embeddings_freq == 0: 

2848 self._log_embeddings(epoch) 

2849 

2850 def _start_trace(self): 

2851 tf.summary.trace_on(graph=True, profiler=False) 

2852 self._start_profiler(logdir=self.log_dir) 

2853 self._is_tracing = True 

2854 

2855 def _stop_trace(self, batch=None): 

2856 """Logs the trace graph to TensorBoard.""" 

2857 if batch is None: 

2858 batch = self._stop_batch 

2859 with self._train_writer.as_default(): 

2860 with tf.summary.record_if(True): 

2861 # TODO(b/126388999): Remove step info in the summary name. 

2862 tf.summary.trace_export(name="batch_%d" % batch, step=batch) 

2863 self._stop_profiler() 

2864 self._is_tracing = False 

2865 

2866 def _collect_learning_rate(self, logs): 

2867 if isinstance(self.model.optimizer, optimizer.Optimizer): 

2868 lr_schedule = getattr(self.model.optimizer, "_learning_rate", None) 

2869 else: 

2870 lr_schedule = getattr(self.model.optimizer, "lr", None) 

2871 if isinstance(lr_schedule, learning_rate_schedule.LearningRateSchedule): 

2872 logs["learning_rate"] = lr_schedule(self.model.optimizer.iterations) 

2873 return logs 

2874 

2875 def _compute_steps_per_second(self): 

2876 current_iteration = self.model.optimizer.iterations.numpy() 

2877 time_since_epoch_begin = time.time() - self._epoch_start_time 

2878 steps_per_second = ( 

2879 current_iteration - self._previous_epoch_iterations 

2880 ) / time_since_epoch_begin 

2881 return steps_per_second 

2882 

2883 def _log_epoch_metrics(self, epoch, logs): 

2884 """Writes epoch metrics out as scalar summaries. 

2885 

2886 Args: 

2887 epoch: Int. The global step to use for TensorBoard. 

2888 logs: Dict. Keys are scalar summary names, values are scalars. 

2889 """ 

2890 if not logs: 

2891 return 

2892 

2893 train_logs = {k: v for k, v in logs.items() if not k.startswith("val_")} 

2894 val_logs = {k: v for k, v in logs.items() if k.startswith("val_")} 

2895 train_logs = self._collect_learning_rate(train_logs) 

2896 if self.write_steps_per_second: 

2897 train_logs["steps_per_second"] = self._compute_steps_per_second() 

2898 

2899 with tf.summary.record_if(True): 

2900 if train_logs: 

2901 with self._train_writer.as_default(): 

2902 for name, value in train_logs.items(): 

2903 tf.summary.scalar("epoch_" + name, value, step=epoch) 

2904 if val_logs: 

2905 with self._val_writer.as_default(): 

2906 for name, value in val_logs.items(): 

2907 name = name[4:] # Remove 'val_' prefix. 

2908 tf.summary.scalar("epoch_" + name, value, step=epoch) 

2909 

2910 def _log_weights(self, epoch): 

2911 """Logs the weights of the Model to TensorBoard.""" 

2912 with self._train_writer.as_default(): 

2913 with tf.summary.record_if(True): 

2914 for layer in self.model.layers: 

2915 for weight in layer.weights: 

2916 weight_name = weight.name.replace(":", "_") 

2917 # Add a suffix to prevent summary tag name collision. 

2918 histogram_weight_name = weight_name + "/histogram" 

2919 tf.summary.histogram( 

2920 histogram_weight_name, weight, step=epoch 

2921 ) 

2922 if self.write_images: 

2923 # Add a suffix to prevent summary tag name 

2924 # collision. 

2925 image_weight_name = weight_name + "/image" 

2926 self._log_weight_as_image( 

2927 weight, image_weight_name, epoch 

2928 ) 

2929 self._train_writer.flush() 

2930 

2931 def _log_weight_as_image(self, weight, weight_name, epoch): 

2932 """Logs a weight as a TensorBoard image.""" 

2933 w_img = tf.squeeze(weight) 

2934 shape = backend.int_shape(w_img) 

2935 if len(shape) == 1: # Bias case 

2936 w_img = tf.reshape(w_img, [1, shape[0], 1, 1]) 

2937 elif len(shape) == 2: # Dense layer kernel case 

2938 if shape[0] > shape[1]: 

2939 w_img = tf.transpose(w_img) 

2940 shape = backend.int_shape(w_img) 

2941 w_img = tf.reshape(w_img, [1, shape[0], shape[1], 1]) 

2942 elif len(shape) == 3: # ConvNet case 

2943 if backend.image_data_format() == "channels_last": 

2944 # Switch to channels_first to display every kernel as a separate 

2945 # image. 

2946 w_img = tf.transpose(w_img, perm=[2, 0, 1]) 

2947 shape = backend.int_shape(w_img) 

2948 w_img = tf.reshape(w_img, [shape[0], shape[1], shape[2], 1]) 

2949 

2950 shape = backend.int_shape(w_img) 

2951 # Not possible to handle 3D convnets etc. 

2952 if len(shape) == 4 and shape[-1] in [1, 3, 4]: 

2953 tf.summary.image(weight_name, w_img, step=epoch) 

2954 

2955 def _log_embeddings(self, epoch): 

2956 embeddings_ckpt = os.path.join( 

2957 self._log_write_dir, 

2958 "train", 

2959 f"keras_embedding.ckpt-{epoch}", 

2960 ) 

2961 self.model.save_weights(embeddings_ckpt) 

2962 

2963 def _start_profiler(self, logdir): 

2964 """Starts the profiler if currently inactive. 

2965 

2966 Args: 

2967 logdir: Directory where profiler results will be saved. 

2968 """ 

2969 if self._profiler_started: 

2970 return 

2971 try: 

2972 tf.profiler.experimental.start(logdir=logdir) 

2973 self._profiler_started = True 

2974 except tf.errors.AlreadyExistsError as e: 

2975 # Profiler errors should not be fatal. 

2976 logging.error("Failed to start profiler: %s", e.message) 

2977 

2978 def _stop_profiler(self, save=True): 

2979 """Stops the profiler if currently active. 

2980 

2981 Args: 

2982 save: Whether to save the profiler results to TensorBoard. 

2983 """ 

2984 if not self._profiler_started: 

2985 return 

2986 try: 

2987 tf.profiler.experimental.stop(save=save) 

2988 except tf.errors.UnavailableError as e: 

2989 # Profiler errors should not be fatal. 

2990 logging.error("Failed to stop profiler: %s", e.message) 

2991 finally: 

2992 self._profiler_started = False 

2993 

2994 

2995@keras_export("keras.callbacks.ReduceLROnPlateau") 

2996class ReduceLROnPlateau(Callback): 

2997 """Reduce learning rate when a metric has stopped improving. 

2998 

2999 Models often benefit from reducing the learning rate by a factor 

3000 of 2-10 once learning stagnates. This callback monitors a 

3001 quantity and if no improvement is seen for a 'patience' number 

3002 of epochs, the learning rate is reduced. 

3003 

3004 Example: 

3005 

3006 ```python 

3007 reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, 

3008 patience=5, min_lr=0.001) 

3009 model.fit(X_train, Y_train, callbacks=[reduce_lr]) 

3010 ``` 

3011 

3012 Args: 

3013 monitor: quantity to be monitored. 

3014 factor: factor by which the learning rate will be reduced. 

3015 `new_lr = lr * factor`. 

3016 patience: number of epochs with no improvement after which learning rate 

3017 will be reduced. 

3018 verbose: int. 0: quiet, 1: update messages. 

3019 mode: one of `{'auto', 'min', 'max'}`. In `'min'` mode, 

3020 the learning rate will be reduced when the 

3021 quantity monitored has stopped decreasing; in `'max'` mode it will be 

3022 reduced when the quantity monitored has stopped increasing; in 

3023 `'auto'` mode, the direction is automatically inferred from the name 

3024 of the monitored quantity. 

3025 min_delta: threshold for measuring the new optimum, to only focus on 

3026 significant changes. 

3027 cooldown: number of epochs to wait before resuming normal operation 

3028 after lr has been reduced. 

3029 min_lr: lower bound on the learning rate. 

3030 """ 

3031 

3032 def __init__( 

3033 self, 

3034 monitor="val_loss", 

3035 factor=0.1, 

3036 patience=10, 

3037 verbose=0, 

3038 mode="auto", 

3039 min_delta=1e-4, 

3040 cooldown=0, 

3041 min_lr=0, 

3042 **kwargs, 

3043 ): 

3044 super().__init__() 

3045 

3046 self.monitor = monitor 

3047 if factor >= 1.0: 

3048 raise ValueError( 

3049 "ReduceLROnPlateau does not support " 

3050 f"a factor >= 1.0. Got {factor}" 

3051 ) 

3052 if "epsilon" in kwargs: 

3053 min_delta = kwargs.pop("epsilon") 

3054 logging.warning( 

3055 "`epsilon` argument is deprecated and " 

3056 "will be removed, use `min_delta` instead." 

3057 ) 

3058 self.factor = factor 

3059 self.min_lr = min_lr 

3060 self.min_delta = min_delta 

3061 self.patience = patience 

3062 self.verbose = verbose 

3063 self.cooldown = cooldown 

3064 self.cooldown_counter = 0 # Cooldown counter. 

3065 self.wait = 0 

3066 self.best = 0 

3067 self.mode = mode 

3068 self.monitor_op = None 

3069 self._reset() 

3070 

3071 def _reset(self): 

3072 """Resets wait counter and cooldown counter.""" 

3073 if self.mode not in ["auto", "min", "max"]: 

3074 logging.warning( 

3075 "Learning rate reduction mode %s is unknown, " 

3076 "fallback to auto mode.", 

3077 self.mode, 

3078 ) 

3079 self.mode = "auto" 

3080 if self.mode == "min" or ( 

3081 self.mode == "auto" and "acc" not in self.monitor 

3082 ): 

3083 self.monitor_op = lambda a, b: np.less(a, b - self.min_delta) 

3084 self.best = np.Inf 

3085 else: 

3086 self.monitor_op = lambda a, b: np.greater(a, b + self.min_delta) 

3087 self.best = -np.Inf 

3088 self.cooldown_counter = 0 

3089 self.wait = 0 

3090 

3091 def on_train_begin(self, logs=None): 

3092 self._reset() 

3093 

3094 def on_epoch_end(self, epoch, logs=None): 

3095 logs = logs or {} 

3096 logs["lr"] = backend.get_value(self.model.optimizer.lr) 

3097 current = logs.get(self.monitor) 

3098 if current is None: 

3099 logging.warning( 

3100 "Learning rate reduction is conditioned on metric `%s` " 

3101 "which is not available. Available metrics are: %s", 

3102 self.monitor, 

3103 ",".join(list(logs.keys())), 

3104 ) 

3105 

3106 else: 

3107 if self.in_cooldown(): 

3108 self.cooldown_counter -= 1 

3109 self.wait = 0 

3110 

3111 if self.monitor_op(current, self.best): 

3112 self.best = current 

3113 self.wait = 0 

3114 elif not self.in_cooldown(): 

3115 self.wait += 1 

3116 if self.wait >= self.patience: 

3117 old_lr = backend.get_value(self.model.optimizer.lr) 

3118 if old_lr > np.float32(self.min_lr): 

3119 new_lr = old_lr * self.factor 

3120 new_lr = max(new_lr, self.min_lr) 

3121 backend.set_value(self.model.optimizer.lr, new_lr) 

3122 if self.verbose > 0: 

3123 io_utils.print_msg( 

3124 f"\nEpoch {epoch +1}: " 

3125 "ReduceLROnPlateau reducing " 

3126 f"learning rate to {new_lr}." 

3127 ) 

3128 self.cooldown_counter = self.cooldown 

3129 self.wait = 0 

3130 

3131 def in_cooldown(self): 

3132 return self.cooldown_counter > 0 

3133 

3134 

3135@keras_export("keras.callbacks.CSVLogger") 

3136class CSVLogger(Callback): 

3137 """Callback that streams epoch results to a CSV file. 

3138 

3139 Supports all values that can be represented as a string, 

3140 including 1D iterables such as `np.ndarray`. 

3141 

3142 Example: 

3143 

3144 ```python 

3145 csv_logger = CSVLogger('training.log') 

3146 model.fit(X_train, Y_train, callbacks=[csv_logger]) 

3147 ``` 

3148 

3149 Args: 

3150 filename: Filename of the CSV file, e.g. `'run/log.csv'`. 

3151 separator: String used to separate elements in the CSV file. 

3152 append: Boolean. True: append if file exists (useful for continuing 

3153 training). False: overwrite existing file. 

3154 """ 

3155 

3156 def __init__(self, filename, separator=",", append=False): 

3157 self.sep = separator 

3158 self.filename = io_utils.path_to_string(filename) 

3159 self.append = append 

3160 self.writer = None 

3161 self.keys = None 

3162 self.append_header = True 

3163 super().__init__() 

3164 

3165 def on_train_begin(self, logs=None): 

3166 if self.append: 

3167 if tf.io.gfile.exists(self.filename): 

3168 with tf.io.gfile.GFile(self.filename, "r") as f: 

3169 self.append_header = not bool(len(f.readline())) 

3170 mode = "a" 

3171 else: 

3172 mode = "w" 

3173 self.csv_file = tf.io.gfile.GFile(self.filename, mode) 

3174 

3175 def on_epoch_end(self, epoch, logs=None): 

3176 logs = logs or {} 

3177 

3178 def handle_value(k): 

3179 is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0 

3180 if isinstance(k, str): 

3181 return k 

3182 elif ( 

3183 isinstance(k, collections.abc.Iterable) 

3184 and not is_zero_dim_ndarray 

3185 ): 

3186 return f"\"[{', '.join(map(str, k))}]\"" 

3187 else: 

3188 return k 

3189 

3190 if self.keys is None: 

3191 self.keys = sorted(logs.keys()) 

3192 # When validation_freq > 1, `val_` keys are not in first epoch logs 

3193 # Add the `val_` keys so that its part of the fieldnames of writer. 

3194 val_keys_found = False 

3195 for key in self.keys: 

3196 if key.startswith("val_"): 

3197 val_keys_found = True 

3198 break 

3199 if not val_keys_found: 

3200 self.keys.extend(["val_" + k for k in self.keys]) 

3201 

3202 if not self.writer: 

3203 

3204 class CustomDialect(csv.excel): 

3205 delimiter = self.sep 

3206 

3207 fieldnames = ["epoch"] + self.keys 

3208 

3209 self.writer = csv.DictWriter( 

3210 self.csv_file, fieldnames=fieldnames, dialect=CustomDialect 

3211 ) 

3212 if self.append_header: 

3213 self.writer.writeheader() 

3214 

3215 row_dict = collections.OrderedDict({"epoch": epoch}) 

3216 row_dict.update( 

3217 (key, handle_value(logs.get(key, "NA"))) for key in self.keys 

3218 ) 

3219 self.writer.writerow(row_dict) 

3220 self.csv_file.flush() 

3221 

3222 def on_train_end(self, logs=None): 

3223 self.csv_file.close() 

3224 self.writer = None 

3225 

3226 

3227@keras_export("keras.callbacks.LambdaCallback") 

3228class LambdaCallback(Callback): 

3229 r"""Callback for creating simple, custom callbacks on-the-fly. 

3230 

3231 This callback is constructed with anonymous functions that will be called 

3232 at the appropriate time (during `Model.{fit | evaluate | predict}`). 

3233 Note that the callbacks expects positional arguments, as: 

3234 

3235 - `on_epoch_begin` and `on_epoch_end` expect two positional arguments: 

3236 `epoch`, `logs` 

3237 - `on_batch_begin` and `on_batch_end` expect two positional arguments: 

3238 `batch`, `logs` 

3239 - `on_train_begin` and `on_train_end` expect one positional argument: 

3240 `logs` 

3241 

3242 Args: 

3243 on_epoch_begin: called at the beginning of every epoch. 

3244 on_epoch_end: called at the end of every epoch. 

3245 on_batch_begin: called at the beginning of every batch. 

3246 on_batch_end: called at the end of every batch. 

3247 on_train_begin: called at the beginning of model training. 

3248 on_train_end: called at the end of model training. 

3249 

3250 Example: 

3251 

3252 ```python 

3253 # Print the batch number at the beginning of every batch. 

3254 batch_print_callback = LambdaCallback( 

3255 on_batch_begin=lambda batch,logs: print(batch)) 

3256 

3257 # Stream the epoch loss to a file in JSON format. The file content 

3258 # is not well-formed JSON but rather has a JSON object per line. 

3259 import json 

3260 json_log = open('loss_log.json', mode='wt', buffering=1) 

3261 json_logging_callback = LambdaCallback( 

3262 on_epoch_end=lambda epoch, logs: json_log.write( 

3263 json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'), 

3264 on_train_end=lambda logs: json_log.close() 

3265 ) 

3266 

3267 # Terminate some processes after having finished model training. 

3268 processes = ... 

3269 cleanup_callback = LambdaCallback( 

3270 on_train_end=lambda logs: [ 

3271 p.terminate() for p in processes if p.is_alive()]) 

3272 

3273 model.fit(..., 

3274 callbacks=[batch_print_callback, 

3275 json_logging_callback, 

3276 cleanup_callback]) 

3277 ``` 

3278 """ 

3279 

3280 def __init__( 

3281 self, 

3282 on_epoch_begin=None, 

3283 on_epoch_end=None, 

3284 on_batch_begin=None, 

3285 on_batch_end=None, 

3286 on_train_begin=None, 

3287 on_train_end=None, 

3288 **kwargs, 

3289 ): 

3290 super().__init__() 

3291 self.__dict__.update(kwargs) 

3292 if on_epoch_begin is not None: 

3293 self.on_epoch_begin = on_epoch_begin 

3294 if on_epoch_end is not None: 

3295 self.on_epoch_end = on_epoch_end 

3296 if on_batch_begin is not None: 

3297 self.on_batch_begin = on_batch_begin 

3298 if on_batch_end is not None: 

3299 self.on_batch_end = on_batch_end 

3300 if on_train_begin is not None: 

3301 self.on_train_begin = on_train_begin 

3302 if on_train_end is not None: 

3303 self.on_train_end = on_train_end 

3304