Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/sqlalchemy/connectors/asyncio.py: 48%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

251 statements  

1# connectors/asyncio.py 

2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors 

3# <see AUTHORS file> 

4# 

5# This module is part of SQLAlchemy and is released under 

6# the MIT License: https://www.opensource.org/licenses/mit-license.php 

7 

8"""generic asyncio-adapted versions of DBAPI connection and cursor""" 

9 

10from __future__ import annotations 

11 

12import asyncio 

13import collections 

14import sys 

15import types 

16from typing import Any 

17from typing import AsyncIterator 

18from typing import Awaitable 

19from typing import Deque 

20from typing import Iterator 

21from typing import NoReturn 

22from typing import Optional 

23from typing import Protocol 

24from typing import Sequence 

25from typing import Tuple 

26from typing import Type 

27from typing import TYPE_CHECKING 

28 

29from ..engine import AdaptedConnection 

30from ..exc import EmulatedDBAPIException 

31from ..util import EMPTY_DICT 

32from ..util.concurrency import await_ 

33from ..util.concurrency import in_greenlet 

34 

35if TYPE_CHECKING: 

36 from ..engine.interfaces import _DBAPICursorDescription 

37 from ..engine.interfaces import _DBAPIMultiExecuteParams 

38 from ..engine.interfaces import _DBAPISingleExecuteParams 

39 from ..engine.interfaces import DBAPIModule 

40 from ..util.typing import Self 

41 

42 

43class AsyncIODBAPIConnection(Protocol): 

44 """protocol representing an async adapted version of a 

45 :pep:`249` database connection. 

46 

47 

48 """ 

49 

50 # note that async DBAPIs dont agree if close() should be awaitable, 

51 # so it is omitted here and picked up by the __getattr__ hook below 

52 

53 async def commit(self) -> None: ... 

54 

55 def cursor(self, *args: Any, **kwargs: Any) -> AsyncIODBAPICursor: ... 

56 

57 async def rollback(self) -> None: ... 

58 

59 def __getattr__(self, key: str) -> Any: ... 

60 

61 def __setattr__(self, key: str, value: Any) -> None: ... 

62 

63 

64class AsyncIODBAPICursor(Protocol): 

65 """protocol representing an async adapted version 

66 of a :pep:`249` database cursor. 

67 

68 

69 """ 

70 

71 def __aenter__(self) -> Any: ... 

72 

73 @property 

74 def description( 

75 self, 

76 ) -> _DBAPICursorDescription: 

77 """The description attribute of the Cursor.""" 

78 ... 

79 

80 @property 

81 def rowcount(self) -> int: ... 

82 

83 arraysize: int 

84 

85 lastrowid: int 

86 

87 async def close(self) -> None: ... 

88 

89 async def execute( 

90 self, 

91 operation: Any, 

92 parameters: Optional[_DBAPISingleExecuteParams] = None, 

93 ) -> Any: ... 

94 

95 async def executemany( 

96 self, 

97 operation: Any, 

98 parameters: _DBAPIMultiExecuteParams, 

99 ) -> Any: ... 

100 

101 async def fetchone(self) -> Optional[Any]: ... 

102 

103 async def fetchmany(self, size: Optional[int] = ...) -> Sequence[Any]: ... 

104 

105 async def fetchall(self) -> Sequence[Any]: ... 

106 

107 async def setinputsizes(self, sizes: Sequence[Any]) -> None: ... 

108 

109 def setoutputsize(self, size: Any, column: Any) -> None: ... 

110 

111 async def callproc( 

112 self, procname: str, parameters: Sequence[Any] = ... 

113 ) -> Any: ... 

114 

115 async def nextset(self) -> Optional[bool]: ... 

116 

117 def __aiter__(self) -> AsyncIterator[Any]: ... 

118 

119 

120class AsyncAdapt_dbapi_module: 

121 if TYPE_CHECKING: 

122 Error = DBAPIModule.Error 

123 OperationalError = DBAPIModule.OperationalError 

124 InterfaceError = DBAPIModule.InterfaceError 

125 IntegrityError = DBAPIModule.IntegrityError 

126 

127 def __getattr__(self, key: str) -> Any: ... 

128 

129 def __init__( 

130 self, 

131 driver: types.ModuleType, 

132 *, 

133 dbapi_module: types.ModuleType | None = None, 

134 ): 

135 self.driver = driver 

136 self.dbapi_module = dbapi_module 

137 

138 @property 

139 def exceptions_module(self) -> types.ModuleType: 

140 """Return the module which we think will have the exception hierarchy. 

141 

142 For an asyncio driver that wraps a plain DBAPI like aiomysql, 

143 aioodbc, aiosqlite, etc. these exceptions will be from the 

144 dbapi_module. For a "pure" driver like asyncpg these will come 

145 from the driver module. 

146 

147 .. versionadded:: 2.1 

148 

149 """ 

150 if self.dbapi_module is not None: 

151 return self.dbapi_module 

152 else: 

153 return self.driver 

154 

155 

156class AsyncAdapt_dbapi_cursor: 

157 server_side = False 

158 __slots__ = ( 

159 "_adapt_connection", 

160 "_connection", 

161 "_cursor", 

162 "_rows", 

163 "_soft_closed_memoized", 

164 ) 

165 

166 _awaitable_cursor_close: bool = True 

167 

168 _cursor: AsyncIODBAPICursor 

169 _adapt_connection: AsyncAdapt_dbapi_connection 

170 _connection: AsyncIODBAPIConnection 

171 _rows: Deque[Any] 

172 

173 def __init__(self, adapt_connection: AsyncAdapt_dbapi_connection): 

174 self._adapt_connection = adapt_connection 

175 self._connection = adapt_connection._connection 

176 

177 cursor = self._make_new_cursor(self._connection) 

178 self._cursor = self._aenter_cursor(cursor) 

179 self._soft_closed_memoized = EMPTY_DICT 

180 if not self.server_side: 

181 self._rows = collections.deque() 

182 

183 def _aenter_cursor(self, cursor: AsyncIODBAPICursor) -> AsyncIODBAPICursor: 

184 try: 

185 return await_(cursor.__aenter__()) # type: ignore[no-any-return] 

186 except Exception as error: 

187 self._adapt_connection._handle_exception(error) 

188 

189 def _make_new_cursor( 

190 self, connection: AsyncIODBAPIConnection 

191 ) -> AsyncIODBAPICursor: 

192 return connection.cursor() 

193 

194 @property 

195 def description(self) -> Optional[_DBAPICursorDescription]: 

196 if "description" in self._soft_closed_memoized: 

197 return self._soft_closed_memoized["description"] # type: ignore[no-any-return] # noqa: E501 

198 return self._cursor.description 

199 

200 @property 

201 def rowcount(self) -> int: 

202 return self._cursor.rowcount 

203 

204 @property 

205 def arraysize(self) -> int: 

206 return self._cursor.arraysize 

207 

208 @arraysize.setter 

209 def arraysize(self, value: int) -> None: 

210 self._cursor.arraysize = value 

211 

212 @property 

213 def lastrowid(self) -> int: 

214 return self._cursor.lastrowid 

215 

216 async def _async_soft_close(self) -> None: 

217 """close the cursor but keep the results pending, and memoize the 

218 description. 

219 

220 .. versionadded:: 2.0.44 

221 

222 """ 

223 

224 if not self._awaitable_cursor_close or self.server_side: 

225 return 

226 

227 self._soft_closed_memoized = self._soft_closed_memoized.union( 

228 { 

229 "description": self._cursor.description, 

230 } 

231 ) 

232 await self._cursor.close() 

233 

234 def close(self) -> None: 

235 self._rows.clear() 

236 

237 # updated as of 2.0.44 

238 # try to "close" the cursor based on what we know about the driver 

239 # and if we are able to. otherwise, hope that the asyncio 

240 # extension called _async_soft_close() if the cursor is going into 

241 # a sync context 

242 if self._cursor is None or bool(self._soft_closed_memoized): 

243 return 

244 

245 if not self._awaitable_cursor_close: 

246 self._cursor.close() # type: ignore[unused-coroutine] 

247 elif in_greenlet(): 

248 await_(self._cursor.close()) 

249 

250 def execute( 

251 self, 

252 operation: Any, 

253 parameters: Optional[_DBAPISingleExecuteParams] = None, 

254 ) -> Any: 

255 try: 

256 return await_(self._execute_async(operation, parameters)) 

257 except Exception as error: 

258 self._adapt_connection._handle_exception(error) 

259 

260 def executemany( 

261 self, 

262 operation: Any, 

263 seq_of_parameters: _DBAPIMultiExecuteParams, 

264 ) -> Any: 

265 try: 

266 return await_( 

267 self._executemany_async(operation, seq_of_parameters) 

268 ) 

269 except Exception as error: 

270 self._adapt_connection._handle_exception(error) 

271 

272 async def _execute_async( 

273 self, operation: Any, parameters: Optional[_DBAPISingleExecuteParams] 

274 ) -> Any: 

275 async with self._adapt_connection._execute_mutex: 

276 if parameters is None: 

277 result = await self._cursor.execute(operation) 

278 else: 

279 result = await self._cursor.execute(operation, parameters) 

280 

281 if self._cursor.description and not self.server_side: 

282 self._rows = collections.deque(await self._cursor.fetchall()) 

283 return result 

284 

285 async def _executemany_async( 

286 self, 

287 operation: Any, 

288 seq_of_parameters: _DBAPIMultiExecuteParams, 

289 ) -> Any: 

290 async with self._adapt_connection._execute_mutex: 

291 return await self._cursor.executemany(operation, seq_of_parameters) 

292 

293 def nextset(self) -> None: 

294 await_(self._cursor.nextset()) 

295 if self._cursor.description and not self.server_side: 

296 self._rows = collections.deque(await_(self._cursor.fetchall())) 

297 

298 def setinputsizes(self, *inputsizes: Any) -> None: 

299 # NOTE: this is overrridden in aioodbc due to 

300 # see https://github.com/aio-libs/aioodbc/issues/451 

301 # right now 

302 

303 return await_(self._cursor.setinputsizes(*inputsizes)) 

304 

305 def __enter__(self) -> Self: 

306 return self 

307 

308 def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: 

309 self.close() 

310 

311 def __iter__(self) -> Iterator[Any]: 

312 while self._rows: 

313 yield self._rows.popleft() 

314 

315 def fetchone(self) -> Optional[Any]: 

316 if self._rows: 

317 return self._rows.popleft() 

318 else: 

319 return None 

320 

321 def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]: 

322 if size is None: 

323 size = self.arraysize 

324 rr = self._rows 

325 return [rr.popleft() for _ in range(min(size, len(rr)))] 

326 

327 def fetchall(self) -> Sequence[Any]: 

328 retval = list(self._rows) 

329 self._rows.clear() 

330 return retval 

331 

332 

333class AsyncAdapt_dbapi_ss_cursor(AsyncAdapt_dbapi_cursor): 

334 __slots__ = () 

335 server_side = True 

336 

337 def close(self) -> None: 

338 if self._cursor is not None: 

339 await_(self._cursor.close()) 

340 self._cursor = None # type: ignore 

341 

342 def fetchone(self) -> Optional[Any]: 

343 return await_(self._cursor.fetchone()) 

344 

345 def fetchmany(self, size: Optional[int] = None) -> Any: 

346 return await_(self._cursor.fetchmany(size=size)) 

347 

348 def fetchall(self) -> Sequence[Any]: 

349 return await_(self._cursor.fetchall()) 

350 

351 def __iter__(self) -> Iterator[Any]: 

352 iterator = self._cursor.__aiter__() 

353 while True: 

354 try: 

355 yield await_(iterator.__anext__()) 

356 except StopAsyncIteration: 

357 break 

358 

359 

360class AsyncAdapt_dbapi_connection(AdaptedConnection): 

361 _cursor_cls = AsyncAdapt_dbapi_cursor 

362 _ss_cursor_cls = AsyncAdapt_dbapi_ss_cursor 

363 

364 __slots__ = ("dbapi", "_execute_mutex") 

365 

366 _connection: AsyncIODBAPIConnection 

367 

368 @classmethod 

369 async def create( 

370 cls, 

371 dbapi: Any, 

372 connection_awaitable: Awaitable[AsyncIODBAPIConnection], 

373 **kw: Any, 

374 ) -> Self: 

375 try: 

376 connection = await connection_awaitable 

377 except Exception as error: 

378 cls._handle_exception_no_connection(dbapi, error) 

379 else: 

380 return cls(dbapi, connection, **kw) 

381 

382 def __init__(self, dbapi: Any, connection: AsyncIODBAPIConnection): 

383 self.dbapi = dbapi 

384 self._connection = connection 

385 self._execute_mutex = asyncio.Lock() 

386 

387 def cursor(self, server_side: bool = False) -> AsyncAdapt_dbapi_cursor: 

388 if server_side: 

389 return self._ss_cursor_cls(self) 

390 else: 

391 return self._cursor_cls(self) 

392 

393 def execute( 

394 self, 

395 operation: Any, 

396 parameters: Optional[_DBAPISingleExecuteParams] = None, 

397 ) -> Any: 

398 """lots of DBAPIs seem to provide this, so include it""" 

399 cursor = self.cursor() 

400 cursor.execute(operation, parameters) 

401 return cursor 

402 

403 @classmethod 

404 def _handle_exception_no_connection( 

405 cls, dbapi: Any, error: Exception 

406 ) -> NoReturn: 

407 exc_info = sys.exc_info() 

408 

409 raise error.with_traceback(exc_info[2]) 

410 

411 def _handle_exception(self, error: Exception) -> NoReturn: 

412 self._handle_exception_no_connection(self.dbapi, error) 

413 

414 def rollback(self) -> None: 

415 try: 

416 await_(self._connection.rollback()) 

417 except Exception as error: 

418 self._handle_exception(error) 

419 

420 def commit(self) -> None: 

421 try: 

422 await_(self._connection.commit()) 

423 except Exception as error: 

424 self._handle_exception(error) 

425 

426 def close(self) -> None: 

427 await_(self._connection.close()) 

428 

429 

430class AsyncAdapt_terminate: 

431 """Mixin for a AsyncAdapt_dbapi_connection to add terminate support.""" 

432 

433 __slots__ = () 

434 

435 def terminate(self) -> None: 

436 if in_greenlet(): 

437 # in a greenlet; this is the connection was invalidated case. 

438 try: 

439 # try to gracefully close; see #10717 

440 await_(asyncio.shield(self._terminate_graceful_close())) 

441 except self._terminate_handled_exceptions() as e: 

442 # in the case where we are recycling an old connection 

443 # that may have already been disconnected, close() will 

444 # fail. In this case, terminate 

445 # the connection without any further waiting. 

446 # see issue #8419 

447 self._terminate_force_close() 

448 if isinstance(e, asyncio.CancelledError): 

449 # re-raise CancelledError if we were cancelled 

450 raise 

451 else: 

452 # not in a greenlet; this is the gc cleanup case 

453 self._terminate_force_close() 

454 

455 def _terminate_handled_exceptions(self) -> Tuple[Type[BaseException], ...]: 

456 """Returns the exceptions that should be handled when 

457 calling _graceful_close. 

458 """ 

459 return (asyncio.TimeoutError, asyncio.CancelledError, OSError) 

460 

461 async def _terminate_graceful_close(self) -> None: 

462 """Try to close connection gracefully""" 

463 raise NotImplementedError 

464 

465 def _terminate_force_close(self) -> None: 

466 """Terminate the connection""" 

467 raise NotImplementedError 

468 

469 

470class AsyncAdapt_Error(EmulatedDBAPIException): 

471 """Provide for the base of DBAPI ``Error`` base class for dialects 

472 that need to emulate the DBAPI exception hierarchy. 

473 

474 .. versionadded:: 2.1 

475 

476 """