smaug/db/db.py
Dominik Roth d0e98b9cb7 feat: cap-tier filtering, Alpaca cost model, README cleanup
- simulate.py: --cap-tier large|mid|small|micro; yfinance market cap fetch
  with DB cache (ticker_meta table); argv fix for main.py dispatch
- plot.py: equity curves now show cap tiers with Alpaca costs (zero commission);
  HP sweep uses Alpaca cost decomposition; SPY line clamped to last strategy date
- db/models.py: TickerMeta table
- db/db.py: get_cached_market_caps, upsert_market_caps
- README: add --cap-tier to simulate docs; backfill note (~3 days for 2 years
  at SEC 10 req/s limit); remove duplicate setup block; remove em-dashes in prose;
  results table tilde estimates to be updated once cap-tier sims complete

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-26 18:10:09 +02:00

304 lines
9.4 KiB
Python

from datetime import datetime
from typing import Optional
from sqlalchemy import create_engine, func, select, text, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
import config
from db.models import Base, Filing, PriceCache, Signal, TickerMeta
def _engine():
url = f"sqlite:///{config.DB_PATH}"
return create_engine(url, connect_args={"check_same_thread": False})
_ENGINE = None
def _get_engine():
global _ENGINE
if _ENGINE is None:
_ENGINE = _engine()
return _ENGINE
def init_db():
engine = _get_engine()
with engine.connect() as conn:
conn.execute(text("PRAGMA journal_mode=WAL"))
conn.execute(text("PRAGMA foreign_keys=ON"))
Base.metadata.create_all(engine)
def _session() -> Session:
return Session(_get_engine())
def insert_filing(filing: dict) -> bool:
with _session() as session:
exists = session.scalar(
select(Filing.id).where(Filing.accession_number == filing["accession_number"])
)
if exists is not None:
return False
row = Filing(
accession_number=filing["accession_number"],
ticker=filing.get("ticker"),
cik=filing.get("cik"),
insider_name=filing.get("insider_name"),
role=filing.get("role"),
transaction_date=filing.get("transaction_date"),
filed_date=filing.get("filed_date"),
shares=filing.get("shares"),
price=filing.get("price"),
total_value=filing.get("total_value"),
flag=filing.get("flag"),
is_10b51=bool(filing.get("is_10b51", False)),
post_tx_shares=filing.get("post_tx_shares"),
)
session.add(row)
try:
session.commit()
return True
except IntegrityError:
session.rollback()
return False
def accession_exists(accession_number: str) -> bool:
with _session() as session:
return session.scalar(
select(Filing.id).where(Filing.accession_number == accession_number)
) is not None
def mark_accession_seen(accession_number: str) -> None:
"""Store a placeholder row so derivative-only/empty filings aren't re-fetched."""
with _session() as session:
exists = session.scalar(
select(Filing.id).where(Filing.accession_number == accession_number)
)
if exists is not None:
return
session.add(Filing(accession_number=accession_number))
try:
session.commit()
except IntegrityError:
session.rollback()
def filter_new_accessions(accessions: list[str]) -> set[str]:
"""Return the subset of accessions not already in the DB."""
if not accessions:
return set()
existing: set[str] = set()
chunk_size = 900 # SQLite SQLITE_MAX_VARIABLE_NUMBER limit
with _session() as session:
for i in range(0, len(accessions), chunk_size):
chunk = accessions[i:i + chunk_size]
existing.update(
session.execute(
select(Filing.accession_number).where(
Filing.accession_number.in_(chunk)
)
).scalars().all()
)
return set(accessions) - existing
def get_latest_filed_date() -> Optional[str]:
with _session() as session:
return session.scalar(select(func.max(Filing.filed_date)))
def insert_signal(signal: dict) -> int:
with _session() as session:
row = Signal(
ticker=signal["ticker"],
trigger_date=signal["trigger_date"],
cluster_size=signal["cluster_size"],
total_cluster_value=signal.get("total_cluster_value", 0.0),
score=signal["score"],
)
session.add(row)
session.commit()
return row.id
def mark_signal_alerted(signal_id: int):
with _session() as session:
session.execute(
update(Signal).where(Signal.id == signal_id).values(alerted=True)
)
session.commit()
def mark_signal_executed(signal_id: int):
with _session() as session:
session.execute(
update(Signal)
.where(Signal.id == signal_id)
.values(executed=True, executed_at=datetime.utcnow())
)
session.commit()
def mark_signal_closed(signal_id: int):
with _session() as session:
session.execute(
update(Signal).where(Signal.id == signal_id).values(closed=True)
)
session.commit()
def get_unalerted_signals() -> list[dict]:
with _session() as session:
rows = session.scalars(
select(Signal).where(Signal.alerted == False).order_by(Signal.created_at)
).all()
return [_signal_to_dict(r) for r in rows]
def get_executed_unclosed_signals() -> list[dict]:
with _session() as session:
rows = session.scalars(
select(Signal).where(
Signal.executed == True,
Signal.closed == False,
Signal.executed_at.is_not(None),
)
).all()
return [_signal_to_dict(r) for r in rows]
def get_recent_buys_for_ticker(
ticker: str, window_days: int, as_of_date: Optional[str] = None
) -> list[dict]:
from datetime import timedelta
ref = datetime.strptime(as_of_date, "%Y-%m-%d") if as_of_date else datetime.utcnow()
cutoff = (ref - timedelta(days=window_days)).strftime("%Y-%m-%d")
ref_str = ref.strftime("%Y-%m-%d")
with _session() as session:
rows = session.scalars(
select(Filing)
.where(
Filing.ticker == ticker,
Filing.flag == "A",
Filing.is_10b51 == False,
Filing.transaction_date >= cutoff,
Filing.transaction_date <= ref_str,
)
.order_by(Filing.transaction_date.desc())
).all()
return [_filing_to_dict(r) for r in rows]
def get_all_buys_for_reprocess() -> list[dict]:
"""Return all buy filings for signal reprocessing (used after bulk ingest)."""
with _session() as session:
rows = session.scalars(
select(Filing)
.where(Filing.flag == "A")
.order_by(Filing.transaction_date)
).all()
return [_filing_to_dict(r) for r in rows]
def get_signals_for_backtest(min_score: float, min_cluster_size: int) -> list[dict]:
with _session() as session:
rows = session.scalars(
select(Signal).where(
Signal.score >= min_score,
Signal.cluster_size >= min_cluster_size,
)
).all()
return [_signal_to_dict(r) for r in rows]
def get_cached_market_caps(tickers: list[str]) -> dict[str, float]:
if not tickers:
return {}
with _session() as session:
rows = session.scalars(
select(TickerMeta).where(TickerMeta.ticker.in_(tickers))
).all()
return {r.ticker: r.market_cap for r in rows if r.market_cap is not None}
def upsert_market_caps(caps: dict[str, float]) -> None:
with _session() as session:
for ticker, cap in caps.items():
existing = session.get(TickerMeta, ticker)
if existing:
existing.market_cap = cap
existing.fetched_at = datetime.utcnow()
else:
session.add(TickerMeta(ticker=ticker, market_cap=cap))
session.commit()
def get_cached_prices(ticker: str, start_date: str, end_date: str) -> dict[str, float]:
with _session() as session:
rows = session.scalars(
select(PriceCache).where(
PriceCache.ticker == ticker,
PriceCache.date >= start_date,
PriceCache.date <= end_date,
)
).all()
return {r.date: r.close for r in rows}
def upsert_prices(ticker: str, prices: dict[str, float]):
with _session() as session:
for date_str, close in prices.items():
existing = session.scalar(
select(PriceCache).where(
PriceCache.ticker == ticker,
PriceCache.date == date_str,
)
)
if existing is None:
session.add(PriceCache(ticker=ticker, date=date_str, close=close))
session.commit()
def _filing_to_dict(row: Filing) -> dict:
return {
"id": row.id,
"accession_number": row.accession_number,
"ticker": row.ticker,
"cik": row.cik,
"insider_name": row.insider_name,
"role": row.role,
"transaction_date": row.transaction_date,
"filed_date": row.filed_date,
"shares": row.shares,
"price": row.price,
"total_value": row.total_value,
"flag": row.flag,
"is_10b51": row.is_10b51,
"post_tx_shares": row.post_tx_shares,
"created_at": row.created_at.isoformat() if row.created_at else None,
}
def _signal_to_dict(row: Signal) -> dict:
return {
"id": row.id,
"ticker": row.ticker,
"trigger_date": row.trigger_date,
"cluster_size": row.cluster_size,
"total_cluster_value": row.total_cluster_value,
"score": row.score,
"alerted": row.alerted,
"executed": row.executed,
"executed_at": row.executed_at.strftime("%Y-%m-%dT%H:%M:%SZ") if row.executed_at else None,
"closed": row.closed,
"created_at": row.created_at.isoformat() if row.created_at else None,
}