Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/aiosqlite/core.py: 35%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

205 statements  

1# Copyright Amethyst Reese 

2# Licensed under the MIT license 

3 

4""" 

5Core implementation of aiosqlite proxies 

6""" 

7 

8import asyncio 

9import logging 

10import sqlite3 

11from collections.abc import AsyncIterator, Generator, Iterable 

12from functools import partial 

13from pathlib import Path 

14from queue import Empty, Queue, SimpleQueue 

15from threading import Thread 

16from typing import Any, Callable, Literal, Optional, Union 

17from warnings import warn 

18 

19from .context import contextmanager 

20from .cursor import Cursor 

21 

22__all__ = ["connect", "Connection", "Cursor"] 

23 

24LOG = logging.getLogger("aiosqlite") 

25 

26 

27IsolationLevel = Optional[Literal["DEFERRED", "IMMEDIATE", "EXCLUSIVE"]] 

28 

29 

30def set_result(fut: asyncio.Future, result: Any) -> None: 

31 """Set the result of a future if it hasn't been set already.""" 

32 if not fut.done(): 

33 fut.set_result(result) 

34 

35 

36def set_exception(fut: asyncio.Future, e: BaseException) -> None: 

37 """Set the exception of a future if it hasn't been set already.""" 

38 if not fut.done(): 

39 fut.set_exception(e) 

40 

41 

42_STOP_RUNNING_SENTINEL = object() 

43 

44 

45class Connection(Thread): 

46 def __init__( 

47 self, 

48 connector: Callable[[], sqlite3.Connection], 

49 iter_chunk_size: int, 

50 loop: Optional[asyncio.AbstractEventLoop] = None, 

51 ) -> None: 

52 super().__init__() 

53 self._running = True 

54 self._connection: Optional[sqlite3.Connection] = None 

55 self._connector = connector 

56 self._tx: SimpleQueue[tuple[asyncio.Future, Callable[[], Any]]] = SimpleQueue() 

57 self._iter_chunk_size = iter_chunk_size 

58 

59 if loop is not None: 

60 warn( 

61 "aiosqlite.Connection no longer uses the `loop` parameter", 

62 DeprecationWarning, 

63 ) 

64 

65 def _stop_running(self): 

66 self._running = False 

67 # PEP 661 is not accepted yet, so we cannot type a sentinel 

68 self._tx.put_nowait(_STOP_RUNNING_SENTINEL) # type: ignore[arg-type] 

69 

70 @property 

71 def _conn(self) -> sqlite3.Connection: 

72 if self._connection is None: 

73 raise ValueError("no active connection") 

74 

75 return self._connection 

76 

77 def _execute_insert(self, sql: str, parameters: Any) -> Optional[sqlite3.Row]: 

78 cursor = self._conn.execute(sql, parameters) 

79 cursor.execute("SELECT last_insert_rowid()") 

80 return cursor.fetchone() 

81 

82 def _execute_fetchall(self, sql: str, parameters: Any) -> Iterable[sqlite3.Row]: 

83 cursor = self._conn.execute(sql, parameters) 

84 return cursor.fetchall() 

85 

86 def run(self) -> None: 

87 """ 

88 Execute function calls on a separate thread. 

89 

90 :meta private: 

91 """ 

92 while True: 

93 # Continues running until all queue items are processed, 

94 # even after connection is closed (so we can finalize all 

95 # futures) 

96 

97 tx_item = self._tx.get() 

98 if tx_item is _STOP_RUNNING_SENTINEL: 

99 break 

100 

101 future, function = tx_item 

102 

103 try: 

104 LOG.debug("executing %s", function) 

105 result = function() 

106 LOG.debug("operation %s completed", function) 

107 future.get_loop().call_soon_threadsafe(set_result, future, result) 

108 except BaseException as e: # noqa B036 

109 LOG.debug("returning exception %s", e) 

110 future.get_loop().call_soon_threadsafe(set_exception, future, e) 

111 

112 async def _execute(self, fn, *args, **kwargs): 

113 """Queue a function with the given arguments for execution.""" 

114 if not self._running or not self._connection: 

115 raise ValueError("Connection closed") 

116 

117 function = partial(fn, *args, **kwargs) 

118 future = asyncio.get_event_loop().create_future() 

119 

120 self._tx.put_nowait((future, function)) 

121 

122 return await future 

123 

124 async def _connect(self) -> "Connection": 

125 """Connect to the actual sqlite database.""" 

126 if self._connection is None: 

127 try: 

128 future = asyncio.get_event_loop().create_future() 

129 self._tx.put_nowait((future, self._connector)) 

130 self._connection = await future 

131 except BaseException: 

132 self._stop_running() 

133 self._connection = None 

134 raise 

135 

136 return self 

137 

138 def __await__(self) -> Generator[Any, None, "Connection"]: 

139 self.start() 

140 return self._connect().__await__() 

141 

142 async def __aenter__(self) -> "Connection": 

143 return await self 

144 

145 async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: 

146 await self.close() 

147 

148 @contextmanager 

149 async def cursor(self) -> Cursor: 

150 """Create an aiosqlite cursor wrapping a sqlite3 cursor object.""" 

151 return Cursor(self, await self._execute(self._conn.cursor)) 

152 

153 async def commit(self) -> None: 

154 """Commit the current transaction.""" 

155 await self._execute(self._conn.commit) 

156 

157 async def rollback(self) -> None: 

158 """Roll back the current transaction.""" 

159 await self._execute(self._conn.rollback) 

160 

161 async def close(self) -> None: 

162 """Complete queued queries/cursors and close the connection.""" 

163 

164 if self._connection is None: 

165 return 

166 

167 try: 

168 await self._execute(self._conn.close) 

169 except Exception: 

170 LOG.info("exception occurred while closing connection") 

171 raise 

172 finally: 

173 self._stop_running() 

174 self._connection = None 

175 

176 @contextmanager 

177 async def execute( 

178 self, sql: str, parameters: Optional[Iterable[Any]] = None 

179 ) -> Cursor: 

180 """Helper to create a cursor and execute the given query.""" 

181 if parameters is None: 

182 parameters = [] 

183 cursor = await self._execute(self._conn.execute, sql, parameters) 

184 return Cursor(self, cursor) 

185 

186 @contextmanager 

187 async def execute_insert( 

188 self, sql: str, parameters: Optional[Iterable[Any]] = None 

189 ) -> Optional[sqlite3.Row]: 

190 """Helper to insert and get the last_insert_rowid.""" 

191 if parameters is None: 

192 parameters = [] 

193 return await self._execute(self._execute_insert, sql, parameters) 

194 

195 @contextmanager 

196 async def execute_fetchall( 

197 self, sql: str, parameters: Optional[Iterable[Any]] = None 

198 ) -> Iterable[sqlite3.Row]: 

199 """Helper to execute a query and return all the data.""" 

200 if parameters is None: 

201 parameters = [] 

202 return await self._execute(self._execute_fetchall, sql, parameters) 

203 

204 @contextmanager 

205 async def executemany( 

206 self, sql: str, parameters: Iterable[Iterable[Any]] 

207 ) -> Cursor: 

208 """Helper to create a cursor and execute the given multiquery.""" 

209 cursor = await self._execute(self._conn.executemany, sql, parameters) 

210 return Cursor(self, cursor) 

211 

212 @contextmanager 

213 async def executescript(self, sql_script: str) -> Cursor: 

214 """Helper to create a cursor and execute a user script.""" 

215 cursor = await self._execute(self._conn.executescript, sql_script) 

216 return Cursor(self, cursor) 

217 

218 async def interrupt(self) -> None: 

219 """Interrupt pending queries.""" 

220 return self._conn.interrupt() 

221 

222 async def create_function( 

223 self, name: str, num_params: int, func: Callable, deterministic: bool = False 

224 ) -> None: 

225 """ 

226 Create user-defined function that can be later used 

227 within SQL statements. Must be run within the same thread 

228 that query executions take place so instead of executing directly 

229 against the connection, we defer this to `run` function. 

230 

231 If ``deterministic`` is true, the created function is marked as deterministic, 

232 which allows SQLite to perform additional optimizations. This flag is supported 

233 by SQLite 3.8.3 or higher, ``NotSupportedError`` will be raised if used with 

234 older versions. 

235 """ 

236 await self._execute( 

237 self._conn.create_function, 

238 name, 

239 num_params, 

240 func, 

241 deterministic=deterministic, 

242 ) 

243 

244 @property 

245 def in_transaction(self) -> bool: 

246 return self._conn.in_transaction 

247 

248 @property 

249 def isolation_level(self) -> Optional[str]: 

250 return self._conn.isolation_level 

251 

252 @isolation_level.setter 

253 def isolation_level(self, value: IsolationLevel) -> None: 

254 self._conn.isolation_level = value 

255 

256 @property 

257 def row_factory(self) -> Optional[type]: 

258 return self._conn.row_factory 

259 

260 @row_factory.setter 

261 def row_factory(self, factory: Optional[type]) -> None: 

262 self._conn.row_factory = factory 

263 

264 @property 

265 def text_factory(self) -> Callable[[bytes], Any]: 

266 return self._conn.text_factory 

267 

268 @text_factory.setter 

269 def text_factory(self, factory: Callable[[bytes], Any]) -> None: 

270 self._conn.text_factory = factory 

271 

272 @property 

273 def total_changes(self) -> int: 

274 return self._conn.total_changes 

275 

276 async def enable_load_extension(self, value: bool) -> None: 

277 await self._execute(self._conn.enable_load_extension, value) # type: ignore 

278 

279 async def load_extension(self, path: str): 

280 await self._execute(self._conn.load_extension, path) # type: ignore 

281 

282 async def set_progress_handler( 

283 self, handler: Callable[[], Optional[int]], n: int 

284 ) -> None: 

285 await self._execute(self._conn.set_progress_handler, handler, n) 

286 

287 async def set_trace_callback(self, handler: Callable) -> None: 

288 await self._execute(self._conn.set_trace_callback, handler) 

289 

290 async def iterdump(self) -> AsyncIterator[str]: 

291 """ 

292 Return an async iterator to dump the database in SQL text format. 

293 

294 Example:: 

295 

296 async for line in db.iterdump(): 

297 ... 

298 

299 """ 

300 dump_queue: Queue = Queue() 

301 

302 def dumper(): 

303 try: 

304 for line in self._conn.iterdump(): 

305 dump_queue.put_nowait(line) 

306 dump_queue.put_nowait(None) 

307 

308 except Exception: 

309 LOG.exception("exception while dumping db") 

310 dump_queue.put_nowait(None) 

311 raise 

312 

313 fut = self._execute(dumper) 

314 task = asyncio.ensure_future(fut) 

315 

316 while True: 

317 try: 

318 line: Optional[str] = dump_queue.get_nowait() 

319 if line is None: 

320 break 

321 yield line 

322 

323 except Empty: 

324 if task.done(): 

325 LOG.warning("iterdump completed unexpectedly") 

326 break 

327 

328 await asyncio.sleep(0.01) 

329 

330 await task 

331 

332 async def backup( 

333 self, 

334 target: Union["Connection", sqlite3.Connection], 

335 *, 

336 pages: int = 0, 

337 progress: Optional[Callable[[int, int, int], None]] = None, 

338 name: str = "main", 

339 sleep: float = 0.250, 

340 ) -> None: 

341 """ 

342 Make a backup of the current database to the target database. 

343 

344 Takes either a standard sqlite3 or aiosqlite Connection object as the target. 

345 """ 

346 if isinstance(target, Connection): 

347 target = target._conn 

348 

349 await self._execute( 

350 self._conn.backup, 

351 target, 

352 pages=pages, 

353 progress=progress, 

354 name=name, 

355 sleep=sleep, 

356 ) 

357 

358 

359def connect( 

360 database: Union[str, Path], 

361 *, 

362 iter_chunk_size=64, 

363 loop: Optional[asyncio.AbstractEventLoop] = None, 

364 **kwargs: Any, 

365) -> Connection: 

366 """Create and return a connection proxy to the sqlite database.""" 

367 

368 if loop is not None: 

369 warn( 

370 "aiosqlite.connect() no longer uses the `loop` parameter", 

371 DeprecationWarning, 

372 ) 

373 

374 def connector() -> sqlite3.Connection: 

375 if isinstance(database, str): 

376 loc = database 

377 elif isinstance(database, bytes): 

378 loc = database.decode("utf-8") 

379 else: 

380 loc = str(database) 

381 

382 return sqlite3.connect(loc, **kwargs) 

383 

384 return Connection(connector, iter_chunk_size)