- insert_filing: catch IntegrityError on duplicate accession instead of crashing - filter_new_accessions: bulk pre-filter entire quarter against DB in chunked IN queries (avoids 30min per-row accession_exists loop during resume) - mark_accession_seen: store placeholder row for derivative-only/empty filings so they aren't re-fetched on every resume - get_recent_buys_for_ticker: accept as_of_date to clamp queries for historical signal gen - get_all_buys_for_reprocess: return all buy filings ordered by transaction_date for backfill Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
282 lines
8.7 KiB
Python
282 lines
8.7 KiB
Python
from datetime import datetime
|
|
from typing import Optional
|
|
|
|
from sqlalchemy import create_engine, func, select, text, update
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.orm import Session
|
|
|
|
import config
|
|
from db.models import Base, Filing, PriceCache, Signal
|
|
|
|
|
|
def _engine():
|
|
url = f"sqlite:///{config.DB_PATH}"
|
|
return create_engine(url, connect_args={"check_same_thread": False})
|
|
|
|
|
|
_ENGINE = None
|
|
|
|
|
|
def _get_engine():
|
|
global _ENGINE
|
|
if _ENGINE is None:
|
|
_ENGINE = _engine()
|
|
return _ENGINE
|
|
|
|
|
|
def init_db():
|
|
engine = _get_engine()
|
|
with engine.connect() as conn:
|
|
conn.execute(text("PRAGMA journal_mode=WAL"))
|
|
conn.execute(text("PRAGMA foreign_keys=ON"))
|
|
Base.metadata.create_all(engine)
|
|
|
|
|
|
def _session() -> Session:
|
|
return Session(_get_engine())
|
|
|
|
|
|
def insert_filing(filing: dict) -> bool:
|
|
with _session() as session:
|
|
exists = session.scalar(
|
|
select(Filing.id).where(Filing.accession_number == filing["accession_number"])
|
|
)
|
|
if exists is not None:
|
|
return False
|
|
|
|
row = Filing(
|
|
accession_number=filing["accession_number"],
|
|
ticker=filing.get("ticker"),
|
|
cik=filing.get("cik"),
|
|
insider_name=filing.get("insider_name"),
|
|
role=filing.get("role"),
|
|
transaction_date=filing.get("transaction_date"),
|
|
filed_date=filing.get("filed_date"),
|
|
shares=filing.get("shares"),
|
|
price=filing.get("price"),
|
|
total_value=filing.get("total_value"),
|
|
flag=filing.get("flag"),
|
|
is_10b51=bool(filing.get("is_10b51", False)),
|
|
post_tx_shares=filing.get("post_tx_shares"),
|
|
)
|
|
session.add(row)
|
|
try:
|
|
session.commit()
|
|
return True
|
|
except IntegrityError:
|
|
session.rollback()
|
|
return False
|
|
|
|
|
|
def accession_exists(accession_number: str) -> bool:
|
|
with _session() as session:
|
|
return session.scalar(
|
|
select(Filing.id).where(Filing.accession_number == accession_number)
|
|
) is not None
|
|
|
|
|
|
def mark_accession_seen(accession_number: str) -> None:
|
|
"""Store a placeholder row so derivative-only/empty filings aren't re-fetched."""
|
|
with _session() as session:
|
|
exists = session.scalar(
|
|
select(Filing.id).where(Filing.accession_number == accession_number)
|
|
)
|
|
if exists is not None:
|
|
return
|
|
session.add(Filing(accession_number=accession_number))
|
|
try:
|
|
session.commit()
|
|
except IntegrityError:
|
|
session.rollback()
|
|
|
|
|
|
def filter_new_accessions(accessions: list[str]) -> set[str]:
|
|
"""Return the subset of accessions not already in the DB."""
|
|
if not accessions:
|
|
return set()
|
|
existing: set[str] = set()
|
|
chunk_size = 900 # SQLite SQLITE_MAX_VARIABLE_NUMBER limit
|
|
with _session() as session:
|
|
for i in range(0, len(accessions), chunk_size):
|
|
chunk = accessions[i:i + chunk_size]
|
|
existing.update(
|
|
session.execute(
|
|
select(Filing.accession_number).where(
|
|
Filing.accession_number.in_(chunk)
|
|
)
|
|
).scalars().all()
|
|
)
|
|
return set(accessions) - existing
|
|
|
|
|
|
def get_latest_filed_date() -> Optional[str]:
|
|
with _session() as session:
|
|
return session.scalar(select(func.max(Filing.filed_date)))
|
|
|
|
|
|
def insert_signal(signal: dict) -> int:
|
|
with _session() as session:
|
|
row = Signal(
|
|
ticker=signal["ticker"],
|
|
trigger_date=signal["trigger_date"],
|
|
cluster_size=signal["cluster_size"],
|
|
total_cluster_value=signal.get("total_cluster_value", 0.0),
|
|
score=signal["score"],
|
|
)
|
|
session.add(row)
|
|
session.commit()
|
|
return row.id
|
|
|
|
|
|
def mark_signal_alerted(signal_id: int):
|
|
with _session() as session:
|
|
session.execute(
|
|
update(Signal).where(Signal.id == signal_id).values(alerted=True)
|
|
)
|
|
session.commit()
|
|
|
|
|
|
def mark_signal_executed(signal_id: int):
|
|
with _session() as session:
|
|
session.execute(
|
|
update(Signal)
|
|
.where(Signal.id == signal_id)
|
|
.values(executed=True, executed_at=datetime.utcnow())
|
|
)
|
|
session.commit()
|
|
|
|
|
|
def mark_signal_closed(signal_id: int):
|
|
with _session() as session:
|
|
session.execute(
|
|
update(Signal).where(Signal.id == signal_id).values(closed=True)
|
|
)
|
|
session.commit()
|
|
|
|
|
|
def get_unalerted_signals() -> list[dict]:
|
|
with _session() as session:
|
|
rows = session.scalars(
|
|
select(Signal).where(Signal.alerted == False).order_by(Signal.created_at)
|
|
).all()
|
|
return [_signal_to_dict(r) for r in rows]
|
|
|
|
|
|
def get_executed_unclosed_signals() -> list[dict]:
|
|
with _session() as session:
|
|
rows = session.scalars(
|
|
select(Signal).where(
|
|
Signal.executed == True,
|
|
Signal.closed == False,
|
|
Signal.executed_at.is_not(None),
|
|
)
|
|
).all()
|
|
return [_signal_to_dict(r) for r in rows]
|
|
|
|
|
|
def get_recent_buys_for_ticker(
|
|
ticker: str, window_days: int, as_of_date: Optional[str] = None
|
|
) -> list[dict]:
|
|
from datetime import timedelta
|
|
|
|
ref = datetime.strptime(as_of_date, "%Y-%m-%d") if as_of_date else datetime.utcnow()
|
|
cutoff = (ref - timedelta(days=window_days)).strftime("%Y-%m-%d")
|
|
ref_str = ref.strftime("%Y-%m-%d")
|
|
with _session() as session:
|
|
rows = session.scalars(
|
|
select(Filing)
|
|
.where(
|
|
Filing.ticker == ticker,
|
|
Filing.flag == "A",
|
|
Filing.is_10b51 == False,
|
|
Filing.transaction_date >= cutoff,
|
|
Filing.transaction_date <= ref_str,
|
|
)
|
|
.order_by(Filing.transaction_date.desc())
|
|
).all()
|
|
return [_filing_to_dict(r) for r in rows]
|
|
|
|
|
|
def get_all_buys_for_reprocess() -> list[dict]:
|
|
"""Return all buy filings for signal reprocessing (used after bulk ingest)."""
|
|
with _session() as session:
|
|
rows = session.scalars(
|
|
select(Filing)
|
|
.where(Filing.flag == "A")
|
|
.order_by(Filing.transaction_date)
|
|
).all()
|
|
return [_filing_to_dict(r) for r in rows]
|
|
|
|
|
|
def get_signals_for_backtest(min_score: float, min_cluster_size: int) -> list[dict]:
|
|
with _session() as session:
|
|
rows = session.scalars(
|
|
select(Signal).where(
|
|
Signal.score >= min_score,
|
|
Signal.cluster_size >= min_cluster_size,
|
|
)
|
|
).all()
|
|
return [_signal_to_dict(r) for r in rows]
|
|
|
|
|
|
def get_cached_prices(ticker: str, start_date: str, end_date: str) -> dict[str, float]:
|
|
with _session() as session:
|
|
rows = session.scalars(
|
|
select(PriceCache).where(
|
|
PriceCache.ticker == ticker,
|
|
PriceCache.date >= start_date,
|
|
PriceCache.date <= end_date,
|
|
)
|
|
).all()
|
|
return {r.date: r.close for r in rows}
|
|
|
|
|
|
def upsert_prices(ticker: str, prices: dict[str, float]):
|
|
with _session() as session:
|
|
for date_str, close in prices.items():
|
|
existing = session.scalar(
|
|
select(PriceCache).where(
|
|
PriceCache.ticker == ticker,
|
|
PriceCache.date == date_str,
|
|
)
|
|
)
|
|
if existing is None:
|
|
session.add(PriceCache(ticker=ticker, date=date_str, close=close))
|
|
session.commit()
|
|
|
|
|
|
def _filing_to_dict(row: Filing) -> dict:
|
|
return {
|
|
"id": row.id,
|
|
"accession_number": row.accession_number,
|
|
"ticker": row.ticker,
|
|
"cik": row.cik,
|
|
"insider_name": row.insider_name,
|
|
"role": row.role,
|
|
"transaction_date": row.transaction_date,
|
|
"filed_date": row.filed_date,
|
|
"shares": row.shares,
|
|
"price": row.price,
|
|
"total_value": row.total_value,
|
|
"flag": row.flag,
|
|
"is_10b51": row.is_10b51,
|
|
"post_tx_shares": row.post_tx_shares,
|
|
"created_at": row.created_at.isoformat() if row.created_at else None,
|
|
}
|
|
|
|
|
|
def _signal_to_dict(row: Signal) -> dict:
|
|
return {
|
|
"id": row.id,
|
|
"ticker": row.ticker,
|
|
"trigger_date": row.trigger_date,
|
|
"cluster_size": row.cluster_size,
|
|
"total_cluster_value": row.total_cluster_value,
|
|
"score": row.score,
|
|
"alerted": row.alerted,
|
|
"executed": row.executed,
|
|
"executed_at": row.executed_at.strftime("%Y-%m-%dT%H:%M:%SZ") if row.executed_at else None,
|
|
"closed": row.closed,
|
|
"created_at": row.created_at.isoformat() if row.created_at else None,
|
|
}
|