1""" 
    2Formatter classes for the progress bar. 
    3Each progress bar consists of a list of these formatters. 
    4""" 
    5 
    6from __future__ import annotations 
    7 
    8import datetime 
    9import time 
    10from abc import ABCMeta, abstractmethod 
    11from typing import TYPE_CHECKING 
    12 
    13from prompt_toolkit.formatted_text import ( 
    14    HTML, 
    15    AnyFormattedText, 
    16    StyleAndTextTuples, 
    17    to_formatted_text, 
    18) 
    19from prompt_toolkit.formatted_text.utils import fragment_list_width 
    20from prompt_toolkit.layout.dimension import AnyDimension, D 
    21from prompt_toolkit.layout.utils import explode_text_fragments 
    22from prompt_toolkit.utils import get_cwidth 
    23 
    24if TYPE_CHECKING: 
    25    from .base import ProgressBar, ProgressBarCounter 
    26 
    27__all__ = [ 
    28    "Formatter", 
    29    "Text", 
    30    "Label", 
    31    "Percentage", 
    32    "Bar", 
    33    "Progress", 
    34    "TimeElapsed", 
    35    "TimeLeft", 
    36    "IterationsPerSecond", 
    37    "SpinningWheel", 
    38    "Rainbow", 
    39    "create_default_formatters", 
    40] 
    41 
    42 
    43class Formatter(metaclass=ABCMeta): 
    44    """ 
    45    Base class for any formatter. 
    46    """ 
    47 
    48    @abstractmethod 
    49    def format( 
    50        self, 
    51        progress_bar: ProgressBar, 
    52        progress: ProgressBarCounter[object], 
    53        width: int, 
    54    ) -> AnyFormattedText: 
    55        pass 
    56 
    57    def get_width(self, progress_bar: ProgressBar) -> AnyDimension: 
    58        return D() 
    59 
    60 
    61class Text(Formatter): 
    62    """ 
    63    Display plain text. 
    64    """ 
    65 
    66    def __init__(self, text: AnyFormattedText, style: str = "") -> None: 
    67        self.text = to_formatted_text(text, style=style) 
    68 
    69    def format( 
    70        self, 
    71        progress_bar: ProgressBar, 
    72        progress: ProgressBarCounter[object], 
    73        width: int, 
    74    ) -> AnyFormattedText: 
    75        return self.text 
    76 
    77    def get_width(self, progress_bar: ProgressBar) -> AnyDimension: 
    78        return fragment_list_width(self.text) 
    79 
    80 
    81class Label(Formatter): 
    82    """ 
    83    Display the name of the current task. 
    84 
    85    :param width: If a `width` is given, use this width. Scroll the text if it 
    86        doesn't fit in this width. 
    87    :param suffix: String suffix to be added after the task name, e.g. ': '. 
    88        If no task name was given, no suffix will be added. 
    89    """ 
    90 
    91    def __init__(self, width: AnyDimension = None, suffix: str = "") -> None: 
    92        self.width = width 
    93        self.suffix = suffix 
    94 
    95    def _add_suffix(self, label: AnyFormattedText) -> StyleAndTextTuples: 
    96        label = to_formatted_text(label, style="class:label") 
    97        return label + [("", self.suffix)] 
    98 
    99    def format( 
    100        self, 
    101        progress_bar: ProgressBar, 
    102        progress: ProgressBarCounter[object], 
    103        width: int, 
    104    ) -> AnyFormattedText: 
    105        label = self._add_suffix(progress.label) 
    106        cwidth = fragment_list_width(label) 
    107 
    108        if cwidth > width: 
    109            # It doesn't fit -> scroll task name. 
    110            label = explode_text_fragments(label) 
    111            max_scroll = cwidth - width 
    112            current_scroll = int(time.time() * 3 % max_scroll) 
    113            label = label[current_scroll:] 
    114 
    115        return label 
    116 
    117    def get_width(self, progress_bar: ProgressBar) -> AnyDimension: 
    118        if self.width: 
    119            return self.width 
    120 
    121        all_labels = [self._add_suffix(c.label) for c in progress_bar.counters] 
    122        if all_labels: 
    123            max_widths = max(fragment_list_width(l) for l in all_labels) 
    124            return D(preferred=max_widths, max=max_widths) 
    125        else: 
    126            return D() 
    127 
    128 
    129class Percentage(Formatter): 
    130    """ 
    131    Display the progress as a percentage. 
    132    """ 
    133 
    134    template = HTML("<percentage>{percentage:>5}%</percentage>") 
    135 
    136    def format( 
    137        self, 
    138        progress_bar: ProgressBar, 
    139        progress: ProgressBarCounter[object], 
    140        width: int, 
    141    ) -> AnyFormattedText: 
    142        return self.template.format(percentage=round(progress.percentage, 1)) 
    143 
    144    def get_width(self, progress_bar: ProgressBar) -> AnyDimension: 
    145        return D.exact(6) 
    146 
    147 
    148class Bar(Formatter): 
    149    """ 
    150    Display the progress bar itself. 
    151    """ 
    152 
    153    template = HTML( 
    154        "<bar>{start}<bar-a>{bar_a}</bar-a><bar-b>{bar_b}</bar-b><bar-c>{bar_c}</bar-c>{end}</bar>" 
    155    ) 
    156 
    157    def __init__( 
    158        self, 
    159        start: str = "[", 
    160        end: str = "]", 
    161        sym_a: str = "=", 
    162        sym_b: str = ">", 
    163        sym_c: str = " ", 
    164        unknown: str = "#", 
    165    ) -> None: 
    166        assert len(sym_a) == 1 and get_cwidth(sym_a) == 1 
    167        assert len(sym_c) == 1 and get_cwidth(sym_c) == 1 
    168 
    169        self.start = start 
    170        self.end = end 
    171        self.sym_a = sym_a 
    172        self.sym_b = sym_b 
    173        self.sym_c = sym_c 
    174        self.unknown = unknown 
    175 
    176    def format( 
    177        self, 
    178        progress_bar: ProgressBar, 
    179        progress: ProgressBarCounter[object], 
    180        width: int, 
    181    ) -> AnyFormattedText: 
    182        if progress.done or progress.total or progress.stopped: 
    183            sym_a, sym_b, sym_c = self.sym_a, self.sym_b, self.sym_c 
    184 
    185            # Compute pb_a based on done, total, or stopped states. 
    186            if progress.done: 
    187                # 100% completed irrelevant of how much was actually marked as completed. 
    188                percent = 1.0 
    189            else: 
    190                # Show percentage completed. 
    191                percent = progress.percentage / 100 
    192        else: 
    193            # Total is unknown and bar is still running. 
    194            sym_a, sym_b, sym_c = self.sym_c, self.unknown, self.sym_c 
    195 
    196            # Compute percent based on the time. 
    197            percent = time.time() * 20 % 100 / 100 
    198 
    199        # Subtract left, sym_b, and right. 
    200        width -= get_cwidth(self.start + sym_b + self.end) 
    201 
    202        # Scale percent by width 
    203        pb_a = int(percent * width) 
    204        bar_a = sym_a * pb_a 
    205        bar_b = sym_b 
    206        bar_c = sym_c * (width - pb_a) 
    207 
    208        return self.template.format( 
    209            start=self.start, end=self.end, bar_a=bar_a, bar_b=bar_b, bar_c=bar_c 
    210        ) 
    211 
    212    def get_width(self, progress_bar: ProgressBar) -> AnyDimension: 
    213        return D(min=9) 
    214 
    215 
    216class Progress(Formatter): 
    217    """ 
    218    Display the progress as text.  E.g. "8/20" 
    219    """ 
    220 
    221    template = HTML("<current>{current:>3}</current>/<total>{total:>3}</total>") 
    222 
    223    def format( 
    224        self, 
    225        progress_bar: ProgressBar, 
    226        progress: ProgressBarCounter[object], 
    227        width: int, 
    228    ) -> AnyFormattedText: 
    229        return self.template.format( 
    230            current=progress.items_completed, total=progress.total or "?" 
    231        ) 
    232 
    233    def get_width(self, progress_bar: ProgressBar) -> AnyDimension: 
    234        all_lengths = [ 
    235            len("{:>3}".format(c.total or "?")) for c in progress_bar.counters 
    236        ] 
    237        all_lengths.append(1) 
    238        return D.exact(max(all_lengths) * 2 + 1) 
    239 
    240 
    241def _format_timedelta(timedelta: datetime.timedelta) -> str: 
    242    """ 
    243    Return hh:mm:ss, or mm:ss if the amount of hours is zero. 
    244    """ 
    245    result = f"{timedelta}".split(".")[0] 
    246    if result.startswith("0:"): 
    247        result = result[2:] 
    248    return result 
    249 
    250 
    251class TimeElapsed(Formatter): 
    252    """ 
    253    Display the elapsed time. 
    254    """ 
    255 
    256    template = HTML("<time-elapsed>{time_elapsed}</time-elapsed>") 
    257 
    258    def format( 
    259        self, 
    260        progress_bar: ProgressBar, 
    261        progress: ProgressBarCounter[object], 
    262        width: int, 
    263    ) -> AnyFormattedText: 
    264        text = _format_timedelta(progress.time_elapsed).rjust(width) 
    265        return self.template.format(time_elapsed=text) 
    266 
    267    def get_width(self, progress_bar: ProgressBar) -> AnyDimension: 
    268        all_values = [ 
    269            len(_format_timedelta(c.time_elapsed)) for c in progress_bar.counters 
    270        ] 
    271        if all_values: 
    272            return max(all_values) 
    273        return 0 
    274 
    275 
    276class TimeLeft(Formatter): 
    277    """ 
    278    Display the time left. 
    279    """ 
    280 
    281    template = HTML("<time-left>{time_left}</time-left>") 
    282    unknown = "?:??:??" 
    283 
    284    def format( 
    285        self, 
    286        progress_bar: ProgressBar, 
    287        progress: ProgressBarCounter[object], 
    288        width: int, 
    289    ) -> AnyFormattedText: 
    290        time_left = progress.time_left 
    291        if time_left is not None: 
    292            formatted_time_left = _format_timedelta(time_left) 
    293        else: 
    294            formatted_time_left = self.unknown 
    295 
    296        return self.template.format(time_left=formatted_time_left.rjust(width)) 
    297 
    298    def get_width(self, progress_bar: ProgressBar) -> AnyDimension: 
    299        all_values = [ 
    300            len(_format_timedelta(c.time_left)) if c.time_left is not None else 7 
    301            for c in progress_bar.counters 
    302        ] 
    303        if all_values: 
    304            return max(all_values) 
    305        return 0 
    306 
    307 
    308class IterationsPerSecond(Formatter): 
    309    """ 
    310    Display the iterations per second. 
    311    """ 
    312 
    313    template = HTML( 
    314        "<iterations-per-second>{iterations_per_second:.2f}</iterations-per-second>" 
    315    ) 
    316 
    317    def format( 
    318        self, 
    319        progress_bar: ProgressBar, 
    320        progress: ProgressBarCounter[object], 
    321        width: int, 
    322    ) -> AnyFormattedText: 
    323        value = progress.items_completed / progress.time_elapsed.total_seconds() 
    324        return self.template.format(iterations_per_second=value) 
    325 
    326    def get_width(self, progress_bar: ProgressBar) -> AnyDimension: 
    327        all_values = [ 
    328            len(f"{c.items_completed / c.time_elapsed.total_seconds():.2f}") 
    329            for c in progress_bar.counters 
    330        ] 
    331        if all_values: 
    332            return max(all_values) 
    333        return 0 
    334 
    335 
    336class SpinningWheel(Formatter): 
    337    """ 
    338    Display a spinning wheel. 
    339    """ 
    340 
    341    template = HTML("<spinning-wheel>{0}</spinning-wheel>") 
    342    characters = r"/-\|" 
    343 
    344    def format( 
    345        self, 
    346        progress_bar: ProgressBar, 
    347        progress: ProgressBarCounter[object], 
    348        width: int, 
    349    ) -> AnyFormattedText: 
    350        index = int(time.time() * 3) % len(self.characters) 
    351        return self.template.format(self.characters[index]) 
    352 
    353    def get_width(self, progress_bar: ProgressBar) -> AnyDimension: 
    354        return D.exact(1) 
    355 
    356 
    357def _hue_to_rgb(hue: float) -> tuple[int, int, int]: 
    358    """ 
    359    Take hue between 0 and 1, return (r, g, b). 
    360    """ 
    361    i = int(hue * 6.0) 
    362    f = (hue * 6.0) - i 
    363 
    364    q = int(255 * (1.0 - f)) 
    365    t = int(255 * (1.0 - (1.0 - f))) 
    366 
    367    i %= 6 
    368 
    369    return [ 
    370        (255, t, 0), 
    371        (q, 255, 0), 
    372        (0, 255, t), 
    373        (0, q, 255), 
    374        (t, 0, 255), 
    375        (255, 0, q), 
    376    ][i] 
    377 
    378 
    379class Rainbow(Formatter): 
    380    """ 
    381    For the fun. Add rainbow colors to any of the other formatters. 
    382    """ 
    383 
    384    colors = ["#%.2x%.2x%.2x" % _hue_to_rgb(h / 100.0) for h in range(0, 100)] 
    385 
    386    def __init__(self, formatter: Formatter) -> None: 
    387        self.formatter = formatter 
    388 
    389    def format( 
    390        self, 
    391        progress_bar: ProgressBar, 
    392        progress: ProgressBarCounter[object], 
    393        width: int, 
    394    ) -> AnyFormattedText: 
    395        # Get formatted text from nested formatter, and explode it in 
    396        # text/style tuples. 
    397        result = self.formatter.format(progress_bar, progress, width) 
    398        result = explode_text_fragments(to_formatted_text(result)) 
    399 
    400        # Insert colors. 
    401        result2: StyleAndTextTuples = [] 
    402        shift = int(time.time() * 3) % len(self.colors) 
    403 
    404        for i, (style, text, *_) in enumerate(result): 
    405            result2.append( 
    406                (style + " " + self.colors[(i + shift) % len(self.colors)], text) 
    407            ) 
    408        return result2 
    409 
    410    def get_width(self, progress_bar: ProgressBar) -> AnyDimension: 
    411        return self.formatter.get_width(progress_bar) 
    412 
    413 
    414def create_default_formatters() -> list[Formatter]: 
    415    """ 
    416    Return the list of default formatters. 
    417    """ 
    418    return [ 
    419        Label(), 
    420        Text(" "), 
    421        Percentage(), 
    422        Text(" "), 
    423        Bar(), 
    424        Text(" "), 
    425        Progress(), 
    426        Text(" "), 
    427        Text("eta [", style="class:time-left"), 
    428        TimeLeft(), 
    429        Text("]", style="class:time-left"), 
    430        Text(" "), 
    431    ]