1# connectors/asyncio.py 
    2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors 
    3# <see AUTHORS file> 
    4# 
    5# This module is part of SQLAlchemy and is released under 
    6# the MIT License: https://www.opensource.org/licenses/mit-license.php 
    7 
    8"""generic asyncio-adapted versions of DBAPI connection and cursor""" 
    9 
    10from __future__ import annotations 
    11 
    12import asyncio 
    13import collections 
    14import sys 
    15import types 
    16from typing import Any 
    17from typing import AsyncIterator 
    18from typing import Awaitable 
    19from typing import Deque 
    20from typing import Iterator 
    21from typing import NoReturn 
    22from typing import Optional 
    23from typing import Protocol 
    24from typing import Sequence 
    25from typing import Tuple 
    26from typing import Type 
    27from typing import TYPE_CHECKING 
    28 
    29from ..engine import AdaptedConnection 
    30from ..exc import EmulatedDBAPIException 
    31from ..util import EMPTY_DICT 
    32from ..util.concurrency import await_ 
    33from ..util.concurrency import in_greenlet 
    34 
    35if TYPE_CHECKING: 
    36    from ..engine.interfaces import _DBAPICursorDescription 
    37    from ..engine.interfaces import _DBAPIMultiExecuteParams 
    38    from ..engine.interfaces import _DBAPISingleExecuteParams 
    39    from ..engine.interfaces import DBAPIModule 
    40    from ..util.typing import Self 
    41 
    42 
    43class AsyncIODBAPIConnection(Protocol): 
    44    """protocol representing an async adapted version of a 
    45    :pep:`249` database connection. 
    46 
    47 
    48    """ 
    49 
    50    # note that async DBAPIs dont agree if close() should be awaitable, 
    51    # so it is omitted here and picked up by the __getattr__ hook below 
    52 
    53    async def commit(self) -> None: ... 
    54 
    55    def cursor(self, *args: Any, **kwargs: Any) -> AsyncIODBAPICursor: ... 
    56 
    57    async def rollback(self) -> None: ... 
    58 
    59    def __getattr__(self, key: str) -> Any: ... 
    60 
    61    def __setattr__(self, key: str, value: Any) -> None: ... 
    62 
    63 
    64class AsyncIODBAPICursor(Protocol): 
    65    """protocol representing an async adapted version 
    66    of a :pep:`249` database cursor. 
    67 
    68 
    69    """ 
    70 
    71    def __aenter__(self) -> Any: ... 
    72 
    73    @property 
    74    def description( 
    75        self, 
    76    ) -> _DBAPICursorDescription: 
    77        """The description attribute of the Cursor.""" 
    78        ... 
    79 
    80    @property 
    81    def rowcount(self) -> int: ... 
    82 
    83    arraysize: int 
    84 
    85    lastrowid: int 
    86 
    87    async def close(self) -> None: ... 
    88 
    89    async def execute( 
    90        self, 
    91        operation: Any, 
    92        parameters: Optional[_DBAPISingleExecuteParams] = None, 
    93    ) -> Any: ... 
    94 
    95    async def executemany( 
    96        self, 
    97        operation: Any, 
    98        parameters: _DBAPIMultiExecuteParams, 
    99    ) -> Any: ... 
    100 
    101    async def fetchone(self) -> Optional[Any]: ... 
    102 
    103    async def fetchmany(self, size: Optional[int] = ...) -> Sequence[Any]: ... 
    104 
    105    async def fetchall(self) -> Sequence[Any]: ... 
    106 
    107    async def setinputsizes(self, sizes: Sequence[Any]) -> None: ... 
    108 
    109    def setoutputsize(self, size: Any, column: Any) -> None: ... 
    110 
    111    async def callproc( 
    112        self, procname: str, parameters: Sequence[Any] = ... 
    113    ) -> Any: ... 
    114 
    115    async def nextset(self) -> Optional[bool]: ... 
    116 
    117    def __aiter__(self) -> AsyncIterator[Any]: ... 
    118 
    119 
    120class AsyncAdapt_dbapi_module: 
    121    if TYPE_CHECKING: 
    122        Error = DBAPIModule.Error 
    123        OperationalError = DBAPIModule.OperationalError 
    124        InterfaceError = DBAPIModule.InterfaceError 
    125        IntegrityError = DBAPIModule.IntegrityError 
    126 
    127        def __getattr__(self, key: str) -> Any: ... 
    128 
    129    def __init__( 
    130        self, 
    131        driver: types.ModuleType, 
    132        *, 
    133        dbapi_module: types.ModuleType | None = None, 
    134    ): 
    135        self.driver = driver 
    136        self.dbapi_module = dbapi_module 
    137 
    138    @property 
    139    def exceptions_module(self) -> types.ModuleType: 
    140        """Return the module which we think will have the exception hierarchy. 
    141 
    142        For an asyncio driver that wraps a plain DBAPI like aiomysql, 
    143        aioodbc, aiosqlite, etc. these exceptions will be from the 
    144        dbapi_module.  For a "pure" driver like asyncpg these will come 
    145        from the driver module. 
    146 
    147        .. versionadded:: 2.1 
    148 
    149        """ 
    150        if self.dbapi_module is not None: 
    151            return self.dbapi_module 
    152        else: 
    153            return self.driver 
    154 
    155 
    156class AsyncAdapt_dbapi_cursor: 
    157    server_side = False 
    158    __slots__ = ( 
    159        "_adapt_connection", 
    160        "_connection", 
    161        "_cursor", 
    162        "_rows", 
    163        "_soft_closed_memoized", 
    164    ) 
    165 
    166    _awaitable_cursor_close: bool = True 
    167 
    168    _cursor: AsyncIODBAPICursor 
    169    _adapt_connection: AsyncAdapt_dbapi_connection 
    170    _connection: AsyncIODBAPIConnection 
    171    _rows: Deque[Any] 
    172 
    173    def __init__(self, adapt_connection: AsyncAdapt_dbapi_connection): 
    174        self._adapt_connection = adapt_connection 
    175        self._connection = adapt_connection._connection 
    176 
    177        cursor = self._make_new_cursor(self._connection) 
    178        self._cursor = self._aenter_cursor(cursor) 
    179        self._soft_closed_memoized = EMPTY_DICT 
    180        if not self.server_side: 
    181            self._rows = collections.deque() 
    182 
    183    def _aenter_cursor(self, cursor: AsyncIODBAPICursor) -> AsyncIODBAPICursor: 
    184        try: 
    185            return await_(cursor.__aenter__())  # type: ignore[no-any-return] 
    186        except Exception as error: 
    187            self._adapt_connection._handle_exception(error) 
    188 
    189    def _make_new_cursor( 
    190        self, connection: AsyncIODBAPIConnection 
    191    ) -> AsyncIODBAPICursor: 
    192        return connection.cursor() 
    193 
    194    @property 
    195    def description(self) -> Optional[_DBAPICursorDescription]: 
    196        if "description" in self._soft_closed_memoized: 
    197            return self._soft_closed_memoized["description"]  # type: ignore[no-any-return]  # noqa: E501 
    198        return self._cursor.description 
    199 
    200    @property 
    201    def rowcount(self) -> int: 
    202        return self._cursor.rowcount 
    203 
    204    @property 
    205    def arraysize(self) -> int: 
    206        return self._cursor.arraysize 
    207 
    208    @arraysize.setter 
    209    def arraysize(self, value: int) -> None: 
    210        self._cursor.arraysize = value 
    211 
    212    @property 
    213    def lastrowid(self) -> int: 
    214        return self._cursor.lastrowid 
    215 
    216    async def _async_soft_close(self) -> None: 
    217        """close the cursor but keep the results pending, and memoize the 
    218        description. 
    219 
    220        .. versionadded:: 2.0.44 
    221 
    222        """ 
    223 
    224        if not self._awaitable_cursor_close or self.server_side: 
    225            return 
    226 
    227        self._soft_closed_memoized = self._soft_closed_memoized.union( 
    228            { 
    229                "description": self._cursor.description, 
    230            } 
    231        ) 
    232        await self._cursor.close() 
    233 
    234    def close(self) -> None: 
    235        self._rows.clear() 
    236 
    237        # updated as of 2.0.44 
    238        # try to "close" the cursor based on what we know about the driver 
    239        # and if we are able to.  otherwise, hope that the asyncio 
    240        # extension called _async_soft_close() if the cursor is going into 
    241        # a sync context 
    242        if self._cursor is None or bool(self._soft_closed_memoized): 
    243            return 
    244 
    245        if not self._awaitable_cursor_close: 
    246            self._cursor.close()  # type: ignore[unused-coroutine] 
    247        elif in_greenlet(): 
    248            await_(self._cursor.close()) 
    249 
    250    def execute( 
    251        self, 
    252        operation: Any, 
    253        parameters: Optional[_DBAPISingleExecuteParams] = None, 
    254    ) -> Any: 
    255        try: 
    256            return await_(self._execute_async(operation, parameters)) 
    257        except Exception as error: 
    258            self._adapt_connection._handle_exception(error) 
    259 
    260    def executemany( 
    261        self, 
    262        operation: Any, 
    263        seq_of_parameters: _DBAPIMultiExecuteParams, 
    264    ) -> Any: 
    265        try: 
    266            return await_( 
    267                self._executemany_async(operation, seq_of_parameters) 
    268            ) 
    269        except Exception as error: 
    270            self._adapt_connection._handle_exception(error) 
    271 
    272    async def _execute_async( 
    273        self, operation: Any, parameters: Optional[_DBAPISingleExecuteParams] 
    274    ) -> Any: 
    275        async with self._adapt_connection._execute_mutex: 
    276            if parameters is None: 
    277                result = await self._cursor.execute(operation) 
    278            else: 
    279                result = await self._cursor.execute(operation, parameters) 
    280 
    281            if self._cursor.description and not self.server_side: 
    282                self._rows = collections.deque(await self._cursor.fetchall()) 
    283            return result 
    284 
    285    async def _executemany_async( 
    286        self, 
    287        operation: Any, 
    288        seq_of_parameters: _DBAPIMultiExecuteParams, 
    289    ) -> Any: 
    290        async with self._adapt_connection._execute_mutex: 
    291            return await self._cursor.executemany(operation, seq_of_parameters) 
    292 
    293    def nextset(self) -> None: 
    294        await_(self._cursor.nextset()) 
    295        if self._cursor.description and not self.server_side: 
    296            self._rows = collections.deque(await_(self._cursor.fetchall())) 
    297 
    298    def setinputsizes(self, *inputsizes: Any) -> None: 
    299        # NOTE: this is overrridden in aioodbc due to 
    300        # see https://github.com/aio-libs/aioodbc/issues/451 
    301        # right now 
    302 
    303        return await_(self._cursor.setinputsizes(*inputsizes)) 
    304 
    305    def __enter__(self) -> Self: 
    306        return self 
    307 
    308    def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: 
    309        self.close() 
    310 
    311    def __iter__(self) -> Iterator[Any]: 
    312        while self._rows: 
    313            yield self._rows.popleft() 
    314 
    315    def fetchone(self) -> Optional[Any]: 
    316        if self._rows: 
    317            return self._rows.popleft() 
    318        else: 
    319            return None 
    320 
    321    def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]: 
    322        if size is None: 
    323            size = self.arraysize 
    324        rr = self._rows 
    325        return [rr.popleft() for _ in range(min(size, len(rr)))] 
    326 
    327    def fetchall(self) -> Sequence[Any]: 
    328        retval = list(self._rows) 
    329        self._rows.clear() 
    330        return retval 
    331 
    332 
    333class AsyncAdapt_dbapi_ss_cursor(AsyncAdapt_dbapi_cursor): 
    334    __slots__ = () 
    335    server_side = True 
    336 
    337    def close(self) -> None: 
    338        if self._cursor is not None: 
    339            await_(self._cursor.close()) 
    340            self._cursor = None  # type: ignore 
    341 
    342    def fetchone(self) -> Optional[Any]: 
    343        return await_(self._cursor.fetchone()) 
    344 
    345    def fetchmany(self, size: Optional[int] = None) -> Any: 
    346        return await_(self._cursor.fetchmany(size=size)) 
    347 
    348    def fetchall(self) -> Sequence[Any]: 
    349        return await_(self._cursor.fetchall()) 
    350 
    351    def __iter__(self) -> Iterator[Any]: 
    352        iterator = self._cursor.__aiter__() 
    353        while True: 
    354            try: 
    355                yield await_(iterator.__anext__()) 
    356            except StopAsyncIteration: 
    357                break 
    358 
    359 
    360class AsyncAdapt_dbapi_connection(AdaptedConnection): 
    361    _cursor_cls = AsyncAdapt_dbapi_cursor 
    362    _ss_cursor_cls = AsyncAdapt_dbapi_ss_cursor 
    363 
    364    __slots__ = ("dbapi", "_execute_mutex") 
    365 
    366    _connection: AsyncIODBAPIConnection 
    367 
    368    @classmethod 
    369    async def create( 
    370        cls, 
    371        dbapi: Any, 
    372        connection_awaitable: Awaitable[AsyncIODBAPIConnection], 
    373        **kw: Any, 
    374    ) -> Self: 
    375        try: 
    376            connection = await connection_awaitable 
    377        except Exception as error: 
    378            cls._handle_exception_no_connection(dbapi, error) 
    379        else: 
    380            return cls(dbapi, connection, **kw) 
    381 
    382    def __init__(self, dbapi: Any, connection: AsyncIODBAPIConnection): 
    383        self.dbapi = dbapi 
    384        self._connection = connection 
    385        self._execute_mutex = asyncio.Lock() 
    386 
    387    def cursor(self, server_side: bool = False) -> AsyncAdapt_dbapi_cursor: 
    388        if server_side: 
    389            return self._ss_cursor_cls(self) 
    390        else: 
    391            return self._cursor_cls(self) 
    392 
    393    def execute( 
    394        self, 
    395        operation: Any, 
    396        parameters: Optional[_DBAPISingleExecuteParams] = None, 
    397    ) -> Any: 
    398        """lots of DBAPIs seem to provide this, so include it""" 
    399        cursor = self.cursor() 
    400        cursor.execute(operation, parameters) 
    401        return cursor 
    402 
    403    @classmethod 
    404    def _handle_exception_no_connection( 
    405        cls, dbapi: Any, error: Exception 
    406    ) -> NoReturn: 
    407        exc_info = sys.exc_info() 
    408 
    409        raise error.with_traceback(exc_info[2]) 
    410 
    411    def _handle_exception(self, error: Exception) -> NoReturn: 
    412        self._handle_exception_no_connection(self.dbapi, error) 
    413 
    414    def rollback(self) -> None: 
    415        try: 
    416            await_(self._connection.rollback()) 
    417        except Exception as error: 
    418            self._handle_exception(error) 
    419 
    420    def commit(self) -> None: 
    421        try: 
    422            await_(self._connection.commit()) 
    423        except Exception as error: 
    424            self._handle_exception(error) 
    425 
    426    def close(self) -> None: 
    427        await_(self._connection.close()) 
    428 
    429 
    430class AsyncAdapt_terminate: 
    431    """Mixin for a AsyncAdapt_dbapi_connection to add terminate support.""" 
    432 
    433    __slots__ = () 
    434 
    435    def terminate(self) -> None: 
    436        if in_greenlet(): 
    437            # in a greenlet; this is the connection was invalidated case. 
    438            try: 
    439                # try to gracefully close; see #10717 
    440                await_(asyncio.shield(self._terminate_graceful_close())) 
    441            except self._terminate_handled_exceptions() as e: 
    442                # in the case where we are recycling an old connection 
    443                # that may have already been disconnected, close() will 
    444                # fail.  In this case, terminate 
    445                # the connection without any further waiting. 
    446                # see issue #8419 
    447                self._terminate_force_close() 
    448                if isinstance(e, asyncio.CancelledError): 
    449                    # re-raise CancelledError if we were cancelled 
    450                    raise 
    451        else: 
    452            # not in a greenlet; this is the gc cleanup case 
    453            self._terminate_force_close() 
    454 
    455    def _terminate_handled_exceptions(self) -> Tuple[Type[BaseException], ...]: 
    456        """Returns the exceptions that should be handled when 
    457        calling _graceful_close. 
    458        """ 
    459        return (asyncio.TimeoutError, asyncio.CancelledError, OSError) 
    460 
    461    async def _terminate_graceful_close(self) -> None: 
    462        """Try to close connection gracefully""" 
    463        raise NotImplementedError 
    464 
    465    def _terminate_force_close(self) -> None: 
    466        """Terminate the connection""" 
    467        raise NotImplementedError 
    468 
    469 
    470class AsyncAdapt_Error(EmulatedDBAPIException): 
    471    """Provide for the base of DBAPI ``Error`` base class for dialects 
    472    that need to emulate the DBAPI exception hierarchy. 
    473 
    474    .. versionadded:: 2.1 
    475 
    476    """