1# util/_concurrency_py3k.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9from __future__ import annotations
10
11import asyncio
12from contextvars import Context
13import sys
14import typing
15from typing import Any
16from typing import Awaitable
17from typing import Callable
18from typing import Coroutine
19from typing import Optional
20from typing import TYPE_CHECKING
21from typing import TypeVar
22from typing import Union
23
24from .langhelpers import memoized_property
25from .. import exc
26from ..util import py311
27from ..util.typing import Literal
28from ..util.typing import Protocol
29from ..util.typing import Self
30from ..util.typing import TypeGuard
31
32_T = TypeVar("_T")
33
34if typing.TYPE_CHECKING:
35
36 class greenlet(Protocol):
37 dead: bool
38 gr_context: Optional[Context]
39
40 def __init__(self, fn: Callable[..., Any], driver: greenlet): ...
41
42 def throw(self, *arg: Any) -> Any:
43 return None
44
45 def switch(self, value: Any) -> Any:
46 return None
47
48 def getcurrent() -> greenlet: ...
49
50else:
51 from greenlet import getcurrent
52 from greenlet import greenlet
53
54
55# If greenlet.gr_context is present in current version of greenlet,
56# it will be set with the current context on creation.
57# Refs: https://github.com/python-greenlet/greenlet/pull/198
58_has_gr_context = hasattr(getcurrent(), "gr_context")
59
60
61def is_exit_exception(e: BaseException) -> bool:
62 # note asyncio.CancelledError is already BaseException
63 # so was an exit exception in any case
64 return not isinstance(e, Exception) or isinstance(
65 e, (asyncio.TimeoutError, asyncio.CancelledError)
66 )
67
68
69# implementation based on snaury gist at
70# https://gist.github.com/snaury/202bf4f22c41ca34e56297bae5f33fef
71# Issue for context: https://github.com/python-greenlet/greenlet/issues/173
72
73
74class _AsyncIoGreenlet(greenlet):
75 dead: bool
76
77 __sqlalchemy_greenlet_provider__ = True
78
79 def __init__(self, fn: Callable[..., Any], driver: greenlet):
80 greenlet.__init__(self, fn, driver)
81 if _has_gr_context:
82 self.gr_context = driver.gr_context
83
84
85_T_co = TypeVar("_T_co", covariant=True)
86
87if TYPE_CHECKING:
88
89 def iscoroutine(
90 awaitable: Awaitable[_T_co],
91 ) -> TypeGuard[Coroutine[Any, Any, _T_co]]: ...
92
93else:
94 iscoroutine = asyncio.iscoroutine
95
96
97def _safe_cancel_awaitable(awaitable: Awaitable[Any]) -> None:
98 # https://docs.python.org/3/reference/datamodel.html#coroutine.close
99
100 if iscoroutine(awaitable):
101 awaitable.close()
102
103
104def in_greenlet() -> bool:
105 current = getcurrent()
106 return getattr(current, "__sqlalchemy_greenlet_provider__", False)
107
108
109def await_only(awaitable: Awaitable[_T]) -> _T:
110 """Awaits an async function in a sync method.
111
112 The sync method must be inside a :func:`greenlet_spawn` context.
113 :func:`await_only` calls cannot be nested.
114
115 :param awaitable: The coroutine to call.
116
117 """
118 # this is called in the context greenlet while running fn
119 current = getcurrent()
120 if not getattr(current, "__sqlalchemy_greenlet_provider__", False):
121 _safe_cancel_awaitable(awaitable)
122
123 raise exc.MissingGreenlet(
124 "greenlet_spawn has not been called; can't call await_only() "
125 "here. Was IO attempted in an unexpected place?"
126 )
127
128 # returns the control to the driver greenlet passing it
129 # a coroutine to run. Once the awaitable is done, the driver greenlet
130 # switches back to this greenlet with the result of awaitable that is
131 # then returned to the caller (or raised as error)
132 return current.parent.switch(awaitable) # type: ignore[no-any-return,attr-defined] # noqa: E501
133
134
135def await_fallback(awaitable: Awaitable[_T]) -> _T:
136 """Awaits an async function in a sync method.
137
138 The sync method must be inside a :func:`greenlet_spawn` context.
139 :func:`await_fallback` calls cannot be nested.
140
141 :param awaitable: The coroutine to call.
142
143 .. deprecated:: 2.0.24 The ``await_fallback()`` function will be removed
144 in SQLAlchemy 2.1. Use :func:`_util.await_only` instead, running the
145 function / program / etc. within a top-level greenlet that is set up
146 using :func:`_util.greenlet_spawn`.
147
148 """
149
150 # this is called in the context greenlet while running fn
151 current = getcurrent()
152 if not getattr(current, "__sqlalchemy_greenlet_provider__", False):
153 loop = get_event_loop()
154 if loop.is_running():
155 _safe_cancel_awaitable(awaitable)
156
157 raise exc.MissingGreenlet(
158 "greenlet_spawn has not been called and asyncio event "
159 "loop is already running; can't call await_fallback() here. "
160 "Was IO attempted in an unexpected place?"
161 )
162 return loop.run_until_complete(awaitable)
163
164 return current.parent.switch(awaitable) # type: ignore[no-any-return,attr-defined] # noqa: E501
165
166
167async def greenlet_spawn(
168 fn: Callable[..., _T],
169 *args: Any,
170 _require_await: bool = False,
171 **kwargs: Any,
172) -> _T:
173 """Runs a sync function ``fn`` in a new greenlet.
174
175 The sync function can then use :func:`await_only` to wait for async
176 functions.
177
178 :param fn: The sync callable to call.
179 :param \\*args: Positional arguments to pass to the ``fn`` callable.
180 :param \\*\\*kwargs: Keyword arguments to pass to the ``fn`` callable.
181 """
182
183 result: Any
184 context = _AsyncIoGreenlet(fn, getcurrent())
185 # runs the function synchronously in gl greenlet. If the execution
186 # is interrupted by await_only, context is not dead and result is a
187 # coroutine to wait. If the context is dead the function has
188 # returned, and its result can be returned.
189 switch_occurred = False
190 result = context.switch(*args, **kwargs)
191 while not context.dead:
192 switch_occurred = True
193 try:
194 # wait for a coroutine from await_only and then return its
195 # result back to it.
196 value = await result
197 except BaseException:
198 # this allows an exception to be raised within
199 # the moderated greenlet so that it can continue
200 # its expected flow.
201 result = context.throw(*sys.exc_info())
202 else:
203 result = context.switch(value)
204
205 if _require_await and not switch_occurred:
206 raise exc.AwaitRequired(
207 "The current operation required an async execution but none was "
208 "detected. This will usually happen when using a non compatible "
209 "DBAPI driver. Please ensure that an async DBAPI is used."
210 )
211 return result # type: ignore[no-any-return]
212
213
214class AsyncAdaptedLock:
215 @memoized_property
216 def mutex(self) -> asyncio.Lock:
217 # there should not be a race here for coroutines creating the
218 # new lock as we are not using await, so therefore no concurrency
219 return asyncio.Lock()
220
221 def __enter__(self) -> bool:
222 # await is used to acquire the lock only after the first calling
223 # coroutine has created the mutex.
224 return await_fallback(self.mutex.acquire())
225
226 def __exit__(self, *arg: Any, **kw: Any) -> None:
227 self.mutex.release()
228
229
230def get_event_loop() -> asyncio.AbstractEventLoop:
231 """vendor asyncio.get_event_loop() for python 3.7 and above.
232
233 Python 3.10 deprecates get_event_loop() as a standalone.
234
235 """
236 try:
237 return asyncio.get_running_loop()
238 except RuntimeError:
239 # avoid "During handling of the above exception, another exception..."
240 pass
241 return asyncio.get_event_loop_policy().get_event_loop()
242
243
244if not TYPE_CHECKING and py311:
245 _Runner = asyncio.Runner
246else:
247
248 class _Runner:
249 """Runner implementation for test only"""
250
251 _loop: Union[None, asyncio.AbstractEventLoop, Literal[False]]
252
253 def __init__(self) -> None:
254 self._loop = None
255
256 def __enter__(self) -> Self:
257 self._lazy_init()
258 return self
259
260 def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
261 self.close()
262
263 def close(self) -> None:
264 if self._loop:
265 try:
266 self._loop.run_until_complete(
267 self._loop.shutdown_asyncgens()
268 )
269 finally:
270 self._loop.close()
271 self._loop = False
272
273 def get_loop(self) -> asyncio.AbstractEventLoop:
274 """Return embedded event loop."""
275 self._lazy_init()
276 assert self._loop
277 return self._loop
278
279 def run(self, coro: Coroutine[Any, Any, _T]) -> _T:
280 self._lazy_init()
281 assert self._loop
282 return self._loop.run_until_complete(coro)
283
284 def _lazy_init(self) -> None:
285 if self._loop is False:
286 raise RuntimeError("Runner is closed")
287 if self._loop is None:
288 self._loop = asyncio.new_event_loop()