1# connectors/asyncio.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""generic asyncio-adapted versions of DBAPI connection and cursor"""
9
10from __future__ import annotations
11
12import asyncio
13import collections
14import sys
15import types
16from typing import Any
17from typing import AsyncIterator
18from typing import Awaitable
19from typing import Deque
20from typing import Iterator
21from typing import NoReturn
22from typing import Optional
23from typing import Protocol
24from typing import Sequence
25from typing import Tuple
26from typing import Type
27from typing import TYPE_CHECKING
28
29from ..engine import AdaptedConnection
30from ..exc import EmulatedDBAPIException
31from ..util import EMPTY_DICT
32from ..util.concurrency import await_
33from ..util.concurrency import in_greenlet
34
35if TYPE_CHECKING:
36 from ..engine.interfaces import _DBAPICursorDescription
37 from ..engine.interfaces import _DBAPIMultiExecuteParams
38 from ..engine.interfaces import _DBAPISingleExecuteParams
39 from ..engine.interfaces import DBAPIModule
40 from ..util.typing import Self
41
42
43class AsyncIODBAPIConnection(Protocol):
44 """protocol representing an async adapted version of a
45 :pep:`249` database connection.
46
47
48 """
49
50 # note that async DBAPIs dont agree if close() should be awaitable,
51 # so it is omitted here and picked up by the __getattr__ hook below
52
53 async def commit(self) -> None: ...
54
55 def cursor(self, *args: Any, **kwargs: Any) -> AsyncIODBAPICursor: ...
56
57 async def rollback(self) -> None: ...
58
59 def __getattr__(self, key: str) -> Any: ...
60
61 def __setattr__(self, key: str, value: Any) -> None: ...
62
63
64class AsyncIODBAPICursor(Protocol):
65 """protocol representing an async adapted version
66 of a :pep:`249` database cursor.
67
68
69 """
70
71 def __aenter__(self) -> Any: ...
72
73 @property
74 def description(
75 self,
76 ) -> _DBAPICursorDescription:
77 """The description attribute of the Cursor."""
78 ...
79
80 @property
81 def rowcount(self) -> int: ...
82
83 arraysize: int
84
85 lastrowid: int
86
87 async def close(self) -> None: ...
88
89 async def execute(
90 self,
91 operation: Any,
92 parameters: Optional[_DBAPISingleExecuteParams] = None,
93 ) -> Any: ...
94
95 async def executemany(
96 self,
97 operation: Any,
98 parameters: _DBAPIMultiExecuteParams,
99 ) -> Any: ...
100
101 async def fetchone(self) -> Optional[Any]: ...
102
103 async def fetchmany(self, size: Optional[int] = ...) -> Sequence[Any]: ...
104
105 async def fetchall(self) -> Sequence[Any]: ...
106
107 async def setinputsizes(self, sizes: Sequence[Any]) -> None: ...
108
109 def setoutputsize(self, size: Any, column: Any) -> None: ...
110
111 async def callproc(
112 self, procname: str, parameters: Sequence[Any] = ...
113 ) -> Any: ...
114
115 async def nextset(self) -> Optional[bool]: ...
116
117 def __aiter__(self) -> AsyncIterator[Any]: ...
118
119
120class AsyncAdapt_dbapi_module:
121 if TYPE_CHECKING:
122 Error = DBAPIModule.Error
123 OperationalError = DBAPIModule.OperationalError
124 InterfaceError = DBAPIModule.InterfaceError
125 IntegrityError = DBAPIModule.IntegrityError
126
127 def __getattr__(self, key: str) -> Any: ...
128
129 def __init__(
130 self,
131 driver: types.ModuleType,
132 *,
133 dbapi_module: types.ModuleType | None = None,
134 ):
135 self.driver = driver
136 self.dbapi_module = dbapi_module
137
138 @property
139 def exceptions_module(self) -> types.ModuleType:
140 """Return the module which we think will have the exception hierarchy.
141
142 For an asyncio driver that wraps a plain DBAPI like aiomysql,
143 aioodbc, aiosqlite, etc. these exceptions will be from the
144 dbapi_module. For a "pure" driver like asyncpg these will come
145 from the driver module.
146
147 .. versionadded:: 2.1
148
149 """
150 if self.dbapi_module is not None:
151 return self.dbapi_module
152 else:
153 return self.driver
154
155
156class AsyncAdapt_dbapi_cursor:
157 server_side = False
158 __slots__ = (
159 "_adapt_connection",
160 "_connection",
161 "_cursor",
162 "_rows",
163 "_soft_closed_memoized",
164 )
165
166 _awaitable_cursor_close: bool = True
167
168 _cursor: AsyncIODBAPICursor
169 _adapt_connection: AsyncAdapt_dbapi_connection
170 _connection: AsyncIODBAPIConnection
171 _rows: Deque[Any]
172
173 def __init__(self, adapt_connection: AsyncAdapt_dbapi_connection):
174 self._adapt_connection = adapt_connection
175 self._connection = adapt_connection._connection
176
177 cursor = self._make_new_cursor(self._connection)
178 self._cursor = self._aenter_cursor(cursor)
179 self._soft_closed_memoized = EMPTY_DICT
180 if not self.server_side:
181 self._rows = collections.deque()
182
183 def _aenter_cursor(self, cursor: AsyncIODBAPICursor) -> AsyncIODBAPICursor:
184 try:
185 return await_(cursor.__aenter__()) # type: ignore[no-any-return]
186 except Exception as error:
187 self._adapt_connection._handle_exception(error)
188
189 def _make_new_cursor(
190 self, connection: AsyncIODBAPIConnection
191 ) -> AsyncIODBAPICursor:
192 return connection.cursor()
193
194 @property
195 def description(self) -> Optional[_DBAPICursorDescription]:
196 if "description" in self._soft_closed_memoized:
197 return self._soft_closed_memoized["description"] # type: ignore[no-any-return] # noqa: E501
198 return self._cursor.description
199
200 @property
201 def rowcount(self) -> int:
202 return self._cursor.rowcount
203
204 @property
205 def arraysize(self) -> int:
206 return self._cursor.arraysize
207
208 @arraysize.setter
209 def arraysize(self, value: int) -> None:
210 self._cursor.arraysize = value
211
212 @property
213 def lastrowid(self) -> int:
214 return self._cursor.lastrowid
215
216 async def _async_soft_close(self) -> None:
217 """close the cursor but keep the results pending, and memoize the
218 description.
219
220 .. versionadded:: 2.0.44
221
222 """
223
224 if not self._awaitable_cursor_close or self.server_side:
225 return
226
227 self._soft_closed_memoized = self._soft_closed_memoized.union(
228 {
229 "description": self._cursor.description,
230 }
231 )
232 await self._cursor.close()
233
234 def close(self) -> None:
235 self._rows.clear()
236
237 # updated as of 2.0.44
238 # try to "close" the cursor based on what we know about the driver
239 # and if we are able to. otherwise, hope that the asyncio
240 # extension called _async_soft_close() if the cursor is going into
241 # a sync context
242 if self._cursor is None or bool(self._soft_closed_memoized):
243 return
244
245 if not self._awaitable_cursor_close:
246 self._cursor.close() # type: ignore[unused-coroutine]
247 elif in_greenlet():
248 await_(self._cursor.close())
249
250 def execute(
251 self,
252 operation: Any,
253 parameters: Optional[_DBAPISingleExecuteParams] = None,
254 ) -> Any:
255 try:
256 return await_(self._execute_async(operation, parameters))
257 except Exception as error:
258 self._adapt_connection._handle_exception(error)
259
260 def executemany(
261 self,
262 operation: Any,
263 seq_of_parameters: _DBAPIMultiExecuteParams,
264 ) -> Any:
265 try:
266 return await_(
267 self._executemany_async(operation, seq_of_parameters)
268 )
269 except Exception as error:
270 self._adapt_connection._handle_exception(error)
271
272 async def _execute_async(
273 self, operation: Any, parameters: Optional[_DBAPISingleExecuteParams]
274 ) -> Any:
275 async with self._adapt_connection._execute_mutex:
276 if parameters is None:
277 result = await self._cursor.execute(operation)
278 else:
279 result = await self._cursor.execute(operation, parameters)
280
281 if self._cursor.description and not self.server_side:
282 self._rows = collections.deque(await self._cursor.fetchall())
283 return result
284
285 async def _executemany_async(
286 self,
287 operation: Any,
288 seq_of_parameters: _DBAPIMultiExecuteParams,
289 ) -> Any:
290 async with self._adapt_connection._execute_mutex:
291 return await self._cursor.executemany(operation, seq_of_parameters)
292
293 def nextset(self) -> None:
294 await_(self._cursor.nextset())
295 if self._cursor.description and not self.server_side:
296 self._rows = collections.deque(await_(self._cursor.fetchall()))
297
298 def setinputsizes(self, *inputsizes: Any) -> None:
299 # NOTE: this is overrridden in aioodbc due to
300 # see https://github.com/aio-libs/aioodbc/issues/451
301 # right now
302
303 return await_(self._cursor.setinputsizes(*inputsizes))
304
305 def __enter__(self) -> Self:
306 return self
307
308 def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
309 self.close()
310
311 def __iter__(self) -> Iterator[Any]:
312 while self._rows:
313 yield self._rows.popleft()
314
315 def fetchone(self) -> Optional[Any]:
316 if self._rows:
317 return self._rows.popleft()
318 else:
319 return None
320
321 def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]:
322 if size is None:
323 size = self.arraysize
324 rr = self._rows
325 return [rr.popleft() for _ in range(min(size, len(rr)))]
326
327 def fetchall(self) -> Sequence[Any]:
328 retval = list(self._rows)
329 self._rows.clear()
330 return retval
331
332
333class AsyncAdapt_dbapi_ss_cursor(AsyncAdapt_dbapi_cursor):
334 __slots__ = ()
335 server_side = True
336
337 def close(self) -> None:
338 if self._cursor is not None:
339 await_(self._cursor.close())
340 self._cursor = None # type: ignore
341
342 def fetchone(self) -> Optional[Any]:
343 return await_(self._cursor.fetchone())
344
345 def fetchmany(self, size: Optional[int] = None) -> Any:
346 return await_(self._cursor.fetchmany(size=size))
347
348 def fetchall(self) -> Sequence[Any]:
349 return await_(self._cursor.fetchall())
350
351 def __iter__(self) -> Iterator[Any]:
352 iterator = self._cursor.__aiter__()
353 while True:
354 try:
355 yield await_(iterator.__anext__())
356 except StopAsyncIteration:
357 break
358
359
360class AsyncAdapt_dbapi_connection(AdaptedConnection):
361 _cursor_cls = AsyncAdapt_dbapi_cursor
362 _ss_cursor_cls = AsyncAdapt_dbapi_ss_cursor
363
364 __slots__ = ("dbapi", "_execute_mutex")
365
366 _connection: AsyncIODBAPIConnection
367
368 @classmethod
369 async def create(
370 cls,
371 dbapi: Any,
372 connection_awaitable: Awaitable[AsyncIODBAPIConnection],
373 **kw: Any,
374 ) -> Self:
375 try:
376 connection = await connection_awaitable
377 except Exception as error:
378 cls._handle_exception_no_connection(dbapi, error)
379 else:
380 return cls(dbapi, connection, **kw)
381
382 def __init__(self, dbapi: Any, connection: AsyncIODBAPIConnection):
383 self.dbapi = dbapi
384 self._connection = connection
385 self._execute_mutex = asyncio.Lock()
386
387 def cursor(self, server_side: bool = False) -> AsyncAdapt_dbapi_cursor:
388 if server_side:
389 return self._ss_cursor_cls(self)
390 else:
391 return self._cursor_cls(self)
392
393 def execute(
394 self,
395 operation: Any,
396 parameters: Optional[_DBAPISingleExecuteParams] = None,
397 ) -> Any:
398 """lots of DBAPIs seem to provide this, so include it"""
399 cursor = self.cursor()
400 cursor.execute(operation, parameters)
401 return cursor
402
403 @classmethod
404 def _handle_exception_no_connection(
405 cls, dbapi: Any, error: Exception
406 ) -> NoReturn:
407 exc_info = sys.exc_info()
408
409 raise error.with_traceback(exc_info[2])
410
411 def _handle_exception(self, error: Exception) -> NoReturn:
412 self._handle_exception_no_connection(self.dbapi, error)
413
414 def rollback(self) -> None:
415 try:
416 await_(self._connection.rollback())
417 except Exception as error:
418 self._handle_exception(error)
419
420 def commit(self) -> None:
421 try:
422 await_(self._connection.commit())
423 except Exception as error:
424 self._handle_exception(error)
425
426 def close(self) -> None:
427 await_(self._connection.close())
428
429
430class AsyncAdapt_terminate:
431 """Mixin for a AsyncAdapt_dbapi_connection to add terminate support."""
432
433 __slots__ = ()
434
435 def terminate(self) -> None:
436 if in_greenlet():
437 # in a greenlet; this is the connection was invalidated case.
438 try:
439 # try to gracefully close; see #10717
440 await_(asyncio.shield(self._terminate_graceful_close()))
441 except self._terminate_handled_exceptions() as e:
442 # in the case where we are recycling an old connection
443 # that may have already been disconnected, close() will
444 # fail. In this case, terminate
445 # the connection without any further waiting.
446 # see issue #8419
447 self._terminate_force_close()
448 if isinstance(e, asyncio.CancelledError):
449 # re-raise CancelledError if we were cancelled
450 raise
451 else:
452 # not in a greenlet; this is the gc cleanup case
453 self._terminate_force_close()
454
455 def _terminate_handled_exceptions(self) -> Tuple[Type[BaseException], ...]:
456 """Returns the exceptions that should be handled when
457 calling _graceful_close.
458 """
459 return (asyncio.TimeoutError, asyncio.CancelledError, OSError)
460
461 async def _terminate_graceful_close(self) -> None:
462 """Try to close connection gracefully"""
463 raise NotImplementedError
464
465 def _terminate_force_close(self) -> None:
466 """Terminate the connection"""
467 raise NotImplementedError
468
469
470class AsyncAdapt_Error(EmulatedDBAPIException):
471 """Provide for the base of DBAPI ``Error`` base class for dialects
472 that need to emulate the DBAPI exception hierarchy.
473
474 .. versionadded:: 2.1
475
476 """