diff --git a/db/db.py b/db/db.py index d021483..4b66a1b 100644 --- a/db/db.py +++ b/db/db.py @@ -2,6 +2,7 @@ from datetime import datetime from typing import Optional from sqlalchemy import create_engine, func, select, text, update +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session import config @@ -59,8 +60,12 @@ def insert_filing(filing: dict) -> bool: post_tx_shares=filing.get("post_tx_shares"), ) session.add(row) - session.commit() - return True + try: + session.commit() + return True + except IntegrityError: + session.rollback() + return False def accession_exists(accession_number: str) -> bool: @@ -70,6 +75,40 @@ def accession_exists(accession_number: str) -> bool: ) is not None +def mark_accession_seen(accession_number: str) -> None: + """Store a placeholder row so derivative-only/empty filings aren't re-fetched.""" + with _session() as session: + exists = session.scalar( + select(Filing.id).where(Filing.accession_number == accession_number) + ) + if exists is not None: + return + session.add(Filing(accession_number=accession_number)) + try: + session.commit() + except IntegrityError: + session.rollback() + + +def filter_new_accessions(accessions: list[str]) -> set[str]: + """Return the subset of accessions not already in the DB.""" + if not accessions: + return set() + existing: set[str] = set() + chunk_size = 900 # SQLite SQLITE_MAX_VARIABLE_NUMBER limit + with _session() as session: + for i in range(0, len(accessions), chunk_size): + chunk = accessions[i:i + chunk_size] + existing.update( + session.execute( + select(Filing.accession_number).where( + Filing.accession_number.in_(chunk) + ) + ).scalars().all() + ) + return set(accessions) - existing + + def get_latest_filed_date() -> Optional[str]: with _session() as session: return session.scalar(select(func.max(Filing.filed_date))) @@ -135,10 +174,14 @@ def get_executed_unclosed_signals() -> list[dict]: return [_signal_to_dict(r) for r in rows] -def get_recent_buys_for_ticker(ticker: str, window_days: int) -> list[dict]: +def get_recent_buys_for_ticker( + ticker: str, window_days: int, as_of_date: Optional[str] = None +) -> list[dict]: from datetime import timedelta - cutoff = (datetime.utcnow() - timedelta(days=window_days)).strftime("%Y-%m-%d") + ref = datetime.strptime(as_of_date, "%Y-%m-%d") if as_of_date else datetime.utcnow() + cutoff = (ref - timedelta(days=window_days)).strftime("%Y-%m-%d") + ref_str = ref.strftime("%Y-%m-%d") with _session() as session: rows = session.scalars( select(Filing) @@ -147,12 +190,24 @@ def get_recent_buys_for_ticker(ticker: str, window_days: int) -> list[dict]: Filing.flag == "A", Filing.is_10b51 == False, Filing.transaction_date >= cutoff, + Filing.transaction_date <= ref_str, ) .order_by(Filing.transaction_date.desc()) ).all() return [_filing_to_dict(r) for r in rows] +def get_all_buys_for_reprocess() -> list[dict]: + """Return all buy filings for signal reprocessing (used after bulk ingest).""" + with _session() as session: + rows = session.scalars( + select(Filing) + .where(Filing.flag == "A") + .order_by(Filing.transaction_date) + ).all() + return [_filing_to_dict(r) for r in rows] + + def get_signals_for_backtest(min_score: float, min_cluster_size: int) -> list[dict]: with _session() as session: rows = session.scalars(