1# connectors/asyncio.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""generic asyncio-adapted versions of DBAPI connection and cursor"""
9
10from __future__ import annotations
11
12import asyncio
13import collections
14import sys
15from typing import Any
16from typing import AsyncIterator
17from typing import Deque
18from typing import Iterator
19from typing import NoReturn
20from typing import Optional
21from typing import Protocol
22from typing import Sequence
23from typing import TYPE_CHECKING
24
25from ..engine import AdaptedConnection
26from ..util.concurrency import await_
27
28if TYPE_CHECKING:
29 from ..engine.interfaces import _DBAPICursorDescription
30 from ..engine.interfaces import _DBAPIMultiExecuteParams
31 from ..engine.interfaces import _DBAPISingleExecuteParams
32 from ..engine.interfaces import DBAPIModule
33 from ..util.typing import Self
34
35
36class AsyncIODBAPIConnection(Protocol):
37 """protocol representing an async adapted version of a
38 :pep:`249` database connection.
39
40
41 """
42
43 # note that async DBAPIs dont agree if close() should be awaitable,
44 # so it is omitted here and picked up by the __getattr__ hook below
45
46 async def commit(self) -> None: ...
47
48 def cursor(self, *args: Any, **kwargs: Any) -> AsyncIODBAPICursor: ...
49
50 async def rollback(self) -> None: ...
51
52 def __getattr__(self, key: str) -> Any: ...
53
54 def __setattr__(self, key: str, value: Any) -> None: ...
55
56
57class AsyncIODBAPICursor(Protocol):
58 """protocol representing an async adapted version
59 of a :pep:`249` database cursor.
60
61
62 """
63
64 def __aenter__(self) -> Any: ...
65
66 @property
67 def description(
68 self,
69 ) -> _DBAPICursorDescription:
70 """The description attribute of the Cursor."""
71 ...
72
73 @property
74 def rowcount(self) -> int: ...
75
76 arraysize: int
77
78 lastrowid: int
79
80 async def close(self) -> None: ...
81
82 async def execute(
83 self,
84 operation: Any,
85 parameters: Optional[_DBAPISingleExecuteParams] = None,
86 ) -> Any: ...
87
88 async def executemany(
89 self,
90 operation: Any,
91 parameters: _DBAPIMultiExecuteParams,
92 ) -> Any: ...
93
94 async def fetchone(self) -> Optional[Any]: ...
95
96 async def fetchmany(self, size: Optional[int] = ...) -> Sequence[Any]: ...
97
98 async def fetchall(self) -> Sequence[Any]: ...
99
100 async def setinputsizes(self, sizes: Sequence[Any]) -> None: ...
101
102 def setoutputsize(self, size: Any, column: Any) -> None: ...
103
104 async def callproc(
105 self, procname: str, parameters: Sequence[Any] = ...
106 ) -> Any: ...
107
108 async def nextset(self) -> Optional[bool]: ...
109
110 def __aiter__(self) -> AsyncIterator[Any]: ...
111
112
113class AsyncAdapt_dbapi_module:
114 if TYPE_CHECKING:
115 Error = DBAPIModule.Error
116 OperationalError = DBAPIModule.OperationalError
117 InterfaceError = DBAPIModule.InterfaceError
118 IntegrityError = DBAPIModule.IntegrityError
119
120 def __getattr__(self, key: str) -> Any: ...
121
122
123class AsyncAdapt_dbapi_cursor:
124 server_side = False
125 __slots__ = (
126 "_adapt_connection",
127 "_connection",
128 "_cursor",
129 "_rows",
130 )
131
132 _cursor: AsyncIODBAPICursor
133 _adapt_connection: AsyncAdapt_dbapi_connection
134 _connection: AsyncIODBAPIConnection
135 _rows: Deque[Any]
136
137 def __init__(self, adapt_connection: AsyncAdapt_dbapi_connection):
138 self._adapt_connection = adapt_connection
139 self._connection = adapt_connection._connection
140
141 cursor = self._make_new_cursor(self._connection)
142 self._cursor = self._aenter_cursor(cursor)
143
144 if not self.server_side:
145 self._rows = collections.deque()
146
147 def _aenter_cursor(self, cursor: AsyncIODBAPICursor) -> AsyncIODBAPICursor:
148 try:
149 return await_(cursor.__aenter__()) # type: ignore[no-any-return]
150 except Exception as error:
151 self._adapt_connection._handle_exception(error)
152
153 def _make_new_cursor(
154 self, connection: AsyncIODBAPIConnection
155 ) -> AsyncIODBAPICursor:
156 return connection.cursor()
157
158 @property
159 def description(self) -> Optional[_DBAPICursorDescription]:
160 return self._cursor.description
161
162 @property
163 def rowcount(self) -> int:
164 return self._cursor.rowcount
165
166 @property
167 def arraysize(self) -> int:
168 return self._cursor.arraysize
169
170 @arraysize.setter
171 def arraysize(self, value: int) -> None:
172 self._cursor.arraysize = value
173
174 @property
175 def lastrowid(self) -> int:
176 return self._cursor.lastrowid
177
178 def close(self) -> None:
179 # note we aren't actually closing the cursor here,
180 # we are just letting GC do it. see notes in aiomysql dialect
181 self._rows.clear()
182
183 def execute(
184 self,
185 operation: Any,
186 parameters: Optional[_DBAPISingleExecuteParams] = None,
187 ) -> Any:
188 try:
189 return await_(self._execute_async(operation, parameters))
190 except Exception as error:
191 self._adapt_connection._handle_exception(error)
192
193 def executemany(
194 self,
195 operation: Any,
196 seq_of_parameters: _DBAPIMultiExecuteParams,
197 ) -> Any:
198 try:
199 return await_(
200 self._executemany_async(operation, seq_of_parameters)
201 )
202 except Exception as error:
203 self._adapt_connection._handle_exception(error)
204
205 async def _execute_async(
206 self, operation: Any, parameters: Optional[_DBAPISingleExecuteParams]
207 ) -> Any:
208 async with self._adapt_connection._execute_mutex:
209 if parameters is None:
210 result = await self._cursor.execute(operation)
211 else:
212 result = await self._cursor.execute(operation, parameters)
213
214 if self._cursor.description and not self.server_side:
215 self._rows = collections.deque(await self._cursor.fetchall())
216 return result
217
218 async def _executemany_async(
219 self,
220 operation: Any,
221 seq_of_parameters: _DBAPIMultiExecuteParams,
222 ) -> Any:
223 async with self._adapt_connection._execute_mutex:
224 return await self._cursor.executemany(operation, seq_of_parameters)
225
226 def nextset(self) -> None:
227 await_(self._cursor.nextset())
228 if self._cursor.description and not self.server_side:
229 self._rows = collections.deque(await_(self._cursor.fetchall()))
230
231 def setinputsizes(self, *inputsizes: Any) -> None:
232 # NOTE: this is overrridden in aioodbc due to
233 # see https://github.com/aio-libs/aioodbc/issues/451
234 # right now
235
236 return await_(self._cursor.setinputsizes(*inputsizes))
237
238 def __enter__(self) -> Self:
239 return self
240
241 def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
242 self.close()
243
244 def __iter__(self) -> Iterator[Any]:
245 while self._rows:
246 yield self._rows.popleft()
247
248 def fetchone(self) -> Optional[Any]:
249 if self._rows:
250 return self._rows.popleft()
251 else:
252 return None
253
254 def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]:
255 if size is None:
256 size = self.arraysize
257 rr = self._rows
258 return [rr.popleft() for _ in range(min(size, len(rr)))]
259
260 def fetchall(self) -> Sequence[Any]:
261 retval = list(self._rows)
262 self._rows.clear()
263 return retval
264
265
266class AsyncAdapt_dbapi_ss_cursor(AsyncAdapt_dbapi_cursor):
267 __slots__ = ()
268 server_side = True
269
270 def close(self) -> None:
271 if self._cursor is not None:
272 await_(self._cursor.close())
273 self._cursor = None # type: ignore
274
275 def fetchone(self) -> Optional[Any]:
276 return await_(self._cursor.fetchone())
277
278 def fetchmany(self, size: Optional[int] = None) -> Any:
279 return await_(self._cursor.fetchmany(size=size))
280
281 def fetchall(self) -> Sequence[Any]:
282 return await_(self._cursor.fetchall())
283
284 def __iter__(self) -> Iterator[Any]:
285 iterator = self._cursor.__aiter__()
286 while True:
287 try:
288 yield await_(iterator.__anext__())
289 except StopAsyncIteration:
290 break
291
292
293class AsyncAdapt_dbapi_connection(AdaptedConnection):
294 _cursor_cls = AsyncAdapt_dbapi_cursor
295 _ss_cursor_cls = AsyncAdapt_dbapi_ss_cursor
296
297 __slots__ = ("dbapi", "_execute_mutex")
298
299 _connection: AsyncIODBAPIConnection
300
301 def __init__(self, dbapi: Any, connection: AsyncIODBAPIConnection):
302 self.dbapi = dbapi
303 self._connection = connection
304 self._execute_mutex = asyncio.Lock()
305
306 def cursor(self, server_side: bool = False) -> AsyncAdapt_dbapi_cursor:
307 if server_side:
308 return self._ss_cursor_cls(self)
309 else:
310 return self._cursor_cls(self)
311
312 def execute(
313 self,
314 operation: Any,
315 parameters: Optional[_DBAPISingleExecuteParams] = None,
316 ) -> Any:
317 """lots of DBAPIs seem to provide this, so include it"""
318 cursor = self.cursor()
319 cursor.execute(operation, parameters)
320 return cursor
321
322 def _handle_exception(self, error: Exception) -> NoReturn:
323 exc_info = sys.exc_info()
324
325 raise error.with_traceback(exc_info[2])
326
327 def rollback(self) -> None:
328 try:
329 await_(self._connection.rollback())
330 except Exception as error:
331 self._handle_exception(error)
332
333 def commit(self) -> None:
334 try:
335 await_(self._connection.commit())
336 except Exception as error:
337 self._handle_exception(error)
338
339 def close(self) -> None:
340 await_(self._connection.close())