NuCon/nucon/model.py
Dominik Roth 646399dcc7 feat: improve NN dynamics model and SAC training
- ReactorDynamicsNet: add dropout (0.3) for regularisation
- ReactorDynamicsModel: z-score normalisation of inputs/outputs, predict
  per-second rates of change, forward_with_uncertainty() stub
- rl.py: misc SAC training improvements
- sim.py: minor fixes
- train_sac.py: updated training loop

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-15 11:18:15 +01:00

549 lines
24 KiB
Python

import numpy as np
import time
import torch
import torch.nn as nn
import torch.optim as optim
import random
from enum import Enum
from nucon import Nucon
import pickle
import os
from typing import Union, Tuple, List, Dict
Actors = {
'random': lambda nucon: lambda obs: {param.id: random.uniform(param.min_val, param.max_val) if param.min_val is not None and param.max_val is not None else 0 for param in nucon.get_all_writable().values()},
'null': lambda nucon: lambda obs: {},
}
# --- NN-based dynamics model ---
class ReactorDynamicsNet(nn.Module):
def __init__(self, input_dim, output_dim, dropout=0.3):
super(ReactorDynamicsNet, self).__init__()
self.network = nn.Sequential(
nn.Linear(input_dim + 1, 128), # +1 for time_delta
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(128, 128),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(128, output_dim)
)
def forward(self, state, time_delta):
x = torch.cat([state, time_delta], dim=-1)
return self.network(x)
class ReactorDynamicsModel(nn.Module):
"""
NN dynamics model predicting per-second rates of change (like ReactorKNNModel).
Inputs are z-score normalised; outputs are normalised rates.
forward() returns absolute next-state dict: cur + predicted_rate * time_delta.
forward_with_uncertainty() returns (next_state, 0.0) — no uncertainty estimate.
"""
def __init__(self, input_params: List[str], output_params: List[str]):
super(ReactorDynamicsModel, self).__init__()
self.input_params = input_params
self.output_params = output_params
self.net = ReactorDynamicsNet(len(input_params), len(output_params))
# Normalisation stats set by fit()
self.register_buffer('_in_mean', torch.zeros(len(input_params)))
self.register_buffer('_in_std', torch.ones(len(input_params)))
self.register_buffer('_rate_mean', torch.zeros(len(output_params)))
self.register_buffer('_rate_std', torch.ones(len(output_params)))
def fit_normalisation(self, dataset):
"""Compute and store normalisation stats from a dataset."""
in_vecs, rate_vecs = [], []
for state, _action, next_state, dt in dataset:
if dt <= 0:
continue
in_vecs.append([state.get(p, 0.0) for p in self.input_params])
rate_vecs.append([(next_state.get(p, 0.0) - state.get(p, 0.0)) / dt
for p in self.output_params])
ins = np.array(in_vecs, dtype=np.float32)
rates = np.array(rate_vecs, dtype=np.float32)
in_std = ins.std(0)
r_std = rates.std(0)
self._in_mean.copy_(torch.from_numpy(ins.mean(0)))
self._in_std.copy_(torch.from_numpy(np.where(in_std < 1e-6, 1.0, in_std)))
self._rate_mean.copy_(torch.from_numpy(rates.mean(0)))
self._rate_std.copy_(torch.from_numpy(np.where(r_std < 1e-6, 1.0, r_std)))
def _normalise_input(self, t: torch.Tensor) -> torch.Tensor:
return (t - self._in_mean) / self._in_std
def _denormalise_rate(self, t: torch.Tensor) -> torch.Tensor:
return t * self._rate_std + self._rate_mean
def forward(self, state_dict, time_delta):
return self.forward_with_uncertainty(state_dict, time_delta)[0]
def forward_with_uncertainty(self, state_dict, time_delta, mc_samples=3):
"""MC-Dropout uncertainty: run mc_samples stochastic forward passes.
Uncertainty is the mean normalised std across output dims, clipped to [0, 1].
0 = very confident (low variance), ~1 = high variance / OOD.
"""
s = torch.tensor([state_dict.get(p, 0.0) for p in self.input_params],
dtype=torch.float32).unsqueeze(0)
s_norm = self._normalise_input(s)
dt_t = torch.tensor([[time_delta]], dtype=torch.float32)
# Keep dropout active for uncertainty sampling
self.net.train()
with torch.no_grad():
samples = torch.stack([self.net(s_norm, dt_t).squeeze(0)
for _ in range(mc_samples)]) # (mc_samples, out_dim)
self.net.eval()
rate_norm_mean = samples.mean(0)
rate_norm_std = samples.std(0)
rate = self._denormalise_rate(rate_norm_mean)
cur = torch.tensor([state_dict.get(p, 0.0) for p in self.output_params],
dtype=torch.float32)
predicted = cur + rate * time_delta
pred_dict = {p: float(predicted[i]) for i, p in enumerate(self.output_params)}
# Uncertainty: mean coefficient of variation in normalised space, clipped to [0,1]
uncertainty = float(rate_norm_std.mean().clamp(0.0, 1.0))
return pred_dict, uncertainty
# --- kNN-based dynamics model ---
class ReactorKNNModel:
"""
Non-parametric dynamics model using k-nearest neighbours.
For a query (state, game_delta):
1. Find the k dataset entries whose *state* is closest (L2 in normalised space).
2. For each neighbour compute the per-second rate-of-change:
rate_i = (next_state_i - state_i) / game_delta_i
3. Linearly scale to the requested game_delta:
predicted_delta_i = rate_i * game_delta
4. Return the inverse-distance-weighted average of those predicted deltas
added to the current output state.
The linear-in-time assumption means two datapoints at 0.5 s and 2 s contribute
equally once normalised by their own game_delta.
"""
def __init__(self, input_params: List[str], output_params: List[str], k: int = 5):
self.input_params = input_params
self.output_params = output_params
self.k = k
self._states = None # (n, d_in) normalised state matrix
self._rates = None # (n, d_out) (next_out - cur_out) / game_delta
self._raw_states = None # unnormalised, for mean/std computation
self._mean = None
self._std = None
def fit(self, dataset):
"""Build lookup tables from a collected dataset."""
raw, rates = [], []
for state, _action, next_state, game_delta in dataset:
if game_delta <= 0:
continue
s = np.array([state[p] for p in self.input_params], dtype=np.float32)
cur = np.array([state[p] for p in self.output_params], dtype=np.float32)
nxt = np.array([next_state[p] for p in self.output_params], dtype=np.float32)
raw.append(s)
rates.append((nxt - cur) / game_delta)
self._raw_states = np.array(raw)
self._rates = np.array(rates)
self._mean = self._raw_states.mean(axis=0)
raw_std = self._raw_states.std(axis=0)
# Dimensions with zero variance in the training data carry no distance information.
# Use inf so they contribute 0 to normalised L2 (i.e., are ignored in kNN lookup).
self._std = np.where(raw_std < 1e-6, np.inf, raw_std)
self._states = (self._raw_states - self._mean) / self._std
def _lookup(self, s: np.ndarray):
"""Return (s_norm, idx, k) for the k nearest neighbours. s is a raw (d_in,) array."""
s_norm = (s - self._mean) / self._std
diff = self._states - s_norm # (n, d_in) broadcast
dists = np.einsum('ij,ij->i', diff, diff) # squared L2, faster than linalg.norm
k = min(self.k, len(dists))
idx = np.argpartition(dists, k - 1)[:k]
return s_norm, idx, k
def forward(self, state_dict: Dict, time_delta: float) -> Dict:
if self._states is None:
raise ValueError("Model not fitted. Call fit(dataset) first.")
return self.forward_with_uncertainty(state_dict, time_delta)[0]
def forward_with_uncertainty(self, state_dict: Dict, time_delta: float):
"""Return (prediction_dict, uncertainty_scalar).
Uncertainty is the GP posterior std in normalised input space:
0 = query lies exactly on a training point (fully confident)
~1 = query is far from all neighbours (maximally uncertain)
"""
if self._states is None:
raise ValueError("Model not fitted. Call fit(dataset) first.")
s = np.array([state_dict[p] for p in self.input_params], dtype=np.float32)
s_norm, idx, k = self._lookup(s)
X = self._states[idx] # (k, d_in)
Y = self._rates[idx] # (k, d_out)
# RBF kernel: k(a,b) = exp(-0.5 ||a-b||^2)
def rbf(A, B):
diff = A[:, None, :] - B[None, :, :]
return np.exp(-0.5 * np.einsum('ijk,ijk->ij', diff, diff))
K = rbf(X, X) + 1e-4 * np.eye(k)
k_star = rbf(s_norm[None, :], X)[0]
K_inv = np.linalg.inv(K)
mean_rates = k_star @ K_inv @ Y
var = max(0.0, 1.0 - float(k_star @ K_inv @ k_star))
std = float(np.sqrt(var))
cur_out = np.array([state_dict[p] for p in self.output_params], dtype=np.float32)
predicted = cur_out + mean_rates * time_delta
pred_dict = {p: float(predicted[i]) for i, p in enumerate(self.output_params)}
return pred_dict, std
# --- Mixture model ---
class MixtureModel:
"""Combines two dynamics models, selecting based on kNN uncertainty.
Uses knn_model when its uncertainty is below threshold (it's confident /
near training data). Falls back to nn_model when kNN is OOD.
Both models must implement forward_with_uncertainty(state_dict, time_delta).
input_params / output_params are taken from knn_model.
"""
def __init__(self, knn_model, nn_model):
self.knn_model = knn_model
self.nn_model = nn_model
self.input_params = knn_model.input_params
self.output_params = knn_model.output_params
def forward(self, state_dict, time_delta):
return self.forward_with_uncertainty(state_dict, time_delta)[0]
def forward_with_uncertainty(self, state_dict, time_delta):
knn_pred, knn_u = self.knn_model.forward_with_uncertainty(state_dict, time_delta)
nn_pred, nn_u = self.nn_model.forward_with_uncertainty(state_dict, time_delta)
w_knn = 1.0 - knn_u # high when kNN is confident
w_nn = knn_u # high when kNN is OOD
blended = {p: w_knn * knn_pred[p] + w_nn * nn_pred[p]
for p in self.output_params}
uncertainty = w_knn * knn_u + w_nn * nn_u # weighted uncertainty
return blended, uncertainty
# --- Learner ---
class NuconModelLearner:
def __init__(self, nucon=None, actor='null', dataset_path='nucon_dataset.pkl',
time_delta: Union[float, Tuple[float, float]] = 1.0,
include_valve_states: bool = False):
self.nucon = Nucon() if nucon is None else nucon
self.actor = Actors[actor](self.nucon) if actor in Actors else actor
self.dataset = self.load_dataset(dataset_path) or []
self.dataset_path = dataset_path
self.include_valve_states = include_valve_states
self.model = None
self.optimizer = None
# Exclude params with no physics signal
_JUNK_PARAMS = frozenset({'GAME_VERSION', 'TIME', 'TIME_STAMP', 'TIME_DAY',
'ALARMS_ACTIVE', 'FUN_IS_ENABLED', 'GAME_SIM_SPEED'})
candidate_params = {k: p for k, p in self.nucon.get_all_readable().items()
if k not in _JUNK_PARAMS and p.param_type != str}
# Filter out params that return None (subsystem not installed).
# Retry until the game is reachable.
import requests as _requests
while True:
try:
test_state = {k: self.nucon.get(k) for k in candidate_params}
break
except (_requests.exceptions.ConnectionError,
_requests.exceptions.Timeout):
print("Waiting for game to be reachable…")
time.sleep(5)
self.readable_params = [k for k in candidate_params if test_state[k] is not None]
self.non_writable_params = [k for k in self.readable_params
if not self.nucon.get_all_readable()[k].is_writable]
# Optionally include valve positions (input only — valves are externally driven)
self.valve_keys = []
if include_valve_states:
valves = self.nucon.get_valves()
self.valve_keys = [f'VALVE__{name}' for name in sorted(valves.keys())]
self.readable_params = self.readable_params + self.valve_keys
# valve positions are input-only (not predicted as outputs)
if isinstance(time_delta, (int, float)):
self.time_delta = lambda: time_delta
elif isinstance(time_delta, tuple) and len(time_delta) == 2:
self.time_delta = lambda: random.uniform(*time_delta)
else:
raise ValueError("time_delta must be a float or a tuple of two floats")
def _get_state(self):
state = {}
for param_id in self.readable_params:
if param_id in self.valve_keys:
continue # filled below
value = self.nucon.get(param_id)
if isinstance(value, Enum):
value = value.value
state[param_id] = value
if self.valve_keys:
valves = self.nucon.get_valves()
for key in self.valve_keys:
name = key[len('VALVE__'):]
state[key] = valves.get(name, {}).get('Value', 0.0)
return state
def collect_data(self, num_steps, save_every=10):
"""
Collect state-transition tuples from the live game.
Sleeps wall_time = target_game_delta / sim_speed so that each stored
game_delta is uniform regardless of the game's simulation speed setting.
Saves the dataset every ``save_every`` steps so a crash doesn't lose
everything. On a connection error the step is skipped and collection
resumes once the game is reachable again (retries every 5 s).
"""
import requests as _requests
def get_state_with_retry():
while True:
try:
return self._get_state()
except (_requests.exceptions.ConnectionError,
_requests.exceptions.Timeout) as e:
print(f"Connection lost ({e}). Retrying in 5 s…")
time.sleep(5)
state = get_state_with_retry()
collected = 0
for i in range(num_steps):
action = self.actor(state)
for param_id, value in action.items():
try:
self.nucon.set(param_id, value)
except Exception:
pass
target_game_delta = self.time_delta()
try:
sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0
except Exception:
sim_speed = 1.0
time.sleep(target_game_delta / sim_speed)
next_state = get_state_with_retry()
self.dataset.append((state, action, next_state, target_game_delta))
state = next_state
collected += 1
if collected % save_every == 0:
self.save_dataset()
print(f" {collected}/{num_steps} steps collected, dataset saved.")
self.save_dataset()
print(f"Collection complete. {collected} steps, {len(self.dataset)} total samples.")
def train_model(self, batch_size=32, num_epochs=10, test_split=0.2, lr=1e-3):
"""Train a neural-network dynamics model on the current dataset."""
if self.model is None:
self.model = ReactorDynamicsModel(self.readable_params, self.non_writable_params)
elif not isinstance(self.model, ReactorDynamicsModel):
raise ValueError("A kNN model is already loaded. Create a new learner to train an NN.")
self.model.fit_normalisation(self.dataset)
self.optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)
random.shuffle(self.dataset)
split_idx = int(len(self.dataset) * (1 - test_split))
train_data = self.dataset[:split_idx]
test_data = self.dataset[split_idx:]
for epoch in range(num_epochs):
train_loss = self._train_epoch(train_data, batch_size)
test_loss = self._test_epoch(test_data)
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
def fit_knn(self, k: int = 5):
"""Fit a kNN/GP dynamics model from the current dataset (instantaneous, no gradient steps)."""
if self.model is None:
self.model = ReactorKNNModel(self.readable_params, self.non_writable_params, k=k)
elif not isinstance(self.model, ReactorKNNModel):
raise ValueError("An NN model is already loaded. Create a new learner to fit a kNN.")
self.model.fit(self.dataset)
print(f"kNN model fitted on {len(self.dataset)} samples.")
def predict_with_uncertainty(self, state_dict: Dict, time_delta: float):
"""Return (prediction_dict, uncertainty_std). Only available after fit_knn()."""
if not isinstance(self.model, ReactorKNNModel):
raise ValueError("predict_with_uncertainty() requires a fitted kNN model (call fit_knn()).")
return self.model.forward_with_uncertainty(state_dict, time_delta)
def drop_well_fitted(self, error_threshold: float):
"""Drop samples the current model already predicts well (MSE < threshold).
Keeps only hard/surprising transitions. Useful for NN training to focus
capacity on difficult regions of state space.
"""
if self.model is None:
raise ValueError("No model fitted yet. Call train_model() or fit_knn() first.")
kept = []
for state, action, next_state, time_delta in self.dataset:
pred = self.model.forward(state, time_delta)
error = sum((pred[p] - next_state[p]) ** 2 for p in self.non_writable_params)
if error > error_threshold:
kept.append((state, action, next_state, time_delta))
dropped = len(self.dataset) - len(kept)
self.dataset = kept
self.save_dataset()
print(f"drop_well_fitted: kept {len(kept)}, dropped {dropped} samples.")
def drop_redundant(self, min_state_distance: float, min_output_distance: float = 0.0):
"""Drop near-duplicate samples, keeping only those that add coverage.
A sample is dropped only if *both* its input state and its output
transition are within the given distances of an already-kept sample
(L2 in z-scored space). If two samples share the same input state but
have different transitions they represent genuinely different dynamics
and are both kept regardless of `min_output_distance`.
Args:
min_state_distance: minimum L2 distance in z-scored input space.
min_output_distance: minimum L2 distance in z-scored output-delta
space. Defaults to 0 (only input distance matters).
"""
if not self.dataset:
return
in_params = [p for p in self.readable_params if p not in self.valve_keys]
out_params = self.non_writable_params
all_states = np.array([[s[p] for p in in_params] for s, *_ in self.dataset], dtype=np.float32)
all_deltas = np.array([[ns[p] - s[p] for p in out_params]
for s, _, ns, gd in self.dataset], dtype=np.float32)
s_mean, s_std = all_states.mean(0), all_states.std(0) + 1e-8
d_mean, d_std = all_deltas.mean(0), all_deltas.std(0) + 1e-8
s_norm = (all_states - s_mean) / s_std
d_norm = (all_deltas - d_mean) / d_std
kept_idx = [0]
kept_s = [s_norm[0]]
kept_d = [d_norm[0]]
for i in range(1, len(self.dataset)):
s_dists = np.linalg.norm(np.array(kept_s) - s_norm[i], axis=1)
d_dists = np.linalg.norm(np.array(kept_d) - d_norm[i], axis=1)
# Drop only if close in BOTH spaces
if not np.any((s_dists < min_state_distance) & (d_dists < min_output_distance)):
kept_idx.append(i)
kept_s.append(s_norm[i])
kept_d.append(d_norm[i])
dropped = len(self.dataset) - len(kept_idx)
self.dataset = [self.dataset[i] for i in kept_idx]
self.save_dataset()
print(f"drop_redundant: kept {len(self.dataset)}, dropped {dropped} samples.")
def _train_epoch(self, data, batch_size):
self.model.train()
total_loss = 0
n_batches = 0
for i in range(0, len(data), batch_size):
batch = [s for s in data[i:i+batch_size] if s[3] > 0]
if not batch:
continue
states = torch.tensor([[s[0].get(p, 0.0) for p in self.readable_params] for s in batch], dtype=torch.float32)
targets = torch.tensor([[(s[2].get(p, 0.0) - s[0].get(p, 0.0)) / s[3] for p in self.non_writable_params] for s in batch], dtype=torch.float32)
dts = torch.tensor([[s[3]] for s in batch], dtype=torch.float32)
s_norm = self.model._normalise_input(states)
rate_norm_pred = self.model.net(s_norm, dts)
rate_norm_target = (targets - self.model._rate_mean) / self.model._rate_std
self.optimizer.zero_grad()
loss = torch.nn.functional.mse_loss(rate_norm_pred, rate_norm_target)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
n_batches += 1
self.model.eval()
return total_loss / max(1, n_batches)
def _test_epoch(self, data):
total_loss = 0.0
n = 0
with torch.no_grad():
for state, _, next_state, dt in data:
if dt <= 0:
continue
s_t = torch.tensor([[state.get(p, 0.0) for p in self.readable_params]], dtype=torch.float32)
s_norm = self.model._normalise_input(s_t)
dt_t = torch.tensor([[dt]], dtype=torch.float32)
rate_norm_pred = self.model.net(s_norm, dt_t).squeeze(0)
target = torch.tensor([(next_state.get(p, 0.0) - state.get(p, 0.0)) / dt
for p in self.non_writable_params], dtype=torch.float32)
rate_norm_target = (target - self.model._rate_mean) / self.model._rate_std
total_loss += torch.nn.functional.mse_loss(rate_norm_pred, rate_norm_target).item()
n += 1
return total_loss / max(1, n)
def save_model(self, path):
if self.model is None:
raise ValueError("No model to save. Call train_model() or fit_knn() first.")
if isinstance(self.model, ReactorDynamicsModel):
torch.save({
'state_dict': self.model.state_dict(),
'input_params': self.model.input_params,
'output_params': self.model.output_params,
}, path)
else:
with open(path, 'wb') as f:
pickle.dump(self.model, f)
def load_model(self, path):
if path.endswith('.pkl'):
with open(path, 'rb') as f:
self.model = pickle.load(f)
else:
checkpoint = torch.load(path, weights_only=False)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
m = ReactorDynamicsModel(checkpoint['input_params'], checkpoint['output_params'])
m.load_state_dict(checkpoint['state_dict'])
self.model = m
else:
# legacy plain state dict
self.model = ReactorDynamicsModel(self.readable_params, self.non_writable_params)
self.model.load_state_dict(checkpoint)
def save_dataset(self, path=None):
path = path or self.dataset_path
with open(path, 'wb') as f:
pickle.dump(self.dataset, f)
def load_dataset(self, path=None):
path = path or self.dataset_path
if os.path.exists(path):
with open(path, 'rb') as f:
return pickle.load(f)
return None
def merge_datasets(self, other_dataset_path):
other_dataset = self.load_dataset(other_dataset_path)
if not isinstance(other_dataset, list):
raise ValueError(
f"'{other_dataset_path}' does not contain a dataset (got {type(other_dataset).__name__}). "
f"Pass a dataset .pkl file, not a model file."
)
self.dataset.extend(other_dataset)
self.save_dataset()