Source code for wsqlite.core.repository

"""Main repository class for SQLite operations with connection pooling."""

import logging
import re
from typing import Any, Callable, Optional

from pydantic import BaseModel

from wsqlite.core.connection import (
    AsyncTransaction,
    Transaction,
    get_async_connection,
    get_connection,
    get_transaction,
    retry_on_lock,
)
from wsqlite.core.pool import ConnectionPool, get_pool, close_pool
from wsqlite.core.sync import AsyncTableSync, TableSync
from wsqlite.exceptions import DatabaseLockedError, SQLInjectionError, TransactionError

logger = logging.getLogger(__name__)


def validate_identifier(identifier: str) -> None:
    """Validate SQL identifier to prevent SQL injection.

    Args:
        identifier: Table or column name to validate.

    Raises:
        SQLInjectionError: If identifier contains dangerous characters.
    """
    if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", identifier):
        raise SQLInjectionError(identifier)


[docs] class WSQLite: """SQLite repository using Pydantic models. Provides a simple interface for CRUD operations on SQLite tables, with automatic table creation, schema synchronization, and connection pooling. Example: from pydantic import BaseModel from wsqlite import WSQLite class User(BaseModel): id: int name: str email: str db = WSQLite(User, "database.db") db.insert(User(id=1, name="John", email="john@example.com")) """
[docs] def __init__( self, model: type[BaseModel], db_path: str, pool_size: int = 10, min_pool_size: int = 2, use_pool: bool = True, ): """Initialize the repository with a Pydantic model. Args: model: Pydantic BaseModel class defining the table schema. db_path: Path to SQLite database file. pool_size: Maximum number of connections in pool. min_pool_size: Minimum number of connections in pool. use_pool: Whether to use connection pooling (recommended). """ self.model = model self.db_path = db_path self.table_name = model.__name__.lower() self.use_pool = use_pool if use_pool: self._pool = get_pool( db_path, min_size=min_pool_size, max_size=pool_size, ) else: self._pool = None self._sync = TableSync(model, db_path) self._sync.create_if_not_exists() self._sync.sync_with_model() logger.info( f"WSQLite initialized for table '{self.table_name}' (pool={use_pool}, size={pool_size})" )
def _execute(self, query: str, values: tuple = (), commit: bool = True) -> Any: """Execute a query using pool or direct connection.""" if self.use_pool and self._pool: with self._pool.connection() as conn: cursor = conn.execute(query, values) if cursor.description: result = cursor.fetchall() else: result = cursor.rowcount if commit: conn.commit() return result else: with get_connection(self.db_path) as conn: cursor = conn.execute(query, values) if cursor.description: result = cursor.fetchall() else: result = cursor.rowcount if commit: conn.commit() return result
[docs] def insert(self, data: BaseModel) -> None: """Insert a new record into the database.""" data_dict = data.model_dump() fields = ", ".join(data_dict.keys()) placeholders = ", ".join(["?"] * len(data_dict)) values = tuple(data_dict.values()) query = f"INSERT INTO {self.table_name} ({fields}) VALUES ({placeholders})" self._execute(query, values)
[docs] def get_all(self) -> list[BaseModel]: """Get all records from the table.""" query = f"SELECT * FROM {self.table_name}" rows = self._execute(query, commit=False) return [ self.model( **{ key: (value if value is not None else self._default_value(key)) for key, value in zip(self.model.model_fields.keys(), row) } ) for row in rows ]
[docs] def get_by_field(self, **filters) -> list[BaseModel]: """Get records filtered by specified fields.""" if not filters: return self.get_all() conditions = " AND ".join(f"{key} = ?" for key in filters) values = tuple(filters.values()) query = f"SELECT * FROM {self.table_name} WHERE {conditions}" rows = self._execute(query, values, commit=False) return [ self.model( **{ key: (value if value is not None else self._default_value(key)) for key, value in zip(self.model.model_fields.keys(), row) } ) for row in rows ]
[docs] def update(self, record_id: int, data: BaseModel) -> None: """Update a record in the database.""" data_dict = data.model_dump() fields = ", ".join(f"{key} = ?" for key in data_dict.keys()) values = tuple(data_dict.values()) + (record_id,) query = f"UPDATE {self.table_name} SET {fields} WHERE id = ?" self._execute(query, values)
[docs] def delete(self, record_id: int) -> None: """Delete a record from the database.""" query = f"DELETE FROM {self.table_name} WHERE id = ?" self._execute(query, (record_id,))
def _default_value(self, field: str) -> Any: """Get default value for a field when database value is NULL.""" field_type = self.model.model_fields[field].annotation if field_type is str: return "" elif field_type is int: return 0 elif field_type is bool: return False return None
[docs] def get_paginated( self, limit: int = 10, offset: int = 0, order_by: Optional[str] = None, order_desc: bool = False, ) -> list[BaseModel]: """Get records with pagination. Args: limit: Maximum number of records to return. offset: Number of records to skip. order_by: Column to order by. order_desc: If True, order descending. Returns: List of model instances. """ validate_identifier(self.table_name) if order_by: validate_identifier(order_by) order_clause = f" ORDER BY {order_by} {'DESC' if order_desc else 'ASC'}" else: order_clause = "" query = f"SELECT * FROM {self.table_name}{order_clause} LIMIT ? OFFSET ?" rows = self._execute(query, (limit, offset), commit=False) return [ self.model( **{ key: (value if value is not None else self._default_value(key)) for key, value in zip(self.model.model_fields.keys(), row) } ) for row in rows ]
[docs] def get_page(self, page: int = 1, per_page: int = 10) -> list[BaseModel]: """Get records by page number. Args: page: Page number (1-indexed). per_page: Number of records per page. Returns: List of model instances for the requested page. """ if page < 1: page = 1 if per_page < 1: per_page = 10 offset = (page - 1) * per_page return self.get_paginated(limit=per_page, offset=offset)
[docs] def count(self) -> int: """Get total number of records in the table.""" validate_identifier(self.table_name) query = f"SELECT COUNT(*) FROM {self.table_name}" result = self._execute(query, commit=False) return result[0][0] if result else 0
[docs] def insert_many(self, data_list: list[BaseModel]) -> None: """Insert multiple records in a single transaction. Args: data_list: List of model instances to insert. """ if not data_list: return data_dicts = [data.model_dump() for data in data_list] fields = ", ".join(data_dicts[0].keys()) placeholders = ", ".join(["?"] * len(data_dicts[0])) query = f"INSERT INTO {self.table_name} ({fields}) VALUES ({placeholders})" if self.use_pool and self._pool: with self._pool.connection() as conn: for data_dict in data_dicts: values = tuple(data_dict.values()) conn.execute(query, values) conn.commit() else: with get_transaction(self.db_path) as txn: for data_dict in data_dicts: values = tuple(data_dict.values()) txn.execute(query, values) txn.commit()
[docs] def update_many(self, updates: list[tuple[BaseModel, int]]) -> int: """Update multiple records. Args: updates: List of (model, record_id) tuples. Returns: Number of records updated. """ if not updates: return 0 validate_identifier(self.table_name) total_updated = 0 if self.use_pool and self._pool: with self._pool.connection() as conn: for data, record_id in updates: data_dict = data.model_dump() fields = ", ".join(f"{key} = ?" for key in data_dict) values = tuple(data_dict.values()) + (record_id,) query = f"UPDATE {self.table_name} SET {fields} WHERE id = ?" conn.execute(query, values) total_updated += conn.total_changes conn.commit() else: with get_transaction(self.db_path) as txn: for data, record_id in updates: data_dict = data.model_dump() fields = ", ".join(f"{key} = ?" for key in data_dict) values = tuple(data_dict.values()) + (record_id,) query = f"UPDATE {self.table_name} SET {fields} WHERE id = ?" txn.execute(query, values) total_updated += txn.conn.total_changes txn.commit() return total_updated
[docs] def delete_many(self, record_ids: list[int]) -> int: """Delete multiple records by their IDs. Args: record_ids: List of record IDs to delete. Returns: Number of records deleted. """ if not record_ids: return 0 validate_identifier(self.table_name) if self.use_pool and self._pool: with self._pool.connection() as conn: for record_id in record_ids: query = f"DELETE FROM {self.table_name} WHERE id = ?" conn.execute(query, (record_id,)) conn.commit() else: with get_transaction(self.db_path) as txn: for record_id in record_ids: query = f"DELETE FROM {self.table_name} WHERE id = ?" txn.execute(query, (record_id,)) txn.commit() return len(record_ids)
[docs] def execute_transaction(self, operations: list[tuple[str, tuple]]) -> list[Any]: """Execute multiple operations in a transaction. Args: operations: List of (query, params) tuples. Returns: List of results from each operation. """ results = [] try: if self.use_pool and self._pool: with self._pool.connection() as conn: for query, values in operations: cursor = conn.execute(query, values) if cursor.description: results.append(cursor.fetchall()) conn.commit() else: with get_transaction(self.db_path) as txn: for query, values in operations: result = txn.execute(query, values) if result is not None: results.append(result) txn.commit() logger.info(f"Transaction completed with {len(operations)} operations") except Exception as e: logger.error(f"Transaction failed: {e}") raise TransactionError(f"Transaction failed: {e}") from e return results
[docs] def with_transaction(self, func: Callable[[Transaction], Any]) -> Any: """Execute a function within a transaction. Args: func: Function that receives Transaction and performs operations. Returns: Result of the function. """ try: if self.use_pool and self._pool: with self._pool.connection() as conn: txn = Transaction(self.db_path) txn.conn = conn result = func(txn) conn.commit() else: with get_transaction(self.db_path) as txn: result = func(txn) txn.commit() logger.info("Transaction completed successfully") return result except Exception as e: logger.error(f"Transaction failed: {e}") raise TransactionError(f"Transaction failed: {e}") from e
@retry_on_lock(max_retries=3, delay=0.1) def insert_with_retry(self, data: BaseModel) -> None: """Insert with automatic retry on database lock.""" self.insert(data)
[docs] async def insert_async(self, data: BaseModel) -> None: """Insert a new record into the database (async).""" data_dict = data.model_dump() fields = ", ".join(data_dict.keys()) placeholders = ", ".join(["?"] * len(data_dict)) values = tuple(data_dict.values()) query = f"INSERT INTO {self.table_name} ({fields}) VALUES ({placeholders})" conn = await get_async_connection(self.db_path) try: await conn.execute(query, values) await conn.commit() finally: await conn.close()
[docs] async def get_all_async(self) -> list[BaseModel]: """Get all records from the table (async).""" query = f"SELECT * FROM {self.table_name}" conn = await get_async_connection(self.db_path) try: cursor = await conn.execute(query) rows = await cursor.fetchall() finally: await conn.close() return [ self.model( **{ key: (value if value is not None else self._default_value(key)) for key, value in zip(self.model.model_fields.keys(), row) } ) for row in rows ]
[docs] async def get_by_field_async(self, **filters) -> list[BaseModel]: """Get records filtered by specified fields (async).""" if not filters: return await self.get_all_async() conditions = " AND ".join(f"{key} = ?" for key in filters) values = tuple(filters.values()) query = f"SELECT * FROM {self.table_name} WHERE {conditions}" conn = await get_async_connection(self.db_path) try: cursor = await conn.execute(query, values) rows = await cursor.fetchall() finally: await conn.close() return [ self.model( **{ key: (value if value is not None else self._default_value(key)) for key, value in zip(self.model.model_fields.keys(), row) } ) for row in rows ]
[docs] async def update_async(self, record_id: int, data: BaseModel) -> None: """Update a record in the database (async).""" data_dict = data.model_dump() fields = ", ".join(f"{key} = ?" for key in data_dict.keys()) values = tuple(data_dict.values()) + (record_id,) query = f"UPDATE {self.table_name} SET {fields} WHERE id = ?" conn = await get_async_connection(self.db_path) try: await conn.execute(query, values) await conn.commit() finally: await conn.close()
[docs] async def delete_async(self, record_id: int) -> None: """Delete a record from the database (async).""" query = f"DELETE FROM {self.table_name} WHERE id = ?" conn = await get_async_connection(self.db_path) try: await conn.execute(query, (record_id,)) await conn.commit() finally: await conn.close()
[docs] async def get_paginated_async( self, limit: int = 10, offset: int = 0, order_by: Optional[str] = None, order_desc: bool = False, ) -> list[BaseModel]: """Get records with pagination (async).""" validate_identifier(self.table_name) if order_by: validate_identifier(order_by) order_clause = f" ORDER BY {order_by} {'DESC' if order_desc else 'ASC'}" else: order_clause = "" query = f"SELECT * FROM {self.table_name}{order_clause} LIMIT ? OFFSET ?" conn = await get_async_connection(self.db_path) try: cursor = await conn.execute(query, (limit, offset)) rows = await cursor.fetchall() finally: await conn.close() return [ self.model( **{ key: (value if value is not None else self._default_value(key)) for key, value in zip(self.model.model_fields.keys(), row) } ) for row in rows ]
[docs] async def get_page_async(self, page: int = 1, per_page: int = 10) -> list[BaseModel]: """Get records by page number (async).""" if page < 1: page = 1 if per_page < 1: per_page = 10 offset = (page - 1) * per_page return await self.get_paginated_async(limit=per_page, offset=offset)
[docs] async def count_async(self) -> int: """Get total number of records in the table (async).""" validate_identifier(self.table_name) query = f"SELECT COUNT(*) FROM {self.table_name}" conn = await get_async_connection(self.db_path) try: cursor = await conn.execute(query) result = await cursor.fetchone() finally: await conn.close() return result[0] if result else 0
[docs] async def insert_many_async(self, data_list: list[BaseModel]) -> None: """Insert multiple records in a single transaction (async).""" if not data_list: return data_dicts = [data.model_dump() for data in data_list] fields = ", ".join(data_dicts[0].keys()) placeholders = ", ".join(["?"] * len(data_dicts[0])) query = f"INSERT INTO {self.table_name} ({fields}) VALUES ({placeholders})" conn = await get_async_connection(self.db_path) try: for data_dict in data_dicts: values = tuple(data_dict.values()) await conn.execute(query, values) await conn.commit() finally: await conn.close()
[docs] async def update_many_async(self, updates: list[tuple[BaseModel, int]]) -> int: """Update multiple records (async).""" if not updates: return 0 validate_identifier(self.table_name) total_updated = 0 conn = await get_async_connection(self.db_path) try: for data, record_id in updates: data_dict = data.model_dump() fields = ", ".join(f"{key} = ?" for key in data_dict) values = tuple(data_dict.values()) + (record_id,) query = f"UPDATE {self.table_name} SET {fields} WHERE id = ?" await conn.execute(query, values) total_updated += conn.total_changes await conn.commit() finally: await conn.close() return total_updated
[docs] async def delete_many_async(self, record_ids: list[int]) -> int: """Delete multiple records by their IDs (async).""" if not record_ids: return 0 validate_identifier(self.table_name) conn = await get_async_connection(self.db_path) try: for record_id in record_ids: query = f"DELETE FROM {self.table_name} WHERE id = ?" await conn.execute(query, (record_id,)) await conn.commit() finally: await conn.close() return len(record_ids)
[docs] async def execute_transaction_async(self, operations: list[tuple[str, tuple]]) -> list[Any]: """Execute multiple operations in a transaction (async).""" results = [] conn = await get_async_connection(self.db_path) try: for query, values in operations: cursor = await conn.execute(query, values) if cursor.description: result = await cursor.fetchall() results.append(result) await conn.commit() logger.info(f"Async transaction completed with {len(operations)} operations") except Exception as e: logger.error(f"Async transaction failed: {e}") await conn.close() raise TransactionError(f"Async transaction failed: {e}") from e finally: if not conn._connection: # not closed yet await conn.close() return results
[docs] async def with_transaction_async(self, func: Callable[[AsyncTransaction], Any]) -> Any: """Execute a function within a transaction (async).""" try: conn = await get_async_connection(self.db_path) async with conn: txn = AsyncTransaction(self.db_path) txn.conn = conn result = await func(txn) await txn.commit() logger.info("Async transaction completed successfully") return result except Exception as e: logger.error(f"Async transaction failed: {e}") raise TransactionError(f"Async transaction failed: {e}") from e