1# connectors/asyncio.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""generic asyncio-adapted versions of DBAPI connection and cursor"""
9
10from __future__ import annotations
11
12import asyncio
13import collections
14import sys
15from typing import Any
16from typing import AsyncIterator
17from typing import Deque
18from typing import Iterator
19from typing import NoReturn
20from typing import Optional
21from typing import Sequence
22from typing import TYPE_CHECKING
23
24from ..engine import AdaptedConnection
25from ..util.concurrency import await_fallback
26from ..util.concurrency import await_only
27from ..util.typing import Protocol
28
29if TYPE_CHECKING:
30 from ..engine.interfaces import _DBAPICursorDescription
31 from ..engine.interfaces import _DBAPIMultiExecuteParams
32 from ..engine.interfaces import _DBAPISingleExecuteParams
33 from ..engine.interfaces import DBAPIModule
34 from ..util.typing import Self
35
36
37class AsyncIODBAPIConnection(Protocol):
38 """protocol representing an async adapted version of a
39 :pep:`249` database connection.
40
41
42 """
43
44 # note that async DBAPIs dont agree if close() should be awaitable,
45 # so it is omitted here and picked up by the __getattr__ hook below
46
47 async def commit(self) -> None: ...
48
49 def cursor(self, *args: Any, **kwargs: Any) -> AsyncIODBAPICursor: ...
50
51 async def rollback(self) -> None: ...
52
53 def __getattr__(self, key: str) -> Any: ...
54
55 def __setattr__(self, key: str, value: Any) -> None: ...
56
57
58class AsyncIODBAPICursor(Protocol):
59 """protocol representing an async adapted version
60 of a :pep:`249` database cursor.
61
62
63 """
64
65 def __aenter__(self) -> Any: ...
66
67 @property
68 def description(
69 self,
70 ) -> _DBAPICursorDescription:
71 """The description attribute of the Cursor."""
72 ...
73
74 @property
75 def rowcount(self) -> int: ...
76
77 arraysize: int
78
79 lastrowid: int
80
81 async def close(self) -> None: ...
82
83 async def execute(
84 self,
85 operation: Any,
86 parameters: Optional[_DBAPISingleExecuteParams] = None,
87 ) -> Any: ...
88
89 async def executemany(
90 self,
91 operation: Any,
92 parameters: _DBAPIMultiExecuteParams,
93 ) -> Any: ...
94
95 async def fetchone(self) -> Optional[Any]: ...
96
97 async def fetchmany(self, size: Optional[int] = ...) -> Sequence[Any]: ...
98
99 async def fetchall(self) -> Sequence[Any]: ...
100
101 async def setinputsizes(self, sizes: Sequence[Any]) -> None: ...
102
103 def setoutputsize(self, size: Any, column: Any) -> None: ...
104
105 async def callproc(
106 self, procname: str, parameters: Sequence[Any] = ...
107 ) -> Any: ...
108
109 async def nextset(self) -> Optional[bool]: ...
110
111 def __aiter__(self) -> AsyncIterator[Any]: ...
112
113
114class AsyncAdapt_dbapi_module:
115 if TYPE_CHECKING:
116 Error = DBAPIModule.Error
117 OperationalError = DBAPIModule.OperationalError
118 InterfaceError = DBAPIModule.InterfaceError
119 IntegrityError = DBAPIModule.IntegrityError
120
121 def __getattr__(self, key: str) -> Any: ...
122
123
124class AsyncAdapt_dbapi_cursor:
125 server_side = False
126 __slots__ = (
127 "_adapt_connection",
128 "_connection",
129 "await_",
130 "_cursor",
131 "_rows",
132 )
133
134 _cursor: AsyncIODBAPICursor
135 _adapt_connection: AsyncAdapt_dbapi_connection
136 _connection: AsyncIODBAPIConnection
137 _rows: Deque[Any]
138
139 def __init__(self, adapt_connection: AsyncAdapt_dbapi_connection):
140 self._adapt_connection = adapt_connection
141 self._connection = adapt_connection._connection
142
143 self.await_ = adapt_connection.await_
144
145 cursor = self._make_new_cursor(self._connection)
146 self._cursor = self._aenter_cursor(cursor)
147
148 if not self.server_side:
149 self._rows = collections.deque()
150
151 def _aenter_cursor(self, cursor: AsyncIODBAPICursor) -> AsyncIODBAPICursor:
152 return self.await_(cursor.__aenter__()) # type: ignore[no-any-return]
153
154 def _make_new_cursor(
155 self, connection: AsyncIODBAPIConnection
156 ) -> AsyncIODBAPICursor:
157 return connection.cursor()
158
159 @property
160 def description(self) -> Optional[_DBAPICursorDescription]:
161 return self._cursor.description
162
163 @property
164 def rowcount(self) -> int:
165 return self._cursor.rowcount
166
167 @property
168 def arraysize(self) -> int:
169 return self._cursor.arraysize
170
171 @arraysize.setter
172 def arraysize(self, value: int) -> None:
173 self._cursor.arraysize = value
174
175 @property
176 def lastrowid(self) -> int:
177 return self._cursor.lastrowid
178
179 def close(self) -> None:
180 # note we aren't actually closing the cursor here,
181 # we are just letting GC do it. see notes in aiomysql dialect
182 self._rows.clear()
183
184 def execute(
185 self,
186 operation: Any,
187 parameters: Optional[_DBAPISingleExecuteParams] = None,
188 ) -> Any:
189 try:
190 return self.await_(self._execute_async(operation, parameters))
191 except Exception as error:
192 self._adapt_connection._handle_exception(error)
193
194 def executemany(
195 self,
196 operation: Any,
197 seq_of_parameters: _DBAPIMultiExecuteParams,
198 ) -> Any:
199 try:
200 return self.await_(
201 self._executemany_async(operation, seq_of_parameters)
202 )
203 except Exception as error:
204 self._adapt_connection._handle_exception(error)
205
206 async def _execute_async(
207 self, operation: Any, parameters: Optional[_DBAPISingleExecuteParams]
208 ) -> Any:
209 async with self._adapt_connection._execute_mutex:
210 if parameters is None:
211 result = await self._cursor.execute(operation)
212 else:
213 result = await self._cursor.execute(operation, parameters)
214
215 if self._cursor.description and not self.server_side:
216 self._rows = collections.deque(await self._cursor.fetchall())
217 return result
218
219 async def _executemany_async(
220 self,
221 operation: Any,
222 seq_of_parameters: _DBAPIMultiExecuteParams,
223 ) -> Any:
224 async with self._adapt_connection._execute_mutex:
225 return await self._cursor.executemany(operation, seq_of_parameters)
226
227 def nextset(self) -> None:
228 self.await_(self._cursor.nextset())
229 if self._cursor.description and not self.server_side:
230 self._rows = collections.deque(
231 self.await_(self._cursor.fetchall())
232 )
233
234 def setinputsizes(self, *inputsizes: Any) -> None:
235 # NOTE: this is overrridden in aioodbc due to
236 # see https://github.com/aio-libs/aioodbc/issues/451
237 # right now
238
239 return self.await_(self._cursor.setinputsizes(*inputsizes))
240
241 def __enter__(self) -> Self:
242 return self
243
244 def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
245 self.close()
246
247 def __iter__(self) -> Iterator[Any]:
248 while self._rows:
249 yield self._rows.popleft()
250
251 def fetchone(self) -> Optional[Any]:
252 if self._rows:
253 return self._rows.popleft()
254 else:
255 return None
256
257 def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]:
258 if size is None:
259 size = self.arraysize
260 rr = self._rows
261 return [rr.popleft() for _ in range(min(size, len(rr)))]
262
263 def fetchall(self) -> Sequence[Any]:
264 retval = list(self._rows)
265 self._rows.clear()
266 return retval
267
268
269class AsyncAdapt_dbapi_ss_cursor(AsyncAdapt_dbapi_cursor):
270 __slots__ = ()
271 server_side = True
272
273 def close(self) -> None:
274 if self._cursor is not None:
275 self.await_(self._cursor.close())
276 self._cursor = None # type: ignore
277
278 def fetchone(self) -> Optional[Any]:
279 return self.await_(self._cursor.fetchone())
280
281 def fetchmany(self, size: Optional[int] = None) -> Any:
282 return self.await_(self._cursor.fetchmany(size=size))
283
284 def fetchall(self) -> Sequence[Any]:
285 return self.await_(self._cursor.fetchall())
286
287 def __iter__(self) -> Iterator[Any]:
288 iterator = self._cursor.__aiter__()
289 while True:
290 try:
291 yield self.await_(iterator.__anext__())
292 except StopAsyncIteration:
293 break
294
295
296class AsyncAdapt_dbapi_connection(AdaptedConnection):
297 _cursor_cls = AsyncAdapt_dbapi_cursor
298 _ss_cursor_cls = AsyncAdapt_dbapi_ss_cursor
299
300 await_ = staticmethod(await_only)
301
302 __slots__ = ("dbapi", "_execute_mutex")
303
304 _connection: AsyncIODBAPIConnection
305
306 def __init__(self, dbapi: Any, connection: AsyncIODBAPIConnection):
307 self.dbapi = dbapi
308 self._connection = connection
309 self._execute_mutex = asyncio.Lock()
310
311 def cursor(self, server_side: bool = False) -> AsyncAdapt_dbapi_cursor:
312 if server_side:
313 return self._ss_cursor_cls(self)
314 else:
315 return self._cursor_cls(self)
316
317 def execute(
318 self,
319 operation: Any,
320 parameters: Optional[_DBAPISingleExecuteParams] = None,
321 ) -> Any:
322 """lots of DBAPIs seem to provide this, so include it"""
323 cursor = self.cursor()
324 cursor.execute(operation, parameters)
325 return cursor
326
327 def _handle_exception(self, error: Exception) -> NoReturn:
328 exc_info = sys.exc_info()
329
330 raise error.with_traceback(exc_info[2])
331
332 def rollback(self) -> None:
333 try:
334 self.await_(self._connection.rollback())
335 except Exception as error:
336 self._handle_exception(error)
337
338 def commit(self) -> None:
339 try:
340 self.await_(self._connection.commit())
341 except Exception as error:
342 self._handle_exception(error)
343
344 def close(self) -> None:
345 self.await_(self._connection.close())
346
347
348class AsyncAdaptFallback_dbapi_connection(AsyncAdapt_dbapi_connection):
349 __slots__ = ()
350
351 await_ = staticmethod(await_fallback)