1"""
2Key bindings registry.
3
4A `KeyBindings` object is a container that holds a list of key bindings. It has a
5very efficient internal data structure for checking which key bindings apply
6for a pressed key.
7
8Typical usage::
9
10 kb = KeyBindings()
11
12 @kb.add(Keys.ControlX, Keys.ControlC, filter=INSERT)
13 def handler(event):
14 # Handle ControlX-ControlC key sequence.
15 pass
16
17It is also possible to combine multiple KeyBindings objects. We do this in the
18default key bindings. There are some KeyBindings objects that contain the Emacs
19bindings, while others contain the Vi bindings. They are merged together using
20`merge_key_bindings`.
21
22We also have a `ConditionalKeyBindings` object that can enable/disable a group of
23key bindings at once.
24
25
26It is also possible to add a filter to a function, before a key binding has
27been assigned, through the `key_binding` decorator.::
28
29 # First define a key handler with the `filter`.
30 @key_binding(filter=condition)
31 def my_key_binding(event):
32 ...
33
34 # Later, add it to the key bindings.
35 kb.add(Keys.A, my_key_binding)
36"""
37
38from __future__ import annotations
39
40from abc import ABCMeta, abstractmethod, abstractproperty
41from inspect import isawaitable
42from typing import (
43 TYPE_CHECKING,
44 Any,
45 Callable,
46 Coroutine,
47 Hashable,
48 Sequence,
49 Tuple,
50 TypeVar,
51 Union,
52 cast,
53)
54
55from prompt_toolkit.cache import SimpleCache
56from prompt_toolkit.filters import FilterOrBool, Never, to_filter
57from prompt_toolkit.keys import KEY_ALIASES, Keys
58
59if TYPE_CHECKING:
60 # Avoid circular imports.
61 from .key_processor import KeyPressEvent
62
63 # The only two return values for a mouse handler (and key bindings) are
64 # `None` and `NotImplemented`. For the type checker it's best to annotate
65 # this as `object`. (The consumer never expects a more specific instance:
66 # checking for NotImplemented can be done using `is NotImplemented`.)
67 NotImplementedOrNone = object
68 # Other non-working options are:
69 # * Optional[Literal[NotImplemented]]
70 # --> Doesn't work, Literal can't take an Any.
71 # * None
72 # --> Doesn't work. We can't assign the result of a function that
73 # returns `None` to a variable.
74 # * Any
75 # --> Works, but too broad.
76
77
78__all__ = [
79 "NotImplementedOrNone",
80 "Binding",
81 "KeyBindingsBase",
82 "KeyBindings",
83 "ConditionalKeyBindings",
84 "merge_key_bindings",
85 "DynamicKeyBindings",
86 "GlobalOnlyKeyBindings",
87]
88
89# Key bindings can be regular functions or coroutines.
90# In both cases, if they return `NotImplemented`, the UI won't be invalidated.
91# This is mainly used in case of mouse move events, to prevent excessive
92# repainting during mouse move events.
93KeyHandlerCallable = Callable[
94 ["KeyPressEvent"],
95 Union["NotImplementedOrNone", Coroutine[Any, Any, "NotImplementedOrNone"]],
96]
97
98
99class Binding:
100 """
101 Key binding: (key sequence + handler + filter).
102 (Immutable binding class.)
103
104 :param record_in_macro: When True, don't record this key binding when a
105 macro is recorded.
106 """
107
108 def __init__(
109 self,
110 keys: tuple[Keys | str, ...],
111 handler: KeyHandlerCallable,
112 filter: FilterOrBool = True,
113 eager: FilterOrBool = False,
114 is_global: FilterOrBool = False,
115 save_before: Callable[[KeyPressEvent], bool] = (lambda e: True),
116 record_in_macro: FilterOrBool = True,
117 ) -> None:
118 self.keys = keys
119 self.handler = handler
120 self.filter = to_filter(filter)
121 self.eager = to_filter(eager)
122 self.is_global = to_filter(is_global)
123 self.save_before = save_before
124 self.record_in_macro = to_filter(record_in_macro)
125
126 def call(self, event: KeyPressEvent) -> None:
127 result = self.handler(event)
128
129 # If the handler is a coroutine, create an asyncio task.
130 if isawaitable(result):
131 awaitable = cast(Coroutine[Any, Any, "NotImplementedOrNone"], result)
132
133 async def bg_task() -> None:
134 result = await awaitable
135 if result != NotImplemented:
136 event.app.invalidate()
137
138 event.app.create_background_task(bg_task())
139
140 elif result != NotImplemented:
141 event.app.invalidate()
142
143 def __repr__(self) -> str:
144 return (
145 f"{self.__class__.__name__}(keys={self.keys!r}, handler={self.handler!r})"
146 )
147
148
149# Sequence of keys presses.
150KeysTuple = Tuple[Union[Keys, str], ...]
151
152
153class KeyBindingsBase(metaclass=ABCMeta):
154 """
155 Interface for a KeyBindings.
156 """
157
158 @abstractproperty
159 def _version(self) -> Hashable:
160 """
161 For cache invalidation. - This should increase every time that
162 something changes.
163 """
164 return 0
165
166 @abstractmethod
167 def get_bindings_for_keys(self, keys: KeysTuple) -> list[Binding]:
168 """
169 Return a list of key bindings that can handle these keys.
170 (This return also inactive bindings, so the `filter` still has to be
171 called, for checking it.)
172
173 :param keys: tuple of keys.
174 """
175 return []
176
177 @abstractmethod
178 def get_bindings_starting_with_keys(self, keys: KeysTuple) -> list[Binding]:
179 """
180 Return a list of key bindings that handle a key sequence starting with
181 `keys`. (It does only return bindings for which the sequences are
182 longer than `keys`. And like `get_bindings_for_keys`, it also includes
183 inactive bindings.)
184
185 :param keys: tuple of keys.
186 """
187 return []
188
189 @abstractproperty
190 def bindings(self) -> list[Binding]:
191 """
192 List of `Binding` objects.
193 (These need to be exposed, so that `KeyBindings` objects can be merged
194 together.)
195 """
196 return []
197
198 # `add` and `remove` don't have to be part of this interface.
199
200
201T = TypeVar("T", bound=Union[KeyHandlerCallable, Binding])
202
203
204class KeyBindings(KeyBindingsBase):
205 """
206 A container for a set of key bindings.
207
208 Example usage::
209
210 kb = KeyBindings()
211
212 @kb.add('c-t')
213 def _(event):
214 print('Control-T pressed')
215
216 @kb.add('c-a', 'c-b')
217 def _(event):
218 print('Control-A pressed, followed by Control-B')
219
220 @kb.add('c-x', filter=is_searching)
221 def _(event):
222 print('Control-X pressed') # Works only if we are searching.
223
224 """
225
226 def __init__(self) -> None:
227 self._bindings: list[Binding] = []
228 self._get_bindings_for_keys_cache: SimpleCache[KeysTuple, list[Binding]] = (
229 SimpleCache(maxsize=10000)
230 )
231 self._get_bindings_starting_with_keys_cache: SimpleCache[
232 KeysTuple, list[Binding]
233 ] = SimpleCache(maxsize=1000)
234 self.__version = 0 # For cache invalidation.
235
236 def _clear_cache(self) -> None:
237 self.__version += 1
238 self._get_bindings_for_keys_cache.clear()
239 self._get_bindings_starting_with_keys_cache.clear()
240
241 @property
242 def bindings(self) -> list[Binding]:
243 return self._bindings
244
245 @property
246 def _version(self) -> Hashable:
247 return self.__version
248
249 def add(
250 self,
251 *keys: Keys | str,
252 filter: FilterOrBool = True,
253 eager: FilterOrBool = False,
254 is_global: FilterOrBool = False,
255 save_before: Callable[[KeyPressEvent], bool] = (lambda e: True),
256 record_in_macro: FilterOrBool = True,
257 ) -> Callable[[T], T]:
258 """
259 Decorator for adding a key bindings.
260
261 :param filter: :class:`~prompt_toolkit.filters.Filter` to determine
262 when this key binding is active.
263 :param eager: :class:`~prompt_toolkit.filters.Filter` or `bool`.
264 When True, ignore potential longer matches when this key binding is
265 hit. E.g. when there is an active eager key binding for Ctrl-X,
266 execute the handler immediately and ignore the key binding for
267 Ctrl-X Ctrl-E of which it is a prefix.
268 :param is_global: When this key bindings is added to a `Container` or
269 `Control`, make it a global (always active) binding.
270 :param save_before: Callable that takes an `Event` and returns True if
271 we should save the current buffer, before handling the event.
272 (That's the default.)
273 :param record_in_macro: Record these key bindings when a macro is
274 being recorded. (True by default.)
275 """
276 assert keys
277
278 keys = tuple(_parse_key(k) for k in keys)
279
280 if isinstance(filter, Never):
281 # When a filter is Never, it will always stay disabled, so in that
282 # case don't bother putting it in the key bindings. It will slow
283 # down every key press otherwise.
284 def decorator(func: T) -> T:
285 return func
286
287 else:
288
289 def decorator(func: T) -> T:
290 if isinstance(func, Binding):
291 # We're adding an existing Binding object.
292 self.bindings.append(
293 Binding(
294 keys,
295 func.handler,
296 filter=func.filter & to_filter(filter),
297 eager=to_filter(eager) | func.eager,
298 is_global=to_filter(is_global) | func.is_global,
299 save_before=func.save_before,
300 record_in_macro=func.record_in_macro,
301 )
302 )
303 else:
304 self.bindings.append(
305 Binding(
306 keys,
307 cast(KeyHandlerCallable, func),
308 filter=filter,
309 eager=eager,
310 is_global=is_global,
311 save_before=save_before,
312 record_in_macro=record_in_macro,
313 )
314 )
315 self._clear_cache()
316
317 return func
318
319 return decorator
320
321 def remove(self, *args: Keys | str | KeyHandlerCallable) -> None:
322 """
323 Remove a key binding.
324
325 This expects either a function that was given to `add` method as
326 parameter or a sequence of key bindings.
327
328 Raises `ValueError` when no bindings was found.
329
330 Usage::
331
332 remove(handler) # Pass handler.
333 remove('c-x', 'c-a') # Or pass the key bindings.
334 """
335 found = False
336
337 if callable(args[0]):
338 assert len(args) == 1
339 function = args[0]
340
341 # Remove the given function.
342 for b in self.bindings:
343 if b.handler == function:
344 self.bindings.remove(b)
345 found = True
346
347 else:
348 assert len(args) > 0
349 args = cast(Tuple[Union[Keys, str]], args)
350
351 # Remove this sequence of key bindings.
352 keys = tuple(_parse_key(k) for k in args)
353
354 for b in self.bindings:
355 if b.keys == keys:
356 self.bindings.remove(b)
357 found = True
358
359 if found:
360 self._clear_cache()
361 else:
362 # No key binding found for this function. Raise ValueError.
363 raise ValueError(f"Binding not found: {function!r}")
364
365 # For backwards-compatibility.
366 add_binding = add
367 remove_binding = remove
368
369 def get_bindings_for_keys(self, keys: KeysTuple) -> list[Binding]:
370 """
371 Return a list of key bindings that can handle this key.
372 (This return also inactive bindings, so the `filter` still has to be
373 called, for checking it.)
374
375 :param keys: tuple of keys.
376 """
377
378 def get() -> list[Binding]:
379 result: list[tuple[int, Binding]] = []
380
381 for b in self.bindings:
382 if len(keys) == len(b.keys):
383 match = True
384 any_count = 0
385
386 for i, j in zip(b.keys, keys):
387 if i != j and i != Keys.Any:
388 match = False
389 break
390
391 if i == Keys.Any:
392 any_count += 1
393
394 if match:
395 result.append((any_count, b))
396
397 # Place bindings that have more 'Any' occurrences in them at the end.
398 result = sorted(result, key=lambda item: -item[0])
399
400 return [item[1] for item in result]
401
402 return self._get_bindings_for_keys_cache.get(keys, get)
403
404 def get_bindings_starting_with_keys(self, keys: KeysTuple) -> list[Binding]:
405 """
406 Return a list of key bindings that handle a key sequence starting with
407 `keys`. (It does only return bindings for which the sequences are
408 longer than `keys`. And like `get_bindings_for_keys`, it also includes
409 inactive bindings.)
410
411 :param keys: tuple of keys.
412 """
413
414 def get() -> list[Binding]:
415 result = []
416 for b in self.bindings:
417 if len(keys) < len(b.keys):
418 match = True
419 for i, j in zip(b.keys, keys):
420 if i != j and i != Keys.Any:
421 match = False
422 break
423 if match:
424 result.append(b)
425 return result
426
427 return self._get_bindings_starting_with_keys_cache.get(keys, get)
428
429
430def _parse_key(key: Keys | str) -> str | Keys:
431 """
432 Replace key by alias and verify whether it's a valid one.
433 """
434 # Already a parse key? -> Return it.
435 if isinstance(key, Keys):
436 return key
437
438 # Lookup aliases.
439 key = KEY_ALIASES.get(key, key)
440
441 # Replace 'space' by ' '
442 if key == "space":
443 key = " "
444
445 # Return as `Key` object when it's a special key.
446 try:
447 return Keys(key)
448 except ValueError:
449 pass
450
451 # Final validation.
452 if len(key) != 1:
453 raise ValueError(f"Invalid key: {key}")
454
455 return key
456
457
458def key_binding(
459 filter: FilterOrBool = True,
460 eager: FilterOrBool = False,
461 is_global: FilterOrBool = False,
462 save_before: Callable[[KeyPressEvent], bool] = (lambda event: True),
463 record_in_macro: FilterOrBool = True,
464) -> Callable[[KeyHandlerCallable], Binding]:
465 """
466 Decorator that turn a function into a `Binding` object. This can be added
467 to a `KeyBindings` object when a key binding is assigned.
468 """
469 assert save_before is None or callable(save_before)
470
471 filter = to_filter(filter)
472 eager = to_filter(eager)
473 is_global = to_filter(is_global)
474 save_before = save_before
475 record_in_macro = to_filter(record_in_macro)
476 keys = ()
477
478 def decorator(function: KeyHandlerCallable) -> Binding:
479 return Binding(
480 keys,
481 function,
482 filter=filter,
483 eager=eager,
484 is_global=is_global,
485 save_before=save_before,
486 record_in_macro=record_in_macro,
487 )
488
489 return decorator
490
491
492class _Proxy(KeyBindingsBase):
493 """
494 Common part for ConditionalKeyBindings and _MergedKeyBindings.
495 """
496
497 def __init__(self) -> None:
498 # `KeyBindings` to be synchronized with all the others.
499 self._bindings2: KeyBindingsBase = KeyBindings()
500 self._last_version: Hashable = ()
501
502 def _update_cache(self) -> None:
503 """
504 If `self._last_version` is outdated, then this should update
505 the version and `self._bindings2`.
506 """
507 raise NotImplementedError
508
509 # Proxy methods to self._bindings2.
510
511 @property
512 def bindings(self) -> list[Binding]:
513 self._update_cache()
514 return self._bindings2.bindings
515
516 @property
517 def _version(self) -> Hashable:
518 self._update_cache()
519 return self._last_version
520
521 def get_bindings_for_keys(self, keys: KeysTuple) -> list[Binding]:
522 self._update_cache()
523 return self._bindings2.get_bindings_for_keys(keys)
524
525 def get_bindings_starting_with_keys(self, keys: KeysTuple) -> list[Binding]:
526 self._update_cache()
527 return self._bindings2.get_bindings_starting_with_keys(keys)
528
529
530class ConditionalKeyBindings(_Proxy):
531 """
532 Wraps around a `KeyBindings`. Disable/enable all the key bindings according to
533 the given (additional) filter.::
534
535 @Condition
536 def setting_is_true():
537 return True # or False
538
539 registry = ConditionalKeyBindings(key_bindings, setting_is_true)
540
541 When new key bindings are added to this object. They are also
542 enable/disabled according to the given `filter`.
543
544 :param registries: List of :class:`.KeyBindings` objects.
545 :param filter: :class:`~prompt_toolkit.filters.Filter` object.
546 """
547
548 def __init__(
549 self, key_bindings: KeyBindingsBase, filter: FilterOrBool = True
550 ) -> None:
551 _Proxy.__init__(self)
552
553 self.key_bindings = key_bindings
554 self.filter = to_filter(filter)
555
556 def _update_cache(self) -> None:
557 "If the original key bindings was changed. Update our copy version."
558 expected_version = self.key_bindings._version
559
560 if self._last_version != expected_version:
561 bindings2 = KeyBindings()
562
563 # Copy all bindings from `self.key_bindings`, adding our condition.
564 for b in self.key_bindings.bindings:
565 bindings2.bindings.append(
566 Binding(
567 keys=b.keys,
568 handler=b.handler,
569 filter=self.filter & b.filter,
570 eager=b.eager,
571 is_global=b.is_global,
572 save_before=b.save_before,
573 record_in_macro=b.record_in_macro,
574 )
575 )
576
577 self._bindings2 = bindings2
578 self._last_version = expected_version
579
580
581class _MergedKeyBindings(_Proxy):
582 """
583 Merge multiple registries of key bindings into one.
584
585 This class acts as a proxy to multiple :class:`.KeyBindings` objects, but
586 behaves as if this is just one bigger :class:`.KeyBindings`.
587
588 :param registries: List of :class:`.KeyBindings` objects.
589 """
590
591 def __init__(self, registries: Sequence[KeyBindingsBase]) -> None:
592 _Proxy.__init__(self)
593 self.registries = registries
594
595 def _update_cache(self) -> None:
596 """
597 If one of the original registries was changed. Update our merged
598 version.
599 """
600 expected_version = tuple(r._version for r in self.registries)
601
602 if self._last_version != expected_version:
603 bindings2 = KeyBindings()
604
605 for reg in self.registries:
606 bindings2.bindings.extend(reg.bindings)
607
608 self._bindings2 = bindings2
609 self._last_version = expected_version
610
611
612def merge_key_bindings(bindings: Sequence[KeyBindingsBase]) -> _MergedKeyBindings:
613 """
614 Merge multiple :class:`.Keybinding` objects together.
615
616 Usage::
617
618 bindings = merge_key_bindings([bindings1, bindings2, ...])
619 """
620 return _MergedKeyBindings(bindings)
621
622
623class DynamicKeyBindings(_Proxy):
624 """
625 KeyBindings class that can dynamically returns any KeyBindings.
626
627 :param get_key_bindings: Callable that returns a :class:`.KeyBindings` instance.
628 """
629
630 def __init__(self, get_key_bindings: Callable[[], KeyBindingsBase | None]) -> None:
631 self.get_key_bindings = get_key_bindings
632 self.__version = 0
633 self._last_child_version = None
634 self._dummy = KeyBindings() # Empty key bindings.
635
636 def _update_cache(self) -> None:
637 key_bindings = self.get_key_bindings() or self._dummy
638 assert isinstance(key_bindings, KeyBindingsBase)
639 version = id(key_bindings), key_bindings._version
640
641 self._bindings2 = key_bindings
642 self._last_version = version
643
644
645class GlobalOnlyKeyBindings(_Proxy):
646 """
647 Wrapper around a :class:`.KeyBindings` object that only exposes the global
648 key bindings.
649 """
650
651 def __init__(self, key_bindings: KeyBindingsBase) -> None:
652 _Proxy.__init__(self)
653 self.key_bindings = key_bindings
654
655 def _update_cache(self) -> None:
656 """
657 If one of the original registries was changed. Update our merged
658 version.
659 """
660 expected_version = self.key_bindings._version
661
662 if self._last_version != expected_version:
663 bindings2 = KeyBindings()
664
665 for b in self.key_bindings.bindings:
666 if b.is_global():
667 bindings2.bindings.append(b)
668
669 self._bindings2 = bindings2
670 self._last_version = expected_version