1"""
2Formatter classes for the progress bar.
3Each progress bar consists of a list of these formatters.
4"""
5
6from __future__ import annotations
7
8import datetime
9import time
10from abc import ABCMeta, abstractmethod
11from typing import TYPE_CHECKING
12
13from prompt_toolkit.formatted_text import (
14 HTML,
15 AnyFormattedText,
16 StyleAndTextTuples,
17 to_formatted_text,
18)
19from prompt_toolkit.formatted_text.utils import fragment_list_width
20from prompt_toolkit.layout.dimension import AnyDimension, D
21from prompt_toolkit.layout.utils import explode_text_fragments
22from prompt_toolkit.utils import get_cwidth
23
24if TYPE_CHECKING:
25 from .base import ProgressBar, ProgressBarCounter
26
27__all__ = [
28 "Formatter",
29 "Text",
30 "Label",
31 "Percentage",
32 "Bar",
33 "Progress",
34 "TimeElapsed",
35 "TimeLeft",
36 "IterationsPerSecond",
37 "SpinningWheel",
38 "Rainbow",
39 "create_default_formatters",
40]
41
42
43class Formatter(metaclass=ABCMeta):
44 """
45 Base class for any formatter.
46 """
47
48 @abstractmethod
49 def format(
50 self,
51 progress_bar: ProgressBar,
52 progress: ProgressBarCounter[object],
53 width: int,
54 ) -> AnyFormattedText:
55 pass
56
57 def get_width(self, progress_bar: ProgressBar) -> AnyDimension:
58 return D()
59
60
61class Text(Formatter):
62 """
63 Display plain text.
64 """
65
66 def __init__(self, text: AnyFormattedText, style: str = "") -> None:
67 self.text = to_formatted_text(text, style=style)
68
69 def format(
70 self,
71 progress_bar: ProgressBar,
72 progress: ProgressBarCounter[object],
73 width: int,
74 ) -> AnyFormattedText:
75 return self.text
76
77 def get_width(self, progress_bar: ProgressBar) -> AnyDimension:
78 return fragment_list_width(self.text)
79
80
81class Label(Formatter):
82 """
83 Display the name of the current task.
84
85 :param width: If a `width` is given, use this width. Scroll the text if it
86 doesn't fit in this width.
87 :param suffix: String suffix to be added after the task name, e.g. ': '.
88 If no task name was given, no suffix will be added.
89 """
90
91 def __init__(self, width: AnyDimension = None, suffix: str = "") -> None:
92 self.width = width
93 self.suffix = suffix
94
95 def _add_suffix(self, label: AnyFormattedText) -> StyleAndTextTuples:
96 label = to_formatted_text(label, style="class:label")
97 return label + [("", self.suffix)]
98
99 def format(
100 self,
101 progress_bar: ProgressBar,
102 progress: ProgressBarCounter[object],
103 width: int,
104 ) -> AnyFormattedText:
105 label = self._add_suffix(progress.label)
106 cwidth = fragment_list_width(label)
107
108 if cwidth > width:
109 # It doesn't fit -> scroll task name.
110 label = explode_text_fragments(label)
111 max_scroll = cwidth - width
112 current_scroll = int(time.time() * 3 % max_scroll)
113 label = label[current_scroll:]
114
115 return label
116
117 def get_width(self, progress_bar: ProgressBar) -> AnyDimension:
118 if self.width:
119 return self.width
120
121 all_labels = [self._add_suffix(c.label) for c in progress_bar.counters]
122 if all_labels:
123 max_widths = max(fragment_list_width(l) for l in all_labels)
124 return D(preferred=max_widths, max=max_widths)
125 else:
126 return D()
127
128
129class Percentage(Formatter):
130 """
131 Display the progress as a percentage.
132 """
133
134 template = HTML("<percentage>{percentage:>5}%</percentage>")
135
136 def format(
137 self,
138 progress_bar: ProgressBar,
139 progress: ProgressBarCounter[object],
140 width: int,
141 ) -> AnyFormattedText:
142 return self.template.format(percentage=round(progress.percentage, 1))
143
144 def get_width(self, progress_bar: ProgressBar) -> AnyDimension:
145 return D.exact(6)
146
147
148class Bar(Formatter):
149 """
150 Display the progress bar itself.
151 """
152
153 template = HTML(
154 "<bar>{start}<bar-a>{bar_a}</bar-a><bar-b>{bar_b}</bar-b><bar-c>{bar_c}</bar-c>{end}</bar>"
155 )
156
157 def __init__(
158 self,
159 start: str = "[",
160 end: str = "]",
161 sym_a: str = "=",
162 sym_b: str = ">",
163 sym_c: str = " ",
164 unknown: str = "#",
165 ) -> None:
166 assert len(sym_a) == 1 and get_cwidth(sym_a) == 1
167 assert len(sym_c) == 1 and get_cwidth(sym_c) == 1
168
169 self.start = start
170 self.end = end
171 self.sym_a = sym_a
172 self.sym_b = sym_b
173 self.sym_c = sym_c
174 self.unknown = unknown
175
176 def format(
177 self,
178 progress_bar: ProgressBar,
179 progress: ProgressBarCounter[object],
180 width: int,
181 ) -> AnyFormattedText:
182 if progress.done or progress.total or progress.stopped:
183 sym_a, sym_b, sym_c = self.sym_a, self.sym_b, self.sym_c
184
185 # Compute pb_a based on done, total, or stopped states.
186 if progress.done:
187 # 100% completed irrelevant of how much was actually marked as completed.
188 percent = 1.0
189 else:
190 # Show percentage completed.
191 percent = progress.percentage / 100
192 else:
193 # Total is unknown and bar is still running.
194 sym_a, sym_b, sym_c = self.sym_c, self.unknown, self.sym_c
195
196 # Compute percent based on the time.
197 percent = time.time() * 20 % 100 / 100
198
199 # Subtract left, sym_b, and right.
200 width -= get_cwidth(self.start + sym_b + self.end)
201
202 # Scale percent by width
203 pb_a = int(percent * width)
204 bar_a = sym_a * pb_a
205 bar_b = sym_b
206 bar_c = sym_c * (width - pb_a)
207
208 return self.template.format(
209 start=self.start, end=self.end, bar_a=bar_a, bar_b=bar_b, bar_c=bar_c
210 )
211
212 def get_width(self, progress_bar: ProgressBar) -> AnyDimension:
213 return D(min=9)
214
215
216class Progress(Formatter):
217 """
218 Display the progress as text. E.g. "8/20"
219 """
220
221 template = HTML("<current>{current:>3}</current>/<total>{total:>3}</total>")
222
223 def format(
224 self,
225 progress_bar: ProgressBar,
226 progress: ProgressBarCounter[object],
227 width: int,
228 ) -> AnyFormattedText:
229 return self.template.format(
230 current=progress.items_completed, total=progress.total or "?"
231 )
232
233 def get_width(self, progress_bar: ProgressBar) -> AnyDimension:
234 all_lengths = [
235 len("{:>3}".format(c.total or "?")) for c in progress_bar.counters
236 ]
237 all_lengths.append(1)
238 return D.exact(max(all_lengths) * 2 + 1)
239
240
241def _format_timedelta(timedelta: datetime.timedelta) -> str:
242 """
243 Return hh:mm:ss, or mm:ss if the amount of hours is zero.
244 """
245 result = f"{timedelta}".split(".")[0]
246 if result.startswith("0:"):
247 result = result[2:]
248 return result
249
250
251class TimeElapsed(Formatter):
252 """
253 Display the elapsed time.
254 """
255
256 template = HTML("<time-elapsed>{time_elapsed}</time-elapsed>")
257
258 def format(
259 self,
260 progress_bar: ProgressBar,
261 progress: ProgressBarCounter[object],
262 width: int,
263 ) -> AnyFormattedText:
264 text = _format_timedelta(progress.time_elapsed).rjust(width)
265 return self.template.format(time_elapsed=text)
266
267 def get_width(self, progress_bar: ProgressBar) -> AnyDimension:
268 all_values = [
269 len(_format_timedelta(c.time_elapsed)) for c in progress_bar.counters
270 ]
271 if all_values:
272 return max(all_values)
273 return 0
274
275
276class TimeLeft(Formatter):
277 """
278 Display the time left.
279 """
280
281 template = HTML("<time-left>{time_left}</time-left>")
282 unknown = "?:??:??"
283
284 def format(
285 self,
286 progress_bar: ProgressBar,
287 progress: ProgressBarCounter[object],
288 width: int,
289 ) -> AnyFormattedText:
290 time_left = progress.time_left
291 if time_left is not None:
292 formatted_time_left = _format_timedelta(time_left)
293 else:
294 formatted_time_left = self.unknown
295
296 return self.template.format(time_left=formatted_time_left.rjust(width))
297
298 def get_width(self, progress_bar: ProgressBar) -> AnyDimension:
299 all_values = [
300 len(_format_timedelta(c.time_left)) if c.time_left is not None else 7
301 for c in progress_bar.counters
302 ]
303 if all_values:
304 return max(all_values)
305 return 0
306
307
308class IterationsPerSecond(Formatter):
309 """
310 Display the iterations per second.
311 """
312
313 template = HTML(
314 "<iterations-per-second>{iterations_per_second:.2f}</iterations-per-second>"
315 )
316
317 def format(
318 self,
319 progress_bar: ProgressBar,
320 progress: ProgressBarCounter[object],
321 width: int,
322 ) -> AnyFormattedText:
323 value = progress.items_completed / progress.time_elapsed.total_seconds()
324 return self.template.format(iterations_per_second=value)
325
326 def get_width(self, progress_bar: ProgressBar) -> AnyDimension:
327 all_values = [
328 len(f"{c.items_completed / c.time_elapsed.total_seconds():.2f}")
329 for c in progress_bar.counters
330 ]
331 if all_values:
332 return max(all_values)
333 return 0
334
335
336class SpinningWheel(Formatter):
337 """
338 Display a spinning wheel.
339 """
340
341 template = HTML("<spinning-wheel>{0}</spinning-wheel>")
342 characters = r"/-\|"
343
344 def format(
345 self,
346 progress_bar: ProgressBar,
347 progress: ProgressBarCounter[object],
348 width: int,
349 ) -> AnyFormattedText:
350 index = int(time.time() * 3) % len(self.characters)
351 return self.template.format(self.characters[index])
352
353 def get_width(self, progress_bar: ProgressBar) -> AnyDimension:
354 return D.exact(1)
355
356
357def _hue_to_rgb(hue: float) -> tuple[int, int, int]:
358 """
359 Take hue between 0 and 1, return (r, g, b).
360 """
361 i = int(hue * 6.0)
362 f = (hue * 6.0) - i
363
364 q = int(255 * (1.0 - f))
365 t = int(255 * (1.0 - (1.0 - f)))
366
367 i %= 6
368
369 return [
370 (255, t, 0),
371 (q, 255, 0),
372 (0, 255, t),
373 (0, q, 255),
374 (t, 0, 255),
375 (255, 0, q),
376 ][i]
377
378
379class Rainbow(Formatter):
380 """
381 For the fun. Add rainbow colors to any of the other formatters.
382 """
383
384 colors = ["#%.2x%.2x%.2x" % _hue_to_rgb(h / 100.0) for h in range(0, 100)]
385
386 def __init__(self, formatter: Formatter) -> None:
387 self.formatter = formatter
388
389 def format(
390 self,
391 progress_bar: ProgressBar,
392 progress: ProgressBarCounter[object],
393 width: int,
394 ) -> AnyFormattedText:
395 # Get formatted text from nested formatter, and explode it in
396 # text/style tuples.
397 result = self.formatter.format(progress_bar, progress, width)
398 result = explode_text_fragments(to_formatted_text(result))
399
400 # Insert colors.
401 result2: StyleAndTextTuples = []
402 shift = int(time.time() * 3) % len(self.colors)
403
404 for i, (style, text, *_) in enumerate(result):
405 result2.append(
406 (style + " " + self.colors[(i + shift) % len(self.colors)], text)
407 )
408 return result2
409
410 def get_width(self, progress_bar: ProgressBar) -> AnyDimension:
411 return self.formatter.get_width(progress_bar)
412
413
414def create_default_formatters() -> list[Formatter]:
415 """
416 Return the list of default formatters.
417 """
418 return [
419 Label(),
420 Text(" "),
421 Percentage(),
422 Text(" "),
423 Bar(),
424 Text(" "),
425 Progress(),
426 Text(" "),
427 Text("eta [", style="class:time-left"),
428 TimeLeft(),
429 Text("]", style="class:time-left"),
430 Text(" "),
431 ]