1"""
2Key bindings registry.
3
4A `KeyBindings` object is a container that holds a list of key bindings. It has a
5very efficient internal data structure for checking which key bindings apply
6for a pressed key.
7
8Typical usage::
9
10 kb = KeyBindings()
11
12 @kb.add(Keys.ControlX, Keys.ControlC, filter=INSERT)
13 def handler(event):
14 # Handle ControlX-ControlC key sequence.
15 pass
16
17It is also possible to combine multiple KeyBindings objects. We do this in the
18default key bindings. There are some KeyBindings objects that contain the Emacs
19bindings, while others contain the Vi bindings. They are merged together using
20`merge_key_bindings`.
21
22We also have a `ConditionalKeyBindings` object that can enable/disable a group of
23key bindings at once.
24
25
26It is also possible to add a filter to a function, before a key binding has
27been assigned, through the `key_binding` decorator.::
28
29 # First define a key handler with the `filter`.
30 @key_binding(filter=condition)
31 def my_key_binding(event):
32 ...
33
34 # Later, add it to the key bindings.
35 kb.add(Keys.A, my_key_binding)
36"""
37
38from __future__ import annotations
39
40from abc import ABCMeta, abstractmethod
41from collections.abc import Callable, Coroutine, Hashable, Sequence
42from inspect import isawaitable
43from typing import (
44 TYPE_CHECKING,
45 Any,
46 TypeVar,
47 Union,
48 cast,
49)
50
51from prompt_toolkit.cache import SimpleCache
52from prompt_toolkit.filters import FilterOrBool, Never, to_filter
53from prompt_toolkit.keys import KEY_ALIASES, Keys
54
55if TYPE_CHECKING:
56 # Avoid circular imports.
57 from .key_processor import KeyPressEvent
58
59 # The only two return values for a mouse handler (and key bindings) are
60 # `None` and `NotImplemented`. For the type checker it's best to annotate
61 # this as `object`. (The consumer never expects a more specific instance:
62 # checking for NotImplemented can be done using `is NotImplemented`.)
63 NotImplementedOrNone = object
64 # Other non-working options are:
65 # * Optional[Literal[NotImplemented]]
66 # --> Doesn't work, Literal can't take an Any.
67 # * None
68 # --> Doesn't work. We can't assign the result of a function that
69 # returns `None` to a variable.
70 # * Any
71 # --> Works, but too broad.
72
73
74__all__ = [
75 "NotImplementedOrNone",
76 "Binding",
77 "KeyBindingsBase",
78 "KeyBindings",
79 "ConditionalKeyBindings",
80 "merge_key_bindings",
81 "DynamicKeyBindings",
82 "GlobalOnlyKeyBindings",
83]
84
85# Key bindings can be regular functions or coroutines.
86# In both cases, if they return `NotImplemented`, the UI won't be invalidated.
87# This is mainly used in case of mouse move events, to prevent excessive
88# repainting during mouse move events.
89KeyHandlerCallable = Callable[
90 ["KeyPressEvent"],
91 Union["NotImplementedOrNone", Coroutine[Any, Any, "NotImplementedOrNone"]],
92]
93
94
95class Binding:
96 """
97 Key binding: (key sequence + handler + filter).
98 (Immutable binding class.)
99
100 :param record_in_macro: When True, don't record this key binding when a
101 macro is recorded.
102 """
103
104 def __init__(
105 self,
106 keys: tuple[Keys | str, ...],
107 handler: KeyHandlerCallable,
108 filter: FilterOrBool = True,
109 eager: FilterOrBool = False,
110 is_global: FilterOrBool = False,
111 save_before: Callable[[KeyPressEvent], bool] = (lambda e: True),
112 record_in_macro: FilterOrBool = True,
113 ) -> None:
114 self.keys = keys
115 self.handler = handler
116 self.filter = to_filter(filter)
117 self.eager = to_filter(eager)
118 self.is_global = to_filter(is_global)
119 self.save_before = save_before
120 self.record_in_macro = to_filter(record_in_macro)
121
122 def call(self, event: KeyPressEvent) -> None:
123 result = self.handler(event)
124
125 # If the handler is a coroutine, create an asyncio task.
126 if isawaitable(result):
127 awaitable = cast(Coroutine[Any, Any, "NotImplementedOrNone"], result)
128
129 async def bg_task() -> None:
130 result = await awaitable
131 if result != NotImplemented:
132 event.app.invalidate()
133
134 event.app.create_background_task(bg_task())
135
136 elif result != NotImplemented:
137 event.app.invalidate()
138
139 def __repr__(self) -> str:
140 return (
141 f"{self.__class__.__name__}(keys={self.keys!r}, handler={self.handler!r})"
142 )
143
144
145# Sequence of keys presses.
146KeysTuple = tuple[Keys | str, ...]
147
148
149class KeyBindingsBase(metaclass=ABCMeta):
150 """
151 Interface for a KeyBindings.
152 """
153
154 @property
155 @abstractmethod
156 def _version(self) -> Hashable:
157 """
158 For cache invalidation. - This should increase every time that
159 something changes.
160 """
161 return 0
162
163 @abstractmethod
164 def get_bindings_for_keys(self, keys: KeysTuple) -> list[Binding]:
165 """
166 Return a list of key bindings that can handle these keys.
167 (This return also inactive bindings, so the `filter` still has to be
168 called, for checking it.)
169
170 :param keys: tuple of keys.
171 """
172 return []
173
174 @abstractmethod
175 def get_bindings_starting_with_keys(self, keys: KeysTuple) -> list[Binding]:
176 """
177 Return a list of key bindings that handle a key sequence starting with
178 `keys`. (It does only return bindings for which the sequences are
179 longer than `keys`. And like `get_bindings_for_keys`, it also includes
180 inactive bindings.)
181
182 :param keys: tuple of keys.
183 """
184 return []
185
186 @property
187 @abstractmethod
188 def bindings(self) -> list[Binding]:
189 """
190 List of `Binding` objects.
191 (These need to be exposed, so that `KeyBindings` objects can be merged
192 together.)
193 """
194 return []
195
196 # `add` and `remove` don't have to be part of this interface.
197
198
199T = TypeVar("T", bound=KeyHandlerCallable | Binding)
200
201
202class KeyBindings(KeyBindingsBase):
203 """
204 A container for a set of key bindings.
205
206 Example usage::
207
208 kb = KeyBindings()
209
210 @kb.add('c-t')
211 def _(event):
212 print('Control-T pressed')
213
214 @kb.add('c-a', 'c-b')
215 def _(event):
216 print('Control-A pressed, followed by Control-B')
217
218 @kb.add('c-x', filter=is_searching)
219 def _(event):
220 print('Control-X pressed') # Works only if we are searching.
221
222 """
223
224 def __init__(self) -> None:
225 self._bindings: list[Binding] = []
226 self._get_bindings_for_keys_cache: SimpleCache[KeysTuple, list[Binding]] = (
227 SimpleCache(maxsize=10000)
228 )
229 self._get_bindings_starting_with_keys_cache: SimpleCache[
230 KeysTuple, list[Binding]
231 ] = SimpleCache(maxsize=1000)
232 self.__version = 0 # For cache invalidation.
233
234 def _clear_cache(self) -> None:
235 self.__version += 1
236 self._get_bindings_for_keys_cache.clear()
237 self._get_bindings_starting_with_keys_cache.clear()
238
239 @property
240 def bindings(self) -> list[Binding]:
241 return self._bindings
242
243 @property
244 def _version(self) -> Hashable:
245 return self.__version
246
247 def add(
248 self,
249 *keys: Keys | str,
250 filter: FilterOrBool = True,
251 eager: FilterOrBool = False,
252 is_global: FilterOrBool = False,
253 save_before: Callable[[KeyPressEvent], bool] = (lambda e: True),
254 record_in_macro: FilterOrBool = True,
255 ) -> Callable[[T], T]:
256 """
257 Decorator for adding a key bindings.
258
259 :param filter: :class:`~prompt_toolkit.filters.Filter` to determine
260 when this key binding is active.
261 :param eager: :class:`~prompt_toolkit.filters.Filter` or `bool`.
262 When True, ignore potential longer matches when this key binding is
263 hit. E.g. when there is an active eager key binding for Ctrl-X,
264 execute the handler immediately and ignore the key binding for
265 Ctrl-X Ctrl-E of which it is a prefix.
266 :param is_global: When this key bindings is added to a `Container` or
267 `Control`, make it a global (always active) binding.
268 :param save_before: Callable that takes an `Event` and returns True if
269 we should save the current buffer, before handling the event.
270 (That's the default.)
271 :param record_in_macro: Record these key bindings when a macro is
272 being recorded. (True by default.)
273 """
274 assert keys
275
276 keys = tuple(_parse_key(k) for k in keys)
277
278 if isinstance(filter, Never):
279 # When a filter is Never, it will always stay disabled, so in that
280 # case don't bother putting it in the key bindings. It will slow
281 # down every key press otherwise.
282 def decorator(func: T) -> T:
283 return func
284
285 else:
286
287 def decorator(func: T) -> T:
288 if isinstance(func, Binding):
289 # We're adding an existing Binding object.
290 self.bindings.append(
291 Binding(
292 keys,
293 func.handler,
294 filter=func.filter & to_filter(filter),
295 eager=to_filter(eager) | func.eager,
296 is_global=to_filter(is_global) | func.is_global,
297 save_before=func.save_before,
298 record_in_macro=func.record_in_macro,
299 )
300 )
301 else:
302 self.bindings.append(
303 Binding(
304 keys,
305 cast(KeyHandlerCallable, func),
306 filter=filter,
307 eager=eager,
308 is_global=is_global,
309 save_before=save_before,
310 record_in_macro=record_in_macro,
311 )
312 )
313 self._clear_cache()
314
315 return func
316
317 return decorator
318
319 def remove(self, *args: Keys | str | KeyHandlerCallable) -> None:
320 """
321 Remove a key binding.
322
323 This expects either a function that was given to `add` method as
324 parameter or a sequence of key bindings.
325
326 Raises `ValueError` when no bindings was found.
327
328 Usage::
329
330 remove(handler) # Pass handler.
331 remove('c-x', 'c-a') # Or pass the key bindings.
332 """
333 found = False
334
335 if callable(args[0]):
336 assert len(args) == 1
337 function = args[0]
338
339 # Remove the given function.
340 for b in self.bindings:
341 if b.handler == function:
342 self.bindings.remove(b)
343 found = True
344
345 else:
346 assert len(args) > 0
347 args = cast(tuple[Keys | str], args)
348
349 # Remove this sequence of key bindings.
350 keys = tuple(_parse_key(k) for k in args)
351
352 for b in self.bindings:
353 if b.keys == keys:
354 self.bindings.remove(b)
355 found = True
356
357 if found:
358 self._clear_cache()
359 else:
360 # No key binding found for this function. Raise ValueError.
361 raise ValueError(f"Binding not found: {function!r}")
362
363 # For backwards-compatibility.
364 add_binding = add
365 remove_binding = remove
366
367 def get_bindings_for_keys(self, keys: KeysTuple) -> list[Binding]:
368 """
369 Return a list of key bindings that can handle this key.
370 (This return also inactive bindings, so the `filter` still has to be
371 called, for checking it.)
372
373 :param keys: tuple of keys.
374 """
375
376 def get() -> list[Binding]:
377 result: list[tuple[int, Binding]] = []
378
379 for b in self.bindings:
380 if len(keys) == len(b.keys):
381 match = True
382 any_count = 0
383
384 for i, j in zip(b.keys, keys):
385 if i != j and i != Keys.Any:
386 match = False
387 break
388
389 if i == Keys.Any:
390 any_count += 1
391
392 if match:
393 result.append((any_count, b))
394
395 # Place bindings that have more 'Any' occurrences in them at the end.
396 result = sorted(result, key=lambda item: -item[0])
397
398 return [item[1] for item in result]
399
400 return self._get_bindings_for_keys_cache.get(keys, get)
401
402 def get_bindings_starting_with_keys(self, keys: KeysTuple) -> list[Binding]:
403 """
404 Return a list of key bindings that handle a key sequence starting with
405 `keys`. (It does only return bindings for which the sequences are
406 longer than `keys`. And like `get_bindings_for_keys`, it also includes
407 inactive bindings.)
408
409 :param keys: tuple of keys.
410 """
411
412 def get() -> list[Binding]:
413 result = []
414 for b in self.bindings:
415 if len(keys) < len(b.keys):
416 match = True
417 for i, j in zip(b.keys, keys):
418 if i != j and i != Keys.Any:
419 match = False
420 break
421 if match:
422 result.append(b)
423 return result
424
425 return self._get_bindings_starting_with_keys_cache.get(keys, get)
426
427
428def _parse_key(key: Keys | str) -> str | Keys:
429 """
430 Replace key by alias and verify whether it's a valid one.
431 """
432 # Already a parse key? -> Return it.
433 if isinstance(key, Keys):
434 return key
435
436 # Lookup aliases.
437 key = KEY_ALIASES.get(key, key)
438
439 # Replace 'space' by ' '
440 if key == "space":
441 key = " "
442
443 # Return as `Key` object when it's a special key.
444 try:
445 return Keys(key)
446 except ValueError:
447 pass
448
449 # Final validation.
450 if len(key) != 1:
451 raise ValueError(f"Invalid key: {key}")
452
453 return key
454
455
456def key_binding(
457 filter: FilterOrBool = True,
458 eager: FilterOrBool = False,
459 is_global: FilterOrBool = False,
460 save_before: Callable[[KeyPressEvent], bool] = (lambda event: True),
461 record_in_macro: FilterOrBool = True,
462) -> Callable[[KeyHandlerCallable], Binding]:
463 """
464 Decorator that turn a function into a `Binding` object. This can be added
465 to a `KeyBindings` object when a key binding is assigned.
466 """
467 assert save_before is None or callable(save_before)
468
469 filter = to_filter(filter)
470 eager = to_filter(eager)
471 is_global = to_filter(is_global)
472 save_before = save_before
473 record_in_macro = to_filter(record_in_macro)
474 keys = ()
475
476 def decorator(function: KeyHandlerCallable) -> Binding:
477 return Binding(
478 keys,
479 function,
480 filter=filter,
481 eager=eager,
482 is_global=is_global,
483 save_before=save_before,
484 record_in_macro=record_in_macro,
485 )
486
487 return decorator
488
489
490class _Proxy(KeyBindingsBase):
491 """
492 Common part for ConditionalKeyBindings and _MergedKeyBindings.
493 """
494
495 def __init__(self) -> None:
496 # `KeyBindings` to be synchronized with all the others.
497 self._bindings2: KeyBindingsBase = KeyBindings()
498 self._last_version: Hashable = ()
499
500 def _update_cache(self) -> None:
501 """
502 If `self._last_version` is outdated, then this should update
503 the version and `self._bindings2`.
504 """
505 raise NotImplementedError
506
507 # Proxy methods to self._bindings2.
508
509 @property
510 def bindings(self) -> list[Binding]:
511 self._update_cache()
512 return self._bindings2.bindings
513
514 @property
515 def _version(self) -> Hashable:
516 self._update_cache()
517 return self._last_version
518
519 def get_bindings_for_keys(self, keys: KeysTuple) -> list[Binding]:
520 self._update_cache()
521 return self._bindings2.get_bindings_for_keys(keys)
522
523 def get_bindings_starting_with_keys(self, keys: KeysTuple) -> list[Binding]:
524 self._update_cache()
525 return self._bindings2.get_bindings_starting_with_keys(keys)
526
527
528class ConditionalKeyBindings(_Proxy):
529 """
530 Wraps around a `KeyBindings`. Disable/enable all the key bindings according to
531 the given (additional) filter.::
532
533 @Condition
534 def setting_is_true():
535 return True # or False
536
537 registry = ConditionalKeyBindings(key_bindings, setting_is_true)
538
539 When new key bindings are added to this object. They are also
540 enable/disabled according to the given `filter`.
541
542 :param registries: List of :class:`.KeyBindings` objects.
543 :param filter: :class:`~prompt_toolkit.filters.Filter` object.
544 """
545
546 def __init__(
547 self, key_bindings: KeyBindingsBase, filter: FilterOrBool = True
548 ) -> None:
549 _Proxy.__init__(self)
550
551 self.key_bindings = key_bindings
552 self.filter = to_filter(filter)
553
554 def _update_cache(self) -> None:
555 "If the original key bindings was changed. Update our copy version."
556 expected_version = self.key_bindings._version
557
558 if self._last_version != expected_version:
559 bindings2 = KeyBindings()
560
561 # Copy all bindings from `self.key_bindings`, adding our condition.
562 for b in self.key_bindings.bindings:
563 bindings2.bindings.append(
564 Binding(
565 keys=b.keys,
566 handler=b.handler,
567 filter=self.filter & b.filter,
568 eager=b.eager,
569 is_global=b.is_global,
570 save_before=b.save_before,
571 record_in_macro=b.record_in_macro,
572 )
573 )
574
575 self._bindings2 = bindings2
576 self._last_version = expected_version
577
578
579class _MergedKeyBindings(_Proxy):
580 """
581 Merge multiple registries of key bindings into one.
582
583 This class acts as a proxy to multiple :class:`.KeyBindings` objects, but
584 behaves as if this is just one bigger :class:`.KeyBindings`.
585
586 :param registries: List of :class:`.KeyBindings` objects.
587 """
588
589 def __init__(self, registries: Sequence[KeyBindingsBase]) -> None:
590 _Proxy.__init__(self)
591 self.registries = registries
592
593 def _update_cache(self) -> None:
594 """
595 If one of the original registries was changed. Update our merged
596 version.
597 """
598 expected_version = tuple(r._version for r in self.registries)
599
600 if self._last_version != expected_version:
601 bindings2 = KeyBindings()
602
603 for reg in self.registries:
604 bindings2.bindings.extend(reg.bindings)
605
606 self._bindings2 = bindings2
607 self._last_version = expected_version
608
609
610def merge_key_bindings(bindings: Sequence[KeyBindingsBase]) -> _MergedKeyBindings:
611 """
612 Merge multiple :class:`.Keybinding` objects together.
613
614 Usage::
615
616 bindings = merge_key_bindings([bindings1, bindings2, ...])
617 """
618 return _MergedKeyBindings(bindings)
619
620
621class DynamicKeyBindings(_Proxy):
622 """
623 KeyBindings class that can dynamically returns any KeyBindings.
624
625 :param get_key_bindings: Callable that returns a :class:`.KeyBindings` instance.
626 """
627
628 def __init__(self, get_key_bindings: Callable[[], KeyBindingsBase | None]) -> None:
629 self.get_key_bindings = get_key_bindings
630 self.__version = 0
631 self._last_child_version = None
632 self._dummy = KeyBindings() # Empty key bindings.
633
634 def _update_cache(self) -> None:
635 key_bindings = self.get_key_bindings() or self._dummy
636 assert isinstance(key_bindings, KeyBindingsBase)
637 version = id(key_bindings), key_bindings._version
638
639 self._bindings2 = key_bindings
640 self._last_version = version
641
642
643class GlobalOnlyKeyBindings(_Proxy):
644 """
645 Wrapper around a :class:`.KeyBindings` object that only exposes the global
646 key bindings.
647 """
648
649 def __init__(self, key_bindings: KeyBindingsBase) -> None:
650 _Proxy.__init__(self)
651 self.key_bindings = key_bindings
652
653 def _update_cache(self) -> None:
654 """
655 If one of the original registries was changed. Update our merged
656 version.
657 """
658 expected_version = self.key_bindings._version
659
660 if self._last_version != expected_version:
661 bindings2 = KeyBindings()
662
663 for b in self.key_bindings.bindings:
664 if b.is_global():
665 bindings2.bindings.append(b)
666
667 self._bindings2 = bindings2
668 self._last_version = expected_version