1# connectors/asyncio.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""generic asyncio-adapted versions of DBAPI connection and cursor"""
9
10from __future__ import annotations
11
12import asyncio
13import collections
14import sys
15from typing import Any
16from typing import AsyncIterator
17from typing import Deque
18from typing import Iterator
19from typing import NoReturn
20from typing import Optional
21from typing import Protocol
22from typing import Sequence
23from typing import TYPE_CHECKING
24
25from ..engine import AdaptedConnection
26from ..util import EMPTY_DICT
27from ..util.concurrency import await_
28from ..util.concurrency import in_greenlet
29
30if TYPE_CHECKING:
31 from ..engine.interfaces import _DBAPICursorDescription
32 from ..engine.interfaces import _DBAPIMultiExecuteParams
33 from ..engine.interfaces import _DBAPISingleExecuteParams
34 from ..engine.interfaces import DBAPIModule
35 from ..util.typing import Self
36
37
38class AsyncIODBAPIConnection(Protocol):
39 """protocol representing an async adapted version of a
40 :pep:`249` database connection.
41
42
43 """
44
45 # note that async DBAPIs dont agree if close() should be awaitable,
46 # so it is omitted here and picked up by the __getattr__ hook below
47
48 async def commit(self) -> None: ...
49
50 def cursor(self, *args: Any, **kwargs: Any) -> AsyncIODBAPICursor: ...
51
52 async def rollback(self) -> None: ...
53
54 def __getattr__(self, key: str) -> Any: ...
55
56 def __setattr__(self, key: str, value: Any) -> None: ...
57
58
59class AsyncIODBAPICursor(Protocol):
60 """protocol representing an async adapted version
61 of a :pep:`249` database cursor.
62
63
64 """
65
66 def __aenter__(self) -> Any: ...
67
68 @property
69 def description(
70 self,
71 ) -> _DBAPICursorDescription:
72 """The description attribute of the Cursor."""
73 ...
74
75 @property
76 def rowcount(self) -> int: ...
77
78 arraysize: int
79
80 lastrowid: int
81
82 async def close(self) -> None: ...
83
84 async def execute(
85 self,
86 operation: Any,
87 parameters: Optional[_DBAPISingleExecuteParams] = None,
88 ) -> Any: ...
89
90 async def executemany(
91 self,
92 operation: Any,
93 parameters: _DBAPIMultiExecuteParams,
94 ) -> Any: ...
95
96 async def fetchone(self) -> Optional[Any]: ...
97
98 async def fetchmany(self, size: Optional[int] = ...) -> Sequence[Any]: ...
99
100 async def fetchall(self) -> Sequence[Any]: ...
101
102 async def setinputsizes(self, sizes: Sequence[Any]) -> None: ...
103
104 def setoutputsize(self, size: Any, column: Any) -> None: ...
105
106 async def callproc(
107 self, procname: str, parameters: Sequence[Any] = ...
108 ) -> Any: ...
109
110 async def nextset(self) -> Optional[bool]: ...
111
112 def __aiter__(self) -> AsyncIterator[Any]: ...
113
114
115class AsyncAdapt_dbapi_module:
116 if TYPE_CHECKING:
117 Error = DBAPIModule.Error
118 OperationalError = DBAPIModule.OperationalError
119 InterfaceError = DBAPIModule.InterfaceError
120 IntegrityError = DBAPIModule.IntegrityError
121
122 def __getattr__(self, key: str) -> Any: ...
123
124
125class AsyncAdapt_dbapi_cursor:
126 server_side = False
127 __slots__ = (
128 "_adapt_connection",
129 "_connection",
130 "_cursor",
131 "_rows",
132 "_soft_closed_memoized",
133 )
134
135 _awaitable_cursor_close: bool = True
136
137 _cursor: AsyncIODBAPICursor
138 _adapt_connection: AsyncAdapt_dbapi_connection
139 _connection: AsyncIODBAPIConnection
140 _rows: Deque[Any]
141
142 def __init__(self, adapt_connection: AsyncAdapt_dbapi_connection):
143 self._adapt_connection = adapt_connection
144 self._connection = adapt_connection._connection
145
146 cursor = self._make_new_cursor(self._connection)
147 self._cursor = self._aenter_cursor(cursor)
148 self._soft_closed_memoized = EMPTY_DICT
149 if not self.server_side:
150 self._rows = collections.deque()
151
152 def _aenter_cursor(self, cursor: AsyncIODBAPICursor) -> AsyncIODBAPICursor:
153 try:
154 return await_(cursor.__aenter__()) # type: ignore[no-any-return]
155 except Exception as error:
156 self._adapt_connection._handle_exception(error)
157
158 def _make_new_cursor(
159 self, connection: AsyncIODBAPIConnection
160 ) -> AsyncIODBAPICursor:
161 return connection.cursor()
162
163 @property
164 def description(self) -> Optional[_DBAPICursorDescription]:
165 if "description" in self._soft_closed_memoized:
166 return self._soft_closed_memoized["description"] # type: ignore[no-any-return] # noqa: E501
167 return self._cursor.description
168
169 @property
170 def rowcount(self) -> int:
171 return self._cursor.rowcount
172
173 @property
174 def arraysize(self) -> int:
175 return self._cursor.arraysize
176
177 @arraysize.setter
178 def arraysize(self, value: int) -> None:
179 self._cursor.arraysize = value
180
181 @property
182 def lastrowid(self) -> int:
183 return self._cursor.lastrowid
184
185 async def _async_soft_close(self) -> None:
186 """close the cursor but keep the results pending, and memoize the
187 description.
188
189 .. versionadded:: 2.0.44
190
191 """
192
193 if not self._awaitable_cursor_close or self.server_side:
194 return
195
196 self._soft_closed_memoized = self._soft_closed_memoized.union(
197 {
198 "description": self._cursor.description,
199 }
200 )
201 await self._cursor.close()
202
203 def close(self) -> None:
204 self._rows.clear()
205
206 # updated as of 2.0.44
207 # try to "close" the cursor based on what we know about the driver
208 # and if we are able to. otherwise, hope that the asyncio
209 # extension called _async_soft_close() if the cursor is going into
210 # a sync context
211 if self._cursor is None or bool(self._soft_closed_memoized):
212 return
213
214 if not self._awaitable_cursor_close:
215 self._cursor.close() # type: ignore[unused-coroutine]
216 elif in_greenlet():
217 await_(self._cursor.close())
218
219 def execute(
220 self,
221 operation: Any,
222 parameters: Optional[_DBAPISingleExecuteParams] = None,
223 ) -> Any:
224 try:
225 return await_(self._execute_async(operation, parameters))
226 except Exception as error:
227 self._adapt_connection._handle_exception(error)
228
229 def executemany(
230 self,
231 operation: Any,
232 seq_of_parameters: _DBAPIMultiExecuteParams,
233 ) -> Any:
234 try:
235 return await_(
236 self._executemany_async(operation, seq_of_parameters)
237 )
238 except Exception as error:
239 self._adapt_connection._handle_exception(error)
240
241 async def _execute_async(
242 self, operation: Any, parameters: Optional[_DBAPISingleExecuteParams]
243 ) -> Any:
244 async with self._adapt_connection._execute_mutex:
245 if parameters is None:
246 result = await self._cursor.execute(operation)
247 else:
248 result = await self._cursor.execute(operation, parameters)
249
250 if self._cursor.description and not self.server_side:
251 self._rows = collections.deque(await self._cursor.fetchall())
252 return result
253
254 async def _executemany_async(
255 self,
256 operation: Any,
257 seq_of_parameters: _DBAPIMultiExecuteParams,
258 ) -> Any:
259 async with self._adapt_connection._execute_mutex:
260 return await self._cursor.executemany(operation, seq_of_parameters)
261
262 def nextset(self) -> None:
263 await_(self._cursor.nextset())
264 if self._cursor.description and not self.server_side:
265 self._rows = collections.deque(await_(self._cursor.fetchall()))
266
267 def setinputsizes(self, *inputsizes: Any) -> None:
268 # NOTE: this is overrridden in aioodbc due to
269 # see https://github.com/aio-libs/aioodbc/issues/451
270 # right now
271
272 return await_(self._cursor.setinputsizes(*inputsizes))
273
274 def __enter__(self) -> Self:
275 return self
276
277 def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
278 self.close()
279
280 def __iter__(self) -> Iterator[Any]:
281 while self._rows:
282 yield self._rows.popleft()
283
284 def fetchone(self) -> Optional[Any]:
285 if self._rows:
286 return self._rows.popleft()
287 else:
288 return None
289
290 def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]:
291 if size is None:
292 size = self.arraysize
293 rr = self._rows
294 return [rr.popleft() for _ in range(min(size, len(rr)))]
295
296 def fetchall(self) -> Sequence[Any]:
297 retval = list(self._rows)
298 self._rows.clear()
299 return retval
300
301
302class AsyncAdapt_dbapi_ss_cursor(AsyncAdapt_dbapi_cursor):
303 __slots__ = ()
304 server_side = True
305
306 def close(self) -> None:
307 if self._cursor is not None:
308 await_(self._cursor.close())
309 self._cursor = None # type: ignore
310
311 def fetchone(self) -> Optional[Any]:
312 return await_(self._cursor.fetchone())
313
314 def fetchmany(self, size: Optional[int] = None) -> Any:
315 return await_(self._cursor.fetchmany(size=size))
316
317 def fetchall(self) -> Sequence[Any]:
318 return await_(self._cursor.fetchall())
319
320 def __iter__(self) -> Iterator[Any]:
321 iterator = self._cursor.__aiter__()
322 while True:
323 try:
324 yield await_(iterator.__anext__())
325 except StopAsyncIteration:
326 break
327
328
329class AsyncAdapt_dbapi_connection(AdaptedConnection):
330 _cursor_cls = AsyncAdapt_dbapi_cursor
331 _ss_cursor_cls = AsyncAdapt_dbapi_ss_cursor
332
333 __slots__ = ("dbapi", "_execute_mutex")
334
335 _connection: AsyncIODBAPIConnection
336
337 def __init__(self, dbapi: Any, connection: AsyncIODBAPIConnection):
338 self.dbapi = dbapi
339 self._connection = connection
340 self._execute_mutex = asyncio.Lock()
341
342 def cursor(self, server_side: bool = False) -> AsyncAdapt_dbapi_cursor:
343 if server_side:
344 return self._ss_cursor_cls(self)
345 else:
346 return self._cursor_cls(self)
347
348 def execute(
349 self,
350 operation: Any,
351 parameters: Optional[_DBAPISingleExecuteParams] = None,
352 ) -> Any:
353 """lots of DBAPIs seem to provide this, so include it"""
354 cursor = self.cursor()
355 cursor.execute(operation, parameters)
356 return cursor
357
358 def _handle_exception(self, error: Exception) -> NoReturn:
359 exc_info = sys.exc_info()
360
361 raise error.with_traceback(exc_info[2])
362
363 def rollback(self) -> None:
364 try:
365 await_(self._connection.rollback())
366 except Exception as error:
367 self._handle_exception(error)
368
369 def commit(self) -> None:
370 try:
371 await_(self._connection.commit())
372 except Exception as error:
373 self._handle_exception(error)
374
375 def close(self) -> None:
376 await_(self._connection.close())