1from __future__ import annotations
2
3import sys
4from collections.abc import Awaitable, Callable, Generator
5from concurrent.futures import Future
6from contextlib import (
7 AbstractAsyncContextManager,
8 AbstractContextManager,
9 contextmanager,
10)
11from dataclasses import dataclass, field
12from inspect import isawaitable
13from threading import Lock, Thread, get_ident
14from types import TracebackType
15from typing import (
16 Any,
17 Generic,
18 TypeVar,
19 cast,
20 overload,
21)
22
23from ._core import _eventloop
24from ._core._eventloop import get_async_backend, get_cancelled_exc_class, threadlocals
25from ._core._synchronization import Event
26from ._core._tasks import CancelScope, create_task_group
27from .abc import AsyncBackend
28from .abc._tasks import TaskStatus
29
30if sys.version_info >= (3, 11):
31 from typing import TypeVarTuple, Unpack
32else:
33 from typing_extensions import TypeVarTuple, Unpack
34
35T_Retval = TypeVar("T_Retval")
36T_co = TypeVar("T_co", covariant=True)
37PosArgsT = TypeVarTuple("PosArgsT")
38
39
40def run(
41 func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], *args: Unpack[PosArgsT]
42) -> T_Retval:
43 """
44 Call a coroutine function from a worker thread.
45
46 :param func: a coroutine function
47 :param args: positional arguments for the callable
48 :return: the return value of the coroutine function
49
50 """
51 try:
52 async_backend = threadlocals.current_async_backend
53 token = threadlocals.current_token
54 except AttributeError:
55 raise RuntimeError(
56 "This function can only be run from an AnyIO worker thread"
57 ) from None
58
59 return async_backend.run_async_from_thread(func, args, token=token)
60
61
62def run_sync(
63 func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
64) -> T_Retval:
65 """
66 Call a function in the event loop thread from a worker thread.
67
68 :param func: a callable
69 :param args: positional arguments for the callable
70 :return: the return value of the callable
71
72 """
73 try:
74 async_backend = threadlocals.current_async_backend
75 token = threadlocals.current_token
76 except AttributeError:
77 raise RuntimeError(
78 "This function can only be run from an AnyIO worker thread"
79 ) from None
80
81 return async_backend.run_sync_from_thread(func, args, token=token)
82
83
84class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager):
85 _enter_future: Future[T_co]
86 _exit_future: Future[bool | None]
87 _exit_event: Event
88 _exit_exc_info: tuple[
89 type[BaseException] | None, BaseException | None, TracebackType | None
90 ] = (None, None, None)
91
92 def __init__(
93 self, async_cm: AbstractAsyncContextManager[T_co], portal: BlockingPortal
94 ):
95 self._async_cm = async_cm
96 self._portal = portal
97
98 async def run_async_cm(self) -> bool | None:
99 try:
100 self._exit_event = Event()
101 value = await self._async_cm.__aenter__()
102 except BaseException as exc:
103 self._enter_future.set_exception(exc)
104 raise
105 else:
106 self._enter_future.set_result(value)
107
108 try:
109 # Wait for the sync context manager to exit.
110 # This next statement can raise `get_cancelled_exc_class()` if
111 # something went wrong in a task group in this async context
112 # manager.
113 await self._exit_event.wait()
114 finally:
115 # In case of cancellation, it could be that we end up here before
116 # `_BlockingAsyncContextManager.__exit__` is called, and an
117 # `_exit_exc_info` has been set.
118 result = await self._async_cm.__aexit__(*self._exit_exc_info)
119 return result
120
121 def __enter__(self) -> T_co:
122 self._enter_future = Future()
123 self._exit_future = self._portal.start_task_soon(self.run_async_cm)
124 return self._enter_future.result()
125
126 def __exit__(
127 self,
128 __exc_type: type[BaseException] | None,
129 __exc_value: BaseException | None,
130 __traceback: TracebackType | None,
131 ) -> bool | None:
132 self._exit_exc_info = __exc_type, __exc_value, __traceback
133 self._portal.call(self._exit_event.set)
134 return self._exit_future.result()
135
136
137class _BlockingPortalTaskStatus(TaskStatus):
138 def __init__(self, future: Future):
139 self._future = future
140
141 def started(self, value: object = None) -> None:
142 self._future.set_result(value)
143
144
145class BlockingPortal:
146 """An object that lets external threads run code in an asynchronous event loop."""
147
148 def __new__(cls) -> BlockingPortal:
149 return get_async_backend().create_blocking_portal()
150
151 def __init__(self) -> None:
152 self._event_loop_thread_id: int | None = get_ident()
153 self._stop_event = Event()
154 self._task_group = create_task_group()
155 self._cancelled_exc_class = get_cancelled_exc_class()
156
157 async def __aenter__(self) -> BlockingPortal:
158 await self._task_group.__aenter__()
159 return self
160
161 async def __aexit__(
162 self,
163 exc_type: type[BaseException] | None,
164 exc_val: BaseException | None,
165 exc_tb: TracebackType | None,
166 ) -> bool | None:
167 await self.stop()
168 return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
169
170 def _check_running(self) -> None:
171 if self._event_loop_thread_id is None:
172 raise RuntimeError("This portal is not running")
173 if self._event_loop_thread_id == get_ident():
174 raise RuntimeError(
175 "This method cannot be called from the event loop thread"
176 )
177
178 async def sleep_until_stopped(self) -> None:
179 """Sleep until :meth:`stop` is called."""
180 await self._stop_event.wait()
181
182 async def stop(self, cancel_remaining: bool = False) -> None:
183 """
184 Signal the portal to shut down.
185
186 This marks the portal as no longer accepting new calls and exits from
187 :meth:`sleep_until_stopped`.
188
189 :param cancel_remaining: ``True`` to cancel all the remaining tasks, ``False``
190 to let them finish before returning
191
192 """
193 self._event_loop_thread_id = None
194 self._stop_event.set()
195 if cancel_remaining:
196 self._task_group.cancel_scope.cancel()
197
198 async def _call_func(
199 self,
200 func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
201 args: tuple[Unpack[PosArgsT]],
202 kwargs: dict[str, Any],
203 future: Future[T_Retval],
204 ) -> None:
205 def callback(f: Future[T_Retval]) -> None:
206 if f.cancelled() and self._event_loop_thread_id not in (
207 None,
208 get_ident(),
209 ):
210 self.call(scope.cancel)
211
212 try:
213 retval_or_awaitable = func(*args, **kwargs)
214 if isawaitable(retval_or_awaitable):
215 with CancelScope() as scope:
216 if future.cancelled():
217 scope.cancel()
218 else:
219 future.add_done_callback(callback)
220
221 retval = await retval_or_awaitable
222 else:
223 retval = retval_or_awaitable
224 except self._cancelled_exc_class:
225 future.cancel()
226 future.set_running_or_notify_cancel()
227 except BaseException as exc:
228 if not future.cancelled():
229 future.set_exception(exc)
230
231 # Let base exceptions fall through
232 if not isinstance(exc, Exception):
233 raise
234 else:
235 if not future.cancelled():
236 future.set_result(retval)
237 finally:
238 scope = None # type: ignore[assignment]
239
240 def _spawn_task_from_thread(
241 self,
242 func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
243 args: tuple[Unpack[PosArgsT]],
244 kwargs: dict[str, Any],
245 name: object,
246 future: Future[T_Retval],
247 ) -> None:
248 """
249 Spawn a new task using the given callable.
250
251 Implementers must ensure that the future is resolved when the task finishes.
252
253 :param func: a callable
254 :param args: positional arguments to be passed to the callable
255 :param kwargs: keyword arguments to be passed to the callable
256 :param name: name of the task (will be coerced to a string if not ``None``)
257 :param future: a future that will resolve to the return value of the callable,
258 or the exception raised during its execution
259
260 """
261 raise NotImplementedError
262
263 @overload
264 def call(
265 self,
266 func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
267 *args: Unpack[PosArgsT],
268 ) -> T_Retval: ...
269
270 @overload
271 def call(
272 self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
273 ) -> T_Retval: ...
274
275 def call(
276 self,
277 func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
278 *args: Unpack[PosArgsT],
279 ) -> T_Retval:
280 """
281 Call the given function in the event loop thread.
282
283 If the callable returns a coroutine object, it is awaited on.
284
285 :param func: any callable
286 :raises RuntimeError: if the portal is not running or if this method is called
287 from within the event loop thread
288
289 """
290 return cast(T_Retval, self.start_task_soon(func, *args).result())
291
292 @overload
293 def start_task_soon(
294 self,
295 func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
296 *args: Unpack[PosArgsT],
297 name: object = None,
298 ) -> Future[T_Retval]: ...
299
300 @overload
301 def start_task_soon(
302 self,
303 func: Callable[[Unpack[PosArgsT]], T_Retval],
304 *args: Unpack[PosArgsT],
305 name: object = None,
306 ) -> Future[T_Retval]: ...
307
308 def start_task_soon(
309 self,
310 func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
311 *args: Unpack[PosArgsT],
312 name: object = None,
313 ) -> Future[T_Retval]:
314 """
315 Start a task in the portal's task group.
316
317 The task will be run inside a cancel scope which can be cancelled by cancelling
318 the returned future.
319
320 :param func: the target function
321 :param args: positional arguments passed to ``func``
322 :param name: name of the task (will be coerced to a string if not ``None``)
323 :return: a future that resolves with the return value of the callable if the
324 task completes successfully, or with the exception raised in the task
325 :raises RuntimeError: if the portal is not running or if this method is called
326 from within the event loop thread
327 :rtype: concurrent.futures.Future[T_Retval]
328
329 .. versionadded:: 3.0
330
331 """
332 self._check_running()
333 f: Future[T_Retval] = Future()
334 self._spawn_task_from_thread(func, args, {}, name, f)
335 return f
336
337 def start_task(
338 self,
339 func: Callable[..., Awaitable[T_Retval]],
340 *args: object,
341 name: object = None,
342 ) -> tuple[Future[T_Retval], Any]:
343 """
344 Start a task in the portal's task group and wait until it signals for readiness.
345
346 This method works the same way as :meth:`.abc.TaskGroup.start`.
347
348 :param func: the target function
349 :param args: positional arguments passed to ``func``
350 :param name: name of the task (will be coerced to a string if not ``None``)
351 :return: a tuple of (future, task_status_value) where the ``task_status_value``
352 is the value passed to ``task_status.started()`` from within the target
353 function
354 :rtype: tuple[concurrent.futures.Future[T_Retval], Any]
355
356 .. versionadded:: 3.0
357
358 """
359
360 def task_done(future: Future[T_Retval]) -> None:
361 if not task_status_future.done():
362 if future.cancelled():
363 task_status_future.cancel()
364 elif future.exception():
365 task_status_future.set_exception(future.exception())
366 else:
367 exc = RuntimeError(
368 "Task exited without calling task_status.started()"
369 )
370 task_status_future.set_exception(exc)
371
372 self._check_running()
373 task_status_future: Future = Future()
374 task_status = _BlockingPortalTaskStatus(task_status_future)
375 f: Future = Future()
376 f.add_done_callback(task_done)
377 self._spawn_task_from_thread(func, args, {"task_status": task_status}, name, f)
378 return f, task_status_future.result()
379
380 def wrap_async_context_manager(
381 self, cm: AbstractAsyncContextManager[T_co]
382 ) -> AbstractContextManager[T_co]:
383 """
384 Wrap an async context manager as a synchronous context manager via this portal.
385
386 Spawns a task that will call both ``__aenter__()`` and ``__aexit__()``, stopping
387 in the middle until the synchronous context manager exits.
388
389 :param cm: an asynchronous context manager
390 :return: a synchronous context manager
391
392 .. versionadded:: 2.1
393
394 """
395 return _BlockingAsyncContextManager(cm, self)
396
397
398@dataclass
399class BlockingPortalProvider:
400 """
401 A manager for a blocking portal. Used as a context manager. The first thread to
402 enter this context manager causes a blocking portal to be started with the specific
403 parameters, and the last thread to exit causes the portal to be shut down. Thus,
404 there will be exactly one blocking portal running in this context as long as at
405 least one thread has entered this context manager.
406
407 The parameters are the same as for :func:`~anyio.run`.
408
409 :param backend: name of the backend
410 :param backend_options: backend options
411
412 .. versionadded:: 4.4
413 """
414
415 backend: str = "asyncio"
416 backend_options: dict[str, Any] | None = None
417 _lock: Lock = field(init=False, default_factory=Lock)
418 _leases: int = field(init=False, default=0)
419 _portal: BlockingPortal = field(init=False)
420 _portal_cm: AbstractContextManager[BlockingPortal] | None = field(
421 init=False, default=None
422 )
423
424 def __enter__(self) -> BlockingPortal:
425 with self._lock:
426 if self._portal_cm is None:
427 self._portal_cm = start_blocking_portal(
428 self.backend, self.backend_options
429 )
430 self._portal = self._portal_cm.__enter__()
431
432 self._leases += 1
433 return self._portal
434
435 def __exit__(
436 self,
437 exc_type: type[BaseException] | None,
438 exc_val: BaseException | None,
439 exc_tb: TracebackType | None,
440 ) -> None:
441 portal_cm: AbstractContextManager[BlockingPortal] | None = None
442 with self._lock:
443 assert self._portal_cm
444 assert self._leases > 0
445 self._leases -= 1
446 if not self._leases:
447 portal_cm = self._portal_cm
448 self._portal_cm = None
449 del self._portal
450
451 if portal_cm:
452 portal_cm.__exit__(None, None, None)
453
454
455@contextmanager
456def start_blocking_portal(
457 backend: str = "asyncio", backend_options: dict[str, Any] | None = None
458) -> Generator[BlockingPortal, Any, None]:
459 """
460 Start a new event loop in a new thread and run a blocking portal in its main task.
461
462 The parameters are the same as for :func:`~anyio.run`.
463
464 :param backend: name of the backend
465 :param backend_options: backend options
466 :return: a context manager that yields a blocking portal
467
468 .. versionchanged:: 3.0
469 Usage as a context manager is now required.
470
471 """
472
473 async def run_portal() -> None:
474 async with BlockingPortal() as portal_:
475 future.set_result(portal_)
476 await portal_.sleep_until_stopped()
477
478 def run_blocking_portal() -> None:
479 if future.set_running_or_notify_cancel():
480 try:
481 _eventloop.run(
482 run_portal, backend=backend, backend_options=backend_options
483 )
484 except BaseException as exc:
485 if not future.done():
486 future.set_exception(exc)
487
488 future: Future[BlockingPortal] = Future()
489 thread = Thread(target=run_blocking_portal, daemon=True)
490 thread.start()
491 try:
492 cancel_remaining_tasks = False
493 portal = future.result()
494 try:
495 yield portal
496 except BaseException:
497 cancel_remaining_tasks = True
498 raise
499 finally:
500 try:
501 portal.call(portal.stop, cancel_remaining_tasks)
502 except RuntimeError:
503 pass
504 finally:
505 thread.join()
506
507
508def check_cancelled() -> None:
509 """
510 Check if the cancel scope of the host task's running the current worker thread has
511 been cancelled.
512
513 If the host task's current cancel scope has indeed been cancelled, the
514 backend-specific cancellation exception will be raised.
515
516 :raises RuntimeError: if the current thread was not spawned by
517 :func:`.to_thread.run_sync`
518
519 """
520 try:
521 async_backend: AsyncBackend = threadlocals.current_async_backend
522 except AttributeError:
523 raise RuntimeError(
524 "This function can only be run from an AnyIO worker thread"
525 ) from None
526
527 async_backend.check_cancelled()