Source code for wsqlite.builders.query_builder
"""Query builder for constructing SQL queries safely."""
import re
from typing import Any, Optional
from wsqlite.exceptions import SQLInjectionError
def validate_identifier(identifier: str) -> None:
"""Validate SQL identifier to prevent SQL injection."""
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", identifier):
raise SQLInjectionError(f"Invalid identifier: {identifier}")
[docs]
class QueryBuilder:
"""Builder for constructing SQL queries safely."""
[docs]
def __init__(self, table_name: str):
"""Initialize query builder."""
validate_identifier(table_name)
self.table_name = table_name
self._where_clauses: list[str] = []
self._where_values: list[Any] = []
self._order_by: Optional[str] = None
self._order_desc: bool = False
self._limit_value: Optional[int] = None
self._offset_value: Optional[int] = None
[docs]
def where(self, field: str, operator: str, value: Any) -> "QueryBuilder":
"""Add WHERE condition."""
validate_identifier(field)
valid_operators = {
"=",
"<",
">",
"<=",
">=",
"!=",
"LIKE",
"IN",
"IS NULL",
"IS NOT NULL",
}
if operator.upper() not in valid_operators:
raise ValueError(f"Invalid operator: {operator}")
if operator.upper() == "IN" and not isinstance(value, (list, tuple)):
raise ValueError("IN operator requires a list or tuple")
if operator.upper() in ("IS NULL", "IS NOT NULL"):
self._where_clauses.append(f"{field} {operator.upper()}")
elif operator.upper() == "IN":
placeholders = ", ".join(["?"] * len(value))
self._where_clauses.append(f"{field} {operator.upper()} ({placeholders})")
self._where_values.extend(value)
else:
self._where_clauses.append(f"{field} {operator} ?")
self._where_values.append(value)
return self
[docs]
def order_by(self, field: str, descending: bool = False) -> "QueryBuilder":
"""Add ORDER BY clause."""
validate_identifier(field)
self._order_by = field
self._order_desc = descending
return self
[docs]
def limit(self, limit: int) -> "QueryBuilder":
"""Add LIMIT clause."""
if limit < 0:
raise ValueError("Limit must be non-negative")
self._limit_value = limit
return self
[docs]
def offset(self, offset: int) -> "QueryBuilder":
"""Add OFFSET clause."""
if offset < 0:
raise ValueError("Offset must be non-negative")
self._offset_value = offset
return self
[docs]
def build_select(self) -> tuple[str, tuple]:
"""Build SELECT query."""
query = f"SELECT * FROM {self.table_name}"
if self._where_clauses:
query += " WHERE " + " AND ".join(self._where_clauses)
if self._order_by:
direction = "DESC" if self._order_desc else "ASC"
query += f" ORDER BY {self._order_by} {direction}"
if self._limit_value is not None:
query += f" LIMIT {self._limit_value}"
if self._offset_value is not None:
query += f" OFFSET {self._offset_value}"
return query, tuple(self._where_values)
[docs]
def build_count(self) -> tuple[str, tuple]:
"""Build COUNT query."""
query = f"SELECT COUNT(*) FROM {self.table_name}"
if self._where_clauses:
query += " WHERE " + " AND ".join(self._where_clauses)
return query, tuple(self._where_values)
[docs]
def build_delete(self) -> tuple[str, tuple]:
"""Build DELETE query."""
if not self._where_clauses:
raise ValueError("DELETE requires WHERE clause")
query = f"DELETE FROM {self.table_name}"
query += " WHERE " + " AND ".join(self._where_clauses)
return query, tuple(self._where_values)
[docs]
def reset(self) -> "QueryBuilder":
"""Reset the builder to initial state."""
self._where_clauses = []
self._where_values = []
self._order_by = None
self._order_desc = False
self._limit_value = None
self._offset_value = None
return self