Source code for wsqlite.core.connection

"""Connection management for SQLite with thread-safe operations."""

import logging
import threading
from contextlib import contextmanager
from typing import Any, Iterator, Optional

logger = logging.getLogger(__name__)

_global_connection_lock = threading.Lock()
_global_connection: Optional[Any] = None
_db_path: Optional[str] = None


def _get_connection(db_path: str) -> Any:
    """Get or create global connection."""
    global _global_connection, _db_path

    with _global_connection_lock:
        if _global_connection is None or _db_path != db_path:
            import sqlite3

            _global_connection = sqlite3.connect(db_path, check_same_thread=False)
            _global_connection.row_factory = sqlite3.Row
            _db_path = db_path
            logger.info(f"Created global connection to {db_path}")
        return _global_connection


[docs] def close_global_connection(): """Close global connection.""" global _global_connection, _db_path with _global_connection_lock: if _global_connection: _global_connection.close() _global_connection = None _db_path = None logger.info("Global SQLite connection closed")
class _SQLiteConnection: """Wrapper for SQLite connection.""" def __init__(self, conn: Any): self._conn = conn def __enter__(self): return self._conn def __exit__(self, *args): pass # SQLite connections are managed globally def __getattr__(self, name): return getattr(self._conn, name)
[docs] class Transaction: """Context manager for database transactions (sync)."""
[docs] def __init__(self, db_path: str): self.db_path = db_path self.conn: Optional[Any] = None self._committed = False
def __enter__(self): import sqlite3 self.conn = sqlite3.connect(self.db_path) self.conn.execute("PRAGMA journal_mode=WAL") self.conn.execute("PRAGMA busy_timeout=5000") self.conn.isolation_level = None self.conn.execute("BEGIN") logger.debug("Transaction started") return self def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is not None: self.conn.rollback() logger.debug("Transaction rolled back") elif not self._committed: self.conn.commit() logger.debug("Transaction committed") self.conn.close() return False
[docs] def commit(self): self.conn.commit() self._committed = True
[docs] def rollback(self): self.conn.rollback()
[docs] def execute(self, query: str, values: tuple = None) -> Any: cursor = self.conn.cursor() cursor.execute(query, values or ()) if cursor.description: return cursor.fetchall() return cursor.rowcount
[docs] class AsyncTransaction: """Context manager for database transactions (async)."""
[docs] def __init__(self, db_path: str): self.db_path = db_path self.conn: Optional[Any] = None self._committed = False
async def __aenter__(self): import aiosqlite self.conn = await aiosqlite.connect(self.db_path) await self.conn.execute("PRAGMA journal_mode=WAL") await self.conn.execute("PRAGMA busy_timeout=5000") self.conn.isolation_level = None await self.conn.execute("BEGIN") logger.debug("Async transaction started") return self async def __aexit__(self, exc_type, exc_val, exc_tb): if exc_type is not None: await self.conn.rollback() logger.debug("Async transaction rolled back") elif not self._committed: await self.conn.commit() logger.debug("Async transaction committed") await self.conn.close() return False
[docs] async def commit(self): await self.conn.commit() self._committed = True
[docs] async def rollback(self): await self.conn.rollback()
[docs] async def execute(self, query: str, values: tuple = None) -> Any: cursor = await self.conn.cursor() await cursor.execute(query, values or ()) if cursor.description: return await cursor.fetchall() return cursor.rowcount
[docs] @contextmanager def get_transaction(db_path: str) -> Iterator[Transaction]: """Get a transaction context manager (sync).""" transaction = Transaction(db_path) yield transaction
[docs] @contextmanager async def get_async_transaction(db_path: str) -> Iterator[AsyncTransaction]: """Get an async transaction context manager.""" transaction = AsyncTransaction(db_path) async with transaction: yield transaction
[docs] def get_connection(db_path: str) -> _SQLiteConnection: """Get a connection from global pool (sync). Usage: with get_connection("database.db") as conn: conn.execute(...) """ conn = _get_connection(db_path) return _SQLiteConnection(conn)
[docs] async def get_async_connection(db_path: str) -> Any: """Get a configured async connection. Usage: async with get_async_connection("database.db") as conn: await conn.execute(...) """ import aiosqlite conn = await aiosqlite.connect(db_path) conn.row_factory = aiosqlite.Row await conn.executescript("PRAGMA journal_mode=WAL; PRAGMA busy_timeout=5000;") return conn
def retry_on_lock(max_retries: int = 3, delay: float = 0.1): """Decorator to retry operations on database lock. Usage: @retry_on_lock(max_retries=5, delay=0.2) def insert_data(data): db.insert(data) """ def decorator(func): def wrapper(*args, **kwargs): import time last_error = None for attempt in range(max_retries): try: return func(*args, **kwargs) except Exception as e: last_error = e error_msg = str(e).lower() if "locked" in error_msg or "busy" in error_msg: wait_time = delay * (2**attempt) logger.debug( f"Database locked, retry {attempt + 1}/{max_retries} " f"after {wait_time:.2f}s" ) time.sleep(wait_time) else: raise raise last_error return wrapper return decorator