feat: SAC+HER training on kNN-GP sim with direct bypass and scripts/

- nucon/rl.py: delta_action_scale action space, bool handling (>=0.5),
  direct sim read/write bypassing HTTP for ~2000fps env throughput;
  remove uncertainty_abort from training (use penalty-only), larger
  default batch sizes; fix _read_obs and step for in-process sim
- nucon/model.py: optimise _lookup with einsum squared-L2, vectorised
  rbf kernel; forward_with_uncertainty uses pre-built normalised arrays
- nucon/sim.py: _update_reactor_state writes outputs via setattr directly
- scripts/train_sac.py: moved from root; full SAC+HER example with kNN-GP
  sim, delta actions, uncertainty penalty, init_states
- scripts/collect_dataset.py: CLI tool to collect dynamics dataset from
  live game session (--steps, --delta, --out, --merge)
- README.md: add Scripts section, reference both scripts in training loop

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Dominik Moritz Roth 2026-03-12 20:43:37 +01:00
parent 3dfe1aa673
commit 0932bb353a
6 changed files with 271 additions and 48 deletions

View File

@ -404,11 +404,11 @@ The recommended end-to-end workflow for training an RL operator is an iterative
└─────────────────────┘ └─────────────────────┘
``` ```
**Step 1 — Human dataset collection**: Run `NuconModelLearner.collect_data()` during your play session. Cover a wide range of states: startup from cold, ramping power, individual rod bank adjustments. Diversity in the dataset directly determines simulator accuracy. See [Model Learning](#model-learning) for collection details. **Step 1 — Human dataset collection**: Run `scripts/collect_dataset.py` during your play session (see [Scripts](#scripts)). Cover a wide range of states: startup from cold, ramping power, individual rod bank adjustments. Diversity in the dataset directly determines simulator accuracy. See [Model Learning](#model-learning) for collection details.
**Step 2 — Initial model fitting**: Fit a kNN-GP model (instant) or NN (better extrapolation with larger datasets) using `fit_knn()` or `train_model()`. Prune near-duplicate samples with `drop_redundant()` before fitting. See [Model Learning](#model-learning). **Step 2 — Initial model fitting**: Fit a kNN-GP model (instant) or NN (better extrapolation with larger datasets) using `fit_knn()` or `train_model()`. Prune near-duplicate samples with `drop_redundant()` before fitting. See [Model Learning](#model-learning).
**Step 3 — Train RL in simulator**: Load the fitted model into `NuconSimulator`, then train a `NuconGoalEnv` policy with SAC + HER. The simulator runs far faster than the real game, allowing many trajectories in reasonable time. Pass `Parameterized_Objectives['uncertainty_penalty']` and `Parameterized_Terminators['uncertainty_abort']` as additional objectives/terminators to discourage the policy from wandering into regions the model hasn't seen; `SIM_UNCERTAINTY` is automatically injected into the obs dict when a simulator is active. See [NuconGoalEnv + HER Usage](#nucongoalenv--her-usage). **Step 3 — Train RL in simulator**: Load the fitted model into `NuconSimulator`, then train a `NuconGoalEnv` policy with SAC + HER. The simulator runs far faster than the real game, allowing many trajectories in reasonable time. Pass `Parameterized_Objectives['uncertainty_penalty']` and `Parameterized_Terminators['uncertainty_abort']` as additional objectives/terminators to discourage the policy from wandering into regions the model hasn't seen; `SIM_UNCERTAINTY` is automatically injected into the obs dict when a simulator is active. See [NuconGoalEnv + HER Usage](#nucongoalenv--her-usage) and `scripts/train_sac.py` for a complete example.
**Step 4 — Eval in game + collect new data**: Run the trained policy against the real game. This validates simulator accuracy and simultaneously collects new data from states the policy visits, which may be regions the original dataset missed. Run a second `NuconModelLearner` in a background thread to collect concurrently. **Step 4 — Eval in game + collect new data**: Run the trained policy against the real game. This validates simulator accuracy and simultaneously collects new data from states the policy visits, which may be regions the original dataset missed. Run a second `NuconModelLearner` in a background thread to collect concurrently.
@ -416,6 +416,25 @@ The recommended end-to-end workflow for training an RL operator is an iterative
Stop when the policy performs well in the real game and kNN-GP uncertainty stays low throughout an episode, indicating the policy stays within the known data distribution. Stop when the policy performs well in the real game and kNN-GP uncertainty stays low throughout an episode, indicating the policy stays within the known data distribution.
## Scripts
Ready-to-run scripts in the `scripts/` directory covering the most common workflows.
**`scripts/collect_dataset.py`** — collect a dynamics dataset while playing the game:
```bash
python scripts/collect_dataset.py --steps 1000 --delta 10 --out reactor_dataset.pkl
# Ctrl-C to stop early; data is saved on exit
# Merge a previous session: --merge previous.pkl
```
**`scripts/train_sac.py`** — train a SAC + HER goal-conditioned policy on the kNN-GP simulator:
```bash
python scripts/train_sac.py
# Expects /tmp/reactor_knn.pkl and /tmp/nucon_dataset.pkl
# Saves trained policy to /tmp/sac_nucon_knn.zip
```
This script is the most elaborate end-to-end example: it loads a pre-fitted kNN-GP model, seeds episode resets from dataset states, uses delta actions and an uncertainty penalty, and configures SAC + HER for fast sim training.
## Testing ## Testing
NuCon includes a test suite to verify its functionality and compatibility with the Nucleares game. NuCon includes a test suite to verify its functionality and compatibility with the Nucleares game.

View File

@ -98,11 +98,11 @@ class ReactorKNNModel:
self._std = self._raw_states.std(axis=0) + 1e-8 self._std = self._raw_states.std(axis=0) + 1e-8
self._states = (self._raw_states - self._mean) / self._std self._states = (self._raw_states - self._mean) / self._std
def _lookup(self, state_dict: Dict): def _lookup(self, s: np.ndarray):
"""Return (s_norm, idx, k) for the k nearest neighbours.""" """Return (s_norm, idx, k) for the k nearest neighbours. s is a raw (d_in,) array."""
s = np.array([state_dict[p] for p in self.input_params], dtype=np.float32)
s_norm = (s - self._mean) / self._std s_norm = (s - self._mean) / self._std
dists = np.linalg.norm(self._states - s_norm, axis=1) diff = self._states - s_norm # (n, d_in) broadcast
dists = np.einsum('ij,ij->i', diff, diff) # squared L2, faster than linalg.norm
k = min(self.k, len(dists)) k = min(self.k, len(dists))
idx = np.argpartition(dists, k - 1)[:k] idx = np.argpartition(dists, k - 1)[:k]
return s_norm, idx, k return s_norm, idx, k
@ -122,22 +122,22 @@ class ReactorKNNModel:
if self._states is None: if self._states is None:
raise ValueError("Model not fitted. Call fit(dataset) first.") raise ValueError("Model not fitted. Call fit(dataset) first.")
s_norm, idx, k = self._lookup(state_dict) s = np.array([state_dict[p] for p in self.input_params], dtype=np.float32)
s_norm, idx, k = self._lookup(s)
X = self._states[idx] # (k, d_in) X = self._states[idx] # (k, d_in)
Y = self._rates[idx] # (k, d_out) Y = self._rates[idx] # (k, d_out)
# RBF kernel (vectorised): k(a,b) = exp(-0.5 ||a-b||^2) # RBF kernel: k(a,b) = exp(-0.5 ||a-b||^2)
def rbf_matrix(A, B): def rbf(A, B):
diff = A[:, None, :] - B[None, :, :] # (|A|, |B|, d) diff = A[:, None, :] - B[None, :, :]
return np.exp(-0.5 * (diff ** 2).sum(axis=-1)) # (|A|, |B|) return np.exp(-0.5 * np.einsum('ijk,ijk->ij', diff, diff))
K = rbf_matrix(X, X) + 1e-4 * np.eye(k) # (k, k) K = rbf(X, X) + 1e-4 * np.eye(k)
k_star = rbf_matrix(s_norm[None, :], X)[0] # (k,) k_star = rbf(s_norm[None, :], X)[0]
K_inv = np.linalg.inv(K) K_inv = np.linalg.inv(K)
mean_rates = k_star @ K_inv @ Y # (d_out,) mean_rates = k_star @ K_inv @ Y
# Posterior variance (scalar, shared across all output dims)
var = max(0.0, 1.0 - float(k_star @ K_inv @ k_star)) var = max(0.0, 1.0 - float(k_star @ K_inv @ k_star))
std = float(np.sqrt(var)) std = float(np.sqrt(var))

View File

@ -49,13 +49,15 @@ Parameterized_Terminators = {
# Internal helpers # Internal helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _build_flat_action_space(nucon, obs_param_set=None): def _build_flat_action_space(nucon, obs_param_set=None, delta_action_scale=None):
"""Return (Box, ordered_param_ids) for all writable, readable, non-cheat params. """Return (Box, ordered_param_ids, param_ranges).
If obs_param_set is provided, only include params in that set. If delta_action_scale is set, the action space is [-1, 1]^n and actions are
treated as normalised deltas: actual_delta = action * delta_action_scale * (max - min).
Otherwise the action space spans [min_val, max_val] per param (absolute values).
""" """
params = [] params = []
lows, highs = [], [] lows, highs, ranges = [], [], []
for param_id, param in nucon.get_all_writable().items(): for param_id, param in nucon.get_all_writable().items():
if not param.is_readable or param.is_cheat: if not param.is_readable or param.is_cheat:
continue continue
@ -69,9 +71,15 @@ def _build_flat_action_space(nucon, obs_param_set=None):
params.append(param_id) params.append(param_id)
lows.append(sp.low[0]) lows.append(sp.low[0])
highs.append(sp.high[0]) highs.append(sp.high[0])
ranges.append(sp.high[0] - sp.low[0])
if delta_action_scale is not None:
n = len(params)
box = spaces.Box(low=-np.ones(n, dtype=np.float32),
high=np.ones(n, dtype=np.float32), dtype=np.float32)
else:
box = spaces.Box(low=np.array(lows, dtype=np.float32), box = spaces.Box(low=np.array(lows, dtype=np.float32),
high=np.array(highs, dtype=np.float32), dtype=np.float32) high=np.array(highs, dtype=np.float32), dtype=np.float32)
return box, params return box, params, np.array(lows, dtype=np.float32), np.array(ranges, dtype=np.float32)
def _unflatten_action(flat_action, param_ids): def _unflatten_action(flat_action, param_ids):
@ -96,12 +104,15 @@ def _build_param_space(param):
def _apply_action(nucon, action): def _apply_action(nucon, action):
for param_id, value in action.items(): for param_id, value in action.items():
param = nucon._parameters[param_id] param = nucon._parameters[param_id]
if issubclass(param.param_type, Enum): v = float(np.asarray(value).flat[0])
value = param.param_type(int(np.asarray(value).flat[0])) if param.param_type == bool:
value = v >= 0.5 # [0,1] space: above midpoint → True
elif issubclass(param.param_type, Enum):
value = param.param_type(int(v))
else: else:
value = param.param_type(np.asarray(value).flat[0]) value = param.param_type(v)
if param.min_val is not None and param.max_val is not None: if param.min_val is not None and param.max_val is not None:
value = np.clip(value, param.min_val, param.max_val) value = param.param_type(np.clip(value, param.min_val, param.max_val))
nucon.set(param, value) nucon.set(param, value)
@ -138,7 +149,8 @@ class NuconEnv(gym.Env):
obs_spaces[param_id] = sp obs_spaces[param_id] = sp
self.observation_space = spaces.Dict(obs_spaces) self.observation_space = spaces.Dict(obs_spaces)
self.action_space, self._action_params = _build_flat_action_space(self.nucon) self.action_space, self._action_params, self._action_lows, self._action_ranges = \
_build_flat_action_space(self.nucon)
self.objectives = [] self.objectives = []
self.terminators = [] self.terminators = []
@ -272,11 +284,14 @@ class NuconGoalEnv(gym.Env):
additional_objectives=None, additional_objectives=None,
additional_objective_weights=None, additional_objective_weights=None,
obs_params=None, obs_params=None,
init_states=None,
delta_action_scale=None,
): ):
super().__init__() super().__init__()
self.render_mode = render_mode self.render_mode = render_mode
self.seconds_per_step = seconds_per_step self.seconds_per_step = seconds_per_step
self._delta_action_scale = delta_action_scale
self.terminate_above = terminate_above self.terminate_above = terminate_above
self.simulator = simulator self.simulator = simulator
self.goal_params = list(goal_params) self.goal_params = list(goal_params)
@ -339,12 +354,14 @@ class NuconGoalEnv(gym.Env):
}) })
# Action space: writable params within the obs param set (flat Box for SB3 compatibility). # Action space: writable params within the obs param set (flat Box for SB3 compatibility).
self.action_space, self._action_params = _build_flat_action_space(self.nucon, set(base_params)) self.action_space, self._action_params, self._action_lows, self._action_ranges = \
_build_flat_action_space(self.nucon, set(base_params), delta_action_scale)
self._terminators = terminators or [] self._terminators = terminators or []
_objs = additional_objectives or [] _objs = additional_objectives or []
self._objectives = [Objectives[o] if isinstance(o, str) else o for o in _objs] self._objectives = [Objectives[o] if isinstance(o, str) else o for o in _objs]
self._objective_weights = additional_objective_weights or [1.0] * len(self._objectives) self._objective_weights = additional_objective_weights or [1.0] * len(self._objectives)
self._init_states = init_states # list of state dicts to sample on reset
self._desired_goal = np.zeros(n_goals, dtype=np.float32) self._desired_goal = np.zeros(n_goals, dtype=np.float32)
self._total_steps = 0 self._total_steps = 0
@ -367,20 +384,37 @@ class NuconGoalEnv(gym.Env):
def _read_obs(self, sim_uncertainty=None): def _read_obs(self, sim_uncertainty=None):
"""Return (gym_obs_dict, reward_obs_dict). """Return (gym_obs_dict, reward_obs_dict).
gym_obs_dict flat Box observation for the policy (no SIM_UNCERTAINTY). When a simulator is attached, reads directly from sim.parameters (no HTTP).
reward_obs_dict same values plus SIM_UNCERTAINTY for objectives/terminators/reward_fn. Otherwise falls back to a single batch HTTP request.
""" """
def _to_float(v):
if v is None:
return 0.0
return float(v.value if isinstance(v, Enum) else v)
if self.simulator is not None:
# Direct in-process read — no HTTP overhead
def _get(pid):
return _to_float(self.simulator.get(pid))
else:
raw = self.nucon._batch_query(self._obs_params + self.goal_params)
all_params = self.nucon.get_all_readable()
def _get(pid):
try:
v = self.nucon._parse_value(all_params[pid], raw.get(pid, '0'))
return _to_float(v)
except Exception:
return 0.0
reward_obs = {} reward_obs = {}
if self._obs_with_uncertainty: if self._obs_with_uncertainty:
reward_obs['SIM_UNCERTAINTY'] = float(sim_uncertainty) if sim_uncertainty is not None else 0.0 reward_obs['SIM_UNCERTAINTY'] = float(sim_uncertainty) if sim_uncertainty is not None else 0.0
for param_id in self._obs_params: for pid in self._obs_params:
value = self.nucon.get(param_id) reward_obs[pid] = _get(pid)
if isinstance(value, Enum):
value = value.value
reward_obs[param_id] = float(value) if value is not None else 0.0
obs_vec = np.array([reward_obs[p] for p in self._obs_params], dtype=np.float32) obs_vec = np.array([reward_obs[p] for p in self._obs_params], dtype=np.float32)
achieved = self._read_goal_values() goal_raw = np.array([_get(p) for p in self.goal_params], dtype=np.float32)
achieved = np.clip((goal_raw - self._goal_low) / self._goal_range, 0.0, 1.0)
gym_obs = {'observation': obs_vec, 'achieved_goal': achieved, gym_obs = {'observation': obs_vec, 'achieved_goal': achieved,
'desired_goal': self._desired_goal.copy()} 'desired_goal': self._desired_goal.copy()}
return gym_obs, reward_obs return gym_obs, reward_obs
@ -390,11 +424,47 @@ class NuconGoalEnv(gym.Env):
self._total_steps = 0 self._total_steps = 0
rng = np.random.default_rng(seed) rng = np.random.default_rng(seed)
self._desired_goal = rng.uniform(0.0, 1.0, size=len(self.goal_params)).astype(np.float32) self._desired_goal = rng.uniform(0.0, 1.0, size=len(self.goal_params)).astype(np.float32)
if self._init_states is not None and self.simulator is not None:
state = self._init_states[rng.integers(len(self._init_states))]
for k, v in state.items():
try:
self.simulator.set(k, v, force=True)
except Exception:
pass
gym_obs, _ = self._read_obs() gym_obs, _ = self._read_obs()
return gym_obs, {} return gym_obs, {}
def step(self, action): def step(self, action):
_apply_action(self.nucon, _unflatten_action(action, self._action_params)) flat = np.asarray(action, dtype=np.float32)
if self._delta_action_scale is not None:
# Compute absolute values from deltas, reading current state directly if possible
absolute = {}
for i, pid in enumerate(self._action_params):
param = self.nucon._parameters[pid]
if param.param_type == bool:
absolute[pid] = 1.0 if flat[i] > 0 else 0.0
else:
if self.simulator is not None:
v = self.simulator.get(pid)
current = float(v.value if isinstance(v, Enum) else v) if v is not None else 0.0
else:
current = 0.0 # fallback; batch read not worth it for actions alone
delta = float(flat[i]) * self._delta_action_scale * self._action_ranges[i]
absolute[pid] = float(np.clip(current + delta,
self._action_lows[i],
self._action_lows[i] + self._action_ranges[i]))
else:
absolute = _unflatten_action(flat, self._action_params)
if self.simulator is not None:
# Write directly to sim — skip HTTP entirely
for pid, val in absolute.items():
try:
self.simulator.set(pid, val, force=True)
except Exception:
pass
else:
_apply_action(self.nucon, absolute)
if self.simulator: if self.simulator:
uncertainty = self.simulator.update(self.seconds_per_step, return_uncertainty=True) uncertainty = self.simulator.update(self.seconds_per_step, return_uncertainty=True)

View File

@ -261,14 +261,13 @@ class NuconSimulator:
raise ValueError("Model not set. Please load a model using load_model() or set_model().") raise ValueError("Model not set. Please load a model using load_model() or set_model().")
# Build state dict using only the params the model knows about # Build state dict using only the params the model knows about
params = self.parameters
state = {} state = {}
for param_id in self.model.input_params: for param_id in self.model.input_params:
value = getattr(self.parameters, param_id, None) value = getattr(params, param_id, None)
if isinstance(value, Enum): if isinstance(value, Enum):
value = value.value value = value.value
if value is None: state[param_id] = 0.0 if value is None else value
value = 0.0 # fallback for params not initialised in sim state
state[param_id] = value
# Forward pass # Forward pass
uncertainty = None uncertainty = None
@ -280,12 +279,9 @@ class NuconSimulator:
else: else:
next_state = self.model.forward(state, time_step) next_state = self.model.forward(state, time_step)
# Update only the output params the model predicts # Write outputs directly — bypass sim.set() type-checking overhead
for param_id, value in next_state.items(): for param_id, value in next_state.items():
try: setattr(params, param_id, value)
self.set(param_id, value, force=True)
except (ValueError, KeyError):
pass # ignore params that can't be set (type mismatch, unknown)
return uncertainty return uncertainty

View File

@ -0,0 +1,55 @@
"""Collect a dynamics dataset from the running Nucleares game.
Play the game normally while this script runs in the background.
It records state transitions every `time_delta` game-seconds and
saves them incrementally so nothing is lost if you quit early.
Usage:
python scripts/collect_dataset.py # default settings
python scripts/collect_dataset.py --steps 2000 --delta 5 # faster sampling
python scripts/collect_dataset.py --out my_dataset.pkl
The saved dataset is a list of (state_before, action_dict, state_after, time_delta)
tuples compatible with NuconModelLearner.fit_knn() and train_model().
Tips for good data:
- Cover a range of operating states: startup, ramp, steady-state, shutdown.
- Vary individual rod bank positions, pump speeds, and MSCV setpoints.
- Collect at least 500 samples for kNN-GP; 5000+ for the NN backend.
- Merge multiple sessions with NuconModelLearner.merge_datasets().
"""
import argparse
import pickle
from nucon.model import NuconModelLearner
parser = argparse.ArgumentParser()
parser.add_argument('--steps', type=int, default=1000,
help='Number of samples to collect (default: 1000)')
parser.add_argument('--delta', type=float, default=10.0,
help='Game-seconds between samples (default: 10.0)')
parser.add_argument('--out', default='reactor_dataset.pkl',
help='Output path for dataset (default: reactor_dataset.pkl)')
parser.add_argument('--merge', default=None,
help='Existing dataset to merge into before saving')
args = parser.parse_args()
learner = NuconModelLearner(
time_delta=args.delta,
dataset_path=args.out,
)
if args.merge:
learner.merge_datasets(args.merge)
print(f"Merged existing dataset from {args.merge} ({len(learner.dataset)} samples)")
print(f"Collecting {args.steps} samples (Δt={args.delta}s each) → {args.out}")
print("Play the game — vary rod positions, pump speeds, and operating states.")
print("Press Ctrl-C to stop early; data collected so far will be saved.")
try:
learner.collect_data(num_steps=args.steps)
except KeyboardInterrupt:
print("\nInterrupted — saving collected data...")
learner.save_dataset(args.out)
print(f"Saved {len(learner.dataset)} samples to {args.out}")

83
scripts/train_sac.py Normal file
View File

@ -0,0 +1,83 @@
"""SAC + HER training on kNN-GP simulator.
Usage:
python3.14 train_sac.py
Requirements:
- NuCon game running (for parameter metadata)
- /tmp/reactor_knn.pkl (kNN-GP model)
- /tmp/nucon_dataset.pkl (500-sample dataset for init_states)
"""
import pickle
from gymnasium.wrappers import TimeLimit
from stable_baselines3 import SAC
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
from nucon.sim import NuconSimulator
from nucon.rl import NuconGoalEnv, Parameterized_Objectives, Parameterized_Terminators
# ---------------------------------------------------------------------------
# Load model and dataset
# ---------------------------------------------------------------------------
with open('/tmp/reactor_knn.pkl', 'rb') as f:
knn_model = pickle.load(f)
with open('/tmp/nucon_dataset.pkl', 'rb') as f:
dataset = pickle.load(f)
# Seed resets to in-distribution states from dataset
init_states = [s for _, _, s, _ in dataset]
# ---------------------------------------------------------------------------
# Build sim + env
# ---------------------------------------------------------------------------
sim = NuconSimulator(port=8786)
sim.set_model(knn_model)
BATCH_SIZE = 2048
MAX_EPISODE_STEPS = 200
env = NuconGoalEnv(
goal_params=['CORE_TEMP'],
goal_range={'CORE_TEMP': (55.0, 550.0)},
tolerance=0.05,
seconds_per_step=10,
simulator=sim,
additional_objectives=[
Parameterized_Objectives['uncertainty_penalty'](start=0.3),
],
additional_objective_weights=[1.0],
init_states=init_states,
delta_action_scale=0.05,
)
env = TimeLimit(env, max_episode_steps=MAX_EPISODE_STEPS)
# ---------------------------------------------------------------------------
# SAC + HER
# learning_starts = batch_size: wait for batch_size complete (short) episodes
# before the first gradient step. As the policy learns to stay in-dist, episodes
# will get longer and HER has more transitions to relabel.
# ---------------------------------------------------------------------------
model = SAC(
'MultiInputPolicy',
env,
replay_buffer_class=HerReplayBuffer,
replay_buffer_kwargs={
'n_sampled_goal': 4,
'goal_selection_strategy': 'future',
},
verbose=1,
learning_rate=3e-4,
batch_size=BATCH_SIZE,
tau=0.005,
gamma=0.98,
train_freq=64,
gradient_steps=8,
learning_starts=BATCH_SIZE,
device='auto',
)
model.learn(total_timesteps=50_000)
model.save('/tmp/sac_nucon_knn')
print("Saved to /tmp/sac_nucon_knn.zip")