Source code for wsqlite.core.repository

"""Main repository class for SQLite operations with connection pooling."""

import logging
import re
from typing import Any, Callable, Optional

from pydantic import BaseModel

from wsqlite.core.connection import (
    AsyncTransaction,
    Transaction,
    get_async_connection,
    get_connection,
    get_transaction,
    retry_on_lock,
)
from wsqlite.core.pool import ConnectionPool, get_pool, close_pool
from wsqlite.core.serialization import serialize_value, deserialize_value
from wsqlite.core.sync import AsyncTableSync, TableSync
from wsqlite.exceptions import DatabaseLockedError, SQLInjectionError, TransactionError

logger = logging.getLogger(__name__)


def validate_identifier(identifier: str) -> None:
    """Validate SQL identifier to prevent SQL injection.

    Args:
        identifier: Table or column name to validate.

    Raises:
        SQLInjectionError: If identifier contains dangerous characters.
    """
    if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", identifier):
        raise SQLInjectionError(identifier)


[docs] class WSQLite: """SQLite repository using Pydantic models. Provides a simple interface for CRUD operations on SQLite tables, with automatic table creation, schema synchronization, and connection pooling. Example: from pydantic import BaseModel from wsqlite import WSQLite class User(BaseModel): id: int name: str email: str db = WSQLite(User, "database.db") db.insert(User(id=1, name="John", email="john@example.com")) """
[docs] def __init__( self, model: type[BaseModel], db_path: str, pool_size: int = 10, min_pool_size: int = 2, use_pool: bool = True, table_name: Optional[str] = None, soft_delete: bool = False, deleted_at_field: str = "deleted_at", pool: Optional[ConnectionPool] = None, sync_handler: Optional[TableSync] = None, ): """Initialize the repository with a Pydantic model. Args: model: Pydantic BaseModel class defining the table schema. db_path: Path to SQLite database file. pool_size: Maximum number of connections in pool (if pool is not provided). min_pool_size: Minimum number of connections in pool (if pool is not provided). use_pool: Whether to use connection pooling (recommended). table_name: Optional custom table name. soft_delete: Whether to use soft deletes (default False). deleted_at_field: Name of the field for soft deletes (default "deleted_at"). pool: Optional pre-configured connection pool. sync_handler: Optional pre-configured TableSync instance. """ self.model = model self.db_path = db_path self.table_name = table_name or model.__name__.lower() self.use_pool = use_pool self.soft_delete = soft_delete self.deleted_at_field = deleted_at_field if use_pool: if pool: self._pool = pool else: self._pool = get_pool( db_path, min_size=min_pool_size, max_size=pool_size, ) else: self._pool = None self._sync = sync_handler or TableSync(model, db_path, table_name=self.table_name) self._sync.create_if_not_exists() self._sync.sync_with_model() logger.info( f"WSQLite initialized for table '{self.table_name}' (pool={use_pool}, size={pool_size}, soft_delete={soft_delete})" )
def _call_hook(self, instance: Any, hook_name: str, *args, **kwargs) -> None: """Call a hook method on the model instance if it exists.""" hook = getattr(instance, hook_name, None) if hook and callable(hook): hook(*args, **kwargs) def _soft_delete_condition(self, prefix: str = "") -> str: """Return the SQL condition for soft delete filtering.""" if not self.soft_delete: return "" col = f"{prefix}.{self.deleted_at_field}" if prefix else self.deleted_at_field return f"{col} IS NULL" def _add_soft_delete_filter(self, conditions: str) -> str: """Add soft delete condition to WHERE clause if enabled.""" sd_cond = self._soft_delete_condition() if not sd_cond: return conditions return f"({conditions}) AND {sd_cond}" if conditions else sd_cond def _dump(self, data: BaseModel) -> dict: """Serialize a model instance to a dictionary for SQLite insertion.""" data_dict = data.model_dump(mode='json') for key, val in data_dict.items(): if key in self.model.model_fields: annotation = self.model.model_fields[key].annotation data_dict[key] = serialize_value(val, annotation) return data_dict def _load(self, row: tuple) -> BaseModel: """Deserialize a SQLite row to a model instance.""" data = {} for key, value in zip(self.model.model_fields.keys(), row): annotation = self.model.model_fields[key].annotation val = deserialize_value(value, annotation) if value is not None else self._default_value(key) data[key] = val return self.model(**data) def _execute(self, query: str, values: tuple = (), commit: bool = True) -> Any: """Execute a query using pool or direct connection.""" if self.use_pool and self._pool: with self._pool.connection() as conn: cursor = conn.execute(query, values) if cursor.description: result = cursor.fetchall() else: result = cursor.rowcount if commit: conn.commit() return result else: with get_connection(self.db_path) as conn: cursor = conn.execute(query, values) if cursor.description: result = cursor.fetchall() else: result = cursor.rowcount if commit: conn.commit() return result
[docs] def insert(self, data: BaseModel) -> None: """Insert a new record into the database.""" self._call_hook(data, "pre_save") data_dict = self._dump(data) fields = ", ".join(data_dict.keys()) placeholders = ", ".join(["?"] * len(data_dict)) values = tuple(data_dict.values()) query = f"INSERT INTO {self.table_name} ({fields}) VALUES ({placeholders})" self._execute(query, values) self._call_hook(data, "post_save")
[docs] def get_all(self) -> list[BaseModel]: """Get all records from the table.""" condition = self._soft_delete_condition() where_clause = f" WHERE {condition}" if condition else "" query = f"SELECT * FROM {self.table_name}{where_clause}" rows = self._execute(query, commit=False) return [self._load(row) for row in rows]
[docs] def get_by_field(self, **filters) -> list[BaseModel]: """Get records filtered by specified fields.""" conditions_list = [f"{key} = ?" for key in filters] conditions = " AND ".join(conditions_list) conditions = self._add_soft_delete_filter(conditions) where_clause = f" WHERE {conditions}" if conditions else "" values = tuple(filters.values()) query = f"SELECT * FROM {self.table_name}{where_clause}" rows = self._execute(query, values, commit=False) return [self._load(row) for row in rows]
[docs] def update(self, record_id: int, data: BaseModel) -> None: """Update a record in the database.""" self._call_hook(data, "pre_save") data_dict = self._dump(data) fields = ", ".join(f"{key} = ?" for key in data_dict.keys()) values = tuple(data_dict.values()) + (record_id,) query = f"UPDATE {self.table_name} SET {fields} WHERE id = ?" self._execute(query, values) self._call_hook(data, "post_save")
[docs] def delete(self, record_id: int) -> None: """Delete a record from the database (hard or soft).""" if self.soft_delete: from datetime import datetime now = datetime.now().isoformat() query = f"UPDATE {self.table_name} SET {self.deleted_at_field} = ? WHERE id = ?" self._execute(query, (now, record_id)) else: query = f"DELETE FROM {self.table_name} WHERE id = ?" self._execute(query, (record_id,))
[docs] def restore(self, record_id: int) -> None: """Restore a soft-deleted record.""" if not self.soft_delete: return query = f"UPDATE {self.table_name} SET {self.deleted_at_field} = NULL WHERE id = ?" self._execute(query, (record_id,))
def _default_value(self, field: str) -> Any: """Get default value for a field when database value is NULL.""" field_type = self.model.model_fields[field].annotation if field_type is str: return "" elif field_type is int: return 0 elif field_type is bool: return False return None
[docs] def get_paginated( self, limit: int = 10, offset: int = 0, order_by: Optional[str] = None, order_desc: bool = False, ) -> list[BaseModel]: """Get records with pagination.""" validate_identifier(self.table_name) condition = self._soft_delete_condition() where_clause = f" WHERE {condition}" if condition else "" if order_by: validate_identifier(order_by) order_clause = f" ORDER BY {order_by} {'DESC' if order_desc else 'ASC'}" else: order_clause = "" query = f"SELECT * FROM {self.table_name}{where_clause}{order_clause} LIMIT ? OFFSET ?" rows = self._execute(query, (limit, offset), commit=False) return [self._load(row) for row in rows]
[docs] def get_page(self, page: int = 1, per_page: int = 10) -> list[BaseModel]: """Get records by page number. Args: page: Page number (1-indexed). per_page: Number of records per page. Returns: List of model instances for the requested page. """ if page < 1: page = 1 if per_page < 1: per_page = 10 offset = (page - 1) * per_page return self.get_paginated(limit=per_page, offset=offset)
[docs] def count(self) -> int: """Get total number of records in the table.""" validate_identifier(self.table_name) condition = self._soft_delete_condition() where_clause = f" WHERE {condition}" if condition else "" query = f"SELECT COUNT(*) FROM {self.table_name}{where_clause}" result = self._execute(query, commit=False) return result[0][0] if result else 0
[docs] def insert_many(self, data_list: list[BaseModel]) -> None: """Insert multiple records in a single transaction. Args: data_list: List of model instances to insert. """ if not data_list: return for data in data_list: self._call_hook(data, "pre_save") data_dicts = [self._dump(data) for data in data_list] fields = ", ".join(data_dicts[0].keys()) placeholders = ", ".join(["?"] * len(data_dicts[0])) query = f"INSERT INTO {self.table_name} ({fields}) VALUES ({placeholders})" if self.use_pool and self._pool: with self._pool.connection() as conn: for data_dict in data_dicts: values = tuple(data_dict.values()) conn.execute(query, values) conn.commit() else: with get_transaction(self.db_path) as txn: for data_dict in data_dicts: values = tuple(data_dict.values()) txn.execute(query, values) txn.commit() for data in data_list: self._call_hook(data, "post_save")
[docs] def update_many(self, updates: list[tuple[BaseModel, int]]) -> int: """Update multiple records. Args: updates: List of (model, record_id) tuples. Returns: Number of records updated. """ if not updates: return 0 for data, _ in updates: self._call_hook(data, "pre_save") validate_identifier(self.table_name) total_updated = 0 if self.use_pool and self._pool: with self._pool.connection() as conn: for data, record_id in updates: data_dict = self._dump(data) fields = ", ".join(f"{key} = ?" for key in data_dict) values = tuple(data_dict.values()) + (record_id,) query = f"UPDATE {self.table_name} SET {fields} WHERE id = ?" conn.execute(query, values) total_updated += conn.total_changes conn.commit() else: with get_transaction(self.db_path) as txn: for data, record_id in updates: data_dict = self._dump(data) fields = ", ".join(f"{key} = ?" for key in data_dict) values = tuple(data_dict.values()) + (record_id,) query = f"UPDATE {self.table_name} SET {fields} WHERE id = ?" txn.execute(query, values) total_updated += txn.conn.total_changes txn.commit() for data, _ in updates: self._call_hook(data, "post_save") return total_updated
[docs] def delete_many(self, record_ids: list[int]) -> int: """Delete multiple records by their IDs (hard or soft). Args: record_ids: List of record IDs to delete. Returns: Number of records deleted. """ if not record_ids: return 0 validate_identifier(self.table_name) if self.soft_delete: from datetime import datetime now = datetime.now().isoformat() query = f"UPDATE {self.table_name} SET {self.deleted_at_field} = ? WHERE id = ?" params = [(now, rid) for rid in record_ids] else: query = f"DELETE FROM {self.table_name} WHERE id = ?" params = [(rid,) for rid in record_ids] if self.use_pool and self._pool: with self._pool.connection() as conn: for p in params: conn.execute(query, p) conn.commit() else: with get_transaction(self.db_path) as txn: for p in params: txn.execute(query, p) txn.commit() return len(record_ids)
[docs] def execute_transaction(self, operations: list[tuple[str, tuple]]) -> list[Any]: """Execute multiple operations in a transaction. Args: operations: List of (query, params) tuples. Returns: List of results from each operation. """ results = [] try: if self.use_pool and self._pool: with self._pool.connection() as conn: for query, values in operations: cursor = conn.execute(query, values) if cursor.description: results.append(cursor.fetchall()) conn.commit() else: with get_transaction(self.db_path) as txn: for query, values in operations: result = txn.execute(query, values) if result is not None: results.append(result) txn.commit() logger.info(f"Transaction completed with {len(operations)} operations") except Exception as e: logger.error(f"Transaction failed: {e}") raise TransactionError(f"Transaction failed: {e}") from e return results
[docs] def with_transaction(self, func: Callable[[Transaction], Any]) -> Any: """Execute a function within a transaction. Args: func: Function that receives Transaction and performs operations. Returns: Result of the function. """ try: if self.use_pool and self._pool: with self._pool.connection() as conn: txn = Transaction(self.db_path) txn.conn = conn result = func(txn) conn.commit() else: with get_transaction(self.db_path) as txn: result = func(txn) txn.commit() logger.info("Transaction completed successfully") return result except Exception as e: logger.error(f"Transaction failed: {e}") raise TransactionError(f"Transaction failed: {e}") from e
@retry_on_lock(max_retries=3, delay=0.1) def insert_with_retry(self, data: BaseModel) -> None: """Insert with automatic retry on database lock.""" self.insert(data)
[docs] async def insert_async(self, data: BaseModel) -> None: """Insert a new record into the database (async).""" self._call_hook(data, "pre_save") data_dict = self._dump(data) fields = ", ".join(data_dict.keys()) placeholders = ", ".join(["?"] * len(data_dict)) values = tuple(data_dict.values()) query = f"INSERT INTO {self.table_name} ({fields}) VALUES ({placeholders})" conn = await get_async_connection(self.db_path) try: await conn.execute(query, values) await conn.commit() finally: await conn.close() self._call_hook(data, "post_save")
[docs] async def get_all_async(self) -> list[BaseModel]: """Get all records from the table (async).""" condition = self._soft_delete_condition() where_clause = f" WHERE {condition}" if condition else "" query = f"SELECT * FROM {self.table_name}{where_clause}" conn = await get_async_connection(self.db_path) try: cursor = await conn.execute(query) rows = await cursor.fetchall() finally: await conn.close() return [self._load(row) for row in rows]
[docs] async def get_by_field_async(self, **filters) -> list[BaseModel]: """Get records filtered by specified fields (async).""" conditions_list = [f"{key} = ?" for key in filters] conditions = " AND ".join(conditions_list) conditions = self._add_soft_delete_filter(conditions) where_clause = f" WHERE {conditions}" if conditions else "" values = tuple(filters.values()) query = f"SELECT * FROM {self.table_name}{where_clause}" conn = await get_async_connection(self.db_path) try: cursor = await conn.execute(query, values) rows = await cursor.fetchall() finally: await conn.close() return [self._load(row) for row in rows]
[docs] async def update_async(self, record_id: int, data: BaseModel) -> None: """Update a record in the database (async).""" self._call_hook(data, "pre_save") data_dict = self._dump(data) fields = ", ".join(f"{key} = ?" for key in data_dict.keys()) values = tuple(data_dict.values()) + (record_id,) query = f"UPDATE {self.table_name} SET {fields} WHERE id = ?" conn = await get_async_connection(self.db_path) try: await conn.execute(query, values) await conn.commit() finally: await conn.close() self._call_hook(data, "post_save")
[docs] async def delete_async(self, record_id: int) -> None: """Delete a record from the database (async, hard or soft).""" if self.soft_delete: from datetime import datetime now = datetime.now().isoformat() query = f"UPDATE {self.table_name} SET {self.deleted_at_field} = ? WHERE id = ?" values = (now, record_id) else: query = f"DELETE FROM {self.table_name} WHERE id = ?" values = (record_id,) conn = await get_async_connection(self.db_path) try: await conn.execute(query, values) await conn.commit() finally: await conn.close()
[docs] async def restore_async(self, record_id: int) -> None: """Restore a soft-deleted record (async).""" if not self.soft_delete: return query = f"UPDATE {self.table_name} SET {self.deleted_at_field} = NULL WHERE id = ?" conn = await get_async_connection(self.db_path) try: await conn.execute(query, (record_id,)) await conn.commit() finally: await conn.close()
[docs] async def search_async(self, query: str, order_by_rank: bool = True) -> list[BaseModel]: """Perform a full-text search on an FTS5 table (async). Args: query: The search query string. order_by_rank: Whether to sort results by relevance (default True). Returns: A list of matching model instances. """ config = getattr(self.model, "wsqlite_config", None) if not getattr(config, "use_fts5", False): raise OperationError("Search method is only available for FTS5-enabled models.") sql = f"SELECT * FROM {self.table_name} WHERE {self.table_name} MATCH ?" if order_by_rank: sql += " ORDER BY rank" conn = await get_async_connection(self.db_path) try: cursor = await conn.execute(sql, (query,)) rows = await cursor.fetchall() finally: await conn.close() return [self._load(row) for row in rows]
[docs] async def get_paginated_async( self, limit: int = 10, offset: int = 0, order_by: Optional[str] = None, order_desc: bool = False, ) -> list[BaseModel]: """Get records with pagination (async).""" validate_identifier(self.table_name) condition = self._soft_delete_condition() where_clause = f" WHERE {condition}" if condition else "" if order_by: validate_identifier(order_by) order_clause = f" ORDER BY {order_by} {'DESC' if order_desc else 'ASC'}" else: order_clause = "" query = f"SELECT * FROM {self.table_name}{where_clause}{order_clause} LIMIT ? OFFSET ?" conn = await get_async_connection(self.db_path) try: cursor = await conn.execute(query, (limit, offset)) rows = await cursor.fetchall() finally: await conn.close() return [self._load(row) for row in rows]
[docs] async def get_page_async(self, page: int = 1, per_page: int = 10) -> list[BaseModel]: """Get records by page number (async).""" if page < 1: page = 1 if per_page < 1: per_page = 10 offset = (page - 1) * per_page return await self.get_paginated_async(limit=per_page, offset=offset)
[docs] async def count_async(self) -> int: """Get total number of records in the table (async).""" validate_identifier(self.table_name) condition = self._soft_delete_condition() where_clause = f" WHERE {condition}" if condition else "" query = f"SELECT COUNT(*) FROM {self.table_name}{where_clause}" conn = await get_async_connection(self.db_path) try: cursor = await conn.execute(query) result = await cursor.fetchone() finally: await conn.close() return result[0] if result else 0
[docs] async def insert_many_async(self, data_list: list[BaseModel]) -> None: """Insert multiple records in a single transaction (async).""" if not data_list: return for data in data_list: self._call_hook(data, "pre_save") data_dicts = [self._dump(data) for data in data_list] fields = ", ".join(data_dicts[0].keys()) placeholders = ", ".join(["?"] * len(data_dicts[0])) query = f"INSERT INTO {self.table_name} ({fields}) VALUES ({placeholders})" conn = await get_async_connection(self.db_path) try: for data_dict in data_dicts: values = tuple(data_dict.values()) await conn.execute(query, values) await conn.commit() finally: await conn.close() for data in data_list: self._call_hook(data, "post_save")
[docs] async def update_many_async(self, updates: list[tuple[BaseModel, int]]) -> int: """Update multiple records (async).""" if not updates: return 0 for data, _ in updates: self._call_hook(data, "pre_save") validate_identifier(self.table_name) total_updated = 0 conn = await get_async_connection(self.db_path) try: for data, record_id in updates: data_dict = self._dump(data) fields = ", ".join(f"{key} = ?" for key in data_dict) values = tuple(data_dict.values()) + (record_id,) query = f"UPDATE {self.table_name} SET {fields} WHERE id = ?" await conn.execute(query, values) total_updated += conn.total_changes await conn.commit() finally: await conn.close() for data, _ in updates: self._call_hook(data, "post_save") return total_updated
[docs] async def delete_many_async(self, record_ids: list[int]) -> int: """Delete multiple records by their IDs (async, hard or soft).""" if not record_ids: return 0 validate_identifier(self.table_name) if self.soft_delete: from datetime import datetime now = datetime.now().isoformat() query = f"UPDATE {self.table_name} SET {self.deleted_at_field} = ? WHERE id = ?" params = [(now, rid) for rid in record_ids] else: query = f"DELETE FROM {self.table_name} WHERE id = ?" params = [(rid,) for rid in record_ids] conn = await get_async_connection(self.db_path) try: for p in params: await conn.execute(query, p) await conn.commit() finally: await conn.close() return len(record_ids)
[docs] async def execute_transaction_async(self, operations: list[tuple[str, tuple]]) -> list[Any]: """Execute multiple operations in a transaction (async).""" results = [] conn = await get_async_connection(self.db_path) try: for query, values in operations: cursor = await conn.execute(query, values) if cursor.description: result = await cursor.fetchall() results.append(result) await conn.commit() logger.info(f"Async transaction completed with {len(operations)} operations") except Exception as e: logger.error(f"Async transaction failed: {e}") await conn.close() raise TransactionError(f"Async transaction failed: {e}") from e finally: if not conn._connection: # not closed yet await conn.close() return results
[docs] async def with_transaction_async(self, func: Callable[[AsyncTransaction], Any]) -> Any: """Execute a function within a transaction (async).""" try: conn = await get_async_connection(self.db_path) async with conn: txn = AsyncTransaction(self.db_path) txn.conn = conn result = await func(txn) await txn.commit() logger.info("Async transaction completed successfully") return result except Exception as e: logger.error(f"Async transaction failed: {e}") raise TransactionError(f"Async transaction failed: {e}") from e