from datetime import datetime from typing import Optional from sqlalchemy import create_engine, func, select, text, update from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session import config from db.models import Base, Filing, PriceCache, Signal def _engine(): url = f"sqlite:///{config.DB_PATH}" return create_engine(url, connect_args={"check_same_thread": False}) _ENGINE = None def _get_engine(): global _ENGINE if _ENGINE is None: _ENGINE = _engine() return _ENGINE def init_db(): engine = _get_engine() with engine.connect() as conn: conn.execute(text("PRAGMA journal_mode=WAL")) conn.execute(text("PRAGMA foreign_keys=ON")) Base.metadata.create_all(engine) def _session() -> Session: return Session(_get_engine()) def insert_filing(filing: dict) -> bool: with _session() as session: exists = session.scalar( select(Filing.id).where(Filing.accession_number == filing["accession_number"]) ) if exists is not None: return False row = Filing( accession_number=filing["accession_number"], ticker=filing.get("ticker"), cik=filing.get("cik"), insider_name=filing.get("insider_name"), role=filing.get("role"), transaction_date=filing.get("transaction_date"), filed_date=filing.get("filed_date"), shares=filing.get("shares"), price=filing.get("price"), total_value=filing.get("total_value"), flag=filing.get("flag"), is_10b51=bool(filing.get("is_10b51", False)), post_tx_shares=filing.get("post_tx_shares"), ) session.add(row) try: session.commit() return True except IntegrityError: session.rollback() return False def accession_exists(accession_number: str) -> bool: with _session() as session: return session.scalar( select(Filing.id).where(Filing.accession_number == accession_number) ) is not None def mark_accession_seen(accession_number: str) -> None: """Store a placeholder row so derivative-only/empty filings aren't re-fetched.""" with _session() as session: exists = session.scalar( select(Filing.id).where(Filing.accession_number == accession_number) ) if exists is not None: return session.add(Filing(accession_number=accession_number)) try: session.commit() except IntegrityError: session.rollback() def filter_new_accessions(accessions: list[str]) -> set[str]: """Return the subset of accessions not already in the DB.""" if not accessions: return set() existing: set[str] = set() chunk_size = 900 # SQLite SQLITE_MAX_VARIABLE_NUMBER limit with _session() as session: for i in range(0, len(accessions), chunk_size): chunk = accessions[i:i + chunk_size] existing.update( session.execute( select(Filing.accession_number).where( Filing.accession_number.in_(chunk) ) ).scalars().all() ) return set(accessions) - existing def get_latest_filed_date() -> Optional[str]: with _session() as session: return session.scalar(select(func.max(Filing.filed_date))) def insert_signal(signal: dict) -> int: with _session() as session: row = Signal( ticker=signal["ticker"], trigger_date=signal["trigger_date"], cluster_size=signal["cluster_size"], total_cluster_value=signal.get("total_cluster_value", 0.0), score=signal["score"], ) session.add(row) session.commit() return row.id def mark_signal_alerted(signal_id: int): with _session() as session: session.execute( update(Signal).where(Signal.id == signal_id).values(alerted=True) ) session.commit() def mark_signal_executed(signal_id: int): with _session() as session: session.execute( update(Signal) .where(Signal.id == signal_id) .values(executed=True, executed_at=datetime.utcnow()) ) session.commit() def mark_signal_closed(signal_id: int): with _session() as session: session.execute( update(Signal).where(Signal.id == signal_id).values(closed=True) ) session.commit() def get_unalerted_signals() -> list[dict]: with _session() as session: rows = session.scalars( select(Signal).where(Signal.alerted == False).order_by(Signal.created_at) ).all() return [_signal_to_dict(r) for r in rows] def get_executed_unclosed_signals() -> list[dict]: with _session() as session: rows = session.scalars( select(Signal).where( Signal.executed == True, Signal.closed == False, Signal.executed_at.is_not(None), ) ).all() return [_signal_to_dict(r) for r in rows] def get_recent_buys_for_ticker( ticker: str, window_days: int, as_of_date: Optional[str] = None ) -> list[dict]: from datetime import timedelta ref = datetime.strptime(as_of_date, "%Y-%m-%d") if as_of_date else datetime.utcnow() cutoff = (ref - timedelta(days=window_days)).strftime("%Y-%m-%d") ref_str = ref.strftime("%Y-%m-%d") with _session() as session: rows = session.scalars( select(Filing) .where( Filing.ticker == ticker, Filing.flag == "A", Filing.is_10b51 == False, Filing.transaction_date >= cutoff, Filing.transaction_date <= ref_str, ) .order_by(Filing.transaction_date.desc()) ).all() return [_filing_to_dict(r) for r in rows] def get_all_buys_for_reprocess() -> list[dict]: """Return all buy filings for signal reprocessing (used after bulk ingest).""" with _session() as session: rows = session.scalars( select(Filing) .where(Filing.flag == "A") .order_by(Filing.transaction_date) ).all() return [_filing_to_dict(r) for r in rows] def get_signals_for_backtest(min_score: float, min_cluster_size: int) -> list[dict]: with _session() as session: rows = session.scalars( select(Signal).where( Signal.score >= min_score, Signal.cluster_size >= min_cluster_size, ) ).all() return [_signal_to_dict(r) for r in rows] def get_cached_prices(ticker: str, start_date: str, end_date: str) -> dict[str, float]: with _session() as session: rows = session.scalars( select(PriceCache).where( PriceCache.ticker == ticker, PriceCache.date >= start_date, PriceCache.date <= end_date, ) ).all() return {r.date: r.close for r in rows} def upsert_prices(ticker: str, prices: dict[str, float]): with _session() as session: for date_str, close in prices.items(): existing = session.scalar( select(PriceCache).where( PriceCache.ticker == ticker, PriceCache.date == date_str, ) ) if existing is None: session.add(PriceCache(ticker=ticker, date=date_str, close=close)) session.commit() def _filing_to_dict(row: Filing) -> dict: return { "id": row.id, "accession_number": row.accession_number, "ticker": row.ticker, "cik": row.cik, "insider_name": row.insider_name, "role": row.role, "transaction_date": row.transaction_date, "filed_date": row.filed_date, "shares": row.shares, "price": row.price, "total_value": row.total_value, "flag": row.flag, "is_10b51": row.is_10b51, "post_tx_shares": row.post_tx_shares, "created_at": row.created_at.isoformat() if row.created_at else None, } def _signal_to_dict(row: Signal) -> dict: return { "id": row.id, "ticker": row.ticker, "trigger_date": row.trigger_date, "cluster_size": row.cluster_size, "total_cluster_value": row.total_cluster_value, "score": row.score, "alerted": row.alerted, "executed": row.executed, "executed_at": row.executed_at.strftime("%Y-%m-%dT%H:%M:%SZ") if row.executed_at else None, "closed": row.closed, "created_at": row.created_at.isoformat() if row.created_at else None, }