Source code for wsqlite.core.sync
"""Table synchronization with Pydantic models."""
import re
from typing import Optional
from wsqlite.core.connection import get_async_connection, get_connection
from wsqlite.types.sql_types import get_sql_type
def validate_identifier(identifier: str) -> None:
"""Validate SQL identifier to prevent SQL injection."""
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", identifier):
from wsqlite.exceptions import SQLInjectionError
raise SQLInjectionError(f"Invalid identifier: {identifier}")
[docs]
class TableSync:
"""Handles table synchronization between Pydantic models and SQLite (sync)."""
[docs]
def __init__(self, model, db_path: str):
"""Initialize table sync.
Args:
model: Pydantic BaseModel class.
db_path: Path to SQLite database file.
"""
self.model = model
self.db_path = db_path
self.table_name = model.__name__.lower()
[docs]
def create_if_not_exists(self):
"""Create the table if it doesn't exist."""
fields = ", ".join(
f"{field} {get_sql_type(typ)}" for field, typ in self.model.model_fields.items()
)
query = f"CREATE TABLE IF NOT EXISTS {self.table_name} ({fields})"
with get_connection(self.db_path) as conn:
conn.execute(query)
conn.commit()
[docs]
def sync_with_model(self):
"""Sync the table with the Pydantic model, adding new columns if necessary."""
query = f"PRAGMA table_info({self.table_name})"
with get_connection(self.db_path) as conn:
cursor = conn.execute(query)
existing_columns = {row[1] for row in cursor.fetchall()}
model_fields = set(self.model.model_fields.keys())
new_fields = model_fields - existing_columns
if new_fields:
with get_connection(self.db_path) as conn:
for field in new_fields:
field_type = get_sql_type(self.model.model_fields[field])
alter_query = f"ALTER TABLE {self.table_name} ADD COLUMN {field} {field_type}"
conn.execute(alter_query)
conn.commit()
[docs]
def table_exists(self) -> bool:
"""Check if the table exists in the database."""
query = f"SELECT name FROM sqlite_master WHERE type='table' AND name=?"
with get_connection(self.db_path) as conn:
cursor = conn.execute(query, (self.table_name,))
return cursor.fetchone() is not None
[docs]
def drop_table(self):
"""Drop the table from the database."""
query = f"DROP TABLE IF EXISTS {self.table_name}"
with get_connection(self.db_path) as conn:
conn.execute(query)
conn.commit()
[docs]
def get_columns(self) -> list[str]:
"""Get list of column names in the table."""
query = f"PRAGMA table_info({self.table_name})"
with get_connection(self.db_path) as conn:
cursor = conn.execute(query)
return [row[1] for row in cursor.fetchall()]
[docs]
def create_index(
self, columns: list[str], index_name: Optional[str] = None, unique: bool = False
):
"""Create an index on the specified columns."""
if index_name is None:
index_name = f"idx_{self.table_name}_{'_'.join(columns)}"
columns_str = ", ".join(columns)
unique_str = "UNIQUE " if unique else ""
query = f"CREATE {unique_str}INDEX IF NOT EXISTS {index_name} ON {self.table_name} ({columns_str})"
with get_connection(self.db_path) as conn:
conn.execute(query)
conn.commit()
[docs]
def drop_index(self, index_name: str):
"""Drop an index from the table."""
query = f"DROP INDEX IF EXISTS {index_name}"
with get_connection(self.db_path) as conn:
conn.execute(query)
conn.commit()
[docs]
def get_indexes(self) -> list[dict]:
"""Get list of indexes on the table."""
query = f"PRAGMA index_list({self.table_name})"
with get_connection(self.db_path) as conn:
cursor = conn.execute(query)
indexes = []
for row in cursor.fetchall():
idx_name = row[1]
idx_info = f"PRAGMA index_info({idx_name})"
idx_cursor = conn.execute(idx_info)
col_names = [col[2] for col in idx_cursor.fetchall()]
indexes.append(
{
"name": idx_name,
"unique": bool(row[2]),
"columns": col_names,
}
)
return indexes
[docs]
class AsyncTableSync:
"""Handles table synchronization between Pydantic models and SQLite (async)."""
[docs]
def __init__(self, model, db_path: str):
"""Initialize async table sync."""
self.model = model
self.db_path = db_path
self.table_name = model.__name__.lower()
[docs]
async def create_if_not_exists_async(self):
"""Create the table if it doesn't exist (async)."""
fields = ", ".join(
f"{field} {get_sql_type(typ)}" for field, typ in self.model.model_fields.items()
)
query = f"CREATE TABLE IF NOT EXISTS {self.table_name} ({fields})"
conn = await get_async_connection(self.db_path)
try:
await conn.execute(query)
await conn.commit()
finally:
await conn.close()
[docs]
async def sync_with_model_async(self):
"""Sync the table with the Pydantic model, adding new columns if necessary (async)."""
query = f"PRAGMA table_info({self.table_name})"
conn = await get_async_connection(self.db_path)
try:
cursor = await conn.execute(query)
rows = await cursor.fetchall()
existing_columns = {row[1] for row in rows}
finally:
await conn.close()
model_fields = set(self.model.model_fields.keys())
new_fields = model_fields - existing_columns
if new_fields:
conn = await get_async_connection(self.db_path)
try:
for field in new_fields:
field_type = get_sql_type(self.model.model_fields[field])
alter_query = f"ALTER TABLE {self.table_name} ADD COLUMN {field} {field_type}"
await conn.execute(alter_query)
await conn.commit()
finally:
await conn.close()
[docs]
async def table_exists_async(self) -> bool:
"""Check if the table exists in the database (async)."""
query = f"SELECT name FROM sqlite_master WHERE type='table' AND name=?"
conn = await get_async_connection(self.db_path)
try:
cursor = await conn.execute(query, (self.table_name,))
result = await cursor.fetchone()
return result is not None
finally:
await conn.close()
[docs]
async def drop_table_async(self):
"""Drop the table from the database (async)."""
query = f"DROP TABLE IF EXISTS {self.table_name}"
conn = await get_async_connection(self.db_path)
try:
await conn.execute(query)
await conn.commit()
finally:
await conn.close()
[docs]
async def get_columns_async(self) -> list[str]:
"""Get list of column names in the table (async)."""
query = f"PRAGMA table_info({self.table_name})"
conn = await get_async_connection(self.db_path)
try:
cursor = await conn.execute(query)
rows = await cursor.fetchall()
return [row[1] for row in rows]
finally:
await conn.close()
[docs]
async def create_index_async(
self, columns: list[str], index_name: Optional[str] = None, unique: bool = False
):
"""Create an index on the specified columns (async)."""
if index_name is None:
index_name = f"idx_{self.table_name}_{'_'.join(columns)}"
columns_str = ", ".join(columns)
unique_str = "UNIQUE " if unique else ""
query = f"CREATE {unique_str}INDEX IF NOT EXISTS {index_name} ON {self.table_name} ({columns_str})"
conn = await get_async_connection(self.db_path)
try:
await conn.execute(query)
await conn.commit()
finally:
await conn.close()
[docs]
async def drop_index_async(self, index_name: str):
"""Drop an index from the table (async)."""
query = f"DROP INDEX IF EXISTS {index_name}"
conn = await get_async_connection(self.db_path)
try:
await conn.execute(query)
await conn.commit()
finally:
await conn.close()
[docs]
async def get_indexes_async(self) -> list[dict]:
"""Get list of indexes on the table (async)."""
query = f"PRAGMA index_list({self.table_name})"
conn = await get_async_connection(self.db_path)
try:
cursor = await conn.execute(query)
indexes = []
for row in await cursor.fetchall():
idx_name = row[1]
idx_info = f"PRAGMA index_info({idx_name})"
idx_cursor = await conn.execute(idx_info)
col_names = [col[2] for col in await idx_cursor.fetchall()]
indexes.append(
{
"name": idx_name,
"unique": bool(row[2]),
"columns": col_names,
}
)
return indexes
finally:
await conn.close()