1# connectors/asyncio.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""generic asyncio-adapted versions of DBAPI connection and cursor"""
9
10from __future__ import annotations
11
12import asyncio
13import collections
14import sys
15from typing import Any
16from typing import Deque
17from typing import Iterator
18from typing import NoReturn
19from typing import Optional
20from typing import Protocol
21from typing import Sequence
22
23from ..engine import AdaptedConnection
24from ..engine.interfaces import _DBAPICursorDescription
25from ..engine.interfaces import _DBAPIMultiExecuteParams
26from ..engine.interfaces import _DBAPISingleExecuteParams
27from ..util.concurrency import await_
28from ..util.typing import Self
29
30
31class AsyncIODBAPIConnection(Protocol):
32 """protocol representing an async adapted version of a
33 :pep:`249` database connection.
34
35
36 """
37
38 async def close(self) -> None: ...
39
40 async def commit(self) -> None: ...
41
42 def cursor(self) -> AsyncIODBAPICursor: ...
43
44 async def rollback(self) -> None: ...
45
46
47class AsyncIODBAPICursor(Protocol):
48 """protocol representing an async adapted version
49 of a :pep:`249` database cursor.
50
51
52 """
53
54 def __aenter__(self) -> Any: ...
55
56 @property
57 def description(
58 self,
59 ) -> _DBAPICursorDescription:
60 """The description attribute of the Cursor."""
61 ...
62
63 @property
64 def rowcount(self) -> int: ...
65
66 arraysize: int
67
68 lastrowid: int
69
70 async def close(self) -> None: ...
71
72 async def execute(
73 self,
74 operation: Any,
75 parameters: Optional[_DBAPISingleExecuteParams] = None,
76 ) -> Any: ...
77
78 async def executemany(
79 self,
80 operation: Any,
81 parameters: _DBAPIMultiExecuteParams,
82 ) -> Any: ...
83
84 async def fetchone(self) -> Optional[Any]: ...
85
86 async def fetchmany(self, size: Optional[int] = ...) -> Sequence[Any]: ...
87
88 async def fetchall(self) -> Sequence[Any]: ...
89
90 async def setinputsizes(self, sizes: Sequence[Any]) -> None: ...
91
92 def setoutputsize(self, size: Any, column: Any) -> None: ...
93
94 async def callproc(
95 self, procname: str, parameters: Sequence[Any] = ...
96 ) -> Any: ...
97
98 async def nextset(self) -> Optional[bool]: ...
99
100
101class AsyncAdapt_dbapi_cursor:
102 server_side = False
103 __slots__ = (
104 "_adapt_connection",
105 "_connection",
106 "_cursor",
107 "_rows",
108 )
109
110 _cursor: AsyncIODBAPICursor
111 _adapt_connection: AsyncAdapt_dbapi_connection
112 _connection: AsyncIODBAPIConnection
113 _rows: Deque[Any]
114
115 def __init__(self, adapt_connection: AsyncAdapt_dbapi_connection):
116 self._adapt_connection = adapt_connection
117 self._connection = adapt_connection._connection
118
119 cursor = self._make_new_cursor(self._connection)
120 self._cursor = self._aenter_cursor(cursor)
121
122 self._rows = collections.deque()
123
124 def _aenter_cursor(self, cursor: AsyncIODBAPICursor) -> AsyncIODBAPICursor:
125 try:
126 return await_(cursor.__aenter__()) # type: ignore[no-any-return]
127 except Exception as error:
128 self._adapt_connection._handle_exception(error)
129
130 def _make_new_cursor(
131 self, connection: AsyncIODBAPIConnection
132 ) -> AsyncIODBAPICursor:
133 return connection.cursor()
134
135 @property
136 def description(self) -> Optional[_DBAPICursorDescription]:
137 return self._cursor.description
138
139 @property
140 def rowcount(self) -> int:
141 return self._cursor.rowcount
142
143 @property
144 def arraysize(self) -> int:
145 return self._cursor.arraysize
146
147 @arraysize.setter
148 def arraysize(self, value: int) -> None:
149 self._cursor.arraysize = value
150
151 @property
152 def lastrowid(self) -> int:
153 return self._cursor.lastrowid
154
155 def close(self) -> None:
156 # note we aren't actually closing the cursor here,
157 # we are just letting GC do it. see notes in aiomysql dialect
158 self._rows.clear()
159
160 def execute(
161 self,
162 operation: Any,
163 parameters: Optional[_DBAPISingleExecuteParams] = None,
164 ) -> Any:
165 try:
166 return await_(self._execute_async(operation, parameters))
167 except Exception as error:
168 self._adapt_connection._handle_exception(error)
169
170 def executemany(
171 self,
172 operation: Any,
173 seq_of_parameters: _DBAPIMultiExecuteParams,
174 ) -> Any:
175 try:
176 return await_(
177 self._executemany_async(operation, seq_of_parameters)
178 )
179 except Exception as error:
180 self._adapt_connection._handle_exception(error)
181
182 async def _execute_async(
183 self, operation: Any, parameters: Optional[_DBAPISingleExecuteParams]
184 ) -> Any:
185 async with self._adapt_connection._execute_mutex:
186 if parameters is None:
187 result = await self._cursor.execute(operation)
188 else:
189 result = await self._cursor.execute(operation, parameters)
190
191 if self._cursor.description and not self.server_side:
192 self._rows = collections.deque(await self._cursor.fetchall())
193 return result
194
195 async def _executemany_async(
196 self,
197 operation: Any,
198 seq_of_parameters: _DBAPIMultiExecuteParams,
199 ) -> Any:
200 async with self._adapt_connection._execute_mutex:
201 return await self._cursor.executemany(operation, seq_of_parameters)
202
203 def nextset(self) -> None:
204 await_(self._cursor.nextset())
205 if self._cursor.description and not self.server_side:
206 self._rows = collections.deque(await_(self._cursor.fetchall()))
207
208 def setinputsizes(self, *inputsizes: Any) -> None:
209 # NOTE: this is overrridden in aioodbc due to
210 # see https://github.com/aio-libs/aioodbc/issues/451
211 # right now
212
213 return await_(self._cursor.setinputsizes(*inputsizes))
214
215 def __enter__(self) -> Self:
216 return self
217
218 def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
219 self.close()
220
221 def __iter__(self) -> Iterator[Any]:
222 while self._rows:
223 yield self._rows.popleft()
224
225 def fetchone(self) -> Optional[Any]:
226 if self._rows:
227 return self._rows.popleft()
228 else:
229 return None
230
231 def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]:
232 if size is None:
233 size = self.arraysize
234 rr = self._rows
235 return [rr.popleft() for _ in range(min(size, len(rr)))]
236
237 def fetchall(self) -> Sequence[Any]:
238 retval = list(self._rows)
239 self._rows.clear()
240 return retval
241
242
243class AsyncAdapt_dbapi_ss_cursor(AsyncAdapt_dbapi_cursor):
244 __slots__ = ()
245 server_side = True
246
247 def close(self) -> None:
248 if self._cursor is not None:
249 await_(self._cursor.close())
250 self._cursor = None # type: ignore
251
252 def fetchone(self) -> Optional[Any]:
253 return await_(self._cursor.fetchone())
254
255 def fetchmany(self, size: Optional[int] = None) -> Any:
256 return await_(self._cursor.fetchmany(size=size))
257
258 def fetchall(self) -> Sequence[Any]:
259 return await_(self._cursor.fetchall())
260
261
262class AsyncAdapt_dbapi_connection(AdaptedConnection):
263 _cursor_cls = AsyncAdapt_dbapi_cursor
264 _ss_cursor_cls = AsyncAdapt_dbapi_ss_cursor
265
266 __slots__ = ("dbapi", "_execute_mutex")
267
268 _connection: AsyncIODBAPIConnection
269
270 def __init__(self, dbapi: Any, connection: AsyncIODBAPIConnection):
271 self.dbapi = dbapi
272 self._connection = connection
273 self._execute_mutex = asyncio.Lock()
274
275 def cursor(self, server_side: bool = False) -> AsyncAdapt_dbapi_cursor:
276 if server_side:
277 return self._ss_cursor_cls(self)
278 else:
279 return self._cursor_cls(self)
280
281 def execute(
282 self,
283 operation: Any,
284 parameters: Optional[_DBAPISingleExecuteParams] = None,
285 ) -> Any:
286 """lots of DBAPIs seem to provide this, so include it"""
287 cursor = self.cursor()
288 cursor.execute(operation, parameters)
289 return cursor
290
291 def _handle_exception(self, error: Exception) -> NoReturn:
292 exc_info = sys.exc_info()
293
294 raise error.with_traceback(exc_info[2])
295
296 def rollback(self) -> None:
297 try:
298 await_(self._connection.rollback())
299 except Exception as error:
300 self._handle_exception(error)
301
302 def commit(self) -> None:
303 try:
304 await_(self._connection.commit())
305 except Exception as error:
306 self._handle_exception(error)
307
308 def close(self) -> None:
309 await_(self._connection.close())