Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/aiosqlite/core.py: 35%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

216 statements  

1# Copyright Amethyst Reese 

2# Licensed under the MIT license 

3 

4""" 

5Core implementation of aiosqlite proxies 

6""" 

7 

8import asyncio 

9import logging 

10import sqlite3 

11from collections.abc import AsyncIterator, Generator, Iterable 

12from functools import partial 

13from pathlib import Path 

14from queue import Empty, Queue, SimpleQueue 

15from threading import Thread 

16from typing import Any, Callable, Literal, Optional, Union 

17from warnings import warn 

18 

19from .context import contextmanager 

20from .cursor import Cursor 

21 

22__all__ = ["connect", "Connection", "Cursor"] 

23 

24AuthorizerCallback = Callable[[int, str, str, str, str], int] 

25 

26LOG = logging.getLogger("aiosqlite") 

27 

28 

29IsolationLevel = Optional[Literal["DEFERRED", "IMMEDIATE", "EXCLUSIVE"]] 

30 

31 

32def set_result(fut: asyncio.Future, result: Any) -> None: 

33 """Set the result of a future if it hasn't been set already.""" 

34 if not fut.done(): 

35 fut.set_result(result) 

36 

37 

38def set_exception(fut: asyncio.Future, e: BaseException) -> None: 

39 """Set the exception of a future if it hasn't been set already.""" 

40 if not fut.done(): 

41 fut.set_exception(e) 

42 

43 

44_STOP_RUNNING_SENTINEL = object() 

45 

46 

47def _connection_worker_thread( 

48 tx: SimpleQueue[tuple[asyncio.Future, Callable[[], Any]]], 

49): 

50 """ 

51 Execute function calls on a separate thread. 

52 

53 :meta private: 

54 """ 

55 while True: 

56 # Continues running until all queue items are processed, 

57 # even after connection is closed (so we can finalize all 

58 # futures) 

59 

60 tx_item = tx.get() 

61 future, function = tx_item 

62 

63 try: 

64 LOG.debug("executing %s", function) 

65 result = function() 

66 LOG.debug("operation %s completed", function) 

67 future.get_loop().call_soon_threadsafe(set_result, future, result) 

68 

69 if result is _STOP_RUNNING_SENTINEL: 

70 break 

71 

72 except BaseException as e: # noqa B036 

73 LOG.debug("returning exception %s", e) 

74 future.get_loop().call_soon_threadsafe(set_exception, future, e) 

75 

76 

77class Connection: 

78 def __init__( 

79 self, 

80 connector: Callable[[], sqlite3.Connection], 

81 iter_chunk_size: int, 

82 loop: Optional[asyncio.AbstractEventLoop] = None, 

83 ) -> None: 

84 self._running = True 

85 self._connection: Optional[sqlite3.Connection] = None 

86 self._connector = connector 

87 self._tx: SimpleQueue[tuple[asyncio.Future, Callable[[], Any]]] = SimpleQueue() 

88 self._iter_chunk_size = iter_chunk_size 

89 self._thread = Thread(target=_connection_worker_thread, args=(self._tx,)) 

90 

91 if loop is not None: 

92 warn( 

93 "aiosqlite.Connection no longer uses the `loop` parameter", 

94 DeprecationWarning, 

95 ) 

96 

97 def _stop_running(self) -> asyncio.Future: 

98 self._running = False 

99 

100 function = partial(lambda: _STOP_RUNNING_SENTINEL) 

101 future = asyncio.get_event_loop().create_future() 

102 

103 self._tx.put_nowait((future, function)) 

104 

105 return future 

106 

107 @property 

108 def _conn(self) -> sqlite3.Connection: 

109 if self._connection is None: 

110 raise ValueError("no active connection") 

111 

112 return self._connection 

113 

114 def _execute_insert(self, sql: str, parameters: Any) -> Optional[sqlite3.Row]: 

115 cursor = self._conn.execute(sql, parameters) 

116 cursor.execute("SELECT last_insert_rowid()") 

117 return cursor.fetchone() 

118 

119 def _execute_fetchall(self, sql: str, parameters: Any) -> Iterable[sqlite3.Row]: 

120 cursor = self._conn.execute(sql, parameters) 

121 return cursor.fetchall() 

122 

123 async def _execute(self, fn, *args, **kwargs): 

124 """Queue a function with the given arguments for execution.""" 

125 if not self._running or not self._connection: 

126 raise ValueError("Connection closed") 

127 

128 function = partial(fn, *args, **kwargs) 

129 future = asyncio.get_event_loop().create_future() 

130 

131 self._tx.put_nowait((future, function)) 

132 

133 return await future 

134 

135 async def _connect(self) -> "Connection": 

136 """Connect to the actual sqlite database.""" 

137 if self._connection is None: 

138 try: 

139 future = asyncio.get_event_loop().create_future() 

140 self._tx.put_nowait((future, self._connector)) 

141 self._connection = await future 

142 except BaseException: 

143 await self._stop_running() 

144 self._connection = None 

145 raise 

146 

147 return self 

148 

149 def __await__(self) -> Generator[Any, None, "Connection"]: 

150 self._thread.start() 

151 return self._connect().__await__() 

152 

153 async def __aenter__(self) -> "Connection": 

154 return await self 

155 

156 async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: 

157 await self.close() 

158 

159 @contextmanager 

160 async def cursor(self) -> Cursor: 

161 """Create an aiosqlite cursor wrapping a sqlite3 cursor object.""" 

162 return Cursor(self, await self._execute(self._conn.cursor)) 

163 

164 async def commit(self) -> None: 

165 """Commit the current transaction.""" 

166 await self._execute(self._conn.commit) 

167 

168 async def rollback(self) -> None: 

169 """Roll back the current transaction.""" 

170 await self._execute(self._conn.rollback) 

171 

172 async def close(self) -> None: 

173 """Complete queued queries/cursors and close the connection.""" 

174 

175 if self._connection is None: 

176 return 

177 

178 try: 

179 await self._execute(self._conn.close) 

180 except Exception: 

181 LOG.info("exception occurred while closing connection") 

182 raise 

183 finally: 

184 await self._stop_running() 

185 self._connection = None 

186 

187 @contextmanager 

188 async def execute( 

189 self, sql: str, parameters: Optional[Iterable[Any]] = None 

190 ) -> Cursor: 

191 """Helper to create a cursor and execute the given query.""" 

192 if parameters is None: 

193 parameters = [] 

194 cursor = await self._execute(self._conn.execute, sql, parameters) 

195 return Cursor(self, cursor) 

196 

197 @contextmanager 

198 async def execute_insert( 

199 self, sql: str, parameters: Optional[Iterable[Any]] = None 

200 ) -> Optional[sqlite3.Row]: 

201 """Helper to insert and get the last_insert_rowid.""" 

202 if parameters is None: 

203 parameters = [] 

204 return await self._execute(self._execute_insert, sql, parameters) 

205 

206 @contextmanager 

207 async def execute_fetchall( 

208 self, sql: str, parameters: Optional[Iterable[Any]] = None 

209 ) -> Iterable[sqlite3.Row]: 

210 """Helper to execute a query and return all the data.""" 

211 if parameters is None: 

212 parameters = [] 

213 return await self._execute(self._execute_fetchall, sql, parameters) 

214 

215 @contextmanager 

216 async def executemany( 

217 self, sql: str, parameters: Iterable[Iterable[Any]] 

218 ) -> Cursor: 

219 """Helper to create a cursor and execute the given multiquery.""" 

220 cursor = await self._execute(self._conn.executemany, sql, parameters) 

221 return Cursor(self, cursor) 

222 

223 @contextmanager 

224 async def executescript(self, sql_script: str) -> Cursor: 

225 """Helper to create a cursor and execute a user script.""" 

226 cursor = await self._execute(self._conn.executescript, sql_script) 

227 return Cursor(self, cursor) 

228 

229 async def interrupt(self) -> None: 

230 """Interrupt pending queries.""" 

231 return self._conn.interrupt() 

232 

233 async def create_function( 

234 self, name: str, num_params: int, func: Callable, deterministic: bool = False 

235 ) -> None: 

236 """ 

237 Create user-defined function that can be later used 

238 within SQL statements. Must be run within the same thread 

239 that query executions take place so instead of executing directly 

240 against the connection, we defer this to `run` function. 

241 

242 If ``deterministic`` is true, the created function is marked as deterministic, 

243 which allows SQLite to perform additional optimizations. This flag is supported 

244 by SQLite 3.8.3 or higher, ``NotSupportedError`` will be raised if used with 

245 older versions. 

246 """ 

247 await self._execute( 

248 self._conn.create_function, 

249 name, 

250 num_params, 

251 func, 

252 deterministic=deterministic, 

253 ) 

254 

255 @property 

256 def in_transaction(self) -> bool: 

257 return self._conn.in_transaction 

258 

259 @property 

260 def isolation_level(self) -> Optional[str]: 

261 return self._conn.isolation_level 

262 

263 @isolation_level.setter 

264 def isolation_level(self, value: IsolationLevel) -> None: 

265 self._conn.isolation_level = value 

266 

267 @property 

268 def row_factory(self) -> Optional[type]: 

269 return self._conn.row_factory 

270 

271 @row_factory.setter 

272 def row_factory(self, factory: Optional[type]) -> None: 

273 self._conn.row_factory = factory 

274 

275 @property 

276 def text_factory(self) -> Callable[[bytes], Any]: 

277 return self._conn.text_factory 

278 

279 @text_factory.setter 

280 def text_factory(self, factory: Callable[[bytes], Any]) -> None: 

281 self._conn.text_factory = factory 

282 

283 @property 

284 def total_changes(self) -> int: 

285 return self._conn.total_changes 

286 

287 async def enable_load_extension(self, value: bool) -> None: 

288 await self._execute(self._conn.enable_load_extension, value) # type: ignore 

289 

290 async def load_extension(self, path: str): 

291 await self._execute(self._conn.load_extension, path) # type: ignore 

292 

293 async def set_progress_handler( 

294 self, handler: Callable[[], Optional[int]], n: int 

295 ) -> None: 

296 await self._execute(self._conn.set_progress_handler, handler, n) 

297 

298 async def set_trace_callback(self, handler: Callable) -> None: 

299 await self._execute(self._conn.set_trace_callback, handler) 

300 

301 async def set_authorizer( 

302 self, authorizer_callback: Optional[AuthorizerCallback] 

303 ) -> None: 

304 """ 

305 Set an authorizer callback to control database access. 

306 

307 The authorizer callback is invoked for each SQL statement that is prepared, 

308 and controls whether specific operations are permitted. 

309 

310 Example:: 

311 

312 import sqlite3 

313 

314 def restrict_drops(action_code, arg1, arg2, db_name, trigger_name): 

315 # Deny all DROP operations 

316 if action_code == sqlite3.SQLITE_DROP_TABLE: 

317 return sqlite3.SQLITE_DENY 

318 # Allow everything else 

319 return sqlite3.SQLITE_OK 

320 

321 await conn.set_authorizer(restrict_drops) 

322 

323 See ``sqlite3`` documentation for details: 

324 https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.set_authorizer 

325 

326 :param authorizer_callback: An optional callable that receives five arguments: 

327 

328 - ``action_code`` (int): The action to be authorized (e.g., ``SQLITE_READ``) 

329 - ``arg1`` (str): First argument, meaning depends on ``action_code`` 

330 - ``arg2`` (str): Second argument, meaning depends on ``action_code`` 

331 - ``db_name`` (str): Database name (e.g., ``"main"``, ``"temp"``) 

332 - ``trigger_name`` (str): Name of trigger or view that is doing the access, 

333 or ``None`` 

334 

335 The callback should return: 

336 

337 - ``SQLITE_OK`` (0): Allow the operation 

338 - ``SQLITE_DENY`` (1): Deny the operation, raise ``sqlite3.DatabaseError`` 

339 - ``SQLITE_IGNORE`` (2): Treat operation as no-op 

340 

341 Pass ``None`` to remove the authorizer. 

342 """ 

343 await self._execute(self._conn.set_authorizer, authorizer_callback) 

344 

345 async def iterdump(self) -> AsyncIterator[str]: 

346 """ 

347 Return an async iterator to dump the database in SQL text format. 

348 

349 Example:: 

350 

351 async for line in db.iterdump(): 

352 ... 

353 

354 """ 

355 dump_queue: Queue = Queue() 

356 

357 def dumper(): 

358 try: 

359 for line in self._conn.iterdump(): 

360 dump_queue.put_nowait(line) 

361 dump_queue.put_nowait(None) 

362 

363 except Exception: 

364 LOG.exception("exception while dumping db") 

365 dump_queue.put_nowait(None) 

366 raise 

367 

368 fut = self._execute(dumper) 

369 task = asyncio.ensure_future(fut) 

370 

371 while True: 

372 try: 

373 line: Optional[str] = dump_queue.get_nowait() 

374 if line is None: 

375 break 

376 yield line 

377 

378 except Empty: 

379 if task.done(): 

380 LOG.warning("iterdump completed unexpectedly") 

381 break 

382 

383 await asyncio.sleep(0.01) 

384 

385 await task 

386 

387 async def backup( 

388 self, 

389 target: Union["Connection", sqlite3.Connection], 

390 *, 

391 pages: int = 0, 

392 progress: Optional[Callable[[int, int, int], None]] = None, 

393 name: str = "main", 

394 sleep: float = 0.250, 

395 ) -> None: 

396 """ 

397 Make a backup of the current database to the target database. 

398 

399 Takes either a standard sqlite3 or aiosqlite Connection object as the target. 

400 """ 

401 if isinstance(target, Connection): 

402 target = target._conn 

403 

404 await self._execute( 

405 self._conn.backup, 

406 target, 

407 pages=pages, 

408 progress=progress, 

409 name=name, 

410 sleep=sleep, 

411 ) 

412 

413 def __del__(self): 

414 if self._connection is None: 

415 return 

416 

417 warn( 

418 ( 

419 f"{self!r} was deleted before being closed. " 

420 "Please use 'async with' or '.close()' to close the connection properly." 

421 ), 

422 ResourceWarning, 

423 stacklevel=1, 

424 ) 

425 

426 # Don't try to be creative here, the event loop may have already been closed. 

427 # Simply stop the worker thread, and let the underlying sqlite3 connection 

428 # be finalized by its own __del__. 

429 self._stop_running() 

430 

431 

432def connect( 

433 database: Union[str, Path], 

434 *, 

435 iter_chunk_size=64, 

436 loop: Optional[asyncio.AbstractEventLoop] = None, 

437 **kwargs: Any, 

438) -> Connection: 

439 """Create and return a connection proxy to the sqlite database.""" 

440 

441 if loop is not None: 

442 warn( 

443 "aiosqlite.connect() no longer uses the `loop` parameter", 

444 DeprecationWarning, 

445 ) 

446 

447 def connector() -> sqlite3.Connection: 

448 if isinstance(database, str): 

449 loc = database 

450 elif isinstance(database, bytes): 

451 loc = database.decode("utf-8") 

452 else: 

453 loc = str(database) 

454 

455 return sqlite3.connect(loc, **kwargs) 

456 

457 return Connection(connector, iter_chunk_size)