"""Table synchronization with Pydantic models."""
import re
from typing import Optional
from wsqlite.core.connection import get_async_connection, get_connection
from wsqlite.types.sql_types import get_sql_type
def validate_identifier(identifier: str) -> None:
"""Validate SQL identifier to prevent SQL injection."""
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", identifier):
from wsqlite.exceptions import SQLInjectionError
raise SQLInjectionError(f"Invalid identifier: {identifier}")
[docs]
class TableSync:
"""Handles table synchronization between Pydantic models and SQLite (sync)."""
[docs]
def __init__(self, model, db_path: str, table_name: Optional[str] = None):
"""Initialize table sync.
Args:
model: Pydantic BaseModel class.
db_path: Path to SQLite database file.
table_name: Optional custom table name.
"""
self.model = model
self.db_path = db_path
self.table_name = table_name or model.__name__.lower()
[docs]
def create_if_not_exists(self):
"""Create the table if it doesn't exist, handling FTS5 virtual tables."""
config = getattr(self.model, "wsqlite_config", None)
use_fts = getattr(config, "use_fts5", False)
if use_fts:
# Handle FTS5 table creation
fts_columns = [
field_name
for field_name, field in self.model.model_fields.items()
if get_sql_type(field) == "TEXT"
]
if not fts_columns:
raise TableSyncError("FTS5 table requires at least one TEXT field.")
columns_clause = ", ".join(fts_columns)
query = f"CREATE VIRTUAL TABLE IF NOT EXISTS {self.table_name} USING fts5({columns_clause})"
else:
# Standard table creation
column_defs = []
composite_uniques = {}
foreign_keys = []
for field_name, field in self.model.model_fields.items():
col_type = get_sql_type(field)
column_defs.append(f"{field_name} {col_type}")
description = (field.description or "").lower()
if "unique:" in description:
match = re.search(r"unique:([a-zA-Z0-9_]+)", description)
if match:
group = match.group(1)
composite_uniques.setdefault(group, []).append(field_name)
if "references:" in description:
match = re.search(r"references:([a-zA-Z0-9_]+)\.([a-zA-Z0-9_]+)", description)
if match:
foreign_keys.append((field_name, match.group(1), match.group(2)))
for fields in composite_uniques.values():
column_defs.append(f"UNIQUE({', '.join(fields)})")
for local_col, ref_table, ref_col in foreign_keys:
column_defs.append(f"FOREIGN KEY({local_col}) REFERENCES {ref_table}({ref_col})")
fields_clause = ", ".join(column_defs)
query = f"CREATE TABLE IF NOT EXISTS {self.table_name} ({fields_clause})"
with get_connection(self.db_path) as conn:
conn.execute(query)
conn.commit()
# Auto-create indexes for non-FTS tables
if not use_fts:
for field_name, field in self.model.model_fields.items():
description = (field.description or "").lower()
if "index" in description:
unique = "unique" in description and "unique:" not in description
self.create_index([field_name], unique=unique)
[docs]
def sync_with_model(self):
"""Sync the table with the Pydantic model, adding new columns if necessary."""
# FTS5 tables cannot be altered
config = getattr(self.model, "wsqlite_config", None)
if getattr(config, "use_fts5", False):
return
query = f"PRAGMA table_info({self.table_name})"
with get_connection(self.db_path) as conn:
cursor = conn.execute(query)
existing_columns = {row[1] for row in cursor.fetchall()}
model_fields = set(self.model.model_fields.keys())
new_fields = model_fields - existing_columns
if new_fields:
with get_connection(self.db_path) as conn:
for field in new_fields:
field_type = get_sql_type(self.model.model_fields[field])
alter_query = f"ALTER TABLE {self.table_name} ADD COLUMN {field} {field_type}"
conn.execute(alter_query)
conn.commit()
[docs]
def table_exists(self) -> bool:
"""Check if the table exists in the database."""
query = f"SELECT name FROM sqlite_master WHERE type='table' AND name=?"
with get_connection(self.db_path) as conn:
cursor = conn.execute(query, (self.table_name,))
return cursor.fetchone() is not None
[docs]
def drop_table(self):
"""Drop the table from the database."""
query = f"DROP TABLE IF EXISTS {self.table_name}"
with get_connection(self.db_path) as conn:
conn.execute(query)
conn.commit()
[docs]
def get_columns(self) -> list[str]:
"""Get list of column names in the table."""
query = f"PRAGMA table_info({self.table_name})"
with get_connection(self.db_path) as conn:
cursor = conn.execute(query)
return [row[1] for row in cursor.fetchall()]
[docs]
def create_index(
self, columns: list[str], index_name: Optional[str] = None, unique: bool = False
):
"""Create an index on the specified columns."""
if index_name is None:
index_name = f"idx_{self.table_name}_{'_'.join(columns)}"
columns_str = ", ".join(columns)
unique_str = "UNIQUE " if unique else ""
query = f"CREATE {unique_str}INDEX IF NOT EXISTS {index_name} ON {self.table_name} ({columns_str})"
with get_connection(self.db_path) as conn:
conn.execute(query)
conn.commit()
[docs]
def drop_index(self, index_name: str):
"""Drop an index from the table."""
query = f"DROP INDEX IF EXISTS {index_name}"
with get_connection(self.db_path) as conn:
conn.execute(query)
conn.commit()
[docs]
def get_indexes(self) -> list[dict]:
"""Get list of indexes on the table."""
query = f"PRAGMA index_list({self.table_name})"
with get_connection(self.db_path) as conn:
cursor = conn.execute(query)
indexes = []
for row in cursor.fetchall():
idx_name = row[1]
idx_info = f"PRAGMA index_info({idx_name})"
idx_cursor = conn.execute(idx_info)
col_names = [col[2] for col in idx_cursor.fetchall()]
indexes.append(
{
"name": idx_name,
"unique": bool(row[2]),
"columns": col_names,
}
)
return indexes
[docs]
class AsyncTableSync:
"""Handles table synchronization between Pydantic models and SQLite (async)."""
[docs]
def __init__(self, model, db_path: str, table_name: Optional[str] = None):
"""Initialize async table sync."""
self.model = model
self.db_path = db_path
self.table_name = table_name or model.__name__.lower()
[docs]
async def create_if_not_exists_async(self):
"""Create the table if it doesn't exist (async)."""
column_defs = []
composite_uniques = {}
foreign_keys = []
for field_name, field in self.model.model_fields.items():
col_type = get_sql_type(field)
column_defs.append(f"{field_name} {col_type}")
description = (field.description or "").lower()
if "unique:" in description:
match = re.search(r"unique:([a-zA-Z0-9_]+)", description)
if match:
group = match.group(1)
if group not in composite_uniques:
composite_uniques[group] = []
composite_uniques[group].append(field_name)
if "references:" in description:
match = re.search(r"references:([a-zA-Z0-9_]+)\.([a-zA-Z0-9_]+)", description)
if match:
ref_table = match.group(1)
ref_col = match.group(2)
foreign_keys.append((field_name, ref_table, ref_col))
for fields in composite_uniques.values():
fields_str = ", ".join(fields)
column_defs.append(f"UNIQUE({fields_str})")
for local_col, ref_table, ref_col in foreign_keys:
column_defs.append(f"FOREIGN KEY({local_col}) REFERENCES {ref_table}({ref_col})")
fields_clause = ", ".join(column_defs)
query = f"CREATE TABLE IF NOT EXISTS {self.table_name} ({fields_clause})"
conn = await get_async_connection(self.db_path)
try:
await conn.execute(query)
await conn.commit()
finally:
await conn.close()
# Auto-create indexes
for field_name, field in self.model.model_fields.items():
description = (field.description or "").lower()
if "index" in description:
unique = "unique" in description and "unique:" not in description
await self.create_index_async([field_name], unique=unique)
[docs]
async def sync_with_model_async(self):
"""Sync the table with the Pydantic model, adding new columns if necessary (async)."""
query = f"PRAGMA table_info({self.table_name})"
conn = await get_async_connection(self.db_path)
try:
cursor = await conn.execute(query)
rows = await cursor.fetchall()
existing_columns = {row[1] for row in rows}
finally:
await conn.close()
model_fields = set(self.model.model_fields.keys())
new_fields = model_fields - existing_columns
if new_fields:
conn = await get_async_connection(self.db_path)
try:
for field in new_fields:
field_type = get_sql_type(self.model.model_fields[field])
alter_query = f"ALTER TABLE {self.table_name} ADD COLUMN {field} {field_type}"
await conn.execute(alter_query)
await conn.commit()
finally:
await conn.close()
[docs]
async def table_exists_async(self) -> bool:
"""Check if the table exists in the database (async)."""
query = f"SELECT name FROM sqlite_master WHERE type='table' AND name=?"
conn = await get_async_connection(self.db_path)
try:
cursor = await conn.execute(query, (self.table_name,))
result = await cursor.fetchone()
return result is not None
finally:
await conn.close()
[docs]
async def drop_table_async(self):
"""Drop the table from the database (async)."""
query = f"DROP TABLE IF EXISTS {self.table_name}"
conn = await get_async_connection(self.db_path)
try:
await conn.execute(query)
await conn.commit()
finally:
await conn.close()
[docs]
async def get_columns_async(self) -> list[str]:
"""Get list of column names in the table (async)."""
query = f"PRAGMA table_info({self.table_name})"
conn = await get_async_connection(self.db_path)
try:
cursor = await conn.execute(query)
rows = await cursor.fetchall()
return [row[1] for row in rows]
finally:
await conn.close()
[docs]
async def create_index_async(
self, columns: list[str], index_name: Optional[str] = None, unique: bool = False
):
"""Create an index on the specified columns (async)."""
if index_name is None:
index_name = f"idx_{self.table_name}_{'_'.join(columns)}"
columns_str = ", ".join(columns)
unique_str = "UNIQUE " if unique else ""
query = f"CREATE {unique_str}INDEX IF NOT EXISTS {index_name} ON {self.table_name} ({columns_str})"
conn = await get_async_connection(self.db_path)
try:
await conn.execute(query)
await conn.commit()
finally:
await conn.close()
[docs]
async def drop_index_async(self, index_name: str):
"""Drop an index from the table (async)."""
query = f"DROP INDEX IF EXISTS {index_name}"
conn = await get_async_connection(self.db_path)
try:
await conn.execute(query)
await conn.commit()
finally:
await conn.close()
[docs]
async def get_indexes_async(self) -> list[dict]:
"""Get list of indexes on the table (async)."""
query = f"PRAGMA index_list({self.table_name})"
conn = await get_async_connection(self.db_path)
try:
cursor = await conn.execute(query)
indexes = []
for row in await cursor.fetchall():
idx_name = row[1]
idx_info = f"PRAGMA index_info({idx_name})"
idx_cursor = await conn.execute(idx_info)
col_names = [col[2] for col in await idx_cursor.fetchall()]
indexes.append(
{
"name": idx_name,
"unique": bool(row[2]),
"columns": col_names,
}
)
return indexes
finally:
await conn.close()