1from __future__ import annotations
2
3import sys
4from collections.abc import Awaitable, Callable, Generator
5from concurrent.futures import Future
6from contextlib import (
7 AbstractAsyncContextManager,
8 AbstractContextManager,
9 contextmanager,
10)
11from dataclasses import dataclass, field
12from inspect import isawaitable
13from threading import Lock, Thread, current_thread, get_ident
14from types import TracebackType
15from typing import (
16 Any,
17 Generic,
18 TypeVar,
19 cast,
20 overload,
21)
22
23from ._core._eventloop import (
24 get_async_backend,
25 get_cancelled_exc_class,
26 threadlocals,
27)
28from ._core._eventloop import run as run_eventloop
29from ._core._exceptions import NoEventLoopError
30from ._core._synchronization import Event
31from ._core._tasks import CancelScope, create_task_group
32from .abc._tasks import TaskStatus
33from .lowlevel import EventLoopToken
34
35if sys.version_info >= (3, 11):
36 from typing import TypeVarTuple, Unpack
37else:
38 from typing_extensions import TypeVarTuple, Unpack
39
40T_Retval = TypeVar("T_Retval")
41T_co = TypeVar("T_co", covariant=True)
42PosArgsT = TypeVarTuple("PosArgsT")
43
44
45def _token_or_error(token: EventLoopToken | None) -> EventLoopToken:
46 if token is not None:
47 return token
48
49 try:
50 return threadlocals.current_token
51 except AttributeError:
52 raise NoEventLoopError(
53 "Not running inside an AnyIO worker thread, and no event loop token was "
54 "provided"
55 ) from None
56
57
58def run(
59 func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
60 *args: Unpack[PosArgsT],
61 token: EventLoopToken | None = None,
62) -> T_Retval:
63 """
64 Call a coroutine function from a worker thread.
65
66 :param func: a coroutine function
67 :param args: positional arguments for the callable
68 :param token: an event loop token to use to get back to the event loop thread
69 (required if calling this function from outside an AnyIO worker thread)
70 :return: the return value of the coroutine function
71 :raises MissingTokenError: if no token was provided and called from outside an
72 AnyIO worker thread
73 :raises RunFinishedError: if the event loop tied to ``token`` is no longer running
74
75 .. versionchanged:: 4.11.0
76 Added the ``token`` parameter.
77
78 """
79 explicit_token = token is not None
80 token = _token_or_error(token)
81 return token.backend_class.run_async_from_thread(
82 func, args, token=token.native_token if explicit_token else None
83 )
84
85
86def run_sync(
87 func: Callable[[Unpack[PosArgsT]], T_Retval],
88 *args: Unpack[PosArgsT],
89 token: EventLoopToken | None = None,
90) -> T_Retval:
91 """
92 Call a function in the event loop thread from a worker thread.
93
94 :param func: a callable
95 :param args: positional arguments for the callable
96 :param token: an event loop token to use to get back to the event loop thread
97 (required if calling this function from outside an AnyIO worker thread)
98 :return: the return value of the callable
99 :raises MissingTokenError: if no token was provided and called from outside an
100 AnyIO worker thread
101 :raises RunFinishedError: if the event loop tied to ``token`` is no longer running
102
103 .. versionchanged:: 4.11.0
104 Added the ``token`` parameter.
105
106 """
107 explicit_token = token is not None
108 token = _token_or_error(token)
109 return token.backend_class.run_sync_from_thread(
110 func, args, token=token.native_token if explicit_token else None
111 )
112
113
114class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager):
115 _enter_future: Future[T_co]
116 _exit_future: Future[bool | None]
117 _exit_event: Event
118 _exit_exc_info: tuple[
119 type[BaseException] | None, BaseException | None, TracebackType | None
120 ] = (None, None, None)
121
122 def __init__(
123 self, async_cm: AbstractAsyncContextManager[T_co], portal: BlockingPortal
124 ):
125 self._async_cm = async_cm
126 self._portal = portal
127
128 async def run_async_cm(self) -> bool | None:
129 try:
130 self._exit_event = Event()
131 value = await self._async_cm.__aenter__()
132 except BaseException as exc:
133 self._enter_future.set_exception(exc)
134 raise
135 else:
136 self._enter_future.set_result(value)
137
138 try:
139 # Wait for the sync context manager to exit.
140 # This next statement can raise `get_cancelled_exc_class()` if
141 # something went wrong in a task group in this async context
142 # manager.
143 await self._exit_event.wait()
144 finally:
145 # In case of cancellation, it could be that we end up here before
146 # `_BlockingAsyncContextManager.__exit__` is called, and an
147 # `_exit_exc_info` has been set.
148 result = await self._async_cm.__aexit__(*self._exit_exc_info)
149
150 return result
151
152 def __enter__(self) -> T_co:
153 self._enter_future = Future()
154 self._exit_future = self._portal.start_task_soon(self.run_async_cm)
155 return self._enter_future.result()
156
157 def __exit__(
158 self,
159 __exc_type: type[BaseException] | None,
160 __exc_value: BaseException | None,
161 __traceback: TracebackType | None,
162 ) -> bool | None:
163 self._exit_exc_info = __exc_type, __exc_value, __traceback
164 self._portal.call(self._exit_event.set)
165 return self._exit_future.result()
166
167
168class _BlockingPortalTaskStatus(TaskStatus):
169 def __init__(self, future: Future):
170 self._future = future
171
172 def started(self, value: object = None) -> None:
173 self._future.set_result(value)
174
175
176class BlockingPortal:
177 """An object that lets external threads run code in an asynchronous event loop."""
178
179 def __new__(cls) -> BlockingPortal:
180 return get_async_backend().create_blocking_portal()
181
182 def __init__(self) -> None:
183 self._event_loop_thread_id: int | None = get_ident()
184 self._stop_event = Event()
185 self._task_group = create_task_group()
186 self._cancelled_exc_class = get_cancelled_exc_class()
187
188 async def __aenter__(self) -> BlockingPortal:
189 await self._task_group.__aenter__()
190 return self
191
192 async def __aexit__(
193 self,
194 exc_type: type[BaseException] | None,
195 exc_val: BaseException | None,
196 exc_tb: TracebackType | None,
197 ) -> bool:
198 await self.stop()
199 return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
200
201 def _check_running(self) -> None:
202 if self._event_loop_thread_id is None:
203 raise RuntimeError("This portal is not running")
204 if self._event_loop_thread_id == get_ident():
205 raise RuntimeError(
206 "This method cannot be called from the event loop thread"
207 )
208
209 async def sleep_until_stopped(self) -> None:
210 """Sleep until :meth:`stop` is called."""
211 await self._stop_event.wait()
212
213 async def stop(self, cancel_remaining: bool = False) -> None:
214 """
215 Signal the portal to shut down.
216
217 This marks the portal as no longer accepting new calls and exits from
218 :meth:`sleep_until_stopped`.
219
220 :param cancel_remaining: ``True`` to cancel all the remaining tasks, ``False``
221 to let them finish before returning
222
223 """
224 self._event_loop_thread_id = None
225 self._stop_event.set()
226 if cancel_remaining:
227 self._task_group.cancel_scope.cancel("the blocking portal is shutting down")
228
229 async def _call_func(
230 self,
231 func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
232 args: tuple[Unpack[PosArgsT]],
233 kwargs: dict[str, Any],
234 future: Future[T_Retval],
235 ) -> None:
236 def callback(f: Future[T_Retval]) -> None:
237 if f.cancelled() and self._event_loop_thread_id not in (
238 None,
239 get_ident(),
240 ):
241 self.call(scope.cancel, "the future was cancelled")
242
243 try:
244 retval_or_awaitable = func(*args, **kwargs)
245 if isawaitable(retval_or_awaitable):
246 with CancelScope() as scope:
247 if future.cancelled():
248 scope.cancel("the future was cancelled")
249 else:
250 future.add_done_callback(callback)
251
252 retval = await retval_or_awaitable
253 else:
254 retval = retval_or_awaitable
255 except self._cancelled_exc_class:
256 future.cancel()
257 future.set_running_or_notify_cancel()
258 except BaseException as exc:
259 if not future.cancelled():
260 future.set_exception(exc)
261
262 # Let base exceptions fall through
263 if not isinstance(exc, Exception):
264 raise
265 else:
266 if not future.cancelled():
267 future.set_result(retval)
268 finally:
269 scope = None # type: ignore[assignment]
270
271 def _spawn_task_from_thread(
272 self,
273 func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
274 args: tuple[Unpack[PosArgsT]],
275 kwargs: dict[str, Any],
276 name: object,
277 future: Future[T_Retval],
278 ) -> None:
279 """
280 Spawn a new task using the given callable.
281
282 Implementers must ensure that the future is resolved when the task finishes.
283
284 :param func: a callable
285 :param args: positional arguments to be passed to the callable
286 :param kwargs: keyword arguments to be passed to the callable
287 :param name: name of the task (will be coerced to a string if not ``None``)
288 :param future: a future that will resolve to the return value of the callable,
289 or the exception raised during its execution
290
291 """
292 raise NotImplementedError
293
294 @overload
295 def call(
296 self,
297 func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
298 *args: Unpack[PosArgsT],
299 ) -> T_Retval: ...
300
301 @overload
302 def call(
303 self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
304 ) -> T_Retval: ...
305
306 def call(
307 self,
308 func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
309 *args: Unpack[PosArgsT],
310 ) -> T_Retval:
311 """
312 Call the given function in the event loop thread.
313
314 If the callable returns a coroutine object, it is awaited on.
315
316 :param func: any callable
317 :raises RuntimeError: if the portal is not running or if this method is called
318 from within the event loop thread
319
320 """
321 return cast(T_Retval, self.start_task_soon(func, *args).result())
322
323 @overload
324 def start_task_soon(
325 self,
326 func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
327 *args: Unpack[PosArgsT],
328 name: object = None,
329 ) -> Future[T_Retval]: ...
330
331 @overload
332 def start_task_soon(
333 self,
334 func: Callable[[Unpack[PosArgsT]], T_Retval],
335 *args: Unpack[PosArgsT],
336 name: object = None,
337 ) -> Future[T_Retval]: ...
338
339 def start_task_soon(
340 self,
341 func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
342 *args: Unpack[PosArgsT],
343 name: object = None,
344 ) -> Future[T_Retval]:
345 """
346 Start a task in the portal's task group.
347
348 The task will be run inside a cancel scope which can be cancelled by cancelling
349 the returned future.
350
351 :param func: the target function
352 :param args: positional arguments passed to ``func``
353 :param name: name of the task (will be coerced to a string if not ``None``)
354 :return: a future that resolves with the return value of the callable if the
355 task completes successfully, or with the exception raised in the task
356 :raises RuntimeError: if the portal is not running or if this method is called
357 from within the event loop thread
358 :rtype: concurrent.futures.Future[T_Retval]
359
360 .. versionadded:: 3.0
361
362 """
363 self._check_running()
364 f: Future[T_Retval] = Future()
365 self._spawn_task_from_thread(func, args, {}, name, f)
366 return f
367
368 def start_task(
369 self,
370 func: Callable[..., Awaitable[T_Retval]],
371 *args: object,
372 name: object = None,
373 ) -> tuple[Future[T_Retval], Any]:
374 """
375 Start a task in the portal's task group and wait until it signals for readiness.
376
377 This method works the same way as :meth:`.abc.TaskGroup.start`.
378
379 :param func: the target function
380 :param args: positional arguments passed to ``func``
381 :param name: name of the task (will be coerced to a string if not ``None``)
382 :return: a tuple of (future, task_status_value) where the ``task_status_value``
383 is the value passed to ``task_status.started()`` from within the target
384 function
385 :rtype: tuple[concurrent.futures.Future[T_Retval], Any]
386
387 .. versionadded:: 3.0
388
389 """
390
391 def task_done(future: Future[T_Retval]) -> None:
392 if not task_status_future.done():
393 if future.cancelled():
394 task_status_future.cancel()
395 elif future.exception():
396 task_status_future.set_exception(future.exception())
397 else:
398 exc = RuntimeError(
399 "Task exited without calling task_status.started()"
400 )
401 task_status_future.set_exception(exc)
402
403 self._check_running()
404 task_status_future: Future = Future()
405 task_status = _BlockingPortalTaskStatus(task_status_future)
406 f: Future = Future()
407 f.add_done_callback(task_done)
408 self._spawn_task_from_thread(func, args, {"task_status": task_status}, name, f)
409 return f, task_status_future.result()
410
411 def wrap_async_context_manager(
412 self, cm: AbstractAsyncContextManager[T_co]
413 ) -> AbstractContextManager[T_co]:
414 """
415 Wrap an async context manager as a synchronous context manager via this portal.
416
417 Spawns a task that will call both ``__aenter__()`` and ``__aexit__()``, stopping
418 in the middle until the synchronous context manager exits.
419
420 :param cm: an asynchronous context manager
421 :return: a synchronous context manager
422
423 .. versionadded:: 2.1
424
425 """
426 return _BlockingAsyncContextManager(cm, self)
427
428
429@dataclass
430class BlockingPortalProvider:
431 """
432 A manager for a blocking portal. Used as a context manager. The first thread to
433 enter this context manager causes a blocking portal to be started with the specific
434 parameters, and the last thread to exit causes the portal to be shut down. Thus,
435 there will be exactly one blocking portal running in this context as long as at
436 least one thread has entered this context manager.
437
438 The parameters are the same as for :func:`~anyio.run`.
439
440 :param backend: name of the backend
441 :param backend_options: backend options
442
443 .. versionadded:: 4.4
444 """
445
446 backend: str = "asyncio"
447 backend_options: dict[str, Any] | None = None
448 _lock: Lock = field(init=False, default_factory=Lock)
449 _leases: int = field(init=False, default=0)
450 _portal: BlockingPortal = field(init=False)
451 _portal_cm: AbstractContextManager[BlockingPortal] | None = field(
452 init=False, default=None
453 )
454
455 def __enter__(self) -> BlockingPortal:
456 with self._lock:
457 if self._portal_cm is None:
458 self._portal_cm = start_blocking_portal(
459 self.backend, self.backend_options
460 )
461 self._portal = self._portal_cm.__enter__()
462
463 self._leases += 1
464 return self._portal
465
466 def __exit__(
467 self,
468 exc_type: type[BaseException] | None,
469 exc_val: BaseException | None,
470 exc_tb: TracebackType | None,
471 ) -> None:
472 portal_cm: AbstractContextManager[BlockingPortal] | None = None
473 with self._lock:
474 assert self._portal_cm
475 assert self._leases > 0
476 self._leases -= 1
477 if not self._leases:
478 portal_cm = self._portal_cm
479 self._portal_cm = None
480 del self._portal
481
482 if portal_cm:
483 portal_cm.__exit__(None, None, None)
484
485
486@contextmanager
487def start_blocking_portal(
488 backend: str = "asyncio",
489 backend_options: dict[str, Any] | None = None,
490 *,
491 name: str | None = None,
492) -> Generator[BlockingPortal, Any, None]:
493 """
494 Start a new event loop in a new thread and run a blocking portal in its main task.
495
496 The parameters are the same as for :func:`~anyio.run`.
497
498 :param backend: name of the backend
499 :param backend_options: backend options
500 :param name: name of the thread
501 :return: a context manager that yields a blocking portal
502
503 .. versionchanged:: 3.0
504 Usage as a context manager is now required.
505
506 """
507
508 async def run_portal() -> None:
509 async with BlockingPortal() as portal_:
510 if name is None:
511 current_thread().name = f"{backend}-portal-{id(portal_):x}"
512
513 future.set_result(portal_)
514 await portal_.sleep_until_stopped()
515
516 def run_blocking_portal() -> None:
517 if future.set_running_or_notify_cancel():
518 try:
519 run_eventloop(
520 run_portal, backend=backend, backend_options=backend_options
521 )
522 except BaseException as exc:
523 if not future.done():
524 future.set_exception(exc)
525
526 future: Future[BlockingPortal] = Future()
527 thread = Thread(target=run_blocking_portal, daemon=True, name=name)
528 thread.start()
529 try:
530 cancel_remaining_tasks = False
531 portal = future.result()
532 try:
533 yield portal
534 except BaseException:
535 cancel_remaining_tasks = True
536 raise
537 finally:
538 try:
539 portal.call(portal.stop, cancel_remaining_tasks)
540 except RuntimeError:
541 pass
542 finally:
543 thread.join()
544
545
546def check_cancelled() -> None:
547 """
548 Check if the cancel scope of the host task's running the current worker thread has
549 been cancelled.
550
551 If the host task's current cancel scope has indeed been cancelled, the
552 backend-specific cancellation exception will be raised.
553
554 :raises RuntimeError: if the current thread was not spawned by
555 :func:`.to_thread.run_sync`
556
557 """
558 try:
559 token: EventLoopToken = threadlocals.current_token
560 except AttributeError:
561 raise NoEventLoopError(
562 "This function can only be called inside an AnyIO worker thread"
563 ) from None
564
565 token.backend_class.check_cancelled()