diff options
Diffstat (limited to 'db.py')
| -rw-r--r-- | db.py | 167 |
1 files changed, 167 insertions, 0 deletions
@@ -0,0 +1,167 @@ +import sqlite3 +import random +import os +from typing import Optional + +DB_PATH = os.path.join(os.path.dirname(__file__), "bot.db") +SEEDS_DIR = os.path.join(os.path.dirname(__file__), "seeds") + + +def _get_conn() -> sqlite3.Connection: + conn = sqlite3.connect(DB_PATH) + conn.row_factory = sqlite3.Row + return conn + + +def init_db(): + with _get_conn() as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS items ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + category TEXT NOT NULL COLLATE NOCASE, + name TEXT NOT NULL COLLATE NOCASE, + note TEXT NOT NULL DEFAULT '', + added_by TEXT NOT NULL DEFAULT 'unknown', + UNIQUE(category, name) + ) + """) + conn.commit() + + +def load_seed_file(category: str) -> list[tuple[str, str, str]]: + path = os.path.join(SEEDS_DIR, f"{category}.txt") + if not os.path.exists(path): + return [] + + entries: list[tuple[str, str, str]] = [] + with open(path, encoding="utf-8") as f: + lines = [line.rstrip("\n") for line in f.readlines()] + + # Split on blank lines to get per-entry blocks + blocks: list[list[str]] = [] + current: list[str] = [] + for line in lines: + if line.strip() == "": + if current: + blocks.append(current) + current = [] + else: + current.append(line) + if current: + blocks.append(current) + + for block in blocks: + if len(block) < 2: + continue + name = block[0].strip() + if len(block) == 2: + note = "" + added_by = block[1].strip() + else: + note = block[1].strip() + added_by = block[2].strip() + if name: + entries.append((name, note, added_by)) + + return entries + + +def seed(category: str, entries: Optional[list[tuple[str, str, str]]] = None): + if entries is None: + entries = load_seed_file(category) + if not entries: + return + with _get_conn() as conn: + row = conn.execute( + "SELECT COUNT(*) AS cnt FROM items WHERE category = ?", (category,) + ).fetchone() + if row["cnt"] == 0: + conn.executemany( + "INSERT OR IGNORE INTO items (category, name, note, added_by) VALUES (?, ?, ?, ?)", + [(category, name, note, added_by) for name, note, added_by in entries], + ) + conn.commit() + + +def add_item(category: str, name: str, added_by: str, note: str = "") -> str | None: + name = name.strip() + note = note.strip() + category = category.strip() + if not name: + return "❌ Name cannot be empty." + with _get_conn() as conn: + existing = conn.execute( + "SELECT id FROM items WHERE category = ? AND name = ?", (category, name) + ).fetchone() + if existing: + return f"❌ **{name}** is already in the list." + conn.execute( + "INSERT INTO items (category, name, note, added_by) VALUES (?, ?, ?, ?)", + (category, name, note, added_by), + ) + conn.commit() + return None + + +def remove_item(category: str, name: str) -> str | None: + name = name.strip() + if not name: + return "❌ Please provide a name to remove." + with _get_conn() as conn: + existing = conn.execute( + "SELECT id, name FROM items WHERE category = ? AND name = ?", (category, name) + ).fetchone() + if not existing: + return None + actual_name = existing["name"] + conn.execute("DELETE FROM items WHERE id = ?", (existing["id"],)) + conn.commit() + return actual_name + + +def find_item(category: str, keyword: str = "") -> sqlite3.Row | None: + """Return a random item from the category, optionally filtered by keyword in note.""" + keyword = keyword.strip().lower() + with _get_conn() as conn: + if keyword: + rows = conn.execute( + "SELECT name, note, added_by FROM items WHERE category = ? AND LOWER(note) LIKE ?", + (category, f"%{keyword}%"), + ).fetchall() + else: + rows = conn.execute( + "SELECT name, note, added_by FROM items WHERE category = ?", (category,) + ).fetchall() + if not rows: + return None + return random.choice(rows) + + +def migrate_from(old_db_path: str, category: str): + if not os.path.exists(old_db_path): + return + + old_conn = sqlite3.connect(old_db_path) + old_conn.row_factory = sqlite3.Row + try: + tables = old_conn.execute( + "SELECT name FROM sqlite_master WHERE type='table'" + ).fetchall() + if not tables: + return + table_name = tables[0]["name"] + rows = old_conn.execute( + f"SELECT name, note, added_by FROM {table_name}" # noqa: S608 + ).fetchall() + finally: + old_conn.close() + + if rows: + with _get_conn() as conn: + conn.executemany( + "INSERT OR IGNORE INTO items (category, name, note, added_by) VALUES (?, ?, ?, ?)", + [(category, r["name"], r["note"], r["added_by"]) for r in rows], + ) + conn.commit() + + os.remove(old_db_path) |
