1"""
2Progress bar implementation on top of prompt_toolkit.
3
4::
5
6 with ProgressBar(...) as pb:
7 for item in pb(data):
8 ...
9"""
10
11from __future__ import annotations
12
13import contextvars
14import datetime
15import functools
16import os
17import signal
18import threading
19import traceback
20from typing import (
21 Callable,
22 Generic,
23 Iterable,
24 Iterator,
25 Sequence,
26 Sized,
27 TextIO,
28 TypeVar,
29 cast,
30)
31
32from prompt_toolkit.application import Application
33from prompt_toolkit.application.current import get_app_session
34from prompt_toolkit.filters import Condition, is_done, renderer_height_is_known
35from prompt_toolkit.formatted_text import (
36 AnyFormattedText,
37 StyleAndTextTuples,
38 to_formatted_text,
39)
40from prompt_toolkit.input import Input
41from prompt_toolkit.key_binding import KeyBindings
42from prompt_toolkit.key_binding.key_processor import KeyPressEvent
43from prompt_toolkit.layout import (
44 ConditionalContainer,
45 FormattedTextControl,
46 HSplit,
47 Layout,
48 VSplit,
49 Window,
50)
51from prompt_toolkit.layout.controls import UIContent, UIControl
52from prompt_toolkit.layout.dimension import AnyDimension, D
53from prompt_toolkit.output import ColorDepth, Output
54from prompt_toolkit.styles import BaseStyle
55from prompt_toolkit.utils import in_main_thread
56
57from .formatters import Formatter, create_default_formatters
58
59__all__ = ["ProgressBar"]
60
61E = KeyPressEvent
62
63_SIGWINCH = getattr(signal, "SIGWINCH", None)
64
65
66def create_key_bindings(cancel_callback: Callable[[], None] | None) -> KeyBindings:
67 """
68 Key bindings handled by the progress bar.
69 (The main thread is not supposed to handle any key bindings.)
70 """
71 kb = KeyBindings()
72
73 @kb.add("c-l")
74 def _clear(event: E) -> None:
75 event.app.renderer.clear()
76
77 if cancel_callback is not None:
78
79 @kb.add("c-c")
80 def _interrupt(event: E) -> None:
81 "Kill the 'body' of the progress bar, but only if we run from the main thread."
82 assert cancel_callback is not None
83 cancel_callback()
84
85 return kb
86
87
88_T = TypeVar("_T")
89
90
91class ProgressBar:
92 """
93 Progress bar context manager.
94
95 Usage ::
96
97 with ProgressBar(...) as pb:
98 for item in pb(data):
99 ...
100
101 :param title: Text to be displayed above the progress bars. This can be a
102 callable or formatted text as well.
103 :param formatters: List of :class:`.Formatter` instances.
104 :param bottom_toolbar: Text to be displayed in the bottom toolbar. This
105 can be a callable or formatted text.
106 :param style: :class:`prompt_toolkit.styles.BaseStyle` instance.
107 :param key_bindings: :class:`.KeyBindings` instance.
108 :param cancel_callback: Callback function that's called when control-c is
109 pressed by the user. This can be used for instance to start "proper"
110 cancellation if the wrapped code supports it.
111 :param file: The file object used for rendering, by default `sys.stderr` is used.
112
113 :param color_depth: `prompt_toolkit` `ColorDepth` instance.
114 :param output: :class:`~prompt_toolkit.output.Output` instance.
115 :param input: :class:`~prompt_toolkit.input.Input` instance.
116 """
117
118 def __init__(
119 self,
120 title: AnyFormattedText = None,
121 formatters: Sequence[Formatter] | None = None,
122 bottom_toolbar: AnyFormattedText = None,
123 style: BaseStyle | None = None,
124 key_bindings: KeyBindings | None = None,
125 cancel_callback: Callable[[], None] | None = None,
126 file: TextIO | None = None,
127 color_depth: ColorDepth | None = None,
128 output: Output | None = None,
129 input: Input | None = None,
130 ) -> None:
131 self.title = title
132 self.formatters = formatters or create_default_formatters()
133 self.bottom_toolbar = bottom_toolbar
134 self.counters: list[ProgressBarCounter[object]] = []
135 self.style = style
136 self.key_bindings = key_bindings
137 self.cancel_callback = cancel_callback
138
139 # If no `cancel_callback` was given, and we're creating the progress
140 # bar from the main thread. Cancel by sending a `KeyboardInterrupt` to
141 # the main thread.
142 if self.cancel_callback is None and in_main_thread():
143
144 def keyboard_interrupt_to_main_thread() -> None:
145 os.kill(os.getpid(), signal.SIGINT)
146
147 self.cancel_callback = keyboard_interrupt_to_main_thread
148
149 # Note that we use __stderr__ as default error output, because that
150 # works best with `patch_stdout`.
151 self.color_depth = color_depth
152 self.output = output or get_app_session().output
153 self.input = input or get_app_session().input
154
155 self._thread: threading.Thread | None = None
156
157 self._has_sigwinch = False
158 self._app_started = threading.Event()
159
160 def __enter__(self) -> ProgressBar:
161 # Create UI Application.
162 title_toolbar = ConditionalContainer(
163 Window(
164 FormattedTextControl(lambda: self.title),
165 height=1,
166 style="class:progressbar,title",
167 ),
168 filter=Condition(lambda: self.title is not None),
169 )
170
171 bottom_toolbar = ConditionalContainer(
172 Window(
173 FormattedTextControl(
174 lambda: self.bottom_toolbar, style="class:bottom-toolbar.text"
175 ),
176 style="class:bottom-toolbar",
177 height=1,
178 ),
179 filter=~is_done
180 & renderer_height_is_known
181 & Condition(lambda: self.bottom_toolbar is not None),
182 )
183
184 def width_for_formatter(formatter: Formatter) -> AnyDimension:
185 # Needs to be passed as callable (partial) to the 'width'
186 # parameter, because we want to call it on every resize.
187 return formatter.get_width(progress_bar=self)
188
189 progress_controls = [
190 Window(
191 content=_ProgressControl(self, f, self.cancel_callback),
192 width=functools.partial(width_for_formatter, f),
193 )
194 for f in self.formatters
195 ]
196
197 self.app: Application[None] = Application(
198 min_redraw_interval=0.05,
199 layout=Layout(
200 HSplit(
201 [
202 title_toolbar,
203 VSplit(
204 progress_controls,
205 height=lambda: D(
206 preferred=len(self.counters), max=len(self.counters)
207 ),
208 ),
209 Window(),
210 bottom_toolbar,
211 ]
212 )
213 ),
214 style=self.style,
215 key_bindings=self.key_bindings,
216 refresh_interval=0.3,
217 color_depth=self.color_depth,
218 output=self.output,
219 input=self.input,
220 )
221
222 # Run application in different thread.
223 def run() -> None:
224 try:
225 self.app.run(pre_run=self._app_started.set)
226 except BaseException as e:
227 traceback.print_exc()
228 print(e)
229
230 ctx: contextvars.Context = contextvars.copy_context()
231
232 self._thread = threading.Thread(target=ctx.run, args=(run,))
233 self._thread.start()
234
235 return self
236
237 def __exit__(self, *a: object) -> None:
238 # Wait for the app to be started. Make sure we don't quit earlier,
239 # otherwise `self.app.exit` won't terminate the app because
240 # `self.app.future` has not yet been set.
241 self._app_started.wait()
242
243 # Quit UI application.
244 if self.app.is_running and self.app.loop is not None:
245 self.app.loop.call_soon_threadsafe(self.app.exit)
246
247 if self._thread is not None:
248 self._thread.join()
249
250 def __call__(
251 self,
252 data: Iterable[_T] | None = None,
253 label: AnyFormattedText = "",
254 remove_when_done: bool = False,
255 total: int | None = None,
256 ) -> ProgressBarCounter[_T]:
257 """
258 Start a new counter.
259
260 :param label: Title text or description for this progress. (This can be
261 formatted text as well).
262 :param remove_when_done: When `True`, hide this progress bar.
263 :param total: Specify the maximum value if it can't be calculated by
264 calling ``len``.
265 """
266 counter = ProgressBarCounter(
267 self, data, label=label, remove_when_done=remove_when_done, total=total
268 )
269 self.counters.append(counter)
270 return counter
271
272 def invalidate(self) -> None:
273 self.app.invalidate()
274
275
276class _ProgressControl(UIControl):
277 """
278 User control for the progress bar.
279 """
280
281 def __init__(
282 self,
283 progress_bar: ProgressBar,
284 formatter: Formatter,
285 cancel_callback: Callable[[], None] | None,
286 ) -> None:
287 self.progress_bar = progress_bar
288 self.formatter = formatter
289 self._key_bindings = create_key_bindings(cancel_callback)
290
291 def create_content(self, width: int, height: int) -> UIContent:
292 items: list[StyleAndTextTuples] = []
293
294 for pr in self.progress_bar.counters:
295 try:
296 text = self.formatter.format(self.progress_bar, pr, width)
297 except BaseException:
298 traceback.print_exc()
299 text = "ERROR"
300
301 items.append(to_formatted_text(text))
302
303 def get_line(i: int) -> StyleAndTextTuples:
304 return items[i]
305
306 return UIContent(get_line=get_line, line_count=len(items), show_cursor=False)
307
308 def is_focusable(self) -> bool:
309 return True # Make sure that the key bindings work.
310
311 def get_key_bindings(self) -> KeyBindings:
312 return self._key_bindings
313
314
315_CounterItem = TypeVar("_CounterItem", covariant=True)
316
317
318class ProgressBarCounter(Generic[_CounterItem]):
319 """
320 An individual counter (A progress bar can have multiple counters).
321 """
322
323 def __init__(
324 self,
325 progress_bar: ProgressBar,
326 data: Iterable[_CounterItem] | None = None,
327 label: AnyFormattedText = "",
328 remove_when_done: bool = False,
329 total: int | None = None,
330 ) -> None:
331 self.start_time = datetime.datetime.now()
332 self.stop_time: datetime.datetime | None = None
333 self.progress_bar = progress_bar
334 self.data = data
335 self.items_completed = 0
336 self.label = label
337 self.remove_when_done = remove_when_done
338 self._done = False
339 self.total: int | None
340
341 if total is None:
342 try:
343 self.total = len(cast(Sized, data))
344 except TypeError:
345 self.total = None # We don't know the total length.
346 else:
347 self.total = total
348
349 def __iter__(self) -> Iterator[_CounterItem]:
350 if self.data is not None:
351 try:
352 for item in self.data:
353 yield item
354 self.item_completed()
355
356 # Only done if we iterate to the very end.
357 self.done = True
358 finally:
359 # Ensure counter has stopped even if we did not iterate to the
360 # end (e.g. break or exceptions).
361 self.stopped = True
362 else:
363 raise NotImplementedError("No data defined to iterate over.")
364
365 def item_completed(self) -> None:
366 """
367 Start handling the next item.
368
369 (Can be called manually in case we don't have a collection to loop through.)
370 """
371 self.items_completed += 1
372 self.progress_bar.invalidate()
373
374 @property
375 def done(self) -> bool:
376 """Whether a counter has been completed.
377
378 Done counter have been stopped (see stopped) and removed depending on
379 remove_when_done value.
380
381 Contrast this with stopped. A stopped counter may be terminated before
382 100% completion. A done counter has reached its 100% completion.
383 """
384 return self._done
385
386 @done.setter
387 def done(self, value: bool) -> None:
388 self._done = value
389 self.stopped = value
390
391 if value and self.remove_when_done:
392 self.progress_bar.counters.remove(self)
393
394 @property
395 def stopped(self) -> bool:
396 """Whether a counter has been stopped.
397
398 Stopped counters no longer have increasing time_elapsed. This distinction is
399 also used to prevent the Bar formatter with unknown totals from continuing to run.
400
401 A stopped counter (but not done) can be used to signal that a given counter has
402 encountered an error but allows other counters to continue
403 (e.g. download X of Y failed). Given how only done counters are removed
404 (see remove_when_done) this can help aggregate failures from a large number of
405 successes.
406
407 Contrast this with done. A done counter has reached its 100% completion.
408 A stopped counter may be terminated before 100% completion.
409 """
410 return self.stop_time is not None
411
412 @stopped.setter
413 def stopped(self, value: bool) -> None:
414 if value:
415 # This counter has not already been stopped.
416 if not self.stop_time:
417 self.stop_time = datetime.datetime.now()
418 else:
419 # Clearing any previously set stop_time.
420 self.stop_time = None
421
422 @property
423 def percentage(self) -> float:
424 if self.total is None:
425 return 0
426 else:
427 return self.items_completed * 100 / max(self.total, 1)
428
429 @property
430 def time_elapsed(self) -> datetime.timedelta:
431 """
432 Return how much time has been elapsed since the start.
433 """
434 if self.stop_time is None:
435 return datetime.datetime.now() - self.start_time
436 else:
437 return self.stop_time - self.start_time
438
439 @property
440 def time_left(self) -> datetime.timedelta | None:
441 """
442 Timedelta representing the time left.
443 """
444 if self.total is None or not self.percentage:
445 return None
446 elif self.done or self.stopped:
447 return datetime.timedelta(0)
448 else:
449 return self.time_elapsed * (100 - self.percentage) / self.percentage