1# util/concurrency.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9from __future__ import annotations
10
11import asyncio
12import sys
13from typing import Any
14from typing import Awaitable
15from typing import Callable
16from typing import Coroutine
17from typing import NoReturn
18from typing import TYPE_CHECKING
19from typing import TypeVar
20from typing import Union
21
22from .compat import py311
23from .langhelpers import memoized_property
24from .typing import Literal
25from .typing import Self
26from .typing import TypeGuard
27from .. import exc
28
29_T = TypeVar("_T")
30
31
32def is_exit_exception(e: BaseException) -> bool:
33 # note asyncio.CancelledError is already BaseException
34 # so was an exit exception in any case
35 return not isinstance(e, Exception) or isinstance(
36 e, (asyncio.TimeoutError, asyncio.CancelledError)
37 )
38
39
40_ERROR_MESSAGE = (
41 "The SQLAlchemy asyncio module requires that the Python 'greenlet' "
42 "library is installed. In order to ensure this dependency is "
43 "available, use the 'sqlalchemy[asyncio]' install target: "
44 "'pip install sqlalchemy[asyncio]'"
45)
46
47
48def _not_implemented(*arg: Any, **kw: Any) -> NoReturn:
49 raise ImportError(_ERROR_MESSAGE)
50
51
52class _concurrency_shim_cls:
53 """Late import shim for greenlet"""
54
55 __slots__ = (
56 "_has_greenlet",
57 "greenlet",
58 "_AsyncIoGreenlet",
59 "getcurrent",
60 )
61
62 def _initialize(self, *, raise_: bool = True) -> None:
63 """Import greenlet and initialize the class"""
64 if "greenlet" in globals():
65 return
66
67 if not TYPE_CHECKING:
68 global getcurrent, greenlet, _AsyncIoGreenlet
69 global _has_gr_context
70
71 try:
72 from greenlet import getcurrent
73 from greenlet import greenlet
74 except ImportError as e:
75 if not TYPE_CHECKING:
76 # set greenlet in the global scope to prevent re-init
77 greenlet = None
78 self._has_greenlet = False
79 self._initialize_no_greenlet()
80 if raise_:
81 raise ImportError(_ERROR_MESSAGE) from e
82 else:
83 self._has_greenlet = True
84 # If greenlet.gr_context is present in current version of greenlet,
85 # it will be set with the current context on creation.
86 # Refs: https://github.com/python-greenlet/greenlet/pull/198
87 _has_gr_context = hasattr(getcurrent(), "gr_context")
88
89 # implementation based on snaury gist at
90 # https://gist.github.com/snaury/202bf4f22c41ca34e56297bae5f33fef
91 # Issue for context: https://github.com/python-greenlet/greenlet/issues/173 # noqa: E501
92
93 class _AsyncIoGreenlet(greenlet):
94 dead: bool
95
96 __sqlalchemy_greenlet_provider__ = True
97
98 def __init__(self, fn: Callable[..., Any], driver: greenlet):
99 greenlet.__init__(self, fn, driver)
100 if _has_gr_context:
101 self.gr_context = driver.gr_context
102
103 self.greenlet = greenlet
104 self.getcurrent = getcurrent
105 self._AsyncIoGreenlet = _AsyncIoGreenlet
106
107 def _initialize_no_greenlet(self):
108 self.getcurrent = _not_implemented
109 self.greenlet = _not_implemented # type: ignore[assignment]
110 self._AsyncIoGreenlet = _not_implemented # type: ignore[assignment]
111
112 def __getattr__(self, key: str) -> Any:
113 if key in self.__slots__:
114 self._initialize()
115 return getattr(self, key)
116 else:
117 raise AttributeError(key)
118
119
120_concurrency_shim = _concurrency_shim_cls()
121
122if TYPE_CHECKING:
123 _T_co = TypeVar("_T_co", covariant=True)
124
125 def iscoroutine(
126 awaitable: Awaitable[_T_co],
127 ) -> TypeGuard[Coroutine[Any, Any, _T_co]]: ...
128
129else:
130 iscoroutine = asyncio.iscoroutine
131
132
133def _safe_cancel_awaitable(awaitable: Awaitable[Any]) -> None:
134 # https://docs.python.org/3/reference/datamodel.html#coroutine.close
135
136 if iscoroutine(awaitable):
137 awaitable.close()
138
139
140def in_greenlet() -> bool:
141 current = _concurrency_shim.getcurrent()
142 return getattr(current, "__sqlalchemy_greenlet_provider__", False)
143
144
145def await_(awaitable: Awaitable[_T]) -> _T:
146 """Awaits an async function in a sync method.
147
148 The sync method must be inside a :func:`greenlet_spawn` context.
149 :func:`await_` calls cannot be nested.
150
151 :param awaitable: The coroutine to call.
152
153 """
154 # this is called in the context greenlet while running fn
155 current = _concurrency_shim.getcurrent()
156 if not getattr(current, "__sqlalchemy_greenlet_provider__", False):
157 _safe_cancel_awaitable(awaitable)
158
159 raise exc.MissingGreenlet(
160 "greenlet_spawn has not been called; can't call await_() "
161 "here. Was IO attempted in an unexpected place?"
162 )
163
164 # returns the control to the driver greenlet passing it
165 # a coroutine to run. Once the awaitable is done, the driver greenlet
166 # switches back to this greenlet with the result of awaitable that is
167 # then returned to the caller (or raised as error)
168 assert current.parent
169 return current.parent.switch(awaitable) # type: ignore[no-any-return]
170
171
172await_only = await_ # old name. deprecated on 2.2
173
174
175async def greenlet_spawn(
176 fn: Callable[..., _T],
177 *args: Any,
178 _require_await: bool = False,
179 **kwargs: Any,
180) -> _T:
181 """Runs a sync function ``fn`` in a new greenlet.
182
183 The sync function can then use :func:`await_` to wait for async
184 functions.
185
186 :param fn: The sync callable to call.
187 :param \\*args: Positional arguments to pass to the ``fn`` callable.
188 :param \\*\\*kwargs: Keyword arguments to pass to the ``fn`` callable.
189 """
190
191 result: Any
192 context = _concurrency_shim._AsyncIoGreenlet(
193 fn, _concurrency_shim.getcurrent()
194 )
195 # runs the function synchronously in gl greenlet. If the execution
196 # is interrupted by await_, context is not dead and result is a
197 # coroutine to wait. If the context is dead the function has
198 # returned, and its result can be returned.
199 switch_occurred = False
200
201 result = context.switch(*args, **kwargs)
202 while not context.dead:
203 switch_occurred = True
204 try:
205 # wait for a coroutine from await_ and then return its
206 # result back to it.
207 value = await result
208 except BaseException:
209 # this allows an exception to be raised within
210 # the moderated greenlet so that it can continue
211 # its expected flow.
212 result = context.throw(*sys.exc_info())
213 else:
214 result = context.switch(value)
215
216 if _require_await and not switch_occurred:
217 raise exc.AwaitRequired(
218 "The current operation required an async execution but none was "
219 "detected. This will usually happen when using a non compatible "
220 "DBAPI driver. Please ensure that an async DBAPI is used."
221 )
222 return result # type: ignore[no-any-return]
223
224
225class AsyncAdaptedLock:
226 @memoized_property
227 def mutex(self) -> asyncio.Lock:
228 # there should not be a race here for coroutines creating the
229 # new lock as we are not using await, so therefore no concurrency
230 return asyncio.Lock()
231
232 def __enter__(self) -> bool:
233 # await is used to acquire the lock only after the first calling
234 # coroutine has created the mutex.
235 return await_(self.mutex.acquire())
236
237 def __exit__(self, *arg: Any, **kw: Any) -> None:
238 self.mutex.release()
239
240
241if not TYPE_CHECKING and py311:
242 _Runner = asyncio.Runner
243else:
244
245 class _Runner:
246 """Runner implementation for test only"""
247
248 _loop: Union[None, asyncio.AbstractEventLoop, Literal[False]]
249
250 def __init__(self) -> None:
251 self._loop = None
252
253 def __enter__(self) -> Self:
254 self._lazy_init()
255 return self
256
257 def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
258 self.close()
259
260 def close(self) -> None:
261 if self._loop:
262 try:
263 self._loop.run_until_complete(
264 self._loop.shutdown_asyncgens()
265 )
266 finally:
267 self._loop.close()
268 self._loop = False
269
270 def get_loop(self) -> asyncio.AbstractEventLoop:
271 """Return embedded event loop."""
272 self._lazy_init()
273 assert self._loop
274 return self._loop
275
276 def run(self, coro: Coroutine[Any, Any, _T]) -> _T:
277 self._lazy_init()
278 assert self._loop
279 return self._loop.run_until_complete(coro)
280
281 def _lazy_init(self) -> None:
282 if self._loop is False:
283 raise RuntimeError("Runner is closed")
284 if self._loop is None:
285 self._loop = asyncio.new_event_loop()
286
287
288class _AsyncUtil:
289 """Asyncio util for test suite/ util only"""
290
291 def __init__(self) -> None:
292 self.runner = _Runner() # runner it lazy so it can be created here
293
294 def run(
295 self,
296 fn: Callable[..., Coroutine[Any, Any, _T]],
297 *args: Any,
298 **kwargs: Any,
299 ) -> _T:
300 """Run coroutine on the loop"""
301 return self.runner.run(fn(*args, **kwargs))
302
303 def run_in_greenlet(
304 self, fn: Callable[..., _T], *args: Any, **kwargs: Any
305 ) -> _T:
306 """Run sync function in greenlet. Support nested calls"""
307 _concurrency_shim._initialize(raise_=False)
308
309 if _concurrency_shim._has_greenlet:
310 if self.runner.get_loop().is_running():
311 # allow for a wrapped test function to call another
312 assert in_greenlet()
313 return fn(*args, **kwargs)
314 else:
315 return self.runner.run(greenlet_spawn(fn, *args, **kwargs))
316 else:
317 return fn(*args, **kwargs)
318
319 def close(self) -> None:
320 self.runner.close()