feat: improve NN dynamics model and SAC training

- ReactorDynamicsNet: add dropout (0.3) for regularisation
- ReactorDynamicsModel: z-score normalisation of inputs/outputs, predict
  per-second rates of change, forward_with_uncertainty() stub
- rl.py: misc SAC training improvements
- sim.py: minor fixes
- train_sac.py: updated training loop

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Dominik Moritz Roth 2026-03-15 11:18:15 +01:00
parent 88f4896086
commit 646399dcc7
4 changed files with 246 additions and 61 deletions

View File

@ -18,13 +18,15 @@ Actors = {
# --- NN-based dynamics model ---
class ReactorDynamicsNet(nn.Module):
def __init__(self, input_dim, output_dim):
def __init__(self, input_dim, output_dim, dropout=0.3):
super(ReactorDynamicsNet, self).__init__()
self.network = nn.Sequential(
nn.Linear(input_dim + 1, 128), # +1 for time_delta
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(128, 128),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(128, output_dim)
)
@ -33,23 +35,81 @@ class ReactorDynamicsNet(nn.Module):
return self.network(x)
class ReactorDynamicsModel(nn.Module):
"""
NN dynamics model predicting per-second rates of change (like ReactorKNNModel).
Inputs are z-score normalised; outputs are normalised rates.
forward() returns absolute next-state dict: cur + predicted_rate * time_delta.
forward_with_uncertainty() returns (next_state, 0.0) no uncertainty estimate.
"""
def __init__(self, input_params: List[str], output_params: List[str]):
super(ReactorDynamicsModel, self).__init__()
self.input_params = input_params
self.output_params = output_params
self.net = ReactorDynamicsNet(len(input_params), len(output_params))
# Normalisation stats set by fit()
self.register_buffer('_in_mean', torch.zeros(len(input_params)))
self.register_buffer('_in_std', torch.ones(len(input_params)))
self.register_buffer('_rate_mean', torch.zeros(len(output_params)))
self.register_buffer('_rate_std', torch.ones(len(output_params)))
def _state_dict_to_tensor(self, state_dict):
return torch.tensor([state_dict[p] for p in self.input_params], dtype=torch.float32)
def fit_normalisation(self, dataset):
"""Compute and store normalisation stats from a dataset."""
in_vecs, rate_vecs = [], []
for state, _action, next_state, dt in dataset:
if dt <= 0:
continue
in_vecs.append([state.get(p, 0.0) for p in self.input_params])
rate_vecs.append([(next_state.get(p, 0.0) - state.get(p, 0.0)) / dt
for p in self.output_params])
ins = np.array(in_vecs, dtype=np.float32)
rates = np.array(rate_vecs, dtype=np.float32)
in_std = ins.std(0)
r_std = rates.std(0)
self._in_mean.copy_(torch.from_numpy(ins.mean(0)))
self._in_std.copy_(torch.from_numpy(np.where(in_std < 1e-6, 1.0, in_std)))
self._rate_mean.copy_(torch.from_numpy(rates.mean(0)))
self._rate_std.copy_(torch.from_numpy(np.where(r_std < 1e-6, 1.0, r_std)))
def _tensor_to_state_dict(self, tensor):
return {p: tensor[i].item() for i, p in enumerate(self.output_params)}
def _normalise_input(self, t: torch.Tensor) -> torch.Tensor:
return (t - self._in_mean) / self._in_std
def _denormalise_rate(self, t: torch.Tensor) -> torch.Tensor:
return t * self._rate_std + self._rate_mean
def forward(self, state_dict, time_delta):
state_tensor = self._state_dict_to_tensor(state_dict).unsqueeze(0)
time_delta_tensor = torch.tensor([time_delta], dtype=torch.float32).unsqueeze(0)
predicted_tensor = self.net(state_tensor, time_delta_tensor)
return self._tensor_to_state_dict(predicted_tensor.squeeze(0))
return self.forward_with_uncertainty(state_dict, time_delta)[0]
def forward_with_uncertainty(self, state_dict, time_delta, mc_samples=3):
"""MC-Dropout uncertainty: run mc_samples stochastic forward passes.
Uncertainty is the mean normalised std across output dims, clipped to [0, 1].
0 = very confident (low variance), ~1 = high variance / OOD.
"""
s = torch.tensor([state_dict.get(p, 0.0) for p in self.input_params],
dtype=torch.float32).unsqueeze(0)
s_norm = self._normalise_input(s)
dt_t = torch.tensor([[time_delta]], dtype=torch.float32)
# Keep dropout active for uncertainty sampling
self.net.train()
with torch.no_grad():
samples = torch.stack([self.net(s_norm, dt_t).squeeze(0)
for _ in range(mc_samples)]) # (mc_samples, out_dim)
self.net.eval()
rate_norm_mean = samples.mean(0)
rate_norm_std = samples.std(0)
rate = self._denormalise_rate(rate_norm_mean)
cur = torch.tensor([state_dict.get(p, 0.0) for p in self.output_params],
dtype=torch.float32)
predicted = cur + rate * time_delta
pred_dict = {p: float(predicted[i]) for i, p in enumerate(self.output_params)}
# Uncertainty: mean coefficient of variation in normalised space, clipped to [0,1]
uncertainty = float(rate_norm_std.mean().clamp(0.0, 1.0))
return pred_dict, uncertainty
# --- kNN-based dynamics model ---
@ -150,6 +210,37 @@ class ReactorKNNModel:
pred_dict = {p: float(predicted[i]) for i, p in enumerate(self.output_params)}
return pred_dict, std
# --- Mixture model ---
class MixtureModel:
"""Combines two dynamics models, selecting based on kNN uncertainty.
Uses knn_model when its uncertainty is below threshold (it's confident /
near training data). Falls back to nn_model when kNN is OOD.
Both models must implement forward_with_uncertainty(state_dict, time_delta).
input_params / output_params are taken from knn_model.
"""
def __init__(self, knn_model, nn_model):
self.knn_model = knn_model
self.nn_model = nn_model
self.input_params = knn_model.input_params
self.output_params = knn_model.output_params
def forward(self, state_dict, time_delta):
return self.forward_with_uncertainty(state_dict, time_delta)[0]
def forward_with_uncertainty(self, state_dict, time_delta):
knn_pred, knn_u = self.knn_model.forward_with_uncertainty(state_dict, time_delta)
nn_pred, nn_u = self.nn_model.forward_with_uncertainty(state_dict, time_delta)
w_knn = 1.0 - knn_u # high when kNN is confident
w_nn = knn_u # high when kNN is OOD
blended = {p: w_knn * knn_pred[p] + w_nn * nn_pred[p]
for p in self.output_params}
uncertainty = w_knn * knn_u + w_nn * nn_u # weighted uncertainty
return blended, uncertainty
# --- Learner ---
class NuconModelLearner:
@ -266,13 +357,14 @@ class NuconModelLearner:
self.save_dataset()
print(f"Collection complete. {collected} steps, {len(self.dataset)} total samples.")
def train_model(self, batch_size=32, num_epochs=10, test_split=0.2):
def train_model(self, batch_size=32, num_epochs=10, test_split=0.2, lr=1e-3):
"""Train a neural-network dynamics model on the current dataset."""
if self.model is None:
self.model = ReactorDynamicsModel(self.readable_params, self.non_writable_params)
self.optimizer = optim.Adam(self.model.parameters())
elif not isinstance(self.model, ReactorDynamicsModel):
raise ValueError("A kNN model is already loaded. Create a new learner to train an NN.")
self.model.fit_normalisation(self.dataset)
self.optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)
random.shuffle(self.dataset)
split_idx = int(len(self.dataset) * (1 - test_split))
train_data = self.dataset[:split_idx]
@ -365,37 +457,45 @@ class NuconModelLearner:
print(f"drop_redundant: kept {len(self.dataset)}, dropped {dropped} samples.")
def _train_epoch(self, data, batch_size):
out_indices = [self.readable_params.index(p) if p in self.readable_params else None
for p in self.non_writable_params]
self.model.train()
total_loss = 0
n_batches = 0
for i in range(0, len(data), batch_size):
batch = data[i:i+batch_size]
batch = [s for s in data[i:i+batch_size] if s[3] > 0]
if not batch:
continue
states = torch.tensor([[s[0].get(p, 0.0) for p in self.readable_params] for s in batch], dtype=torch.float32)
targets = torch.tensor([[(s[2].get(p, 0.0) - s[0].get(p, 0.0)) / s[3] for p in self.non_writable_params] for s in batch], dtype=torch.float32)
dts = torch.tensor([[s[3]] for s in batch], dtype=torch.float32)
s_norm = self.model._normalise_input(states)
rate_norm_pred = self.model.net(s_norm, dts)
rate_norm_target = (targets - self.model._rate_mean) / self.model._rate_std
self.optimizer.zero_grad()
loss = torch.tensor(0.0)
for state, _, next_state, time_delta in batch:
state_t = self.model._state_dict_to_tensor(state).unsqueeze(0)
td_t = torch.tensor([[time_delta]], dtype=torch.float32)
pred = self.model.net(state_t, td_t).squeeze(0)
target = torch.tensor([next_state[p] for p in self.non_writable_params],
dtype=torch.float32)
loss = loss + torch.nn.functional.mse_loss(pred, target)
loss = loss / len(batch)
loss = torch.nn.functional.mse_loss(rate_norm_pred, rate_norm_target)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
return total_loss / max(1, len(data) // batch_size)
n_batches += 1
self.model.eval()
return total_loss / max(1, n_batches)
def _test_epoch(self, data):
total_loss = 0.0
n = 0
with torch.no_grad():
for state, _, next_state, time_delta in data:
state_t = self.model._state_dict_to_tensor(state).unsqueeze(0)
td_t = torch.tensor([[time_delta]], dtype=torch.float32)
pred = self.model.net(state_t, td_t).squeeze(0)
target = torch.tensor([next_state[p] for p in self.non_writable_params],
dtype=torch.float32)
total_loss += torch.nn.functional.mse_loss(pred, target).item()
return total_loss / len(data)
for state, _, next_state, dt in data:
if dt <= 0:
continue
s_t = torch.tensor([[state.get(p, 0.0) for p in self.readable_params]], dtype=torch.float32)
s_norm = self.model._normalise_input(s_t)
dt_t = torch.tensor([[dt]], dtype=torch.float32)
rate_norm_pred = self.model.net(s_norm, dt_t).squeeze(0)
target = torch.tensor([(next_state.get(p, 0.0) - state.get(p, 0.0)) / dt
for p in self.non_writable_params], dtype=torch.float32)
rate_norm_target = (target - self.model._rate_mean) / self.model._rate_std
total_loss += torch.nn.functional.mse_loss(rate_norm_pred, rate_norm_target).item()
n += 1
return total_loss / max(1, n)
def save_model(self, path):
if self.model is None:
@ -439,6 +539,10 @@ class NuconModelLearner:
def merge_datasets(self, other_dataset_path):
other_dataset = self.load_dataset(other_dataset_path)
if other_dataset:
if not isinstance(other_dataset, list):
raise ValueError(
f"'{other_dataset_path}' does not contain a dataset (got {type(other_dataset).__name__}). "
f"Pass a dataset .pkl file, not a model file."
)
self.dataset.extend(other_dataset)
self.save_dataset()

View File

@ -12,10 +12,18 @@ from nucon import Nucon, BreakerStatus, PumpStatus, PumpDryStatus, PumpOverloadS
# Reward / objective helpers
# ---------------------------------------------------------------------------
def _alarm_penalty(obs):
"""Penalty proportional to number of active alarms. Only meaningful when running against the real game."""
raw = obs.get('ALARMS_ACTIVE', '')
if not raw or not raw.strip():
return 0.0
return -float(len(raw.split(',')))
Objectives = {
"null": lambda obs: 0,
"max_power": lambda obs: obs["GENERATOR_0_KW"] + obs["GENERATOR_1_KW"] + obs["GENERATOR_2_KW"],
"episode_time": lambda obs: obs["EPISODE_TIME"],
"alarm_penalty": _alarm_penalty,
}
def _uncertainty_penalty(start=0.3, scale=1.0, mode='l2'):
@ -35,6 +43,8 @@ Parameterized_Objectives = {
"target_gap": lambda goal_gap: lambda obs: -((obs["CORE_TEMP"] - obs["CORE_TEMP_MIN"] - goal_gap) ** 2),
"temp_below": lambda max_temp: lambda obs: -(np.clip(obs["CORE_TEMP"] - max_temp, 0, np.inf) ** 2),
"temp_above": lambda min_temp: lambda obs: -(np.clip(min_temp - obs["CORE_TEMP"], 0, np.inf) ** 2),
"temp_below_linear": lambda max_temp: lambda obs: -np.clip(obs["CORE_TEMP"] - max_temp, 0, np.inf),
"temp_above_linear": lambda min_temp: lambda obs: -np.clip(min_temp - obs["CORE_TEMP"], 0, np.inf),
"constant": lambda constant: lambda obs: constant,
"uncertainty_penalty": _uncertainty_penalty, # (start, scale, mode) -> (obs) -> float
}
@ -284,8 +294,10 @@ class NuconGoalEnv(gym.Env):
additional_objectives=None,
additional_objective_weights=None,
obs_params=None,
action_params=None,
init_states=None,
delta_action_scale=None,
goal_sampling_std=None,
):
super().__init__()
@ -353,15 +365,17 @@ class NuconGoalEnv(gym.Env):
'desired_goal': spaces.Box(low=0.0, high=1.0, shape=(n_goals,), dtype=np.float32),
})
# Action space: writable params within the obs param set (flat Box for SB3 compatibility).
# Action space: writable params within the obs param set, or an explicit override list.
action_set = set(action_params) if action_params is not None else set(base_params)
self.action_space, self._action_params, self._action_lows, self._action_ranges = \
_build_flat_action_space(self.nucon, set(base_params), delta_action_scale)
_build_flat_action_space(self.nucon, action_set, delta_action_scale)
self._terminators = terminators or []
_objs = additional_objectives or []
self._objectives = [Objectives[o] if isinstance(o, str) else o for o in _objs]
self._objective_weights = additional_objective_weights or [1.0] * len(self._objectives)
self._init_states = init_states # list of state dicts to sample on reset
self._goal_sampling_std = goal_sampling_std # Gaussian std in normalised goal space; None → uniform
self._desired_goal = np.zeros(n_goals, dtype=np.float32)
self._total_steps = 0
@ -423,7 +437,6 @@ class NuconGoalEnv(gym.Env):
super().reset(seed=seed)
self._total_steps = 0
rng = np.random.default_rng(seed)
self._desired_goal = rng.uniform(0.0, 1.0, size=len(self.goal_params)).astype(np.float32)
if self._init_states is not None and self.simulator is not None:
state = self._init_states[rng.integers(len(self._init_states))]
for k, v in state.items():
@ -431,13 +444,28 @@ class NuconGoalEnv(gym.Env):
self.simulator.set(k, v, force=True)
except Exception:
pass
if self._goal_sampling_std is not None:
# Sample goal as Gaussian delta from current state — usually a small change,
# occasionally a large one.
current = np.array([
float(self.simulator.get(p) if self.simulator else 0.0)
for p in self.goal_params
], dtype=np.float32)
current_norm = np.clip((current - self._goal_low) / self._goal_range, 0.0, 1.0)
delta = rng.normal(0.0, self._goal_sampling_std, size=len(self.goal_params))
self._desired_goal = np.clip(current_norm + delta, 0.0, 1.0).astype(np.float32)
else:
self._desired_goal = rng.uniform(0.0, 1.0, size=len(self.goal_params)).astype(np.float32)
gym_obs, _ = self._read_obs()
return gym_obs, {}
def step(self, action):
flat = np.asarray(action, dtype=np.float32)
if self._delta_action_scale is not None:
# Compute absolute values from deltas, reading current state directly if possible
# Compute absolute values from deltas, reading current state
if self.simulator is None:
raw_current = self.nucon._batch_query(self._action_params)
all_params = self.nucon.get_all_readable()
absolute = {}
for i, pid in enumerate(self._action_params):
param = self.nucon._parameters[pid]
@ -448,7 +476,11 @@ class NuconGoalEnv(gym.Env):
v = self.simulator.get(pid)
current = float(v.value if isinstance(v, Enum) else v) if v is not None else 0.0
else:
current = 0.0 # fallback; batch read not worth it for actions alone
try:
v = self.nucon._parse_value(all_params[pid], raw_current.get(pid, '0'))
current = float(v.value if isinstance(v, Enum) else v)
except Exception:
current = 0.0
delta = float(flat[i]) * self._delta_action_scale * self._action_ranges[i]
absolute[pid] = float(np.clip(current + delta,
self._action_lows[i],

View File

@ -271,10 +271,7 @@ class NuconSimulator:
# Forward pass
uncertainty = None
if isinstance(self.model, ReactorDynamicsModel):
with torch.no_grad():
next_state = self.model.forward(state, time_step)
elif return_uncertainty:
if return_uncertainty:
next_state, uncertainty = self.model.forward_with_uncertainty(state, time_step)
else:
next_state = self.model.forward(state, time_step)

View File

@ -11,26 +11,43 @@ Requirements:
"""
import argparse
import pickle
import torch
from gymnasium.wrappers import TimeLimit
from stable_baselines3 import SAC
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
from stable_baselines3.common.callbacks import CheckpointCallback
from nucon.sim import NuconSimulator
from nucon.model import ReactorDynamicsModel, MixtureModel
from nucon.rl import NuconGoalEnv, Parameterized_Objectives, Parameterized_Terminators
parser = argparse.ArgumentParser()
parser.add_argument('--load', default=None, help='Path to existing model to hot-start from')
parser.add_argument('--steps', type=int, default=50_000, help='Total timesteps (default: 50000)')
parser.add_argument('--out', default='/tmp/sac_nucon_knn', help='Output path for saved model')
parser.add_argument('--model', default='/tmp/reactor_knn.pkl', help='Dynamics model (.pkl for kNN, .pt for NN)')
parser.add_argument('--model2', default=None, help='Second dynamics model for mixture (optional)')
parser.add_argument('--dataset', default='/tmp/nucon_dataset.pkl', help='Dataset for init states')
args = parser.parse_args()
# ---------------------------------------------------------------------------
# Load model and dataset
# Load dynamics model(s) and dataset
# ---------------------------------------------------------------------------
with open('/tmp/reactor_knn.pkl', 'rb') as f:
knn_model = pickle.load(f)
def _load_model(path):
if path.endswith('.pt'):
ckpt = torch.load(path, weights_only=False)
m = ReactorDynamicsModel(ckpt['input_params'], ckpt['output_params'])
m.load_state_dict(ckpt['state_dict'])
m.eval()
return m
with open(path, 'rb') as f:
return pickle.load(f)
with open('/tmp/nucon_dataset.pkl', 'rb') as f:
dynamics_model = _load_model(args.model)
if args.model2:
dynamics_model = MixtureModel(dynamics_model, _load_model(args.model2))
with open(args.dataset, 'rb') as f:
dataset = pickle.load(f)
# Seed resets to in-distribution states from dataset
@ -40,23 +57,40 @@ init_states = [s for _, _, s, _ in dataset]
# Build sim + env
# ---------------------------------------------------------------------------
sim = NuconSimulator(port=8786)
sim.set_model(knn_model)
sim.set_model(dynamics_model)
BATCH_SIZE = 2048
MAX_EPISODE_STEPS = 200
GENERATORS = ['GENERATOR_0_KW', 'GENERATOR_1_KW', 'GENERATOR_2_KW']
POWER_RANGE = {g: (0.0, 100_000.0) for g in GENERATORS} # per-generator kW; ~100 MW upper bound
# Curated obs: physically relevant features for power control (~25 dims vs ~260 full)
OBS_PARAMS = [
'CORE_TEMP', 'CORE_PRESSURE', 'CORE_STATE_CRITICALITY', 'CORE_WEAR', 'CORE_INTEGRITY',
'ROD_BANK_POS_0_ACTUAL', 'ROD_BANK_POS_0_ORDERED',
'COOLANT_CORE_FLOW_SPEED', 'COOLANT_CORE_VESSEL_TEMPERATURE',
'COOLANT_CORE_PRESSURE', 'COOLANT_CORE_QUANTITY_IN_VESSEL',
'STEAM_TURBINE_0_RPM', 'STEAM_TURBINE_0_TEMPERATURE', 'STEAM_TURBINE_0_PRESSURE',
'STEAM_TURBINE_1_RPM', 'STEAM_TURBINE_1_TEMPERATURE', 'STEAM_TURBINE_1_PRESSURE',
'STEAM_TURBINE_2_RPM', 'STEAM_TURBINE_2_TEMPERATURE', 'STEAM_TURBINE_2_PRESSURE',
'GENERATOR_0_V', 'GENERATOR_1_V', 'GENERATOR_2_V',
]
env = NuconGoalEnv(
goal_params=['CORE_TEMP'],
goal_range={'CORE_TEMP': (55.0, 550.0)},
tolerance=0.05,
goal_params=GENERATORS,
goal_range=POWER_RANGE,
seconds_per_step=10,
simulator=sim,
obs_params=OBS_PARAMS,
additional_objectives=[
Parameterized_Objectives['uncertainty_penalty'](start=0.3),
Parameterized_Objectives['temp_below_linear'](max_temp=420),
],
additional_objective_weights=[1.0],
additional_objective_weights=[1.0, 0.01],
init_states=init_states,
delta_action_scale=0.05,
goal_sampling_std=0.15, # Gaussian delta in normalised space (~180 kW typical)
)
env = TimeLimit(env, max_episode_steps=MAX_EPISODE_STEPS)
@ -73,7 +107,8 @@ if args.load:
custom_objects={'learning_rate': 3e-4, 'batch_size': BATCH_SIZE,
'tau': 0.005, 'gamma': 0.98,
'train_freq': 64, 'gradient_steps': 8,
'learning_starts': 0})
'learning_starts': MAX_EPISODE_STEPS,
'ent_coef': 0.1})
else:
model = SAC(
'MultiInputPolicy',
@ -91,9 +126,26 @@ else:
train_freq=64,
gradient_steps=8,
learning_starts=BATCH_SIZE,
ent_coef=0.1, # fixed; auto-tuning diverges on this many action dims
device='auto',
)
model.learn(total_timesteps=args.steps)
checkpoint_cb = CheckpointCallback(
save_freq=10_000,
save_path=args.out + '_checkpoints/',
name_prefix='sac',
)
import json, os
config = {'obs_params': OBS_PARAMS}
for save_dir in [args.out + '_checkpoints/', os.path.dirname(args.out) or '.']:
os.makedirs(save_dir, exist_ok=True)
with open(os.path.join(save_dir, 'config.json'), 'w') as f:
json.dump(config, f)
model.learn(total_timesteps=args.steps, callback=checkpoint_cb)
model.save(args.out)
with open(args.out + '.json', 'w') as f:
json.dump(config, f)
print(f"Saved to {args.out}.zip")