- ReactorDynamicsNet: add dropout (0.3) for regularisation - ReactorDynamicsModel: z-score normalisation of inputs/outputs, predict per-second rates of change, forward_with_uncertainty() stub - rl.py: misc SAC training improvements - sim.py: minor fixes - train_sac.py: updated training loop Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
549 lines
24 KiB
Python
549 lines
24 KiB
Python
import numpy as np
|
|
import time
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import random
|
|
from enum import Enum
|
|
from nucon import Nucon
|
|
import pickle
|
|
import os
|
|
from typing import Union, Tuple, List, Dict
|
|
|
|
Actors = {
|
|
'random': lambda nucon: lambda obs: {param.id: random.uniform(param.min_val, param.max_val) if param.min_val is not None and param.max_val is not None else 0 for param in nucon.get_all_writable().values()},
|
|
'null': lambda nucon: lambda obs: {},
|
|
}
|
|
|
|
# --- NN-based dynamics model ---
|
|
|
|
class ReactorDynamicsNet(nn.Module):
|
|
def __init__(self, input_dim, output_dim, dropout=0.3):
|
|
super(ReactorDynamicsNet, self).__init__()
|
|
self.network = nn.Sequential(
|
|
nn.Linear(input_dim + 1, 128), # +1 for time_delta
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(128, 128),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(128, output_dim)
|
|
)
|
|
|
|
def forward(self, state, time_delta):
|
|
x = torch.cat([state, time_delta], dim=-1)
|
|
return self.network(x)
|
|
|
|
class ReactorDynamicsModel(nn.Module):
|
|
"""
|
|
NN dynamics model predicting per-second rates of change (like ReactorKNNModel).
|
|
|
|
Inputs are z-score normalised; outputs are normalised rates.
|
|
forward() returns absolute next-state dict: cur + predicted_rate * time_delta.
|
|
forward_with_uncertainty() returns (next_state, 0.0) — no uncertainty estimate.
|
|
"""
|
|
def __init__(self, input_params: List[str], output_params: List[str]):
|
|
super(ReactorDynamicsModel, self).__init__()
|
|
self.input_params = input_params
|
|
self.output_params = output_params
|
|
self.net = ReactorDynamicsNet(len(input_params), len(output_params))
|
|
# Normalisation stats set by fit()
|
|
self.register_buffer('_in_mean', torch.zeros(len(input_params)))
|
|
self.register_buffer('_in_std', torch.ones(len(input_params)))
|
|
self.register_buffer('_rate_mean', torch.zeros(len(output_params)))
|
|
self.register_buffer('_rate_std', torch.ones(len(output_params)))
|
|
|
|
def fit_normalisation(self, dataset):
|
|
"""Compute and store normalisation stats from a dataset."""
|
|
in_vecs, rate_vecs = [], []
|
|
for state, _action, next_state, dt in dataset:
|
|
if dt <= 0:
|
|
continue
|
|
in_vecs.append([state.get(p, 0.0) for p in self.input_params])
|
|
rate_vecs.append([(next_state.get(p, 0.0) - state.get(p, 0.0)) / dt
|
|
for p in self.output_params])
|
|
ins = np.array(in_vecs, dtype=np.float32)
|
|
rates = np.array(rate_vecs, dtype=np.float32)
|
|
in_std = ins.std(0)
|
|
r_std = rates.std(0)
|
|
self._in_mean.copy_(torch.from_numpy(ins.mean(0)))
|
|
self._in_std.copy_(torch.from_numpy(np.where(in_std < 1e-6, 1.0, in_std)))
|
|
self._rate_mean.copy_(torch.from_numpy(rates.mean(0)))
|
|
self._rate_std.copy_(torch.from_numpy(np.where(r_std < 1e-6, 1.0, r_std)))
|
|
|
|
def _normalise_input(self, t: torch.Tensor) -> torch.Tensor:
|
|
return (t - self._in_mean) / self._in_std
|
|
|
|
def _denormalise_rate(self, t: torch.Tensor) -> torch.Tensor:
|
|
return t * self._rate_std + self._rate_mean
|
|
|
|
def forward(self, state_dict, time_delta):
|
|
return self.forward_with_uncertainty(state_dict, time_delta)[0]
|
|
|
|
def forward_with_uncertainty(self, state_dict, time_delta, mc_samples=3):
|
|
"""MC-Dropout uncertainty: run mc_samples stochastic forward passes.
|
|
|
|
Uncertainty is the mean normalised std across output dims, clipped to [0, 1].
|
|
0 = very confident (low variance), ~1 = high variance / OOD.
|
|
"""
|
|
s = torch.tensor([state_dict.get(p, 0.0) for p in self.input_params],
|
|
dtype=torch.float32).unsqueeze(0)
|
|
s_norm = self._normalise_input(s)
|
|
dt_t = torch.tensor([[time_delta]], dtype=torch.float32)
|
|
|
|
# Keep dropout active for uncertainty sampling
|
|
self.net.train()
|
|
with torch.no_grad():
|
|
samples = torch.stack([self.net(s_norm, dt_t).squeeze(0)
|
|
for _ in range(mc_samples)]) # (mc_samples, out_dim)
|
|
self.net.eval()
|
|
|
|
rate_norm_mean = samples.mean(0)
|
|
rate_norm_std = samples.std(0)
|
|
|
|
rate = self._denormalise_rate(rate_norm_mean)
|
|
cur = torch.tensor([state_dict.get(p, 0.0) for p in self.output_params],
|
|
dtype=torch.float32)
|
|
predicted = cur + rate * time_delta
|
|
pred_dict = {p: float(predicted[i]) for i, p in enumerate(self.output_params)}
|
|
|
|
# Uncertainty: mean coefficient of variation in normalised space, clipped to [0,1]
|
|
uncertainty = float(rate_norm_std.mean().clamp(0.0, 1.0))
|
|
return pred_dict, uncertainty
|
|
|
|
# --- kNN-based dynamics model ---
|
|
|
|
class ReactorKNNModel:
|
|
"""
|
|
Non-parametric dynamics model using k-nearest neighbours.
|
|
|
|
For a query (state, game_delta):
|
|
1. Find the k dataset entries whose *state* is closest (L2 in normalised space).
|
|
2. For each neighbour compute the per-second rate-of-change:
|
|
rate_i = (next_state_i - state_i) / game_delta_i
|
|
3. Linearly scale to the requested game_delta:
|
|
predicted_delta_i = rate_i * game_delta
|
|
4. Return the inverse-distance-weighted average of those predicted deltas
|
|
added to the current output state.
|
|
|
|
The linear-in-time assumption means two datapoints at 0.5 s and 2 s contribute
|
|
equally once normalised by their own game_delta.
|
|
"""
|
|
|
|
def __init__(self, input_params: List[str], output_params: List[str], k: int = 5):
|
|
self.input_params = input_params
|
|
self.output_params = output_params
|
|
self.k = k
|
|
self._states = None # (n, d_in) normalised state matrix
|
|
self._rates = None # (n, d_out) (next_out - cur_out) / game_delta
|
|
self._raw_states = None # unnormalised, for mean/std computation
|
|
self._mean = None
|
|
self._std = None
|
|
|
|
def fit(self, dataset):
|
|
"""Build lookup tables from a collected dataset."""
|
|
raw, rates = [], []
|
|
for state, _action, next_state, game_delta in dataset:
|
|
if game_delta <= 0:
|
|
continue
|
|
s = np.array([state[p] for p in self.input_params], dtype=np.float32)
|
|
cur = np.array([state[p] for p in self.output_params], dtype=np.float32)
|
|
nxt = np.array([next_state[p] for p in self.output_params], dtype=np.float32)
|
|
raw.append(s)
|
|
rates.append((nxt - cur) / game_delta)
|
|
|
|
self._raw_states = np.array(raw)
|
|
self._rates = np.array(rates)
|
|
self._mean = self._raw_states.mean(axis=0)
|
|
raw_std = self._raw_states.std(axis=0)
|
|
# Dimensions with zero variance in the training data carry no distance information.
|
|
# Use inf so they contribute 0 to normalised L2 (i.e., are ignored in kNN lookup).
|
|
self._std = np.where(raw_std < 1e-6, np.inf, raw_std)
|
|
self._states = (self._raw_states - self._mean) / self._std
|
|
|
|
def _lookup(self, s: np.ndarray):
|
|
"""Return (s_norm, idx, k) for the k nearest neighbours. s is a raw (d_in,) array."""
|
|
s_norm = (s - self._mean) / self._std
|
|
diff = self._states - s_norm # (n, d_in) broadcast
|
|
dists = np.einsum('ij,ij->i', diff, diff) # squared L2, faster than linalg.norm
|
|
k = min(self.k, len(dists))
|
|
idx = np.argpartition(dists, k - 1)[:k]
|
|
return s_norm, idx, k
|
|
|
|
def forward(self, state_dict: Dict, time_delta: float) -> Dict:
|
|
if self._states is None:
|
|
raise ValueError("Model not fitted. Call fit(dataset) first.")
|
|
return self.forward_with_uncertainty(state_dict, time_delta)[0]
|
|
|
|
def forward_with_uncertainty(self, state_dict: Dict, time_delta: float):
|
|
"""Return (prediction_dict, uncertainty_scalar).
|
|
|
|
Uncertainty is the GP posterior std in normalised input space:
|
|
0 = query lies exactly on a training point (fully confident)
|
|
~1 = query is far from all neighbours (maximally uncertain)
|
|
"""
|
|
if self._states is None:
|
|
raise ValueError("Model not fitted. Call fit(dataset) first.")
|
|
|
|
s = np.array([state_dict[p] for p in self.input_params], dtype=np.float32)
|
|
s_norm, idx, k = self._lookup(s)
|
|
X = self._states[idx] # (k, d_in)
|
|
Y = self._rates[idx] # (k, d_out)
|
|
|
|
# RBF kernel: k(a,b) = exp(-0.5 ||a-b||^2)
|
|
def rbf(A, B):
|
|
diff = A[:, None, :] - B[None, :, :]
|
|
return np.exp(-0.5 * np.einsum('ijk,ijk->ij', diff, diff))
|
|
|
|
K = rbf(X, X) + 1e-4 * np.eye(k)
|
|
k_star = rbf(s_norm[None, :], X)[0]
|
|
|
|
K_inv = np.linalg.inv(K)
|
|
mean_rates = k_star @ K_inv @ Y
|
|
|
|
var = max(0.0, 1.0 - float(k_star @ K_inv @ k_star))
|
|
std = float(np.sqrt(var))
|
|
|
|
cur_out = np.array([state_dict[p] for p in self.output_params], dtype=np.float32)
|
|
predicted = cur_out + mean_rates * time_delta
|
|
|
|
pred_dict = {p: float(predicted[i]) for i, p in enumerate(self.output_params)}
|
|
return pred_dict, std
|
|
|
|
# --- Mixture model ---
|
|
|
|
class MixtureModel:
|
|
"""Combines two dynamics models, selecting based on kNN uncertainty.
|
|
|
|
Uses knn_model when its uncertainty is below threshold (it's confident /
|
|
near training data). Falls back to nn_model when kNN is OOD.
|
|
|
|
Both models must implement forward_with_uncertainty(state_dict, time_delta).
|
|
input_params / output_params are taken from knn_model.
|
|
"""
|
|
def __init__(self, knn_model, nn_model):
|
|
self.knn_model = knn_model
|
|
self.nn_model = nn_model
|
|
self.input_params = knn_model.input_params
|
|
self.output_params = knn_model.output_params
|
|
|
|
def forward(self, state_dict, time_delta):
|
|
return self.forward_with_uncertainty(state_dict, time_delta)[0]
|
|
|
|
def forward_with_uncertainty(self, state_dict, time_delta):
|
|
knn_pred, knn_u = self.knn_model.forward_with_uncertainty(state_dict, time_delta)
|
|
nn_pred, nn_u = self.nn_model.forward_with_uncertainty(state_dict, time_delta)
|
|
w_knn = 1.0 - knn_u # high when kNN is confident
|
|
w_nn = knn_u # high when kNN is OOD
|
|
blended = {p: w_knn * knn_pred[p] + w_nn * nn_pred[p]
|
|
for p in self.output_params}
|
|
uncertainty = w_knn * knn_u + w_nn * nn_u # weighted uncertainty
|
|
return blended, uncertainty
|
|
|
|
|
|
# --- Learner ---
|
|
|
|
class NuconModelLearner:
|
|
def __init__(self, nucon=None, actor='null', dataset_path='nucon_dataset.pkl',
|
|
time_delta: Union[float, Tuple[float, float]] = 1.0,
|
|
include_valve_states: bool = False):
|
|
self.nucon = Nucon() if nucon is None else nucon
|
|
self.actor = Actors[actor](self.nucon) if actor in Actors else actor
|
|
self.dataset = self.load_dataset(dataset_path) or []
|
|
self.dataset_path = dataset_path
|
|
self.include_valve_states = include_valve_states
|
|
self.model = None
|
|
self.optimizer = None
|
|
|
|
# Exclude params with no physics signal
|
|
_JUNK_PARAMS = frozenset({'GAME_VERSION', 'TIME', 'TIME_STAMP', 'TIME_DAY',
|
|
'ALARMS_ACTIVE', 'FUN_IS_ENABLED', 'GAME_SIM_SPEED'})
|
|
candidate_params = {k: p for k, p in self.nucon.get_all_readable().items()
|
|
if k not in _JUNK_PARAMS and p.param_type != str}
|
|
# Filter out params that return None (subsystem not installed).
|
|
# Retry until the game is reachable.
|
|
import requests as _requests
|
|
while True:
|
|
try:
|
|
test_state = {k: self.nucon.get(k) for k in candidate_params}
|
|
break
|
|
except (_requests.exceptions.ConnectionError,
|
|
_requests.exceptions.Timeout):
|
|
print("Waiting for game to be reachable…")
|
|
time.sleep(5)
|
|
self.readable_params = [k for k in candidate_params if test_state[k] is not None]
|
|
self.non_writable_params = [k for k in self.readable_params
|
|
if not self.nucon.get_all_readable()[k].is_writable]
|
|
|
|
# Optionally include valve positions (input only — valves are externally driven)
|
|
self.valve_keys = []
|
|
if include_valve_states:
|
|
valves = self.nucon.get_valves()
|
|
self.valve_keys = [f'VALVE__{name}' for name in sorted(valves.keys())]
|
|
self.readable_params = self.readable_params + self.valve_keys
|
|
# valve positions are input-only (not predicted as outputs)
|
|
|
|
if isinstance(time_delta, (int, float)):
|
|
self.time_delta = lambda: time_delta
|
|
elif isinstance(time_delta, tuple) and len(time_delta) == 2:
|
|
self.time_delta = lambda: random.uniform(*time_delta)
|
|
else:
|
|
raise ValueError("time_delta must be a float or a tuple of two floats")
|
|
|
|
def _get_state(self):
|
|
state = {}
|
|
for param_id in self.readable_params:
|
|
if param_id in self.valve_keys:
|
|
continue # filled below
|
|
value = self.nucon.get(param_id)
|
|
if isinstance(value, Enum):
|
|
value = value.value
|
|
state[param_id] = value
|
|
if self.valve_keys:
|
|
valves = self.nucon.get_valves()
|
|
for key in self.valve_keys:
|
|
name = key[len('VALVE__'):]
|
|
state[key] = valves.get(name, {}).get('Value', 0.0)
|
|
return state
|
|
|
|
def collect_data(self, num_steps, save_every=10):
|
|
"""
|
|
Collect state-transition tuples from the live game.
|
|
|
|
Sleeps wall_time = target_game_delta / sim_speed so that each stored
|
|
game_delta is uniform regardless of the game's simulation speed setting.
|
|
|
|
Saves the dataset every ``save_every`` steps so a crash doesn't lose
|
|
everything. On a connection error the step is skipped and collection
|
|
resumes once the game is reachable again (retries every 5 s).
|
|
"""
|
|
import requests as _requests
|
|
|
|
def get_state_with_retry():
|
|
while True:
|
|
try:
|
|
return self._get_state()
|
|
except (_requests.exceptions.ConnectionError,
|
|
_requests.exceptions.Timeout) as e:
|
|
print(f"Connection lost ({e}). Retrying in 5 s…")
|
|
time.sleep(5)
|
|
|
|
state = get_state_with_retry()
|
|
collected = 0
|
|
for i in range(num_steps):
|
|
action = self.actor(state)
|
|
for param_id, value in action.items():
|
|
try:
|
|
self.nucon.set(param_id, value)
|
|
except Exception:
|
|
pass
|
|
|
|
target_game_delta = self.time_delta()
|
|
try:
|
|
sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0
|
|
except Exception:
|
|
sim_speed = 1.0
|
|
time.sleep(target_game_delta / sim_speed)
|
|
|
|
next_state = get_state_with_retry()
|
|
self.dataset.append((state, action, next_state, target_game_delta))
|
|
state = next_state
|
|
collected += 1
|
|
|
|
if collected % save_every == 0:
|
|
self.save_dataset()
|
|
print(f" {collected}/{num_steps} steps collected, dataset saved.")
|
|
|
|
self.save_dataset()
|
|
print(f"Collection complete. {collected} steps, {len(self.dataset)} total samples.")
|
|
|
|
def train_model(self, batch_size=32, num_epochs=10, test_split=0.2, lr=1e-3):
|
|
"""Train a neural-network dynamics model on the current dataset."""
|
|
if self.model is None:
|
|
self.model = ReactorDynamicsModel(self.readable_params, self.non_writable_params)
|
|
elif not isinstance(self.model, ReactorDynamicsModel):
|
|
raise ValueError("A kNN model is already loaded. Create a new learner to train an NN.")
|
|
self.model.fit_normalisation(self.dataset)
|
|
self.optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)
|
|
random.shuffle(self.dataset)
|
|
split_idx = int(len(self.dataset) * (1 - test_split))
|
|
train_data = self.dataset[:split_idx]
|
|
test_data = self.dataset[split_idx:]
|
|
for epoch in range(num_epochs):
|
|
train_loss = self._train_epoch(train_data, batch_size)
|
|
test_loss = self._test_epoch(test_data)
|
|
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
|
|
|
|
def fit_knn(self, k: int = 5):
|
|
"""Fit a kNN/GP dynamics model from the current dataset (instantaneous, no gradient steps)."""
|
|
if self.model is None:
|
|
self.model = ReactorKNNModel(self.readable_params, self.non_writable_params, k=k)
|
|
elif not isinstance(self.model, ReactorKNNModel):
|
|
raise ValueError("An NN model is already loaded. Create a new learner to fit a kNN.")
|
|
self.model.fit(self.dataset)
|
|
print(f"kNN model fitted on {len(self.dataset)} samples.")
|
|
|
|
def predict_with_uncertainty(self, state_dict: Dict, time_delta: float):
|
|
"""Return (prediction_dict, uncertainty_std). Only available after fit_knn()."""
|
|
if not isinstance(self.model, ReactorKNNModel):
|
|
raise ValueError("predict_with_uncertainty() requires a fitted kNN model (call fit_knn()).")
|
|
return self.model.forward_with_uncertainty(state_dict, time_delta)
|
|
|
|
def drop_well_fitted(self, error_threshold: float):
|
|
"""Drop samples the current model already predicts well (MSE < threshold).
|
|
|
|
Keeps only hard/surprising transitions. Useful for NN training to focus
|
|
capacity on difficult regions of state space.
|
|
"""
|
|
if self.model is None:
|
|
raise ValueError("No model fitted yet. Call train_model() or fit_knn() first.")
|
|
kept = []
|
|
for state, action, next_state, time_delta in self.dataset:
|
|
pred = self.model.forward(state, time_delta)
|
|
error = sum((pred[p] - next_state[p]) ** 2 for p in self.non_writable_params)
|
|
if error > error_threshold:
|
|
kept.append((state, action, next_state, time_delta))
|
|
dropped = len(self.dataset) - len(kept)
|
|
self.dataset = kept
|
|
self.save_dataset()
|
|
print(f"drop_well_fitted: kept {len(kept)}, dropped {dropped} samples.")
|
|
|
|
def drop_redundant(self, min_state_distance: float, min_output_distance: float = 0.0):
|
|
"""Drop near-duplicate samples, keeping only those that add coverage.
|
|
|
|
A sample is dropped only if *both* its input state and its output
|
|
transition are within the given distances of an already-kept sample
|
|
(L2 in z-scored space). If two samples share the same input state but
|
|
have different transitions they represent genuinely different dynamics
|
|
and are both kept regardless of `min_output_distance`.
|
|
|
|
Args:
|
|
min_state_distance: minimum L2 distance in z-scored input space.
|
|
min_output_distance: minimum L2 distance in z-scored output-delta
|
|
space. Defaults to 0 (only input distance matters).
|
|
"""
|
|
if not self.dataset:
|
|
return
|
|
|
|
in_params = [p for p in self.readable_params if p not in self.valve_keys]
|
|
out_params = self.non_writable_params
|
|
|
|
all_states = np.array([[s[p] for p in in_params] for s, *_ in self.dataset], dtype=np.float32)
|
|
all_deltas = np.array([[ns[p] - s[p] for p in out_params]
|
|
for s, _, ns, gd in self.dataset], dtype=np.float32)
|
|
|
|
s_mean, s_std = all_states.mean(0), all_states.std(0) + 1e-8
|
|
d_mean, d_std = all_deltas.mean(0), all_deltas.std(0) + 1e-8
|
|
|
|
s_norm = (all_states - s_mean) / s_std
|
|
d_norm = (all_deltas - d_mean) / d_std
|
|
|
|
kept_idx = [0]
|
|
kept_s = [s_norm[0]]
|
|
kept_d = [d_norm[0]]
|
|
|
|
for i in range(1, len(self.dataset)):
|
|
s_dists = np.linalg.norm(np.array(kept_s) - s_norm[i], axis=1)
|
|
d_dists = np.linalg.norm(np.array(kept_d) - d_norm[i], axis=1)
|
|
# Drop only if close in BOTH spaces
|
|
if not np.any((s_dists < min_state_distance) & (d_dists < min_output_distance)):
|
|
kept_idx.append(i)
|
|
kept_s.append(s_norm[i])
|
|
kept_d.append(d_norm[i])
|
|
|
|
dropped = len(self.dataset) - len(kept_idx)
|
|
self.dataset = [self.dataset[i] for i in kept_idx]
|
|
self.save_dataset()
|
|
print(f"drop_redundant: kept {len(self.dataset)}, dropped {dropped} samples.")
|
|
|
|
def _train_epoch(self, data, batch_size):
|
|
self.model.train()
|
|
total_loss = 0
|
|
n_batches = 0
|
|
for i in range(0, len(data), batch_size):
|
|
batch = [s for s in data[i:i+batch_size] if s[3] > 0]
|
|
if not batch:
|
|
continue
|
|
states = torch.tensor([[s[0].get(p, 0.0) for p in self.readable_params] for s in batch], dtype=torch.float32)
|
|
targets = torch.tensor([[(s[2].get(p, 0.0) - s[0].get(p, 0.0)) / s[3] for p in self.non_writable_params] for s in batch], dtype=torch.float32)
|
|
dts = torch.tensor([[s[3]] for s in batch], dtype=torch.float32)
|
|
s_norm = self.model._normalise_input(states)
|
|
rate_norm_pred = self.model.net(s_norm, dts)
|
|
rate_norm_target = (targets - self.model._rate_mean) / self.model._rate_std
|
|
self.optimizer.zero_grad()
|
|
loss = torch.nn.functional.mse_loss(rate_norm_pred, rate_norm_target)
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
total_loss += loss.item()
|
|
n_batches += 1
|
|
self.model.eval()
|
|
return total_loss / max(1, n_batches)
|
|
|
|
def _test_epoch(self, data):
|
|
total_loss = 0.0
|
|
n = 0
|
|
with torch.no_grad():
|
|
for state, _, next_state, dt in data:
|
|
if dt <= 0:
|
|
continue
|
|
s_t = torch.tensor([[state.get(p, 0.0) for p in self.readable_params]], dtype=torch.float32)
|
|
s_norm = self.model._normalise_input(s_t)
|
|
dt_t = torch.tensor([[dt]], dtype=torch.float32)
|
|
rate_norm_pred = self.model.net(s_norm, dt_t).squeeze(0)
|
|
target = torch.tensor([(next_state.get(p, 0.0) - state.get(p, 0.0)) / dt
|
|
for p in self.non_writable_params], dtype=torch.float32)
|
|
rate_norm_target = (target - self.model._rate_mean) / self.model._rate_std
|
|
total_loss += torch.nn.functional.mse_loss(rate_norm_pred, rate_norm_target).item()
|
|
n += 1
|
|
return total_loss / max(1, n)
|
|
|
|
def save_model(self, path):
|
|
if self.model is None:
|
|
raise ValueError("No model to save. Call train_model() or fit_knn() first.")
|
|
if isinstance(self.model, ReactorDynamicsModel):
|
|
torch.save({
|
|
'state_dict': self.model.state_dict(),
|
|
'input_params': self.model.input_params,
|
|
'output_params': self.model.output_params,
|
|
}, path)
|
|
else:
|
|
with open(path, 'wb') as f:
|
|
pickle.dump(self.model, f)
|
|
|
|
def load_model(self, path):
|
|
if path.endswith('.pkl'):
|
|
with open(path, 'rb') as f:
|
|
self.model = pickle.load(f)
|
|
else:
|
|
checkpoint = torch.load(path, weights_only=False)
|
|
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
|
m = ReactorDynamicsModel(checkpoint['input_params'], checkpoint['output_params'])
|
|
m.load_state_dict(checkpoint['state_dict'])
|
|
self.model = m
|
|
else:
|
|
# legacy plain state dict
|
|
self.model = ReactorDynamicsModel(self.readable_params, self.non_writable_params)
|
|
self.model.load_state_dict(checkpoint)
|
|
|
|
def save_dataset(self, path=None):
|
|
path = path or self.dataset_path
|
|
with open(path, 'wb') as f:
|
|
pickle.dump(self.dataset, f)
|
|
|
|
def load_dataset(self, path=None):
|
|
path = path or self.dataset_path
|
|
if os.path.exists(path):
|
|
with open(path, 'rb') as f:
|
|
return pickle.load(f)
|
|
return None
|
|
|
|
def merge_datasets(self, other_dataset_path):
|
|
other_dataset = self.load_dataset(other_dataset_path)
|
|
if not isinstance(other_dataset, list):
|
|
raise ValueError(
|
|
f"'{other_dataset_path}' does not contain a dataset (got {type(other_dataset).__name__}). "
|
|
f"Pass a dataset .pkl file, not a model file."
|
|
)
|
|
self.dataset.extend(other_dataset)
|
|
self.save_dataset()
|