1# connectors/asyncio.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""generic asyncio-adapted versions of DBAPI connection and cursor"""
9
10from __future__ import annotations
11
12import asyncio
13import collections
14import sys
15from typing import Any
16from typing import AsyncIterator
17from typing import Deque
18from typing import Iterator
19from typing import NoReturn
20from typing import Optional
21from typing import Protocol
22from typing import Sequence
23
24from ..engine import AdaptedConnection
25from ..engine.interfaces import _DBAPICursorDescription
26from ..engine.interfaces import _DBAPIMultiExecuteParams
27from ..engine.interfaces import _DBAPISingleExecuteParams
28from ..util.concurrency import await_
29from ..util.typing import Self
30
31
32class AsyncIODBAPIConnection(Protocol):
33 """protocol representing an async adapted version of a
34 :pep:`249` database connection.
35
36
37 """
38
39 async def close(self) -> None: ...
40
41 async def commit(self) -> None: ...
42
43 def cursor(self) -> AsyncIODBAPICursor: ...
44
45 async def rollback(self) -> None: ...
46
47
48class AsyncIODBAPICursor(Protocol):
49 """protocol representing an async adapted version
50 of a :pep:`249` database cursor.
51
52
53 """
54
55 def __aenter__(self) -> Any: ...
56
57 @property
58 def description(
59 self,
60 ) -> _DBAPICursorDescription:
61 """The description attribute of the Cursor."""
62 ...
63
64 @property
65 def rowcount(self) -> int: ...
66
67 arraysize: int
68
69 lastrowid: int
70
71 async def close(self) -> None: ...
72
73 async def execute(
74 self,
75 operation: Any,
76 parameters: Optional[_DBAPISingleExecuteParams] = None,
77 ) -> Any: ...
78
79 async def executemany(
80 self,
81 operation: Any,
82 parameters: _DBAPIMultiExecuteParams,
83 ) -> Any: ...
84
85 async def fetchone(self) -> Optional[Any]: ...
86
87 async def fetchmany(self, size: Optional[int] = ...) -> Sequence[Any]: ...
88
89 async def fetchall(self) -> Sequence[Any]: ...
90
91 async def setinputsizes(self, sizes: Sequence[Any]) -> None: ...
92
93 def setoutputsize(self, size: Any, column: Any) -> None: ...
94
95 async def callproc(
96 self, procname: str, parameters: Sequence[Any] = ...
97 ) -> Any: ...
98
99 async def nextset(self) -> Optional[bool]: ...
100
101 def __aiter__(self) -> AsyncIterator[Any]: ...
102
103
104class AsyncAdapt_dbapi_cursor:
105 server_side = False
106 __slots__ = (
107 "_adapt_connection",
108 "_connection",
109 "_cursor",
110 "_rows",
111 )
112
113 _cursor: AsyncIODBAPICursor
114 _adapt_connection: AsyncAdapt_dbapi_connection
115 _connection: AsyncIODBAPIConnection
116 _rows: Deque[Any]
117
118 def __init__(self, adapt_connection: AsyncAdapt_dbapi_connection):
119 self._adapt_connection = adapt_connection
120 self._connection = adapt_connection._connection
121
122 cursor = self._make_new_cursor(self._connection)
123 self._cursor = self._aenter_cursor(cursor)
124
125 if not self.server_side:
126 self._rows = collections.deque()
127
128 def _aenter_cursor(self, cursor: AsyncIODBAPICursor) -> AsyncIODBAPICursor:
129 try:
130 return await_(cursor.__aenter__()) # type: ignore[no-any-return]
131 except Exception as error:
132 self._adapt_connection._handle_exception(error)
133
134 def _make_new_cursor(
135 self, connection: AsyncIODBAPIConnection
136 ) -> AsyncIODBAPICursor:
137 return connection.cursor()
138
139 @property
140 def description(self) -> Optional[_DBAPICursorDescription]:
141 return self._cursor.description
142
143 @property
144 def rowcount(self) -> int:
145 return self._cursor.rowcount
146
147 @property
148 def arraysize(self) -> int:
149 return self._cursor.arraysize
150
151 @arraysize.setter
152 def arraysize(self, value: int) -> None:
153 self._cursor.arraysize = value
154
155 @property
156 def lastrowid(self) -> int:
157 return self._cursor.lastrowid
158
159 def close(self) -> None:
160 # note we aren't actually closing the cursor here,
161 # we are just letting GC do it. see notes in aiomysql dialect
162 self._rows.clear()
163
164 def execute(
165 self,
166 operation: Any,
167 parameters: Optional[_DBAPISingleExecuteParams] = None,
168 ) -> Any:
169 try:
170 return await_(self._execute_async(operation, parameters))
171 except Exception as error:
172 self._adapt_connection._handle_exception(error)
173
174 def executemany(
175 self,
176 operation: Any,
177 seq_of_parameters: _DBAPIMultiExecuteParams,
178 ) -> Any:
179 try:
180 return await_(
181 self._executemany_async(operation, seq_of_parameters)
182 )
183 except Exception as error:
184 self._adapt_connection._handle_exception(error)
185
186 async def _execute_async(
187 self, operation: Any, parameters: Optional[_DBAPISingleExecuteParams]
188 ) -> Any:
189 async with self._adapt_connection._execute_mutex:
190 if parameters is None:
191 result = await self._cursor.execute(operation)
192 else:
193 result = await self._cursor.execute(operation, parameters)
194
195 if self._cursor.description and not self.server_side:
196 self._rows = collections.deque(await self._cursor.fetchall())
197 return result
198
199 async def _executemany_async(
200 self,
201 operation: Any,
202 seq_of_parameters: _DBAPIMultiExecuteParams,
203 ) -> Any:
204 async with self._adapt_connection._execute_mutex:
205 return await self._cursor.executemany(operation, seq_of_parameters)
206
207 def nextset(self) -> None:
208 await_(self._cursor.nextset())
209 if self._cursor.description and not self.server_side:
210 self._rows = collections.deque(await_(self._cursor.fetchall()))
211
212 def setinputsizes(self, *inputsizes: Any) -> None:
213 # NOTE: this is overrridden in aioodbc due to
214 # see https://github.com/aio-libs/aioodbc/issues/451
215 # right now
216
217 return await_(self._cursor.setinputsizes(*inputsizes))
218
219 def __enter__(self) -> Self:
220 return self
221
222 def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
223 self.close()
224
225 def __iter__(self) -> Iterator[Any]:
226 while self._rows:
227 yield self._rows.popleft()
228
229 def fetchone(self) -> Optional[Any]:
230 if self._rows:
231 return self._rows.popleft()
232 else:
233 return None
234
235 def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]:
236 if size is None:
237 size = self.arraysize
238 rr = self._rows
239 return [rr.popleft() for _ in range(min(size, len(rr)))]
240
241 def fetchall(self) -> Sequence[Any]:
242 retval = list(self._rows)
243 self._rows.clear()
244 return retval
245
246
247class AsyncAdapt_dbapi_ss_cursor(AsyncAdapt_dbapi_cursor):
248 __slots__ = ()
249 server_side = True
250
251 def close(self) -> None:
252 if self._cursor is not None:
253 await_(self._cursor.close())
254 self._cursor = None # type: ignore
255
256 def fetchone(self) -> Optional[Any]:
257 return await_(self._cursor.fetchone())
258
259 def fetchmany(self, size: Optional[int] = None) -> Any:
260 return await_(self._cursor.fetchmany(size=size))
261
262 def fetchall(self) -> Sequence[Any]:
263 return await_(self._cursor.fetchall())
264
265 def __iter__(self) -> Iterator[Any]:
266 iterator = self._cursor.__aiter__()
267 while True:
268 try:
269 yield await_(iterator.__anext__())
270 except StopAsyncIteration:
271 break
272
273
274class AsyncAdapt_dbapi_connection(AdaptedConnection):
275 _cursor_cls = AsyncAdapt_dbapi_cursor
276 _ss_cursor_cls = AsyncAdapt_dbapi_ss_cursor
277
278 __slots__ = ("dbapi", "_execute_mutex")
279
280 _connection: AsyncIODBAPIConnection
281
282 def __init__(self, dbapi: Any, connection: AsyncIODBAPIConnection):
283 self.dbapi = dbapi
284 self._connection = connection
285 self._execute_mutex = asyncio.Lock()
286
287 def cursor(self, server_side: bool = False) -> AsyncAdapt_dbapi_cursor:
288 if server_side:
289 return self._ss_cursor_cls(self)
290 else:
291 return self._cursor_cls(self)
292
293 def execute(
294 self,
295 operation: Any,
296 parameters: Optional[_DBAPISingleExecuteParams] = None,
297 ) -> Any:
298 """lots of DBAPIs seem to provide this, so include it"""
299 cursor = self.cursor()
300 cursor.execute(operation, parameters)
301 return cursor
302
303 def _handle_exception(self, error: Exception) -> NoReturn:
304 exc_info = sys.exc_info()
305
306 raise error.with_traceback(exc_info[2])
307
308 def rollback(self) -> None:
309 try:
310 await_(self._connection.rollback())
311 except Exception as error:
312 self._handle_exception(error)
313
314 def commit(self) -> None:
315 try:
316 await_(self._connection.commit())
317 except Exception as error:
318 self._handle_exception(error)
319
320 def close(self) -> None:
321 await_(self._connection.close())