1"""
2Progress bar implementation on top of prompt_toolkit.
3
4::
5
6 with ProgressBar(...) as pb:
7 for item in pb(data):
8 ...
9"""
10
11from __future__ import annotations
12
13import contextvars
14import datetime
15import functools
16import os
17import signal
18import threading
19import traceback
20from collections.abc import Callable, Iterable, Iterator, Sequence, Sized
21from typing import (
22 Generic,
23 TextIO,
24 TypeVar,
25 cast,
26)
27
28from prompt_toolkit.application import Application
29from prompt_toolkit.application.current import get_app_session
30from prompt_toolkit.filters import Condition, is_done, renderer_height_is_known
31from prompt_toolkit.formatted_text import (
32 AnyFormattedText,
33 StyleAndTextTuples,
34 to_formatted_text,
35)
36from prompt_toolkit.input import Input
37from prompt_toolkit.key_binding import KeyBindings
38from prompt_toolkit.key_binding.key_processor import KeyPressEvent
39from prompt_toolkit.layout import (
40 ConditionalContainer,
41 FormattedTextControl,
42 HSplit,
43 Layout,
44 VSplit,
45 Window,
46)
47from prompt_toolkit.layout.controls import UIContent, UIControl
48from prompt_toolkit.layout.dimension import AnyDimension, D
49from prompt_toolkit.output import ColorDepth, Output
50from prompt_toolkit.styles import BaseStyle
51from prompt_toolkit.utils import in_main_thread
52
53from .formatters import Formatter, create_default_formatters
54
55__all__ = ["ProgressBar"]
56
57E = KeyPressEvent
58
59_SIGWINCH = getattr(signal, "SIGWINCH", None)
60
61
62def create_key_bindings(cancel_callback: Callable[[], None] | None) -> KeyBindings:
63 """
64 Key bindings handled by the progress bar.
65 (The main thread is not supposed to handle any key bindings.)
66 """
67 kb = KeyBindings()
68
69 @kb.add("c-l")
70 def _clear(event: E) -> None:
71 event.app.renderer.clear()
72
73 if cancel_callback is not None:
74
75 @kb.add("c-c")
76 def _interrupt(event: E) -> None:
77 "Kill the 'body' of the progress bar, but only if we run from the main thread."
78 assert cancel_callback is not None
79 cancel_callback()
80
81 return kb
82
83
84_T = TypeVar("_T")
85
86
87class ProgressBar:
88 """
89 Progress bar context manager.
90
91 Usage ::
92
93 with ProgressBar(...) as pb:
94 for item in pb(data):
95 ...
96
97 :param title: Text to be displayed above the progress bars. This can be a
98 callable or formatted text as well.
99 :param formatters: List of :class:`.Formatter` instances.
100 :param bottom_toolbar: Text to be displayed in the bottom toolbar. This
101 can be a callable or formatted text.
102 :param style: :class:`prompt_toolkit.styles.BaseStyle` instance.
103 :param key_bindings: :class:`.KeyBindings` instance.
104 :param cancel_callback: Callback function that's called when control-c is
105 pressed by the user. This can be used for instance to start "proper"
106 cancellation if the wrapped code supports it.
107 :param file: The file object used for rendering, by default `sys.stderr` is used.
108
109 :param color_depth: `prompt_toolkit` `ColorDepth` instance.
110 :param output: :class:`~prompt_toolkit.output.Output` instance.
111 :param input: :class:`~prompt_toolkit.input.Input` instance.
112 """
113
114 def __init__(
115 self,
116 title: AnyFormattedText = None,
117 formatters: Sequence[Formatter] | None = None,
118 bottom_toolbar: AnyFormattedText = None,
119 style: BaseStyle | None = None,
120 key_bindings: KeyBindings | None = None,
121 cancel_callback: Callable[[], None] | None = None,
122 file: TextIO | None = None,
123 color_depth: ColorDepth | None = None,
124 output: Output | None = None,
125 input: Input | None = None,
126 ) -> None:
127 self.title = title
128 self.formatters = formatters or create_default_formatters()
129 self.bottom_toolbar = bottom_toolbar
130 self.counters: list[ProgressBarCounter[object]] = []
131 self.style = style
132 self.key_bindings = key_bindings
133 self.cancel_callback = cancel_callback
134
135 # If no `cancel_callback` was given, and we're creating the progress
136 # bar from the main thread. Cancel by sending a `KeyboardInterrupt` to
137 # the main thread.
138 if self.cancel_callback is None and in_main_thread():
139
140 def keyboard_interrupt_to_main_thread() -> None:
141 os.kill(os.getpid(), signal.SIGINT)
142
143 self.cancel_callback = keyboard_interrupt_to_main_thread
144
145 # Note that we use __stderr__ as default error output, because that
146 # works best with `patch_stdout`.
147 self.color_depth = color_depth
148 self.output = output or get_app_session().output
149 self.input = input or get_app_session().input
150
151 self._thread: threading.Thread | None = None
152
153 self._has_sigwinch = False
154 self._app_started = threading.Event()
155
156 def __enter__(self) -> ProgressBar:
157 # Create UI Application.
158 title_toolbar = ConditionalContainer(
159 Window(
160 FormattedTextControl(lambda: self.title),
161 height=1,
162 style="class:progressbar,title",
163 ),
164 filter=Condition(lambda: self.title is not None),
165 )
166
167 bottom_toolbar = ConditionalContainer(
168 Window(
169 FormattedTextControl(
170 lambda: self.bottom_toolbar, style="class:bottom-toolbar.text"
171 ),
172 style="class:bottom-toolbar",
173 height=1,
174 ),
175 filter=~is_done
176 & renderer_height_is_known
177 & Condition(lambda: self.bottom_toolbar is not None),
178 )
179
180 def width_for_formatter(formatter: Formatter) -> AnyDimension:
181 # Needs to be passed as callable (partial) to the 'width'
182 # parameter, because we want to call it on every resize.
183 return formatter.get_width(progress_bar=self)
184
185 progress_controls = [
186 Window(
187 content=_ProgressControl(self, f, self.cancel_callback),
188 width=functools.partial(width_for_formatter, f),
189 )
190 for f in self.formatters
191 ]
192
193 self.app: Application[None] = Application(
194 min_redraw_interval=0.05,
195 layout=Layout(
196 HSplit(
197 [
198 title_toolbar,
199 VSplit(
200 progress_controls,
201 height=lambda: D(
202 preferred=len(self.counters), max=len(self.counters)
203 ),
204 ),
205 Window(),
206 bottom_toolbar,
207 ]
208 )
209 ),
210 style=self.style,
211 key_bindings=self.key_bindings,
212 refresh_interval=0.3,
213 color_depth=self.color_depth,
214 output=self.output,
215 input=self.input,
216 )
217
218 # Run application in different thread.
219 def run() -> None:
220 try:
221 self.app.run(pre_run=self._app_started.set)
222 except BaseException as e:
223 traceback.print_exc()
224 print(e)
225
226 ctx: contextvars.Context = contextvars.copy_context()
227
228 self._thread = threading.Thread(target=ctx.run, args=(run,))
229 self._thread.start()
230
231 return self
232
233 def __exit__(self, *a: object) -> None:
234 # Wait for the app to be started. Make sure we don't quit earlier,
235 # otherwise `self.app.exit` won't terminate the app because
236 # `self.app.future` has not yet been set.
237 self._app_started.wait()
238
239 # Quit UI application.
240 if self.app.is_running and self.app.loop is not None:
241 self.app.loop.call_soon_threadsafe(self.app.exit)
242
243 if self._thread is not None:
244 self._thread.join()
245
246 def __call__(
247 self,
248 data: Iterable[_T] | None = None,
249 label: AnyFormattedText = "",
250 remove_when_done: bool = False,
251 total: int | None = None,
252 ) -> ProgressBarCounter[_T]:
253 """
254 Start a new counter.
255
256 :param label: Title text or description for this progress. (This can be
257 formatted text as well).
258 :param remove_when_done: When `True`, hide this progress bar.
259 :param total: Specify the maximum value if it can't be calculated by
260 calling ``len``.
261 """
262 counter = ProgressBarCounter(
263 self, data, label=label, remove_when_done=remove_when_done, total=total
264 )
265 self.counters.append(counter)
266 return counter
267
268 def invalidate(self) -> None:
269 self.app.invalidate()
270
271
272class _ProgressControl(UIControl):
273 """
274 User control for the progress bar.
275 """
276
277 def __init__(
278 self,
279 progress_bar: ProgressBar,
280 formatter: Formatter,
281 cancel_callback: Callable[[], None] | None,
282 ) -> None:
283 self.progress_bar = progress_bar
284 self.formatter = formatter
285 self._key_bindings = create_key_bindings(cancel_callback)
286
287 def create_content(self, width: int, height: int) -> UIContent:
288 items: list[StyleAndTextTuples] = []
289
290 for pr in self.progress_bar.counters:
291 try:
292 text = self.formatter.format(self.progress_bar, pr, width)
293 except BaseException:
294 traceback.print_exc()
295 text = "ERROR"
296
297 items.append(to_formatted_text(text))
298
299 def get_line(i: int) -> StyleAndTextTuples:
300 return items[i]
301
302 return UIContent(get_line=get_line, line_count=len(items), show_cursor=False)
303
304 def is_focusable(self) -> bool:
305 return True # Make sure that the key bindings work.
306
307 def get_key_bindings(self) -> KeyBindings:
308 return self._key_bindings
309
310
311_CounterItem = TypeVar("_CounterItem", covariant=True)
312
313
314class ProgressBarCounter(Generic[_CounterItem]):
315 """
316 An individual counter (A progress bar can have multiple counters).
317 """
318
319 def __init__(
320 self,
321 progress_bar: ProgressBar,
322 data: Iterable[_CounterItem] | None = None,
323 label: AnyFormattedText = "",
324 remove_when_done: bool = False,
325 total: int | None = None,
326 ) -> None:
327 self.start_time = datetime.datetime.now()
328 self.stop_time: datetime.datetime | None = None
329 self.progress_bar = progress_bar
330 self.data = data
331 self.items_completed = 0
332 self.label = label
333 self.remove_when_done = remove_when_done
334 self._done = False
335 self.total: int | None
336
337 if total is None:
338 try:
339 self.total = len(cast(Sized, data))
340 except TypeError:
341 self.total = None # We don't know the total length.
342 else:
343 self.total = total
344
345 def __iter__(self) -> Iterator[_CounterItem]:
346 if self.data is not None:
347 try:
348 for item in self.data:
349 yield item
350 self.item_completed()
351
352 # Only done if we iterate to the very end.
353 self.done = True
354 finally:
355 # Ensure counter has stopped even if we did not iterate to the
356 # end (e.g. break or exceptions).
357 self.stopped = True
358 else:
359 raise NotImplementedError("No data defined to iterate over.")
360
361 def item_completed(self) -> None:
362 """
363 Start handling the next item.
364
365 (Can be called manually in case we don't have a collection to loop through.)
366 """
367 self.items_completed += 1
368 self.progress_bar.invalidate()
369
370 @property
371 def done(self) -> bool:
372 """Whether a counter has been completed.
373
374 Done counter have been stopped (see stopped) and removed depending on
375 remove_when_done value.
376
377 Contrast this with stopped. A stopped counter may be terminated before
378 100% completion. A done counter has reached its 100% completion.
379 """
380 return self._done
381
382 @done.setter
383 def done(self, value: bool) -> None:
384 self._done = value
385 self.stopped = value
386
387 if value and self.remove_when_done:
388 self.progress_bar.counters.remove(self)
389
390 @property
391 def stopped(self) -> bool:
392 """Whether a counter has been stopped.
393
394 Stopped counters no longer have increasing time_elapsed. This distinction is
395 also used to prevent the Bar formatter with unknown totals from continuing to run.
396
397 A stopped counter (but not done) can be used to signal that a given counter has
398 encountered an error but allows other counters to continue
399 (e.g. download X of Y failed). Given how only done counters are removed
400 (see remove_when_done) this can help aggregate failures from a large number of
401 successes.
402
403 Contrast this with done. A done counter has reached its 100% completion.
404 A stopped counter may be terminated before 100% completion.
405 """
406 return self.stop_time is not None
407
408 @stopped.setter
409 def stopped(self, value: bool) -> None:
410 if value:
411 # This counter has not already been stopped.
412 if not self.stop_time:
413 self.stop_time = datetime.datetime.now()
414 else:
415 # Clearing any previously set stop_time.
416 self.stop_time = None
417
418 @property
419 def percentage(self) -> float:
420 if self.total is None:
421 return 0
422 else:
423 return self.items_completed * 100 / max(self.total, 1)
424
425 @property
426 def time_elapsed(self) -> datetime.timedelta:
427 """
428 Return how much time has been elapsed since the start.
429 """
430 if self.stop_time is None:
431 return datetime.datetime.now() - self.start_time
432 else:
433 return self.stop_time - self.start_time
434
435 @property
436 def time_left(self) -> datetime.timedelta | None:
437 """
438 Timedelta representing the time left.
439 """
440 if self.total is None or not self.percentage:
441 return None
442 elif self.done or self.stopped:
443 return datetime.timedelta(0)
444 else:
445 return self.time_elapsed * (100 - self.percentage) / self.percentage