1# connectors/asyncio.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""generic asyncio-adapted versions of DBAPI connection and cursor"""
9
10from __future__ import annotations
11
12import asyncio
13import collections
14import sys
15from typing import Any
16from typing import AsyncIterator
17from typing import Deque
18from typing import Iterator
19from typing import NoReturn
20from typing import Optional
21from typing import Protocol
22from typing import Sequence
23from typing import Tuple
24from typing import Type
25from typing import TYPE_CHECKING
26
27from ..engine import AdaptedConnection
28from ..util import EMPTY_DICT
29from ..util.concurrency import await_
30from ..util.concurrency import in_greenlet
31
32if TYPE_CHECKING:
33 from ..engine.interfaces import _DBAPICursorDescription
34 from ..engine.interfaces import _DBAPIMultiExecuteParams
35 from ..engine.interfaces import _DBAPISingleExecuteParams
36 from ..engine.interfaces import DBAPIModule
37 from ..util.typing import Self
38
39
40class AsyncIODBAPIConnection(Protocol):
41 """protocol representing an async adapted version of a
42 :pep:`249` database connection.
43
44
45 """
46
47 # note that async DBAPIs dont agree if close() should be awaitable,
48 # so it is omitted here and picked up by the __getattr__ hook below
49
50 async def commit(self) -> None: ...
51
52 def cursor(self, *args: Any, **kwargs: Any) -> AsyncIODBAPICursor: ...
53
54 async def rollback(self) -> None: ...
55
56 def __getattr__(self, key: str) -> Any: ...
57
58 def __setattr__(self, key: str, value: Any) -> None: ...
59
60
61class AsyncIODBAPICursor(Protocol):
62 """protocol representing an async adapted version
63 of a :pep:`249` database cursor.
64
65
66 """
67
68 def __aenter__(self) -> Any: ...
69
70 @property
71 def description(
72 self,
73 ) -> _DBAPICursorDescription:
74 """The description attribute of the Cursor."""
75 ...
76
77 @property
78 def rowcount(self) -> int: ...
79
80 arraysize: int
81
82 lastrowid: int
83
84 async def close(self) -> None: ...
85
86 async def execute(
87 self,
88 operation: Any,
89 parameters: Optional[_DBAPISingleExecuteParams] = None,
90 ) -> Any: ...
91
92 async def executemany(
93 self,
94 operation: Any,
95 parameters: _DBAPIMultiExecuteParams,
96 ) -> Any: ...
97
98 async def fetchone(self) -> Optional[Any]: ...
99
100 async def fetchmany(self, size: Optional[int] = ...) -> Sequence[Any]: ...
101
102 async def fetchall(self) -> Sequence[Any]: ...
103
104 async def setinputsizes(self, sizes: Sequence[Any]) -> None: ...
105
106 def setoutputsize(self, size: Any, column: Any) -> None: ...
107
108 async def callproc(
109 self, procname: str, parameters: Sequence[Any] = ...
110 ) -> Any: ...
111
112 async def nextset(self) -> Optional[bool]: ...
113
114 def __aiter__(self) -> AsyncIterator[Any]: ...
115
116
117class AsyncAdapt_dbapi_module:
118 if TYPE_CHECKING:
119 Error = DBAPIModule.Error
120 OperationalError = DBAPIModule.OperationalError
121 InterfaceError = DBAPIModule.InterfaceError
122 IntegrityError = DBAPIModule.IntegrityError
123
124 def __getattr__(self, key: str) -> Any: ...
125
126
127class AsyncAdapt_dbapi_cursor:
128 server_side = False
129 __slots__ = (
130 "_adapt_connection",
131 "_connection",
132 "_cursor",
133 "_rows",
134 "_soft_closed_memoized",
135 )
136
137 _awaitable_cursor_close: bool = True
138
139 _cursor: AsyncIODBAPICursor
140 _adapt_connection: AsyncAdapt_dbapi_connection
141 _connection: AsyncIODBAPIConnection
142 _rows: Deque[Any]
143
144 def __init__(self, adapt_connection: AsyncAdapt_dbapi_connection):
145 self._adapt_connection = adapt_connection
146 self._connection = adapt_connection._connection
147
148 cursor = self._make_new_cursor(self._connection)
149 self._cursor = self._aenter_cursor(cursor)
150 self._soft_closed_memoized = EMPTY_DICT
151 if not self.server_side:
152 self._rows = collections.deque()
153
154 def _aenter_cursor(self, cursor: AsyncIODBAPICursor) -> AsyncIODBAPICursor:
155 try:
156 return await_(cursor.__aenter__()) # type: ignore[no-any-return]
157 except Exception as error:
158 self._adapt_connection._handle_exception(error)
159
160 def _make_new_cursor(
161 self, connection: AsyncIODBAPIConnection
162 ) -> AsyncIODBAPICursor:
163 return connection.cursor()
164
165 @property
166 def description(self) -> Optional[_DBAPICursorDescription]:
167 if "description" in self._soft_closed_memoized:
168 return self._soft_closed_memoized["description"] # type: ignore[no-any-return] # noqa: E501
169 return self._cursor.description
170
171 @property
172 def rowcount(self) -> int:
173 return self._cursor.rowcount
174
175 @property
176 def arraysize(self) -> int:
177 return self._cursor.arraysize
178
179 @arraysize.setter
180 def arraysize(self, value: int) -> None:
181 self._cursor.arraysize = value
182
183 @property
184 def lastrowid(self) -> int:
185 return self._cursor.lastrowid
186
187 async def _async_soft_close(self) -> None:
188 """close the cursor but keep the results pending, and memoize the
189 description.
190
191 .. versionadded:: 2.0.44
192
193 """
194
195 if not self._awaitable_cursor_close or self.server_side:
196 return
197
198 self._soft_closed_memoized = self._soft_closed_memoized.union(
199 {
200 "description": self._cursor.description,
201 }
202 )
203 await self._cursor.close()
204
205 def close(self) -> None:
206 self._rows.clear()
207
208 # updated as of 2.0.44
209 # try to "close" the cursor based on what we know about the driver
210 # and if we are able to. otherwise, hope that the asyncio
211 # extension called _async_soft_close() if the cursor is going into
212 # a sync context
213 if self._cursor is None or bool(self._soft_closed_memoized):
214 return
215
216 if not self._awaitable_cursor_close:
217 self._cursor.close() # type: ignore[unused-coroutine]
218 elif in_greenlet():
219 await_(self._cursor.close())
220
221 def execute(
222 self,
223 operation: Any,
224 parameters: Optional[_DBAPISingleExecuteParams] = None,
225 ) -> Any:
226 try:
227 return await_(self._execute_async(operation, parameters))
228 except Exception as error:
229 self._adapt_connection._handle_exception(error)
230
231 def executemany(
232 self,
233 operation: Any,
234 seq_of_parameters: _DBAPIMultiExecuteParams,
235 ) -> Any:
236 try:
237 return await_(
238 self._executemany_async(operation, seq_of_parameters)
239 )
240 except Exception as error:
241 self._adapt_connection._handle_exception(error)
242
243 async def _execute_async(
244 self, operation: Any, parameters: Optional[_DBAPISingleExecuteParams]
245 ) -> Any:
246 async with self._adapt_connection._execute_mutex:
247 if parameters is None:
248 result = await self._cursor.execute(operation)
249 else:
250 result = await self._cursor.execute(operation, parameters)
251
252 if self._cursor.description and not self.server_side:
253 self._rows = collections.deque(await self._cursor.fetchall())
254 return result
255
256 async def _executemany_async(
257 self,
258 operation: Any,
259 seq_of_parameters: _DBAPIMultiExecuteParams,
260 ) -> Any:
261 async with self._adapt_connection._execute_mutex:
262 return await self._cursor.executemany(operation, seq_of_parameters)
263
264 def nextset(self) -> None:
265 await_(self._cursor.nextset())
266 if self._cursor.description and not self.server_side:
267 self._rows = collections.deque(await_(self._cursor.fetchall()))
268
269 def setinputsizes(self, *inputsizes: Any) -> None:
270 # NOTE: this is overrridden in aioodbc due to
271 # see https://github.com/aio-libs/aioodbc/issues/451
272 # right now
273
274 return await_(self._cursor.setinputsizes(*inputsizes))
275
276 def __enter__(self) -> Self:
277 return self
278
279 def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
280 self.close()
281
282 def __iter__(self) -> Iterator[Any]:
283 while self._rows:
284 yield self._rows.popleft()
285
286 def fetchone(self) -> Optional[Any]:
287 if self._rows:
288 return self._rows.popleft()
289 else:
290 return None
291
292 def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]:
293 if size is None:
294 size = self.arraysize
295 rr = self._rows
296 return [rr.popleft() for _ in range(min(size, len(rr)))]
297
298 def fetchall(self) -> Sequence[Any]:
299 retval = list(self._rows)
300 self._rows.clear()
301 return retval
302
303
304class AsyncAdapt_dbapi_ss_cursor(AsyncAdapt_dbapi_cursor):
305 __slots__ = ()
306 server_side = True
307
308 def close(self) -> None:
309 if self._cursor is not None:
310 await_(self._cursor.close())
311 self._cursor = None # type: ignore
312
313 def fetchone(self) -> Optional[Any]:
314 return await_(self._cursor.fetchone())
315
316 def fetchmany(self, size: Optional[int] = None) -> Any:
317 return await_(self._cursor.fetchmany(size=size))
318
319 def fetchall(self) -> Sequence[Any]:
320 return await_(self._cursor.fetchall())
321
322 def __iter__(self) -> Iterator[Any]:
323 iterator = self._cursor.__aiter__()
324 while True:
325 try:
326 yield await_(iterator.__anext__())
327 except StopAsyncIteration:
328 break
329
330
331class AsyncAdapt_dbapi_connection(AdaptedConnection):
332 _cursor_cls = AsyncAdapt_dbapi_cursor
333 _ss_cursor_cls = AsyncAdapt_dbapi_ss_cursor
334
335 __slots__ = ("dbapi", "_execute_mutex")
336
337 _connection: AsyncIODBAPIConnection
338
339 def __init__(self, dbapi: Any, connection: AsyncIODBAPIConnection):
340 self.dbapi = dbapi
341 self._connection = connection
342 self._execute_mutex = asyncio.Lock()
343
344 def cursor(self, server_side: bool = False) -> AsyncAdapt_dbapi_cursor:
345 if server_side:
346 return self._ss_cursor_cls(self)
347 else:
348 return self._cursor_cls(self)
349
350 def execute(
351 self,
352 operation: Any,
353 parameters: Optional[_DBAPISingleExecuteParams] = None,
354 ) -> Any:
355 """lots of DBAPIs seem to provide this, so include it"""
356 cursor = self.cursor()
357 cursor.execute(operation, parameters)
358 return cursor
359
360 def _handle_exception(self, error: Exception) -> NoReturn:
361 exc_info = sys.exc_info()
362
363 raise error.with_traceback(exc_info[2])
364
365 def rollback(self) -> None:
366 try:
367 await_(self._connection.rollback())
368 except Exception as error:
369 self._handle_exception(error)
370
371 def commit(self) -> None:
372 try:
373 await_(self._connection.commit())
374 except Exception as error:
375 self._handle_exception(error)
376
377 def close(self) -> None:
378 await_(self._connection.close())
379
380
381class AsyncAdapt_terminate:
382 """Mixin for a AsyncAdapt_dbapi_connection to add terminate support."""
383
384 __slots__ = ()
385
386 def terminate(self) -> None:
387 if in_greenlet():
388 # in a greenlet; this is the connection was invalidated case.
389 try:
390 # try to gracefully close; see #10717
391 await_(asyncio.shield(self._terminate_graceful_close()))
392 except self._terminate_handled_exceptions() as e:
393 # in the case where we are recycling an old connection
394 # that may have already been disconnected, close() will
395 # fail. In this case, terminate
396 # the connection without any further waiting.
397 # see issue #8419
398 self._terminate_force_close()
399 if isinstance(e, asyncio.CancelledError):
400 # re-raise CancelledError if we were cancelled
401 raise
402 else:
403 # not in a greenlet; this is the gc cleanup case
404 self._terminate_force_close()
405
406 def _terminate_handled_exceptions(self) -> Tuple[Type[BaseException], ...]:
407 """Returns the exceptions that should be handled when
408 calling _graceful_close.
409 """
410 return (asyncio.TimeoutError, asyncio.CancelledError, OSError)
411
412 async def _terminate_graceful_close(self) -> None:
413 """Try to close connection gracefully"""
414 raise NotImplementedError
415
416 def _terminate_force_close(self) -> None:
417 """Terminate the connection"""
418 raise NotImplementedError