refactor: remove model_type from NuconModelLearner.__init__
Model type is irrelevant during data collection. Models are now created lazily on first use: train_model() creates a ReactorDynamicsModel, fit_knn(k) creates a ReactorKNNModel. load_model() detects type by file extension as before. drop_well_fitted() now checks model exists. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
1f7ecc301f
commit
ce2019e060
12
README.md
12
README.md
@ -315,9 +315,9 @@ pip install -e '.[model]'
|
||||
```python
|
||||
from nucon.model import NuconModelLearner
|
||||
|
||||
# --- Data collection ---
|
||||
# --- Data collection (model_type not needed here) ---
|
||||
learner = NuconModelLearner(
|
||||
time_delta=10.0, # 10 game-seconds per step (wall sleep auto-scales with sim speed)
|
||||
time_delta=10.0, # 10 game-seconds per step (wall sleep auto-scales with sim speed)
|
||||
include_valve_states=False, # set True to include all 53 valve positions as model inputs
|
||||
)
|
||||
learner.collect_data(num_steps=1000)
|
||||
@ -327,19 +327,19 @@ learner.save_dataset('reactor_dataset.pkl')
|
||||
learner.merge_datasets('other_session.pkl')
|
||||
|
||||
# --- Neural network backend ---
|
||||
nn_learner = NuconModelLearner(model_type='nn', dataset_path='reactor_dataset.pkl')
|
||||
nn_learner.train_model(batch_size=32, num_epochs=50)
|
||||
nn_learner = NuconModelLearner(dataset_path='reactor_dataset.pkl')
|
||||
nn_learner.train_model(batch_size=32, num_epochs=50) # creates NN model on first call
|
||||
# Drop samples the NN already predicts well (keep hard cases for further training)
|
||||
nn_learner.drop_well_fitted(error_threshold=1.0)
|
||||
nn_learner.save_model('reactor_nn.pth')
|
||||
|
||||
# --- kNN + GP backend ---
|
||||
knn_learner = NuconModelLearner(model_type='knn', knn_k=10, dataset_path='reactor_dataset.pkl')
|
||||
knn_learner = NuconModelLearner(dataset_path='reactor_dataset.pkl')
|
||||
# Drop near-duplicate samples before fitting (keeps diverse coverage).
|
||||
# A sample is dropped only if BOTH its input state AND output transition
|
||||
# are within the given distances of an already-kept sample.
|
||||
knn_learner.drop_redundant(min_state_distance=0.1, min_output_distance=0.05)
|
||||
knn_learner.fit_knn()
|
||||
knn_learner.fit_knn(k=10) # creates kNN model on first call
|
||||
|
||||
# Point prediction
|
||||
state = knn_learner._get_state()
|
||||
|
||||
@ -152,13 +152,14 @@ class ReactorKNNModel:
|
||||
class NuconModelLearner:
|
||||
def __init__(self, nucon=None, actor='null', dataset_path='nucon_dataset.pkl',
|
||||
time_delta: Union[float, Tuple[float, float]] = 1.0,
|
||||
model_type: str = 'nn', knn_k: int = 5,
|
||||
include_valve_states: bool = False):
|
||||
self.nucon = Nucon() if nucon is None else nucon
|
||||
self.actor = Actors[actor](self.nucon) if actor in Actors else actor
|
||||
self.dataset = self.load_dataset(dataset_path) or []
|
||||
self.dataset_path = dataset_path
|
||||
self.include_valve_states = include_valve_states
|
||||
self.model = None
|
||||
self.optimizer = None
|
||||
|
||||
# Exclude params with no physics signal
|
||||
_JUNK_PARAMS = frozenset({'GAME_VERSION', 'TIME', 'TIME_STAMP', 'TIME_DAY',
|
||||
@ -179,15 +180,6 @@ class NuconModelLearner:
|
||||
self.readable_params = self.readable_params + self.valve_keys
|
||||
# valve positions are input-only (not predicted as outputs)
|
||||
|
||||
if model_type == 'nn':
|
||||
self.model = ReactorDynamicsModel(self.readable_params, self.non_writable_params)
|
||||
self.optimizer = optim.Adam(self.model.parameters())
|
||||
elif model_type == 'knn':
|
||||
self.model = ReactorKNNModel(self.readable_params, self.non_writable_params, k=knn_k)
|
||||
self.optimizer = None
|
||||
else:
|
||||
raise ValueError(f"Unknown model_type '{model_type}'. Use 'nn' or 'knn'.")
|
||||
|
||||
if isinstance(time_delta, (int, float)):
|
||||
self.time_delta = lambda: time_delta
|
||||
elif isinstance(time_delta, tuple) and len(time_delta) == 2:
|
||||
@ -235,9 +227,12 @@ class NuconModelLearner:
|
||||
self.save_dataset()
|
||||
|
||||
def train_model(self, batch_size=32, num_epochs=10, test_split=0.2):
|
||||
"""Train the NN model. For kNN, call fit_knn() instead."""
|
||||
if not isinstance(self.model, ReactorDynamicsModel):
|
||||
raise ValueError("train_model() is for the NN model. Use fit_knn() for kNN.")
|
||||
"""Train a neural-network dynamics model on the current dataset."""
|
||||
if self.model is None:
|
||||
self.model = ReactorDynamicsModel(self.readable_params, self.non_writable_params)
|
||||
self.optimizer = optim.Adam(self.model.parameters())
|
||||
elif not isinstance(self.model, ReactorDynamicsModel):
|
||||
raise ValueError("A kNN model is already loaded. Create a new learner to train an NN.")
|
||||
random.shuffle(self.dataset)
|
||||
split_idx = int(len(self.dataset) * (1 - test_split))
|
||||
train_data = self.dataset[:split_idx]
|
||||
@ -247,17 +242,19 @@ class NuconModelLearner:
|
||||
test_loss = self._test_epoch(test_data)
|
||||
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
|
||||
|
||||
def fit_knn(self):
|
||||
"""Fit the kNN/GP model from the current dataset (instantaneous, no gradient steps)."""
|
||||
if not isinstance(self.model, ReactorKNNModel):
|
||||
raise ValueError("fit_knn() is for the kNN model. Use train_model() for NN.")
|
||||
def fit_knn(self, k: int = 5):
|
||||
"""Fit a kNN/GP dynamics model from the current dataset (instantaneous, no gradient steps)."""
|
||||
if self.model is None:
|
||||
self.model = ReactorKNNModel(self.readable_params, self.non_writable_params, k=k)
|
||||
elif not isinstance(self.model, ReactorKNNModel):
|
||||
raise ValueError("An NN model is already loaded. Create a new learner to fit a kNN.")
|
||||
self.model.fit(self.dataset)
|
||||
print(f"kNN model fitted on {len(self.dataset)} samples.")
|
||||
|
||||
def predict_with_uncertainty(self, state_dict: Dict, time_delta: float):
|
||||
"""Return (prediction_dict, uncertainty_std). Only available for kNN model."""
|
||||
"""Return (prediction_dict, uncertainty_std). Only available after fit_knn()."""
|
||||
if not isinstance(self.model, ReactorKNNModel):
|
||||
raise ValueError("predict_with_uncertainty() requires model_type='knn'.")
|
||||
raise ValueError("predict_with_uncertainty() requires a fitted kNN model (call fit_knn()).")
|
||||
return self.model.forward_with_uncertainty(state_dict, time_delta)
|
||||
|
||||
def drop_well_fitted(self, error_threshold: float):
|
||||
@ -266,6 +263,8 @@ class NuconModelLearner:
|
||||
Keeps only hard/surprising transitions. Useful for NN training to focus
|
||||
capacity on difficult regions of state space.
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("No model fitted yet. Call train_model() or fit_knn() first.")
|
||||
kept = []
|
||||
for state, action, next_state, time_delta in self.dataset:
|
||||
pred = self.model.forward(state, time_delta)
|
||||
@ -359,6 +358,8 @@ class NuconModelLearner:
|
||||
return total_loss / len(data)
|
||||
|
||||
def save_model(self, path):
|
||||
if self.model is None:
|
||||
raise ValueError("No model to save. Call train_model() or fit_knn() first.")
|
||||
if isinstance(self.model, ReactorDynamicsModel):
|
||||
torch.save({
|
||||
'state_dict': self.model.state_dict(),
|
||||
@ -370,15 +371,19 @@ class NuconModelLearner:
|
||||
pickle.dump(self.model, f)
|
||||
|
||||
def load_model(self, path):
|
||||
if isinstance(self.model, ReactorDynamicsModel):
|
||||
checkpoint = torch.load(path, weights_only=False)
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
self.model.load_state_dict(checkpoint['state_dict'])
|
||||
else:
|
||||
self.model.load_state_dict(checkpoint)
|
||||
else:
|
||||
if path.endswith('.pkl'):
|
||||
with open(path, 'rb') as f:
|
||||
self.model = pickle.load(f)
|
||||
else:
|
||||
checkpoint = torch.load(path, weights_only=False)
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
m = ReactorDynamicsModel(checkpoint['input_params'], checkpoint['output_params'])
|
||||
m.load_state_dict(checkpoint['state_dict'])
|
||||
self.model = m
|
||||
else:
|
||||
# legacy plain state dict
|
||||
self.model = ReactorDynamicsModel(self.readable_params, self.non_writable_params)
|
||||
self.model.load_state_dict(checkpoint)
|
||||
|
||||
def save_dataset(self, path=None):
|
||||
path = path or self.dataset_path
|
||||
|
||||
Loading…
Reference in New Issue
Block a user