NuCon/scripts/train_sac.py
Dominik Roth 0932bb353a feat: SAC+HER training on kNN-GP sim with direct bypass and scripts/
- nucon/rl.py: delta_action_scale action space, bool handling (>=0.5),
  direct sim read/write bypassing HTTP for ~2000fps env throughput;
  remove uncertainty_abort from training (use penalty-only), larger
  default batch sizes; fix _read_obs and step for in-process sim
- nucon/model.py: optimise _lookup with einsum squared-L2, vectorised
  rbf kernel; forward_with_uncertainty uses pre-built normalised arrays
- nucon/sim.py: _update_reactor_state writes outputs via setattr directly
- scripts/train_sac.py: moved from root; full SAC+HER example with kNN-GP
  sim, delta actions, uncertainty penalty, init_states
- scripts/collect_dataset.py: CLI tool to collect dynamics dataset from
  live game session (--steps, --delta, --out, --merge)
- README.md: add Scripts section, reference both scripts in training loop

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-12 20:43:37 +01:00

84 lines
2.5 KiB
Python

"""SAC + HER training on kNN-GP simulator.
Usage:
python3.14 train_sac.py
Requirements:
- NuCon game running (for parameter metadata)
- /tmp/reactor_knn.pkl (kNN-GP model)
- /tmp/nucon_dataset.pkl (500-sample dataset for init_states)
"""
import pickle
from gymnasium.wrappers import TimeLimit
from stable_baselines3 import SAC
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
from nucon.sim import NuconSimulator
from nucon.rl import NuconGoalEnv, Parameterized_Objectives, Parameterized_Terminators
# ---------------------------------------------------------------------------
# Load model and dataset
# ---------------------------------------------------------------------------
with open('/tmp/reactor_knn.pkl', 'rb') as f:
knn_model = pickle.load(f)
with open('/tmp/nucon_dataset.pkl', 'rb') as f:
dataset = pickle.load(f)
# Seed resets to in-distribution states from dataset
init_states = [s for _, _, s, _ in dataset]
# ---------------------------------------------------------------------------
# Build sim + env
# ---------------------------------------------------------------------------
sim = NuconSimulator(port=8786)
sim.set_model(knn_model)
BATCH_SIZE = 2048
MAX_EPISODE_STEPS = 200
env = NuconGoalEnv(
goal_params=['CORE_TEMP'],
goal_range={'CORE_TEMP': (55.0, 550.0)},
tolerance=0.05,
seconds_per_step=10,
simulator=sim,
additional_objectives=[
Parameterized_Objectives['uncertainty_penalty'](start=0.3),
],
additional_objective_weights=[1.0],
init_states=init_states,
delta_action_scale=0.05,
)
env = TimeLimit(env, max_episode_steps=MAX_EPISODE_STEPS)
# ---------------------------------------------------------------------------
# SAC + HER
# learning_starts = batch_size: wait for batch_size complete (short) episodes
# before the first gradient step. As the policy learns to stay in-dist, episodes
# will get longer and HER has more transitions to relabel.
# ---------------------------------------------------------------------------
model = SAC(
'MultiInputPolicy',
env,
replay_buffer_class=HerReplayBuffer,
replay_buffer_kwargs={
'n_sampled_goal': 4,
'goal_selection_strategy': 'future',
},
verbose=1,
learning_rate=3e-4,
batch_size=BATCH_SIZE,
tau=0.005,
gamma=0.98,
train_freq=64,
gradient_steps=8,
learning_starts=BATCH_SIZE,
device='auto',
)
model.learn(total_timesteps=50_000)
model.save('/tmp/sac_nucon_knn')
print("Saved to /tmp/sac_nucon_knn.zip")