NuCon/scripts/train_sac.py
Dominik Roth 55d6e8708e fix: kNN zero-variance dims get inf std; hot-start SAC from saved model
- nucon/model.py: constant input dimensions (zero variance in training
  data) now get std=inf so they contribute 0 to normalised kNN distance
  instead of causing catastrophic OOD from tiny float epsilon
- scripts/train_sac.py: add --load, --steps, --out CLI args; --load
  hot-starts actor/critic weights from a previous run (learning_starts=0)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-13 12:44:26 +01:00

100 lines
3.4 KiB
Python

"""SAC + HER training on kNN-GP simulator.
Usage:
python3.14 train_sac.py
python3.14 train_sac.py --load /tmp/sac_nucon_knn # hot-start from previous run
Requirements:
- NuCon game running (for parameter metadata)
- /tmp/reactor_knn.pkl (kNN-GP model)
- /tmp/nucon_dataset.pkl (500-sample dataset for init_states)
"""
import argparse
import pickle
from gymnasium.wrappers import TimeLimit
from stable_baselines3 import SAC
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
from nucon.sim import NuconSimulator
from nucon.rl import NuconGoalEnv, Parameterized_Objectives, Parameterized_Terminators
parser = argparse.ArgumentParser()
parser.add_argument('--load', default=None, help='Path to existing model to hot-start from')
parser.add_argument('--steps', type=int, default=50_000, help='Total timesteps (default: 50000)')
parser.add_argument('--out', default='/tmp/sac_nucon_knn', help='Output path for saved model')
args = parser.parse_args()
# ---------------------------------------------------------------------------
# Load model and dataset
# ---------------------------------------------------------------------------
with open('/tmp/reactor_knn.pkl', 'rb') as f:
knn_model = pickle.load(f)
with open('/tmp/nucon_dataset.pkl', 'rb') as f:
dataset = pickle.load(f)
# Seed resets to in-distribution states from dataset
init_states = [s for _, _, s, _ in dataset]
# ---------------------------------------------------------------------------
# Build sim + env
# ---------------------------------------------------------------------------
sim = NuconSimulator(port=8786)
sim.set_model(knn_model)
BATCH_SIZE = 2048
MAX_EPISODE_STEPS = 200
env = NuconGoalEnv(
goal_params=['CORE_TEMP'],
goal_range={'CORE_TEMP': (55.0, 550.0)},
tolerance=0.05,
seconds_per_step=10,
simulator=sim,
additional_objectives=[
Parameterized_Objectives['uncertainty_penalty'](start=0.3),
],
additional_objective_weights=[1.0],
init_states=init_states,
delta_action_scale=0.05,
)
env = TimeLimit(env, max_episode_steps=MAX_EPISODE_STEPS)
# ---------------------------------------------------------------------------
# SAC + HER
# learning_starts = batch_size: wait for batch_size complete (short) episodes
# before the first gradient step. As the policy learns to stay in-dist, episodes
# will get longer and HER has more transitions to relabel.
# ---------------------------------------------------------------------------
if args.load:
print(f"Hot-starting from {args.load}")
model = SAC.load(args.load, env=env, device='auto',
custom_objects={'learning_rate': 3e-4, 'batch_size': BATCH_SIZE,
'tau': 0.005, 'gamma': 0.98,
'train_freq': 64, 'gradient_steps': 8,
'learning_starts': 0})
else:
model = SAC(
'MultiInputPolicy',
env,
replay_buffer_class=HerReplayBuffer,
replay_buffer_kwargs={
'n_sampled_goal': 4,
'goal_selection_strategy': 'future',
},
verbose=1,
learning_rate=3e-4,
batch_size=BATCH_SIZE,
tau=0.005,
gamma=0.98,
train_freq=64,
gradient_steps=8,
learning_starts=BATCH_SIZE,
device='auto',
)
model.learn(total_timesteps=args.steps)
model.save(args.out)
print(f"Saved to {args.out}.zip")