diff --git a/nucon/model.py b/nucon/model.py index e579888..4669a83 100644 --- a/nucon/model.py +++ b/nucon/model.py @@ -95,7 +95,10 @@ class ReactorKNNModel: self._raw_states = np.array(raw) self._rates = np.array(rates) self._mean = self._raw_states.mean(axis=0) - self._std = self._raw_states.std(axis=0) + 1e-8 + raw_std = self._raw_states.std(axis=0) + # Dimensions with zero variance in the training data carry no distance information. + # Use inf so they contribute 0 to normalised L2 (i.e., are ignored in kNN lookup). + self._std = np.where(raw_std < 1e-6, np.inf, raw_std) self._states = (self._raw_states - self._mean) / self._std def _lookup(self, s: np.ndarray): diff --git a/scripts/train_sac.py b/scripts/train_sac.py index 0ce51e9..e5f0b62 100644 --- a/scripts/train_sac.py +++ b/scripts/train_sac.py @@ -2,12 +2,14 @@ Usage: python3.14 train_sac.py + python3.14 train_sac.py --load /tmp/sac_nucon_knn # hot-start from previous run Requirements: - NuCon game running (for parameter metadata) - /tmp/reactor_knn.pkl (kNN-GP model) - /tmp/nucon_dataset.pkl (500-sample dataset for init_states) """ +import argparse import pickle from gymnasium.wrappers import TimeLimit from stable_baselines3 import SAC @@ -16,6 +18,12 @@ from stable_baselines3.her.her_replay_buffer import HerReplayBuffer from nucon.sim import NuconSimulator from nucon.rl import NuconGoalEnv, Parameterized_Objectives, Parameterized_Terminators +parser = argparse.ArgumentParser() +parser.add_argument('--load', default=None, help='Path to existing model to hot-start from') +parser.add_argument('--steps', type=int, default=50_000, help='Total timesteps (default: 50000)') +parser.add_argument('--out', default='/tmp/sac_nucon_knn', help='Output path for saved model') +args = parser.parse_args() + # --------------------------------------------------------------------------- # Load model and dataset # --------------------------------------------------------------------------- @@ -59,25 +67,33 @@ env = TimeLimit(env, max_episode_steps=MAX_EPISODE_STEPS) # before the first gradient step. As the policy learns to stay in-dist, episodes # will get longer and HER has more transitions to relabel. # --------------------------------------------------------------------------- -model = SAC( - 'MultiInputPolicy', - env, - replay_buffer_class=HerReplayBuffer, - replay_buffer_kwargs={ - 'n_sampled_goal': 4, - 'goal_selection_strategy': 'future', - }, - verbose=1, - learning_rate=3e-4, - batch_size=BATCH_SIZE, - tau=0.005, - gamma=0.98, - train_freq=64, - gradient_steps=8, - learning_starts=BATCH_SIZE, - device='auto', -) +if args.load: + print(f"Hot-starting from {args.load}") + model = SAC.load(args.load, env=env, device='auto', + custom_objects={'learning_rate': 3e-4, 'batch_size': BATCH_SIZE, + 'tau': 0.005, 'gamma': 0.98, + 'train_freq': 64, 'gradient_steps': 8, + 'learning_starts': 0}) +else: + model = SAC( + 'MultiInputPolicy', + env, + replay_buffer_class=HerReplayBuffer, + replay_buffer_kwargs={ + 'n_sampled_goal': 4, + 'goal_selection_strategy': 'future', + }, + verbose=1, + learning_rate=3e-4, + batch_size=BATCH_SIZE, + tau=0.005, + gamma=0.98, + train_freq=64, + gradient_steps=8, + learning_starts=BATCH_SIZE, + device='auto', + ) -model.learn(total_timesteps=50_000) -model.save('/tmp/sac_nucon_knn') -print("Saved to /tmp/sac_nucon_knn.zip") +model.learn(total_timesteps=args.steps) +model.save(args.out) +print(f"Saved to {args.out}.zip")