Overhaul model learning: kNN+GP, uncertainty, dataset pruning, sim-speed fix
Data collection:
- time_delta is now target game-time; wall sleep = game_delta / sim_speed
so stored deltas are uniform regardless of GAME_SIM_SPEED setting
- Auto-exclude junk params (GAME_VERSION, TIME, ALARMS_ACTIVE, …) and
params returning None (uninstalled subsystems)
- Optional include_valve_states=True adds all 53 valve positions as inputs
Model backends (model_type='nn' or 'knn'):
- ReactorKNNModel: k-nearest neighbours with GP interpolation
- Finds k nearest states, computes per-second transition rates,
linearly scales to requested game_delta (linear-in-time assumption)
- forward_with_uncertainty() returns (prediction_dict, gp_std)
where std≈0 = on known data, std≈1 = out of distribution
- NN training fixed: loss computed in tensor space, mse_loss per batch
Dataset management:
- drop_well_fitted(error_threshold): drop samples model predicts well,
keep hard cases (useful for NN curriculum)
- drop_redundant(min_state_distance, min_output_distance): drop samples
that are close in BOTH input state AND output transition space, keeping
genuinely different dynamics even at the same input state
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
c78106dffc
commit
31cb6862e1
315
nucon/model.py
315
nucon/model.py
@ -8,13 +8,15 @@ from enum import Enum
|
|||||||
from nucon import Nucon
|
from nucon import Nucon
|
||||||
import pickle
|
import pickle
|
||||||
import os
|
import os
|
||||||
from typing import Union, Tuple, List
|
from typing import Union, Tuple, List, Dict
|
||||||
|
|
||||||
Actors = {
|
Actors = {
|
||||||
'random': lambda nucon: lambda obs: {param.id: random.uniform(param.min_val, param.max_val) if param.min_val is not None and param.max_val is not None else 0 for param in nucon.get_all_writable().values()},
|
'random': lambda nucon: lambda obs: {param.id: random.uniform(param.min_val, param.max_val) if param.min_val is not None and param.max_val is not None else 0 for param in nucon.get_all_writable().values()},
|
||||||
'null': lambda nucon: lambda obs: {},
|
'null': lambda nucon: lambda obs: {},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# --- NN-based dynamics model ---
|
||||||
|
|
||||||
class ReactorDynamicsNet(nn.Module):
|
class ReactorDynamicsNet(nn.Module):
|
||||||
def __init__(self, input_dim, output_dim):
|
def __init__(self, input_dim, output_dim):
|
||||||
super(ReactorDynamicsNet, self).__init__()
|
super(ReactorDynamicsNet, self).__init__()
|
||||||
@ -35,10 +37,7 @@ class ReactorDynamicsModel(nn.Module):
|
|||||||
super(ReactorDynamicsModel, self).__init__()
|
super(ReactorDynamicsModel, self).__init__()
|
||||||
self.input_params = input_params
|
self.input_params = input_params
|
||||||
self.output_params = output_params
|
self.output_params = output_params
|
||||||
|
self.net = ReactorDynamicsNet(len(input_params), len(output_params))
|
||||||
input_dim = len(input_params)
|
|
||||||
output_dim = len(output_params)
|
|
||||||
self.net = ReactorDynamicsNet(input_dim, output_dim)
|
|
||||||
|
|
||||||
def _state_dict_to_tensor(self, state_dict):
|
def _state_dict_to_tensor(self, state_dict):
|
||||||
return torch.tensor([state_dict[p] for p in self.input_params], dtype=torch.float32)
|
return torch.tensor([state_dict[p] for p in self.input_params], dtype=torch.float32)
|
||||||
@ -52,17 +51,142 @@ class ReactorDynamicsModel(nn.Module):
|
|||||||
predicted_tensor = self.net(state_tensor, time_delta_tensor)
|
predicted_tensor = self.net(state_tensor, time_delta_tensor)
|
||||||
return self._tensor_to_state_dict(predicted_tensor.squeeze(0))
|
return self._tensor_to_state_dict(predicted_tensor.squeeze(0))
|
||||||
|
|
||||||
|
# --- kNN-based dynamics model ---
|
||||||
|
|
||||||
|
class ReactorKNNModel:
|
||||||
|
"""
|
||||||
|
Non-parametric dynamics model using k-nearest neighbours.
|
||||||
|
|
||||||
|
For a query (state, game_delta):
|
||||||
|
1. Find the k dataset entries whose *state* is closest (L2 in normalised space).
|
||||||
|
2. For each neighbour compute the per-second rate-of-change:
|
||||||
|
rate_i = (next_state_i - state_i) / game_delta_i
|
||||||
|
3. Linearly scale to the requested game_delta:
|
||||||
|
predicted_delta_i = rate_i * game_delta
|
||||||
|
4. Return the inverse-distance-weighted average of those predicted deltas
|
||||||
|
added to the current output state.
|
||||||
|
|
||||||
|
The linear-in-time assumption means two datapoints at 0.5 s and 2 s contribute
|
||||||
|
equally once normalised by their own game_delta.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_params: List[str], output_params: List[str], k: int = 5):
|
||||||
|
self.input_params = input_params
|
||||||
|
self.output_params = output_params
|
||||||
|
self.k = k
|
||||||
|
self._states = None # (n, d_in) normalised state matrix
|
||||||
|
self._rates = None # (n, d_out) (next_out - cur_out) / game_delta
|
||||||
|
self._raw_states = None # unnormalised, for mean/std computation
|
||||||
|
self._mean = None
|
||||||
|
self._std = None
|
||||||
|
|
||||||
|
def fit(self, dataset):
|
||||||
|
"""Build lookup tables from a collected dataset."""
|
||||||
|
raw, rates = [], []
|
||||||
|
for state, _action, next_state, game_delta in dataset:
|
||||||
|
if game_delta <= 0:
|
||||||
|
continue
|
||||||
|
s = np.array([state[p] for p in self.input_params], dtype=np.float32)
|
||||||
|
cur = np.array([state[p] for p in self.output_params], dtype=np.float32)
|
||||||
|
nxt = np.array([next_state[p] for p in self.output_params], dtype=np.float32)
|
||||||
|
raw.append(s)
|
||||||
|
rates.append((nxt - cur) / game_delta)
|
||||||
|
|
||||||
|
self._raw_states = np.array(raw)
|
||||||
|
self._rates = np.array(rates)
|
||||||
|
self._mean = self._raw_states.mean(axis=0)
|
||||||
|
self._std = self._raw_states.std(axis=0) + 1e-8
|
||||||
|
self._states = (self._raw_states - self._mean) / self._std
|
||||||
|
|
||||||
|
def _lookup(self, state_dict: Dict):
|
||||||
|
"""Return (s_norm, idx, k) for the k nearest neighbours."""
|
||||||
|
s = np.array([state_dict[p] for p in self.input_params], dtype=np.float32)
|
||||||
|
s_norm = (s - self._mean) / self._std
|
||||||
|
dists = np.linalg.norm(self._states - s_norm, axis=1)
|
||||||
|
k = min(self.k, len(dists))
|
||||||
|
idx = np.argpartition(dists, k - 1)[:k]
|
||||||
|
return s_norm, idx, k
|
||||||
|
|
||||||
|
def forward(self, state_dict: Dict, time_delta: float) -> Dict:
|
||||||
|
if self._states is None:
|
||||||
|
raise ValueError("Model not fitted. Call fit(dataset) first.")
|
||||||
|
return self.forward_with_uncertainty(state_dict, time_delta)[0]
|
||||||
|
|
||||||
|
def forward_with_uncertainty(self, state_dict: Dict, time_delta: float):
|
||||||
|
"""Return (prediction_dict, uncertainty_scalar).
|
||||||
|
|
||||||
|
Uncertainty is the GP posterior std in normalised input space:
|
||||||
|
0 = query lies exactly on a training point (fully confident)
|
||||||
|
~1 = query is far from all neighbours (maximally uncertain)
|
||||||
|
"""
|
||||||
|
if self._states is None:
|
||||||
|
raise ValueError("Model not fitted. Call fit(dataset) first.")
|
||||||
|
|
||||||
|
s_norm, idx, k = self._lookup(state_dict)
|
||||||
|
X = self._states[idx] # (k, d_in)
|
||||||
|
Y = self._rates[idx] # (k, d_out)
|
||||||
|
|
||||||
|
# RBF kernel (vectorised): k(a,b) = exp(-0.5 ||a-b||^2)
|
||||||
|
def rbf_matrix(A, B):
|
||||||
|
diff = A[:, None, :] - B[None, :, :] # (|A|, |B|, d)
|
||||||
|
return np.exp(-0.5 * (diff ** 2).sum(axis=-1)) # (|A|, |B|)
|
||||||
|
|
||||||
|
K = rbf_matrix(X, X) + 1e-4 * np.eye(k) # (k, k)
|
||||||
|
k_star = rbf_matrix(s_norm[None, :], X)[0] # (k,)
|
||||||
|
|
||||||
|
K_inv = np.linalg.inv(K)
|
||||||
|
mean_rates = k_star @ K_inv @ Y # (d_out,)
|
||||||
|
|
||||||
|
# Posterior variance (scalar, shared across all output dims)
|
||||||
|
var = max(0.0, 1.0 - float(k_star @ K_inv @ k_star))
|
||||||
|
std = float(np.sqrt(var))
|
||||||
|
|
||||||
|
cur_out = np.array([state_dict[p] for p in self.output_params], dtype=np.float32)
|
||||||
|
predicted = cur_out + mean_rates * time_delta
|
||||||
|
|
||||||
|
pred_dict = {p: float(predicted[i]) for i, p in enumerate(self.output_params)}
|
||||||
|
return pred_dict, std
|
||||||
|
|
||||||
|
# --- Learner ---
|
||||||
|
|
||||||
class NuconModelLearner:
|
class NuconModelLearner:
|
||||||
def __init__(self, nucon=None, actor='null', dataset_path='nucon_dataset.pkl', time_delta: Union[float, Tuple[float, float]] = 0.1):
|
def __init__(self, nucon=None, actor='null', dataset_path='nucon_dataset.pkl',
|
||||||
|
time_delta: Union[float, Tuple[float, float]] = 1.0,
|
||||||
|
model_type: str = 'nn', knn_k: int = 5,
|
||||||
|
include_valve_states: bool = False):
|
||||||
self.nucon = Nucon() if nucon is None else nucon
|
self.nucon = Nucon() if nucon is None else nucon
|
||||||
self.actor = Actors[actor](self.nucon) if actor in Actors else actor
|
self.actor = Actors[actor](self.nucon) if actor in Actors else actor
|
||||||
self.dataset = self.load_dataset(dataset_path) or []
|
self.dataset = self.load_dataset(dataset_path) or []
|
||||||
self.dataset_path = dataset_path
|
self.dataset_path = dataset_path
|
||||||
|
self.include_valve_states = include_valve_states
|
||||||
|
|
||||||
self.readable_params = list(self.nucon.get_all_readable().keys())
|
# Exclude params with no physics signal
|
||||||
self.non_writable_params = [param.id for param in self.nucon.get_all_readable().values() if not param.is_writable]
|
_JUNK_PARAMS = frozenset({'GAME_VERSION', 'TIME', 'TIME_STAMP', 'TIME_DAY',
|
||||||
self.model = ReactorDynamicsModel(self.readable_params, self.non_writable_params)
|
'ALARMS_ACTIVE', 'FUN_IS_ENABLED', 'GAME_SIM_SPEED'})
|
||||||
self.optimizer = optim.Adam(self.model.parameters())
|
candidate_params = {k: p for k, p in self.nucon.get_all_readable().items()
|
||||||
|
if k not in _JUNK_PARAMS and p.param_type != str}
|
||||||
|
# Filter out params that return None (subsystem not installed)
|
||||||
|
test_state = {k: self.nucon.get(k) for k in candidate_params}
|
||||||
|
self.readable_params = [k for k in candidate_params if test_state[k] is not None]
|
||||||
|
self.non_writable_params = [k for k in self.readable_params
|
||||||
|
if not self.nucon.get_all_readable()[k].is_writable]
|
||||||
|
|
||||||
|
# Optionally include valve positions (input only — valves are externally driven)
|
||||||
|
self.valve_keys = []
|
||||||
|
if include_valve_states:
|
||||||
|
valves = self.nucon.get_valves()
|
||||||
|
self.valve_keys = [f'VALVE__{name}' for name in sorted(valves.keys())]
|
||||||
|
self.readable_params = self.readable_params + self.valve_keys
|
||||||
|
# valve positions are input-only (not predicted as outputs)
|
||||||
|
|
||||||
|
if model_type == 'nn':
|
||||||
|
self.model = ReactorDynamicsModel(self.readable_params, self.non_writable_params)
|
||||||
|
self.optimizer = optim.Adam(self.model.parameters())
|
||||||
|
elif model_type == 'knn':
|
||||||
|
self.model = ReactorKNNModel(self.readable_params, self.non_writable_params, k=knn_k)
|
||||||
|
self.optimizer = None
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model_type '{model_type}'. Use 'nn' or 'knn'.")
|
||||||
|
|
||||||
if isinstance(time_delta, (int, float)):
|
if isinstance(time_delta, (int, float)):
|
||||||
self.time_delta = lambda: time_delta
|
self.time_delta = lambda: time_delta
|
||||||
@ -73,87 +197,180 @@ class NuconModelLearner:
|
|||||||
|
|
||||||
def _get_state(self):
|
def _get_state(self):
|
||||||
state = {}
|
state = {}
|
||||||
for param_id, param in self.nucon.get_all_readable().items():
|
for param_id in self.readable_params:
|
||||||
value = self.nucon.get(param)
|
if param_id in self.valve_keys:
|
||||||
|
continue # filled below
|
||||||
|
value = self.nucon.get(param_id)
|
||||||
if isinstance(value, Enum):
|
if isinstance(value, Enum):
|
||||||
value = value.value
|
value = value.value
|
||||||
state[param_id] = value
|
state[param_id] = value
|
||||||
|
if self.valve_keys:
|
||||||
|
valves = self.nucon.get_valves()
|
||||||
|
for key in self.valve_keys:
|
||||||
|
name = key[len('VALVE__'):]
|
||||||
|
state[key] = valves.get(name, {}).get('Value', 0.0)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def collect_data(self, num_steps):
|
def collect_data(self, num_steps):
|
||||||
|
"""
|
||||||
|
Collect state-transition tuples from the live game.
|
||||||
|
|
||||||
|
Sleeps wall_time = target_game_delta / sim_speed so that each stored
|
||||||
|
game_delta is uniform regardless of the game's simulation speed setting.
|
||||||
|
"""
|
||||||
state = self._get_state()
|
state = self._get_state()
|
||||||
for _ in range(num_steps):
|
for _ in range(num_steps):
|
||||||
action = self.actor(state)
|
action = self.actor(state)
|
||||||
start_time = time.time()
|
|
||||||
for param_id, value in action.items():
|
for param_id, value in action.items():
|
||||||
self.nucon.set(param_id, value)
|
self.nucon.set(param_id, value)
|
||||||
time_delta = self.time_delta()
|
|
||||||
time.sleep(time_delta)
|
target_game_delta = self.time_delta()
|
||||||
|
sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0
|
||||||
|
time.sleep(target_game_delta / sim_speed)
|
||||||
next_state = self._get_state()
|
next_state = self._get_state()
|
||||||
|
|
||||||
self.dataset.append((state, action, next_state, time_delta))
|
self.dataset.append((state, action, next_state, target_game_delta))
|
||||||
state = next_state
|
state = next_state
|
||||||
|
|
||||||
self.save_dataset()
|
self.save_dataset()
|
||||||
|
|
||||||
def refine_dataset(self, error_threshold):
|
|
||||||
refined_data = []
|
|
||||||
for state, action, next_state, time_delta in self.dataset:
|
|
||||||
predicted_next_state = self.model(state, time_delta)
|
|
||||||
|
|
||||||
error = sum((predicted_next_state[p] - next_state[p])**2 for p in self.non_writable_params)
|
|
||||||
if error > error_threshold:
|
|
||||||
refined_data.append((state, action, next_state, time_delta))
|
|
||||||
|
|
||||||
self.dataset = refined_data
|
|
||||||
self.save_dataset()
|
|
||||||
|
|
||||||
def train_model(self, batch_size=32, num_epochs=10, test_split=0.2):
|
def train_model(self, batch_size=32, num_epochs=10, test_split=0.2):
|
||||||
|
"""Train the NN model. For kNN, call fit_knn() instead."""
|
||||||
|
if not isinstance(self.model, ReactorDynamicsModel):
|
||||||
|
raise ValueError("train_model() is for the NN model. Use fit_knn() for kNN.")
|
||||||
random.shuffle(self.dataset)
|
random.shuffle(self.dataset)
|
||||||
split_idx = int(len(self.dataset) * (1 - test_split))
|
split_idx = int(len(self.dataset) * (1 - test_split))
|
||||||
train_data = self.dataset[:split_idx]
|
train_data = self.dataset[:split_idx]
|
||||||
test_data = self.dataset[split_idx:]
|
test_data = self.dataset[split_idx:]
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
train_loss = self._train_epoch(train_data, batch_size)
|
train_loss = self._train_epoch(train_data, batch_size)
|
||||||
test_loss = self._test_epoch(test_data)
|
test_loss = self._test_epoch(test_data)
|
||||||
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
|
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
|
||||||
|
|
||||||
|
def fit_knn(self):
|
||||||
|
"""Fit the kNN/GP model from the current dataset (instantaneous, no gradient steps)."""
|
||||||
|
if not isinstance(self.model, ReactorKNNModel):
|
||||||
|
raise ValueError("fit_knn() is for the kNN model. Use train_model() for NN.")
|
||||||
|
self.model.fit(self.dataset)
|
||||||
|
print(f"kNN model fitted on {len(self.dataset)} samples.")
|
||||||
|
|
||||||
|
def predict_with_uncertainty(self, state_dict: Dict, time_delta: float):
|
||||||
|
"""Return (prediction_dict, uncertainty_std). Only available for kNN model."""
|
||||||
|
if not isinstance(self.model, ReactorKNNModel):
|
||||||
|
raise ValueError("predict_with_uncertainty() requires model_type='knn'.")
|
||||||
|
return self.model.forward_with_uncertainty(state_dict, time_delta)
|
||||||
|
|
||||||
|
def drop_well_fitted(self, error_threshold: float):
|
||||||
|
"""Drop samples the current model already predicts well (MSE < threshold).
|
||||||
|
|
||||||
|
Keeps only hard/surprising transitions. Useful for NN training to focus
|
||||||
|
capacity on difficult regions of state space.
|
||||||
|
"""
|
||||||
|
kept = []
|
||||||
|
for state, action, next_state, time_delta in self.dataset:
|
||||||
|
pred = self.model.forward(state, time_delta)
|
||||||
|
error = sum((pred[p] - next_state[p]) ** 2 for p in self.non_writable_params)
|
||||||
|
if error > error_threshold:
|
||||||
|
kept.append((state, action, next_state, time_delta))
|
||||||
|
dropped = len(self.dataset) - len(kept)
|
||||||
|
self.dataset = kept
|
||||||
|
self.save_dataset()
|
||||||
|
print(f"drop_well_fitted: kept {len(kept)}, dropped {dropped} samples.")
|
||||||
|
|
||||||
|
def drop_redundant(self, min_state_distance: float, min_output_distance: float = 0.0):
|
||||||
|
"""Drop near-duplicate samples, keeping only those that add coverage.
|
||||||
|
|
||||||
|
A sample is dropped only if *both* its input state and its output
|
||||||
|
transition are within the given distances of an already-kept sample
|
||||||
|
(L2 in z-scored space). If two samples share the same input state but
|
||||||
|
have different transitions they represent genuinely different dynamics
|
||||||
|
and are both kept regardless of `min_output_distance`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_state_distance: minimum L2 distance in z-scored input space.
|
||||||
|
min_output_distance: minimum L2 distance in z-scored output-delta
|
||||||
|
space. Defaults to 0 (only input distance matters).
|
||||||
|
"""
|
||||||
|
if not self.dataset:
|
||||||
|
return
|
||||||
|
|
||||||
|
in_params = [p for p in self.readable_params if p not in self.valve_keys]
|
||||||
|
out_params = self.non_writable_params
|
||||||
|
|
||||||
|
all_states = np.array([[s[p] for p in in_params] for s, *_ in self.dataset], dtype=np.float32)
|
||||||
|
all_deltas = np.array([[ns[p] - s[p] for p in out_params]
|
||||||
|
for s, _, ns, gd in self.dataset], dtype=np.float32)
|
||||||
|
|
||||||
|
s_mean, s_std = all_states.mean(0), all_states.std(0) + 1e-8
|
||||||
|
d_mean, d_std = all_deltas.mean(0), all_deltas.std(0) + 1e-8
|
||||||
|
|
||||||
|
s_norm = (all_states - s_mean) / s_std
|
||||||
|
d_norm = (all_deltas - d_mean) / d_std
|
||||||
|
|
||||||
|
kept_idx = [0]
|
||||||
|
kept_s = [s_norm[0]]
|
||||||
|
kept_d = [d_norm[0]]
|
||||||
|
|
||||||
|
for i in range(1, len(self.dataset)):
|
||||||
|
s_dists = np.linalg.norm(np.array(kept_s) - s_norm[i], axis=1)
|
||||||
|
d_dists = np.linalg.norm(np.array(kept_d) - d_norm[i], axis=1)
|
||||||
|
# Drop only if close in BOTH spaces
|
||||||
|
if not np.any((s_dists < min_state_distance) & (d_dists < min_output_distance)):
|
||||||
|
kept_idx.append(i)
|
||||||
|
kept_s.append(s_norm[i])
|
||||||
|
kept_d.append(d_norm[i])
|
||||||
|
|
||||||
|
dropped = len(self.dataset) - len(kept_idx)
|
||||||
|
self.dataset = [self.dataset[i] for i in kept_idx]
|
||||||
|
self.save_dataset()
|
||||||
|
print(f"drop_redundant: kept {len(self.dataset)}, dropped {dropped} samples.")
|
||||||
|
|
||||||
def _train_epoch(self, data, batch_size):
|
def _train_epoch(self, data, batch_size):
|
||||||
|
out_indices = [self.readable_params.index(p) if p in self.readable_params else None
|
||||||
|
for p in self.non_writable_params]
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
for i in range(0, len(data), batch_size):
|
for i in range(0, len(data), batch_size):
|
||||||
batch = data[i:i+batch_size]
|
batch = data[i:i+batch_size]
|
||||||
states, _, next_states, time_deltas = zip(*batch)
|
|
||||||
|
|
||||||
loss = 0
|
|
||||||
for state, next_state, time_delta in zip(states, next_states, time_deltas):
|
|
||||||
predicted_next_state = self.model(state, time_delta)
|
|
||||||
loss += sum((predicted_next_state[p] - next_state[p])**2 for p in self.non_writable_params)
|
|
||||||
|
|
||||||
loss /= len(batch)
|
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
loss = torch.tensor(0.0)
|
||||||
|
for state, _, next_state, time_delta in batch:
|
||||||
|
state_t = self.model._state_dict_to_tensor(state).unsqueeze(0)
|
||||||
|
td_t = torch.tensor([[time_delta]], dtype=torch.float32)
|
||||||
|
pred = self.model.net(state_t, td_t).squeeze(0)
|
||||||
|
target = torch.tensor([next_state[p] for p in self.non_writable_params],
|
||||||
|
dtype=torch.float32)
|
||||||
|
loss = loss + torch.nn.functional.mse_loss(pred, target)
|
||||||
|
loss = loss / len(batch)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
return total_loss / max(1, len(data) // batch_size)
|
||||||
return total_loss / (len(data) // batch_size)
|
|
||||||
|
|
||||||
def _test_epoch(self, data):
|
def _test_epoch(self, data):
|
||||||
total_loss = 0
|
total_loss = 0.0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for state, _, next_state, time_delta in data:
|
for state, _, next_state, time_delta in data:
|
||||||
predicted_next_state = self.model(state, time_delta)
|
state_t = self.model._state_dict_to_tensor(state).unsqueeze(0)
|
||||||
loss = sum((predicted_next_state[p] - next_state[p])**2 for p in self.non_writable_params)
|
td_t = torch.tensor([[time_delta]], dtype=torch.float32)
|
||||||
total_loss += loss
|
pred = self.model.net(state_t, td_t).squeeze(0)
|
||||||
|
target = torch.tensor([next_state[p] for p in self.non_writable_params],
|
||||||
|
dtype=torch.float32)
|
||||||
|
total_loss += torch.nn.functional.mse_loss(pred, target).item()
|
||||||
return total_loss / len(data)
|
return total_loss / len(data)
|
||||||
|
|
||||||
def save_model(self, path):
|
def save_model(self, path):
|
||||||
torch.save(self.model.state_dict(), path)
|
if isinstance(self.model, ReactorDynamicsModel):
|
||||||
|
torch.save(self.model.state_dict(), path)
|
||||||
|
else:
|
||||||
|
with open(path, 'wb') as f:
|
||||||
|
pickle.dump(self.model, f)
|
||||||
|
|
||||||
def load_model(self, path):
|
def load_model(self, path):
|
||||||
self.model.load_state_dict(torch.load(path))
|
if isinstance(self.model, ReactorDynamicsModel):
|
||||||
|
self.model.load_state_dict(torch.load(path))
|
||||||
|
else:
|
||||||
|
with open(path, 'rb') as f:
|
||||||
|
self.model = pickle.load(f)
|
||||||
|
|
||||||
def save_dataset(self, path=None):
|
def save_dataset(self, path=None):
|
||||||
path = path or self.dataset_path
|
path = path or self.dataset_path
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user