1# dialects/mysql/aiomysql.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors <see AUTHORS
3# file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7r"""
8.. dialect:: mysql+aiomysql
9 :name: aiomysql
10 :dbapi: aiomysql
11 :connectstring: mysql+aiomysql://user:password@host:port/dbname[?key=value&key=value...]
12 :url: https://github.com/aio-libs/aiomysql
13
14The aiomysql dialect is SQLAlchemy's second Python asyncio dialect.
15
16Using a special asyncio mediation layer, the aiomysql dialect is usable
17as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
18extension package.
19
20This dialect should normally be used only with the
21:func:`_asyncio.create_async_engine` engine creation function::
22
23 from sqlalchemy.ext.asyncio import create_async_engine
24 engine = create_async_engine("mysql+aiomysql://user:pass@hostname/dbname?charset=utf8mb4")
25
26
27""" # noqa
28
29from .pymysql import MySQLDialect_pymysql
30from ... import pool
31from ... import util
32from ...engine import AdaptedConnection
33from ...util.concurrency import asyncio
34from ...util.concurrency import await_fallback
35from ...util.concurrency import await_only
36
37
38class AsyncAdapt_aiomysql_cursor:
39 server_side = False
40 __slots__ = (
41 "_adapt_connection",
42 "_connection",
43 "await_",
44 "_cursor",
45 "_rows",
46 )
47
48 def __init__(self, adapt_connection):
49 self._adapt_connection = adapt_connection
50 self._connection = adapt_connection._connection
51 self.await_ = adapt_connection.await_
52
53 cursor = self._connection.cursor(adapt_connection.dbapi.Cursor)
54
55 # see https://github.com/aio-libs/aiomysql/issues/543
56 self._cursor = self.await_(cursor.__aenter__())
57 self._rows = []
58
59 @property
60 def description(self):
61 return self._cursor.description
62
63 @property
64 def rowcount(self):
65 return self._cursor.rowcount
66
67 @property
68 def arraysize(self):
69 return self._cursor.arraysize
70
71 @arraysize.setter
72 def arraysize(self, value):
73 self._cursor.arraysize = value
74
75 @property
76 def lastrowid(self):
77 return self._cursor.lastrowid
78
79 def close(self):
80 # note we aren't actually closing the cursor here,
81 # we are just letting GC do it. to allow this to be async
82 # we would need the Result to change how it does "Safe close cursor".
83 # MySQL "cursors" don't actually have state to be "closed" besides
84 # exhausting rows, which we already have done for sync cursor.
85 # another option would be to emulate aiosqlite dialect and assign
86 # cursor only if we are doing server side cursor operation.
87 self._rows[:] = []
88
89 def execute(self, operation, parameters=None):
90 return self.await_(self._execute_async(operation, parameters))
91
92 def executemany(self, operation, seq_of_parameters):
93 return self.await_(
94 self._executemany_async(operation, seq_of_parameters)
95 )
96
97 async def _execute_async(self, operation, parameters):
98 async with self._adapt_connection._execute_mutex:
99 result = await self._cursor.execute(operation, parameters)
100
101 if not self.server_side:
102 # aiomysql has a "fake" async result, so we have to pull it out
103 # of that here since our default result is not async.
104 # we could just as easily grab "_rows" here and be done with it
105 # but this is safer.
106 self._rows = list(await self._cursor.fetchall())
107 return result
108
109 async def _executemany_async(self, operation, seq_of_parameters):
110 async with self._adapt_connection._execute_mutex:
111 return await self._cursor.executemany(operation, seq_of_parameters)
112
113 def setinputsizes(self, *inputsizes):
114 pass
115
116 def __iter__(self):
117 while self._rows:
118 yield self._rows.pop(0)
119
120 def fetchone(self):
121 if self._rows:
122 return self._rows.pop(0)
123 else:
124 return None
125
126 def fetchmany(self, size=None):
127 if size is None:
128 size = self.arraysize
129
130 retval = self._rows[0:size]
131 self._rows[:] = self._rows[size:]
132 return retval
133
134 def fetchall(self):
135 retval = self._rows[:]
136 self._rows[:] = []
137 return retval
138
139
140class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor):
141 __slots__ = ()
142 server_side = True
143
144 def __init__(self, adapt_connection):
145 self._adapt_connection = adapt_connection
146 self._connection = adapt_connection._connection
147 self.await_ = adapt_connection.await_
148
149 cursor = self._connection.cursor(adapt_connection.dbapi.SSCursor)
150
151 self._cursor = self.await_(cursor.__aenter__())
152
153 def close(self):
154 if self._cursor is not None:
155 self.await_(self._cursor.close())
156 self._cursor = None
157
158 def fetchone(self):
159 return self.await_(self._cursor.fetchone())
160
161 def fetchmany(self, size=None):
162 return self.await_(self._cursor.fetchmany(size=size))
163
164 def fetchall(self):
165 return self.await_(self._cursor.fetchall())
166
167
168class AsyncAdapt_aiomysql_connection(AdaptedConnection):
169 await_ = staticmethod(await_only)
170 __slots__ = ("dbapi", "_connection", "_execute_mutex")
171
172 def __init__(self, dbapi, connection):
173 self.dbapi = dbapi
174 self._connection = connection
175 self._execute_mutex = asyncio.Lock()
176
177 def ping(self, reconnect):
178 return self.await_(self._connection.ping(reconnect))
179
180 def character_set_name(self):
181 return self._connection.character_set_name()
182
183 def autocommit(self, value):
184 self.await_(self._connection.autocommit(value))
185
186 def cursor(self, server_side=False):
187 if server_side:
188 return AsyncAdapt_aiomysql_ss_cursor(self)
189 else:
190 return AsyncAdapt_aiomysql_cursor(self)
191
192 def rollback(self):
193 self.await_(self._connection.rollback())
194
195 def commit(self):
196 self.await_(self._connection.commit())
197
198 def close(self):
199 # it's not awaitable.
200 self._connection.close()
201
202
203class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection):
204 __slots__ = ()
205
206 await_ = staticmethod(await_fallback)
207
208
209class AsyncAdapt_aiomysql_dbapi:
210 def __init__(self, aiomysql, pymysql):
211 self.aiomysql = aiomysql
212 self.pymysql = pymysql
213 self.paramstyle = "format"
214 self._init_dbapi_attributes()
215 self.Cursor, self.SSCursor = self._init_cursors_subclasses()
216
217 def _init_dbapi_attributes(self):
218 for name in (
219 "Warning",
220 "Error",
221 "InterfaceError",
222 "DataError",
223 "DatabaseError",
224 "OperationalError",
225 "InterfaceError",
226 "IntegrityError",
227 "ProgrammingError",
228 "InternalError",
229 "NotSupportedError",
230 ):
231 setattr(self, name, getattr(self.aiomysql, name))
232
233 for name in (
234 "NUMBER",
235 "STRING",
236 "DATETIME",
237 "BINARY",
238 "TIMESTAMP",
239 "Binary",
240 ):
241 setattr(self, name, getattr(self.pymysql, name))
242
243 def connect(self, *arg, **kw):
244 async_fallback = kw.pop("async_fallback", False)
245
246 if util.asbool(async_fallback):
247 return AsyncAdaptFallback_aiomysql_connection(
248 self,
249 await_fallback(self.aiomysql.connect(*arg, **kw)),
250 )
251 else:
252 return AsyncAdapt_aiomysql_connection(
253 self,
254 await_only(self.aiomysql.connect(*arg, **kw)),
255 )
256
257 def _init_cursors_subclasses(self):
258 # suppress unconditional warning emitted by aiomysql
259 class Cursor(self.aiomysql.Cursor):
260 async def _show_warnings(self, conn):
261 pass
262
263 class SSCursor(self.aiomysql.SSCursor):
264 async def _show_warnings(self, conn):
265 pass
266
267 return Cursor, SSCursor
268
269
270class MySQLDialect_aiomysql(MySQLDialect_pymysql):
271 driver = "aiomysql"
272 supports_statement_cache = True
273
274 supports_server_side_cursors = True
275 _sscursor = AsyncAdapt_aiomysql_ss_cursor
276
277 is_async = True
278
279 @classmethod
280 def dbapi(cls):
281 return AsyncAdapt_aiomysql_dbapi(
282 __import__("aiomysql"), __import__("pymysql")
283 )
284
285 @classmethod
286 def get_pool_class(cls, url):
287
288 async_fallback = url.query.get("async_fallback", False)
289
290 if util.asbool(async_fallback):
291 return pool.FallbackAsyncAdaptedQueuePool
292 else:
293 return pool.AsyncAdaptedQueuePool
294
295 def create_connect_args(self, url):
296 return super(MySQLDialect_aiomysql, self).create_connect_args(
297 url, _translate_args=dict(username="user", database="db")
298 )
299
300 def is_disconnect(self, e, connection, cursor):
301 if super(MySQLDialect_aiomysql, self).is_disconnect(
302 e, connection, cursor
303 ):
304 return True
305 else:
306 str_e = str(e).lower()
307 return "not connected" in str_e
308
309 def _found_rows_client_flag(self):
310 from pymysql.constants import CLIENT
311
312 return CLIENT.FOUND_ROWS
313
314 def get_driver_connection(self, connection):
315 return connection._connection
316
317
318dialect = MySQLDialect_aiomysql