Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/callbacks/tqdm_progress_bar.py: 19%

120 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""TQDM Progress Bar.""" 

16 

17import time 

18import tensorflow as tf 

19from collections import defaultdict 

20from typeguard import typechecked 

21 

22from tensorflow.keras.callbacks import Callback 

23 

24 

25@tf.keras.utils.register_keras_serializable(package="Addons") 

26class TQDMProgressBar(Callback): 

27 """TQDM Progress Bar for Tensorflow Keras. 

28 

29 Args: 

30 metrics_separator: Custom separator between metrics. 

31 Defaults to ' - '. 

32 overall_bar_format: Custom bar format for overall 

33 (outer) progress bar, see https://github.com/tqdm/tqdm#parameters 

34 for more detail. 

35 epoch_bar_format: Custom bar format for epoch 

36 (inner) progress bar, see https://github.com/tqdm/tqdm#parameters 

37 for more detail. 

38 update_per_second: Maximum number of updates in the epochs bar 

39 per second, this is to prevent small batches from slowing down 

40 training. Defaults to 10. 

41 metrics_format: Custom format for how metrics are formatted. 

42 See https://github.com/tqdm/tqdm#parameters for more detail. 

43 leave_epoch_progress: `True` to leave epoch progress bars. 

44 leave_overall_progress: `True` to leave overall progress bar. 

45 show_epoch_progress: `False` to hide epoch progress bars. 

46 show_overall_progress: `False` to hide overall progress bar. 

47 """ 

48 

49 @typechecked 

50 def __init__( 

51 self, 

52 metrics_separator: str = " - ", 

53 overall_bar_format: str = "{l_bar}{bar} {n_fmt}/{total_fmt} ETA: " 

54 "{remaining}s, {rate_fmt}{postfix}", 

55 epoch_bar_format: str = "{n_fmt}/{total_fmt}{bar} ETA: " 

56 "{remaining}s - {desc}", 

57 metrics_format: str = "{name}: {value:0.4f}", 

58 update_per_second: int = 10, 

59 leave_epoch_progress: bool = True, 

60 leave_overall_progress: bool = True, 

61 show_epoch_progress: bool = True, 

62 show_overall_progress: bool = True, 

63 ): 

64 

65 try: 

66 # import tqdm here because tqdm is not a required package 

67 # for addons 

68 import tqdm 

69 

70 version_message = "Please update your TQDM version to >= 4.36.1, " 

71 "you have version {}. To update, run !pip install -U tqdm" 

72 assert tqdm.__version__ >= "4.36.1", version_message.format( 

73 tqdm.__version__ 

74 ) 

75 from tqdm.auto import tqdm 

76 

77 self.tqdm = tqdm 

78 except ImportError: 

79 raise ImportError("Please install tqdm via pip install tqdm") 

80 

81 self.metrics_separator = metrics_separator 

82 self.overall_bar_format = overall_bar_format 

83 self.epoch_bar_format = epoch_bar_format 

84 self.leave_epoch_progress = leave_epoch_progress 

85 self.leave_overall_progress = leave_overall_progress 

86 self.show_epoch_progress = show_epoch_progress 

87 self.show_overall_progress = show_overall_progress 

88 self.metrics_format = metrics_format 

89 

90 # compute update interval (inverse of update per second) 

91 self.update_interval = 1 / update_per_second 

92 

93 self.last_update_time = time.time() 

94 self.overall_progress_tqdm = None 

95 self.epoch_progress_tqdm = None 

96 self.is_training = False 

97 self.num_epochs = None 

98 self.logs = None 

99 super().__init__() 

100 

101 def _initialize_progbar(self, hook, epoch, logs=None): 

102 self.num_samples_seen = 0 

103 self.steps_to_update = 0 

104 self.steps_so_far = 0 

105 self.logs = defaultdict(float) 

106 self.num_epochs = self.params["epochs"] 

107 self.mode = "steps" 

108 self.total_steps = self.params["steps"] 

109 if hook == "train_overall": 

110 if self.show_overall_progress: 

111 self.overall_progress_tqdm = self.tqdm( 

112 desc="Training", 

113 total=self.num_epochs, 

114 bar_format=self.overall_bar_format, 

115 leave=self.leave_overall_progress, 

116 dynamic_ncols=True, 

117 unit="epochs", 

118 ) 

119 elif hook == "test": 

120 if self.show_epoch_progress: 

121 self.epoch_progress_tqdm = self.tqdm( 

122 total=self.total_steps, 

123 desc="Evaluating", 

124 bar_format=self.epoch_bar_format, 

125 leave=self.leave_epoch_progress, 

126 dynamic_ncols=True, 

127 unit=self.mode, 

128 ) 

129 elif hook == "train_epoch": 

130 current_epoch_description = "Epoch {epoch}/{num_epochs}".format( 

131 epoch=epoch + 1, num_epochs=self.num_epochs 

132 ) 

133 if self.show_epoch_progress: 

134 print(current_epoch_description) 

135 self.epoch_progress_tqdm = self.tqdm( 

136 total=self.total_steps, 

137 bar_format=self.epoch_bar_format, 

138 leave=self.leave_epoch_progress, 

139 dynamic_ncols=True, 

140 unit=self.mode, 

141 ) 

142 

143 def _clean_up_progbar(self, hook, logs): 

144 if hook == "train_overall": 

145 if self.show_overall_progress: 

146 self.overall_progress_tqdm.close() 

147 else: 

148 if hook == "test": 

149 metrics = self.format_metrics(logs, self.num_samples_seen) 

150 else: 

151 metrics = self.format_metrics(logs) 

152 if self.show_epoch_progress: 

153 self.epoch_progress_tqdm.desc = metrics 

154 # set miniters and mininterval to 0 so last update displays 

155 self.epoch_progress_tqdm.miniters = 0 

156 self.epoch_progress_tqdm.mininterval = 0 

157 # update the rest of the steps in epoch progress bar 

158 self.epoch_progress_tqdm.update( 

159 self.total_steps - self.epoch_progress_tqdm.n 

160 ) 

161 self.epoch_progress_tqdm.close() 

162 

163 def _update_progbar(self, logs): 

164 if self.mode == "samples": 

165 batch_size = logs["size"] 

166 else: 

167 batch_size = 1 

168 

169 self.num_samples_seen += batch_size 

170 self.steps_to_update += 1 

171 self.steps_so_far += 1 

172 

173 if self.steps_so_far <= self.total_steps: 

174 for metric, value in logs.items(): 

175 self.logs[metric] += value * batch_size 

176 

177 now = time.time() 

178 time_diff = now - self.last_update_time 

179 if self.show_epoch_progress and time_diff >= self.update_interval: 

180 

181 # update the epoch progress bar 

182 metrics = self.format_metrics(self.logs, self.num_samples_seen) 

183 self.epoch_progress_tqdm.desc = metrics 

184 self.epoch_progress_tqdm.update(self.steps_to_update) 

185 

186 # reset steps to update 

187 self.steps_to_update = 0 

188 

189 # update timestamp for last update 

190 self.last_update_time = now 

191 

192 def on_train_begin(self, logs=None): 

193 self.is_training = True 

194 self._initialize_progbar("train_overall", None, logs) 

195 

196 def on_train_end(self, logs={}): 

197 self.is_training = False 

198 self._clean_up_progbar("train_overall", logs) 

199 

200 def on_test_begin(self, logs={}): 

201 if not self.is_training: 

202 self._initialize_progbar("test", None, logs) 

203 

204 def on_test_end(self, logs={}): 

205 if not self.is_training: 

206 self._clean_up_progbar("test", self.logs) 

207 

208 def on_epoch_begin(self, epoch, logs={}): 

209 self._initialize_progbar("train_epoch", epoch, logs) 

210 

211 def on_epoch_end(self, epoch, logs={}): 

212 self._clean_up_progbar("train_epoch", logs) 

213 if self.show_overall_progress: 

214 self.overall_progress_tqdm.update(1) 

215 

216 def on_test_batch_end(self, batch, logs={}): 

217 if not self.is_training: 

218 self._update_progbar(logs) 

219 

220 def on_batch_end(self, batch, logs={}): 

221 self._update_progbar(logs) 

222 

223 def format_metrics(self, logs={}, factor=1): 

224 """Format metrics in logs into a string. 

225 

226 Args: 

227 logs: dictionary of metrics and their values. Defaults to 

228 empty dictionary. 

229 factor (int): The factor we want to divide the metrics in logs 

230 by, useful when we are computing the logs after each batch. 

231 Defaults to 1. 

232 

233 Returns: 

234 metrics_string: a string displaying metrics using the given 

235 formators passed in through the constructor. 

236 """ 

237 

238 metric_value_pairs = [] 

239 for key, value in logs.items(): 

240 if key in ["batch", "size"]: 

241 continue 

242 pair = self.metrics_format.format(name=key, value=value / factor) 

243 metric_value_pairs.append(pair) 

244 metrics_string = self.metrics_separator.join(metric_value_pairs) 

245 return metrics_string 

246 

247 def get_config(self): 

248 config = { 

249 "metrics_separator": self.metrics_separator, 

250 "overall_bar_format": self.overall_bar_format, 

251 "epoch_bar_format": self.epoch_bar_format, 

252 "leave_epoch_progress": self.leave_epoch_progress, 

253 "leave_overall_progress": self.leave_overall_progress, 

254 "show_epoch_progress": self.show_epoch_progress, 

255 "show_overall_progress": self.show_overall_progress, 

256 } 

257 

258 base_config = super().get_config() 

259 return {**base_config, **config}