1""" 
    2Key bindings registry. 
    3 
    4A `KeyBindings` object is a container that holds a list of key bindings. It has a 
    5very efficient internal data structure for checking which key bindings apply 
    6for a pressed key. 
    7 
    8Typical usage:: 
    9 
    10    kb = KeyBindings() 
    11 
    12    @kb.add(Keys.ControlX, Keys.ControlC, filter=INSERT) 
    13    def handler(event): 
    14        # Handle ControlX-ControlC key sequence. 
    15        pass 
    16 
    17It is also possible to combine multiple KeyBindings objects. We do this in the 
    18default key bindings. There are some KeyBindings objects that contain the Emacs 
    19bindings, while others contain the Vi bindings. They are merged together using 
    20`merge_key_bindings`. 
    21 
    22We also have a `ConditionalKeyBindings` object that can enable/disable a group of 
    23key bindings at once. 
    24 
    25 
    26It is also possible to add a filter to a function, before a key binding has 
    27been assigned, through the `key_binding` decorator.:: 
    28 
    29    # First define a key handler with the `filter`. 
    30    @key_binding(filter=condition) 
    31    def my_key_binding(event): 
    32        ... 
    33 
    34    # Later, add it to the key bindings. 
    35    kb.add(Keys.A, my_key_binding) 
    36""" 
    37 
    38from __future__ import annotations 
    39 
    40from abc import ABCMeta, abstractmethod 
    41from inspect import isawaitable 
    42from typing import ( 
    43    TYPE_CHECKING, 
    44    Any, 
    45    Callable, 
    46    Coroutine, 
    47    Hashable, 
    48    Sequence, 
    49    Tuple, 
    50    TypeVar, 
    51    Union, 
    52    cast, 
    53) 
    54 
    55from prompt_toolkit.cache import SimpleCache 
    56from prompt_toolkit.filters import FilterOrBool, Never, to_filter 
    57from prompt_toolkit.keys import KEY_ALIASES, Keys 
    58 
    59if TYPE_CHECKING: 
    60    # Avoid circular imports. 
    61    from .key_processor import KeyPressEvent 
    62 
    63    # The only two return values for a mouse handler (and key bindings) are 
    64    # `None` and `NotImplemented`. For the type checker it's best to annotate 
    65    # this as `object`. (The consumer never expects a more specific instance: 
    66    # checking for NotImplemented can be done using `is NotImplemented`.) 
    67    NotImplementedOrNone = object 
    68    # Other non-working options are: 
    69    # * Optional[Literal[NotImplemented]] 
    70    #      --> Doesn't work, Literal can't take an Any. 
    71    # * None 
    72    #      --> Doesn't work. We can't assign the result of a function that 
    73    #          returns `None` to a variable. 
    74    # * Any 
    75    #      --> Works, but too broad. 
    76 
    77 
    78__all__ = [ 
    79    "NotImplementedOrNone", 
    80    "Binding", 
    81    "KeyBindingsBase", 
    82    "KeyBindings", 
    83    "ConditionalKeyBindings", 
    84    "merge_key_bindings", 
    85    "DynamicKeyBindings", 
    86    "GlobalOnlyKeyBindings", 
    87] 
    88 
    89# Key bindings can be regular functions or coroutines. 
    90# In both cases, if they return `NotImplemented`, the UI won't be invalidated. 
    91# This is mainly used in case of mouse move events, to prevent excessive 
    92# repainting during mouse move events. 
    93KeyHandlerCallable = Callable[ 
    94    ["KeyPressEvent"], 
    95    Union["NotImplementedOrNone", Coroutine[Any, Any, "NotImplementedOrNone"]], 
    96] 
    97 
    98 
    99class Binding: 
    100    """ 
    101    Key binding: (key sequence + handler + filter). 
    102    (Immutable binding class.) 
    103 
    104    :param record_in_macro: When True, don't record this key binding when a 
    105        macro is recorded. 
    106    """ 
    107 
    108    def __init__( 
    109        self, 
    110        keys: tuple[Keys | str, ...], 
    111        handler: KeyHandlerCallable, 
    112        filter: FilterOrBool = True, 
    113        eager: FilterOrBool = False, 
    114        is_global: FilterOrBool = False, 
    115        save_before: Callable[[KeyPressEvent], bool] = (lambda e: True), 
    116        record_in_macro: FilterOrBool = True, 
    117    ) -> None: 
    118        self.keys = keys 
    119        self.handler = handler 
    120        self.filter = to_filter(filter) 
    121        self.eager = to_filter(eager) 
    122        self.is_global = to_filter(is_global) 
    123        self.save_before = save_before 
    124        self.record_in_macro = to_filter(record_in_macro) 
    125 
    126    def call(self, event: KeyPressEvent) -> None: 
    127        result = self.handler(event) 
    128 
    129        # If the handler is a coroutine, create an asyncio task. 
    130        if isawaitable(result): 
    131            awaitable = cast(Coroutine[Any, Any, "NotImplementedOrNone"], result) 
    132 
    133            async def bg_task() -> None: 
    134                result = await awaitable 
    135                if result != NotImplemented: 
    136                    event.app.invalidate() 
    137 
    138            event.app.create_background_task(bg_task()) 
    139 
    140        elif result != NotImplemented: 
    141            event.app.invalidate() 
    142 
    143    def __repr__(self) -> str: 
    144        return ( 
    145            f"{self.__class__.__name__}(keys={self.keys!r}, handler={self.handler!r})" 
    146        ) 
    147 
    148 
    149# Sequence of keys presses. 
    150KeysTuple = Tuple[Union[Keys, str], ...] 
    151 
    152 
    153class KeyBindingsBase(metaclass=ABCMeta): 
    154    """ 
    155    Interface for a KeyBindings. 
    156    """ 
    157 
    158    @property 
    159    @abstractmethod 
    160    def _version(self) -> Hashable: 
    161        """ 
    162        For cache invalidation. - This should increase every time that 
    163        something changes. 
    164        """ 
    165        return 0 
    166 
    167    @abstractmethod 
    168    def get_bindings_for_keys(self, keys: KeysTuple) -> list[Binding]: 
    169        """ 
    170        Return a list of key bindings that can handle these keys. 
    171        (This return also inactive bindings, so the `filter` still has to be 
    172        called, for checking it.) 
    173 
    174        :param keys: tuple of keys. 
    175        """ 
    176        return [] 
    177 
    178    @abstractmethod 
    179    def get_bindings_starting_with_keys(self, keys: KeysTuple) -> list[Binding]: 
    180        """ 
    181        Return a list of key bindings that handle a key sequence starting with 
    182        `keys`. (It does only return bindings for which the sequences are 
    183        longer than `keys`. And like `get_bindings_for_keys`, it also includes 
    184        inactive bindings.) 
    185 
    186        :param keys: tuple of keys. 
    187        """ 
    188        return [] 
    189 
    190    @property 
    191    @abstractmethod 
    192    def bindings(self) -> list[Binding]: 
    193        """ 
    194        List of `Binding` objects. 
    195        (These need to be exposed, so that `KeyBindings` objects can be merged 
    196        together.) 
    197        """ 
    198        return [] 
    199 
    200    # `add` and `remove` don't have to be part of this interface. 
    201 
    202 
    203T = TypeVar("T", bound=Union[KeyHandlerCallable, Binding]) 
    204 
    205 
    206class KeyBindings(KeyBindingsBase): 
    207    """ 
    208    A container for a set of key bindings. 
    209 
    210    Example usage:: 
    211 
    212        kb = KeyBindings() 
    213 
    214        @kb.add('c-t') 
    215        def _(event): 
    216            print('Control-T pressed') 
    217 
    218        @kb.add('c-a', 'c-b') 
    219        def _(event): 
    220            print('Control-A pressed, followed by Control-B') 
    221 
    222        @kb.add('c-x', filter=is_searching) 
    223        def _(event): 
    224            print('Control-X pressed')  # Works only if we are searching. 
    225 
    226    """ 
    227 
    228    def __init__(self) -> None: 
    229        self._bindings: list[Binding] = [] 
    230        self._get_bindings_for_keys_cache: SimpleCache[KeysTuple, list[Binding]] = ( 
    231            SimpleCache(maxsize=10000) 
    232        ) 
    233        self._get_bindings_starting_with_keys_cache: SimpleCache[ 
    234            KeysTuple, list[Binding] 
    235        ] = SimpleCache(maxsize=1000) 
    236        self.__version = 0  # For cache invalidation. 
    237 
    238    def _clear_cache(self) -> None: 
    239        self.__version += 1 
    240        self._get_bindings_for_keys_cache.clear() 
    241        self._get_bindings_starting_with_keys_cache.clear() 
    242 
    243    @property 
    244    def bindings(self) -> list[Binding]: 
    245        return self._bindings 
    246 
    247    @property 
    248    def _version(self) -> Hashable: 
    249        return self.__version 
    250 
    251    def add( 
    252        self, 
    253        *keys: Keys | str, 
    254        filter: FilterOrBool = True, 
    255        eager: FilterOrBool = False, 
    256        is_global: FilterOrBool = False, 
    257        save_before: Callable[[KeyPressEvent], bool] = (lambda e: True), 
    258        record_in_macro: FilterOrBool = True, 
    259    ) -> Callable[[T], T]: 
    260        """ 
    261        Decorator for adding a key bindings. 
    262 
    263        :param filter: :class:`~prompt_toolkit.filters.Filter` to determine 
    264            when this key binding is active. 
    265        :param eager: :class:`~prompt_toolkit.filters.Filter` or `bool`. 
    266            When True, ignore potential longer matches when this key binding is 
    267            hit. E.g. when there is an active eager key binding for Ctrl-X, 
    268            execute the handler immediately and ignore the key binding for 
    269            Ctrl-X Ctrl-E of which it is a prefix. 
    270        :param is_global: When this key bindings is added to a `Container` or 
    271            `Control`, make it a global (always active) binding. 
    272        :param save_before: Callable that takes an `Event` and returns True if 
    273            we should save the current buffer, before handling the event. 
    274            (That's the default.) 
    275        :param record_in_macro: Record these key bindings when a macro is 
    276            being recorded. (True by default.) 
    277        """ 
    278        assert keys 
    279 
    280        keys = tuple(_parse_key(k) for k in keys) 
    281 
    282        if isinstance(filter, Never): 
    283            # When a filter is Never, it will always stay disabled, so in that 
    284            # case don't bother putting it in the key bindings. It will slow 
    285            # down every key press otherwise. 
    286            def decorator(func: T) -> T: 
    287                return func 
    288 
    289        else: 
    290 
    291            def decorator(func: T) -> T: 
    292                if isinstance(func, Binding): 
    293                    # We're adding an existing Binding object. 
    294                    self.bindings.append( 
    295                        Binding( 
    296                            keys, 
    297                            func.handler, 
    298                            filter=func.filter & to_filter(filter), 
    299                            eager=to_filter(eager) | func.eager, 
    300                            is_global=to_filter(is_global) | func.is_global, 
    301                            save_before=func.save_before, 
    302                            record_in_macro=func.record_in_macro, 
    303                        ) 
    304                    ) 
    305                else: 
    306                    self.bindings.append( 
    307                        Binding( 
    308                            keys, 
    309                            cast(KeyHandlerCallable, func), 
    310                            filter=filter, 
    311                            eager=eager, 
    312                            is_global=is_global, 
    313                            save_before=save_before, 
    314                            record_in_macro=record_in_macro, 
    315                        ) 
    316                    ) 
    317                self._clear_cache() 
    318 
    319                return func 
    320 
    321        return decorator 
    322 
    323    def remove(self, *args: Keys | str | KeyHandlerCallable) -> None: 
    324        """ 
    325        Remove a key binding. 
    326 
    327        This expects either a function that was given to `add` method as 
    328        parameter or a sequence of key bindings. 
    329 
    330        Raises `ValueError` when no bindings was found. 
    331 
    332        Usage:: 
    333 
    334            remove(handler)  # Pass handler. 
    335            remove('c-x', 'c-a')  # Or pass the key bindings. 
    336        """ 
    337        found = False 
    338 
    339        if callable(args[0]): 
    340            assert len(args) == 1 
    341            function = args[0] 
    342 
    343            # Remove the given function. 
    344            for b in self.bindings: 
    345                if b.handler == function: 
    346                    self.bindings.remove(b) 
    347                    found = True 
    348 
    349        else: 
    350            assert len(args) > 0 
    351            args = cast(Tuple[Union[Keys, str]], args) 
    352 
    353            # Remove this sequence of key bindings. 
    354            keys = tuple(_parse_key(k) for k in args) 
    355 
    356            for b in self.bindings: 
    357                if b.keys == keys: 
    358                    self.bindings.remove(b) 
    359                    found = True 
    360 
    361        if found: 
    362            self._clear_cache() 
    363        else: 
    364            # No key binding found for this function. Raise ValueError. 
    365            raise ValueError(f"Binding not found: {function!r}") 
    366 
    367    # For backwards-compatibility. 
    368    add_binding = add 
    369    remove_binding = remove 
    370 
    371    def get_bindings_for_keys(self, keys: KeysTuple) -> list[Binding]: 
    372        """ 
    373        Return a list of key bindings that can handle this key. 
    374        (This return also inactive bindings, so the `filter` still has to be 
    375        called, for checking it.) 
    376 
    377        :param keys: tuple of keys. 
    378        """ 
    379 
    380        def get() -> list[Binding]: 
    381            result: list[tuple[int, Binding]] = [] 
    382 
    383            for b in self.bindings: 
    384                if len(keys) == len(b.keys): 
    385                    match = True 
    386                    any_count = 0 
    387 
    388                    for i, j in zip(b.keys, keys): 
    389                        if i != j and i != Keys.Any: 
    390                            match = False 
    391                            break 
    392 
    393                        if i == Keys.Any: 
    394                            any_count += 1 
    395 
    396                    if match: 
    397                        result.append((any_count, b)) 
    398 
    399            # Place bindings that have more 'Any' occurrences in them at the end. 
    400            result = sorted(result, key=lambda item: -item[0]) 
    401 
    402            return [item[1] for item in result] 
    403 
    404        return self._get_bindings_for_keys_cache.get(keys, get) 
    405 
    406    def get_bindings_starting_with_keys(self, keys: KeysTuple) -> list[Binding]: 
    407        """ 
    408        Return a list of key bindings that handle a key sequence starting with 
    409        `keys`. (It does only return bindings for which the sequences are 
    410        longer than `keys`. And like `get_bindings_for_keys`, it also includes 
    411        inactive bindings.) 
    412 
    413        :param keys: tuple of keys. 
    414        """ 
    415 
    416        def get() -> list[Binding]: 
    417            result = [] 
    418            for b in self.bindings: 
    419                if len(keys) < len(b.keys): 
    420                    match = True 
    421                    for i, j in zip(b.keys, keys): 
    422                        if i != j and i != Keys.Any: 
    423                            match = False 
    424                            break 
    425                    if match: 
    426                        result.append(b) 
    427            return result 
    428 
    429        return self._get_bindings_starting_with_keys_cache.get(keys, get) 
    430 
    431 
    432def _parse_key(key: Keys | str) -> str | Keys: 
    433    """ 
    434    Replace key by alias and verify whether it's a valid one. 
    435    """ 
    436    # Already a parse key? -> Return it. 
    437    if isinstance(key, Keys): 
    438        return key 
    439 
    440    # Lookup aliases. 
    441    key = KEY_ALIASES.get(key, key) 
    442 
    443    # Replace 'space' by ' ' 
    444    if key == "space": 
    445        key = " " 
    446 
    447    # Return as `Key` object when it's a special key. 
    448    try: 
    449        return Keys(key) 
    450    except ValueError: 
    451        pass 
    452 
    453    # Final validation. 
    454    if len(key) != 1: 
    455        raise ValueError(f"Invalid key: {key}") 
    456 
    457    return key 
    458 
    459 
    460def key_binding( 
    461    filter: FilterOrBool = True, 
    462    eager: FilterOrBool = False, 
    463    is_global: FilterOrBool = False, 
    464    save_before: Callable[[KeyPressEvent], bool] = (lambda event: True), 
    465    record_in_macro: FilterOrBool = True, 
    466) -> Callable[[KeyHandlerCallable], Binding]: 
    467    """ 
    468    Decorator that turn a function into a `Binding` object. This can be added 
    469    to a `KeyBindings` object when a key binding is assigned. 
    470    """ 
    471    assert save_before is None or callable(save_before) 
    472 
    473    filter = to_filter(filter) 
    474    eager = to_filter(eager) 
    475    is_global = to_filter(is_global) 
    476    save_before = save_before 
    477    record_in_macro = to_filter(record_in_macro) 
    478    keys = () 
    479 
    480    def decorator(function: KeyHandlerCallable) -> Binding: 
    481        return Binding( 
    482            keys, 
    483            function, 
    484            filter=filter, 
    485            eager=eager, 
    486            is_global=is_global, 
    487            save_before=save_before, 
    488            record_in_macro=record_in_macro, 
    489        ) 
    490 
    491    return decorator 
    492 
    493 
    494class _Proxy(KeyBindingsBase): 
    495    """ 
    496    Common part for ConditionalKeyBindings and _MergedKeyBindings. 
    497    """ 
    498 
    499    def __init__(self) -> None: 
    500        # `KeyBindings` to be synchronized with all the others. 
    501        self._bindings2: KeyBindingsBase = KeyBindings() 
    502        self._last_version: Hashable = () 
    503 
    504    def _update_cache(self) -> None: 
    505        """ 
    506        If `self._last_version` is outdated, then this should update 
    507        the version and `self._bindings2`. 
    508        """ 
    509        raise NotImplementedError 
    510 
    511    # Proxy methods to self._bindings2. 
    512 
    513    @property 
    514    def bindings(self) -> list[Binding]: 
    515        self._update_cache() 
    516        return self._bindings2.bindings 
    517 
    518    @property 
    519    def _version(self) -> Hashable: 
    520        self._update_cache() 
    521        return self._last_version 
    522 
    523    def get_bindings_for_keys(self, keys: KeysTuple) -> list[Binding]: 
    524        self._update_cache() 
    525        return self._bindings2.get_bindings_for_keys(keys) 
    526 
    527    def get_bindings_starting_with_keys(self, keys: KeysTuple) -> list[Binding]: 
    528        self._update_cache() 
    529        return self._bindings2.get_bindings_starting_with_keys(keys) 
    530 
    531 
    532class ConditionalKeyBindings(_Proxy): 
    533    """ 
    534    Wraps around a `KeyBindings`. Disable/enable all the key bindings according to 
    535    the given (additional) filter.:: 
    536 
    537        @Condition 
    538        def setting_is_true(): 
    539            return True  # or False 
    540 
    541        registry = ConditionalKeyBindings(key_bindings, setting_is_true) 
    542 
    543    When new key bindings are added to this object. They are also 
    544    enable/disabled according to the given `filter`. 
    545 
    546    :param registries: List of :class:`.KeyBindings` objects. 
    547    :param filter: :class:`~prompt_toolkit.filters.Filter` object. 
    548    """ 
    549 
    550    def __init__( 
    551        self, key_bindings: KeyBindingsBase, filter: FilterOrBool = True 
    552    ) -> None: 
    553        _Proxy.__init__(self) 
    554 
    555        self.key_bindings = key_bindings 
    556        self.filter = to_filter(filter) 
    557 
    558    def _update_cache(self) -> None: 
    559        "If the original key bindings was changed. Update our copy version." 
    560        expected_version = self.key_bindings._version 
    561 
    562        if self._last_version != expected_version: 
    563            bindings2 = KeyBindings() 
    564 
    565            # Copy all bindings from `self.key_bindings`, adding our condition. 
    566            for b in self.key_bindings.bindings: 
    567                bindings2.bindings.append( 
    568                    Binding( 
    569                        keys=b.keys, 
    570                        handler=b.handler, 
    571                        filter=self.filter & b.filter, 
    572                        eager=b.eager, 
    573                        is_global=b.is_global, 
    574                        save_before=b.save_before, 
    575                        record_in_macro=b.record_in_macro, 
    576                    ) 
    577                ) 
    578 
    579            self._bindings2 = bindings2 
    580            self._last_version = expected_version 
    581 
    582 
    583class _MergedKeyBindings(_Proxy): 
    584    """ 
    585    Merge multiple registries of key bindings into one. 
    586 
    587    This class acts as a proxy to multiple :class:`.KeyBindings` objects, but 
    588    behaves as if this is just one bigger :class:`.KeyBindings`. 
    589 
    590    :param registries: List of :class:`.KeyBindings` objects. 
    591    """ 
    592 
    593    def __init__(self, registries: Sequence[KeyBindingsBase]) -> None: 
    594        _Proxy.__init__(self) 
    595        self.registries = registries 
    596 
    597    def _update_cache(self) -> None: 
    598        """ 
    599        If one of the original registries was changed. Update our merged 
    600        version. 
    601        """ 
    602        expected_version = tuple(r._version for r in self.registries) 
    603 
    604        if self._last_version != expected_version: 
    605            bindings2 = KeyBindings() 
    606 
    607            for reg in self.registries: 
    608                bindings2.bindings.extend(reg.bindings) 
    609 
    610            self._bindings2 = bindings2 
    611            self._last_version = expected_version 
    612 
    613 
    614def merge_key_bindings(bindings: Sequence[KeyBindingsBase]) -> _MergedKeyBindings: 
    615    """ 
    616    Merge multiple :class:`.Keybinding` objects together. 
    617 
    618    Usage:: 
    619 
    620        bindings = merge_key_bindings([bindings1, bindings2, ...]) 
    621    """ 
    622    return _MergedKeyBindings(bindings) 
    623 
    624 
    625class DynamicKeyBindings(_Proxy): 
    626    """ 
    627    KeyBindings class that can dynamically returns any KeyBindings. 
    628 
    629    :param get_key_bindings: Callable that returns a :class:`.KeyBindings` instance. 
    630    """ 
    631 
    632    def __init__(self, get_key_bindings: Callable[[], KeyBindingsBase | None]) -> None: 
    633        self.get_key_bindings = get_key_bindings 
    634        self.__version = 0 
    635        self._last_child_version = None 
    636        self._dummy = KeyBindings()  # Empty key bindings. 
    637 
    638    def _update_cache(self) -> None: 
    639        key_bindings = self.get_key_bindings() or self._dummy 
    640        assert isinstance(key_bindings, KeyBindingsBase) 
    641        version = id(key_bindings), key_bindings._version 
    642 
    643        self._bindings2 = key_bindings 
    644        self._last_version = version 
    645 
    646 
    647class GlobalOnlyKeyBindings(_Proxy): 
    648    """ 
    649    Wrapper around a :class:`.KeyBindings` object that only exposes the global 
    650    key bindings. 
    651    """ 
    652 
    653    def __init__(self, key_bindings: KeyBindingsBase) -> None: 
    654        _Proxy.__init__(self) 
    655        self.key_bindings = key_bindings 
    656 
    657    def _update_cache(self) -> None: 
    658        """ 
    659        If one of the original registries was changed. Update our merged 
    660        version. 
    661        """ 
    662        expected_version = self.key_bindings._version 
    663 
    664        if self._last_version != expected_version: 
    665            bindings2 = KeyBindings() 
    666 
    667            for b in self.key_bindings.bindings: 
    668                if b.is_global(): 
    669                    bindings2.bindings.append(b) 
    670 
    671            self._bindings2 = bindings2 
    672            self._last_version = expected_version