diff options
Diffstat (limited to 'persistance.py')
| -rw-r--r-- | persistance.py | 134 |
1 files changed, 134 insertions, 0 deletions
diff --git a/persistance.py b/persistance.py new file mode 100644 index 0000000..69c6bb9 --- /dev/null +++ b/persistance.py @@ -0,0 +1,134 @@ +from pathlib import Path +import sqlite3 +from typing import Any + +DB_PATH = Path(__file__).resolve().parent / "data" / "spot_prices.db" + + +def _connect() -> sqlite3.Connection: + DB_PATH.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(DB_PATH) + conn.row_factory = sqlite3.Row + return conn + + +def init_storage() -> None: + with _connect() as conn: + conn.execute( + """ + CREATE TABLE IF NOT EXISTS spot_price_points ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + polled_at TEXT NOT NULL, + source_timestamp TEXT, + asg_name TEXT, + instance_id TEXT, + instance_type TEXT, + az TEXT, + spot_price REAL NOT NULL + ) + """ + ) + conn.execute( + """ + CREATE UNIQUE INDEX IF NOT EXISTS ux_spot_point_unique + ON spot_price_points(instance_id, source_timestamp, spot_price) + """ + ) + conn.commit() + + +def save_spot_datapoint(point: dict[str, Any]) -> bool: + init_storage() + + with _connect() as conn: + cursor = conn.execute( + """ + INSERT OR IGNORE INTO spot_price_points( + polled_at, + source_timestamp, + asg_name, + instance_id, + instance_type, + az, + spot_price + ) VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + point.get("polled_at"), + point.get("source_timestamp"), + point.get("asg_name"), + point.get("instance_id"), + point.get("instance_type"), + point.get("az"), + point.get("spot_price"), + ), + ) + conn.commit() + + return cursor.rowcount > 0 + + +def load_spot_datapoints(limit: int = 500) -> list[dict[str, Any]]: + init_storage() + limit = max(1, min(limit, 5000)) + + with _connect() as conn: + rows = conn.execute( + """ + SELECT + polled_at, + source_timestamp, + asg_name, + instance_id, + instance_type, + az, + spot_price + FROM spot_price_points + ORDER BY polled_at DESC + LIMIT ? + """, + (limit,), + ).fetchall() + + # Return ascending for chart rendering + return [dict(row) for row in reversed(rows)] + + +def get_peak_spot_price() -> float | None: + init_storage() + + with _connect() as conn: + row = conn.execute( + """ + SELECT MAX(spot_price) AS peak_spot_price + FROM spot_price_points + """ + ).fetchone() + + if not row: + return None + + peak = row["peak_spot_price"] + return float(peak) if peak is not None else None + + +def get_first_breach(max_pay: float) -> dict[str, Any] | None: + init_storage() + + with _connect() as conn: + row = conn.execute( + """ + SELECT + polled_at, + spot_price, + instance_type, + az + FROM spot_price_points + WHERE spot_price > ? + ORDER BY polled_at ASC + LIMIT 1 + """, + (max_pay,), + ).fetchone() + + return dict(row) if row else None |
