1# util/concurrency.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""asyncio-related concurrency functions."""
10
11from __future__ import annotations
12
13import asyncio
14import sys
15from typing import Any
16from typing import Awaitable
17from typing import Callable
18from typing import Coroutine
19from typing import Literal
20from typing import NoReturn
21from typing import TYPE_CHECKING
22from typing import TypeGuard
23from typing import TypeVar
24from typing import Union
25
26from .compat import py311
27from .langhelpers import memoized_property
28from .typing import Self
29from .. import exc
30
31_T = TypeVar("_T")
32
33
34def is_exit_exception(e: BaseException) -> bool:
35 # note asyncio.CancelledError is already BaseException
36 # so was an exit exception in any case
37 return not isinstance(e, Exception) or isinstance(
38 e, (asyncio.TimeoutError, asyncio.CancelledError)
39 )
40
41
42_ERROR_MESSAGE = (
43 "The SQLAlchemy asyncio module requires that the Python 'greenlet' "
44 "library is installed. In order to ensure this dependency is "
45 "available, use the 'sqlalchemy[asyncio]' install target: "
46 "'pip install sqlalchemy[asyncio]'"
47)
48
49
50def _not_implemented(*arg: Any, **kw: Any) -> NoReturn:
51 raise ImportError(_ERROR_MESSAGE)
52
53
54class _concurrency_shim_cls:
55 """Late import shim for greenlet"""
56
57 __slots__ = (
58 "_has_greenlet",
59 "greenlet",
60 "_AsyncIoGreenlet",
61 "getcurrent",
62 )
63
64 def _initialize(self, *, raise_: bool = True) -> None:
65 """Import greenlet and initialize the class"""
66 if "greenlet" in globals():
67 return
68
69 if not TYPE_CHECKING:
70 global getcurrent, greenlet, _AsyncIoGreenlet
71 global _has_gr_context
72
73 try:
74 from greenlet import getcurrent
75 from greenlet import greenlet
76 except ImportError as e:
77 if not TYPE_CHECKING:
78 # set greenlet in the global scope to prevent re-init
79 greenlet = None
80 self._has_greenlet = False
81 self._initialize_no_greenlet()
82 if raise_:
83 raise ImportError(_ERROR_MESSAGE) from e
84 else:
85 self._has_greenlet = True
86 # If greenlet.gr_context is present in current version of greenlet,
87 # it will be set with the current context on creation.
88 # Refs: https://github.com/python-greenlet/greenlet/pull/198
89 _has_gr_context = hasattr(getcurrent(), "gr_context")
90
91 # implementation based on snaury gist at
92 # https://gist.github.com/snaury/202bf4f22c41ca34e56297bae5f33fef
93 # Issue for context: https://github.com/python-greenlet/greenlet/issues/173 # noqa: E501
94
95 class _AsyncIoGreenlet(greenlet):
96 dead: bool
97
98 __sqlalchemy_greenlet_provider__ = True
99
100 def __init__(self, fn: Callable[..., Any], driver: greenlet):
101 greenlet.__init__(self, fn, driver)
102 if _has_gr_context:
103 self.gr_context = driver.gr_context
104
105 self.greenlet = greenlet
106 self.getcurrent = getcurrent
107 self._AsyncIoGreenlet = _AsyncIoGreenlet
108
109 def _initialize_no_greenlet(self):
110 self.getcurrent = _not_implemented
111 self.greenlet = _not_implemented # type: ignore[assignment]
112 self._AsyncIoGreenlet = _not_implemented # type: ignore[assignment]
113
114 def __getattr__(self, key: str) -> Any:
115 if key in self.__slots__:
116 self._initialize()
117 return getattr(self, key)
118 else:
119 raise AttributeError(key)
120
121
122_concurrency_shim = _concurrency_shim_cls()
123
124if TYPE_CHECKING:
125 _T_co = TypeVar("_T_co", covariant=True)
126
127 def iscoroutine(
128 awaitable: Awaitable[_T_co],
129 ) -> TypeGuard[Coroutine[Any, Any, _T_co]]: ...
130
131else:
132 iscoroutine = asyncio.iscoroutine
133
134
135def _safe_cancel_awaitable(awaitable: Awaitable[Any]) -> None:
136 # https://docs.python.org/3/reference/datamodel.html#coroutine.close
137
138 if iscoroutine(awaitable):
139 awaitable.close()
140
141
142def in_greenlet() -> bool:
143 current = _concurrency_shim.getcurrent()
144 return getattr(current, "__sqlalchemy_greenlet_provider__", False)
145
146
147def await_(awaitable: Awaitable[_T]) -> _T:
148 """Awaits an async function in a sync method.
149
150 The sync method must be inside a :func:`greenlet_spawn` context.
151 :func:`await_` calls cannot be nested.
152
153 :param awaitable: The coroutine to call.
154
155 """
156 # this is called in the context greenlet while running fn
157 current = _concurrency_shim.getcurrent()
158 if not getattr(current, "__sqlalchemy_greenlet_provider__", False):
159 _safe_cancel_awaitable(awaitable)
160
161 raise exc.MissingGreenlet(
162 "greenlet_spawn has not been called; can't call await_() "
163 "here. Was IO attempted in an unexpected place?"
164 )
165
166 # returns the control to the driver greenlet passing it
167 # a coroutine to run. Once the awaitable is done, the driver greenlet
168 # switches back to this greenlet with the result of awaitable that is
169 # then returned to the caller (or raised as error)
170 assert current.parent
171 return current.parent.switch(awaitable) # type: ignore[no-any-return]
172
173
174await_only = await_ # old name. deprecated on 2.2
175
176
177async def greenlet_spawn(
178 fn: Callable[..., _T],
179 *args: Any,
180 _require_await: bool = False,
181 **kwargs: Any,
182) -> _T:
183 """Runs a sync function ``fn`` in a new greenlet.
184
185 The sync function can then use :func:`await_` to wait for async
186 functions.
187
188 :param fn: The sync callable to call.
189 :param \\*args: Positional arguments to pass to the ``fn`` callable.
190 :param \\*\\*kwargs: Keyword arguments to pass to the ``fn`` callable.
191 """
192
193 result: Any
194 context = _concurrency_shim._AsyncIoGreenlet(
195 fn, _concurrency_shim.getcurrent()
196 )
197 # runs the function synchronously in gl greenlet. If the execution
198 # is interrupted by await_, context is not dead and result is a
199 # coroutine to wait. If the context is dead the function has
200 # returned, and its result can be returned.
201 switch_occurred = False
202
203 result = context.switch(*args, **kwargs)
204 while not context.dead:
205 switch_occurred = True
206 try:
207 # wait for a coroutine from await_ and then return its
208 # result back to it.
209 value = await result
210 except BaseException:
211 # this allows an exception to be raised within
212 # the moderated greenlet so that it can continue
213 # its expected flow.
214 result = context.throw(*sys.exc_info())
215 else:
216 result = context.switch(value)
217
218 if _require_await and not switch_occurred:
219 raise exc.AwaitRequired(
220 "The current operation required an async execution but none was "
221 "detected. This will usually happen when using a non compatible "
222 "DBAPI driver. Please ensure that an async DBAPI is used."
223 )
224 return result # type: ignore[no-any-return]
225
226
227class AsyncAdaptedLock:
228 @memoized_property
229 def mutex(self) -> asyncio.Lock:
230 # there should not be a race here for coroutines creating the
231 # new lock as we are not using await, so therefore no concurrency
232 return asyncio.Lock()
233
234 def __enter__(self) -> bool:
235 # await is used to acquire the lock only after the first calling
236 # coroutine has created the mutex.
237 return await_(self.mutex.acquire())
238
239 def __exit__(self, *arg: Any, **kw: Any) -> None:
240 self.mutex.release()
241
242
243if not TYPE_CHECKING and py311:
244 _Runner = asyncio.Runner
245else:
246
247 class _Runner:
248 """Runner implementation for test only"""
249
250 _loop: Union[None, asyncio.AbstractEventLoop, Literal[False]]
251
252 def __init__(self) -> None:
253 self._loop = None
254
255 def __enter__(self) -> Self:
256 self._lazy_init()
257 return self
258
259 def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
260 self.close()
261
262 def close(self) -> None:
263 if self._loop:
264 try:
265 self._loop.run_until_complete(
266 self._loop.shutdown_asyncgens()
267 )
268 finally:
269 self._loop.close()
270 self._loop = False
271
272 def get_loop(self) -> asyncio.AbstractEventLoop:
273 """Return embedded event loop."""
274 self._lazy_init()
275 assert self._loop
276 return self._loop
277
278 def run(self, coro: Coroutine[Any, Any, _T]) -> _T:
279 self._lazy_init()
280 assert self._loop
281 return self._loop.run_until_complete(coro)
282
283 def _lazy_init(self) -> None:
284 if self._loop is False:
285 raise RuntimeError("Runner is closed")
286 if self._loop is None:
287 self._loop = asyncio.new_event_loop()
288
289
290class _AsyncUtil:
291 """Asyncio util for test suite/ util only"""
292
293 def __init__(self) -> None:
294 self.runner = _Runner() # runner it lazy so it can be created here
295
296 def run(
297 self,
298 fn: Callable[..., Coroutine[Any, Any, _T]],
299 *args: Any,
300 **kwargs: Any,
301 ) -> _T:
302 """Run coroutine on the loop"""
303 return self.runner.run(fn(*args, **kwargs))
304
305 def run_in_greenlet(
306 self, fn: Callable[..., _T], *args: Any, **kwargs: Any
307 ) -> _T:
308 """Run sync function in greenlet. Support nested calls"""
309 _concurrency_shim._initialize(raise_=False)
310
311 if _concurrency_shim._has_greenlet:
312 if self.runner.get_loop().is_running():
313 # allow for a wrapped test function to call another
314 assert in_greenlet()
315 return fn(*args, **kwargs)
316 else:
317 return self.runner.run(greenlet_spawn(fn, *args, **kwargs))
318 else:
319 return fn(*args, **kwargs)
320
321 def close(self) -> None:
322 self.runner.close()