1from __future__ import annotations
2
3import os
4import signal
5import sys
6import threading
7from collections import deque
8from typing import (
9 Callable,
10 ContextManager,
11 Dict,
12 Generator,
13 Generic,
14 TypeVar,
15 Union,
16)
17
18from wcwidth import wcwidth
19
20__all__ = [
21 "Event",
22 "DummyContext",
23 "get_cwidth",
24 "suspend_to_background_supported",
25 "is_conemu_ansi",
26 "is_windows",
27 "in_main_thread",
28 "get_bell_environment_variable",
29 "get_term_environment_variable",
30 "take_using_weights",
31 "to_str",
32 "to_int",
33 "AnyFloat",
34 "to_float",
35 "is_dumb_terminal",
36]
37
38# Used to ensure sphinx autodoc does not try to import platform-specific
39# stuff when documenting win32.py modules.
40SPHINX_AUTODOC_RUNNING = "sphinx.ext.autodoc" in sys.modules
41
42_Sender = TypeVar("_Sender", covariant=True)
43
44
45class Event(Generic[_Sender]):
46 """
47 Simple event to which event handlers can be attached. For instance::
48
49 class Cls:
50 def __init__(self):
51 # Define event. The first parameter is the sender.
52 self.event = Event(self)
53
54 obj = Cls()
55
56 def handler(sender):
57 pass
58
59 # Add event handler by using the += operator.
60 obj.event += handler
61
62 # Fire event.
63 obj.event()
64 """
65
66 def __init__(
67 self, sender: _Sender, handler: Callable[[_Sender], None] | None = None
68 ) -> None:
69 self.sender = sender
70 self._handlers: list[Callable[[_Sender], None]] = []
71
72 if handler is not None:
73 self += handler
74
75 def __call__(self) -> None:
76 "Fire event."
77 for handler in self._handlers:
78 handler(self.sender)
79
80 def fire(self) -> None:
81 "Alias for just calling the event."
82 self()
83
84 def add_handler(self, handler: Callable[[_Sender], None]) -> None:
85 """
86 Add another handler to this callback.
87 (Handler should be a callable that takes exactly one parameter: the
88 sender object.)
89 """
90 # Add to list of event handlers.
91 self._handlers.append(handler)
92
93 def remove_handler(self, handler: Callable[[_Sender], None]) -> None:
94 """
95 Remove a handler from this callback.
96 """
97 if handler in self._handlers:
98 self._handlers.remove(handler)
99
100 def __iadd__(self, handler: Callable[[_Sender], None]) -> Event[_Sender]:
101 """
102 `event += handler` notation for adding a handler.
103 """
104 self.add_handler(handler)
105 return self
106
107 def __isub__(self, handler: Callable[[_Sender], None]) -> Event[_Sender]:
108 """
109 `event -= handler` notation for removing a handler.
110 """
111 self.remove_handler(handler)
112 return self
113
114
115class DummyContext(ContextManager[None]):
116 """
117 (contextlib.nested is not available on Py3)
118 """
119
120 def __enter__(self) -> None:
121 pass
122
123 def __exit__(self, *a: object) -> None:
124 pass
125
126
127class _CharSizesCache(Dict[str, int]):
128 """
129 Cache for wcwidth sizes.
130 """
131
132 LONG_STRING_MIN_LEN = 64 # Minimum string length for considering it long.
133 MAX_LONG_STRINGS = 16 # Maximum number of long strings to remember.
134
135 def __init__(self) -> None:
136 super().__init__()
137 # Keep track of the "long" strings in this cache.
138 self._long_strings: deque[str] = deque()
139
140 def __missing__(self, string: str) -> int:
141 # Note: We use the `max(0, ...` because some non printable control
142 # characters, like e.g. Ctrl-underscore get a -1 wcwidth value.
143 # It can be possible that these characters end up in the input
144 # text.
145 result: int
146 if len(string) == 1:
147 result = max(0, wcwidth(string))
148 else:
149 result = sum(self[c] for c in string)
150
151 # Store in cache.
152 self[string] = result
153
154 # Rotate long strings.
155 # (It's hard to tell what we can consider short...)
156 if len(string) > self.LONG_STRING_MIN_LEN:
157 long_strings = self._long_strings
158 long_strings.append(string)
159
160 if len(long_strings) > self.MAX_LONG_STRINGS:
161 key_to_remove = long_strings.popleft()
162 if key_to_remove in self:
163 del self[key_to_remove]
164
165 return result
166
167
168_CHAR_SIZES_CACHE = _CharSizesCache()
169
170
171def get_cwidth(string: str) -> int:
172 """
173 Return width of a string. Wrapper around ``wcwidth``.
174 """
175 return _CHAR_SIZES_CACHE[string]
176
177
178def suspend_to_background_supported() -> bool:
179 """
180 Returns `True` when the Python implementation supports
181 suspend-to-background. This is typically `False' on Windows systems.
182 """
183 return hasattr(signal, "SIGTSTP")
184
185
186def is_windows() -> bool:
187 """
188 True when we are using Windows.
189 """
190 return sys.platform == "win32" # Not 'darwin' or 'linux2'
191
192
193def is_windows_vt100_supported() -> bool:
194 """
195 True when we are using Windows, but VT100 escape sequences are supported.
196 """
197 if sys.platform == "win32":
198 # Import needs to be inline. Windows libraries are not always available.
199 from prompt_toolkit.output.windows10 import is_win_vt100_enabled
200
201 return is_win_vt100_enabled()
202
203 return False
204
205
206def is_conemu_ansi() -> bool:
207 """
208 True when the ConEmu Windows console is used.
209 """
210 return sys.platform == "win32" and os.environ.get("ConEmuANSI", "OFF") == "ON"
211
212
213def in_main_thread() -> bool:
214 """
215 True when the current thread is the main thread.
216 """
217 return threading.current_thread().__class__.__name__ == "_MainThread"
218
219
220def get_bell_environment_variable() -> bool:
221 """
222 True if env variable is set to true (true, TRUE, True, 1).
223 """
224 value = os.environ.get("PROMPT_TOOLKIT_BELL", "true")
225 return value.lower() in ("1", "true")
226
227
228def get_term_environment_variable() -> str:
229 "Return the $TERM environment variable."
230 return os.environ.get("TERM", "")
231
232
233_T = TypeVar("_T")
234
235
236def take_using_weights(
237 items: list[_T], weights: list[int]
238) -> Generator[_T, None, None]:
239 """
240 Generator that keeps yielding items from the items list, in proportion to
241 their weight. For instance::
242
243 # Getting the first 70 items from this generator should have yielded 10
244 # times A, 20 times B and 40 times C, all distributed equally..
245 take_using_weights(['A', 'B', 'C'], [5, 10, 20])
246
247 :param items: List of items to take from.
248 :param weights: Integers representing the weight. (Numbers have to be
249 integers, not floats.)
250 """
251 assert len(items) == len(weights)
252 assert len(items) > 0
253
254 # Remove items with zero-weight.
255 items2 = []
256 weights2 = []
257 for item, w in zip(items, weights):
258 if w > 0:
259 items2.append(item)
260 weights2.append(w)
261
262 items = items2
263 weights = weights2
264
265 # Make sure that we have some items left.
266 if not items:
267 raise ValueError("Did't got any items with a positive weight.")
268
269 #
270 already_taken = [0 for i in items]
271 item_count = len(items)
272 max_weight = max(weights)
273
274 i = 0
275 while True:
276 # Each iteration of this loop, we fill up until by (total_weight/max_weight).
277 adding = True
278 while adding:
279 adding = False
280
281 for item_i, item, weight in zip(range(item_count), items, weights):
282 if already_taken[item_i] < i * weight / float(max_weight):
283 yield item
284 already_taken[item_i] += 1
285 adding = True
286
287 i += 1
288
289
290def to_str(value: Callable[[], str] | str) -> str:
291 "Turn callable or string into string."
292 if callable(value):
293 return to_str(value())
294 else:
295 return str(value)
296
297
298def to_int(value: Callable[[], int] | int) -> int:
299 "Turn callable or int into int."
300 if callable(value):
301 return to_int(value())
302 else:
303 return int(value)
304
305
306AnyFloat = Union[Callable[[], float], float]
307
308
309def to_float(value: AnyFloat) -> float:
310 "Turn callable or float into float."
311 if callable(value):
312 return to_float(value())
313 else:
314 return float(value)
315
316
317def is_dumb_terminal(term: str | None = None) -> bool:
318 """
319 True if this terminal type is considered "dumb".
320
321 If so, we should fall back to the simplest possible form of line editing,
322 without cursor positioning and color support.
323 """
324 if term is None:
325 return is_dumb_terminal(os.environ.get("TERM", ""))
326
327 return term.lower() in ["dumb", "unknown"]