Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/callbacks.py: 22%
1177 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=g-import-not-at-top
16# pylint: disable=g-classes-have-attributes
17"""Callbacks: utilities called at certain points during model training."""
19import collections
20import copy
21import csv
22import json
23import os
24import re
25import sys
26import time
28import numpy as np
30from tensorflow.core.framework import summary_pb2
31from tensorflow.python.checkpoint import checkpoint_management
32from tensorflow.python.checkpoint import checkpoint_options as checkpoint_options_lib
33from tensorflow.python.data.ops import iterator_ops
34from tensorflow.python.distribute import collective_all_reduce_strategy
35from tensorflow.python.distribute import distribute_lib
36from tensorflow.python.distribute import mirrored_strategy
37from tensorflow.python.distribute import parameter_server_strategy_v2
38from tensorflow.python.distribute import tpu_strategy
39from tensorflow.python.eager import context
40from tensorflow.python.framework import constant_op
41from tensorflow.python.framework import dtypes
42from tensorflow.python.framework import errors
43from tensorflow.python.framework import ops
44from tensorflow.python.keras import backend
45from tensorflow.python.keras.distribute import distributed_file_utils
46from tensorflow.python.keras.distribute import worker_training_state
47from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
48from tensorflow.python.keras.utils import generic_utils
49from tensorflow.python.keras.utils import tf_utils
50from tensorflow.python.keras.utils import version_utils
51from tensorflow.python.keras.utils.data_utils import Sequence
52from tensorflow.python.keras.utils.generic_utils import Progbar
53from tensorflow.python.keras.utils.io_utils import path_to_string
54from tensorflow.python.keras.utils.mode_keys import ModeKeys
55from tensorflow.python.lib.io import file_io
56from tensorflow.python.ops import array_ops
57from tensorflow.python.ops import math_ops
58from tensorflow.python.ops import summary_ops_v2
59from tensorflow.python.platform import gfile
60from tensorflow.python.platform import tf_logging as logging
61from tensorflow.python.profiler import profiler_v2 as profiler
62from tensorflow.python.saved_model import save_options as save_options_lib
63from tensorflow.python.util import nest
64from tensorflow.python.util.tf_export import keras_export
65from tensorflow.tools.docs import doc_controls
67try:
68 import requests
69except ImportError:
70 requests = None
73# Note: `configure_callbacks` is only used in TF1.
74def configure_callbacks(callbacks,
75 model,
76 do_validation=False,
77 batch_size=None,
78 epochs=None,
79 steps_per_epoch=None,
80 samples=None,
81 verbose=1,
82 count_mode='steps',
83 mode=ModeKeys.TRAIN):
84 """Configures callbacks for use in various training loops.
86 Args:
87 callbacks: List of Callbacks.
88 model: Model being trained.
89 do_validation: Whether or not validation loop will be run.
90 batch_size: Number of samples per batch.
91 epochs: Number of epoch to train.
92 steps_per_epoch: Number of batches to run per training epoch.
93 samples: Number of training samples.
94 verbose: int, 0 or 1. Keras logging verbosity to pass to ProgbarLogger.
95 count_mode: One of 'steps' or 'samples'. Per-batch or per-sample count.
96 mode: String. One of ModeKeys.TRAIN, ModeKeys.TEST, or ModeKeys.PREDICT.
97 Which loop mode to configure callbacks for.
99 Returns:
100 Instance of CallbackList used to control all Callbacks.
101 """
102 # Check if callbacks have already been configured.
103 if isinstance(callbacks, CallbackList):
104 return callbacks
106 if not callbacks:
107 callbacks = []
109 # Add additional callbacks during training.
110 if mode == ModeKeys.TRAIN:
111 model.history = History()
112 callbacks = [BaseLogger()] + (callbacks or []) + [model.history]
113 if verbose:
114 callbacks.append(ProgbarLogger(count_mode))
115 callback_list = CallbackList(callbacks)
117 # Set callback model
118 callback_model = model._get_callback_model() # pylint: disable=protected-access
119 callback_list.set_model(callback_model)
121 set_callback_parameters(
122 callback_list,
123 model,
124 do_validation=do_validation,
125 batch_size=batch_size,
126 epochs=epochs,
127 steps_per_epoch=steps_per_epoch,
128 samples=samples,
129 verbose=verbose,
130 mode=mode)
132 callback_list.model.stop_training = False
133 return callback_list
136def set_callback_parameters(callback_list,
137 model,
138 do_validation=False,
139 batch_size=None,
140 epochs=None,
141 steps_per_epoch=None,
142 samples=None,
143 verbose=1,
144 mode=ModeKeys.TRAIN):
145 """Sets callback parameters.
147 Args:
148 callback_list: CallbackList instance.
149 model: Model being trained.
150 do_validation: Whether or not validation loop will be run.
151 batch_size: Number of samples per batch.
152 epochs: Number of epoch to train.
153 steps_per_epoch: Number of batches to run per training epoch.
154 samples: Number of training samples.
155 verbose: int, 0 or 1. Keras logging verbosity to pass to ProgbarLogger.
156 mode: String. One of ModeKeys.TRAIN, ModeKeys.TEST, or ModeKeys.PREDICT.
157 Which loop mode to configure callbacks for.
158 """
159 metric_names = model.metrics_names
160 for cbk in callback_list:
161 if isinstance(cbk, (BaseLogger, ProgbarLogger)):
162 cbk.stateful_metrics = metric_names[1:] # Exclude `loss`
164 # Set callback parameters
165 callback_metrics = []
166 # When we have deferred build scenario with iterator input, we will compile
167 # when we standardize first batch of data.
168 if mode != ModeKeys.PREDICT:
169 callback_metrics = copy.copy(metric_names)
170 if do_validation:
171 callback_metrics += ['val_' + n for n in metric_names]
172 callback_params = {
173 'batch_size': batch_size,
174 'epochs': epochs,
175 'steps': steps_per_epoch,
176 'samples': samples,
177 'verbose': verbose,
178 'do_validation': do_validation,
179 'metrics': callback_metrics,
180 }
181 callback_list.set_params(callback_params)
184def _is_generator_like(data):
185 """Checks if data is a generator, Sequence, or Iterator."""
186 return (hasattr(data, '__next__') or hasattr(data, 'next') or isinstance(
187 data, (Sequence, iterator_ops.Iterator, iterator_ops.IteratorBase)))
190def make_logs(model, logs, outputs, mode, prefix=''):
191 """Computes logs for sending to `on_batch_end` methods."""
192 metric_names = model.metrics_names
193 if mode in {ModeKeys.TRAIN, ModeKeys.TEST} and metric_names:
194 for label, output in zip(metric_names, outputs):
195 logs[prefix + label] = output
196 else:
197 logs['outputs'] = outputs
198 return logs
201@keras_export('keras.callbacks.CallbackList')
202class CallbackList:
203 """Container abstracting a list of callbacks."""
205 def __init__(self,
206 callbacks=None,
207 add_history=False,
208 add_progbar=False,
209 model=None,
210 **params):
211 """Container for `Callback` instances.
213 This object wraps a list of `Callback` instances, making it possible
214 to call them all at once via a single endpoint
215 (e.g. `callback_list.on_epoch_end(...)`).
217 Args:
218 callbacks: List of `Callback` instances.
219 add_history: Whether a `History` callback should be added, if one does not
220 already exist in the `callbacks` list.
221 add_progbar: Whether a `ProgbarLogger` callback should be added, if one
222 does not already exist in the `callbacks` list.
223 model: The `Model` these callbacks are used with.
224 **params: If provided, parameters will be passed to each `Callback` via
225 `Callback.set_params`.
226 """
227 self.callbacks = nest.flatten(callbacks) if callbacks else []
228 self._add_default_callbacks(add_history, add_progbar)
230 if model:
231 self.set_model(model)
232 if params:
233 self.set_params(params)
235 # Performance optimization: determines if batch hooks need to be called.
236 # pylint: disable=protected-access
237 self._supports_tf_logs = all(
238 getattr(cb, '_supports_tf_logs', False) for cb in self.callbacks)
239 self._batch_hooks_support_tf_logs = all(
240 getattr(cb, '_supports_tf_logs', False)
241 for cb in self.callbacks
242 if cb._implements_train_batch_hooks() or cb
243 ._implements_test_batch_hooks() or cb._implements_predict_batch_hooks())
245 self._should_call_train_batch_hooks = any(
246 cb._implements_train_batch_hooks() for cb in self.callbacks)
247 self._should_call_test_batch_hooks = any(
248 cb._implements_test_batch_hooks() for cb in self.callbacks)
249 self._should_call_predict_batch_hooks = any(
250 cb._implements_predict_batch_hooks() for cb in self.callbacks)
251 # pylint: enable=protected-access
253 self._disallow_batch_hooks_in_ps_strategy()
255 # Performance check: Check batch hooks for slowness compared to batch time.
256 # Only run check for custom callbacks (i.e. not present in this file).
257 self._check_timing = any(
258 cbk.__class__.__name__ not in globals() for cbk in self.callbacks)
259 self._num_batches_for_timing_check = 5
260 self._hook_times = {}
261 self._batch_start_time = None
262 self._batch_times = []
264 def _add_default_callbacks(self, add_history, add_progbar):
265 """Adds `Callback`s that are always present."""
266 self._progbar = None
267 self._history = None
269 for cb in self.callbacks:
270 if isinstance(cb, ProgbarLogger):
271 self._progbar = cb
272 elif isinstance(cb, History):
273 self._history = cb
275 if self._progbar is None and add_progbar:
276 self._progbar = ProgbarLogger(count_mode='steps')
277 self.callbacks.insert(0, self._progbar)
279 if self._history is None and add_history:
280 self._history = History()
281 self.callbacks.append(self._history)
283 def _process_logs(self, logs, is_batch_hook=False):
284 """Turns tensors into numpy arrays or Python scalars if necessary."""
285 if logs is None:
286 return {}
287 if self._supports_tf_logs:
288 return logs
289 if is_batch_hook and self._batch_hooks_support_tf_logs:
290 return logs
291 return tf_utils.sync_to_numpy_or_python_type(logs)
293 def append(self, callback):
294 self.callbacks.append(callback)
296 def set_params(self, params):
297 self.params = params
298 for callback in self.callbacks:
299 callback.set_params(params)
301 def set_model(self, model):
302 self.model = model
303 if self._history:
304 model.history = self._history
305 for callback in self.callbacks:
306 callback.set_model(model)
308 def _call_batch_hook(self, mode, hook, batch, logs=None):
309 """Helper function for all batch_{begin | end} methods."""
310 if not self.callbacks:
311 return
313 if hook == 'begin':
314 self._call_batch_begin_hook(mode, batch, logs)
315 elif hook == 'end':
316 self._call_batch_end_hook(mode, batch, logs)
317 else:
318 raise ValueError('Unrecognized hook: {}'.format(hook))
320 def _call_batch_begin_hook(self, mode, batch, logs):
321 """Helper function for `on_*_batch_begin` methods."""
322 hook_name = 'on_{mode}_batch_begin'.format(mode=mode)
323 self._call_batch_hook_helper(hook_name, batch, logs)
325 if self._check_timing:
326 self._batch_start_time = time.time()
328 def _call_batch_end_hook(self, mode, batch, logs):
329 """Helper function for `on_*_batch_end` methods."""
330 hook_name = 'on_{mode}_batch_end'.format(mode=mode)
332 if self._check_timing and batch >= 1:
333 batch_time = time.time() - self._batch_start_time
334 self._batch_times.append(batch_time)
336 self._call_batch_hook_helper(hook_name, batch, logs)
338 if len(self._batch_times) >= self._num_batches_for_timing_check:
339 end_hook_name = hook_name
340 begin_hook_name = 'on_{mode}_batch_begin'.format(mode=mode)
341 avg_batch_time = sum(self._batch_times) / len(self._batch_times)
342 avg_end_hook_time = sum(self._hook_times[end_hook_name]) / len(
343 self._hook_times[end_hook_name])
344 avg_begin_hook_time = sum(self._hook_times[begin_hook_name]) / len(
345 self._hook_times[begin_hook_name])
347 threshold_time = 1.0 * avg_batch_time
348 warning_msg = ('Callback method `{hook}` is slow compared to '
349 'the batch time (batch time: {batch_time:.4f}s vs '
350 '`{hook}` time: {hook_time:.4f}s). Check your callbacks.')
351 if avg_begin_hook_time > threshold_time:
352 logging.warning(warning_msg.format(
353 hook=begin_hook_name,
354 batch_time=avg_batch_time,
355 hook_time=avg_begin_hook_time))
356 if avg_end_hook_time > threshold_time:
357 logging.warning(warning_msg.format(
358 hook=end_hook_name,
359 batch_time=avg_batch_time,
360 hook_time=avg_end_hook_time))
361 self._check_timing = False
362 self._batch_start_time = None
363 self._batch_times = []
364 self._hook_times = {}
366 def _call_batch_hook_helper(self, hook_name, batch, logs):
367 """Helper function for `on_*_batch_*` methods."""
368 if self._check_timing:
369 start_time = time.time()
371 logs = self._process_logs(logs, is_batch_hook=True)
372 for callback in self.callbacks:
373 hook = getattr(callback, hook_name)
374 hook(batch, logs)
376 if self._check_timing:
377 if hook_name not in self._hook_times:
378 self._hook_times[hook_name] = []
379 self._hook_times[hook_name].append(time.time() - start_time)
381 def _call_begin_hook(self, mode):
382 """Helper function for on_{train|test|predict}_begin methods."""
383 if mode == ModeKeys.TRAIN:
384 self.on_train_begin()
385 elif mode == ModeKeys.TEST:
386 self.on_test_begin()
387 else:
388 self.on_predict_begin()
390 def _call_end_hook(self, mode):
391 """Helper function for on_{train|test|predict}_end methods."""
392 if mode == ModeKeys.TRAIN:
393 self.on_train_end()
394 elif mode == ModeKeys.TEST:
395 self.on_test_end()
396 else:
397 self.on_predict_end()
399 def on_batch_begin(self, batch, logs=None):
400 if self._should_call_train_batch_hooks:
401 self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs)
403 def on_batch_end(self, batch, logs=None):
404 if self._should_call_train_batch_hooks:
405 self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
407 def on_epoch_begin(self, epoch, logs=None):
408 """Calls the `on_epoch_begin` methods of its callbacks.
410 This function should only be called during TRAIN mode.
412 Args:
413 epoch: Integer, index of epoch.
414 logs: Dict. Currently no data is passed to this argument for this method
415 but that may change in the future.
416 """
417 logs = self._process_logs(logs)
418 for callback in self.callbacks:
419 callback.on_epoch_begin(epoch, logs)
421 def on_epoch_end(self, epoch, logs=None):
422 """Calls the `on_epoch_end` methods of its callbacks.
424 This function should only be called during TRAIN mode.
426 Args:
427 epoch: Integer, index of epoch.
428 logs: Dict, metric results for this training epoch, and for the
429 validation epoch if validation is performed. Validation result keys
430 are prefixed with `val_`.
431 """
432 logs = self._process_logs(logs)
433 for callback in self.callbacks:
434 callback.on_epoch_end(epoch, logs)
436 def on_train_batch_begin(self, batch, logs=None):
437 """Calls the `on_train_batch_begin` methods of its callbacks.
439 Args:
440 batch: Integer, index of batch within the current epoch.
441 logs: Dict, contains the return value of `model.train_step`. Typically,
442 the values of the `Model`'s metrics are returned. Example:
443 `{'loss': 0.2, 'accuracy': 0.7}`.
444 """
445 if self._should_call_train_batch_hooks:
446 self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs)
448 def on_train_batch_end(self, batch, logs=None):
449 """Calls the `on_train_batch_end` methods of its callbacks.
451 Args:
452 batch: Integer, index of batch within the current epoch.
453 logs: Dict. Aggregated metric results up until this batch.
454 """
455 if self._should_call_train_batch_hooks:
456 self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
458 def on_test_batch_begin(self, batch, logs=None):
459 """Calls the `on_test_batch_begin` methods of its callbacks.
461 Args:
462 batch: Integer, index of batch within the current epoch.
463 logs: Dict, contains the return value of `model.test_step`. Typically,
464 the values of the `Model`'s metrics are returned. Example:
465 `{'loss': 0.2, 'accuracy': 0.7}`.
466 """
467 if self._should_call_test_batch_hooks:
468 self._call_batch_hook(ModeKeys.TEST, 'begin', batch, logs=logs)
470 def on_test_batch_end(self, batch, logs=None):
471 """Calls the `on_test_batch_end` methods of its callbacks.
473 Args:
474 batch: Integer, index of batch within the current epoch.
475 logs: Dict. Aggregated metric results up until this batch.
476 """
477 if self._should_call_test_batch_hooks:
478 self._call_batch_hook(ModeKeys.TEST, 'end', batch, logs=logs)
480 def on_predict_batch_begin(self, batch, logs=None):
481 """Calls the `on_predict_batch_begin` methods of its callbacks.
483 Args:
484 batch: Integer, index of batch within the current epoch.
485 logs: Dict, contains the return value of `model.predict_step`,
486 it typically returns a dict with a key 'outputs' containing
487 the model's outputs.
488 """
489 if self._should_call_predict_batch_hooks:
490 self._call_batch_hook(ModeKeys.PREDICT, 'begin', batch, logs=logs)
492 def on_predict_batch_end(self, batch, logs=None):
493 """Calls the `on_predict_batch_end` methods of its callbacks.
495 Args:
496 batch: Integer, index of batch within the current epoch.
497 logs: Dict. Aggregated metric results up until this batch.
498 """
499 if self._should_call_predict_batch_hooks:
500 self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs)
502 def on_train_begin(self, logs=None):
503 """Calls the `on_train_begin` methods of its callbacks.
505 Args:
506 logs: Dict. Currently no data is passed to this argument for this method
507 but that may change in the future.
508 """
509 logs = self._process_logs(logs)
510 for callback in self.callbacks:
511 callback.on_train_begin(logs)
513 def on_train_end(self, logs=None):
514 """Calls the `on_train_end` methods of its callbacks.
516 Args:
517 logs: Dict. Currently no data is passed to this argument for this method
518 but that may change in the future.
519 """
520 logs = self._process_logs(logs)
521 for callback in self.callbacks:
522 callback.on_train_end(logs)
524 def on_test_begin(self, logs=None):
525 """Calls the `on_test_begin` methods of its callbacks.
527 Args:
528 logs: Dict. Currently no data is passed to this argument for this method
529 but that may change in the future.
530 """
531 logs = self._process_logs(logs)
532 for callback in self.callbacks:
533 callback.on_test_begin(logs)
535 def on_test_end(self, logs=None):
536 """Calls the `on_test_end` methods of its callbacks.
538 Args:
539 logs: Dict. Currently no data is passed to this argument for this method
540 but that may change in the future.
541 """
542 logs = self._process_logs(logs)
543 for callback in self.callbacks:
544 callback.on_test_end(logs)
546 def on_predict_begin(self, logs=None):
547 """Calls the 'on_predict_begin` methods of its callbacks.
549 Args:
550 logs: Dict. Currently no data is passed to this argument for this method
551 but that may change in the future.
552 """
553 logs = self._process_logs(logs)
554 for callback in self.callbacks:
555 callback.on_predict_begin(logs)
557 def on_predict_end(self, logs=None):
558 """Calls the `on_predict_end` methods of its callbacks.
560 Args:
561 logs: Dict. Currently no data is passed to this argument for this method
562 but that may change in the future.
563 """
564 logs = self._process_logs(logs)
565 for callback in self.callbacks:
566 callback.on_predict_end(logs)
568 def __iter__(self):
569 return iter(self.callbacks)
571 def _disallow_batch_hooks_in_ps_strategy(self):
572 """Error out if batch-level callbacks are passed with PSStrategy."""
573 # pylint: disable=protected-access
574 strategy = distribute_lib.get_strategy()
575 if strategy._should_use_with_coordinator:
576 unsupported_callbacks = []
577 for cb in self.callbacks:
578 # These Callbacks can accept RemoteValues directly.
579 if getattr(cb, '_supports_tf_logs', False):
580 continue
581 if (cb._implements_train_batch_hooks() or
582 cb._implements_test_batch_hooks() or
583 cb._implements_predict_batch_hooks()):
584 unsupported_callbacks.append(cb)
585 if unsupported_callbacks:
586 raise ValueError('Batch-level `Callback`s are not supported with '
587 '`ParameterServerStrategy`. Found unsupported '
588 'callbacks: {}'.format(unsupported_callbacks))
589 # pylint: enable=protected-access
592@keras_export('keras.callbacks.Callback')
593class Callback:
594 """Abstract base class used to build new callbacks.
596 Callbacks can be passed to keras methods such as `fit`, `evaluate`, and
597 `predict` in order to hook into the various stages of the model training and
598 inference lifecycle.
600 To create a custom callback, subclass `keras.callbacks.Callback` and override
601 the method associated with the stage of interest. See
602 https://www.tensorflow.org/guide/keras/custom_callback for more information.
604 Example:
606 >>> training_finished = False
607 >>> class MyCallback(tf.keras.callbacks.Callback):
608 ... def on_train_end(self, logs=None):
609 ... global training_finished
610 ... training_finished = True
611 >>> model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])
612 >>> model.compile(loss='mean_squared_error')
613 >>> model.fit(tf.constant([[1.0]]), tf.constant([[1.0]]),
614 ... callbacks=[MyCallback()])
615 >>> assert training_finished == True
617 If you want to use `Callback` objects in a custom training loop:
619 1. You should pack all your callbacks into a single `callbacks.CallbackList`
620 so they can all be called together.
621 2. You will need to manually call all the `on_*` methods at the apropriate
622 locations in your loop. Like this:
624 ```
625 callbacks = tf.keras.callbacks.CallbackList([...])
626 callbacks.append(...)
628 callbacks.on_train_begin(...)
629 for epoch in range(EPOCHS):
630 callbacks.on_epoch_begin(epoch)
631 for i, data in dataset.enumerate():
632 callbacks.on_train_batch_begin(i)
633 batch_logs = model.train_step(data)
634 callbacks.on_train_batch_end(i, batch_logs)
635 epoch_logs = ...
636 callbacks.on_epoch_end(epoch, epoch_logs)
637 final_logs=...
638 callbacks.on_train_end(final_logs)
639 ```
641 Attributes:
642 params: Dict. Training parameters
643 (eg. verbosity, batch size, number of epochs...).
644 model: Instance of `keras.models.Model`.
645 Reference of the model being trained.
647 The `logs` dictionary that callback methods
648 take as argument will contain keys for quantities relevant to
649 the current batch or epoch (see method-specific docstrings).
650 """
652 def __init__(self):
653 self.validation_data = None # pylint: disable=g-missing-from-attributes
654 self.model = None
655 # Whether this Callback should only run on the chief worker in a
656 # Multi-Worker setting.
657 # TODO(omalleyt): Make this attr public once solution is stable.
658 self._chief_worker_only = None
659 self._supports_tf_logs = False
661 def set_params(self, params):
662 self.params = params
664 def set_model(self, model):
665 self.model = model
667 @doc_controls.for_subclass_implementers
668 @generic_utils.default
669 def on_batch_begin(self, batch, logs=None):
670 """A backwards compatibility alias for `on_train_batch_begin`."""
672 @doc_controls.for_subclass_implementers
673 @generic_utils.default
674 def on_batch_end(self, batch, logs=None):
675 """A backwards compatibility alias for `on_train_batch_end`."""
677 @doc_controls.for_subclass_implementers
678 def on_epoch_begin(self, epoch, logs=None):
679 """Called at the start of an epoch.
681 Subclasses should override for any actions to run. This function should only
682 be called during TRAIN mode.
684 Args:
685 epoch: Integer, index of epoch.
686 logs: Dict. Currently no data is passed to this argument for this method
687 but that may change in the future.
688 """
690 @doc_controls.for_subclass_implementers
691 def on_epoch_end(self, epoch, logs=None):
692 """Called at the end of an epoch.
694 Subclasses should override for any actions to run. This function should only
695 be called during TRAIN mode.
697 Args:
698 epoch: Integer, index of epoch.
699 logs: Dict, metric results for this training epoch, and for the
700 validation epoch if validation is performed. Validation result keys
701 are prefixed with `val_`. For training epoch, the values of the
702 `Model`'s metrics are returned. Example : `{'loss': 0.2, 'accuracy':
703 0.7}`.
704 """
706 @doc_controls.for_subclass_implementers
707 @generic_utils.default
708 def on_train_batch_begin(self, batch, logs=None):
709 """Called at the beginning of a training batch in `fit` methods.
711 Subclasses should override for any actions to run.
713 Note that if the `steps_per_execution` argument to `compile` in
714 `tf.keras.Model` is set to `N`, this method will only be called every `N`
715 batches.
717 Args:
718 batch: Integer, index of batch within the current epoch.
719 logs: Dict, contains the return value of `model.train_step`. Typically,
720 the values of the `Model`'s metrics are returned. Example:
721 `{'loss': 0.2, 'accuracy': 0.7}`.
722 """
723 # For backwards compatibility.
724 self.on_batch_begin(batch, logs=logs)
726 @doc_controls.for_subclass_implementers
727 @generic_utils.default
728 def on_train_batch_end(self, batch, logs=None):
729 """Called at the end of a training batch in `fit` methods.
731 Subclasses should override for any actions to run.
733 Note that if the `steps_per_execution` argument to `compile` in
734 `tf.keras.Model` is set to `N`, this method will only be called every `N`
735 batches.
737 Args:
738 batch: Integer, index of batch within the current epoch.
739 logs: Dict. Aggregated metric results up until this batch.
740 """
741 # For backwards compatibility.
742 self.on_batch_end(batch, logs=logs)
744 @doc_controls.for_subclass_implementers
745 @generic_utils.default
746 def on_test_batch_begin(self, batch, logs=None):
747 """Called at the beginning of a batch in `evaluate` methods.
749 Also called at the beginning of a validation batch in the `fit`
750 methods, if validation data is provided.
752 Subclasses should override for any actions to run.
754 Note that if the `steps_per_execution` argument to `compile` in
755 `tf.keras.Model` is set to `N`, this method will only be called every `N`
756 batches.
758 Args:
759 batch: Integer, index of batch within the current epoch.
760 logs: Dict, contains the return value of `model.test_step`. Typically,
761 the values of the `Model`'s metrics are returned. Example:
762 `{'loss': 0.2, 'accuracy': 0.7}`.
763 """
765 @doc_controls.for_subclass_implementers
766 @generic_utils.default
767 def on_test_batch_end(self, batch, logs=None):
768 """Called at the end of a batch in `evaluate` methods.
770 Also called at the end of a validation batch in the `fit`
771 methods, if validation data is provided.
773 Subclasses should override for any actions to run.
775 Note that if the `steps_per_execution` argument to `compile` in
776 `tf.keras.Model` is set to `N`, this method will only be called every `N`
777 batches.
779 Args:
780 batch: Integer, index of batch within the current epoch.
781 logs: Dict. Aggregated metric results up until this batch.
782 """
784 @doc_controls.for_subclass_implementers
785 @generic_utils.default
786 def on_predict_batch_begin(self, batch, logs=None):
787 """Called at the beginning of a batch in `predict` methods.
789 Subclasses should override for any actions to run.
791 Note that if the `steps_per_execution` argument to `compile` in
792 `tf.keras.Model` is set to `N`, this method will only be called every `N`
793 batches.
795 Args:
796 batch: Integer, index of batch within the current epoch.
797 logs: Dict, contains the return value of `model.predict_step`,
798 it typically returns a dict with a key 'outputs' containing
799 the model's outputs.
800 """
802 @doc_controls.for_subclass_implementers
803 @generic_utils.default
804 def on_predict_batch_end(self, batch, logs=None):
805 """Called at the end of a batch in `predict` methods.
807 Subclasses should override for any actions to run.
809 Note that if the `steps_per_execution` argument to `compile` in
810 `tf.keras.Model` is set to `N`, this method will only be called every `N`
811 batches.
813 Args:
814 batch: Integer, index of batch within the current epoch.
815 logs: Dict. Aggregated metric results up until this batch.
816 """
818 @doc_controls.for_subclass_implementers
819 def on_train_begin(self, logs=None):
820 """Called at the beginning of training.
822 Subclasses should override for any actions to run.
824 Args:
825 logs: Dict. Currently no data is passed to this argument for this method
826 but that may change in the future.
827 """
829 @doc_controls.for_subclass_implementers
830 def on_train_end(self, logs=None):
831 """Called at the end of training.
833 Subclasses should override for any actions to run.
835 Args:
836 logs: Dict. Currently the output of the last call to `on_epoch_end()`
837 is passed to this argument for this method but that may change in
838 the future.
839 """
841 @doc_controls.for_subclass_implementers
842 def on_test_begin(self, logs=None):
843 """Called at the beginning of evaluation or validation.
845 Subclasses should override for any actions to run.
847 Args:
848 logs: Dict. Currently no data is passed to this argument for this method
849 but that may change in the future.
850 """
852 @doc_controls.for_subclass_implementers
853 def on_test_end(self, logs=None):
854 """Called at the end of evaluation or validation.
856 Subclasses should override for any actions to run.
858 Args:
859 logs: Dict. Currently the output of the last call to
860 `on_test_batch_end()` is passed to this argument for this method
861 but that may change in the future.
862 """
864 @doc_controls.for_subclass_implementers
865 def on_predict_begin(self, logs=None):
866 """Called at the beginning of prediction.
868 Subclasses should override for any actions to run.
870 Args:
871 logs: Dict. Currently no data is passed to this argument for this method
872 but that may change in the future.
873 """
875 @doc_controls.for_subclass_implementers
876 def on_predict_end(self, logs=None):
877 """Called at the end of prediction.
879 Subclasses should override for any actions to run.
881 Args:
882 logs: Dict. Currently no data is passed to this argument for this method
883 but that may change in the future.
884 """
886 def _implements_train_batch_hooks(self):
887 """Determines if this Callback should be called for each train batch."""
888 return (not generic_utils.is_default(self.on_batch_begin) or
889 not generic_utils.is_default(self.on_batch_end) or
890 not generic_utils.is_default(self.on_train_batch_begin) or
891 not generic_utils.is_default(self.on_train_batch_end))
893 def _implements_test_batch_hooks(self):
894 """Determines if this Callback should be called for each test batch."""
895 return (not generic_utils.is_default(self.on_test_batch_begin) or
896 not generic_utils.is_default(self.on_test_batch_end))
898 def _implements_predict_batch_hooks(self):
899 """Determines if this Callback should be called for each predict batch."""
900 return (not generic_utils.is_default(self.on_predict_batch_begin) or
901 not generic_utils.is_default(self.on_predict_batch_end))
904@keras_export('keras.callbacks.BaseLogger')
905class BaseLogger(Callback):
906 """Callback that accumulates epoch averages of metrics.
908 This callback is automatically applied to every Keras model.
910 Args:
911 stateful_metrics: Iterable of string names of metrics that
912 should *not* be averaged over an epoch.
913 Metrics in this list will be logged as-is in `on_epoch_end`.
914 All others will be averaged in `on_epoch_end`.
915 """
917 def __init__(self, stateful_metrics=None):
918 super(BaseLogger, self).__init__()
919 self.stateful_metrics = set(stateful_metrics or [])
921 def on_epoch_begin(self, epoch, logs=None):
922 self.seen = 0
923 self.totals = {}
925 def on_batch_end(self, batch, logs=None):
926 logs = logs or {}
927 batch_size = logs.get('size', 0)
928 # In case of distribution strategy we can potentially run multiple steps
929 # at the same time, we should account for that in the `seen` calculation.
930 num_steps = logs.get('num_steps', 1)
931 self.seen += batch_size * num_steps
933 for k, v in logs.items():
934 if k in self.stateful_metrics:
935 self.totals[k] = v
936 else:
937 if k in self.totals:
938 self.totals[k] += v * batch_size
939 else:
940 self.totals[k] = v * batch_size
942 def on_epoch_end(self, epoch, logs=None):
943 if logs is not None:
944 for k in self.params['metrics']:
945 if k in self.totals:
946 # Make value available to next callbacks.
947 if k in self.stateful_metrics:
948 logs[k] = self.totals[k]
949 else:
950 logs[k] = self.totals[k] / self.seen
953@keras_export('keras.callbacks.TerminateOnNaN')
954class TerminateOnNaN(Callback):
955 """Callback that terminates training when a NaN loss is encountered.
956 """
958 def __init__(self):
959 super(TerminateOnNaN, self).__init__()
960 self._supports_tf_logs = True
962 def on_batch_end(self, batch, logs=None):
963 logs = logs or {}
964 loss = logs.get('loss')
965 if loss is not None:
966 loss = tf_utils.sync_to_numpy_or_python_type(loss)
967 if np.isnan(loss) or np.isinf(loss):
968 print('Batch %d: Invalid loss, terminating training' % (batch))
969 self.model.stop_training = True
972@keras_export('keras.callbacks.ProgbarLogger')
973class ProgbarLogger(Callback):
974 """Callback that prints metrics to stdout.
976 Args:
977 count_mode: One of `"steps"` or `"samples"`.
978 Whether the progress bar should
979 count samples seen or steps (batches) seen.
980 stateful_metrics: Iterable of string names of metrics that
981 should *not* be averaged over an epoch.
982 Metrics in this list will be logged as-is.
983 All others will be averaged over time (e.g. loss, etc).
984 If not provided, defaults to the `Model`'s metrics.
986 Raises:
987 ValueError: In case of invalid `count_mode`.
988 """
990 def __init__(self, count_mode='samples', stateful_metrics=None):
991 super(ProgbarLogger, self).__init__()
992 self._supports_tf_logs = True
993 if count_mode == 'samples':
994 self.use_steps = False
995 elif count_mode == 'steps':
996 self.use_steps = True
997 else:
998 raise ValueError('Unknown `count_mode`: ' + str(count_mode))
999 # Defaults to all Model's metrics except for loss.
1000 self.stateful_metrics = set(stateful_metrics) if stateful_metrics else set()
1002 self.seen = 0
1003 self.progbar = None
1004 self.target = None
1005 self.verbose = 1
1006 self.epochs = 1
1008 self._train_step, self._test_step, self._predict_step = None, None, None
1009 self._call_batch_hooks = True
1011 self._called_in_fit = False
1013 def set_params(self, params):
1014 self.verbose = params['verbose']
1015 self.epochs = params['epochs']
1016 if self.use_steps and 'steps' in params:
1017 self.target = params['steps']
1018 elif not self.use_steps and 'samples' in params:
1019 self.target = params['samples']
1020 else:
1021 self.target = None # Will be inferred at the end of the first epoch.
1023 self._call_batch_hooks = self.verbose == 1
1024 if self.target is None:
1025 try:
1026 self._train_step = self.model._train_counter # pylint: disable=protected-access
1027 self._test_step = self.model._test_counter # pylint: disable=protected-access
1028 self._predict_step = self.model._predict_counter # pylint: disable=protected-access
1029 except AttributeError:
1030 self._call_batch_hooks = True
1032 def on_train_begin(self, logs=None):
1033 # When this logger is called inside `fit`, validation is silent.
1034 self._called_in_fit = True
1036 def on_test_begin(self, logs=None):
1037 if not self._called_in_fit:
1038 self._reset_progbar()
1039 self._maybe_init_progbar()
1041 def on_predict_begin(self, logs=None):
1042 self._reset_progbar()
1043 self._maybe_init_progbar()
1045 def on_epoch_begin(self, epoch, logs=None):
1046 self._reset_progbar()
1047 self._maybe_init_progbar()
1048 if self.verbose and self.epochs > 1:
1049 print('Epoch %d/%d' % (epoch + 1, self.epochs))
1051 def on_train_batch_end(self, batch, logs=None):
1052 self._batch_update_progbar(batch, logs)
1054 def on_test_batch_end(self, batch, logs=None):
1055 if not self._called_in_fit:
1056 self._batch_update_progbar(batch, logs)
1058 def on_predict_batch_end(self, batch, logs=None):
1059 # Don't pass prediction results.
1060 self._batch_update_progbar(batch, None)
1062 def on_epoch_end(self, epoch, logs=None):
1063 self._finalize_progbar(logs, self._train_step)
1065 def on_test_end(self, logs=None):
1066 if not self._called_in_fit:
1067 self._finalize_progbar(logs, self._test_step)
1069 def on_predict_end(self, logs=None):
1070 self._finalize_progbar(logs, self._predict_step)
1072 def _reset_progbar(self):
1073 self.seen = 0
1074 self.progbar = None
1076 def _maybe_init_progbar(self):
1077 """Instantiate a `Progbar` if not yet, and update the stateful metrics."""
1078 # TODO(rchao): Legacy TF1 code path may use list for
1079 # `self.stateful_metrics`. Remove "cast to set" when TF1 support is dropped.
1080 self.stateful_metrics = set(self.stateful_metrics)
1082 if self.model:
1083 # Update the existing stateful metrics as `self.model.metrics` may contain
1084 # updated metrics after `MetricsContainer` is built in the first train
1085 # step.
1086 self.stateful_metrics = self.stateful_metrics.union(
1087 set(m.name for m in self.model.metrics))
1089 if self.progbar is None:
1090 self.progbar = Progbar(
1091 target=self.target,
1092 verbose=self.verbose,
1093 stateful_metrics=self.stateful_metrics,
1094 unit_name='step' if self.use_steps else 'sample')
1096 self.progbar._update_stateful_metrics(self.stateful_metrics) # pylint: disable=protected-access
1098 def _implements_train_batch_hooks(self):
1099 return self._call_batch_hooks
1101 def _implements_test_batch_hooks(self):
1102 return self._call_batch_hooks
1104 def _implements_predict_batch_hooks(self):
1105 return self._call_batch_hooks
1107 def _batch_update_progbar(self, batch, logs=None):
1108 """Updates the progbar."""
1109 logs = logs or {}
1110 self._maybe_init_progbar()
1111 if self.use_steps:
1112 self.seen = batch + 1 # One-indexed.
1113 else:
1114 # v1 path only.
1115 logs = copy.copy(logs)
1116 batch_size = logs.pop('size', 0)
1117 num_steps = logs.pop('num_steps', 1)
1118 logs.pop('batch', None)
1119 add_seen = num_steps * batch_size
1120 self.seen += add_seen
1122 if self.verbose == 1:
1123 # Only block async when verbose = 1.
1124 logs = tf_utils.sync_to_numpy_or_python_type(logs)
1125 self.progbar.update(self.seen, list(logs.items()), finalize=False)
1127 def _finalize_progbar(self, logs, counter):
1128 logs = tf_utils.sync_to_numpy_or_python_type(logs or {})
1129 if self.target is None:
1130 if counter is not None:
1131 counter = counter.numpy()
1132 if not self.use_steps:
1133 counter *= logs.get('size', 1)
1134 self.target = counter or self.seen
1135 self.progbar.target = self.target
1136 self.progbar.update(self.target, list(logs.items()), finalize=True)
1139@keras_export('keras.callbacks.History')
1140class History(Callback):
1141 """Callback that records events into a `History` object.
1143 This callback is automatically applied to
1144 every Keras model. The `History` object
1145 gets returned by the `fit` method of models.
1147 Example:
1149 >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
1150 >>> model.compile(tf.keras.optimizers.SGD(), loss='mse')
1151 >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
1152 ... epochs=10, verbose=1)
1153 >>> print(history.params)
1154 {'verbose': 1, 'epochs': 10, 'steps': 1}
1155 >>> # check the keys of history object
1156 >>> print(history.history.keys())
1157 dict_keys(['loss'])
1159 """
1161 def __init__(self):
1162 super(History, self).__init__()
1163 self.history = {}
1165 def on_train_begin(self, logs=None):
1166 self.epoch = []
1168 def on_epoch_end(self, epoch, logs=None):
1169 logs = logs or {}
1170 self.epoch.append(epoch)
1171 for k, v in logs.items():
1172 self.history.setdefault(k, []).append(v)
1174 # Set the history attribute on the model after the epoch ends. This will
1175 # make sure that the state which is set is the latest one.
1176 self.model.history = self
1179@keras_export('keras.callbacks.ModelCheckpoint')
1180class ModelCheckpoint(Callback):
1181 """Callback to save the Keras model or model weights at some frequency.
1183 `ModelCheckpoint` callback is used in conjunction with training using
1184 `model.fit()` to save a model or weights (in a checkpoint file) at some
1185 interval, so the model or weights can be loaded later to continue the training
1186 from the state saved.
1188 A few options this callback provides include:
1190 - Whether to only keep the model that has achieved the "best performance" so
1191 far, or whether to save the model at the end of every epoch regardless of
1192 performance.
1193 - Definition of 'best'; which quantity to monitor and whether it should be
1194 maximized or minimized.
1195 - The frequency it should save at. Currently, the callback supports saving at
1196 the end of every epoch, or after a fixed number of training batches.
1197 - Whether only weights are saved, or the whole model is saved.
1199 Note: If you get `WARNING:tensorflow:Can save best model only with <name>
1200 available, skipping` see the description of the `monitor` argument for
1201 details on how to get this right.
1203 Example:
1205 ```python
1206 model.compile(loss=..., optimizer=...,
1207 metrics=['accuracy'])
1209 EPOCHS = 10
1210 checkpoint_filepath = '/tmp/checkpoint'
1211 model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
1212 filepath=checkpoint_filepath,
1213 save_weights_only=True,
1214 monitor='val_accuracy',
1215 mode='max',
1216 save_best_only=True)
1218 # Model weights are saved at the end of every epoch, if it's the best seen
1219 # so far.
1220 model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])
1222 # The model weights (that are considered the best) are loaded into the model.
1223 model.load_weights(checkpoint_filepath)
1224 ```
1226 Args:
1227 filepath: string or `PathLike`, path to save the model file. e.g.
1228 filepath = os.path.join(working_dir, 'ckpt', file_name). `filepath`
1229 can contain named formatting options, which will be filled the value of
1230 `epoch` and keys in `logs` (passed in `on_epoch_end`). For example: if
1231 `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then the model
1232 checkpoints will be saved with the epoch number and the validation loss
1233 in the filename. The directory of the filepath should not be reused by
1234 any other callbacks to avoid conflicts.
1235 monitor: The metric name to monitor. Typically the metrics are set by the
1236 `Model.compile` method. Note:
1238 * Prefix the name with `"val_`" to monitor validation metrics.
1239 * Use `"loss"` or "`val_loss`" to monitor the model's total loss.
1240 * If you specify metrics as strings, like `"accuracy"`, pass the same
1241 string (with or without the `"val_"` prefix).
1242 * If you pass `metrics.Metric` objects, `monitor` should be set to
1243 `metric.name`
1244 * If you're not sure about the metric names you can check the contents
1245 of the `history.history` dictionary returned by
1246 `history = model.fit()`
1247 * Multi-output models set additional prefixes on the metric names.
1249 verbose: verbosity mode, 0 or 1.
1250 save_best_only: if `save_best_only=True`, it only saves when the model
1251 is considered the "best" and the latest best model according to the
1252 quantity monitored will not be overwritten. If `filepath` doesn't
1253 contain formatting options like `{epoch}` then `filepath` will be
1254 overwritten by each new better model.
1255 mode: one of {'auto', 'min', 'max'}. If `save_best_only=True`, the
1256 decision to overwrite the current save file is made based on either
1257 the maximization or the minimization of the monitored quantity.
1258 For `val_acc`, this should be `max`, for `val_loss` this should be
1259 `min`, etc. In `auto` mode, the mode is set to `max` if the quantities
1260 monitored are 'acc' or start with 'fmeasure' and are set to `min` for
1261 the rest of the quantities.
1262 save_weights_only: if True, then only the model's weights will be saved
1263 (`model.save_weights(filepath)`), else the full model is saved
1264 (`model.save(filepath)`).
1265 save_freq: `'epoch'` or integer. When using `'epoch'`, the callback saves
1266 the model after each epoch. When using integer, the callback saves the
1267 model at end of this many batches. If the `Model` is compiled with
1268 `steps_per_execution=N`, then the saving criteria will be
1269 checked every Nth batch. Note that if the saving isn't aligned to
1270 epochs, the monitored metric may potentially be less reliable (it
1271 could reflect as little as 1 batch, since the metrics get reset every
1272 epoch). Defaults to `'epoch'`.
1273 options: Optional `tf.train.CheckpointOptions` object if
1274 `save_weights_only` is true or optional `tf.saved_model.SaveOptions`
1275 object if `save_weights_only` is false.
1276 **kwargs: Additional arguments for backwards compatibility. Possible key
1277 is `period`.
1278 """
1280 def __init__(self,
1281 filepath,
1282 monitor='val_loss',
1283 verbose=0,
1284 save_best_only=False,
1285 save_weights_only=False,
1286 mode='auto',
1287 save_freq='epoch',
1288 options=None,
1289 **kwargs):
1290 super(ModelCheckpoint, self).__init__()
1291 self._supports_tf_logs = True
1292 self.monitor = monitor
1293 self.verbose = verbose
1294 self.filepath = path_to_string(filepath)
1295 self.save_best_only = save_best_only
1296 self.save_weights_only = save_weights_only
1297 self.save_freq = save_freq
1298 self.epochs_since_last_save = 0
1299 self._batches_seen_since_last_saving = 0
1300 self._last_batch_seen = 0
1302 if save_weights_only:
1303 if options is None or isinstance(
1304 options, checkpoint_options_lib.CheckpointOptions):
1305 self._options = options or checkpoint_options_lib.CheckpointOptions()
1306 else:
1307 raise TypeError('If save_weights_only is True, then `options` must be '
1308 'either None or a tf.train.CheckpointOptions')
1309 else:
1310 if options is None or isinstance(options, save_options_lib.SaveOptions):
1311 self._options = options or save_options_lib.SaveOptions()
1312 else:
1313 raise TypeError('If save_weights_only is False, then `options` must be'
1314 'either None or a tf.saved_model.SaveOptions')
1316 # Deprecated field `load_weights_on_restart` is for loading the checkpoint
1317 # file from `filepath` at the start of `model.fit()`
1318 # TODO(rchao): Remove the arg during next breaking release.
1319 if 'load_weights_on_restart' in kwargs:
1320 self.load_weights_on_restart = kwargs['load_weights_on_restart']
1321 logging.warning('`load_weights_on_restart` argument is deprecated. '
1322 'Please use `model.load_weights()` for loading weights '
1323 'before the start of `model.fit()`.')
1324 else:
1325 self.load_weights_on_restart = False
1327 # Deprecated field `period` is for the number of epochs between which
1328 # the model is saved.
1329 if 'period' in kwargs:
1330 self.period = kwargs['period']
1331 logging.warning('`period` argument is deprecated. Please use `save_freq` '
1332 'to specify the frequency in number of batches seen.')
1333 else:
1334 self.period = 1
1336 if mode not in ['auto', 'min', 'max']:
1337 logging.warning('ModelCheckpoint mode %s is unknown, '
1338 'fallback to auto mode.', mode)
1339 mode = 'auto'
1341 if mode == 'min':
1342 self.monitor_op = np.less
1343 self.best = np.Inf
1344 elif mode == 'max':
1345 self.monitor_op = np.greater
1346 self.best = -np.Inf
1347 else:
1348 if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
1349 self.monitor_op = np.greater
1350 self.best = -np.Inf
1351 else:
1352 self.monitor_op = np.less
1353 self.best = np.Inf
1355 if self.save_freq != 'epoch' and not isinstance(self.save_freq, int):
1356 raise ValueError('Unrecognized save_freq: {}'.format(self.save_freq))
1358 # Only the chief worker writes model checkpoints, but all workers
1359 # restore checkpoint at on_train_begin().
1360 self._chief_worker_only = False
1362 def on_train_begin(self, logs=None):
1363 if self.load_weights_on_restart:
1364 filepath_to_load = (
1365 self._get_most_recently_modified_file_matching_pattern(self.filepath))
1366 if (filepath_to_load is not None and
1367 self._checkpoint_exists(filepath_to_load)):
1368 try:
1369 # `filepath` may contain placeholders such as `{epoch:02d}`, and
1370 # thus it attempts to load the most recently modified file with file
1371 # name matching the pattern.
1372 self.model.load_weights(filepath_to_load)
1373 except (IOError, ValueError) as e:
1374 raise ValueError('Error loading file from {}. Reason: {}'.format(
1375 filepath_to_load, e))
1377 def _implements_train_batch_hooks(self):
1378 # Only call batch hooks when saving on batch
1379 return self.save_freq != 'epoch'
1381 def on_train_batch_end(self, batch, logs=None):
1382 if self._should_save_on_batch(batch):
1383 self._save_model(epoch=self._current_epoch, logs=logs)
1385 def on_epoch_begin(self, epoch, logs=None):
1386 self._current_epoch = epoch
1388 def on_epoch_end(self, epoch, logs=None):
1389 self.epochs_since_last_save += 1
1390 # pylint: disable=protected-access
1391 if self.save_freq == 'epoch':
1392 self._save_model(epoch=epoch, logs=logs)
1394 def _should_save_on_batch(self, batch):
1395 """Handles batch-level saving logic, supports steps_per_execution."""
1396 if self.save_freq == 'epoch':
1397 return False
1399 if batch <= self._last_batch_seen: # New epoch.
1400 add_batches = batch + 1 # batches are zero-indexed.
1401 else:
1402 add_batches = batch - self._last_batch_seen
1403 self._batches_seen_since_last_saving += add_batches
1404 self._last_batch_seen = batch
1406 if self._batches_seen_since_last_saving >= self.save_freq:
1407 self._batches_seen_since_last_saving = 0
1408 return True
1409 return False
1411 def _save_model(self, epoch, logs):
1412 """Saves the model.
1414 Args:
1415 epoch: the epoch this iteration is in.
1416 logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`.
1417 """
1418 logs = logs or {}
1420 if isinstance(self.save_freq,
1421 int) or self.epochs_since_last_save >= self.period:
1422 # Block only when saving interval is reached.
1423 logs = tf_utils.sync_to_numpy_or_python_type(logs)
1424 self.epochs_since_last_save = 0
1425 filepath = self._get_file_path(epoch, logs)
1427 try:
1428 if self.save_best_only:
1429 current = logs.get(self.monitor)
1430 if current is None:
1431 logging.warning('Can save best model only with %s available, '
1432 'skipping.', self.monitor)
1433 else:
1434 if self.monitor_op(current, self.best):
1435 if self.verbose > 0:
1436 print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
1437 ' saving model to %s' % (epoch + 1, self.monitor,
1438 self.best, current, filepath))
1439 self.best = current
1440 if self.save_weights_only:
1441 self.model.save_weights(
1442 filepath, overwrite=True, options=self._options)
1443 else:
1444 self.model.save(filepath, overwrite=True, options=self._options)
1445 else:
1446 if self.verbose > 0:
1447 print('\nEpoch %05d: %s did not improve from %0.5f' %
1448 (epoch + 1, self.monitor, self.best))
1449 else:
1450 if self.verbose > 0:
1451 print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
1452 if self.save_weights_only:
1453 self.model.save_weights(
1454 filepath, overwrite=True, options=self._options)
1455 else:
1456 self.model.save(filepath, overwrite=True, options=self._options)
1458 self._maybe_remove_file()
1459 except IsADirectoryError as e: # h5py 3.x
1460 raise IOError('Please specify a non-directory filepath for '
1461 'ModelCheckpoint. Filepath used is an existing '
1462 'directory: {}'.format(filepath))
1463 except IOError as e: # h5py 2.x
1464 # `e.errno` appears to be `None` so checking the content of `e.args[0]`.
1465 if 'is a directory' in str(e.args[0]).lower():
1466 raise IOError('Please specify a non-directory filepath for '
1467 'ModelCheckpoint. Filepath used is an existing '
1468 'directory: {}'.format(filepath))
1469 # Re-throw the error for any other causes.
1470 raise e
1472 def _get_file_path(self, epoch, logs):
1473 """Returns the file path for checkpoint."""
1474 # pylint: disable=protected-access
1475 try:
1476 # `filepath` may contain placeholders such as `{epoch:02d}` and
1477 # `{mape:.2f}`. A mismatch between logged metrics and the path's
1478 # placeholders can cause formatting to fail.
1479 file_path = self.filepath.format(epoch=epoch + 1, **logs)
1480 except KeyError as e:
1481 raise KeyError('Failed to format this callback filepath: "{}". '
1482 'Reason: {}'.format(self.filepath, e))
1483 self._write_filepath = distributed_file_utils.write_filepath(
1484 file_path, self.model.distribute_strategy)
1485 return self._write_filepath
1487 def _maybe_remove_file(self):
1488 # Remove the checkpoint directory in multi-worker training where this worker
1489 # should not checkpoint. It is a dummy directory previously saved for sync
1490 # distributed training.
1491 distributed_file_utils.remove_temp_dir_with_filepath(
1492 self._write_filepath, self.model.distribute_strategy)
1494 def _checkpoint_exists(self, filepath):
1495 """Returns whether the checkpoint `filepath` refers to exists."""
1496 if filepath.endswith('.h5'):
1497 return file_io.file_exists_v2(filepath)
1498 tf_saved_model_exists = file_io.file_exists_v2(filepath)
1499 tf_weights_only_checkpoint_exists = file_io.file_exists_v2(
1500 filepath + '.index')
1501 return tf_saved_model_exists or tf_weights_only_checkpoint_exists
1503 def _get_most_recently_modified_file_matching_pattern(self, pattern):
1504 """Returns the most recently modified filepath matching pattern.
1506 Pattern may contain python formatting placeholder. If
1507 `tf.train.latest_checkpoint()` does not return None, use that; otherwise,
1508 check for most recently modified one that matches the pattern.
1510 In the rare case where there are more than one pattern-matching file having
1511 the same modified time that is most recent among all, return the filepath
1512 that is largest (by `>` operator, lexicographically using the numeric
1513 equivalents). This provides a tie-breaker when multiple files are most
1514 recent. Note that a larger `filepath` can sometimes indicate a later time of
1515 modification (for instance, when epoch/batch is used as formatting option),
1516 but not necessarily (when accuracy or loss is used). The tie-breaker is
1517 put in the logic as best effort to return the most recent, and to avoid
1518 undeterministic result.
1520 Modified time of a file is obtained with `os.path.getmtime()`.
1522 This utility function is best demonstrated via an example:
1524 ```python
1525 file_pattern = 'f.batch{batch:02d}epoch{epoch:02d}.h5'
1526 test_dir = self.get_temp_dir()
1527 path_pattern = os.path.join(test_dir, file_pattern)
1528 file_paths = [
1529 os.path.join(test_dir, file_name) for file_name in
1530 ['f.batch03epoch02.h5', 'f.batch02epoch02.h5', 'f.batch01epoch01.h5']
1531 ]
1532 for file_path in file_paths:
1533 # Write something to each of the files
1534 self.assertEqual(
1535 _get_most_recently_modified_file_matching_pattern(path_pattern),
1536 file_paths[-1])
1537 ```
1539 Args:
1540 pattern: The file pattern that may optionally contain python placeholder
1541 such as `{epoch:02d}`.
1543 Returns:
1544 The most recently modified file's full filepath matching `pattern`. If
1545 `pattern` does not contain any placeholder, this returns the filepath
1546 that
1547 exactly matches `pattern`. Returns `None` if no match is found.
1548 """
1549 dir_name = os.path.dirname(pattern)
1550 base_name = os.path.basename(pattern)
1551 base_name_regex = '^' + re.sub(r'{.*}', r'.*', base_name) + '$'
1553 # If tf.train.latest_checkpoint tells us there exists a latest checkpoint,
1554 # use that as it is more robust than `os.path.getmtime()`.
1555 latest_tf_checkpoint = checkpoint_management.latest_checkpoint(dir_name)
1556 if latest_tf_checkpoint is not None and re.match(
1557 base_name_regex, os.path.basename(latest_tf_checkpoint)):
1558 return latest_tf_checkpoint
1560 latest_mod_time = 0
1561 file_path_with_latest_mod_time = None
1562 n_file_with_latest_mod_time = 0
1563 file_path_with_largest_file_name = None
1565 if file_io.file_exists_v2(dir_name):
1566 for file_name in os.listdir(dir_name):
1567 # Only consider if `file_name` matches the pattern.
1568 if re.match(base_name_regex, file_name):
1569 file_path = os.path.join(dir_name, file_name)
1570 mod_time = os.path.getmtime(file_path)
1571 if (file_path_with_largest_file_name is None or
1572 file_path > file_path_with_largest_file_name):
1573 file_path_with_largest_file_name = file_path
1574 if mod_time > latest_mod_time:
1575 latest_mod_time = mod_time
1576 file_path_with_latest_mod_time = file_path
1577 # In the case a file with later modified time is found, reset
1578 # the counter for the number of files with latest modified time.
1579 n_file_with_latest_mod_time = 1
1580 elif mod_time == latest_mod_time:
1581 # In the case a file has modified time tied with the most recent,
1582 # increment the counter for the number of files with latest modified
1583 # time by 1.
1584 n_file_with_latest_mod_time += 1
1586 if n_file_with_latest_mod_time == 1:
1587 # Return the sole file that has most recent modified time.
1588 return file_path_with_latest_mod_time
1589 else:
1590 # If there are more than one file having latest modified time, return
1591 # the file path with the largest file name.
1592 return file_path_with_largest_file_name
1595@keras_export('keras.callbacks.experimental.BackupAndRestore', v1=[])
1596class BackupAndRestore(Callback):
1597 """Callback to back up and restore the training state.
1599 `BackupAndRestore` callback is intended to recover from interruptions that
1600 happened in the middle of a model.fit execution by backing up the
1601 training states in a temporary checkpoint file (based on TF CheckpointManager)
1602 at the end of each epoch. If training restarted before completion, the
1603 training state and model are restored to the most recently saved state at the
1604 beginning of a new model.fit() run.
1605 Note that user is responsible to bring jobs back up.
1606 This callback is important for the backup and restore mechanism for fault
1607 tolerance purpose. And the model to be restored from an previous checkpoint is
1608 expected to be the same as the one used to back up. If user changes arguments
1609 passed to compile or fit, the checkpoint saved for fault tolerance can become
1610 invalid.
1612 Note:
1613 1. This callback is not compatible with disabling eager execution.
1614 2. A checkpoint is saved at the end of each epoch, when restoring we'll redo
1615 any partial work from an unfinished epoch in which the training got restarted
1616 (so the work done before a interruption doesn't affect the final model state).
1617 3. This works for both single worker and multi-worker mode, only
1618 MirroredStrategy and MultiWorkerMirroredStrategy are supported for now.
1620 Example:
1622 >>> class InterruptingCallback(tf.keras.callbacks.Callback):
1623 ... def on_epoch_begin(self, epoch, logs=None):
1624 ... if epoch == 4:
1625 ... raise RuntimeError('Interrupting!')
1626 >>> callback = tf.keras.callbacks.experimental.BackupAndRestore(
1627 ... backup_dir="/tmp/backup")
1628 >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
1629 >>> model.compile(tf.keras.optimizers.SGD(), loss='mse')
1630 >>> try:
1631 ... model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10,
1632 ... batch_size=1, callbacks=[callback, InterruptingCallback()],
1633 ... verbose=0)
1634 ... except:
1635 ... pass
1636 >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10,
1637 ... batch_size=1, callbacks=[callback], verbose=0)
1638 >>> # Only 6 more epochs are run, since first trainning got interrupted at
1639 >>> # zero-indexed epoch 4, second training will continue from 4 to 9.
1640 >>> len(history.history['loss'])
1641 6
1643 Args:
1644 backup_dir: String, path to store the checkpoint.
1645 e.g. backup_dir = os.path.join(working_dir, 'backup')
1646 This is the directory in which the system stores temporary files to
1647 recover the model from jobs terminated unexpectedly. The directory
1648 cannot be reused elsewhere to store other files, e.g. by
1649 BackupAndRestore callback of another training, or by another callback
1650 (ModelCheckpoint) of the same training.
1651 """
1653 def __init__(self, backup_dir):
1654 super(BackupAndRestore, self).__init__()
1655 self.backup_dir = backup_dir
1656 self._supports_tf_logs = True
1657 self._supported_strategies = (
1658 mirrored_strategy.MirroredStrategy,
1659 collective_all_reduce_strategy.CollectiveAllReduceStrategy,
1660 tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV2,
1661 parameter_server_strategy_v2.ParameterServerStrategyV2)
1663 if not context.executing_eagerly():
1664 if ops.inside_function():
1665 raise ValueError('This Callback\'s method contains Python state and '
1666 'should be called outside of `tf.function`s.')
1667 else: # Legacy graph mode:
1668 raise ValueError(
1669 'BackupAndRestore only supports eager mode. In graph '
1670 'mode, consider using ModelCheckpoint to manually save '
1671 'and restore weights with `model.load_weights()` and by '
1672 'providing `initial_epoch` in `model.fit()` for fault tolerance.')
1674 # Only the chief worker writes model checkpoints, but all workers
1675 # restore checkpoint at on_train_begin().
1676 self._chief_worker_only = False
1678 def on_train_begin(self, logs=None):
1679 # TrainingState is used to manage the training state needed for
1680 # failure-recovery of a worker in training.
1681 # pylint: disable=protected-access
1683 if self.model._distribution_strategy and not isinstance(
1684 self.model.distribute_strategy, self._supported_strategies):
1685 raise NotImplementedError(
1686 '%s is not supported yet. '
1687 'Currently BackupAndRestore callback only supports empty strategy, '
1688 'MirroredStrategy, MultiWorkerMirroredStrategy and TPUStrategy.' %
1689 type(self.model.distribute_strategy).__name__)
1690 self.model._training_state = (
1691 worker_training_state.WorkerTrainingState(self.model, self.backup_dir))
1692 self._training_state = self.model._training_state
1693 self._training_state.restore()
1695 def on_train_end(self, logs=None):
1696 # pylint: disable=protected-access
1697 # On exit of training, delete the training state backup file that was saved
1698 # for the purpose of worker recovery.
1699 self._training_state.delete_backup()
1701 # Clean up the training state.
1702 del self._training_state
1703 del self.model._training_state
1705 def on_epoch_end(self, epoch, logs=None):
1706 # Back up the model and current epoch for possible future recovery.
1707 self._training_state.back_up(epoch)
1710@keras_export('keras.callbacks.EarlyStopping')
1711class EarlyStopping(Callback):
1712 """Stop training when a monitored metric has stopped improving.
1714 Assuming the goal of a training is to minimize the loss. With this, the
1715 metric to be monitored would be `'loss'`, and mode would be `'min'`. A
1716 `model.fit()` training loop will check at end of every epoch whether
1717 the loss is no longer decreasing, considering the `min_delta` and
1718 `patience` if applicable. Once it's found no longer decreasing,
1719 `model.stop_training` is marked True and the training terminates.
1721 The quantity to be monitored needs to be available in `logs` dict.
1722 To make it so, pass the loss or metrics at `model.compile()`.
1724 Args:
1725 monitor: Quantity to be monitored.
1726 min_delta: Minimum change in the monitored quantity
1727 to qualify as an improvement, i.e. an absolute
1728 change of less than min_delta, will count as no
1729 improvement.
1730 patience: Number of epochs with no improvement
1731 after which training will be stopped.
1732 verbose: verbosity mode.
1733 mode: One of `{"auto", "min", "max"}`. In `min` mode,
1734 training will stop when the quantity
1735 monitored has stopped decreasing; in `"max"`
1736 mode it will stop when the quantity
1737 monitored has stopped increasing; in `"auto"`
1738 mode, the direction is automatically inferred
1739 from the name of the monitored quantity.
1740 baseline: Baseline value for the monitored quantity.
1741 Training will stop if the model doesn't show improvement over the
1742 baseline.
1743 restore_best_weights: Whether to restore model weights from
1744 the epoch with the best value of the monitored quantity.
1745 If False, the model weights obtained at the last step of
1746 training are used. An epoch will be restored regardless
1747 of the performance relative to the `baseline`. If no epoch
1748 improves on `baseline`, training will run for `patience`
1749 epochs and restore weights from the best epoch in that set.
1751 Example:
1753 >>> callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
1754 >>> # This callback will stop the training when there is no improvement in
1755 >>> # the loss for three consecutive epochs.
1756 >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
1757 >>> model.compile(tf.keras.optimizers.SGD(), loss='mse')
1758 >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
1759 ... epochs=10, batch_size=1, callbacks=[callback],
1760 ... verbose=0)
1761 >>> len(history.history['loss']) # Only 4 epochs are run.
1762 4
1763 """
1765 def __init__(self,
1766 monitor='val_loss',
1767 min_delta=0,
1768 patience=0,
1769 verbose=0,
1770 mode='auto',
1771 baseline=None,
1772 restore_best_weights=False):
1773 super(EarlyStopping, self).__init__()
1775 self.monitor = monitor
1776 self.patience = patience
1777 self.verbose = verbose
1778 self.baseline = baseline
1779 self.min_delta = abs(min_delta)
1780 self.wait = 0
1781 self.stopped_epoch = 0
1782 self.restore_best_weights = restore_best_weights
1783 self.best_weights = None
1785 if mode not in ['auto', 'min', 'max']:
1786 logging.warning('EarlyStopping mode %s is unknown, '
1787 'fallback to auto mode.', mode)
1788 mode = 'auto'
1790 if mode == 'min':
1791 self.monitor_op = np.less
1792 elif mode == 'max':
1793 self.monitor_op = np.greater
1794 else:
1795 if 'acc' in self.monitor:
1796 self.monitor_op = np.greater
1797 else:
1798 self.monitor_op = np.less
1800 if self.monitor_op == np.greater:
1801 self.min_delta *= 1
1802 else:
1803 self.min_delta *= -1
1805 def on_train_begin(self, logs=None):
1806 # Allow instances to be re-used
1807 self.wait = 0
1808 self.stopped_epoch = 0
1809 self.best = np.Inf if self.monitor_op == np.less else -np.Inf
1810 self.best_weights = None
1812 def on_epoch_end(self, epoch, logs=None):
1813 current = self.get_monitor_value(logs)
1814 if current is None:
1815 return
1816 if self.restore_best_weights and self.best_weights is None:
1817 # Restore the weights after first epoch if no progress is ever made.
1818 self.best_weights = self.model.get_weights()
1820 self.wait += 1
1821 if self._is_improvement(current, self.best):
1822 self.best = current
1823 if self.restore_best_weights:
1824 self.best_weights = self.model.get_weights()
1825 # Only restart wait if we beat both the baseline and our previous best.
1826 if self.baseline is None or self._is_improvement(current, self.baseline):
1827 self.wait = 0
1829 if self.wait >= self.patience:
1830 self.stopped_epoch = epoch
1831 self.model.stop_training = True
1832 if self.restore_best_weights and self.best_weights is not None:
1833 if self.verbose > 0:
1834 print('Restoring model weights from the end of the best epoch.')
1835 self.model.set_weights(self.best_weights)
1837 def on_train_end(self, logs=None):
1838 if self.stopped_epoch > 0 and self.verbose > 0:
1839 print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
1841 def get_monitor_value(self, logs):
1842 logs = logs or {}
1843 monitor_value = logs.get(self.monitor)
1844 if monitor_value is None:
1845 logging.warning('Early stopping conditioned on metric `%s` '
1846 'which is not available. Available metrics are: %s',
1847 self.monitor, ','.join(list(logs.keys())))
1848 return monitor_value
1850 def _is_improvement(self, monitor_value, reference_value):
1851 return self.monitor_op(monitor_value - self.min_delta, reference_value)
1854@keras_export('keras.callbacks.RemoteMonitor')
1855class RemoteMonitor(Callback):
1856 """Callback used to stream events to a server.
1858 Requires the `requests` library.
1859 Events are sent to `root + '/publish/epoch/end/'` by default. Calls are
1860 HTTP POST, with a `data` argument which is a
1861 JSON-encoded dictionary of event data.
1862 If `send_as_json=True`, the content type of the request will be
1863 `"application/json"`.
1864 Otherwise the serialized JSON will be sent within a form.
1866 Args:
1867 root: String; root url of the target server.
1868 path: String; path relative to `root` to which the events will be sent.
1869 field: String; JSON field under which the data will be stored.
1870 The field is used only if the payload is sent within a form
1871 (i.e. send_as_json is set to False).
1872 headers: Dictionary; optional custom HTTP headers.
1873 send_as_json: Boolean; whether the request should be
1874 sent as `"application/json"`.
1875 """
1877 def __init__(self,
1878 root='http://localhost:9000',
1879 path='/publish/epoch/end/',
1880 field='data',
1881 headers=None,
1882 send_as_json=False):
1883 super(RemoteMonitor, self).__init__()
1885 self.root = root
1886 self.path = path
1887 self.field = field
1888 self.headers = headers
1889 self.send_as_json = send_as_json
1891 def on_epoch_end(self, epoch, logs=None):
1892 if requests is None:
1893 raise ImportError('RemoteMonitor requires the `requests` library.')
1894 logs = logs or {}
1895 send = {}
1896 send['epoch'] = epoch
1897 for k, v in logs.items():
1898 # np.ndarray and np.generic are not scalar types
1899 # therefore we must unwrap their scalar values and
1900 # pass to the json-serializable dict 'send'
1901 if isinstance(v, (np.ndarray, np.generic)):
1902 send[k] = v.item()
1903 else:
1904 send[k] = v
1905 try:
1906 if self.send_as_json:
1907 requests.post(self.root + self.path, json=send, headers=self.headers)
1908 else:
1909 requests.post(
1910 self.root + self.path, {self.field: json.dumps(send)},
1911 headers=self.headers)
1912 except requests.exceptions.RequestException:
1913 logging.warning('Warning: could not reach RemoteMonitor '
1914 'root server at ' + str(self.root))
1917@keras_export('keras.callbacks.LearningRateScheduler')
1918class LearningRateScheduler(Callback):
1919 """Learning rate scheduler.
1921 At the beginning of every epoch, this callback gets the updated learning rate
1922 value from `schedule` function provided at `__init__`, with the current epoch
1923 and current learning rate, and applies the updated learning rate
1924 on the optimizer.
1926 Args:
1927 schedule: a function that takes an epoch index (integer, indexed from 0)
1928 and current learning rate (float) as inputs and returns a new
1929 learning rate as output (float).
1930 verbose: int. 0: quiet, 1: update messages.
1932 Example:
1934 >>> # This function keeps the initial learning rate for the first ten epochs
1935 >>> # and decreases it exponentially after that.
1936 >>> def scheduler(epoch, lr):
1937 ... if epoch < 10:
1938 ... return lr
1939 ... else:
1940 ... return lr * tf.math.exp(-0.1)
1941 >>>
1942 >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
1943 >>> model.compile(tf.keras.optimizers.SGD(), loss='mse')
1944 >>> round(model.optimizer.lr.numpy(), 5)
1945 0.01
1947 >>> callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
1948 >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
1949 ... epochs=15, callbacks=[callback], verbose=0)
1950 >>> round(model.optimizer.lr.numpy(), 5)
1951 0.00607
1953 """
1955 def __init__(self, schedule, verbose=0):
1956 super(LearningRateScheduler, self).__init__()
1957 self.schedule = schedule
1958 self.verbose = verbose
1960 def on_epoch_begin(self, epoch, logs=None):
1961 if not hasattr(self.model.optimizer, 'lr'):
1962 raise ValueError('Optimizer must have a "lr" attribute.')
1963 try: # new API
1964 lr = float(backend.get_value(self.model.optimizer.lr))
1965 lr = self.schedule(epoch, lr)
1966 except TypeError: # Support for old API for backward compatibility
1967 lr = self.schedule(epoch)
1968 if not isinstance(lr, (ops.Tensor, float, np.float32, np.float64)):
1969 raise ValueError('The output of the "schedule" function '
1970 'should be float.')
1971 if isinstance(lr, ops.Tensor) and not lr.dtype.is_floating:
1972 raise ValueError('The dtype of Tensor should be float')
1973 backend.set_value(self.model.optimizer.lr, backend.get_value(lr))
1974 if self.verbose > 0:
1975 print('\nEpoch %05d: LearningRateScheduler setting learning '
1976 'rate to %s.' % (epoch + 1, lr))
1978 def on_epoch_end(self, epoch, logs=None):
1979 logs = logs or {}
1980 logs['lr'] = backend.get_value(self.model.optimizer.lr)
1983def keras_model_summary(name, data, step=None):
1984 """Writes a Keras model as JSON to as a Summary.
1986 Writing the Keras model configuration allows the TensorBoard graph plugin to
1987 render a conceptual graph, as opposed to graph of ops. In case the model fails
1988 to serialize as JSON, it ignores and returns False.
1990 Args:
1991 name: A name for this summary. The summary tag used for TensorBoard will be
1992 this name prefixed by any active name scopes.
1993 data: A Keras Model to write.
1994 step: Explicit `int64`-castable monotonic step value for this summary. If
1995 omitted, this defaults to `tf.summary.experimental.get_step()`, which must
1996 not be None.
1998 Returns:
1999 True on success, or False if no summary was written because no default
2000 summary writer was available.
2002 Raises:
2003 ValueError: if a default writer exists, but no step was provided and
2004 `tf.summary.experimental.get_step()` is None.
2005 """
2006 summary_metadata = summary_pb2.SummaryMetadata()
2007 # Hard coding a plugin name. Please refer to go/tb-plugin-name-hardcode for
2008 # the rationale.
2009 summary_metadata.plugin_data.plugin_name = 'graph_keras_model'
2010 # version number = 1
2011 summary_metadata.plugin_data.content = b'1'
2013 try:
2014 json_string = data.to_json()
2015 except Exception as exc: # pylint: disable=broad-except
2016 # An exception should not break a model code.
2017 logging.warning('Model failed to serialize as JSON. Ignoring... %s', exc)
2018 return False
2020 with summary_ops_v2.summary_scope(name, 'graph_keras_model',
2021 [data, step]) as (tag, _):
2022 with ops.device('cpu:0'):
2023 tensor = constant_op.constant(json_string, dtype=dtypes.string)
2024 return summary_ops_v2.write(
2025 tag=tag, tensor=tensor, step=step, metadata=summary_metadata)
2028@keras_export('keras.callbacks.TensorBoard', v1=[])
2029class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
2030 # pylint: disable=line-too-long
2031 """Enable visualizations for TensorBoard.
2033 TensorBoard is a visualization tool provided with TensorFlow.
2035 This callback logs events for TensorBoard, including:
2037 * Metrics summary plots
2038 * Training graph visualization
2039 * Activation histograms
2040 * Sampled profiling
2042 When used in `Model.evaluate`, in addition to epoch summaries, there will be
2043 a summary that records evaluation metrics vs `Model.optimizer.iterations`
2044 written. The metric names will be prepended with `evaluation`, with
2045 `Model.optimizer.iterations` being the step in the visualized TensorBoard.
2047 If you have installed TensorFlow with pip, you should be able
2048 to launch TensorBoard from the command line:
2050 ```
2051 tensorboard --logdir=path_to_your_logs
2052 ```
2054 You can find more information about TensorBoard
2055 [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
2057 Args:
2058 log_dir: the path of the directory where to save the log files to be
2059 parsed by TensorBoard. e.g. log_dir = os.path.join(working_dir, 'logs')
2060 This directory should not be reused by any other callbacks.
2061 histogram_freq: frequency (in epochs) at which to compute activation and
2062 weight histograms for the layers of the model. If set to 0, histograms
2063 won't be computed. Validation data (or split) must be specified for
2064 histogram visualizations.
2065 write_graph: whether to visualize the graph in TensorBoard. The log file
2066 can become quite large when write_graph is set to True.
2067 write_images: whether to write model weights to visualize as image in
2068 TensorBoard.
2069 write_steps_per_second: whether to log the training steps per second into
2070 Tensorboard. This supports both epoch and batch frequency logging.
2071 update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`,
2072 writes the losses and metrics to TensorBoard after each batch. The same
2073 applies for `'epoch'`. If using an integer, let's say `1000`, the
2074 callback will write the metrics and losses to TensorBoard every 1000
2075 batches. Note that writing too frequently to TensorBoard can slow down
2076 your training.
2077 profile_batch: Profile the batch(es) to sample compute characteristics.
2078 profile_batch must be a non-negative integer or a tuple of integers.
2079 A pair of positive integers signify a range of batches to profile.
2080 By default, it will profile the second batch. Set profile_batch=0
2081 to disable profiling.
2082 embeddings_freq: frequency (in epochs) at which embedding layers will be
2083 visualized. If set to 0, embeddings won't be visualized.
2084 embeddings_metadata: Dictionary which maps embedding layer names to the
2085 filename of a file in which to save metadata for the embedding layer.
2086 In case the same metadata file is to be
2087 used for all embedding layers, a single filename can be passed.
2089 Examples:
2091 Basic usage:
2093 ```python
2094 tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
2095 model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
2096 # Then run the tensorboard command to view the visualizations.
2097 ```
2099 Custom batch-level summaries in a subclassed Model:
2101 ```python
2102 class MyModel(tf.keras.Model):
2104 def build(self, _):
2105 self.dense = tf.keras.layers.Dense(10)
2107 def call(self, x):
2108 outputs = self.dense(x)
2109 tf.summary.histogram('outputs', outputs)
2110 return outputs
2112 model = MyModel()
2113 model.compile('sgd', 'mse')
2115 # Make sure to set `update_freq=N` to log a batch-level summary every N batches.
2116 # In addition to any `tf.summary` contained in `Model.call`, metrics added in
2117 # `Model.compile` will be logged every N batches.
2118 tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1)
2119 model.fit(x_train, y_train, callbacks=[tb_callback])
2120 ```
2122 Custom batch-level summaries in a Functional API Model:
2124 ```python
2125 def my_summary(x):
2126 tf.summary.histogram('x', x)
2127 return x
2129 inputs = tf.keras.Input(10)
2130 x = tf.keras.layers.Dense(10)(inputs)
2131 outputs = tf.keras.layers.Lambda(my_summary)(x)
2132 model = tf.keras.Model(inputs, outputs)
2133 model.compile('sgd', 'mse')
2135 # Make sure to set `update_freq=N` to log a batch-level summary every N batches.
2136 # In addition to any `tf.summary` contained in `Model.call`, metrics added in
2137 # `Model.compile` will be logged every N batches.
2138 tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1)
2139 model.fit(x_train, y_train, callbacks=[tb_callback])
2140 ```
2142 Profiling:
2144 ```python
2145 # Profile a single batch, e.g. the 5th batch.
2146 tensorboard_callback = tf.keras.callbacks.TensorBoard(
2147 log_dir='./logs', profile_batch=5)
2148 model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
2150 # Profile a range of batches, e.g. from 10 to 20.
2151 tensorboard_callback = tf.keras.callbacks.TensorBoard(
2152 log_dir='./logs', profile_batch=(10,20))
2153 model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
2154 ```
2155 """
2157 # pylint: enable=line-too-long
2159 def __init__(self,
2160 log_dir='logs',
2161 histogram_freq=0,
2162 write_graph=True,
2163 write_images=False,
2164 write_steps_per_second=False,
2165 update_freq='epoch',
2166 profile_batch=2,
2167 embeddings_freq=0,
2168 embeddings_metadata=None,
2169 **kwargs):
2170 super(TensorBoard, self).__init__()
2171 self._supports_tf_logs = True
2172 self._validate_kwargs(kwargs)
2174 self.log_dir = path_to_string(log_dir)
2175 self.histogram_freq = histogram_freq
2176 self.write_graph = write_graph
2177 self.write_images = write_images
2178 self.write_steps_per_second = write_steps_per_second
2179 self.update_freq = 1 if update_freq == 'batch' else update_freq
2180 self.embeddings_freq = embeddings_freq
2181 self.embeddings_metadata = embeddings_metadata
2182 self._init_profile_batch(profile_batch)
2183 self._global_train_batch = 0
2184 self._previous_epoch_iterations = 0
2185 self._train_accumulated_time = 0
2186 self._batch_start_time = 0
2188 # Lazily initialized in order to avoid creating event files when
2189 # not needed.
2190 self._writers = {}
2192 # Used to restore any existing `SummaryWriter` after training ends.
2193 self._prev_summary_state = []
2195 def _validate_kwargs(self, kwargs):
2196 """Handle arguments were supported in V1."""
2197 if kwargs.get('write_grads', False):
2198 logging.warning('`write_grads` will be ignored in TensorFlow 2.0 '
2199 'for the `TensorBoard` Callback.')
2200 if kwargs.get('batch_size', False):
2201 logging.warning('`batch_size` is no longer needed in the '
2202 '`TensorBoard` Callback and will be ignored '
2203 'in TensorFlow 2.0.')
2204 if kwargs.get('embeddings_layer_names', False):
2205 logging.warning('`embeddings_layer_names` is not supported in '
2206 'TensorFlow 2.0. Instead, all `Embedding` layers '
2207 'will be visualized.')
2208 if kwargs.get('embeddings_data', False):
2209 logging.warning('`embeddings_data` is not supported in TensorFlow '
2210 '2.0. Instead, all `Embedding` variables will be '
2211 'visualized.')
2213 unrecognized_kwargs = set(kwargs.keys()) - {
2214 'write_grads', 'embeddings_layer_names', 'embeddings_data', 'batch_size'
2215 }
2217 # Only allow kwargs that were supported in V1.
2218 if unrecognized_kwargs:
2219 raise ValueError('Unrecognized arguments in `TensorBoard` '
2220 'Callback: ' + str(unrecognized_kwargs))
2222 def set_model(self, model):
2223 """Sets Keras model and writes graph if specified."""
2224 self.model = model
2225 self._log_write_dir = self._get_log_write_dir()
2227 self._train_dir = os.path.join(self._log_write_dir, 'train')
2228 self._train_step = self.model._train_counter # pylint: disable=protected-access
2230 self._val_dir = os.path.join(self._log_write_dir, 'validation')
2231 self._val_step = self.model._test_counter # pylint: disable=protected-access
2233 self._writers = {} # Resets writers.
2235 self._should_write_train_graph = False
2236 if self.write_graph:
2237 self._write_keras_model_summary()
2238 self._should_write_train_graph = True
2239 if self.embeddings_freq:
2240 self._configure_embeddings()
2242 @property
2243 def _train_writer(self):
2244 if 'train' not in self._writers:
2245 self._writers['train'] = summary_ops_v2.create_file_writer_v2(
2246 self._train_dir)
2247 return self._writers['train']
2249 @property
2250 def _val_writer(self):
2251 if 'val' not in self._writers:
2252 self._writers['val'] = summary_ops_v2.create_file_writer_v2(self._val_dir)
2253 return self._writers['val']
2255 def _get_log_write_dir(self):
2256 """For multi-worker, only chief should write, others write to '/tmp'."""
2257 return distributed_file_utils.write_dirpath(self.log_dir,
2258 self.model.distribute_strategy)
2260 def _delete_tmp_write_dir(self):
2261 """Deletes tmp write directories for multi-worker."""
2262 distributed_file_utils.remove_temp_dirpath(self.log_dir,
2263 self.model.distribute_strategy)
2265 def _write_keras_model_train_graph(self):
2266 """Writes Keras model train_function graph to TensorBoard."""
2267 with self._train_writer.as_default():
2268 with summary_ops_v2.record_if(True):
2269 train_fn = self.model.train_tf_function
2270 # If the train_function is a `tf.function`, we can write out a graph
2271 if hasattr(train_fn, 'function_spec'):
2272 summary_ops_v2.graph(train_fn._concrete_stateful_fn.graph) # pylint: disable=protected-access
2274 def _write_keras_model_summary(self):
2275 """Writes Keras graph network summary to TensorBoard."""
2276 with self._train_writer.as_default():
2277 with summary_ops_v2.record_if(True):
2278 summary_writable = (
2279 self.model._is_graph_network or # pylint: disable=protected-access
2280 self.model.__class__.__name__ == 'Sequential') # pylint: disable=protected-access
2281 if summary_writable:
2282 keras_model_summary('keras', self.model, step=0)
2284 def _configure_embeddings(self):
2285 """Configure the Projector for embeddings."""
2286 # TODO(omalleyt): Add integration tests.
2287 from google.protobuf import text_format
2288 from tensorflow.python.keras.layers import embeddings
2289 from tensorflow.python.keras.protobuf import projector_config_pb2
2291 config = projector_config_pb2.ProjectorConfig()
2292 for layer in self.model.layers:
2293 if isinstance(layer, embeddings.Embedding):
2294 embedding = config.embeddings.add()
2295 # Embeddings are always the first layer, so this naming should be
2296 # consistent in any keras models checkpoints.
2297 name = 'layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE'
2298 embedding.tensor_name = name
2300 if self.embeddings_metadata is not None:
2301 if isinstance(self.embeddings_metadata, str):
2302 embedding.metadata_path = self.embeddings_metadata
2303 else:
2304 if layer.name in self.embeddings_metadata.keys():
2305 embedding.metadata_path = self.embeddings_metadata.pop(layer.name)
2307 if self.embeddings_metadata and not isinstance(self.embeddings_metadata,
2308 str):
2309 raise ValueError('Unrecognized `Embedding` layer names passed to '
2310 '`keras.callbacks.TensorBoard` `embeddings_metadata` '
2311 'argument: ' + str(self.embeddings_metadata.keys()))
2313 config_pbtxt = text_format.MessageToString(config)
2314 path = os.path.join(self._log_write_dir, 'projector_config.pbtxt')
2315 with gfile.Open(path, 'w') as f:
2316 f.write(config_pbtxt)
2318 def _push_writer(self, writer, step):
2319 """Sets the default writer for custom batch-level summaries."""
2320 if self.update_freq == 'epoch':
2321 return
2323 should_record = lambda: math_ops.equal(step % self.update_freq, 0)
2324 # TODO(b/151339474): Fix deadlock when not using .value() here.
2325 summary_context = (writer.as_default(step.value()),
2326 summary_ops_v2.record_if(should_record))
2327 self._prev_summary_state.append(summary_context)
2328 summary_context[0].__enter__()
2329 summary_context[1].__enter__()
2331 def _pop_writer(self):
2332 """Pops the current writer."""
2333 if self.update_freq == 'epoch':
2334 return
2336 # See _push_writer for the content of the previous_context, which is pair
2337 # of context.
2338 previous_context = self._prev_summary_state.pop()
2339 previous_context[1].__exit__(*sys.exc_info())
2340 previous_context[0].__exit__(*sys.exc_info())
2342 def _close_writers(self):
2343 for writer in self._writers.values():
2344 writer.close()
2346 def _init_profile_batch(self, profile_batch):
2347 """Validate profile_batch value and set the range of batches to profile.
2348 Sets values of _start_batch and _stop_batch attributes,
2349 specifying the start and stop batch to profile.
2350 Setting `profile_batch=0` disables profiling.
2352 Args:
2353 profile_batch: The range of batches to profile. Should be a non-negative
2354 integer or a comma separated string of pair of positive integers. A pair
2355 of positive integers signify a range of batches to profile.
2357 Raises:
2358 ValueError: If profile_batch is not an integer or a comma separated pair
2359 of positive integers.
2361 """
2362 profile_batch_error_message = (
2363 'profile_batch must be a non-negative integer or 2-tuple of positive '
2364 'integers. A pair of positive integers signifies a range of batches '
2365 'to profile. Found: {}'.format(profile_batch))
2367 # Support legacy way of specifying "start,stop" or "start" as str.
2368 if isinstance(profile_batch, str):
2369 profile_batch = str(profile_batch).split(',')
2370 profile_batch = nest.map_structure(int, profile_batch)
2372 if isinstance(profile_batch, int):
2373 self._start_batch = profile_batch
2374 self._stop_batch = profile_batch
2375 elif isinstance(profile_batch, (tuple, list)) and len(profile_batch) == 2:
2376 self._start_batch, self._stop_batch = profile_batch
2377 else:
2378 raise ValueError(profile_batch_error_message)
2380 if self._start_batch < 0 or self._stop_batch < self._start_batch:
2381 raise ValueError(profile_batch_error_message)
2383 # True when the profiler was successfully started by this callback.
2384 # We track the status here to make sure callbacks do not interfere with
2385 # each other. The callback will only stop the profiler it started.
2386 self._profiler_started = False
2387 if self._start_batch > 0:
2388 # Warm up and improve the profiling accuracy.
2389 self._start_profiler(logdir='')
2390 self._stop_profiler(save=False)
2391 # True when a trace is running.
2392 self._is_tracing = False
2394 # Setting `profile_batch=0` disables profiling.
2395 self._should_trace = not (self._start_batch == 0 and self._stop_batch == 0)
2397 def on_train_begin(self, logs=None):
2398 self._global_train_batch = 0
2399 self._previous_epoch_iterations = 0
2400 self._train_accumulated_time = 0
2401 self._push_writer(self._train_writer, self._train_step)
2403 def on_train_end(self, logs=None):
2404 self._pop_writer()
2406 if self._is_tracing:
2407 self._stop_trace()
2409 self._close_writers()
2410 self._delete_tmp_write_dir()
2412 def on_test_begin(self, logs=None):
2413 self._push_writer(self._val_writer, self._val_step)
2415 def on_test_end(self, logs=None):
2416 if self.model.optimizer and hasattr(self.model.optimizer, 'iterations'):
2417 with summary_ops_v2.record_if(True), self._val_writer.as_default():
2418 for name, value in logs.items():
2419 summary_ops_v2.scalar(
2420 'evaluation_' + name + '_vs_iterations',
2421 value,
2422 step=self.model.optimizer.iterations.read_value())
2423 self._pop_writer()
2425 def _implements_train_batch_hooks(self):
2426 # Only call batch hooks when tracing or write_steps_per_second are enabled
2427 return self._should_trace or self.write_steps_per_second
2429 def on_train_batch_begin(self, batch, logs=None):
2430 self._global_train_batch += 1
2431 if self.write_steps_per_second:
2432 self._batch_start_time = time.time()
2433 if not self._should_trace:
2434 return
2436 if self._global_train_batch == self._start_batch:
2437 self._start_trace()
2439 def on_train_batch_end(self, batch, logs=None):
2440 if self._should_write_train_graph:
2441 self._write_keras_model_train_graph()
2442 self._should_write_train_graph = False
2443 if self.write_steps_per_second:
2444 batch_run_time = time.time() - self._batch_start_time
2445 self._train_accumulated_time += batch_run_time
2446 summary_ops_v2.scalar(
2447 'batch_steps_per_second', 1. / batch_run_time, step=self._train_step)
2448 if not self._should_trace:
2449 return
2451 if self._is_tracing and self._global_train_batch >= self._stop_batch:
2452 self._stop_trace()
2454 def on_epoch_begin(self, epoch, logs=None):
2455 # Keeps track of epoch for profiling.
2456 if self.write_steps_per_second:
2457 self._previous_epoch_iterations = self.model.optimizer.iterations.numpy()
2458 self._train_accumulated_time = 0
2460 def on_epoch_end(self, epoch, logs=None):
2461 """Runs metrics and histogram summaries at epoch end."""
2462 self._log_epoch_metrics(epoch, logs)
2464 if self.histogram_freq and epoch % self.histogram_freq == 0:
2465 self._log_weights(epoch)
2467 if self.embeddings_freq and epoch % self.embeddings_freq == 0:
2468 self._log_embeddings(epoch)
2470 def _start_trace(self):
2471 summary_ops_v2.trace_on(graph=True, profiler=False)
2472 self._start_profiler(logdir=self._train_dir)
2473 self._is_tracing = True
2475 def _stop_trace(self, batch=None):
2476 """Logs the trace graph to TensorBoard."""
2477 if batch is None:
2478 batch = self._stop_batch
2479 with self._train_writer.as_default():
2480 with summary_ops_v2.record_if(True):
2481 # TODO(b/126388999): Remove step info in the summary name.
2482 summary_ops_v2.trace_export(name='batch_%d' % batch, step=batch)
2483 self._stop_profiler()
2484 self._is_tracing = False
2486 def _collect_learning_rate(self, logs):
2487 lr_schedule = getattr(self.model.optimizer, 'lr', None)
2488 if isinstance(lr_schedule, learning_rate_schedule.LearningRateSchedule):
2489 logs['learning_rate'] = lr_schedule(self.model.optimizer.iterations)
2490 return logs
2492 def _compute_steps_per_second(self):
2493 current_iteration = self.model.optimizer.iterations.numpy()
2494 steps_per_second = ((current_iteration - self._previous_epoch_iterations) /
2495 (self._train_accumulated_time))
2496 return steps_per_second
2498 def _log_epoch_metrics(self, epoch, logs):
2499 """Writes epoch metrics out as scalar summaries.
2501 Args:
2502 epoch: Int. The global step to use for TensorBoard.
2503 logs: Dict. Keys are scalar summary names, values are scalars.
2504 """
2505 if not logs:
2506 return
2508 train_logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
2509 val_logs = {k: v for k, v in logs.items() if k.startswith('val_')}
2510 train_logs = self._collect_learning_rate(train_logs)
2511 if self.write_steps_per_second:
2512 train_logs['steps_per_second'] = self._compute_steps_per_second()
2514 with summary_ops_v2.record_if(True):
2515 if train_logs:
2516 with self._train_writer.as_default():
2517 for name, value in train_logs.items():
2518 summary_ops_v2.scalar('epoch_' + name, value, step=epoch)
2519 if val_logs:
2520 with self._val_writer.as_default():
2521 for name, value in val_logs.items():
2522 name = name[4:] # Remove 'val_' prefix.
2523 summary_ops_v2.scalar('epoch_' + name, value, step=epoch)
2525 def _log_weights(self, epoch):
2526 """Logs the weights of the Model to TensorBoard."""
2527 with self._train_writer.as_default():
2528 with summary_ops_v2.record_if(True):
2529 for layer in self.model.layers:
2530 for weight in layer.weights:
2531 weight_name = weight.name.replace(':', '_')
2532 summary_ops_v2.histogram(weight_name, weight, step=epoch)
2533 if self.write_images:
2534 self._log_weight_as_image(weight, weight_name, epoch)
2535 self._train_writer.flush()
2537 def _log_weight_as_image(self, weight, weight_name, epoch):
2538 """Logs a weight as a TensorBoard image."""
2539 w_img = array_ops.squeeze(weight)
2540 shape = backend.int_shape(w_img)
2541 if len(shape) == 1: # Bias case
2542 w_img = array_ops.reshape(w_img, [1, shape[0], 1, 1])
2543 elif len(shape) == 2: # Dense layer kernel case
2544 if shape[0] > shape[1]:
2545 w_img = array_ops.transpose(w_img)
2546 shape = backend.int_shape(w_img)
2547 w_img = array_ops.reshape(w_img, [1, shape[0], shape[1], 1])
2548 elif len(shape) == 3: # ConvNet case
2549 if backend.image_data_format() == 'channels_last':
2550 # Switch to channels_first to display every kernel as a separate
2551 # image.
2552 w_img = array_ops.transpose(w_img, perm=[2, 0, 1])
2553 shape = backend.int_shape(w_img)
2554 w_img = array_ops.reshape(w_img, [shape[0], shape[1], shape[2], 1])
2556 shape = backend.int_shape(w_img)
2557 # Not possible to handle 3D convnets etc.
2558 if len(shape) == 4 and shape[-1] in [1, 3, 4]:
2559 summary_ops_v2.image(weight_name, w_img, step=epoch)
2561 def _log_embeddings(self, epoch):
2562 embeddings_ckpt = os.path.join(self._log_write_dir, 'train',
2563 'keras_embedding.ckpt-{}'.format(epoch))
2564 self.model.save_weights(embeddings_ckpt)
2566 def _start_profiler(self, logdir):
2567 """Starts the profiler if currently inactive.
2569 Args:
2570 logdir: Directory where profiler results will be saved.
2571 """
2572 if self._profiler_started:
2573 return
2574 try:
2575 profiler.start(logdir=logdir)
2576 self._profiler_started = True
2577 except errors.AlreadyExistsError as e:
2578 # Profiler errors should not be fatal.
2579 logging.error('Failed to start profiler: %s', e.message)
2581 def _stop_profiler(self, save=True):
2582 """Stops the profiler if currently active.
2584 Args:
2585 save: Whether to save the profiler results to TensorBoard.
2586 """
2587 if not self._profiler_started:
2588 return
2589 try:
2590 profiler.stop(save=save)
2591 except errors.UnavailableError as e:
2592 # Profiler errors should not be fatal.
2593 logging.error('Failed to stop profiler: %s', e.message)
2594 finally:
2595 self._profiler_started = False
2598@keras_export('keras.callbacks.ReduceLROnPlateau')
2599class ReduceLROnPlateau(Callback):
2600 """Reduce learning rate when a metric has stopped improving.
2602 Models often benefit from reducing the learning rate by a factor
2603 of 2-10 once learning stagnates. This callback monitors a
2604 quantity and if no improvement is seen for a 'patience' number
2605 of epochs, the learning rate is reduced.
2607 Example:
2609 ```python
2610 reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
2611 patience=5, min_lr=0.001)
2612 model.fit(X_train, Y_train, callbacks=[reduce_lr])
2613 ```
2615 Args:
2616 monitor: quantity to be monitored.
2617 factor: factor by which the learning rate will be reduced.
2618 `new_lr = lr * factor`.
2619 patience: number of epochs with no improvement after which learning rate
2620 will be reduced.
2621 verbose: int. 0: quiet, 1: update messages.
2622 mode: one of `{'auto', 'min', 'max'}`. In `'min'` mode,
2623 the learning rate will be reduced when the
2624 quantity monitored has stopped decreasing; in `'max'` mode it will be
2625 reduced when the quantity monitored has stopped increasing; in `'auto'`
2626 mode, the direction is automatically inferred from the name of the
2627 monitored quantity.
2628 min_delta: threshold for measuring the new optimum, to only focus on
2629 significant changes.
2630 cooldown: number of epochs to wait before resuming normal operation after
2631 lr has been reduced.
2632 min_lr: lower bound on the learning rate.
2633 """
2635 def __init__(self,
2636 monitor='val_loss',
2637 factor=0.1,
2638 patience=10,
2639 verbose=0,
2640 mode='auto',
2641 min_delta=1e-4,
2642 cooldown=0,
2643 min_lr=0,
2644 **kwargs):
2645 super(ReduceLROnPlateau, self).__init__()
2647 self.monitor = monitor
2648 if factor >= 1.0:
2649 raise ValueError('ReduceLROnPlateau ' 'does not support a factor >= 1.0.')
2650 if 'epsilon' in kwargs:
2651 min_delta = kwargs.pop('epsilon')
2652 logging.warning('`epsilon` argument is deprecated and '
2653 'will be removed, use `min_delta` instead.')
2654 self.factor = factor
2655 self.min_lr = min_lr
2656 self.min_delta = min_delta
2657 self.patience = patience
2658 self.verbose = verbose
2659 self.cooldown = cooldown
2660 self.cooldown_counter = 0 # Cooldown counter.
2661 self.wait = 0
2662 self.best = 0
2663 self.mode = mode
2664 self.monitor_op = None
2665 self._reset()
2667 def _reset(self):
2668 """Resets wait counter and cooldown counter.
2669 """
2670 if self.mode not in ['auto', 'min', 'max']:
2671 logging.warning('Learning rate reduction mode %s is unknown, '
2672 'fallback to auto mode.', self.mode)
2673 self.mode = 'auto'
2674 if (self.mode == 'min' or
2675 (self.mode == 'auto' and 'acc' not in self.monitor)):
2676 self.monitor_op = lambda a, b: np.less(a, b - self.min_delta)
2677 self.best = np.Inf
2678 else:
2679 self.monitor_op = lambda a, b: np.greater(a, b + self.min_delta)
2680 self.best = -np.Inf
2681 self.cooldown_counter = 0
2682 self.wait = 0
2684 def on_train_begin(self, logs=None):
2685 self._reset()
2687 def on_epoch_end(self, epoch, logs=None):
2688 logs = logs or {}
2689 logs['lr'] = backend.get_value(self.model.optimizer.lr)
2690 current = logs.get(self.monitor)
2691 if current is None:
2692 logging.warning('Learning rate reduction is conditioned on metric `%s` '
2693 'which is not available. Available metrics are: %s',
2694 self.monitor, ','.join(list(logs.keys())))
2696 else:
2697 if self.in_cooldown():
2698 self.cooldown_counter -= 1
2699 self.wait = 0
2701 if self.monitor_op(current, self.best):
2702 self.best = current
2703 self.wait = 0
2704 elif not self.in_cooldown():
2705 self.wait += 1
2706 if self.wait >= self.patience:
2707 old_lr = backend.get_value(self.model.optimizer.lr)
2708 if old_lr > np.float32(self.min_lr):
2709 new_lr = old_lr * self.factor
2710 new_lr = max(new_lr, self.min_lr)
2711 backend.set_value(self.model.optimizer.lr, new_lr)
2712 if self.verbose > 0:
2713 print('\nEpoch %05d: ReduceLROnPlateau reducing learning '
2714 'rate to %s.' % (epoch + 1, new_lr))
2715 self.cooldown_counter = self.cooldown
2716 self.wait = 0
2718 def in_cooldown(self):
2719 return self.cooldown_counter > 0
2722@keras_export('keras.callbacks.CSVLogger')
2723class CSVLogger(Callback):
2724 """Callback that streams epoch results to a CSV file.
2726 Supports all values that can be represented as a string,
2727 including 1D iterables such as `np.ndarray`.
2729 Example:
2731 ```python
2732 csv_logger = CSVLogger('training.log')
2733 model.fit(X_train, Y_train, callbacks=[csv_logger])
2734 ```
2736 Args:
2737 filename: Filename of the CSV file, e.g. `'run/log.csv'`.
2738 separator: String used to separate elements in the CSV file.
2739 append: Boolean. True: append if file exists (useful for continuing
2740 training). False: overwrite existing file.
2741 """
2743 def __init__(self, filename, separator=',', append=False):
2744 self.sep = separator
2745 self.filename = path_to_string(filename)
2746 self.append = append
2747 self.writer = None
2748 self.keys = None
2749 self.append_header = True
2750 super(CSVLogger, self).__init__()
2752 def on_train_begin(self, logs=None):
2753 if self.append:
2754 if file_io.file_exists_v2(self.filename):
2755 with gfile.GFile(self.filename, 'r') as f:
2756 self.append_header = not bool(len(f.readline()))
2757 mode = 'a'
2758 else:
2759 mode = 'w'
2760 self.csv_file = gfile.GFile(self.filename, mode)
2762 def on_epoch_end(self, epoch, logs=None):
2763 logs = logs or {}
2765 def handle_value(k):
2766 is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
2767 if isinstance(k, str):
2768 return k
2769 elif isinstance(k, collections.abc.Iterable) and not is_zero_dim_ndarray:
2770 return '"[%s]"' % (', '.join(map(str, k)))
2771 else:
2772 return k
2774 if self.keys is None:
2775 self.keys = sorted(logs.keys())
2777 if self.model.stop_training:
2778 # We set NA so that csv parsers do not fail for this last epoch.
2779 logs = dict((k, logs[k]) if k in logs else (k, 'NA') for k in self.keys)
2781 if not self.writer:
2783 class CustomDialect(csv.excel):
2784 delimiter = self.sep
2786 fieldnames = ['epoch'] + self.keys
2788 self.writer = csv.DictWriter(
2789 self.csv_file,
2790 fieldnames=fieldnames,
2791 dialect=CustomDialect)
2792 if self.append_header:
2793 self.writer.writeheader()
2795 row_dict = collections.OrderedDict({'epoch': epoch})
2796 row_dict.update((key, handle_value(logs[key])) for key in self.keys)
2797 self.writer.writerow(row_dict)
2798 self.csv_file.flush()
2800 def on_train_end(self, logs=None):
2801 self.csv_file.close()
2802 self.writer = None
2805@keras_export('keras.callbacks.LambdaCallback')
2806class LambdaCallback(Callback):
2807 r"""Callback for creating simple, custom callbacks on-the-fly.
2809 This callback is constructed with anonymous functions that will be called
2810 at the appropriate time (during `Model.{fit | evaluate | predict}`).
2811 Note that the callbacks expects positional arguments, as:
2813 - `on_epoch_begin` and `on_epoch_end` expect two positional arguments:
2814 `epoch`, `logs`
2815 - `on_batch_begin` and `on_batch_end` expect two positional arguments:
2816 `batch`, `logs`
2817 - `on_train_begin` and `on_train_end` expect one positional argument:
2818 `logs`
2820 Args:
2821 on_epoch_begin: called at the beginning of every epoch.
2822 on_epoch_end: called at the end of every epoch.
2823 on_batch_begin: called at the beginning of every batch.
2824 on_batch_end: called at the end of every batch.
2825 on_train_begin: called at the beginning of model training.
2826 on_train_end: called at the end of model training.
2828 Example:
2830 ```python
2831 # Print the batch number at the beginning of every batch.
2832 batch_print_callback = LambdaCallback(
2833 on_batch_begin=lambda batch,logs: print(batch))
2835 # Stream the epoch loss to a file in JSON format. The file content
2836 # is not well-formed JSON but rather has a JSON object per line.
2837 import json
2838 json_log = open('loss_log.json', mode='wt', buffering=1)
2839 json_logging_callback = LambdaCallback(
2840 on_epoch_end=lambda epoch, logs: json_log.write(
2841 json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'),
2842 on_train_end=lambda logs: json_log.close()
2843 )
2845 # Terminate some processes after having finished model training.
2846 processes = ...
2847 cleanup_callback = LambdaCallback(
2848 on_train_end=lambda logs: [
2849 p.terminate() for p in processes if p.is_alive()])
2851 model.fit(...,
2852 callbacks=[batch_print_callback,
2853 json_logging_callback,
2854 cleanup_callback])
2855 ```
2856 """
2858 def __init__(self,
2859 on_epoch_begin=None,
2860 on_epoch_end=None,
2861 on_batch_begin=None,
2862 on_batch_end=None,
2863 on_train_begin=None,
2864 on_train_end=None,
2865 **kwargs):
2866 super(LambdaCallback, self).__init__()
2867 self.__dict__.update(kwargs)
2868 if on_epoch_begin is not None:
2869 self.on_epoch_begin = on_epoch_begin
2870 else:
2871 self.on_epoch_begin = lambda epoch, logs: None
2872 if on_epoch_end is not None:
2873 self.on_epoch_end = on_epoch_end
2874 else:
2875 self.on_epoch_end = lambda epoch, logs: None
2876 if on_batch_begin is not None:
2877 self.on_batch_begin = on_batch_begin
2878 else:
2879 self.on_batch_begin = lambda batch, logs: None
2880 if on_batch_end is not None:
2881 self.on_batch_end = on_batch_end
2882 else:
2883 self.on_batch_end = lambda batch, logs: None
2884 if on_train_begin is not None:
2885 self.on_train_begin = on_train_begin
2886 else:
2887 self.on_train_begin = lambda logs: None
2888 if on_train_end is not None:
2889 self.on_train_end = on_train_end
2890 else:
2891 self.on_train_end = lambda logs: None