1""" 
    2Progress bar implementation on top of prompt_toolkit. 
    3 
    4:: 
    5 
    6    with ProgressBar(...) as pb: 
    7        for item in pb(data): 
    8            ... 
    9""" 
    10 
    11from __future__ import annotations 
    12 
    13import contextvars 
    14import datetime 
    15import functools 
    16import os 
    17import signal 
    18import threading 
    19import traceback 
    20from typing import ( 
    21    Callable, 
    22    Generic, 
    23    Iterable, 
    24    Iterator, 
    25    Sequence, 
    26    Sized, 
    27    TextIO, 
    28    TypeVar, 
    29    cast, 
    30) 
    31 
    32from prompt_toolkit.application import Application 
    33from prompt_toolkit.application.current import get_app_session 
    34from prompt_toolkit.filters import Condition, is_done, renderer_height_is_known 
    35from prompt_toolkit.formatted_text import ( 
    36    AnyFormattedText, 
    37    StyleAndTextTuples, 
    38    to_formatted_text, 
    39) 
    40from prompt_toolkit.input import Input 
    41from prompt_toolkit.key_binding import KeyBindings 
    42from prompt_toolkit.key_binding.key_processor import KeyPressEvent 
    43from prompt_toolkit.layout import ( 
    44    ConditionalContainer, 
    45    FormattedTextControl, 
    46    HSplit, 
    47    Layout, 
    48    VSplit, 
    49    Window, 
    50) 
    51from prompt_toolkit.layout.controls import UIContent, UIControl 
    52from prompt_toolkit.layout.dimension import AnyDimension, D 
    53from prompt_toolkit.output import ColorDepth, Output 
    54from prompt_toolkit.styles import BaseStyle 
    55from prompt_toolkit.utils import in_main_thread 
    56 
    57from .formatters import Formatter, create_default_formatters 
    58 
    59__all__ = ["ProgressBar"] 
    60 
    61E = KeyPressEvent 
    62 
    63_SIGWINCH = getattr(signal, "SIGWINCH", None) 
    64 
    65 
    66def create_key_bindings(cancel_callback: Callable[[], None] | None) -> KeyBindings: 
    67    """ 
    68    Key bindings handled by the progress bar. 
    69    (The main thread is not supposed to handle any key bindings.) 
    70    """ 
    71    kb = KeyBindings() 
    72 
    73    @kb.add("c-l") 
    74    def _clear(event: E) -> None: 
    75        event.app.renderer.clear() 
    76 
    77    if cancel_callback is not None: 
    78 
    79        @kb.add("c-c") 
    80        def _interrupt(event: E) -> None: 
    81            "Kill the 'body' of the progress bar, but only if we run from the main thread." 
    82            assert cancel_callback is not None 
    83            cancel_callback() 
    84 
    85    return kb 
    86 
    87 
    88_T = TypeVar("_T") 
    89 
    90 
    91class ProgressBar: 
    92    """ 
    93    Progress bar context manager. 
    94 
    95    Usage :: 
    96 
    97        with ProgressBar(...) as pb: 
    98            for item in pb(data): 
    99                ... 
    100 
    101    :param title: Text to be displayed above the progress bars. This can be a 
    102        callable or formatted text as well. 
    103    :param formatters: List of :class:`.Formatter` instances. 
    104    :param bottom_toolbar: Text to be displayed in the bottom toolbar. This 
    105        can be a callable or formatted text. 
    106    :param style: :class:`prompt_toolkit.styles.BaseStyle` instance. 
    107    :param key_bindings: :class:`.KeyBindings` instance. 
    108    :param cancel_callback: Callback function that's called when control-c is 
    109        pressed by the user. This can be used for instance to start "proper" 
    110        cancellation if the wrapped code supports it. 
    111    :param file: The file object used for rendering, by default `sys.stderr` is used. 
    112 
    113    :param color_depth: `prompt_toolkit` `ColorDepth` instance. 
    114    :param output: :class:`~prompt_toolkit.output.Output` instance. 
    115    :param input: :class:`~prompt_toolkit.input.Input` instance. 
    116    """ 
    117 
    118    def __init__( 
    119        self, 
    120        title: AnyFormattedText = None, 
    121        formatters: Sequence[Formatter] | None = None, 
    122        bottom_toolbar: AnyFormattedText = None, 
    123        style: BaseStyle | None = None, 
    124        key_bindings: KeyBindings | None = None, 
    125        cancel_callback: Callable[[], None] | None = None, 
    126        file: TextIO | None = None, 
    127        color_depth: ColorDepth | None = None, 
    128        output: Output | None = None, 
    129        input: Input | None = None, 
    130    ) -> None: 
    131        self.title = title 
    132        self.formatters = formatters or create_default_formatters() 
    133        self.bottom_toolbar = bottom_toolbar 
    134        self.counters: list[ProgressBarCounter[object]] = [] 
    135        self.style = style 
    136        self.key_bindings = key_bindings 
    137        self.cancel_callback = cancel_callback 
    138 
    139        # If no `cancel_callback` was given, and we're creating the progress 
    140        # bar from the main thread. Cancel by sending a `KeyboardInterrupt` to 
    141        # the main thread. 
    142        if self.cancel_callback is None and in_main_thread(): 
    143 
    144            def keyboard_interrupt_to_main_thread() -> None: 
    145                os.kill(os.getpid(), signal.SIGINT) 
    146 
    147            self.cancel_callback = keyboard_interrupt_to_main_thread 
    148 
    149        # Note that we use __stderr__ as default error output, because that 
    150        # works best with `patch_stdout`. 
    151        self.color_depth = color_depth 
    152        self.output = output or get_app_session().output 
    153        self.input = input or get_app_session().input 
    154 
    155        self._thread: threading.Thread | None = None 
    156 
    157        self._has_sigwinch = False 
    158        self._app_started = threading.Event() 
    159 
    160    def __enter__(self) -> ProgressBar: 
    161        # Create UI Application. 
    162        title_toolbar = ConditionalContainer( 
    163            Window( 
    164                FormattedTextControl(lambda: self.title), 
    165                height=1, 
    166                style="class:progressbar,title", 
    167            ), 
    168            filter=Condition(lambda: self.title is not None), 
    169        ) 
    170 
    171        bottom_toolbar = ConditionalContainer( 
    172            Window( 
    173                FormattedTextControl( 
    174                    lambda: self.bottom_toolbar, style="class:bottom-toolbar.text" 
    175                ), 
    176                style="class:bottom-toolbar", 
    177                height=1, 
    178            ), 
    179            filter=~is_done 
    180            & renderer_height_is_known 
    181            & Condition(lambda: self.bottom_toolbar is not None), 
    182        ) 
    183 
    184        def width_for_formatter(formatter: Formatter) -> AnyDimension: 
    185            # Needs to be passed as callable (partial) to the 'width' 
    186            # parameter, because we want to call it on every resize. 
    187            return formatter.get_width(progress_bar=self) 
    188 
    189        progress_controls = [ 
    190            Window( 
    191                content=_ProgressControl(self, f, self.cancel_callback), 
    192                width=functools.partial(width_for_formatter, f), 
    193            ) 
    194            for f in self.formatters 
    195        ] 
    196 
    197        self.app: Application[None] = Application( 
    198            min_redraw_interval=0.05, 
    199            layout=Layout( 
    200                HSplit( 
    201                    [ 
    202                        title_toolbar, 
    203                        VSplit( 
    204                            progress_controls, 
    205                            height=lambda: D( 
    206                                preferred=len(self.counters), max=len(self.counters) 
    207                            ), 
    208                        ), 
    209                        Window(), 
    210                        bottom_toolbar, 
    211                    ] 
    212                ) 
    213            ), 
    214            style=self.style, 
    215            key_bindings=self.key_bindings, 
    216            refresh_interval=0.3, 
    217            color_depth=self.color_depth, 
    218            output=self.output, 
    219            input=self.input, 
    220        ) 
    221 
    222        # Run application in different thread. 
    223        def run() -> None: 
    224            try: 
    225                self.app.run(pre_run=self._app_started.set) 
    226            except BaseException as e: 
    227                traceback.print_exc() 
    228                print(e) 
    229 
    230        ctx: contextvars.Context = contextvars.copy_context() 
    231 
    232        self._thread = threading.Thread(target=ctx.run, args=(run,)) 
    233        self._thread.start() 
    234 
    235        return self 
    236 
    237    def __exit__(self, *a: object) -> None: 
    238        # Wait for the app to be started. Make sure we don't quit earlier, 
    239        # otherwise `self.app.exit` won't terminate the app because 
    240        # `self.app.future` has not yet been set. 
    241        self._app_started.wait() 
    242 
    243        # Quit UI application. 
    244        if self.app.is_running and self.app.loop is not None: 
    245            self.app.loop.call_soon_threadsafe(self.app.exit) 
    246 
    247        if self._thread is not None: 
    248            self._thread.join() 
    249 
    250    def __call__( 
    251        self, 
    252        data: Iterable[_T] | None = None, 
    253        label: AnyFormattedText = "", 
    254        remove_when_done: bool = False, 
    255        total: int | None = None, 
    256    ) -> ProgressBarCounter[_T]: 
    257        """ 
    258        Start a new counter. 
    259 
    260        :param label: Title text or description for this progress. (This can be 
    261            formatted text as well). 
    262        :param remove_when_done: When `True`, hide this progress bar. 
    263        :param total: Specify the maximum value if it can't be calculated by 
    264            calling ``len``. 
    265        """ 
    266        counter = ProgressBarCounter( 
    267            self, data, label=label, remove_when_done=remove_when_done, total=total 
    268        ) 
    269        self.counters.append(counter) 
    270        return counter 
    271 
    272    def invalidate(self) -> None: 
    273        self.app.invalidate() 
    274 
    275 
    276class _ProgressControl(UIControl): 
    277    """ 
    278    User control for the progress bar. 
    279    """ 
    280 
    281    def __init__( 
    282        self, 
    283        progress_bar: ProgressBar, 
    284        formatter: Formatter, 
    285        cancel_callback: Callable[[], None] | None, 
    286    ) -> None: 
    287        self.progress_bar = progress_bar 
    288        self.formatter = formatter 
    289        self._key_bindings = create_key_bindings(cancel_callback) 
    290 
    291    def create_content(self, width: int, height: int) -> UIContent: 
    292        items: list[StyleAndTextTuples] = [] 
    293 
    294        for pr in self.progress_bar.counters: 
    295            try: 
    296                text = self.formatter.format(self.progress_bar, pr, width) 
    297            except BaseException: 
    298                traceback.print_exc() 
    299                text = "ERROR" 
    300 
    301            items.append(to_formatted_text(text)) 
    302 
    303        def get_line(i: int) -> StyleAndTextTuples: 
    304            return items[i] 
    305 
    306        return UIContent(get_line=get_line, line_count=len(items), show_cursor=False) 
    307 
    308    def is_focusable(self) -> bool: 
    309        return True  # Make sure that the key bindings work. 
    310 
    311    def get_key_bindings(self) -> KeyBindings: 
    312        return self._key_bindings 
    313 
    314 
    315_CounterItem = TypeVar("_CounterItem", covariant=True) 
    316 
    317 
    318class ProgressBarCounter(Generic[_CounterItem]): 
    319    """ 
    320    An individual counter (A progress bar can have multiple counters). 
    321    """ 
    322 
    323    def __init__( 
    324        self, 
    325        progress_bar: ProgressBar, 
    326        data: Iterable[_CounterItem] | None = None, 
    327        label: AnyFormattedText = "", 
    328        remove_when_done: bool = False, 
    329        total: int | None = None, 
    330    ) -> None: 
    331        self.start_time = datetime.datetime.now() 
    332        self.stop_time: datetime.datetime | None = None 
    333        self.progress_bar = progress_bar 
    334        self.data = data 
    335        self.items_completed = 0 
    336        self.label = label 
    337        self.remove_when_done = remove_when_done 
    338        self._done = False 
    339        self.total: int | None 
    340 
    341        if total is None: 
    342            try: 
    343                self.total = len(cast(Sized, data)) 
    344            except TypeError: 
    345                self.total = None  # We don't know the total length. 
    346        else: 
    347            self.total = total 
    348 
    349    def __iter__(self) -> Iterator[_CounterItem]: 
    350        if self.data is not None: 
    351            try: 
    352                for item in self.data: 
    353                    yield item 
    354                    self.item_completed() 
    355 
    356                # Only done if we iterate to the very end. 
    357                self.done = True 
    358            finally: 
    359                # Ensure counter has stopped even if we did not iterate to the 
    360                # end (e.g. break or exceptions). 
    361                self.stopped = True 
    362        else: 
    363            raise NotImplementedError("No data defined to iterate over.") 
    364 
    365    def item_completed(self) -> None: 
    366        """ 
    367        Start handling the next item. 
    368 
    369        (Can be called manually in case we don't have a collection to loop through.) 
    370        """ 
    371        self.items_completed += 1 
    372        self.progress_bar.invalidate() 
    373 
    374    @property 
    375    def done(self) -> bool: 
    376        """Whether a counter has been completed. 
    377 
    378        Done counter have been stopped (see stopped) and removed depending on 
    379        remove_when_done value. 
    380 
    381        Contrast this with stopped. A stopped counter may be terminated before 
    382        100% completion. A done counter has reached its 100% completion. 
    383        """ 
    384        return self._done 
    385 
    386    @done.setter 
    387    def done(self, value: bool) -> None: 
    388        self._done = value 
    389        self.stopped = value 
    390 
    391        if value and self.remove_when_done: 
    392            self.progress_bar.counters.remove(self) 
    393 
    394    @property 
    395    def stopped(self) -> bool: 
    396        """Whether a counter has been stopped. 
    397 
    398        Stopped counters no longer have increasing time_elapsed. This distinction is 
    399        also used to prevent the Bar formatter with unknown totals from continuing to run. 
    400 
    401        A stopped counter (but not done) can be used to signal that a given counter has 
    402        encountered an error but allows other counters to continue 
    403        (e.g. download X of Y failed). Given how only done counters are removed 
    404        (see remove_when_done) this can help aggregate failures from a large number of 
    405        successes. 
    406 
    407        Contrast this with done. A done counter has reached its 100% completion. 
    408        A stopped counter may be terminated before 100% completion. 
    409        """ 
    410        return self.stop_time is not None 
    411 
    412    @stopped.setter 
    413    def stopped(self, value: bool) -> None: 
    414        if value: 
    415            # This counter has not already been stopped. 
    416            if not self.stop_time: 
    417                self.stop_time = datetime.datetime.now() 
    418        else: 
    419            # Clearing any previously set stop_time. 
    420            self.stop_time = None 
    421 
    422    @property 
    423    def percentage(self) -> float: 
    424        if self.total is None: 
    425            return 0 
    426        else: 
    427            return self.items_completed * 100 / max(self.total, 1) 
    428 
    429    @property 
    430    def time_elapsed(self) -> datetime.timedelta: 
    431        """ 
    432        Return how much time has been elapsed since the start. 
    433        """ 
    434        if self.stop_time is None: 
    435            return datetime.datetime.now() - self.start_time 
    436        else: 
    437            return self.stop_time - self.start_time 
    438 
    439    @property 
    440    def time_left(self) -> datetime.timedelta | None: 
    441        """ 
    442        Timedelta representing the time left. 
    443        """ 
    444        if self.total is None or not self.percentage: 
    445            return None 
    446        elif self.done or self.stopped: 
    447            return datetime.timedelta(0) 
    448        else: 
    449            return self.time_elapsed * (100 - self.percentage) / self.percentage