feat: SAC+HER training on kNN-GP sim with direct bypass and scripts/

- nucon/rl.py: delta_action_scale action space, bool handling (>=0.5),
  direct sim read/write bypassing HTTP for ~2000fps env throughput;
  remove uncertainty_abort from training (use penalty-only), larger
  default batch sizes; fix _read_obs and step for in-process sim
- nucon/model.py: optimise _lookup with einsum squared-L2, vectorised
  rbf kernel; forward_with_uncertainty uses pre-built normalised arrays
- nucon/sim.py: _update_reactor_state writes outputs via setattr directly
- scripts/train_sac.py: moved from root; full SAC+HER example with kNN-GP
  sim, delta actions, uncertainty penalty, init_states
- scripts/collect_dataset.py: CLI tool to collect dynamics dataset from
  live game session (--steps, --delta, --out, --merge)
- README.md: add Scripts section, reference both scripts in training loop

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Dominik Moritz Roth 2026-03-12 20:43:37 +01:00
parent 3dfe1aa673
commit 0932bb353a
6 changed files with 271 additions and 48 deletions

View File

@ -404,11 +404,11 @@ The recommended end-to-end workflow for training an RL operator is an iterative
└─────────────────────┘
```
**Step 1 — Human dataset collection**: Run `NuconModelLearner.collect_data()` during your play session. Cover a wide range of states: startup from cold, ramping power, individual rod bank adjustments. Diversity in the dataset directly determines simulator accuracy. See [Model Learning](#model-learning) for collection details.
**Step 1 — Human dataset collection**: Run `scripts/collect_dataset.py` during your play session (see [Scripts](#scripts)). Cover a wide range of states: startup from cold, ramping power, individual rod bank adjustments. Diversity in the dataset directly determines simulator accuracy. See [Model Learning](#model-learning) for collection details.
**Step 2 — Initial model fitting**: Fit a kNN-GP model (instant) or NN (better extrapolation with larger datasets) using `fit_knn()` or `train_model()`. Prune near-duplicate samples with `drop_redundant()` before fitting. See [Model Learning](#model-learning).
**Step 3 — Train RL in simulator**: Load the fitted model into `NuconSimulator`, then train a `NuconGoalEnv` policy with SAC + HER. The simulator runs far faster than the real game, allowing many trajectories in reasonable time. Pass `Parameterized_Objectives['uncertainty_penalty']` and `Parameterized_Terminators['uncertainty_abort']` as additional objectives/terminators to discourage the policy from wandering into regions the model hasn't seen; `SIM_UNCERTAINTY` is automatically injected into the obs dict when a simulator is active. See [NuconGoalEnv + HER Usage](#nucongoalenv--her-usage).
**Step 3 — Train RL in simulator**: Load the fitted model into `NuconSimulator`, then train a `NuconGoalEnv` policy with SAC + HER. The simulator runs far faster than the real game, allowing many trajectories in reasonable time. Pass `Parameterized_Objectives['uncertainty_penalty']` and `Parameterized_Terminators['uncertainty_abort']` as additional objectives/terminators to discourage the policy from wandering into regions the model hasn't seen; `SIM_UNCERTAINTY` is automatically injected into the obs dict when a simulator is active. See [NuconGoalEnv + HER Usage](#nucongoalenv--her-usage) and `scripts/train_sac.py` for a complete example.
**Step 4 — Eval in game + collect new data**: Run the trained policy against the real game. This validates simulator accuracy and simultaneously collects new data from states the policy visits, which may be regions the original dataset missed. Run a second `NuconModelLearner` in a background thread to collect concurrently.
@ -416,6 +416,25 @@ The recommended end-to-end workflow for training an RL operator is an iterative
Stop when the policy performs well in the real game and kNN-GP uncertainty stays low throughout an episode, indicating the policy stays within the known data distribution.
## Scripts
Ready-to-run scripts in the `scripts/` directory covering the most common workflows.
**`scripts/collect_dataset.py`** — collect a dynamics dataset while playing the game:
```bash
python scripts/collect_dataset.py --steps 1000 --delta 10 --out reactor_dataset.pkl
# Ctrl-C to stop early; data is saved on exit
# Merge a previous session: --merge previous.pkl
```
**`scripts/train_sac.py`** — train a SAC + HER goal-conditioned policy on the kNN-GP simulator:
```bash
python scripts/train_sac.py
# Expects /tmp/reactor_knn.pkl and /tmp/nucon_dataset.pkl
# Saves trained policy to /tmp/sac_nucon_knn.zip
```
This script is the most elaborate end-to-end example: it loads a pre-fitted kNN-GP model, seeds episode resets from dataset states, uses delta actions and an uncertainty penalty, and configures SAC + HER for fast sim training.
## Testing
NuCon includes a test suite to verify its functionality and compatibility with the Nucleares game.

View File

@ -98,11 +98,11 @@ class ReactorKNNModel:
self._std = self._raw_states.std(axis=0) + 1e-8
self._states = (self._raw_states - self._mean) / self._std
def _lookup(self, state_dict: Dict):
"""Return (s_norm, idx, k) for the k nearest neighbours."""
s = np.array([state_dict[p] for p in self.input_params], dtype=np.float32)
def _lookup(self, s: np.ndarray):
"""Return (s_norm, idx, k) for the k nearest neighbours. s is a raw (d_in,) array."""
s_norm = (s - self._mean) / self._std
dists = np.linalg.norm(self._states - s_norm, axis=1)
diff = self._states - s_norm # (n, d_in) broadcast
dists = np.einsum('ij,ij->i', diff, diff) # squared L2, faster than linalg.norm
k = min(self.k, len(dists))
idx = np.argpartition(dists, k - 1)[:k]
return s_norm, idx, k
@ -122,22 +122,22 @@ class ReactorKNNModel:
if self._states is None:
raise ValueError("Model not fitted. Call fit(dataset) first.")
s_norm, idx, k = self._lookup(state_dict)
s = np.array([state_dict[p] for p in self.input_params], dtype=np.float32)
s_norm, idx, k = self._lookup(s)
X = self._states[idx] # (k, d_in)
Y = self._rates[idx] # (k, d_out)
# RBF kernel (vectorised): k(a,b) = exp(-0.5 ||a-b||^2)
def rbf_matrix(A, B):
diff = A[:, None, :] - B[None, :, :] # (|A|, |B|, d)
return np.exp(-0.5 * (diff ** 2).sum(axis=-1)) # (|A|, |B|)
# RBF kernel: k(a,b) = exp(-0.5 ||a-b||^2)
def rbf(A, B):
diff = A[:, None, :] - B[None, :, :]
return np.exp(-0.5 * np.einsum('ijk,ijk->ij', diff, diff))
K = rbf_matrix(X, X) + 1e-4 * np.eye(k) # (k, k)
k_star = rbf_matrix(s_norm[None, :], X)[0] # (k,)
K = rbf(X, X) + 1e-4 * np.eye(k)
k_star = rbf(s_norm[None, :], X)[0]
K_inv = np.linalg.inv(K)
mean_rates = k_star @ K_inv @ Y # (d_out,)
K_inv = np.linalg.inv(K)
mean_rates = k_star @ K_inv @ Y
# Posterior variance (scalar, shared across all output dims)
var = max(0.0, 1.0 - float(k_star @ K_inv @ k_star))
std = float(np.sqrt(var))

View File

@ -49,13 +49,15 @@ Parameterized_Terminators = {
# Internal helpers
# ---------------------------------------------------------------------------
def _build_flat_action_space(nucon, obs_param_set=None):
"""Return (Box, ordered_param_ids) for all writable, readable, non-cheat params.
def _build_flat_action_space(nucon, obs_param_set=None, delta_action_scale=None):
"""Return (Box, ordered_param_ids, param_ranges).
If obs_param_set is provided, only include params in that set.
If delta_action_scale is set, the action space is [-1, 1]^n and actions are
treated as normalised deltas: actual_delta = action * delta_action_scale * (max - min).
Otherwise the action space spans [min_val, max_val] per param (absolute values).
"""
params = []
lows, highs = [], []
lows, highs, ranges = [], [], []
for param_id, param in nucon.get_all_writable().items():
if not param.is_readable or param.is_cheat:
continue
@ -69,9 +71,15 @@ def _build_flat_action_space(nucon, obs_param_set=None):
params.append(param_id)
lows.append(sp.low[0])
highs.append(sp.high[0])
box = spaces.Box(low=np.array(lows, dtype=np.float32),
high=np.array(highs, dtype=np.float32), dtype=np.float32)
return box, params
ranges.append(sp.high[0] - sp.low[0])
if delta_action_scale is not None:
n = len(params)
box = spaces.Box(low=-np.ones(n, dtype=np.float32),
high=np.ones(n, dtype=np.float32), dtype=np.float32)
else:
box = spaces.Box(low=np.array(lows, dtype=np.float32),
high=np.array(highs, dtype=np.float32), dtype=np.float32)
return box, params, np.array(lows, dtype=np.float32), np.array(ranges, dtype=np.float32)
def _unflatten_action(flat_action, param_ids):
@ -96,12 +104,15 @@ def _build_param_space(param):
def _apply_action(nucon, action):
for param_id, value in action.items():
param = nucon._parameters[param_id]
if issubclass(param.param_type, Enum):
value = param.param_type(int(np.asarray(value).flat[0]))
v = float(np.asarray(value).flat[0])
if param.param_type == bool:
value = v >= 0.5 # [0,1] space: above midpoint → True
elif issubclass(param.param_type, Enum):
value = param.param_type(int(v))
else:
value = param.param_type(np.asarray(value).flat[0])
if param.min_val is not None and param.max_val is not None:
value = np.clip(value, param.min_val, param.max_val)
value = param.param_type(v)
if param.min_val is not None and param.max_val is not None:
value = param.param_type(np.clip(value, param.min_val, param.max_val))
nucon.set(param, value)
@ -138,7 +149,8 @@ class NuconEnv(gym.Env):
obs_spaces[param_id] = sp
self.observation_space = spaces.Dict(obs_spaces)
self.action_space, self._action_params = _build_flat_action_space(self.nucon)
self.action_space, self._action_params, self._action_lows, self._action_ranges = \
_build_flat_action_space(self.nucon)
self.objectives = []
self.terminators = []
@ -272,11 +284,14 @@ class NuconGoalEnv(gym.Env):
additional_objectives=None,
additional_objective_weights=None,
obs_params=None,
init_states=None,
delta_action_scale=None,
):
super().__init__()
self.render_mode = render_mode
self.seconds_per_step = seconds_per_step
self._delta_action_scale = delta_action_scale
self.terminate_above = terminate_above
self.simulator = simulator
self.goal_params = list(goal_params)
@ -339,12 +354,14 @@ class NuconGoalEnv(gym.Env):
})
# Action space: writable params within the obs param set (flat Box for SB3 compatibility).
self.action_space, self._action_params = _build_flat_action_space(self.nucon, set(base_params))
self.action_space, self._action_params, self._action_lows, self._action_ranges = \
_build_flat_action_space(self.nucon, set(base_params), delta_action_scale)
self._terminators = terminators or []
_objs = additional_objectives or []
self._objectives = [Objectives[o] if isinstance(o, str) else o for o in _objs]
self._objective_weights = additional_objective_weights or [1.0] * len(self._objectives)
self._init_states = init_states # list of state dicts to sample on reset
self._desired_goal = np.zeros(n_goals, dtype=np.float32)
self._total_steps = 0
@ -367,20 +384,37 @@ class NuconGoalEnv(gym.Env):
def _read_obs(self, sim_uncertainty=None):
"""Return (gym_obs_dict, reward_obs_dict).
gym_obs_dict flat Box observation for the policy (no SIM_UNCERTAINTY).
reward_obs_dict same values plus SIM_UNCERTAINTY for objectives/terminators/reward_fn.
When a simulator is attached, reads directly from sim.parameters (no HTTP).
Otherwise falls back to a single batch HTTP request.
"""
def _to_float(v):
if v is None:
return 0.0
return float(v.value if isinstance(v, Enum) else v)
if self.simulator is not None:
# Direct in-process read — no HTTP overhead
def _get(pid):
return _to_float(self.simulator.get(pid))
else:
raw = self.nucon._batch_query(self._obs_params + self.goal_params)
all_params = self.nucon.get_all_readable()
def _get(pid):
try:
v = self.nucon._parse_value(all_params[pid], raw.get(pid, '0'))
return _to_float(v)
except Exception:
return 0.0
reward_obs = {}
if self._obs_with_uncertainty:
reward_obs['SIM_UNCERTAINTY'] = float(sim_uncertainty) if sim_uncertainty is not None else 0.0
for param_id in self._obs_params:
value = self.nucon.get(param_id)
if isinstance(value, Enum):
value = value.value
reward_obs[param_id] = float(value) if value is not None else 0.0
for pid in self._obs_params:
reward_obs[pid] = _get(pid)
obs_vec = np.array([reward_obs[p] for p in self._obs_params], dtype=np.float32)
achieved = self._read_goal_values()
goal_raw = np.array([_get(p) for p in self.goal_params], dtype=np.float32)
achieved = np.clip((goal_raw - self._goal_low) / self._goal_range, 0.0, 1.0)
gym_obs = {'observation': obs_vec, 'achieved_goal': achieved,
'desired_goal': self._desired_goal.copy()}
return gym_obs, reward_obs
@ -390,11 +424,47 @@ class NuconGoalEnv(gym.Env):
self._total_steps = 0
rng = np.random.default_rng(seed)
self._desired_goal = rng.uniform(0.0, 1.0, size=len(self.goal_params)).astype(np.float32)
if self._init_states is not None and self.simulator is not None:
state = self._init_states[rng.integers(len(self._init_states))]
for k, v in state.items():
try:
self.simulator.set(k, v, force=True)
except Exception:
pass
gym_obs, _ = self._read_obs()
return gym_obs, {}
def step(self, action):
_apply_action(self.nucon, _unflatten_action(action, self._action_params))
flat = np.asarray(action, dtype=np.float32)
if self._delta_action_scale is not None:
# Compute absolute values from deltas, reading current state directly if possible
absolute = {}
for i, pid in enumerate(self._action_params):
param = self.nucon._parameters[pid]
if param.param_type == bool:
absolute[pid] = 1.0 if flat[i] > 0 else 0.0
else:
if self.simulator is not None:
v = self.simulator.get(pid)
current = float(v.value if isinstance(v, Enum) else v) if v is not None else 0.0
else:
current = 0.0 # fallback; batch read not worth it for actions alone
delta = float(flat[i]) * self._delta_action_scale * self._action_ranges[i]
absolute[pid] = float(np.clip(current + delta,
self._action_lows[i],
self._action_lows[i] + self._action_ranges[i]))
else:
absolute = _unflatten_action(flat, self._action_params)
if self.simulator is not None:
# Write directly to sim — skip HTTP entirely
for pid, val in absolute.items():
try:
self.simulator.set(pid, val, force=True)
except Exception:
pass
else:
_apply_action(self.nucon, absolute)
if self.simulator:
uncertainty = self.simulator.update(self.seconds_per_step, return_uncertainty=True)

View File

@ -261,14 +261,13 @@ class NuconSimulator:
raise ValueError("Model not set. Please load a model using load_model() or set_model().")
# Build state dict using only the params the model knows about
params = self.parameters
state = {}
for param_id in self.model.input_params:
value = getattr(self.parameters, param_id, None)
value = getattr(params, param_id, None)
if isinstance(value, Enum):
value = value.value
if value is None:
value = 0.0 # fallback for params not initialised in sim state
state[param_id] = value
state[param_id] = 0.0 if value is None else value
# Forward pass
uncertainty = None
@ -280,12 +279,9 @@ class NuconSimulator:
else:
next_state = self.model.forward(state, time_step)
# Update only the output params the model predicts
# Write outputs directly — bypass sim.set() type-checking overhead
for param_id, value in next_state.items():
try:
self.set(param_id, value, force=True)
except (ValueError, KeyError):
pass # ignore params that can't be set (type mismatch, unknown)
setattr(params, param_id, value)
return uncertainty

View File

@ -0,0 +1,55 @@
"""Collect a dynamics dataset from the running Nucleares game.
Play the game normally while this script runs in the background.
It records state transitions every `time_delta` game-seconds and
saves them incrementally so nothing is lost if you quit early.
Usage:
python scripts/collect_dataset.py # default settings
python scripts/collect_dataset.py --steps 2000 --delta 5 # faster sampling
python scripts/collect_dataset.py --out my_dataset.pkl
The saved dataset is a list of (state_before, action_dict, state_after, time_delta)
tuples compatible with NuconModelLearner.fit_knn() and train_model().
Tips for good data:
- Cover a range of operating states: startup, ramp, steady-state, shutdown.
- Vary individual rod bank positions, pump speeds, and MSCV setpoints.
- Collect at least 500 samples for kNN-GP; 5000+ for the NN backend.
- Merge multiple sessions with NuconModelLearner.merge_datasets().
"""
import argparse
import pickle
from nucon.model import NuconModelLearner
parser = argparse.ArgumentParser()
parser.add_argument('--steps', type=int, default=1000,
help='Number of samples to collect (default: 1000)')
parser.add_argument('--delta', type=float, default=10.0,
help='Game-seconds between samples (default: 10.0)')
parser.add_argument('--out', default='reactor_dataset.pkl',
help='Output path for dataset (default: reactor_dataset.pkl)')
parser.add_argument('--merge', default=None,
help='Existing dataset to merge into before saving')
args = parser.parse_args()
learner = NuconModelLearner(
time_delta=args.delta,
dataset_path=args.out,
)
if args.merge:
learner.merge_datasets(args.merge)
print(f"Merged existing dataset from {args.merge} ({len(learner.dataset)} samples)")
print(f"Collecting {args.steps} samples (Δt={args.delta}s each) → {args.out}")
print("Play the game — vary rod positions, pump speeds, and operating states.")
print("Press Ctrl-C to stop early; data collected so far will be saved.")
try:
learner.collect_data(num_steps=args.steps)
except KeyboardInterrupt:
print("\nInterrupted — saving collected data...")
learner.save_dataset(args.out)
print(f"Saved {len(learner.dataset)} samples to {args.out}")

83
scripts/train_sac.py Normal file
View File

@ -0,0 +1,83 @@
"""SAC + HER training on kNN-GP simulator.
Usage:
python3.14 train_sac.py
Requirements:
- NuCon game running (for parameter metadata)
- /tmp/reactor_knn.pkl (kNN-GP model)
- /tmp/nucon_dataset.pkl (500-sample dataset for init_states)
"""
import pickle
from gymnasium.wrappers import TimeLimit
from stable_baselines3 import SAC
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
from nucon.sim import NuconSimulator
from nucon.rl import NuconGoalEnv, Parameterized_Objectives, Parameterized_Terminators
# ---------------------------------------------------------------------------
# Load model and dataset
# ---------------------------------------------------------------------------
with open('/tmp/reactor_knn.pkl', 'rb') as f:
knn_model = pickle.load(f)
with open('/tmp/nucon_dataset.pkl', 'rb') as f:
dataset = pickle.load(f)
# Seed resets to in-distribution states from dataset
init_states = [s for _, _, s, _ in dataset]
# ---------------------------------------------------------------------------
# Build sim + env
# ---------------------------------------------------------------------------
sim = NuconSimulator(port=8786)
sim.set_model(knn_model)
BATCH_SIZE = 2048
MAX_EPISODE_STEPS = 200
env = NuconGoalEnv(
goal_params=['CORE_TEMP'],
goal_range={'CORE_TEMP': (55.0, 550.0)},
tolerance=0.05,
seconds_per_step=10,
simulator=sim,
additional_objectives=[
Parameterized_Objectives['uncertainty_penalty'](start=0.3),
],
additional_objective_weights=[1.0],
init_states=init_states,
delta_action_scale=0.05,
)
env = TimeLimit(env, max_episode_steps=MAX_EPISODE_STEPS)
# ---------------------------------------------------------------------------
# SAC + HER
# learning_starts = batch_size: wait for batch_size complete (short) episodes
# before the first gradient step. As the policy learns to stay in-dist, episodes
# will get longer and HER has more transitions to relabel.
# ---------------------------------------------------------------------------
model = SAC(
'MultiInputPolicy',
env,
replay_buffer_class=HerReplayBuffer,
replay_buffer_kwargs={
'n_sampled_goal': 4,
'goal_selection_strategy': 'future',
},
verbose=1,
learning_rate=3e-4,
batch_size=BATCH_SIZE,
tau=0.005,
gamma=0.98,
train_freq=64,
gradient_steps=8,
learning_starts=BATCH_SIZE,
device='auto',
)
model.learn(total_timesteps=50_000)
model.save('/tmp/sac_nucon_knn')
print("Saved to /tmp/sac_nucon_knn.zip")