1"""
2Key bindings registry.
3
4A `KeyBindings` object is a container that holds a list of key bindings. It has a
5very efficient internal data structure for checking which key bindings apply
6for a pressed key.
7
8Typical usage::
9
10 kb = KeyBindings()
11
12 @kb.add(Keys.ControlX, Keys.ControlC, filter=INSERT)
13 def handler(event):
14 # Handle ControlX-ControlC key sequence.
15 pass
16
17It is also possible to combine multiple KeyBindings objects. We do this in the
18default key bindings. There are some KeyBindings objects that contain the Emacs
19bindings, while others contain the Vi bindings. They are merged together using
20`merge_key_bindings`.
21
22We also have a `ConditionalKeyBindings` object that can enable/disable a group of
23key bindings at once.
24
25
26It is also possible to add a filter to a function, before a key binding has
27been assigned, through the `key_binding` decorator.::
28
29 # First define a key handler with the `filter`.
30 @key_binding(filter=condition)
31 def my_key_binding(event):
32 ...
33
34 # Later, add it to the key bindings.
35 kb.add(Keys.A, my_key_binding)
36"""
37
38from __future__ import annotations
39
40from abc import ABCMeta, abstractmethod
41from inspect import isawaitable
42from typing import (
43 TYPE_CHECKING,
44 Any,
45 Callable,
46 Coroutine,
47 Hashable,
48 Sequence,
49 Tuple,
50 TypeVar,
51 Union,
52 cast,
53)
54
55from prompt_toolkit.cache import SimpleCache
56from prompt_toolkit.filters import FilterOrBool, Never, to_filter
57from prompt_toolkit.keys import KEY_ALIASES, Keys
58
59if TYPE_CHECKING:
60 # Avoid circular imports.
61 from .key_processor import KeyPressEvent
62
63 # The only two return values for a mouse handler (and key bindings) are
64 # `None` and `NotImplemented`. For the type checker it's best to annotate
65 # this as `object`. (The consumer never expects a more specific instance:
66 # checking for NotImplemented can be done using `is NotImplemented`.)
67 NotImplementedOrNone = object
68 # Other non-working options are:
69 # * Optional[Literal[NotImplemented]]
70 # --> Doesn't work, Literal can't take an Any.
71 # * None
72 # --> Doesn't work. We can't assign the result of a function that
73 # returns `None` to a variable.
74 # * Any
75 # --> Works, but too broad.
76
77
78__all__ = [
79 "NotImplementedOrNone",
80 "Binding",
81 "KeyBindingsBase",
82 "KeyBindings",
83 "ConditionalKeyBindings",
84 "merge_key_bindings",
85 "DynamicKeyBindings",
86 "GlobalOnlyKeyBindings",
87]
88
89# Key bindings can be regular functions or coroutines.
90# In both cases, if they return `NotImplemented`, the UI won't be invalidated.
91# This is mainly used in case of mouse move events, to prevent excessive
92# repainting during mouse move events.
93KeyHandlerCallable = Callable[
94 ["KeyPressEvent"],
95 Union["NotImplementedOrNone", Coroutine[Any, Any, "NotImplementedOrNone"]],
96]
97
98
99class Binding:
100 """
101 Key binding: (key sequence + handler + filter).
102 (Immutable binding class.)
103
104 :param record_in_macro: When True, don't record this key binding when a
105 macro is recorded.
106 """
107
108 def __init__(
109 self,
110 keys: tuple[Keys | str, ...],
111 handler: KeyHandlerCallable,
112 filter: FilterOrBool = True,
113 eager: FilterOrBool = False,
114 is_global: FilterOrBool = False,
115 save_before: Callable[[KeyPressEvent], bool] = (lambda e: True),
116 record_in_macro: FilterOrBool = True,
117 ) -> None:
118 self.keys = keys
119 self.handler = handler
120 self.filter = to_filter(filter)
121 self.eager = to_filter(eager)
122 self.is_global = to_filter(is_global)
123 self.save_before = save_before
124 self.record_in_macro = to_filter(record_in_macro)
125
126 def call(self, event: KeyPressEvent) -> None:
127 result = self.handler(event)
128
129 # If the handler is a coroutine, create an asyncio task.
130 if isawaitable(result):
131 awaitable = cast(Coroutine[Any, Any, "NotImplementedOrNone"], result)
132
133 async def bg_task() -> None:
134 result = await awaitable
135 if result != NotImplemented:
136 event.app.invalidate()
137
138 event.app.create_background_task(bg_task())
139
140 elif result != NotImplemented:
141 event.app.invalidate()
142
143 def __repr__(self) -> str:
144 return (
145 f"{self.__class__.__name__}(keys={self.keys!r}, handler={self.handler!r})"
146 )
147
148
149# Sequence of keys presses.
150KeysTuple = Tuple[Union[Keys, str], ...]
151
152
153class KeyBindingsBase(metaclass=ABCMeta):
154 """
155 Interface for a KeyBindings.
156 """
157
158 @property
159 @abstractmethod
160 def _version(self) -> Hashable:
161 """
162 For cache invalidation. - This should increase every time that
163 something changes.
164 """
165 return 0
166
167 @abstractmethod
168 def get_bindings_for_keys(self, keys: KeysTuple) -> list[Binding]:
169 """
170 Return a list of key bindings that can handle these keys.
171 (This return also inactive bindings, so the `filter` still has to be
172 called, for checking it.)
173
174 :param keys: tuple of keys.
175 """
176 return []
177
178 @abstractmethod
179 def get_bindings_starting_with_keys(self, keys: KeysTuple) -> list[Binding]:
180 """
181 Return a list of key bindings that handle a key sequence starting with
182 `keys`. (It does only return bindings for which the sequences are
183 longer than `keys`. And like `get_bindings_for_keys`, it also includes
184 inactive bindings.)
185
186 :param keys: tuple of keys.
187 """
188 return []
189
190 @property
191 @abstractmethod
192 def bindings(self) -> list[Binding]:
193 """
194 List of `Binding` objects.
195 (These need to be exposed, so that `KeyBindings` objects can be merged
196 together.)
197 """
198 return []
199
200 # `add` and `remove` don't have to be part of this interface.
201
202
203T = TypeVar("T", bound=Union[KeyHandlerCallable, Binding])
204
205
206class KeyBindings(KeyBindingsBase):
207 """
208 A container for a set of key bindings.
209
210 Example usage::
211
212 kb = KeyBindings()
213
214 @kb.add('c-t')
215 def _(event):
216 print('Control-T pressed')
217
218 @kb.add('c-a', 'c-b')
219 def _(event):
220 print('Control-A pressed, followed by Control-B')
221
222 @kb.add('c-x', filter=is_searching)
223 def _(event):
224 print('Control-X pressed') # Works only if we are searching.
225
226 """
227
228 def __init__(self) -> None:
229 self._bindings: list[Binding] = []
230 self._get_bindings_for_keys_cache: SimpleCache[KeysTuple, list[Binding]] = (
231 SimpleCache(maxsize=10000)
232 )
233 self._get_bindings_starting_with_keys_cache: SimpleCache[
234 KeysTuple, list[Binding]
235 ] = SimpleCache(maxsize=1000)
236 self.__version = 0 # For cache invalidation.
237
238 def _clear_cache(self) -> None:
239 self.__version += 1
240 self._get_bindings_for_keys_cache.clear()
241 self._get_bindings_starting_with_keys_cache.clear()
242
243 @property
244 def bindings(self) -> list[Binding]:
245 return self._bindings
246
247 @property
248 def _version(self) -> Hashable:
249 return self.__version
250
251 def add(
252 self,
253 *keys: Keys | str,
254 filter: FilterOrBool = True,
255 eager: FilterOrBool = False,
256 is_global: FilterOrBool = False,
257 save_before: Callable[[KeyPressEvent], bool] = (lambda e: True),
258 record_in_macro: FilterOrBool = True,
259 ) -> Callable[[T], T]:
260 """
261 Decorator for adding a key bindings.
262
263 :param filter: :class:`~prompt_toolkit.filters.Filter` to determine
264 when this key binding is active.
265 :param eager: :class:`~prompt_toolkit.filters.Filter` or `bool`.
266 When True, ignore potential longer matches when this key binding is
267 hit. E.g. when there is an active eager key binding for Ctrl-X,
268 execute the handler immediately and ignore the key binding for
269 Ctrl-X Ctrl-E of which it is a prefix.
270 :param is_global: When this key bindings is added to a `Container` or
271 `Control`, make it a global (always active) binding.
272 :param save_before: Callable that takes an `Event` and returns True if
273 we should save the current buffer, before handling the event.
274 (That's the default.)
275 :param record_in_macro: Record these key bindings when a macro is
276 being recorded. (True by default.)
277 """
278 assert keys
279
280 keys = tuple(_parse_key(k) for k in keys)
281
282 if isinstance(filter, Never):
283 # When a filter is Never, it will always stay disabled, so in that
284 # case don't bother putting it in the key bindings. It will slow
285 # down every key press otherwise.
286 def decorator(func: T) -> T:
287 return func
288
289 else:
290
291 def decorator(func: T) -> T:
292 if isinstance(func, Binding):
293 # We're adding an existing Binding object.
294 self.bindings.append(
295 Binding(
296 keys,
297 func.handler,
298 filter=func.filter & to_filter(filter),
299 eager=to_filter(eager) | func.eager,
300 is_global=to_filter(is_global) | func.is_global,
301 save_before=func.save_before,
302 record_in_macro=func.record_in_macro,
303 )
304 )
305 else:
306 self.bindings.append(
307 Binding(
308 keys,
309 cast(KeyHandlerCallable, func),
310 filter=filter,
311 eager=eager,
312 is_global=is_global,
313 save_before=save_before,
314 record_in_macro=record_in_macro,
315 )
316 )
317 self._clear_cache()
318
319 return func
320
321 return decorator
322
323 def remove(self, *args: Keys | str | KeyHandlerCallable) -> None:
324 """
325 Remove a key binding.
326
327 This expects either a function that was given to `add` method as
328 parameter or a sequence of key bindings.
329
330 Raises `ValueError` when no bindings was found.
331
332 Usage::
333
334 remove(handler) # Pass handler.
335 remove('c-x', 'c-a') # Or pass the key bindings.
336 """
337 found = False
338
339 if callable(args[0]):
340 assert len(args) == 1
341 function = args[0]
342
343 # Remove the given function.
344 for b in self.bindings:
345 if b.handler == function:
346 self.bindings.remove(b)
347 found = True
348
349 else:
350 assert len(args) > 0
351 args = cast(Tuple[Union[Keys, str]], args)
352
353 # Remove this sequence of key bindings.
354 keys = tuple(_parse_key(k) for k in args)
355
356 for b in self.bindings:
357 if b.keys == keys:
358 self.bindings.remove(b)
359 found = True
360
361 if found:
362 self._clear_cache()
363 else:
364 # No key binding found for this function. Raise ValueError.
365 raise ValueError(f"Binding not found: {function!r}")
366
367 # For backwards-compatibility.
368 add_binding = add
369 remove_binding = remove
370
371 def get_bindings_for_keys(self, keys: KeysTuple) -> list[Binding]:
372 """
373 Return a list of key bindings that can handle this key.
374 (This return also inactive bindings, so the `filter` still has to be
375 called, for checking it.)
376
377 :param keys: tuple of keys.
378 """
379
380 def get() -> list[Binding]:
381 result: list[tuple[int, Binding]] = []
382
383 for b in self.bindings:
384 if len(keys) == len(b.keys):
385 match = True
386 any_count = 0
387
388 for i, j in zip(b.keys, keys):
389 if i != j and i != Keys.Any:
390 match = False
391 break
392
393 if i == Keys.Any:
394 any_count += 1
395
396 if match:
397 result.append((any_count, b))
398
399 # Place bindings that have more 'Any' occurrences in them at the end.
400 result = sorted(result, key=lambda item: -item[0])
401
402 return [item[1] for item in result]
403
404 return self._get_bindings_for_keys_cache.get(keys, get)
405
406 def get_bindings_starting_with_keys(self, keys: KeysTuple) -> list[Binding]:
407 """
408 Return a list of key bindings that handle a key sequence starting with
409 `keys`. (It does only return bindings for which the sequences are
410 longer than `keys`. And like `get_bindings_for_keys`, it also includes
411 inactive bindings.)
412
413 :param keys: tuple of keys.
414 """
415
416 def get() -> list[Binding]:
417 result = []
418 for b in self.bindings:
419 if len(keys) < len(b.keys):
420 match = True
421 for i, j in zip(b.keys, keys):
422 if i != j and i != Keys.Any:
423 match = False
424 break
425 if match:
426 result.append(b)
427 return result
428
429 return self._get_bindings_starting_with_keys_cache.get(keys, get)
430
431
432def _parse_key(key: Keys | str) -> str | Keys:
433 """
434 Replace key by alias and verify whether it's a valid one.
435 """
436 # Already a parse key? -> Return it.
437 if isinstance(key, Keys):
438 return key
439
440 # Lookup aliases.
441 key = KEY_ALIASES.get(key, key)
442
443 # Replace 'space' by ' '
444 if key == "space":
445 key = " "
446
447 # Return as `Key` object when it's a special key.
448 try:
449 return Keys(key)
450 except ValueError:
451 pass
452
453 # Final validation.
454 if len(key) != 1:
455 raise ValueError(f"Invalid key: {key}")
456
457 return key
458
459
460def key_binding(
461 filter: FilterOrBool = True,
462 eager: FilterOrBool = False,
463 is_global: FilterOrBool = False,
464 save_before: Callable[[KeyPressEvent], bool] = (lambda event: True),
465 record_in_macro: FilterOrBool = True,
466) -> Callable[[KeyHandlerCallable], Binding]:
467 """
468 Decorator that turn a function into a `Binding` object. This can be added
469 to a `KeyBindings` object when a key binding is assigned.
470 """
471 assert save_before is None or callable(save_before)
472
473 filter = to_filter(filter)
474 eager = to_filter(eager)
475 is_global = to_filter(is_global)
476 save_before = save_before
477 record_in_macro = to_filter(record_in_macro)
478 keys = ()
479
480 def decorator(function: KeyHandlerCallable) -> Binding:
481 return Binding(
482 keys,
483 function,
484 filter=filter,
485 eager=eager,
486 is_global=is_global,
487 save_before=save_before,
488 record_in_macro=record_in_macro,
489 )
490
491 return decorator
492
493
494class _Proxy(KeyBindingsBase):
495 """
496 Common part for ConditionalKeyBindings and _MergedKeyBindings.
497 """
498
499 def __init__(self) -> None:
500 # `KeyBindings` to be synchronized with all the others.
501 self._bindings2: KeyBindingsBase = KeyBindings()
502 self._last_version: Hashable = ()
503
504 def _update_cache(self) -> None:
505 """
506 If `self._last_version` is outdated, then this should update
507 the version and `self._bindings2`.
508 """
509 raise NotImplementedError
510
511 # Proxy methods to self._bindings2.
512
513 @property
514 def bindings(self) -> list[Binding]:
515 self._update_cache()
516 return self._bindings2.bindings
517
518 @property
519 def _version(self) -> Hashable:
520 self._update_cache()
521 return self._last_version
522
523 def get_bindings_for_keys(self, keys: KeysTuple) -> list[Binding]:
524 self._update_cache()
525 return self._bindings2.get_bindings_for_keys(keys)
526
527 def get_bindings_starting_with_keys(self, keys: KeysTuple) -> list[Binding]:
528 self._update_cache()
529 return self._bindings2.get_bindings_starting_with_keys(keys)
530
531
532class ConditionalKeyBindings(_Proxy):
533 """
534 Wraps around a `KeyBindings`. Disable/enable all the key bindings according to
535 the given (additional) filter.::
536
537 @Condition
538 def setting_is_true():
539 return True # or False
540
541 registry = ConditionalKeyBindings(key_bindings, setting_is_true)
542
543 When new key bindings are added to this object. They are also
544 enable/disabled according to the given `filter`.
545
546 :param registries: List of :class:`.KeyBindings` objects.
547 :param filter: :class:`~prompt_toolkit.filters.Filter` object.
548 """
549
550 def __init__(
551 self, key_bindings: KeyBindingsBase, filter: FilterOrBool = True
552 ) -> None:
553 _Proxy.__init__(self)
554
555 self.key_bindings = key_bindings
556 self.filter = to_filter(filter)
557
558 def _update_cache(self) -> None:
559 "If the original key bindings was changed. Update our copy version."
560 expected_version = self.key_bindings._version
561
562 if self._last_version != expected_version:
563 bindings2 = KeyBindings()
564
565 # Copy all bindings from `self.key_bindings`, adding our condition.
566 for b in self.key_bindings.bindings:
567 bindings2.bindings.append(
568 Binding(
569 keys=b.keys,
570 handler=b.handler,
571 filter=self.filter & b.filter,
572 eager=b.eager,
573 is_global=b.is_global,
574 save_before=b.save_before,
575 record_in_macro=b.record_in_macro,
576 )
577 )
578
579 self._bindings2 = bindings2
580 self._last_version = expected_version
581
582
583class _MergedKeyBindings(_Proxy):
584 """
585 Merge multiple registries of key bindings into one.
586
587 This class acts as a proxy to multiple :class:`.KeyBindings` objects, but
588 behaves as if this is just one bigger :class:`.KeyBindings`.
589
590 :param registries: List of :class:`.KeyBindings` objects.
591 """
592
593 def __init__(self, registries: Sequence[KeyBindingsBase]) -> None:
594 _Proxy.__init__(self)
595 self.registries = registries
596
597 def _update_cache(self) -> None:
598 """
599 If one of the original registries was changed. Update our merged
600 version.
601 """
602 expected_version = tuple(r._version for r in self.registries)
603
604 if self._last_version != expected_version:
605 bindings2 = KeyBindings()
606
607 for reg in self.registries:
608 bindings2.bindings.extend(reg.bindings)
609
610 self._bindings2 = bindings2
611 self._last_version = expected_version
612
613
614def merge_key_bindings(bindings: Sequence[KeyBindingsBase]) -> _MergedKeyBindings:
615 """
616 Merge multiple :class:`.Keybinding` objects together.
617
618 Usage::
619
620 bindings = merge_key_bindings([bindings1, bindings2, ...])
621 """
622 return _MergedKeyBindings(bindings)
623
624
625class DynamicKeyBindings(_Proxy):
626 """
627 KeyBindings class that can dynamically returns any KeyBindings.
628
629 :param get_key_bindings: Callable that returns a :class:`.KeyBindings` instance.
630 """
631
632 def __init__(self, get_key_bindings: Callable[[], KeyBindingsBase | None]) -> None:
633 self.get_key_bindings = get_key_bindings
634 self.__version = 0
635 self._last_child_version = None
636 self._dummy = KeyBindings() # Empty key bindings.
637
638 def _update_cache(self) -> None:
639 key_bindings = self.get_key_bindings() or self._dummy
640 assert isinstance(key_bindings, KeyBindingsBase)
641 version = id(key_bindings), key_bindings._version
642
643 self._bindings2 = key_bindings
644 self._last_version = version
645
646
647class GlobalOnlyKeyBindings(_Proxy):
648 """
649 Wrapper around a :class:`.KeyBindings` object that only exposes the global
650 key bindings.
651 """
652
653 def __init__(self, key_bindings: KeyBindingsBase) -> None:
654 _Proxy.__init__(self)
655 self.key_bindings = key_bindings
656
657 def _update_cache(self) -> None:
658 """
659 If one of the original registries was changed. Update our merged
660 version.
661 """
662 expected_version = self.key_bindings._version
663
664 if self._last_version != expected_version:
665 bindings2 = KeyBindings()
666
667 for b in self.key_bindings.bindings:
668 if b.is_global():
669 bindings2.bindings.append(b)
670
671 self._bindings2 = bindings2
672 self._last_version = expected_version