Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/aiosqlite/core.py: 33%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright Amethyst Reese
2# Licensed under the MIT license
4"""
5Core implementation of aiosqlite proxies
6"""
8import asyncio
9import logging
10import sqlite3
11from collections.abc import AsyncIterator, Generator, Iterable
12from functools import partial
13from pathlib import Path
14from queue import Empty, Queue, SimpleQueue
15from threading import Thread
16from typing import Any, Callable, Literal, Optional, Union
17from warnings import warn
19from .context import contextmanager
20from .cursor import Cursor
22__all__ = ["connect", "Connection", "Cursor"]
24AuthorizerCallback = Callable[[int, str, str, str, str], int]
26LOG = logging.getLogger("aiosqlite")
29IsolationLevel = Optional[Literal["DEFERRED", "IMMEDIATE", "EXCLUSIVE"]]
32def set_result(fut: asyncio.Future, result: Any) -> None:
33 """Set the result of a future if it hasn't been set already."""
34 if not fut.done():
35 fut.set_result(result)
38def set_exception(fut: asyncio.Future, e: BaseException) -> None:
39 """Set the exception of a future if it hasn't been set already."""
40 if not fut.done():
41 fut.set_exception(e)
44_STOP_RUNNING_SENTINEL = object()
45_TxQueue = SimpleQueue[tuple[Optional[asyncio.Future], Callable[[], Any]]]
48def _connection_worker_thread(tx: _TxQueue):
49 """
50 Execute function calls on a separate thread.
52 :meta private:
53 """
54 while True:
55 # Continues running until all queue items are processed,
56 # even after connection is closed (so we can finalize all
57 # futures)
59 future, function = tx.get()
61 try:
62 LOG.debug("executing %s", function)
63 result = function()
65 if future:
66 future.get_loop().call_soon_threadsafe(set_result, future, result)
67 LOG.debug("operation %s completed", function)
69 if result is _STOP_RUNNING_SENTINEL:
70 break
72 except BaseException as e: # noqa B036
73 LOG.debug("returning exception %s", e)
74 if future:
75 future.get_loop().call_soon_threadsafe(set_exception, future, e)
78class Connection:
79 def __init__(
80 self,
81 connector: Callable[[], sqlite3.Connection],
82 iter_chunk_size: int,
83 loop: Optional[asyncio.AbstractEventLoop] = None,
84 ) -> None:
85 self._running = True
86 self._connection: Optional[sqlite3.Connection] = None
87 self._connector = connector
88 self._tx: _TxQueue = SimpleQueue()
89 self._iter_chunk_size = iter_chunk_size
90 self._thread = Thread(target=_connection_worker_thread, args=(self._tx,))
92 if loop is not None:
93 warn(
94 "aiosqlite.Connection no longer uses the `loop` parameter",
95 DeprecationWarning,
96 )
98 def __del__(self):
99 if self._connection is None:
100 return
102 warn(
103 (
104 f"{self!r} was deleted before being closed. "
105 "Please use 'async with' or '.close()' to close the connection properly."
106 ),
107 ResourceWarning,
108 stacklevel=1,
109 )
111 # Don't try to be creative here, the event loop may have already been closed.
112 # Simply stop the worker thread, and let the underlying sqlite3 connection
113 # be finalized by its own __del__.
114 self.stop()
116 def stop(self) -> Optional[asyncio.Future]:
117 """Stop the background thread. Prefer `async with` or `await close()`"""
118 self._running = False
120 def close_and_stop():
121 if self._connection is not None:
122 self._connection.close()
123 self._connection = None
124 return _STOP_RUNNING_SENTINEL
126 try:
127 future = asyncio.get_event_loop().create_future()
128 except Exception:
129 future = None
131 self._tx.put_nowait((future, close_and_stop))
132 return future
134 @property
135 def _conn(self) -> sqlite3.Connection:
136 if self._connection is None:
137 raise ValueError("no active connection")
139 return self._connection
141 def _execute_insert(self, sql: str, parameters: Any) -> Optional[sqlite3.Row]:
142 cursor = self._conn.execute(sql, parameters)
143 cursor.execute("SELECT last_insert_rowid()")
144 return cursor.fetchone()
146 def _execute_fetchall(self, sql: str, parameters: Any) -> Iterable[sqlite3.Row]:
147 cursor = self._conn.execute(sql, parameters)
148 return cursor.fetchall()
150 async def _execute(self, fn, *args, **kwargs):
151 """Queue a function with the given arguments for execution."""
152 if not self._running or not self._connection:
153 raise ValueError("Connection closed")
155 function = partial(fn, *args, **kwargs)
156 future = asyncio.get_event_loop().create_future()
158 self._tx.put_nowait((future, function))
160 return await future
162 async def _connect(self) -> "Connection":
163 """Connect to the actual sqlite database."""
164 if self._connection is None:
165 try:
166 future = asyncio.get_event_loop().create_future()
167 self._tx.put_nowait((future, self._connector))
168 self._connection = await future
169 except BaseException:
170 self.stop()
171 self._connection = None
172 raise
174 return self
176 def __await__(self) -> Generator[Any, None, "Connection"]:
177 self._thread.start()
178 return self._connect().__await__()
180 async def __aenter__(self) -> "Connection":
181 return await self
183 async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
184 await self.close()
186 @contextmanager
187 async def cursor(self) -> Cursor:
188 """Create an aiosqlite cursor wrapping a sqlite3 cursor object."""
189 return Cursor(self, await self._execute(self._conn.cursor))
191 async def commit(self) -> None:
192 """Commit the current transaction."""
193 await self._execute(self._conn.commit)
195 async def rollback(self) -> None:
196 """Roll back the current transaction."""
197 await self._execute(self._conn.rollback)
199 async def close(self) -> None:
200 """Complete queued queries/cursors and close the connection."""
202 if self._connection is None:
203 return
205 try:
206 await self._execute(self._conn.close)
207 except Exception:
208 LOG.info("exception occurred while closing connection")
209 raise
210 finally:
211 self._connection = None
212 future = self.stop()
213 if future:
214 await future
216 @contextmanager
217 async def execute(
218 self, sql: str, parameters: Optional[Iterable[Any]] = None
219 ) -> Cursor:
220 """Helper to create a cursor and execute the given query."""
221 if parameters is None:
222 parameters = []
223 cursor = await self._execute(self._conn.execute, sql, parameters)
224 return Cursor(self, cursor)
226 @contextmanager
227 async def execute_insert(
228 self, sql: str, parameters: Optional[Iterable[Any]] = None
229 ) -> Optional[sqlite3.Row]:
230 """Helper to insert and get the last_insert_rowid."""
231 if parameters is None:
232 parameters = []
233 return await self._execute(self._execute_insert, sql, parameters)
235 @contextmanager
236 async def execute_fetchall(
237 self, sql: str, parameters: Optional[Iterable[Any]] = None
238 ) -> Iterable[sqlite3.Row]:
239 """Helper to execute a query and return all the data."""
240 if parameters is None:
241 parameters = []
242 return await self._execute(self._execute_fetchall, sql, parameters)
244 @contextmanager
245 async def executemany(
246 self, sql: str, parameters: Iterable[Iterable[Any]]
247 ) -> Cursor:
248 """Helper to create a cursor and execute the given multiquery."""
249 cursor = await self._execute(self._conn.executemany, sql, parameters)
250 return Cursor(self, cursor)
252 @contextmanager
253 async def executescript(self, sql_script: str) -> Cursor:
254 """Helper to create a cursor and execute a user script."""
255 cursor = await self._execute(self._conn.executescript, sql_script)
256 return Cursor(self, cursor)
258 async def interrupt(self) -> None:
259 """Interrupt pending queries."""
260 return self._conn.interrupt()
262 async def create_function(
263 self, name: str, num_params: int, func: Callable, deterministic: bool = False
264 ) -> None:
265 """
266 Create user-defined function that can be later used
267 within SQL statements. Must be run within the same thread
268 that query executions take place so instead of executing directly
269 against the connection, we defer this to `run` function.
271 If ``deterministic`` is true, the created function is marked as deterministic,
272 which allows SQLite to perform additional optimizations. This flag is supported
273 by SQLite 3.8.3 or higher, ``NotSupportedError`` will be raised if used with
274 older versions.
275 """
276 await self._execute(
277 self._conn.create_function,
278 name,
279 num_params,
280 func,
281 deterministic=deterministic,
282 )
284 @property
285 def in_transaction(self) -> bool:
286 return self._conn.in_transaction
288 @property
289 def isolation_level(self) -> Optional[str]:
290 return self._conn.isolation_level
292 @isolation_level.setter
293 def isolation_level(self, value: IsolationLevel) -> None:
294 self._conn.isolation_level = value
296 @property
297 def row_factory(self) -> Optional[type]:
298 return self._conn.row_factory
300 @row_factory.setter
301 def row_factory(self, factory: Optional[type]) -> None:
302 self._conn.row_factory = factory
304 @property
305 def text_factory(self) -> Callable[[bytes], Any]:
306 return self._conn.text_factory
308 @text_factory.setter
309 def text_factory(self, factory: Callable[[bytes], Any]) -> None:
310 self._conn.text_factory = factory
312 @property
313 def total_changes(self) -> int:
314 return self._conn.total_changes
316 async def enable_load_extension(self, value: bool) -> None:
317 await self._execute(self._conn.enable_load_extension, value) # type: ignore
319 async def load_extension(self, path: str):
320 await self._execute(self._conn.load_extension, path) # type: ignore
322 async def set_progress_handler(
323 self, handler: Callable[[], Optional[int]], n: int
324 ) -> None:
325 await self._execute(self._conn.set_progress_handler, handler, n)
327 async def set_trace_callback(self, handler: Callable) -> None:
328 await self._execute(self._conn.set_trace_callback, handler)
330 async def set_authorizer(
331 self, authorizer_callback: Optional[AuthorizerCallback]
332 ) -> None:
333 """
334 Set an authorizer callback to control database access.
336 The authorizer callback is invoked for each SQL statement that is prepared,
337 and controls whether specific operations are permitted.
339 Example::
341 import sqlite3
343 def restrict_drops(action_code, arg1, arg2, db_name, trigger_name):
344 # Deny all DROP operations
345 if action_code == sqlite3.SQLITE_DROP_TABLE:
346 return sqlite3.SQLITE_DENY
347 # Allow everything else
348 return sqlite3.SQLITE_OK
350 await conn.set_authorizer(restrict_drops)
352 See ``sqlite3`` documentation for details:
353 https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.set_authorizer
355 :param authorizer_callback: An optional callable that receives five arguments:
357 - ``action_code`` (int): The action to be authorized (e.g., ``SQLITE_READ``)
358 - ``arg1`` (str): First argument, meaning depends on ``action_code``
359 - ``arg2`` (str): Second argument, meaning depends on ``action_code``
360 - ``db_name`` (str): Database name (e.g., ``"main"``, ``"temp"``)
361 - ``trigger_name`` (str): Name of trigger or view that is doing the access,
362 or ``None``
364 The callback should return:
366 - ``SQLITE_OK`` (0): Allow the operation
367 - ``SQLITE_DENY`` (1): Deny the operation, raise ``sqlite3.DatabaseError``
368 - ``SQLITE_IGNORE`` (2): Treat operation as no-op
370 Pass ``None`` to remove the authorizer.
371 """
372 await self._execute(self._conn.set_authorizer, authorizer_callback)
374 async def iterdump(self) -> AsyncIterator[str]:
375 """
376 Return an async iterator to dump the database in SQL text format.
378 Example::
380 async for line in db.iterdump():
381 ...
383 """
384 dump_queue: Queue = Queue()
386 def dumper():
387 try:
388 for line in self._conn.iterdump():
389 dump_queue.put_nowait(line)
390 dump_queue.put_nowait(None)
392 except Exception:
393 LOG.exception("exception while dumping db")
394 dump_queue.put_nowait(None)
395 raise
397 fut = self._execute(dumper)
398 task = asyncio.ensure_future(fut)
400 while True:
401 try:
402 line: Optional[str] = dump_queue.get_nowait()
403 if line is None:
404 break
405 yield line
407 except Empty:
408 if task.done():
409 LOG.warning("iterdump completed unexpectedly")
410 break
412 await asyncio.sleep(0.01)
414 await task
416 async def backup(
417 self,
418 target: Union["Connection", sqlite3.Connection],
419 *,
420 pages: int = 0,
421 progress: Optional[Callable[[int, int, int], None]] = None,
422 name: str = "main",
423 sleep: float = 0.250,
424 ) -> None:
425 """
426 Make a backup of the current database to the target database.
428 Takes either a standard sqlite3 or aiosqlite Connection object as the target.
429 """
430 if isinstance(target, Connection):
431 target = target._conn
433 await self._execute(
434 self._conn.backup,
435 target,
436 pages=pages,
437 progress=progress,
438 name=name,
439 sleep=sleep,
440 )
443def connect(
444 database: Union[str, Path],
445 *,
446 iter_chunk_size=64,
447 loop: Optional[asyncio.AbstractEventLoop] = None,
448 **kwargs: Any,
449) -> Connection:
450 """Create and return a connection proxy to the sqlite database."""
452 if loop is not None:
453 warn(
454 "aiosqlite.connect() no longer uses the `loop` parameter",
455 DeprecationWarning,
456 )
458 def connector() -> sqlite3.Connection:
459 if isinstance(database, str):
460 loc = database
461 elif isinstance(database, bytes):
462 loc = database.decode("utf-8")
463 else:
464 loc = str(database)
466 return sqlite3.connect(loc, **kwargs)
468 return Connection(connector, iter_chunk_size)