1# util/concurrency.py 
    2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors 
    3# <see AUTHORS file> 
    4# 
    5# This module is part of SQLAlchemy and is released under 
    6# the MIT License: https://www.opensource.org/licenses/mit-license.php 
    7# mypy: allow-untyped-defs, allow-untyped-calls 
    8 
    9"""asyncio-related concurrency functions.""" 
    10 
    11from __future__ import annotations 
    12 
    13import asyncio 
    14import sys 
    15from typing import Any 
    16from typing import Awaitable 
    17from typing import Callable 
    18from typing import Coroutine 
    19from typing import Literal 
    20from typing import NoReturn 
    21from typing import TYPE_CHECKING 
    22from typing import TypeGuard 
    23from typing import TypeVar 
    24from typing import Union 
    25 
    26from .compat import py311 
    27from .langhelpers import memoized_property 
    28from .typing import Self 
    29from .. import exc 
    30 
    31_T = TypeVar("_T") 
    32 
    33 
    34def is_exit_exception(e: BaseException) -> bool: 
    35    # note asyncio.CancelledError is already BaseException 
    36    # so was an exit exception in any case 
    37    return not isinstance(e, Exception) or isinstance( 
    38        e, (asyncio.TimeoutError, asyncio.CancelledError) 
    39    ) 
    40 
    41 
    42_ERROR_MESSAGE = ( 
    43    "The SQLAlchemy asyncio module requires that the Python 'greenlet' " 
    44    "library is installed.  In order to ensure this dependency is " 
    45    "available, use the 'sqlalchemy[asyncio]' install target:  " 
    46    "'pip install sqlalchemy[asyncio]'" 
    47) 
    48 
    49 
    50def _not_implemented(*arg: Any, **kw: Any) -> NoReturn: 
    51    raise ImportError(_ERROR_MESSAGE) 
    52 
    53 
    54class _concurrency_shim_cls: 
    55    """Late import shim for greenlet""" 
    56 
    57    __slots__ = ( 
    58        "_has_greenlet", 
    59        "greenlet", 
    60        "_AsyncIoGreenlet", 
    61        "getcurrent", 
    62    ) 
    63 
    64    def _initialize(self, *, raise_: bool = True) -> None: 
    65        """Import greenlet and initialize the class""" 
    66        if "greenlet" in globals(): 
    67            return 
    68 
    69        if not TYPE_CHECKING: 
    70            global getcurrent, greenlet, _AsyncIoGreenlet 
    71            global _has_gr_context 
    72 
    73        try: 
    74            from greenlet import getcurrent 
    75            from greenlet import greenlet 
    76        except ImportError as e: 
    77            if not TYPE_CHECKING: 
    78                # set greenlet in the global scope to prevent re-init 
    79                greenlet = None 
    80            self._has_greenlet = False 
    81            self._initialize_no_greenlet() 
    82            if raise_: 
    83                raise ImportError(_ERROR_MESSAGE) from e 
    84        else: 
    85            self._has_greenlet = True 
    86            # If greenlet.gr_context is present in current version of greenlet, 
    87            # it will be set with the current context on creation. 
    88            # Refs: https://github.com/python-greenlet/greenlet/pull/198 
    89            _has_gr_context = hasattr(getcurrent(), "gr_context") 
    90 
    91            # implementation based on snaury gist at 
    92            # https://gist.github.com/snaury/202bf4f22c41ca34e56297bae5f33fef 
    93            # Issue for context: https://github.com/python-greenlet/greenlet/issues/173 # noqa: E501 
    94 
    95            class _AsyncIoGreenlet(greenlet): 
    96                dead: bool 
    97 
    98                __sqlalchemy_greenlet_provider__ = True 
    99 
    100                def __init__(self, fn: Callable[..., Any], driver: greenlet): 
    101                    greenlet.__init__(self, fn, driver) 
    102                    if _has_gr_context: 
    103                        self.gr_context = driver.gr_context 
    104 
    105            self.greenlet = greenlet 
    106            self.getcurrent = getcurrent 
    107            self._AsyncIoGreenlet = _AsyncIoGreenlet 
    108 
    109    def _initialize_no_greenlet(self): 
    110        self.getcurrent = _not_implemented 
    111        self.greenlet = _not_implemented  # type: ignore[assignment] 
    112        self._AsyncIoGreenlet = _not_implemented  # type: ignore[assignment] 
    113 
    114    def __getattr__(self, key: str) -> Any: 
    115        if key in self.__slots__: 
    116            self._initialize() 
    117            return getattr(self, key) 
    118        else: 
    119            raise AttributeError(key) 
    120 
    121 
    122_concurrency_shim = _concurrency_shim_cls() 
    123 
    124if TYPE_CHECKING: 
    125    _T_co = TypeVar("_T_co", covariant=True) 
    126 
    127    def iscoroutine( 
    128        awaitable: Awaitable[_T_co], 
    129    ) -> TypeGuard[Coroutine[Any, Any, _T_co]]: ... 
    130 
    131else: 
    132    iscoroutine = asyncio.iscoroutine 
    133 
    134 
    135def _safe_cancel_awaitable(awaitable: Awaitable[Any]) -> None: 
    136    # https://docs.python.org/3/reference/datamodel.html#coroutine.close 
    137 
    138    if iscoroutine(awaitable): 
    139        awaitable.close() 
    140 
    141 
    142def in_greenlet() -> bool: 
    143    current = _concurrency_shim.getcurrent() 
    144    return getattr(current, "__sqlalchemy_greenlet_provider__", False) 
    145 
    146 
    147def await_(awaitable: Awaitable[_T]) -> _T: 
    148    """Awaits an async function in a sync method. 
    149 
    150    The sync method must be inside a :func:`greenlet_spawn` context. 
    151    :func:`await_` calls cannot be nested. 
    152 
    153    :param awaitable: The coroutine to call. 
    154 
    155    """ 
    156    # this is called in the context greenlet while running fn 
    157    current = _concurrency_shim.getcurrent() 
    158    if not getattr(current, "__sqlalchemy_greenlet_provider__", False): 
    159        _safe_cancel_awaitable(awaitable) 
    160 
    161        raise exc.MissingGreenlet( 
    162            "greenlet_spawn has not been called; can't call await_() " 
    163            "here. Was IO attempted in an unexpected place?" 
    164        ) 
    165 
    166    # returns the control to the driver greenlet passing it 
    167    # a coroutine to run. Once the awaitable is done, the driver greenlet 
    168    # switches back to this greenlet with the result of awaitable that is 
    169    # then returned to the caller (or raised as error) 
    170    assert current.parent 
    171    return current.parent.switch(awaitable)  # type: ignore[no-any-return] 
    172 
    173 
    174await_only = await_  # old name. deprecated on 2.2 
    175 
    176 
    177async def greenlet_spawn( 
    178    fn: Callable[..., _T], 
    179    *args: Any, 
    180    _require_await: bool = False, 
    181    **kwargs: Any, 
    182) -> _T: 
    183    """Runs a sync function ``fn`` in a new greenlet. 
    184 
    185    The sync function can then use :func:`await_` to wait for async 
    186    functions. 
    187 
    188    :param fn: The sync callable to call. 
    189    :param \\*args: Positional arguments to pass to the ``fn`` callable. 
    190    :param \\*\\*kwargs: Keyword arguments to pass to the ``fn`` callable. 
    191    """ 
    192 
    193    result: Any 
    194    context = _concurrency_shim._AsyncIoGreenlet( 
    195        fn, _concurrency_shim.getcurrent() 
    196    ) 
    197    # runs the function synchronously in gl greenlet. If the execution 
    198    # is interrupted by await_, context is not dead and result is a 
    199    # coroutine to wait. If the context is dead the function has 
    200    # returned, and its result can be returned. 
    201    switch_occurred = False 
    202 
    203    result = context.switch(*args, **kwargs) 
    204    while not context.dead: 
    205        switch_occurred = True 
    206        try: 
    207            # wait for a coroutine from await_ and then return its 
    208            # result back to it. 
    209            value = await result 
    210        except BaseException: 
    211            # this allows an exception to be raised within 
    212            # the moderated greenlet so that it can continue 
    213            # its expected flow. 
    214            result = context.throw(*sys.exc_info()) 
    215        else: 
    216            result = context.switch(value) 
    217 
    218    if _require_await and not switch_occurred: 
    219        raise exc.AwaitRequired( 
    220            "The current operation required an async execution but none was " 
    221            "detected. This will usually happen when using a non compatible " 
    222            "DBAPI driver. Please ensure that an async DBAPI is used." 
    223        ) 
    224    return result  # type: ignore[no-any-return] 
    225 
    226 
    227class AsyncAdaptedLock: 
    228    @memoized_property 
    229    def mutex(self) -> asyncio.Lock: 
    230        # there should not be a race here for coroutines creating the 
    231        # new lock as we are not using await, so therefore no concurrency 
    232        return asyncio.Lock() 
    233 
    234    def __enter__(self) -> bool: 
    235        # await is used to acquire the lock only after the first calling 
    236        # coroutine has created the mutex. 
    237        return await_(self.mutex.acquire()) 
    238 
    239    def __exit__(self, *arg: Any, **kw: Any) -> None: 
    240        self.mutex.release() 
    241 
    242 
    243if not TYPE_CHECKING and py311: 
    244    _Runner = asyncio.Runner 
    245else: 
    246 
    247    class _Runner: 
    248        """Runner implementation for test only""" 
    249 
    250        _loop: Union[None, asyncio.AbstractEventLoop, Literal[False]] 
    251 
    252        def __init__(self) -> None: 
    253            self._loop = None 
    254 
    255        def __enter__(self) -> Self: 
    256            self._lazy_init() 
    257            return self 
    258 
    259        def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: 
    260            self.close() 
    261 
    262        def close(self) -> None: 
    263            if self._loop: 
    264                try: 
    265                    self._loop.run_until_complete( 
    266                        self._loop.shutdown_asyncgens() 
    267                    ) 
    268                finally: 
    269                    self._loop.close() 
    270                    self._loop = False 
    271 
    272        def get_loop(self) -> asyncio.AbstractEventLoop: 
    273            """Return embedded event loop.""" 
    274            self._lazy_init() 
    275            assert self._loop 
    276            return self._loop 
    277 
    278        def run(self, coro: Coroutine[Any, Any, _T]) -> _T: 
    279            self._lazy_init() 
    280            assert self._loop 
    281            return self._loop.run_until_complete(coro) 
    282 
    283        def _lazy_init(self) -> None: 
    284            if self._loop is False: 
    285                raise RuntimeError("Runner is closed") 
    286            if self._loop is None: 
    287                self._loop = asyncio.new_event_loop() 
    288 
    289 
    290class _AsyncUtil: 
    291    """Asyncio util for test suite/ util only""" 
    292 
    293    def __init__(self) -> None: 
    294        self.runner = _Runner()  # runner it lazy so it can be created here 
    295 
    296    def run( 
    297        self, 
    298        fn: Callable[..., Coroutine[Any, Any, _T]], 
    299        *args: Any, 
    300        **kwargs: Any, 
    301    ) -> _T: 
    302        """Run coroutine on the loop""" 
    303        return self.runner.run(fn(*args, **kwargs)) 
    304 
    305    def run_in_greenlet( 
    306        self, fn: Callable[..., _T], *args: Any, **kwargs: Any 
    307    ) -> _T: 
    308        """Run sync function in greenlet. Support nested calls""" 
    309        _concurrency_shim._initialize(raise_=False) 
    310 
    311        if _concurrency_shim._has_greenlet: 
    312            if self.runner.get_loop().is_running(): 
    313                # allow for a wrapped test function to call another 
    314                assert in_greenlet() 
    315                return fn(*args, **kwargs) 
    316            else: 
    317                return self.runner.run(greenlet_spawn(fn, *args, **kwargs)) 
    318        else: 
    319            return fn(*args, **kwargs) 
    320 
    321    def close(self) -> None: 
    322        self.runner.close()