- nucon/rl.py: delta_action_scale action space, bool handling (>=0.5), direct sim read/write bypassing HTTP for ~2000fps env throughput; remove uncertainty_abort from training (use penalty-only), larger default batch sizes; fix _read_obs and step for in-process sim - nucon/model.py: optimise _lookup with einsum squared-L2, vectorised rbf kernel; forward_with_uncertainty uses pre-built normalised arrays - nucon/sim.py: _update_reactor_state writes outputs via setattr directly - scripts/train_sac.py: moved from root; full SAC+HER example with kNN-GP sim, delta actions, uncertainty penalty, init_states - scripts/collect_dataset.py: CLI tool to collect dynamics dataset from live game session (--steps, --delta, --out, --merge) - README.md: add Scripts section, reference both scripts in training loop Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
84 lines
2.5 KiB
Python
84 lines
2.5 KiB
Python
"""SAC + HER training on kNN-GP simulator.
|
|
|
|
Usage:
|
|
python3.14 train_sac.py
|
|
|
|
Requirements:
|
|
- NuCon game running (for parameter metadata)
|
|
- /tmp/reactor_knn.pkl (kNN-GP model)
|
|
- /tmp/nucon_dataset.pkl (500-sample dataset for init_states)
|
|
"""
|
|
import pickle
|
|
from gymnasium.wrappers import TimeLimit
|
|
from stable_baselines3 import SAC
|
|
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
|
|
|
|
from nucon.sim import NuconSimulator
|
|
from nucon.rl import NuconGoalEnv, Parameterized_Objectives, Parameterized_Terminators
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Load model and dataset
|
|
# ---------------------------------------------------------------------------
|
|
with open('/tmp/reactor_knn.pkl', 'rb') as f:
|
|
knn_model = pickle.load(f)
|
|
|
|
with open('/tmp/nucon_dataset.pkl', 'rb') as f:
|
|
dataset = pickle.load(f)
|
|
|
|
# Seed resets to in-distribution states from dataset
|
|
init_states = [s for _, _, s, _ in dataset]
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Build sim + env
|
|
# ---------------------------------------------------------------------------
|
|
sim = NuconSimulator(port=8786)
|
|
sim.set_model(knn_model)
|
|
|
|
BATCH_SIZE = 2048
|
|
MAX_EPISODE_STEPS = 200
|
|
|
|
env = NuconGoalEnv(
|
|
goal_params=['CORE_TEMP'],
|
|
goal_range={'CORE_TEMP': (55.0, 550.0)},
|
|
tolerance=0.05,
|
|
seconds_per_step=10,
|
|
simulator=sim,
|
|
additional_objectives=[
|
|
Parameterized_Objectives['uncertainty_penalty'](start=0.3),
|
|
],
|
|
additional_objective_weights=[1.0],
|
|
init_states=init_states,
|
|
delta_action_scale=0.05,
|
|
)
|
|
|
|
env = TimeLimit(env, max_episode_steps=MAX_EPISODE_STEPS)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SAC + HER
|
|
# learning_starts = batch_size: wait for batch_size complete (short) episodes
|
|
# before the first gradient step. As the policy learns to stay in-dist, episodes
|
|
# will get longer and HER has more transitions to relabel.
|
|
# ---------------------------------------------------------------------------
|
|
model = SAC(
|
|
'MultiInputPolicy',
|
|
env,
|
|
replay_buffer_class=HerReplayBuffer,
|
|
replay_buffer_kwargs={
|
|
'n_sampled_goal': 4,
|
|
'goal_selection_strategy': 'future',
|
|
},
|
|
verbose=1,
|
|
learning_rate=3e-4,
|
|
batch_size=BATCH_SIZE,
|
|
tau=0.005,
|
|
gamma=0.98,
|
|
train_freq=64,
|
|
gradient_steps=8,
|
|
learning_starts=BATCH_SIZE,
|
|
device='auto',
|
|
)
|
|
|
|
model.learn(total_timesteps=50_000)
|
|
model.save('/tmp/sac_nucon_knn')
|
|
print("Saved to /tmp/sac_nucon_knn.zip")
|