1import queue
2import sys
3import threading
4from concurrent.futures import Executor, Future
5from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union
6
7if sys.version_info >= (3, 10):
8 from typing import ParamSpec
9else:
10 from typing_extensions import ParamSpec
11
12_T = TypeVar("_T")
13_P = ParamSpec("_P")
14_R = TypeVar("_R")
15
16
17class _WorkItem:
18 """
19 Represents an item needing to be run in the executor.
20 Copied from ThreadPoolExecutor (but it's private, so we're not going to rely on importing it)
21 """
22
23 def __init__(
24 self,
25 future: "Future[_R]",
26 fn: Callable[_P, _R],
27 *args: _P.args,
28 **kwargs: _P.kwargs,
29 ):
30 self.future = future
31 self.fn = fn
32 self.args = args
33 self.kwargs = kwargs
34
35 def run(self) -> None:
36 __traceback_hide__ = True # noqa: F841
37 if not self.future.set_running_or_notify_cancel():
38 return
39 try:
40 result = self.fn(*self.args, **self.kwargs)
41 except BaseException as exc:
42 self.future.set_exception(exc)
43 # Break a reference cycle with the exception 'exc'
44 self = None # type: ignore[assignment]
45 else:
46 self.future.set_result(result)
47
48
49class CurrentThreadExecutor(Executor):
50 """
51 An Executor that actually runs code in the thread it is instantiated in.
52 Passed to other threads running async code, so they can run sync code in
53 the thread they came from.
54 """
55
56 def __init__(self) -> None:
57 self._work_thread = threading.current_thread()
58 self._work_queue: queue.Queue[Union[_WorkItem, "Future[Any]"]] = queue.Queue()
59 self._broken = False
60
61 def run_until_future(self, future: "Future[Any]") -> None:
62 """
63 Runs the code in the work queue until a result is available from the future.
64 Should be run from the thread the executor is initialised in.
65 """
66 # Check we're in the right thread
67 if threading.current_thread() != self._work_thread:
68 raise RuntimeError(
69 "You cannot run CurrentThreadExecutor from a different thread"
70 )
71 future.add_done_callback(self._work_queue.put)
72 # Keep getting and running work items until we get the future we're waiting for
73 # back via the future's done callback.
74 try:
75 while True:
76 # Get a work item and run it
77 work_item = self._work_queue.get()
78 if work_item is future:
79 return
80 assert isinstance(work_item, _WorkItem)
81 work_item.run()
82 del work_item
83 finally:
84 self._broken = True
85
86 def _submit(
87 self,
88 fn: Callable[_P, _R],
89 *args: _P.args,
90 **kwargs: _P.kwargs,
91 ) -> "Future[_R]":
92 # Check they're not submitting from the same thread
93 if threading.current_thread() == self._work_thread:
94 raise RuntimeError(
95 "You cannot submit onto CurrentThreadExecutor from its own thread"
96 )
97 # Check they're not too late or the executor errored
98 if self._broken:
99 raise RuntimeError("CurrentThreadExecutor already quit or is broken")
100 # Add to work queue
101 f: "Future[_R]" = Future()
102 work_item = _WorkItem(f, fn, *args, **kwargs)
103 self._work_queue.put(work_item)
104 # Return the future
105 return f
106
107 # Python 3.9+ has a new signature for submit with a "/" after `fn`, to enforce
108 # it to be a positional argument. If we ignore[override] mypy on 3.9+ will be
109 # happy but 3.8 will say that the ignore comment is unused, even when
110 # defining them differently based on sys.version_info.
111 # We should be able to remove this when we drop support for 3.8.
112 if not TYPE_CHECKING:
113
114 def submit(self, fn, *args, **kwargs):
115 return self._submit(fn, *args, **kwargs)