- nucon/model.py: constant input dimensions (zero variance in training data) now get std=inf so they contribute 0 to normalised kNN distance instead of causing catastrophic OOD from tiny float epsilon - scripts/train_sac.py: add --load, --steps, --out CLI args; --load hot-starts actor/critic weights from a previous run (learning_starts=0) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
100 lines
3.4 KiB
Python
100 lines
3.4 KiB
Python
"""SAC + HER training on kNN-GP simulator.
|
|
|
|
Usage:
|
|
python3.14 train_sac.py
|
|
python3.14 train_sac.py --load /tmp/sac_nucon_knn # hot-start from previous run
|
|
|
|
Requirements:
|
|
- NuCon game running (for parameter metadata)
|
|
- /tmp/reactor_knn.pkl (kNN-GP model)
|
|
- /tmp/nucon_dataset.pkl (500-sample dataset for init_states)
|
|
"""
|
|
import argparse
|
|
import pickle
|
|
from gymnasium.wrappers import TimeLimit
|
|
from stable_baselines3 import SAC
|
|
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
|
|
|
|
from nucon.sim import NuconSimulator
|
|
from nucon.rl import NuconGoalEnv, Parameterized_Objectives, Parameterized_Terminators
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--load', default=None, help='Path to existing model to hot-start from')
|
|
parser.add_argument('--steps', type=int, default=50_000, help='Total timesteps (default: 50000)')
|
|
parser.add_argument('--out', default='/tmp/sac_nucon_knn', help='Output path for saved model')
|
|
args = parser.parse_args()
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Load model and dataset
|
|
# ---------------------------------------------------------------------------
|
|
with open('/tmp/reactor_knn.pkl', 'rb') as f:
|
|
knn_model = pickle.load(f)
|
|
|
|
with open('/tmp/nucon_dataset.pkl', 'rb') as f:
|
|
dataset = pickle.load(f)
|
|
|
|
# Seed resets to in-distribution states from dataset
|
|
init_states = [s for _, _, s, _ in dataset]
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Build sim + env
|
|
# ---------------------------------------------------------------------------
|
|
sim = NuconSimulator(port=8786)
|
|
sim.set_model(knn_model)
|
|
|
|
BATCH_SIZE = 2048
|
|
MAX_EPISODE_STEPS = 200
|
|
|
|
env = NuconGoalEnv(
|
|
goal_params=['CORE_TEMP'],
|
|
goal_range={'CORE_TEMP': (55.0, 550.0)},
|
|
tolerance=0.05,
|
|
seconds_per_step=10,
|
|
simulator=sim,
|
|
additional_objectives=[
|
|
Parameterized_Objectives['uncertainty_penalty'](start=0.3),
|
|
],
|
|
additional_objective_weights=[1.0],
|
|
init_states=init_states,
|
|
delta_action_scale=0.05,
|
|
)
|
|
|
|
env = TimeLimit(env, max_episode_steps=MAX_EPISODE_STEPS)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SAC + HER
|
|
# learning_starts = batch_size: wait for batch_size complete (short) episodes
|
|
# before the first gradient step. As the policy learns to stay in-dist, episodes
|
|
# will get longer and HER has more transitions to relabel.
|
|
# ---------------------------------------------------------------------------
|
|
if args.load:
|
|
print(f"Hot-starting from {args.load}")
|
|
model = SAC.load(args.load, env=env, device='auto',
|
|
custom_objects={'learning_rate': 3e-4, 'batch_size': BATCH_SIZE,
|
|
'tau': 0.005, 'gamma': 0.98,
|
|
'train_freq': 64, 'gradient_steps': 8,
|
|
'learning_starts': 0})
|
|
else:
|
|
model = SAC(
|
|
'MultiInputPolicy',
|
|
env,
|
|
replay_buffer_class=HerReplayBuffer,
|
|
replay_buffer_kwargs={
|
|
'n_sampled_goal': 4,
|
|
'goal_selection_strategy': 'future',
|
|
},
|
|
verbose=1,
|
|
learning_rate=3e-4,
|
|
batch_size=BATCH_SIZE,
|
|
tau=0.005,
|
|
gamma=0.98,
|
|
train_freq=64,
|
|
gradient_steps=8,
|
|
learning_starts=BATCH_SIZE,
|
|
device='auto',
|
|
)
|
|
|
|
model.learn(total_timesteps=args.steps)
|
|
model.save(args.out)
|
|
print(f"Saved to {args.out}.zip")
|