smaug/db/db.py
claude b119b9abae feat: SQLAlchemy ORM models, filing cache incremental fetch, yfinance price cache
- Replace db/schema.sql + raw sqlite3 with SQLAlchemy ORM (db/models.py)
  - Filing, Signal, PriceCache models with proper indexes
  - db/db.py uses SQLAlchemy sessions throughout; no raw SQL strings
- Add PriceCache table: stores daily close prices per ticker
  - backtest._fetch_prices checks DB first; skips yfinance for completed ranges
  - New data persisted via upsert_prices()
  - get_cached_prices() / upsert_prices() added to db.py
- EDGAR poller incremental fetch: get_latest_filed_date() returns newest
  filed_date in DB; fetch_and_store_new_filings skips entries older than
  that cutoff before even checking accession_exists
- Add get_signals_for_backtest() to db.py; backtest no longer opens its
  own sqlite3 connection
- requirements.txt: add sqlalchemy>=2.0.0

Co-authored-by: dodox <dodox@users.noreply.local>
2026-05-04 17:21:23 +00:00

227 lines
6.8 KiB
Python

from datetime import datetime
from typing import Optional
from sqlalchemy import create_engine, func, select, text, update
from sqlalchemy.orm import Session
import config
from db.models import Base, Filing, PriceCache, Signal
def _engine():
url = f"sqlite:///{config.DB_PATH}"
return create_engine(url, connect_args={"check_same_thread": False})
_ENGINE = None
def _get_engine():
global _ENGINE
if _ENGINE is None:
_ENGINE = _engine()
return _ENGINE
def init_db():
engine = _get_engine()
with engine.connect() as conn:
conn.execute(text("PRAGMA journal_mode=WAL"))
conn.execute(text("PRAGMA foreign_keys=ON"))
Base.metadata.create_all(engine)
def _session() -> Session:
return Session(_get_engine())
def insert_filing(filing: dict) -> bool:
with _session() as session:
exists = session.scalar(
select(Filing.id).where(Filing.accession_number == filing["accession_number"])
)
if exists is not None:
return False
row = Filing(
accession_number=filing["accession_number"],
ticker=filing.get("ticker"),
cik=filing.get("cik"),
insider_name=filing.get("insider_name"),
role=filing.get("role"),
transaction_date=filing.get("transaction_date"),
filed_date=filing.get("filed_date"),
shares=filing.get("shares"),
price=filing.get("price"),
total_value=filing.get("total_value"),
flag=filing.get("flag"),
is_10b51=bool(filing.get("is_10b51", False)),
post_tx_shares=filing.get("post_tx_shares"),
)
session.add(row)
session.commit()
return True
def accession_exists(accession_number: str) -> bool:
with _session() as session:
return session.scalar(
select(Filing.id).where(Filing.accession_number == accession_number)
) is not None
def get_latest_filed_date() -> Optional[str]:
with _session() as session:
return session.scalar(select(func.max(Filing.filed_date)))
def insert_signal(signal: dict) -> int:
with _session() as session:
row = Signal(
ticker=signal["ticker"],
trigger_date=signal["trigger_date"],
cluster_size=signal["cluster_size"],
total_cluster_value=signal.get("total_cluster_value", 0.0),
score=signal["score"],
)
session.add(row)
session.commit()
return row.id
def mark_signal_alerted(signal_id: int):
with _session() as session:
session.execute(
update(Signal).where(Signal.id == signal_id).values(alerted=True)
)
session.commit()
def mark_signal_executed(signal_id: int):
with _session() as session:
session.execute(
update(Signal)
.where(Signal.id == signal_id)
.values(executed=True, executed_at=datetime.utcnow())
)
session.commit()
def mark_signal_closed(signal_id: int):
with _session() as session:
session.execute(
update(Signal).where(Signal.id == signal_id).values(closed=True)
)
session.commit()
def get_unalerted_signals() -> list[dict]:
with _session() as session:
rows = session.scalars(
select(Signal).where(Signal.alerted == False).order_by(Signal.created_at)
).all()
return [_signal_to_dict(r) for r in rows]
def get_executed_unclosed_signals() -> list[dict]:
with _session() as session:
rows = session.scalars(
select(Signal).where(
Signal.executed == True,
Signal.closed == False,
Signal.executed_at.is_not(None),
)
).all()
return [_signal_to_dict(r) for r in rows]
def get_recent_buys_for_ticker(ticker: str, window_days: int) -> list[dict]:
from datetime import timedelta
cutoff = (datetime.utcnow() - timedelta(days=window_days)).strftime("%Y-%m-%d")
with _session() as session:
rows = session.scalars(
select(Filing)
.where(
Filing.ticker == ticker,
Filing.flag == "A",
Filing.is_10b51 == False,
Filing.transaction_date >= cutoff,
)
.order_by(Filing.transaction_date.desc())
).all()
return [_filing_to_dict(r) for r in rows]
def get_signals_for_backtest(min_score: float, min_cluster_size: int) -> list[dict]:
with _session() as session:
rows = session.scalars(
select(Signal).where(
Signal.score >= min_score,
Signal.cluster_size >= min_cluster_size,
)
).all()
return [_signal_to_dict(r) for r in rows]
def get_cached_prices(ticker: str, start_date: str, end_date: str) -> dict[str, float]:
with _session() as session:
rows = session.scalars(
select(PriceCache).where(
PriceCache.ticker == ticker,
PriceCache.date >= start_date,
PriceCache.date <= end_date,
)
).all()
return {r.date: r.close for r in rows}
def upsert_prices(ticker: str, prices: dict[str, float]):
with _session() as session:
for date_str, close in prices.items():
existing = session.scalar(
select(PriceCache).where(
PriceCache.ticker == ticker,
PriceCache.date == date_str,
)
)
if existing is None:
session.add(PriceCache(ticker=ticker, date=date_str, close=close))
session.commit()
def _filing_to_dict(row: Filing) -> dict:
return {
"id": row.id,
"accession_number": row.accession_number,
"ticker": row.ticker,
"cik": row.cik,
"insider_name": row.insider_name,
"role": row.role,
"transaction_date": row.transaction_date,
"filed_date": row.filed_date,
"shares": row.shares,
"price": row.price,
"total_value": row.total_value,
"flag": row.flag,
"is_10b51": row.is_10b51,
"post_tx_shares": row.post_tx_shares,
"created_at": row.created_at.isoformat() if row.created_at else None,
}
def _signal_to_dict(row: Signal) -> dict:
return {
"id": row.id,
"ticker": row.ticker,
"trigger_date": row.trigger_date,
"cluster_size": row.cluster_size,
"total_cluster_value": row.total_cluster_value,
"score": row.score,
"alerted": row.alerted,
"executed": row.executed,
"executed_at": row.executed_at.strftime("%Y-%m-%dT%H:%M:%SZ") if row.executed_at else None,
"closed": row.closed,
"created_at": row.created_at.isoformat() if row.created_at else None,
}