Source code for esbonio.sphinx_agent.database
from __future__ import annotations
import pathlib
import sqlite3
from dataclasses import dataclass
from dataclasses import field
from typing import Any
from typing import Literal
[docs]
class Database:
[docs]
@dataclass
class Column:
name: str
dtype: str
notnull: bool = field(default=False)
default: Any | None = field(default=None)
pk: int = field(default=0)
@property
def definition(self):
# TODO: Is there a way to do this via a prepared statement?
return f"{self.name} {self.dtype}"
[docs]
@dataclass
class Table:
name: str
columns: list[Database.Column]
@property
def create_statement(self):
"""Return the SQL statement required to create this table."""
# TODO: Is there a way to do this via a prepared statement?
columns = ",".join([c.definition for c in self.columns])
return "".join([f"CREATE TABLE {self.name} (", columns, ");"])
def __init__(self, dbpath: pathlib.Path | Literal[":memory:"]):
self.path = dbpath
if isinstance(self.path, pathlib.Path) and not self.path.parent.exists():
self.path.parent.mkdir(parents=True)
self.db = sqlite3.connect(self.path)
# Ensure that Write Ahead Logging is enabled.
self.db.execute("PRAGMA journal_mode(WAL)")
self._checked_tables: set[str] = set()
def _get_table(self, name: str) -> Table | None:
"""Get the table with the given name, if it exists."""
# TODO: SQLite does not seem to like '?' syntax in this statement...
cursor = self.db.execute(f"PRAGMA table_info({name});")
rows = cursor.fetchall()
if len(rows) == 0:
# Table does not exist
return None
columns = [
self.Column(name=name, dtype=type_, notnull=notnull, default=default, pk=pk)
for (_, name, type_, notnull, default, pk) in rows
]
return self.Table(name=name, columns=columns)
def _create_table(self, table: Table):
"""Create the given table."""
cursor = self.db.cursor()
# TODO: Is there a way to do this via a prepared statement?
cursor.execute(f"DROP TABLE IF EXISTS {table.name}")
cursor.execute(table.create_statement)
self.db.commit()
[docs]
def clear_table(self, table: Table, **kwargs):
"""Clear the given table
Parameters
----------
kwargs
Constraints to limit the rows that get cleared
"""
# TODO: Is there a way to pass the table name as a '?' parameter?
base_query = f"DELETE FROM {table.name}" # noqa: S608
where: list[str] = []
parameters: list[Any] = []
for param, value in kwargs.items():
if value is None:
where.append(f"{param} is null")
else:
where.append(f"{param} = ?")
parameters.append(value)
if where:
conditions = " AND ".join(where)
query = " ".join([base_query, "WHERE", conditions])
else:
query = base_query
cursor = self.db.cursor()
cursor.execute(query, tuple(parameters))
self.db.commit()
[docs]
def ensure_table(self, table: Table):
"""Ensure that the given table exists in the database.
If the table *does* exist, but has the wrong shape, it will be dropped and
recreated.
"""
# If we've already checked the table, then there's nothing to do
if table.name in self._checked_tables:
return
if (existing := self._get_table(table.name)) is None:
self._create_table(table)
return
# Are the tables compatible?
if len(existing.columns) != len(table.columns):
self._create_table(table)
else:
for existing_col, col in zip(existing.columns, table.columns, strict=False):
if existing_col.name != col.name or existing_col.dtype != col.dtype:
self._create_table(table)
break
self._checked_tables.add(table.name)
[docs]
def insert_values(self, table: Table, values: list[tuple]):
"""Insert the given values into the given table."""
if len(values) == 0:
return
cursor = self.db.cursor()
placeholder = "(" + ",".join(["?" for _ in range(len(values[0]))]) + ")"
cursor.executemany(f"INSERT INTO {table.name} VALUES {placeholder}", values) # noqa: S608
self.db.commit()