Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/aiosqlite/core.py: 35%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright Amethyst Reese
2# Licensed under the MIT license
4"""
5Core implementation of aiosqlite proxies
6"""
8import asyncio
9import logging
10import sqlite3
11from collections.abc import AsyncIterator, Generator, Iterable
12from functools import partial
13from pathlib import Path
14from queue import Empty, Queue, SimpleQueue
15from threading import Thread
16from typing import Any, Callable, Literal, Optional, Union
17from warnings import warn
19from .context import contextmanager
20from .cursor import Cursor
22__all__ = ["connect", "Connection", "Cursor"]
24AuthorizerCallback = Callable[[int, str, str, str, str], int]
26LOG = logging.getLogger("aiosqlite")
29IsolationLevel = Optional[Literal["DEFERRED", "IMMEDIATE", "EXCLUSIVE"]]
32def set_result(fut: asyncio.Future, result: Any) -> None:
33 """Set the result of a future if it hasn't been set already."""
34 if not fut.done():
35 fut.set_result(result)
38def set_exception(fut: asyncio.Future, e: BaseException) -> None:
39 """Set the exception of a future if it hasn't been set already."""
40 if not fut.done():
41 fut.set_exception(e)
44_STOP_RUNNING_SENTINEL = object()
47def _connection_worker_thread(
48 tx: SimpleQueue[tuple[asyncio.Future, Callable[[], Any]]],
49):
50 """
51 Execute function calls on a separate thread.
53 :meta private:
54 """
55 while True:
56 # Continues running until all queue items are processed,
57 # even after connection is closed (so we can finalize all
58 # futures)
60 tx_item = tx.get()
61 future, function = tx_item
63 try:
64 LOG.debug("executing %s", function)
65 result = function()
66 LOG.debug("operation %s completed", function)
67 future.get_loop().call_soon_threadsafe(set_result, future, result)
69 if result is _STOP_RUNNING_SENTINEL:
70 break
72 except BaseException as e: # noqa B036
73 LOG.debug("returning exception %s", e)
74 future.get_loop().call_soon_threadsafe(set_exception, future, e)
77class Connection:
78 def __init__(
79 self,
80 connector: Callable[[], sqlite3.Connection],
81 iter_chunk_size: int,
82 loop: Optional[asyncio.AbstractEventLoop] = None,
83 ) -> None:
84 self._running = True
85 self._connection: Optional[sqlite3.Connection] = None
86 self._connector = connector
87 self._tx: SimpleQueue[tuple[asyncio.Future, Callable[[], Any]]] = SimpleQueue()
88 self._iter_chunk_size = iter_chunk_size
89 self._thread = Thread(target=_connection_worker_thread, args=(self._tx,))
91 if loop is not None:
92 warn(
93 "aiosqlite.Connection no longer uses the `loop` parameter",
94 DeprecationWarning,
95 )
97 def _stop_running(self) -> asyncio.Future:
98 self._running = False
100 function = partial(lambda: _STOP_RUNNING_SENTINEL)
101 future = asyncio.get_event_loop().create_future()
103 self._tx.put_nowait((future, function))
105 return future
107 @property
108 def _conn(self) -> sqlite3.Connection:
109 if self._connection is None:
110 raise ValueError("no active connection")
112 return self._connection
114 def _execute_insert(self, sql: str, parameters: Any) -> Optional[sqlite3.Row]:
115 cursor = self._conn.execute(sql, parameters)
116 cursor.execute("SELECT last_insert_rowid()")
117 return cursor.fetchone()
119 def _execute_fetchall(self, sql: str, parameters: Any) -> Iterable[sqlite3.Row]:
120 cursor = self._conn.execute(sql, parameters)
121 return cursor.fetchall()
123 async def _execute(self, fn, *args, **kwargs):
124 """Queue a function with the given arguments for execution."""
125 if not self._running or not self._connection:
126 raise ValueError("Connection closed")
128 function = partial(fn, *args, **kwargs)
129 future = asyncio.get_event_loop().create_future()
131 self._tx.put_nowait((future, function))
133 return await future
135 async def _connect(self) -> "Connection":
136 """Connect to the actual sqlite database."""
137 if self._connection is None:
138 try:
139 future = asyncio.get_event_loop().create_future()
140 self._tx.put_nowait((future, self._connector))
141 self._connection = await future
142 except BaseException:
143 await self._stop_running()
144 self._connection = None
145 raise
147 return self
149 def __await__(self) -> Generator[Any, None, "Connection"]:
150 self._thread.start()
151 return self._connect().__await__()
153 async def __aenter__(self) -> "Connection":
154 return await self
156 async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
157 await self.close()
159 @contextmanager
160 async def cursor(self) -> Cursor:
161 """Create an aiosqlite cursor wrapping a sqlite3 cursor object."""
162 return Cursor(self, await self._execute(self._conn.cursor))
164 async def commit(self) -> None:
165 """Commit the current transaction."""
166 await self._execute(self._conn.commit)
168 async def rollback(self) -> None:
169 """Roll back the current transaction."""
170 await self._execute(self._conn.rollback)
172 async def close(self) -> None:
173 """Complete queued queries/cursors and close the connection."""
175 if self._connection is None:
176 return
178 try:
179 await self._execute(self._conn.close)
180 except Exception:
181 LOG.info("exception occurred while closing connection")
182 raise
183 finally:
184 await self._stop_running()
185 self._connection = None
187 @contextmanager
188 async def execute(
189 self, sql: str, parameters: Optional[Iterable[Any]] = None
190 ) -> Cursor:
191 """Helper to create a cursor and execute the given query."""
192 if parameters is None:
193 parameters = []
194 cursor = await self._execute(self._conn.execute, sql, parameters)
195 return Cursor(self, cursor)
197 @contextmanager
198 async def execute_insert(
199 self, sql: str, parameters: Optional[Iterable[Any]] = None
200 ) -> Optional[sqlite3.Row]:
201 """Helper to insert and get the last_insert_rowid."""
202 if parameters is None:
203 parameters = []
204 return await self._execute(self._execute_insert, sql, parameters)
206 @contextmanager
207 async def execute_fetchall(
208 self, sql: str, parameters: Optional[Iterable[Any]] = None
209 ) -> Iterable[sqlite3.Row]:
210 """Helper to execute a query and return all the data."""
211 if parameters is None:
212 parameters = []
213 return await self._execute(self._execute_fetchall, sql, parameters)
215 @contextmanager
216 async def executemany(
217 self, sql: str, parameters: Iterable[Iterable[Any]]
218 ) -> Cursor:
219 """Helper to create a cursor and execute the given multiquery."""
220 cursor = await self._execute(self._conn.executemany, sql, parameters)
221 return Cursor(self, cursor)
223 @contextmanager
224 async def executescript(self, sql_script: str) -> Cursor:
225 """Helper to create a cursor and execute a user script."""
226 cursor = await self._execute(self._conn.executescript, sql_script)
227 return Cursor(self, cursor)
229 async def interrupt(self) -> None:
230 """Interrupt pending queries."""
231 return self._conn.interrupt()
233 async def create_function(
234 self, name: str, num_params: int, func: Callable, deterministic: bool = False
235 ) -> None:
236 """
237 Create user-defined function that can be later used
238 within SQL statements. Must be run within the same thread
239 that query executions take place so instead of executing directly
240 against the connection, we defer this to `run` function.
242 If ``deterministic`` is true, the created function is marked as deterministic,
243 which allows SQLite to perform additional optimizations. This flag is supported
244 by SQLite 3.8.3 or higher, ``NotSupportedError`` will be raised if used with
245 older versions.
246 """
247 await self._execute(
248 self._conn.create_function,
249 name,
250 num_params,
251 func,
252 deterministic=deterministic,
253 )
255 @property
256 def in_transaction(self) -> bool:
257 return self._conn.in_transaction
259 @property
260 def isolation_level(self) -> Optional[str]:
261 return self._conn.isolation_level
263 @isolation_level.setter
264 def isolation_level(self, value: IsolationLevel) -> None:
265 self._conn.isolation_level = value
267 @property
268 def row_factory(self) -> Optional[type]:
269 return self._conn.row_factory
271 @row_factory.setter
272 def row_factory(self, factory: Optional[type]) -> None:
273 self._conn.row_factory = factory
275 @property
276 def text_factory(self) -> Callable[[bytes], Any]:
277 return self._conn.text_factory
279 @text_factory.setter
280 def text_factory(self, factory: Callable[[bytes], Any]) -> None:
281 self._conn.text_factory = factory
283 @property
284 def total_changes(self) -> int:
285 return self._conn.total_changes
287 async def enable_load_extension(self, value: bool) -> None:
288 await self._execute(self._conn.enable_load_extension, value) # type: ignore
290 async def load_extension(self, path: str):
291 await self._execute(self._conn.load_extension, path) # type: ignore
293 async def set_progress_handler(
294 self, handler: Callable[[], Optional[int]], n: int
295 ) -> None:
296 await self._execute(self._conn.set_progress_handler, handler, n)
298 async def set_trace_callback(self, handler: Callable) -> None:
299 await self._execute(self._conn.set_trace_callback, handler)
301 async def set_authorizer(
302 self, authorizer_callback: Optional[AuthorizerCallback]
303 ) -> None:
304 """
305 Set an authorizer callback to control database access.
307 The authorizer callback is invoked for each SQL statement that is prepared,
308 and controls whether specific operations are permitted.
310 Example::
312 import sqlite3
314 def restrict_drops(action_code, arg1, arg2, db_name, trigger_name):
315 # Deny all DROP operations
316 if action_code == sqlite3.SQLITE_DROP_TABLE:
317 return sqlite3.SQLITE_DENY
318 # Allow everything else
319 return sqlite3.SQLITE_OK
321 await conn.set_authorizer(restrict_drops)
323 See ``sqlite3`` documentation for details:
324 https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.set_authorizer
326 :param authorizer_callback: An optional callable that receives five arguments:
328 - ``action_code`` (int): The action to be authorized (e.g., ``SQLITE_READ``)
329 - ``arg1`` (str): First argument, meaning depends on ``action_code``
330 - ``arg2`` (str): Second argument, meaning depends on ``action_code``
331 - ``db_name`` (str): Database name (e.g., ``"main"``, ``"temp"``)
332 - ``trigger_name`` (str): Name of trigger or view that is doing the access,
333 or ``None``
335 The callback should return:
337 - ``SQLITE_OK`` (0): Allow the operation
338 - ``SQLITE_DENY`` (1): Deny the operation, raise ``sqlite3.DatabaseError``
339 - ``SQLITE_IGNORE`` (2): Treat operation as no-op
341 Pass ``None`` to remove the authorizer.
342 """
343 await self._execute(self._conn.set_authorizer, authorizer_callback)
345 async def iterdump(self) -> AsyncIterator[str]:
346 """
347 Return an async iterator to dump the database in SQL text format.
349 Example::
351 async for line in db.iterdump():
352 ...
354 """
355 dump_queue: Queue = Queue()
357 def dumper():
358 try:
359 for line in self._conn.iterdump():
360 dump_queue.put_nowait(line)
361 dump_queue.put_nowait(None)
363 except Exception:
364 LOG.exception("exception while dumping db")
365 dump_queue.put_nowait(None)
366 raise
368 fut = self._execute(dumper)
369 task = asyncio.ensure_future(fut)
371 while True:
372 try:
373 line: Optional[str] = dump_queue.get_nowait()
374 if line is None:
375 break
376 yield line
378 except Empty:
379 if task.done():
380 LOG.warning("iterdump completed unexpectedly")
381 break
383 await asyncio.sleep(0.01)
385 await task
387 async def backup(
388 self,
389 target: Union["Connection", sqlite3.Connection],
390 *,
391 pages: int = 0,
392 progress: Optional[Callable[[int, int, int], None]] = None,
393 name: str = "main",
394 sleep: float = 0.250,
395 ) -> None:
396 """
397 Make a backup of the current database to the target database.
399 Takes either a standard sqlite3 or aiosqlite Connection object as the target.
400 """
401 if isinstance(target, Connection):
402 target = target._conn
404 await self._execute(
405 self._conn.backup,
406 target,
407 pages=pages,
408 progress=progress,
409 name=name,
410 sleep=sleep,
411 )
413 def __del__(self):
414 if self._connection is None:
415 return
417 warn(
418 (
419 f"{self!r} was deleted before being closed. "
420 "Please use 'async with' or '.close()' to close the connection properly."
421 ),
422 ResourceWarning,
423 stacklevel=1,
424 )
426 # Don't try to be creative here, the event loop may have already been closed.
427 # Simply stop the worker thread, and let the underlying sqlite3 connection
428 # be finalized by its own __del__.
429 self._stop_running()
432def connect(
433 database: Union[str, Path],
434 *,
435 iter_chunk_size=64,
436 loop: Optional[asyncio.AbstractEventLoop] = None,
437 **kwargs: Any,
438) -> Connection:
439 """Create and return a connection proxy to the sqlite database."""
441 if loop is not None:
442 warn(
443 "aiosqlite.connect() no longer uses the `loop` parameter",
444 DeprecationWarning,
445 )
447 def connector() -> sqlite3.Connection:
448 if isinstance(database, str):
449 loc = database
450 elif isinstance(database, bytes):
451 loc = database.decode("utf-8")
452 else:
453 loc = str(database)
455 return sqlite3.connect(loc, **kwargs)
457 return Connection(connector, iter_chunk_size)