From 0fa36a33900f381f954924d3c3a63f0f4ee59996 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Tue, 26 May 2026 17:48:33 +0200 Subject: [PATCH] feat(db): dedup-safe inserts, filter_new_accessions, mark_accession_seen, as-of-date queries - insert_filing: catch IntegrityError on duplicate accession instead of crashing - filter_new_accessions: bulk pre-filter entire quarter against DB in chunked IN queries (avoids 30min per-row accession_exists loop during resume) - mark_accession_seen: store placeholder row for derivative-only/empty filings so they aren't re-fetched on every resume - get_recent_buys_for_ticker: accept as_of_date to clamp queries for historical signal gen - get_all_buys_for_reprocess: return all buy filings ordered by transaction_date for backfill Co-Authored-By: Claude Sonnet 4.6 --- db/db.py | 63 ++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 59 insertions(+), 4 deletions(-) diff --git a/db/db.py b/db/db.py index d021483..4b66a1b 100644 --- a/db/db.py +++ b/db/db.py @@ -2,6 +2,7 @@ from datetime import datetime from typing import Optional from sqlalchemy import create_engine, func, select, text, update +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session import config @@ -59,8 +60,12 @@ def insert_filing(filing: dict) -> bool: post_tx_shares=filing.get("post_tx_shares"), ) session.add(row) - session.commit() - return True + try: + session.commit() + return True + except IntegrityError: + session.rollback() + return False def accession_exists(accession_number: str) -> bool: @@ -70,6 +75,40 @@ def accession_exists(accession_number: str) -> bool: ) is not None +def mark_accession_seen(accession_number: str) -> None: + """Store a placeholder row so derivative-only/empty filings aren't re-fetched.""" + with _session() as session: + exists = session.scalar( + select(Filing.id).where(Filing.accession_number == accession_number) + ) + if exists is not None: + return + session.add(Filing(accession_number=accession_number)) + try: + session.commit() + except IntegrityError: + session.rollback() + + +def filter_new_accessions(accessions: list[str]) -> set[str]: + """Return the subset of accessions not already in the DB.""" + if not accessions: + return set() + existing: set[str] = set() + chunk_size = 900 # SQLite SQLITE_MAX_VARIABLE_NUMBER limit + with _session() as session: + for i in range(0, len(accessions), chunk_size): + chunk = accessions[i:i + chunk_size] + existing.update( + session.execute( + select(Filing.accession_number).where( + Filing.accession_number.in_(chunk) + ) + ).scalars().all() + ) + return set(accessions) - existing + + def get_latest_filed_date() -> Optional[str]: with _session() as session: return session.scalar(select(func.max(Filing.filed_date))) @@ -135,10 +174,14 @@ def get_executed_unclosed_signals() -> list[dict]: return [_signal_to_dict(r) for r in rows] -def get_recent_buys_for_ticker(ticker: str, window_days: int) -> list[dict]: +def get_recent_buys_for_ticker( + ticker: str, window_days: int, as_of_date: Optional[str] = None +) -> list[dict]: from datetime import timedelta - cutoff = (datetime.utcnow() - timedelta(days=window_days)).strftime("%Y-%m-%d") + ref = datetime.strptime(as_of_date, "%Y-%m-%d") if as_of_date else datetime.utcnow() + cutoff = (ref - timedelta(days=window_days)).strftime("%Y-%m-%d") + ref_str = ref.strftime("%Y-%m-%d") with _session() as session: rows = session.scalars( select(Filing) @@ -147,12 +190,24 @@ def get_recent_buys_for_ticker(ticker: str, window_days: int) -> list[dict]: Filing.flag == "A", Filing.is_10b51 == False, Filing.transaction_date >= cutoff, + Filing.transaction_date <= ref_str, ) .order_by(Filing.transaction_date.desc()) ).all() return [_filing_to_dict(r) for r in rows] +def get_all_buys_for_reprocess() -> list[dict]: + """Return all buy filings for signal reprocessing (used after bulk ingest).""" + with _session() as session: + rows = session.scalars( + select(Filing) + .where(Filing.flag == "A") + .order_by(Filing.transaction_date) + ).all() + return [_filing_to_dict(r) for r in rows] + + def get_signals_for_backtest(min_score: float, min_cluster_size: int) -> list[dict]: with _session() as session: rows = session.scalars(