Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/callbacks/time_stopping.py: 48%
29 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Callback that stops training when a specified amount of time has passed."""
17import datetime
18import time
19from typeguard import typechecked
21import tensorflow as tf
22from tensorflow.keras.callbacks import Callback
25@tf.keras.utils.register_keras_serializable(package="Addons")
26class TimeStopping(Callback):
27 """Stop training when a specified amount of time has passed.
29 Args:
30 seconds: maximum amount of time before stopping.
31 Defaults to 86400 (1 day).
32 verbose: verbosity mode. Defaults to 0.
33 """
35 @typechecked
36 def __init__(self, seconds: int = 86400, verbose: int = 0):
37 super().__init__()
39 self.seconds = seconds
40 self.verbose = verbose
41 self.stopped_epoch = None
43 def on_train_begin(self, logs=None):
44 self.stopping_time = time.time() + self.seconds
46 def on_epoch_end(self, epoch, logs={}):
47 if time.time() >= self.stopping_time:
48 self.model.stop_training = True
49 self.stopped_epoch = epoch
51 def on_train_end(self, logs=None):
52 if self.stopped_epoch is not None and self.verbose > 0:
53 formatted_time = datetime.timedelta(seconds=self.seconds)
54 msg = "Timed stopping at epoch {} after training for {}".format(
55 self.stopped_epoch + 1, formatted_time
56 )
57 print(msg)
59 def get_config(self):
60 config = {
61 "seconds": self.seconds,
62 "verbose": self.verbose,
63 }
65 base_config = super().get_config()
66 return {**base_config, **config}