1from __future__ import annotations
2
3import sys
4from collections.abc import Awaitable, Callable, Generator
5from concurrent.futures import Future
6from contextlib import (
7 AbstractAsyncContextManager,
8 AbstractContextManager,
9 contextmanager,
10)
11from dataclasses import dataclass, field
12from inspect import isawaitable
13from threading import Lock, Thread, current_thread, get_ident
14from types import TracebackType
15from typing import (
16 Any,
17 Generic,
18 TypeVar,
19 cast,
20 overload,
21)
22
23from ._core import _eventloop
24from ._core._eventloop import get_async_backend, get_cancelled_exc_class, threadlocals
25from ._core._synchronization import Event
26from ._core._tasks import CancelScope, create_task_group
27from .abc import AsyncBackend
28from .abc._tasks import TaskStatus
29
30if sys.version_info >= (3, 11):
31 from typing import TypeVarTuple, Unpack
32else:
33 from typing_extensions import TypeVarTuple, Unpack
34
35T_Retval = TypeVar("T_Retval")
36T_co = TypeVar("T_co", covariant=True)
37PosArgsT = TypeVarTuple("PosArgsT")
38
39
40def run(
41 func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], *args: Unpack[PosArgsT]
42) -> T_Retval:
43 """
44 Call a coroutine function from a worker thread.
45
46 :param func: a coroutine function
47 :param args: positional arguments for the callable
48 :return: the return value of the coroutine function
49
50 """
51 try:
52 async_backend = threadlocals.current_async_backend
53 token = threadlocals.current_token
54 except AttributeError:
55 raise RuntimeError(
56 "This function can only be run from an AnyIO worker thread"
57 ) from None
58
59 return async_backend.run_async_from_thread(func, args, token=token)
60
61
62def run_sync(
63 func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
64) -> T_Retval:
65 """
66 Call a function in the event loop thread from a worker thread.
67
68 :param func: a callable
69 :param args: positional arguments for the callable
70 :return: the return value of the callable
71
72 """
73 try:
74 async_backend = threadlocals.current_async_backend
75 token = threadlocals.current_token
76 except AttributeError:
77 raise RuntimeError(
78 "This function can only be run from an AnyIO worker thread"
79 ) from None
80
81 return async_backend.run_sync_from_thread(func, args, token=token)
82
83
84class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager):
85 _enter_future: Future[T_co]
86 _exit_future: Future[bool | None]
87 _exit_event: Event
88 _exit_exc_info: tuple[
89 type[BaseException] | None, BaseException | None, TracebackType | None
90 ] = (None, None, None)
91
92 def __init__(
93 self, async_cm: AbstractAsyncContextManager[T_co], portal: BlockingPortal
94 ):
95 self._async_cm = async_cm
96 self._portal = portal
97
98 async def run_async_cm(self) -> bool | None:
99 try:
100 self._exit_event = Event()
101 value = await self._async_cm.__aenter__()
102 except BaseException as exc:
103 self._enter_future.set_exception(exc)
104 raise
105 else:
106 self._enter_future.set_result(value)
107
108 try:
109 # Wait for the sync context manager to exit.
110 # This next statement can raise `get_cancelled_exc_class()` if
111 # something went wrong in a task group in this async context
112 # manager.
113 await self._exit_event.wait()
114 finally:
115 # In case of cancellation, it could be that we end up here before
116 # `_BlockingAsyncContextManager.__exit__` is called, and an
117 # `_exit_exc_info` has been set.
118 result = await self._async_cm.__aexit__(*self._exit_exc_info)
119
120 return result
121
122 def __enter__(self) -> T_co:
123 self._enter_future = Future()
124 self._exit_future = self._portal.start_task_soon(self.run_async_cm)
125 return self._enter_future.result()
126
127 def __exit__(
128 self,
129 __exc_type: type[BaseException] | None,
130 __exc_value: BaseException | None,
131 __traceback: TracebackType | None,
132 ) -> bool | None:
133 self._exit_exc_info = __exc_type, __exc_value, __traceback
134 self._portal.call(self._exit_event.set)
135 return self._exit_future.result()
136
137
138class _BlockingPortalTaskStatus(TaskStatus):
139 def __init__(self, future: Future):
140 self._future = future
141
142 def started(self, value: object = None) -> None:
143 self._future.set_result(value)
144
145
146class BlockingPortal:
147 """An object that lets external threads run code in an asynchronous event loop."""
148
149 def __new__(cls) -> BlockingPortal:
150 return get_async_backend().create_blocking_portal()
151
152 def __init__(self) -> None:
153 self._event_loop_thread_id: int | None = get_ident()
154 self._stop_event = Event()
155 self._task_group = create_task_group()
156 self._cancelled_exc_class = get_cancelled_exc_class()
157
158 async def __aenter__(self) -> BlockingPortal:
159 await self._task_group.__aenter__()
160 return self
161
162 async def __aexit__(
163 self,
164 exc_type: type[BaseException] | None,
165 exc_val: BaseException | None,
166 exc_tb: TracebackType | None,
167 ) -> bool:
168 await self.stop()
169 return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
170
171 def _check_running(self) -> None:
172 if self._event_loop_thread_id is None:
173 raise RuntimeError("This portal is not running")
174 if self._event_loop_thread_id == get_ident():
175 raise RuntimeError(
176 "This method cannot be called from the event loop thread"
177 )
178
179 async def sleep_until_stopped(self) -> None:
180 """Sleep until :meth:`stop` is called."""
181 await self._stop_event.wait()
182
183 async def stop(self, cancel_remaining: bool = False) -> None:
184 """
185 Signal the portal to shut down.
186
187 This marks the portal as no longer accepting new calls and exits from
188 :meth:`sleep_until_stopped`.
189
190 :param cancel_remaining: ``True`` to cancel all the remaining tasks, ``False``
191 to let them finish before returning
192
193 """
194 self._event_loop_thread_id = None
195 self._stop_event.set()
196 if cancel_remaining:
197 self._task_group.cancel_scope.cancel()
198
199 async def _call_func(
200 self,
201 func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
202 args: tuple[Unpack[PosArgsT]],
203 kwargs: dict[str, Any],
204 future: Future[T_Retval],
205 ) -> None:
206 def callback(f: Future[T_Retval]) -> None:
207 if f.cancelled() and self._event_loop_thread_id not in (
208 None,
209 get_ident(),
210 ):
211 self.call(scope.cancel)
212
213 try:
214 retval_or_awaitable = func(*args, **kwargs)
215 if isawaitable(retval_or_awaitable):
216 with CancelScope() as scope:
217 if future.cancelled():
218 scope.cancel()
219 else:
220 future.add_done_callback(callback)
221
222 retval = await retval_or_awaitable
223 else:
224 retval = retval_or_awaitable
225 except self._cancelled_exc_class:
226 future.cancel()
227 future.set_running_or_notify_cancel()
228 except BaseException as exc:
229 if not future.cancelled():
230 future.set_exception(exc)
231
232 # Let base exceptions fall through
233 if not isinstance(exc, Exception):
234 raise
235 else:
236 if not future.cancelled():
237 future.set_result(retval)
238 finally:
239 scope = None # type: ignore[assignment]
240
241 def _spawn_task_from_thread(
242 self,
243 func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
244 args: tuple[Unpack[PosArgsT]],
245 kwargs: dict[str, Any],
246 name: object,
247 future: Future[T_Retval],
248 ) -> None:
249 """
250 Spawn a new task using the given callable.
251
252 Implementers must ensure that the future is resolved when the task finishes.
253
254 :param func: a callable
255 :param args: positional arguments to be passed to the callable
256 :param kwargs: keyword arguments to be passed to the callable
257 :param name: name of the task (will be coerced to a string if not ``None``)
258 :param future: a future that will resolve to the return value of the callable,
259 or the exception raised during its execution
260
261 """
262 raise NotImplementedError
263
264 @overload
265 def call(
266 self,
267 func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
268 *args: Unpack[PosArgsT],
269 ) -> T_Retval: ...
270
271 @overload
272 def call(
273 self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
274 ) -> T_Retval: ...
275
276 def call(
277 self,
278 func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
279 *args: Unpack[PosArgsT],
280 ) -> T_Retval:
281 """
282 Call the given function in the event loop thread.
283
284 If the callable returns a coroutine object, it is awaited on.
285
286 :param func: any callable
287 :raises RuntimeError: if the portal is not running or if this method is called
288 from within the event loop thread
289
290 """
291 return cast(T_Retval, self.start_task_soon(func, *args).result())
292
293 @overload
294 def start_task_soon(
295 self,
296 func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
297 *args: Unpack[PosArgsT],
298 name: object = None,
299 ) -> Future[T_Retval]: ...
300
301 @overload
302 def start_task_soon(
303 self,
304 func: Callable[[Unpack[PosArgsT]], T_Retval],
305 *args: Unpack[PosArgsT],
306 name: object = None,
307 ) -> Future[T_Retval]: ...
308
309 def start_task_soon(
310 self,
311 func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
312 *args: Unpack[PosArgsT],
313 name: object = None,
314 ) -> Future[T_Retval]:
315 """
316 Start a task in the portal's task group.
317
318 The task will be run inside a cancel scope which can be cancelled by cancelling
319 the returned future.
320
321 :param func: the target function
322 :param args: positional arguments passed to ``func``
323 :param name: name of the task (will be coerced to a string if not ``None``)
324 :return: a future that resolves with the return value of the callable if the
325 task completes successfully, or with the exception raised in the task
326 :raises RuntimeError: if the portal is not running or if this method is called
327 from within the event loop thread
328 :rtype: concurrent.futures.Future[T_Retval]
329
330 .. versionadded:: 3.0
331
332 """
333 self._check_running()
334 f: Future[T_Retval] = Future()
335 self._spawn_task_from_thread(func, args, {}, name, f)
336 return f
337
338 def start_task(
339 self,
340 func: Callable[..., Awaitable[T_Retval]],
341 *args: object,
342 name: object = None,
343 ) -> tuple[Future[T_Retval], Any]:
344 """
345 Start a task in the portal's task group and wait until it signals for readiness.
346
347 This method works the same way as :meth:`.abc.TaskGroup.start`.
348
349 :param func: the target function
350 :param args: positional arguments passed to ``func``
351 :param name: name of the task (will be coerced to a string if not ``None``)
352 :return: a tuple of (future, task_status_value) where the ``task_status_value``
353 is the value passed to ``task_status.started()`` from within the target
354 function
355 :rtype: tuple[concurrent.futures.Future[T_Retval], Any]
356
357 .. versionadded:: 3.0
358
359 """
360
361 def task_done(future: Future[T_Retval]) -> None:
362 if not task_status_future.done():
363 if future.cancelled():
364 task_status_future.cancel()
365 elif future.exception():
366 task_status_future.set_exception(future.exception())
367 else:
368 exc = RuntimeError(
369 "Task exited without calling task_status.started()"
370 )
371 task_status_future.set_exception(exc)
372
373 self._check_running()
374 task_status_future: Future = Future()
375 task_status = _BlockingPortalTaskStatus(task_status_future)
376 f: Future = Future()
377 f.add_done_callback(task_done)
378 self._spawn_task_from_thread(func, args, {"task_status": task_status}, name, f)
379 return f, task_status_future.result()
380
381 def wrap_async_context_manager(
382 self, cm: AbstractAsyncContextManager[T_co]
383 ) -> AbstractContextManager[T_co]:
384 """
385 Wrap an async context manager as a synchronous context manager via this portal.
386
387 Spawns a task that will call both ``__aenter__()`` and ``__aexit__()``, stopping
388 in the middle until the synchronous context manager exits.
389
390 :param cm: an asynchronous context manager
391 :return: a synchronous context manager
392
393 .. versionadded:: 2.1
394
395 """
396 return _BlockingAsyncContextManager(cm, self)
397
398
399@dataclass
400class BlockingPortalProvider:
401 """
402 A manager for a blocking portal. Used as a context manager. The first thread to
403 enter this context manager causes a blocking portal to be started with the specific
404 parameters, and the last thread to exit causes the portal to be shut down. Thus,
405 there will be exactly one blocking portal running in this context as long as at
406 least one thread has entered this context manager.
407
408 The parameters are the same as for :func:`~anyio.run`.
409
410 :param backend: name of the backend
411 :param backend_options: backend options
412
413 .. versionadded:: 4.4
414 """
415
416 backend: str = "asyncio"
417 backend_options: dict[str, Any] | None = None
418 _lock: Lock = field(init=False, default_factory=Lock)
419 _leases: int = field(init=False, default=0)
420 _portal: BlockingPortal = field(init=False)
421 _portal_cm: AbstractContextManager[BlockingPortal] | None = field(
422 init=False, default=None
423 )
424
425 def __enter__(self) -> BlockingPortal:
426 with self._lock:
427 if self._portal_cm is None:
428 self._portal_cm = start_blocking_portal(
429 self.backend, self.backend_options
430 )
431 self._portal = self._portal_cm.__enter__()
432
433 self._leases += 1
434 return self._portal
435
436 def __exit__(
437 self,
438 exc_type: type[BaseException] | None,
439 exc_val: BaseException | None,
440 exc_tb: TracebackType | None,
441 ) -> None:
442 portal_cm: AbstractContextManager[BlockingPortal] | None = None
443 with self._lock:
444 assert self._portal_cm
445 assert self._leases > 0
446 self._leases -= 1
447 if not self._leases:
448 portal_cm = self._portal_cm
449 self._portal_cm = None
450 del self._portal
451
452 if portal_cm:
453 portal_cm.__exit__(None, None, None)
454
455
456@contextmanager
457def start_blocking_portal(
458 backend: str = "asyncio",
459 backend_options: dict[str, Any] | None = None,
460 *,
461 name: str | None = None,
462) -> Generator[BlockingPortal, Any, None]:
463 """
464 Start a new event loop in a new thread and run a blocking portal in its main task.
465
466 The parameters are the same as for :func:`~anyio.run`.
467
468 :param backend: name of the backend
469 :param backend_options: backend options
470 :param name: name of the thread
471 :return: a context manager that yields a blocking portal
472
473 .. versionchanged:: 3.0
474 Usage as a context manager is now required.
475
476 """
477
478 async def run_portal() -> None:
479 async with BlockingPortal() as portal_:
480 if name is None:
481 current_thread().name = f"{backend}-portal-{id(portal_):x}"
482
483 future.set_result(portal_)
484 await portal_.sleep_until_stopped()
485
486 def run_blocking_portal() -> None:
487 if future.set_running_or_notify_cancel():
488 try:
489 _eventloop.run(
490 run_portal, backend=backend, backend_options=backend_options
491 )
492 except BaseException as exc:
493 if not future.done():
494 future.set_exception(exc)
495
496 future: Future[BlockingPortal] = Future()
497 thread = Thread(target=run_blocking_portal, daemon=True, name=name)
498 thread.start()
499 try:
500 cancel_remaining_tasks = False
501 portal = future.result()
502 try:
503 yield portal
504 except BaseException:
505 cancel_remaining_tasks = True
506 raise
507 finally:
508 try:
509 portal.call(portal.stop, cancel_remaining_tasks)
510 except RuntimeError:
511 pass
512 finally:
513 thread.join()
514
515
516def check_cancelled() -> None:
517 """
518 Check if the cancel scope of the host task's running the current worker thread has
519 been cancelled.
520
521 If the host task's current cancel scope has indeed been cancelled, the
522 backend-specific cancellation exception will be raised.
523
524 :raises RuntimeError: if the current thread was not spawned by
525 :func:`.to_thread.run_sync`
526
527 """
528 try:
529 async_backend: AsyncBackend = threadlocals.current_async_backend
530 except AttributeError:
531 raise RuntimeError(
532 "This function can only be run from an AnyIO worker thread"
533 ) from None
534
535 async_backend.check_cancelled()