From ce2019e060001892f3b7190c514a54a2679d18b2 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Thu, 12 Mar 2026 17:50:31 +0100 Subject: [PATCH] refactor: remove model_type from NuconModelLearner.__init__ Model type is irrelevant during data collection. Models are now created lazily on first use: train_model() creates a ReactorDynamicsModel, fit_knn(k) creates a ReactorKNNModel. load_model() detects type by file extension as before. drop_well_fitted() now checks model exists. Co-Authored-By: Claude Sonnet 4.6 --- README.md | 12 +++++------ nucon/model.py | 57 +++++++++++++++++++++++++++----------------------- 2 files changed, 37 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 6a796b6..e627fcf 100644 --- a/README.md +++ b/README.md @@ -315,9 +315,9 @@ pip install -e '.[model]' ```python from nucon.model import NuconModelLearner -# --- Data collection --- +# --- Data collection (model_type not needed here) --- learner = NuconModelLearner( - time_delta=10.0, # 10 game-seconds per step (wall sleep auto-scales with sim speed) + time_delta=10.0, # 10 game-seconds per step (wall sleep auto-scales with sim speed) include_valve_states=False, # set True to include all 53 valve positions as model inputs ) learner.collect_data(num_steps=1000) @@ -327,19 +327,19 @@ learner.save_dataset('reactor_dataset.pkl') learner.merge_datasets('other_session.pkl') # --- Neural network backend --- -nn_learner = NuconModelLearner(model_type='nn', dataset_path='reactor_dataset.pkl') -nn_learner.train_model(batch_size=32, num_epochs=50) +nn_learner = NuconModelLearner(dataset_path='reactor_dataset.pkl') +nn_learner.train_model(batch_size=32, num_epochs=50) # creates NN model on first call # Drop samples the NN already predicts well (keep hard cases for further training) nn_learner.drop_well_fitted(error_threshold=1.0) nn_learner.save_model('reactor_nn.pth') # --- kNN + GP backend --- -knn_learner = NuconModelLearner(model_type='knn', knn_k=10, dataset_path='reactor_dataset.pkl') +knn_learner = NuconModelLearner(dataset_path='reactor_dataset.pkl') # Drop near-duplicate samples before fitting (keeps diverse coverage). # A sample is dropped only if BOTH its input state AND output transition # are within the given distances of an already-kept sample. knn_learner.drop_redundant(min_state_distance=0.1, min_output_distance=0.05) -knn_learner.fit_knn() +knn_learner.fit_knn(k=10) # creates kNN model on first call # Point prediction state = knn_learner._get_state() diff --git a/nucon/model.py b/nucon/model.py index 43f060e..5b87ded 100644 --- a/nucon/model.py +++ b/nucon/model.py @@ -152,13 +152,14 @@ class ReactorKNNModel: class NuconModelLearner: def __init__(self, nucon=None, actor='null', dataset_path='nucon_dataset.pkl', time_delta: Union[float, Tuple[float, float]] = 1.0, - model_type: str = 'nn', knn_k: int = 5, include_valve_states: bool = False): self.nucon = Nucon() if nucon is None else nucon self.actor = Actors[actor](self.nucon) if actor in Actors else actor self.dataset = self.load_dataset(dataset_path) or [] self.dataset_path = dataset_path self.include_valve_states = include_valve_states + self.model = None + self.optimizer = None # Exclude params with no physics signal _JUNK_PARAMS = frozenset({'GAME_VERSION', 'TIME', 'TIME_STAMP', 'TIME_DAY', @@ -179,15 +180,6 @@ class NuconModelLearner: self.readable_params = self.readable_params + self.valve_keys # valve positions are input-only (not predicted as outputs) - if model_type == 'nn': - self.model = ReactorDynamicsModel(self.readable_params, self.non_writable_params) - self.optimizer = optim.Adam(self.model.parameters()) - elif model_type == 'knn': - self.model = ReactorKNNModel(self.readable_params, self.non_writable_params, k=knn_k) - self.optimizer = None - else: - raise ValueError(f"Unknown model_type '{model_type}'. Use 'nn' or 'knn'.") - if isinstance(time_delta, (int, float)): self.time_delta = lambda: time_delta elif isinstance(time_delta, tuple) and len(time_delta) == 2: @@ -235,9 +227,12 @@ class NuconModelLearner: self.save_dataset() def train_model(self, batch_size=32, num_epochs=10, test_split=0.2): - """Train the NN model. For kNN, call fit_knn() instead.""" - if not isinstance(self.model, ReactorDynamicsModel): - raise ValueError("train_model() is for the NN model. Use fit_knn() for kNN.") + """Train a neural-network dynamics model on the current dataset.""" + if self.model is None: + self.model = ReactorDynamicsModel(self.readable_params, self.non_writable_params) + self.optimizer = optim.Adam(self.model.parameters()) + elif not isinstance(self.model, ReactorDynamicsModel): + raise ValueError("A kNN model is already loaded. Create a new learner to train an NN.") random.shuffle(self.dataset) split_idx = int(len(self.dataset) * (1 - test_split)) train_data = self.dataset[:split_idx] @@ -247,17 +242,19 @@ class NuconModelLearner: test_loss = self._test_epoch(test_data) print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}") - def fit_knn(self): - """Fit the kNN/GP model from the current dataset (instantaneous, no gradient steps).""" - if not isinstance(self.model, ReactorKNNModel): - raise ValueError("fit_knn() is for the kNN model. Use train_model() for NN.") + def fit_knn(self, k: int = 5): + """Fit a kNN/GP dynamics model from the current dataset (instantaneous, no gradient steps).""" + if self.model is None: + self.model = ReactorKNNModel(self.readable_params, self.non_writable_params, k=k) + elif not isinstance(self.model, ReactorKNNModel): + raise ValueError("An NN model is already loaded. Create a new learner to fit a kNN.") self.model.fit(self.dataset) print(f"kNN model fitted on {len(self.dataset)} samples.") def predict_with_uncertainty(self, state_dict: Dict, time_delta: float): - """Return (prediction_dict, uncertainty_std). Only available for kNN model.""" + """Return (prediction_dict, uncertainty_std). Only available after fit_knn().""" if not isinstance(self.model, ReactorKNNModel): - raise ValueError("predict_with_uncertainty() requires model_type='knn'.") + raise ValueError("predict_with_uncertainty() requires a fitted kNN model (call fit_knn()).") return self.model.forward_with_uncertainty(state_dict, time_delta) def drop_well_fitted(self, error_threshold: float): @@ -266,6 +263,8 @@ class NuconModelLearner: Keeps only hard/surprising transitions. Useful for NN training to focus capacity on difficult regions of state space. """ + if self.model is None: + raise ValueError("No model fitted yet. Call train_model() or fit_knn() first.") kept = [] for state, action, next_state, time_delta in self.dataset: pred = self.model.forward(state, time_delta) @@ -359,6 +358,8 @@ class NuconModelLearner: return total_loss / len(data) def save_model(self, path): + if self.model is None: + raise ValueError("No model to save. Call train_model() or fit_knn() first.") if isinstance(self.model, ReactorDynamicsModel): torch.save({ 'state_dict': self.model.state_dict(), @@ -370,15 +371,19 @@ class NuconModelLearner: pickle.dump(self.model, f) def load_model(self, path): - if isinstance(self.model, ReactorDynamicsModel): - checkpoint = torch.load(path, weights_only=False) - if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: - self.model.load_state_dict(checkpoint['state_dict']) - else: - self.model.load_state_dict(checkpoint) - else: + if path.endswith('.pkl'): with open(path, 'rb') as f: self.model = pickle.load(f) + else: + checkpoint = torch.load(path, weights_only=False) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + m = ReactorDynamicsModel(checkpoint['input_params'], checkpoint['output_params']) + m.load_state_dict(checkpoint['state_dict']) + self.model = m + else: + # legacy plain state dict + self.model = ReactorDynamicsModel(self.readable_params, self.non_writable_params) + self.model.load_state_dict(checkpoint) def save_dataset(self, path=None): path = path or self.dataset_path