NuCon/scripts/train_sac.py
Dominik Roth 646399dcc7 feat: improve NN dynamics model and SAC training
- ReactorDynamicsNet: add dropout (0.3) for regularisation
- ReactorDynamicsModel: z-score normalisation of inputs/outputs, predict
  per-second rates of change, forward_with_uncertainty() stub
- rl.py: misc SAC training improvements
- sim.py: minor fixes
- train_sac.py: updated training loop

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-15 11:18:15 +01:00

152 lines
5.9 KiB
Python

"""SAC + HER training on kNN-GP simulator.
Usage:
python3.14 train_sac.py
python3.14 train_sac.py --load /tmp/sac_nucon_knn # hot-start from previous run
Requirements:
- NuCon game running (for parameter metadata)
- /tmp/reactor_knn.pkl (kNN-GP model)
- /tmp/nucon_dataset.pkl (500-sample dataset for init_states)
"""
import argparse
import pickle
import torch
from gymnasium.wrappers import TimeLimit
from stable_baselines3 import SAC
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
from stable_baselines3.common.callbacks import CheckpointCallback
from nucon.sim import NuconSimulator
from nucon.model import ReactorDynamicsModel, MixtureModel
from nucon.rl import NuconGoalEnv, Parameterized_Objectives, Parameterized_Terminators
parser = argparse.ArgumentParser()
parser.add_argument('--load', default=None, help='Path to existing model to hot-start from')
parser.add_argument('--steps', type=int, default=50_000, help='Total timesteps (default: 50000)')
parser.add_argument('--out', default='/tmp/sac_nucon_knn', help='Output path for saved model')
parser.add_argument('--model', default='/tmp/reactor_knn.pkl', help='Dynamics model (.pkl for kNN, .pt for NN)')
parser.add_argument('--model2', default=None, help='Second dynamics model for mixture (optional)')
parser.add_argument('--dataset', default='/tmp/nucon_dataset.pkl', help='Dataset for init states')
args = parser.parse_args()
# ---------------------------------------------------------------------------
# Load dynamics model(s) and dataset
# ---------------------------------------------------------------------------
def _load_model(path):
if path.endswith('.pt'):
ckpt = torch.load(path, weights_only=False)
m = ReactorDynamicsModel(ckpt['input_params'], ckpt['output_params'])
m.load_state_dict(ckpt['state_dict'])
m.eval()
return m
with open(path, 'rb') as f:
return pickle.load(f)
dynamics_model = _load_model(args.model)
if args.model2:
dynamics_model = MixtureModel(dynamics_model, _load_model(args.model2))
with open(args.dataset, 'rb') as f:
dataset = pickle.load(f)
# Seed resets to in-distribution states from dataset
init_states = [s for _, _, s, _ in dataset]
# ---------------------------------------------------------------------------
# Build sim + env
# ---------------------------------------------------------------------------
sim = NuconSimulator(port=8786)
sim.set_model(dynamics_model)
BATCH_SIZE = 2048
MAX_EPISODE_STEPS = 200
GENERATORS = ['GENERATOR_0_KW', 'GENERATOR_1_KW', 'GENERATOR_2_KW']
POWER_RANGE = {g: (0.0, 100_000.0) for g in GENERATORS} # per-generator kW; ~100 MW upper bound
# Curated obs: physically relevant features for power control (~25 dims vs ~260 full)
OBS_PARAMS = [
'CORE_TEMP', 'CORE_PRESSURE', 'CORE_STATE_CRITICALITY', 'CORE_WEAR', 'CORE_INTEGRITY',
'ROD_BANK_POS_0_ACTUAL', 'ROD_BANK_POS_0_ORDERED',
'COOLANT_CORE_FLOW_SPEED', 'COOLANT_CORE_VESSEL_TEMPERATURE',
'COOLANT_CORE_PRESSURE', 'COOLANT_CORE_QUANTITY_IN_VESSEL',
'STEAM_TURBINE_0_RPM', 'STEAM_TURBINE_0_TEMPERATURE', 'STEAM_TURBINE_0_PRESSURE',
'STEAM_TURBINE_1_RPM', 'STEAM_TURBINE_1_TEMPERATURE', 'STEAM_TURBINE_1_PRESSURE',
'STEAM_TURBINE_2_RPM', 'STEAM_TURBINE_2_TEMPERATURE', 'STEAM_TURBINE_2_PRESSURE',
'GENERATOR_0_V', 'GENERATOR_1_V', 'GENERATOR_2_V',
]
env = NuconGoalEnv(
goal_params=GENERATORS,
goal_range=POWER_RANGE,
seconds_per_step=10,
simulator=sim,
obs_params=OBS_PARAMS,
additional_objectives=[
Parameterized_Objectives['uncertainty_penalty'](start=0.3),
Parameterized_Objectives['temp_below_linear'](max_temp=420),
],
additional_objective_weights=[1.0, 0.01],
init_states=init_states,
delta_action_scale=0.05,
goal_sampling_std=0.15, # Gaussian delta in normalised space (~180 kW typical)
)
env = TimeLimit(env, max_episode_steps=MAX_EPISODE_STEPS)
# ---------------------------------------------------------------------------
# SAC + HER
# learning_starts = batch_size: wait for batch_size complete (short) episodes
# before the first gradient step. As the policy learns to stay in-dist, episodes
# will get longer and HER has more transitions to relabel.
# ---------------------------------------------------------------------------
if args.load:
print(f"Hot-starting from {args.load}")
model = SAC.load(args.load, env=env, device='auto',
custom_objects={'learning_rate': 3e-4, 'batch_size': BATCH_SIZE,
'tau': 0.005, 'gamma': 0.98,
'train_freq': 64, 'gradient_steps': 8,
'learning_starts': MAX_EPISODE_STEPS,
'ent_coef': 0.1})
else:
model = SAC(
'MultiInputPolicy',
env,
replay_buffer_class=HerReplayBuffer,
replay_buffer_kwargs={
'n_sampled_goal': 4,
'goal_selection_strategy': 'future',
},
verbose=1,
learning_rate=3e-4,
batch_size=BATCH_SIZE,
tau=0.005,
gamma=0.98,
train_freq=64,
gradient_steps=8,
learning_starts=BATCH_SIZE,
ent_coef=0.1, # fixed; auto-tuning diverges on this many action dims
device='auto',
)
checkpoint_cb = CheckpointCallback(
save_freq=10_000,
save_path=args.out + '_checkpoints/',
name_prefix='sac',
)
import json, os
config = {'obs_params': OBS_PARAMS}
for save_dir in [args.out + '_checkpoints/', os.path.dirname(args.out) or '.']:
os.makedirs(save_dir, exist_ok=True)
with open(os.path.join(save_dir, 'config.json'), 'w') as f:
json.dump(config, f)
model.learn(total_timesteps=args.steps, callback=checkpoint_cb)
model.save(args.out)
with open(args.out + '.json', 'w') as f:
json.dump(config, f)
print(f"Saved to {args.out}.zip")