Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/aiosqlite/core.py: 35%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright Amethyst Reese
2# Licensed under the MIT license
4"""
5Core implementation of aiosqlite proxies
6"""
8import asyncio
9import logging
10import sqlite3
11from collections.abc import AsyncIterator, Generator, Iterable
12from functools import partial
13from pathlib import Path
14from queue import Empty, Queue, SimpleQueue
15from threading import Thread
16from typing import Any, Callable, Literal, Optional, Union
17from warnings import warn
19from .context import contextmanager
20from .cursor import Cursor
22__all__ = ["connect", "Connection", "Cursor"]
24LOG = logging.getLogger("aiosqlite")
27IsolationLevel = Optional[Literal["DEFERRED", "IMMEDIATE", "EXCLUSIVE"]]
30def set_result(fut: asyncio.Future, result: Any) -> None:
31 """Set the result of a future if it hasn't been set already."""
32 if not fut.done():
33 fut.set_result(result)
36def set_exception(fut: asyncio.Future, e: BaseException) -> None:
37 """Set the exception of a future if it hasn't been set already."""
38 if not fut.done():
39 fut.set_exception(e)
42_STOP_RUNNING_SENTINEL = object()
45class Connection(Thread):
46 def __init__(
47 self,
48 connector: Callable[[], sqlite3.Connection],
49 iter_chunk_size: int,
50 loop: Optional[asyncio.AbstractEventLoop] = None,
51 ) -> None:
52 super().__init__()
53 self._running = True
54 self._connection: Optional[sqlite3.Connection] = None
55 self._connector = connector
56 self._tx: SimpleQueue[tuple[asyncio.Future, Callable[[], Any]]] = SimpleQueue()
57 self._iter_chunk_size = iter_chunk_size
59 if loop is not None:
60 warn(
61 "aiosqlite.Connection no longer uses the `loop` parameter",
62 DeprecationWarning,
63 )
65 def _stop_running(self):
66 self._running = False
67 # PEP 661 is not accepted yet, so we cannot type a sentinel
68 self._tx.put_nowait(_STOP_RUNNING_SENTINEL) # type: ignore[arg-type]
70 @property
71 def _conn(self) -> sqlite3.Connection:
72 if self._connection is None:
73 raise ValueError("no active connection")
75 return self._connection
77 def _execute_insert(self, sql: str, parameters: Any) -> Optional[sqlite3.Row]:
78 cursor = self._conn.execute(sql, parameters)
79 cursor.execute("SELECT last_insert_rowid()")
80 return cursor.fetchone()
82 def _execute_fetchall(self, sql: str, parameters: Any) -> Iterable[sqlite3.Row]:
83 cursor = self._conn.execute(sql, parameters)
84 return cursor.fetchall()
86 def run(self) -> None:
87 """
88 Execute function calls on a separate thread.
90 :meta private:
91 """
92 while True:
93 # Continues running until all queue items are processed,
94 # even after connection is closed (so we can finalize all
95 # futures)
97 tx_item = self._tx.get()
98 if tx_item is _STOP_RUNNING_SENTINEL:
99 break
101 future, function = tx_item
103 try:
104 LOG.debug("executing %s", function)
105 result = function()
106 LOG.debug("operation %s completed", function)
107 future.get_loop().call_soon_threadsafe(set_result, future, result)
108 except BaseException as e: # noqa B036
109 LOG.debug("returning exception %s", e)
110 future.get_loop().call_soon_threadsafe(set_exception, future, e)
112 async def _execute(self, fn, *args, **kwargs):
113 """Queue a function with the given arguments for execution."""
114 if not self._running or not self._connection:
115 raise ValueError("Connection closed")
117 function = partial(fn, *args, **kwargs)
118 future = asyncio.get_event_loop().create_future()
120 self._tx.put_nowait((future, function))
122 return await future
124 async def _connect(self) -> "Connection":
125 """Connect to the actual sqlite database."""
126 if self._connection is None:
127 try:
128 future = asyncio.get_event_loop().create_future()
129 self._tx.put_nowait((future, self._connector))
130 self._connection = await future
131 except BaseException:
132 self._stop_running()
133 self._connection = None
134 raise
136 return self
138 def __await__(self) -> Generator[Any, None, "Connection"]:
139 self.start()
140 return self._connect().__await__()
142 async def __aenter__(self) -> "Connection":
143 return await self
145 async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
146 await self.close()
148 @contextmanager
149 async def cursor(self) -> Cursor:
150 """Create an aiosqlite cursor wrapping a sqlite3 cursor object."""
151 return Cursor(self, await self._execute(self._conn.cursor))
153 async def commit(self) -> None:
154 """Commit the current transaction."""
155 await self._execute(self._conn.commit)
157 async def rollback(self) -> None:
158 """Roll back the current transaction."""
159 await self._execute(self._conn.rollback)
161 async def close(self) -> None:
162 """Complete queued queries/cursors and close the connection."""
164 if self._connection is None:
165 return
167 try:
168 await self._execute(self._conn.close)
169 except Exception:
170 LOG.info("exception occurred while closing connection")
171 raise
172 finally:
173 self._stop_running()
174 self._connection = None
176 @contextmanager
177 async def execute(
178 self, sql: str, parameters: Optional[Iterable[Any]] = None
179 ) -> Cursor:
180 """Helper to create a cursor and execute the given query."""
181 if parameters is None:
182 parameters = []
183 cursor = await self._execute(self._conn.execute, sql, parameters)
184 return Cursor(self, cursor)
186 @contextmanager
187 async def execute_insert(
188 self, sql: str, parameters: Optional[Iterable[Any]] = None
189 ) -> Optional[sqlite3.Row]:
190 """Helper to insert and get the last_insert_rowid."""
191 if parameters is None:
192 parameters = []
193 return await self._execute(self._execute_insert, sql, parameters)
195 @contextmanager
196 async def execute_fetchall(
197 self, sql: str, parameters: Optional[Iterable[Any]] = None
198 ) -> Iterable[sqlite3.Row]:
199 """Helper to execute a query and return all the data."""
200 if parameters is None:
201 parameters = []
202 return await self._execute(self._execute_fetchall, sql, parameters)
204 @contextmanager
205 async def executemany(
206 self, sql: str, parameters: Iterable[Iterable[Any]]
207 ) -> Cursor:
208 """Helper to create a cursor and execute the given multiquery."""
209 cursor = await self._execute(self._conn.executemany, sql, parameters)
210 return Cursor(self, cursor)
212 @contextmanager
213 async def executescript(self, sql_script: str) -> Cursor:
214 """Helper to create a cursor and execute a user script."""
215 cursor = await self._execute(self._conn.executescript, sql_script)
216 return Cursor(self, cursor)
218 async def interrupt(self) -> None:
219 """Interrupt pending queries."""
220 return self._conn.interrupt()
222 async def create_function(
223 self, name: str, num_params: int, func: Callable, deterministic: bool = False
224 ) -> None:
225 """
226 Create user-defined function that can be later used
227 within SQL statements. Must be run within the same thread
228 that query executions take place so instead of executing directly
229 against the connection, we defer this to `run` function.
231 If ``deterministic`` is true, the created function is marked as deterministic,
232 which allows SQLite to perform additional optimizations. This flag is supported
233 by SQLite 3.8.3 or higher, ``NotSupportedError`` will be raised if used with
234 older versions.
235 """
236 await self._execute(
237 self._conn.create_function,
238 name,
239 num_params,
240 func,
241 deterministic=deterministic,
242 )
244 @property
245 def in_transaction(self) -> bool:
246 return self._conn.in_transaction
248 @property
249 def isolation_level(self) -> Optional[str]:
250 return self._conn.isolation_level
252 @isolation_level.setter
253 def isolation_level(self, value: IsolationLevel) -> None:
254 self._conn.isolation_level = value
256 @property
257 def row_factory(self) -> Optional[type]:
258 return self._conn.row_factory
260 @row_factory.setter
261 def row_factory(self, factory: Optional[type]) -> None:
262 self._conn.row_factory = factory
264 @property
265 def text_factory(self) -> Callable[[bytes], Any]:
266 return self._conn.text_factory
268 @text_factory.setter
269 def text_factory(self, factory: Callable[[bytes], Any]) -> None:
270 self._conn.text_factory = factory
272 @property
273 def total_changes(self) -> int:
274 return self._conn.total_changes
276 async def enable_load_extension(self, value: bool) -> None:
277 await self._execute(self._conn.enable_load_extension, value) # type: ignore
279 async def load_extension(self, path: str):
280 await self._execute(self._conn.load_extension, path) # type: ignore
282 async def set_progress_handler(
283 self, handler: Callable[[], Optional[int]], n: int
284 ) -> None:
285 await self._execute(self._conn.set_progress_handler, handler, n)
287 async def set_trace_callback(self, handler: Callable) -> None:
288 await self._execute(self._conn.set_trace_callback, handler)
290 async def iterdump(self) -> AsyncIterator[str]:
291 """
292 Return an async iterator to dump the database in SQL text format.
294 Example::
296 async for line in db.iterdump():
297 ...
299 """
300 dump_queue: Queue = Queue()
302 def dumper():
303 try:
304 for line in self._conn.iterdump():
305 dump_queue.put_nowait(line)
306 dump_queue.put_nowait(None)
308 except Exception:
309 LOG.exception("exception while dumping db")
310 dump_queue.put_nowait(None)
311 raise
313 fut = self._execute(dumper)
314 task = asyncio.ensure_future(fut)
316 while True:
317 try:
318 line: Optional[str] = dump_queue.get_nowait()
319 if line is None:
320 break
321 yield line
323 except Empty:
324 if task.done():
325 LOG.warning("iterdump completed unexpectedly")
326 break
328 await asyncio.sleep(0.01)
330 await task
332 async def backup(
333 self,
334 target: Union["Connection", sqlite3.Connection],
335 *,
336 pages: int = 0,
337 progress: Optional[Callable[[int, int, int], None]] = None,
338 name: str = "main",
339 sleep: float = 0.250,
340 ) -> None:
341 """
342 Make a backup of the current database to the target database.
344 Takes either a standard sqlite3 or aiosqlite Connection object as the target.
345 """
346 if isinstance(target, Connection):
347 target = target._conn
349 await self._execute(
350 self._conn.backup,
351 target,
352 pages=pages,
353 progress=progress,
354 name=name,
355 sleep=sleep,
356 )
359def connect(
360 database: Union[str, Path],
361 *,
362 iter_chunk_size=64,
363 loop: Optional[asyncio.AbstractEventLoop] = None,
364 **kwargs: Any,
365) -> Connection:
366 """Create and return a connection proxy to the sqlite database."""
368 if loop is not None:
369 warn(
370 "aiosqlite.connect() no longer uses the `loop` parameter",
371 DeprecationWarning,
372 )
374 def connector() -> sqlite3.Connection:
375 if isinstance(database, str):
376 loc = database
377 elif isinstance(database, bytes):
378 loc = database.decode("utf-8")
379 else:
380 loc = str(database)
382 return sqlite3.connect(loc, **kwargs)
384 return Connection(connector, iter_chunk_size)