Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/aiosqlite/core.py: 33%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

227 statements  

1# Copyright Amethyst Reese 

2# Licensed under the MIT license 

3 

4""" 

5Core implementation of aiosqlite proxies 

6""" 

7 

8import asyncio 

9import logging 

10import sqlite3 

11from collections.abc import AsyncIterator, Generator, Iterable 

12from functools import partial 

13from pathlib import Path 

14from queue import Empty, Queue, SimpleQueue 

15from threading import Thread 

16from typing import Any, Callable, Literal, Optional, Union 

17from warnings import warn 

18 

19from .context import contextmanager 

20from .cursor import Cursor 

21 

22__all__ = ["connect", "Connection", "Cursor"] 

23 

24AuthorizerCallback = Callable[[int, str, str, str, str], int] 

25 

26LOG = logging.getLogger("aiosqlite") 

27 

28 

29IsolationLevel = Optional[Literal["DEFERRED", "IMMEDIATE", "EXCLUSIVE"]] 

30 

31 

32def set_result(fut: asyncio.Future, result: Any) -> None: 

33 """Set the result of a future if it hasn't been set already.""" 

34 if not fut.done(): 

35 fut.set_result(result) 

36 

37 

38def set_exception(fut: asyncio.Future, e: BaseException) -> None: 

39 """Set the exception of a future if it hasn't been set already.""" 

40 if not fut.done(): 

41 fut.set_exception(e) 

42 

43 

44_STOP_RUNNING_SENTINEL = object() 

45_TxQueue = SimpleQueue[tuple[Optional[asyncio.Future], Callable[[], Any]]] 

46 

47 

48def _connection_worker_thread(tx: _TxQueue): 

49 """ 

50 Execute function calls on a separate thread. 

51 

52 :meta private: 

53 """ 

54 while True: 

55 # Continues running until all queue items are processed, 

56 # even after connection is closed (so we can finalize all 

57 # futures) 

58 

59 future, function = tx.get() 

60 

61 try: 

62 LOG.debug("executing %s", function) 

63 result = function() 

64 

65 if future: 

66 future.get_loop().call_soon_threadsafe(set_result, future, result) 

67 LOG.debug("operation %s completed", function) 

68 

69 if result is _STOP_RUNNING_SENTINEL: 

70 break 

71 

72 except BaseException as e: # noqa B036 

73 LOG.debug("returning exception %s", e) 

74 if future: 

75 future.get_loop().call_soon_threadsafe(set_exception, future, e) 

76 

77 

78class Connection: 

79 def __init__( 

80 self, 

81 connector: Callable[[], sqlite3.Connection], 

82 iter_chunk_size: int, 

83 loop: Optional[asyncio.AbstractEventLoop] = None, 

84 ) -> None: 

85 self._running = True 

86 self._connection: Optional[sqlite3.Connection] = None 

87 self._connector = connector 

88 self._tx: _TxQueue = SimpleQueue() 

89 self._iter_chunk_size = iter_chunk_size 

90 self._thread = Thread(target=_connection_worker_thread, args=(self._tx,)) 

91 

92 if loop is not None: 

93 warn( 

94 "aiosqlite.Connection no longer uses the `loop` parameter", 

95 DeprecationWarning, 

96 ) 

97 

98 def __del__(self): 

99 if self._connection is None: 

100 return 

101 

102 warn( 

103 ( 

104 f"{self!r} was deleted before being closed. " 

105 "Please use 'async with' or '.close()' to close the connection properly." 

106 ), 

107 ResourceWarning, 

108 stacklevel=1, 

109 ) 

110 

111 # Don't try to be creative here, the event loop may have already been closed. 

112 # Simply stop the worker thread, and let the underlying sqlite3 connection 

113 # be finalized by its own __del__. 

114 self.stop() 

115 

116 def stop(self) -> Optional[asyncio.Future]: 

117 """Stop the background thread. Prefer `async with` or `await close()`""" 

118 self._running = False 

119 

120 def close_and_stop(): 

121 if self._connection is not None: 

122 self._connection.close() 

123 self._connection = None 

124 return _STOP_RUNNING_SENTINEL 

125 

126 try: 

127 future = asyncio.get_event_loop().create_future() 

128 except Exception: 

129 future = None 

130 

131 self._tx.put_nowait((future, close_and_stop)) 

132 return future 

133 

134 @property 

135 def _conn(self) -> sqlite3.Connection: 

136 if self._connection is None: 

137 raise ValueError("no active connection") 

138 

139 return self._connection 

140 

141 def _execute_insert(self, sql: str, parameters: Any) -> Optional[sqlite3.Row]: 

142 cursor = self._conn.execute(sql, parameters) 

143 cursor.execute("SELECT last_insert_rowid()") 

144 return cursor.fetchone() 

145 

146 def _execute_fetchall(self, sql: str, parameters: Any) -> Iterable[sqlite3.Row]: 

147 cursor = self._conn.execute(sql, parameters) 

148 return cursor.fetchall() 

149 

150 async def _execute(self, fn, *args, **kwargs): 

151 """Queue a function with the given arguments for execution.""" 

152 if not self._running or not self._connection: 

153 raise ValueError("Connection closed") 

154 

155 function = partial(fn, *args, **kwargs) 

156 future = asyncio.get_event_loop().create_future() 

157 

158 self._tx.put_nowait((future, function)) 

159 

160 return await future 

161 

162 async def _connect(self) -> "Connection": 

163 """Connect to the actual sqlite database.""" 

164 if self._connection is None: 

165 try: 

166 future = asyncio.get_event_loop().create_future() 

167 self._tx.put_nowait((future, self._connector)) 

168 self._connection = await future 

169 except BaseException: 

170 self.stop() 

171 self._connection = None 

172 raise 

173 

174 return self 

175 

176 def __await__(self) -> Generator[Any, None, "Connection"]: 

177 self._thread.start() 

178 return self._connect().__await__() 

179 

180 async def __aenter__(self) -> "Connection": 

181 return await self 

182 

183 async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: 

184 await self.close() 

185 

186 @contextmanager 

187 async def cursor(self) -> Cursor: 

188 """Create an aiosqlite cursor wrapping a sqlite3 cursor object.""" 

189 return Cursor(self, await self._execute(self._conn.cursor)) 

190 

191 async def commit(self) -> None: 

192 """Commit the current transaction.""" 

193 await self._execute(self._conn.commit) 

194 

195 async def rollback(self) -> None: 

196 """Roll back the current transaction.""" 

197 await self._execute(self._conn.rollback) 

198 

199 async def close(self) -> None: 

200 """Complete queued queries/cursors and close the connection.""" 

201 

202 if self._connection is None: 

203 return 

204 

205 try: 

206 await self._execute(self._conn.close) 

207 except Exception: 

208 LOG.info("exception occurred while closing connection") 

209 raise 

210 finally: 

211 self._connection = None 

212 future = self.stop() 

213 if future: 

214 await future 

215 

216 @contextmanager 

217 async def execute( 

218 self, sql: str, parameters: Optional[Iterable[Any]] = None 

219 ) -> Cursor: 

220 """Helper to create a cursor and execute the given query.""" 

221 if parameters is None: 

222 parameters = [] 

223 cursor = await self._execute(self._conn.execute, sql, parameters) 

224 return Cursor(self, cursor) 

225 

226 @contextmanager 

227 async def execute_insert( 

228 self, sql: str, parameters: Optional[Iterable[Any]] = None 

229 ) -> Optional[sqlite3.Row]: 

230 """Helper to insert and get the last_insert_rowid.""" 

231 if parameters is None: 

232 parameters = [] 

233 return await self._execute(self._execute_insert, sql, parameters) 

234 

235 @contextmanager 

236 async def execute_fetchall( 

237 self, sql: str, parameters: Optional[Iterable[Any]] = None 

238 ) -> Iterable[sqlite3.Row]: 

239 """Helper to execute a query and return all the data.""" 

240 if parameters is None: 

241 parameters = [] 

242 return await self._execute(self._execute_fetchall, sql, parameters) 

243 

244 @contextmanager 

245 async def executemany( 

246 self, sql: str, parameters: Iterable[Iterable[Any]] 

247 ) -> Cursor: 

248 """Helper to create a cursor and execute the given multiquery.""" 

249 cursor = await self._execute(self._conn.executemany, sql, parameters) 

250 return Cursor(self, cursor) 

251 

252 @contextmanager 

253 async def executescript(self, sql_script: str) -> Cursor: 

254 """Helper to create a cursor and execute a user script.""" 

255 cursor = await self._execute(self._conn.executescript, sql_script) 

256 return Cursor(self, cursor) 

257 

258 async def interrupt(self) -> None: 

259 """Interrupt pending queries.""" 

260 return self._conn.interrupt() 

261 

262 async def create_function( 

263 self, name: str, num_params: int, func: Callable, deterministic: bool = False 

264 ) -> None: 

265 """ 

266 Create user-defined function that can be later used 

267 within SQL statements. Must be run within the same thread 

268 that query executions take place so instead of executing directly 

269 against the connection, we defer this to `run` function. 

270 

271 If ``deterministic`` is true, the created function is marked as deterministic, 

272 which allows SQLite to perform additional optimizations. This flag is supported 

273 by SQLite 3.8.3 or higher, ``NotSupportedError`` will be raised if used with 

274 older versions. 

275 """ 

276 await self._execute( 

277 self._conn.create_function, 

278 name, 

279 num_params, 

280 func, 

281 deterministic=deterministic, 

282 ) 

283 

284 @property 

285 def in_transaction(self) -> bool: 

286 return self._conn.in_transaction 

287 

288 @property 

289 def isolation_level(self) -> Optional[str]: 

290 return self._conn.isolation_level 

291 

292 @isolation_level.setter 

293 def isolation_level(self, value: IsolationLevel) -> None: 

294 self._conn.isolation_level = value 

295 

296 @property 

297 def row_factory(self) -> Optional[type]: 

298 return self._conn.row_factory 

299 

300 @row_factory.setter 

301 def row_factory(self, factory: Optional[type]) -> None: 

302 self._conn.row_factory = factory 

303 

304 @property 

305 def text_factory(self) -> Callable[[bytes], Any]: 

306 return self._conn.text_factory 

307 

308 @text_factory.setter 

309 def text_factory(self, factory: Callable[[bytes], Any]) -> None: 

310 self._conn.text_factory = factory 

311 

312 @property 

313 def total_changes(self) -> int: 

314 return self._conn.total_changes 

315 

316 async def enable_load_extension(self, value: bool) -> None: 

317 await self._execute(self._conn.enable_load_extension, value) # type: ignore 

318 

319 async def load_extension(self, path: str): 

320 await self._execute(self._conn.load_extension, path) # type: ignore 

321 

322 async def set_progress_handler( 

323 self, handler: Callable[[], Optional[int]], n: int 

324 ) -> None: 

325 await self._execute(self._conn.set_progress_handler, handler, n) 

326 

327 async def set_trace_callback(self, handler: Callable) -> None: 

328 await self._execute(self._conn.set_trace_callback, handler) 

329 

330 async def set_authorizer( 

331 self, authorizer_callback: Optional[AuthorizerCallback] 

332 ) -> None: 

333 """ 

334 Set an authorizer callback to control database access. 

335 

336 The authorizer callback is invoked for each SQL statement that is prepared, 

337 and controls whether specific operations are permitted. 

338 

339 Example:: 

340 

341 import sqlite3 

342 

343 def restrict_drops(action_code, arg1, arg2, db_name, trigger_name): 

344 # Deny all DROP operations 

345 if action_code == sqlite3.SQLITE_DROP_TABLE: 

346 return sqlite3.SQLITE_DENY 

347 # Allow everything else 

348 return sqlite3.SQLITE_OK 

349 

350 await conn.set_authorizer(restrict_drops) 

351 

352 See ``sqlite3`` documentation for details: 

353 https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.set_authorizer 

354 

355 :param authorizer_callback: An optional callable that receives five arguments: 

356 

357 - ``action_code`` (int): The action to be authorized (e.g., ``SQLITE_READ``) 

358 - ``arg1`` (str): First argument, meaning depends on ``action_code`` 

359 - ``arg2`` (str): Second argument, meaning depends on ``action_code`` 

360 - ``db_name`` (str): Database name (e.g., ``"main"``, ``"temp"``) 

361 - ``trigger_name`` (str): Name of trigger or view that is doing the access, 

362 or ``None`` 

363 

364 The callback should return: 

365 

366 - ``SQLITE_OK`` (0): Allow the operation 

367 - ``SQLITE_DENY`` (1): Deny the operation, raise ``sqlite3.DatabaseError`` 

368 - ``SQLITE_IGNORE`` (2): Treat operation as no-op 

369 

370 Pass ``None`` to remove the authorizer. 

371 """ 

372 await self._execute(self._conn.set_authorizer, authorizer_callback) 

373 

374 async def iterdump(self) -> AsyncIterator[str]: 

375 """ 

376 Return an async iterator to dump the database in SQL text format. 

377 

378 Example:: 

379 

380 async for line in db.iterdump(): 

381 ... 

382 

383 """ 

384 dump_queue: Queue = Queue() 

385 

386 def dumper(): 

387 try: 

388 for line in self._conn.iterdump(): 

389 dump_queue.put_nowait(line) 

390 dump_queue.put_nowait(None) 

391 

392 except Exception: 

393 LOG.exception("exception while dumping db") 

394 dump_queue.put_nowait(None) 

395 raise 

396 

397 fut = self._execute(dumper) 

398 task = asyncio.ensure_future(fut) 

399 

400 while True: 

401 try: 

402 line: Optional[str] = dump_queue.get_nowait() 

403 if line is None: 

404 break 

405 yield line 

406 

407 except Empty: 

408 if task.done(): 

409 LOG.warning("iterdump completed unexpectedly") 

410 break 

411 

412 await asyncio.sleep(0.01) 

413 

414 await task 

415 

416 async def backup( 

417 self, 

418 target: Union["Connection", sqlite3.Connection], 

419 *, 

420 pages: int = 0, 

421 progress: Optional[Callable[[int, int, int], None]] = None, 

422 name: str = "main", 

423 sleep: float = 0.250, 

424 ) -> None: 

425 """ 

426 Make a backup of the current database to the target database. 

427 

428 Takes either a standard sqlite3 or aiosqlite Connection object as the target. 

429 """ 

430 if isinstance(target, Connection): 

431 target = target._conn 

432 

433 await self._execute( 

434 self._conn.backup, 

435 target, 

436 pages=pages, 

437 progress=progress, 

438 name=name, 

439 sleep=sleep, 

440 ) 

441 

442 

443def connect( 

444 database: Union[str, Path], 

445 *, 

446 iter_chunk_size=64, 

447 loop: Optional[asyncio.AbstractEventLoop] = None, 

448 **kwargs: Any, 

449) -> Connection: 

450 """Create and return a connection proxy to the sqlite database.""" 

451 

452 if loop is not None: 

453 warn( 

454 "aiosqlite.connect() no longer uses the `loop` parameter", 

455 DeprecationWarning, 

456 ) 

457 

458 def connector() -> sqlite3.Connection: 

459 if isinstance(database, str): 

460 loc = database 

461 elif isinstance(database, bytes): 

462 loc = database.decode("utf-8") 

463 else: 

464 loc = str(database) 

465 

466 return sqlite3.connect(loc, **kwargs) 

467 

468 return Connection(connector, iter_chunk_size)