Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/callbacks/tqdm_progress_bar.py: 19%
120 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TQDM Progress Bar."""
17import time
18import tensorflow as tf
19from collections import defaultdict
20from typeguard import typechecked
22from tensorflow.keras.callbacks import Callback
25@tf.keras.utils.register_keras_serializable(package="Addons")
26class TQDMProgressBar(Callback):
27 """TQDM Progress Bar for Tensorflow Keras.
29 Args:
30 metrics_separator: Custom separator between metrics.
31 Defaults to ' - '.
32 overall_bar_format: Custom bar format for overall
33 (outer) progress bar, see https://github.com/tqdm/tqdm#parameters
34 for more detail.
35 epoch_bar_format: Custom bar format for epoch
36 (inner) progress bar, see https://github.com/tqdm/tqdm#parameters
37 for more detail.
38 update_per_second: Maximum number of updates in the epochs bar
39 per second, this is to prevent small batches from slowing down
40 training. Defaults to 10.
41 metrics_format: Custom format for how metrics are formatted.
42 See https://github.com/tqdm/tqdm#parameters for more detail.
43 leave_epoch_progress: `True` to leave epoch progress bars.
44 leave_overall_progress: `True` to leave overall progress bar.
45 show_epoch_progress: `False` to hide epoch progress bars.
46 show_overall_progress: `False` to hide overall progress bar.
47 """
49 @typechecked
50 def __init__(
51 self,
52 metrics_separator: str = " - ",
53 overall_bar_format: str = "{l_bar}{bar} {n_fmt}/{total_fmt} ETA: "
54 "{remaining}s, {rate_fmt}{postfix}",
55 epoch_bar_format: str = "{n_fmt}/{total_fmt}{bar} ETA: "
56 "{remaining}s - {desc}",
57 metrics_format: str = "{name}: {value:0.4f}",
58 update_per_second: int = 10,
59 leave_epoch_progress: bool = True,
60 leave_overall_progress: bool = True,
61 show_epoch_progress: bool = True,
62 show_overall_progress: bool = True,
63 ):
65 try:
66 # import tqdm here because tqdm is not a required package
67 # for addons
68 import tqdm
70 version_message = "Please update your TQDM version to >= 4.36.1, "
71 "you have version {}. To update, run !pip install -U tqdm"
72 assert tqdm.__version__ >= "4.36.1", version_message.format(
73 tqdm.__version__
74 )
75 from tqdm.auto import tqdm
77 self.tqdm = tqdm
78 except ImportError:
79 raise ImportError("Please install tqdm via pip install tqdm")
81 self.metrics_separator = metrics_separator
82 self.overall_bar_format = overall_bar_format
83 self.epoch_bar_format = epoch_bar_format
84 self.leave_epoch_progress = leave_epoch_progress
85 self.leave_overall_progress = leave_overall_progress
86 self.show_epoch_progress = show_epoch_progress
87 self.show_overall_progress = show_overall_progress
88 self.metrics_format = metrics_format
90 # compute update interval (inverse of update per second)
91 self.update_interval = 1 / update_per_second
93 self.last_update_time = time.time()
94 self.overall_progress_tqdm = None
95 self.epoch_progress_tqdm = None
96 self.is_training = False
97 self.num_epochs = None
98 self.logs = None
99 super().__init__()
101 def _initialize_progbar(self, hook, epoch, logs=None):
102 self.num_samples_seen = 0
103 self.steps_to_update = 0
104 self.steps_so_far = 0
105 self.logs = defaultdict(float)
106 self.num_epochs = self.params["epochs"]
107 self.mode = "steps"
108 self.total_steps = self.params["steps"]
109 if hook == "train_overall":
110 if self.show_overall_progress:
111 self.overall_progress_tqdm = self.tqdm(
112 desc="Training",
113 total=self.num_epochs,
114 bar_format=self.overall_bar_format,
115 leave=self.leave_overall_progress,
116 dynamic_ncols=True,
117 unit="epochs",
118 )
119 elif hook == "test":
120 if self.show_epoch_progress:
121 self.epoch_progress_tqdm = self.tqdm(
122 total=self.total_steps,
123 desc="Evaluating",
124 bar_format=self.epoch_bar_format,
125 leave=self.leave_epoch_progress,
126 dynamic_ncols=True,
127 unit=self.mode,
128 )
129 elif hook == "train_epoch":
130 current_epoch_description = "Epoch {epoch}/{num_epochs}".format(
131 epoch=epoch + 1, num_epochs=self.num_epochs
132 )
133 if self.show_epoch_progress:
134 print(current_epoch_description)
135 self.epoch_progress_tqdm = self.tqdm(
136 total=self.total_steps,
137 bar_format=self.epoch_bar_format,
138 leave=self.leave_epoch_progress,
139 dynamic_ncols=True,
140 unit=self.mode,
141 )
143 def _clean_up_progbar(self, hook, logs):
144 if hook == "train_overall":
145 if self.show_overall_progress:
146 self.overall_progress_tqdm.close()
147 else:
148 if hook == "test":
149 metrics = self.format_metrics(logs, self.num_samples_seen)
150 else:
151 metrics = self.format_metrics(logs)
152 if self.show_epoch_progress:
153 self.epoch_progress_tqdm.desc = metrics
154 # set miniters and mininterval to 0 so last update displays
155 self.epoch_progress_tqdm.miniters = 0
156 self.epoch_progress_tqdm.mininterval = 0
157 # update the rest of the steps in epoch progress bar
158 self.epoch_progress_tqdm.update(
159 self.total_steps - self.epoch_progress_tqdm.n
160 )
161 self.epoch_progress_tqdm.close()
163 def _update_progbar(self, logs):
164 if self.mode == "samples":
165 batch_size = logs["size"]
166 else:
167 batch_size = 1
169 self.num_samples_seen += batch_size
170 self.steps_to_update += 1
171 self.steps_so_far += 1
173 if self.steps_so_far <= self.total_steps:
174 for metric, value in logs.items():
175 self.logs[metric] += value * batch_size
177 now = time.time()
178 time_diff = now - self.last_update_time
179 if self.show_epoch_progress and time_diff >= self.update_interval:
181 # update the epoch progress bar
182 metrics = self.format_metrics(self.logs, self.num_samples_seen)
183 self.epoch_progress_tqdm.desc = metrics
184 self.epoch_progress_tqdm.update(self.steps_to_update)
186 # reset steps to update
187 self.steps_to_update = 0
189 # update timestamp for last update
190 self.last_update_time = now
192 def on_train_begin(self, logs=None):
193 self.is_training = True
194 self._initialize_progbar("train_overall", None, logs)
196 def on_train_end(self, logs={}):
197 self.is_training = False
198 self._clean_up_progbar("train_overall", logs)
200 def on_test_begin(self, logs={}):
201 if not self.is_training:
202 self._initialize_progbar("test", None, logs)
204 def on_test_end(self, logs={}):
205 if not self.is_training:
206 self._clean_up_progbar("test", self.logs)
208 def on_epoch_begin(self, epoch, logs={}):
209 self._initialize_progbar("train_epoch", epoch, logs)
211 def on_epoch_end(self, epoch, logs={}):
212 self._clean_up_progbar("train_epoch", logs)
213 if self.show_overall_progress:
214 self.overall_progress_tqdm.update(1)
216 def on_test_batch_end(self, batch, logs={}):
217 if not self.is_training:
218 self._update_progbar(logs)
220 def on_batch_end(self, batch, logs={}):
221 self._update_progbar(logs)
223 def format_metrics(self, logs={}, factor=1):
224 """Format metrics in logs into a string.
226 Args:
227 logs: dictionary of metrics and their values. Defaults to
228 empty dictionary.
229 factor (int): The factor we want to divide the metrics in logs
230 by, useful when we are computing the logs after each batch.
231 Defaults to 1.
233 Returns:
234 metrics_string: a string displaying metrics using the given
235 formators passed in through the constructor.
236 """
238 metric_value_pairs = []
239 for key, value in logs.items():
240 if key in ["batch", "size"]:
241 continue
242 pair = self.metrics_format.format(name=key, value=value / factor)
243 metric_value_pairs.append(pair)
244 metrics_string = self.metrics_separator.join(metric_value_pairs)
245 return metrics_string
247 def get_config(self):
248 config = {
249 "metrics_separator": self.metrics_separator,
250 "overall_bar_format": self.overall_bar_format,
251 "epoch_bar_format": self.epoch_bar_format,
252 "leave_epoch_progress": self.leave_epoch_progress,
253 "leave_overall_progress": self.leave_overall_progress,
254 "show_epoch_progress": self.show_epoch_progress,
255 "show_overall_progress": self.show_overall_progress,
256 }
258 base_config = super().get_config()
259 return {**base_config, **config}