1import sys
2import threading
3from collections import deque
4from concurrent.futures import Executor, Future
5from typing import Any, Callable, TypeVar
6
7if sys.version_info >= (3, 10):
8 from typing import ParamSpec
9else:
10 from typing_extensions import ParamSpec
11
12_T = TypeVar("_T")
13_P = ParamSpec("_P")
14_R = TypeVar("_R")
15
16
17class _WorkItem:
18 """
19 Represents an item needing to be run in the executor.
20 Copied from ThreadPoolExecutor (but it's private, so we're not going to rely on importing it)
21 """
22
23 def __init__(
24 self,
25 future: "Future[_R]",
26 fn: Callable[_P, _R],
27 *args: _P.args,
28 **kwargs: _P.kwargs,
29 ):
30 self.future = future
31 self.fn = fn
32 self.args = args
33 self.kwargs = kwargs
34
35 def run(self) -> None:
36 __traceback_hide__ = True # noqa: F841
37 if not self.future.set_running_or_notify_cancel():
38 return
39 try:
40 result = self.fn(*self.args, **self.kwargs)
41 except BaseException as exc:
42 self.future.set_exception(exc)
43 # Break a reference cycle with the exception 'exc'
44 self = None # type: ignore[assignment]
45 else:
46 self.future.set_result(result)
47
48
49class CurrentThreadExecutor(Executor):
50 """
51 An Executor that actually runs code in the thread it is instantiated in.
52 Passed to other threads running async code, so they can run sync code in
53 the thread they came from.
54 """
55
56 def __init__(self, old_executor: "CurrentThreadExecutor | None") -> None:
57 self._work_thread = threading.current_thread()
58 self._work_ready = threading.Condition(threading.Lock())
59 self._work_items = deque[_WorkItem]() # synchronized by _work_ready
60 self._broken = False # synchronized by _work_ready
61 self._old_executor = old_executor
62
63 def run_until_future(self, future: "Future[Any]") -> None:
64 """
65 Runs the code in the work queue until a result is available from the future.
66 Should be run from the thread the executor is initialised in.
67 """
68 # Check we're in the right thread
69 if threading.current_thread() != self._work_thread:
70 raise RuntimeError(
71 "You cannot run CurrentThreadExecutor from a different thread"
72 )
73
74 def done(future: "Future[Any]") -> None:
75 with self._work_ready:
76 self._broken = True
77 self._work_ready.notify()
78
79 future.add_done_callback(done)
80 # Keep getting and running work items until the future we're waiting for
81 # is done and the queue is empty.
82 while True:
83 with self._work_ready:
84 while not self._work_items and not self._broken:
85 self._work_ready.wait()
86 if not self._work_items:
87 break
88 # Get a work item and run it
89 work_item = self._work_items.popleft()
90 work_item.run()
91 del work_item
92
93 def submit(
94 self,
95 fn: Callable[_P, _R],
96 /,
97 *args: _P.args,
98 **kwargs: _P.kwargs,
99 ) -> "Future[_R]":
100 # Check they're not submitting from the same thread
101 if threading.current_thread() == self._work_thread:
102 raise RuntimeError(
103 "You cannot submit onto CurrentThreadExecutor from its own thread"
104 )
105 f: "Future[_R]" = Future()
106 work_item = _WorkItem(f, fn, *args, **kwargs)
107
108 # Walk up the CurrentThreadExecutor stack to find the closest one still
109 # running
110 executor = self
111 while True:
112 with executor._work_ready:
113 if not executor._broken:
114 # Add to work queue
115 executor._work_items.append(work_item)
116 executor._work_ready.notify()
117 break
118 if executor._old_executor is None:
119 raise RuntimeError("CurrentThreadExecutor already quit or is broken")
120 executor = executor._old_executor
121
122 # Return the future
123 return f