1# connectors/asyncio.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""generic asyncio-adapted versions of DBAPI connection and cursor"""
9
10from __future__ import annotations
11
12import asyncio
13import collections
14import sys
15from typing import Any
16from typing import AsyncIterator
17from typing import Deque
18from typing import Iterator
19from typing import NoReturn
20from typing import Optional
21from typing import Sequence
22from typing import Tuple
23from typing import Type
24from typing import TYPE_CHECKING
25
26from ..engine import AdaptedConnection
27from ..util import EMPTY_DICT
28from ..util.concurrency import await_fallback
29from ..util.concurrency import await_only
30from ..util.concurrency import in_greenlet
31from ..util.typing import Protocol
32
33if TYPE_CHECKING:
34 from ..engine.interfaces import _DBAPICursorDescription
35 from ..engine.interfaces import _DBAPIMultiExecuteParams
36 from ..engine.interfaces import _DBAPISingleExecuteParams
37 from ..engine.interfaces import DBAPIModule
38 from ..util.typing import Self
39
40
41class AsyncIODBAPIConnection(Protocol):
42 """protocol representing an async adapted version of a
43 :pep:`249` database connection.
44
45
46 """
47
48 # note that async DBAPIs dont agree if close() should be awaitable,
49 # so it is omitted here and picked up by the __getattr__ hook below
50
51 async def commit(self) -> None: ...
52
53 def cursor(self, *args: Any, **kwargs: Any) -> AsyncIODBAPICursor: ...
54
55 async def rollback(self) -> None: ...
56
57 def __getattr__(self, key: str) -> Any: ...
58
59 def __setattr__(self, key: str, value: Any) -> None: ...
60
61
62class AsyncIODBAPICursor(Protocol):
63 """protocol representing an async adapted version
64 of a :pep:`249` database cursor.
65
66
67 """
68
69 def __aenter__(self) -> Any: ...
70
71 @property
72 def description(
73 self,
74 ) -> _DBAPICursorDescription:
75 """The description attribute of the Cursor."""
76 ...
77
78 @property
79 def rowcount(self) -> int: ...
80
81 arraysize: int
82
83 lastrowid: int
84
85 async def close(self) -> None: ...
86
87 async def execute(
88 self,
89 operation: Any,
90 parameters: Optional[_DBAPISingleExecuteParams] = None,
91 ) -> Any: ...
92
93 async def executemany(
94 self,
95 operation: Any,
96 parameters: _DBAPIMultiExecuteParams,
97 ) -> Any: ...
98
99 async def fetchone(self) -> Optional[Any]: ...
100
101 async def fetchmany(self, size: Optional[int] = ...) -> Sequence[Any]: ...
102
103 async def fetchall(self) -> Sequence[Any]: ...
104
105 async def setinputsizes(self, sizes: Sequence[Any]) -> None: ...
106
107 def setoutputsize(self, size: Any, column: Any) -> None: ...
108
109 async def callproc(
110 self, procname: str, parameters: Sequence[Any] = ...
111 ) -> Any: ...
112
113 async def nextset(self) -> Optional[bool]: ...
114
115 def __aiter__(self) -> AsyncIterator[Any]: ...
116
117
118class AsyncAdapt_dbapi_module:
119 if TYPE_CHECKING:
120 Error = DBAPIModule.Error
121 OperationalError = DBAPIModule.OperationalError
122 InterfaceError = DBAPIModule.InterfaceError
123 IntegrityError = DBAPIModule.IntegrityError
124
125 def __getattr__(self, key: str) -> Any: ...
126
127
128class AsyncAdapt_dbapi_cursor:
129 server_side = False
130 __slots__ = (
131 "_adapt_connection",
132 "_connection",
133 "await_",
134 "_cursor",
135 "_rows",
136 "_soft_closed_memoized",
137 )
138
139 _awaitable_cursor_close: bool = True
140
141 _cursor: AsyncIODBAPICursor
142 _adapt_connection: AsyncAdapt_dbapi_connection
143 _connection: AsyncIODBAPIConnection
144 _rows: Deque[Any]
145
146 def __init__(self, adapt_connection: AsyncAdapt_dbapi_connection):
147 self._adapt_connection = adapt_connection
148 self._connection = adapt_connection._connection
149
150 self.await_ = adapt_connection.await_
151
152 cursor = self._make_new_cursor(self._connection)
153 self._cursor = self._aenter_cursor(cursor)
154 self._soft_closed_memoized = EMPTY_DICT
155 if not self.server_side:
156 self._rows = collections.deque()
157
158 def _aenter_cursor(self, cursor: AsyncIODBAPICursor) -> AsyncIODBAPICursor:
159 return self.await_(cursor.__aenter__()) # type: ignore[no-any-return]
160
161 def _make_new_cursor(
162 self, connection: AsyncIODBAPIConnection
163 ) -> AsyncIODBAPICursor:
164 return connection.cursor()
165
166 @property
167 def description(self) -> Optional[_DBAPICursorDescription]:
168 if "description" in self._soft_closed_memoized:
169 return self._soft_closed_memoized["description"] # type: ignore[no-any-return] # noqa: E501
170 return self._cursor.description
171
172 @property
173 def rowcount(self) -> int:
174 return self._cursor.rowcount
175
176 @property
177 def arraysize(self) -> int:
178 return self._cursor.arraysize
179
180 @arraysize.setter
181 def arraysize(self, value: int) -> None:
182 self._cursor.arraysize = value
183
184 @property
185 def lastrowid(self) -> int:
186 return self._cursor.lastrowid
187
188 async def _async_soft_close(self) -> None:
189 """close the cursor but keep the results pending, and memoize the
190 description.
191
192 .. versionadded:: 2.0.44
193
194 """
195
196 if not self._awaitable_cursor_close or self.server_side:
197 return
198
199 self._soft_closed_memoized = self._soft_closed_memoized.union(
200 {
201 "description": self._cursor.description,
202 }
203 )
204 await self._cursor.close()
205
206 def close(self) -> None:
207 self._rows.clear()
208
209 # updated as of 2.0.44
210 # try to "close" the cursor based on what we know about the driver
211 # and if we are able to. otherwise, hope that the asyncio
212 # extension called _async_soft_close() if the cursor is going into
213 # a sync context
214 if self._cursor is None or bool(self._soft_closed_memoized):
215 return
216
217 if not self._awaitable_cursor_close:
218 self._cursor.close() # type: ignore[unused-coroutine]
219 elif in_greenlet():
220 self.await_(self._cursor.close())
221
222 def execute(
223 self,
224 operation: Any,
225 parameters: Optional[_DBAPISingleExecuteParams] = None,
226 ) -> Any:
227 try:
228 return self.await_(self._execute_async(operation, parameters))
229 except Exception as error:
230 self._adapt_connection._handle_exception(error)
231
232 def executemany(
233 self,
234 operation: Any,
235 seq_of_parameters: _DBAPIMultiExecuteParams,
236 ) -> Any:
237 try:
238 return self.await_(
239 self._executemany_async(operation, seq_of_parameters)
240 )
241 except Exception as error:
242 self._adapt_connection._handle_exception(error)
243
244 async def _execute_async(
245 self, operation: Any, parameters: Optional[_DBAPISingleExecuteParams]
246 ) -> Any:
247 async with self._adapt_connection._execute_mutex:
248 if parameters is None:
249 result = await self._cursor.execute(operation)
250 else:
251 result = await self._cursor.execute(operation, parameters)
252
253 if self._cursor.description and not self.server_side:
254 self._rows = collections.deque(await self._cursor.fetchall())
255 return result
256
257 async def _executemany_async(
258 self,
259 operation: Any,
260 seq_of_parameters: _DBAPIMultiExecuteParams,
261 ) -> Any:
262 async with self._adapt_connection._execute_mutex:
263 return await self._cursor.executemany(operation, seq_of_parameters)
264
265 def nextset(self) -> None:
266 self.await_(self._cursor.nextset())
267 if self._cursor.description and not self.server_side:
268 self._rows = collections.deque(
269 self.await_(self._cursor.fetchall())
270 )
271
272 def setinputsizes(self, *inputsizes: Any) -> None:
273 # NOTE: this is overrridden in aioodbc due to
274 # see https://github.com/aio-libs/aioodbc/issues/451
275 # right now
276
277 return self.await_(self._cursor.setinputsizes(*inputsizes))
278
279 def __enter__(self) -> Self:
280 return self
281
282 def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
283 self.close()
284
285 def __iter__(self) -> Iterator[Any]:
286 while self._rows:
287 yield self._rows.popleft()
288
289 def fetchone(self) -> Optional[Any]:
290 if self._rows:
291 return self._rows.popleft()
292 else:
293 return None
294
295 def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]:
296 if size is None:
297 size = self.arraysize
298 rr = self._rows
299 return [rr.popleft() for _ in range(min(size, len(rr)))]
300
301 def fetchall(self) -> Sequence[Any]:
302 retval = list(self._rows)
303 self._rows.clear()
304 return retval
305
306
307class AsyncAdapt_dbapi_ss_cursor(AsyncAdapt_dbapi_cursor):
308 __slots__ = ()
309 server_side = True
310
311 def close(self) -> None:
312 if self._cursor is not None:
313 self.await_(self._cursor.close())
314 self._cursor = None # type: ignore
315
316 def fetchone(self) -> Optional[Any]:
317 return self.await_(self._cursor.fetchone())
318
319 def fetchmany(self, size: Optional[int] = None) -> Any:
320 return self.await_(self._cursor.fetchmany(size=size))
321
322 def fetchall(self) -> Sequence[Any]:
323 return self.await_(self._cursor.fetchall())
324
325 def __iter__(self) -> Iterator[Any]:
326 iterator = self._cursor.__aiter__()
327 while True:
328 try:
329 yield self.await_(iterator.__anext__())
330 except StopAsyncIteration:
331 break
332
333
334class AsyncAdapt_dbapi_connection(AdaptedConnection):
335 _cursor_cls = AsyncAdapt_dbapi_cursor
336 _ss_cursor_cls = AsyncAdapt_dbapi_ss_cursor
337
338 await_ = staticmethod(await_only)
339
340 __slots__ = ("dbapi", "_execute_mutex")
341
342 _connection: AsyncIODBAPIConnection
343
344 def __init__(self, dbapi: Any, connection: AsyncIODBAPIConnection):
345 self.dbapi = dbapi
346 self._connection = connection
347 self._execute_mutex = asyncio.Lock()
348
349 def cursor(self, server_side: bool = False) -> AsyncAdapt_dbapi_cursor:
350 if server_side:
351 return self._ss_cursor_cls(self)
352 else:
353 return self._cursor_cls(self)
354
355 def execute(
356 self,
357 operation: Any,
358 parameters: Optional[_DBAPISingleExecuteParams] = None,
359 ) -> Any:
360 """lots of DBAPIs seem to provide this, so include it"""
361 cursor = self.cursor()
362 cursor.execute(operation, parameters)
363 return cursor
364
365 def _handle_exception(self, error: Exception) -> NoReturn:
366 exc_info = sys.exc_info()
367
368 raise error.with_traceback(exc_info[2])
369
370 def rollback(self) -> None:
371 try:
372 self.await_(self._connection.rollback())
373 except Exception as error:
374 self._handle_exception(error)
375
376 def commit(self) -> None:
377 try:
378 self.await_(self._connection.commit())
379 except Exception as error:
380 self._handle_exception(error)
381
382 def close(self) -> None:
383 self.await_(self._connection.close())
384
385
386class AsyncAdaptFallback_dbapi_connection(AsyncAdapt_dbapi_connection):
387 __slots__ = ()
388
389 await_ = staticmethod(await_fallback)
390
391
392class AsyncAdapt_terminate:
393 """Mixin for a AsyncAdapt_dbapi_connection to add terminate support."""
394
395 __slots__ = ()
396
397 def terminate(self) -> None:
398 if in_greenlet():
399 # in a greenlet; this is the connection was invalidated case.
400 try:
401 # try to gracefully close; see #10717
402 self.await_(asyncio.shield(self._terminate_graceful_close())) # type: ignore[attr-defined] # noqa: E501
403 except self._terminate_handled_exceptions() as e:
404 # in the case where we are recycling an old connection
405 # that may have already been disconnected, close() will
406 # fail. In this case, terminate
407 # the connection without any further waiting.
408 # see issue #8419
409 self._terminate_force_close()
410 if isinstance(e, asyncio.CancelledError):
411 # re-raise CancelledError if we were cancelled
412 raise
413 else:
414 # not in a greenlet; this is the gc cleanup case
415 self._terminate_force_close()
416
417 def _terminate_handled_exceptions(self) -> Tuple[Type[BaseException], ...]:
418 """Returns the exceptions that should be handled when
419 calling _graceful_close.
420 """
421 return (asyncio.TimeoutError, asyncio.CancelledError, OSError)
422
423 async def _terminate_graceful_close(self) -> None:
424 """Try to close connection gracefully"""
425 raise NotImplementedError
426
427 def _terminate_force_close(self) -> None:
428 """Terminate the connection"""
429 raise NotImplementedError