fix: kNN zero-variance dims get inf std; hot-start SAC from saved model

- nucon/model.py: constant input dimensions (zero variance in training
  data) now get std=inf so they contribute 0 to normalised kNN distance
  instead of causing catastrophic OOD from tiny float epsilon
- scripts/train_sac.py: add --load, --steps, --out CLI args; --load
  hot-starts actor/critic weights from a previous run (learning_starts=0)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Dominik Moritz Roth 2026-03-13 12:44:26 +01:00
parent f582e72151
commit 55d6e8708e
2 changed files with 41 additions and 22 deletions

View File

@ -95,7 +95,10 @@ class ReactorKNNModel:
self._raw_states = np.array(raw) self._raw_states = np.array(raw)
self._rates = np.array(rates) self._rates = np.array(rates)
self._mean = self._raw_states.mean(axis=0) self._mean = self._raw_states.mean(axis=0)
self._std = self._raw_states.std(axis=0) + 1e-8 raw_std = self._raw_states.std(axis=0)
# Dimensions with zero variance in the training data carry no distance information.
# Use inf so they contribute 0 to normalised L2 (i.e., are ignored in kNN lookup).
self._std = np.where(raw_std < 1e-6, np.inf, raw_std)
self._states = (self._raw_states - self._mean) / self._std self._states = (self._raw_states - self._mean) / self._std
def _lookup(self, s: np.ndarray): def _lookup(self, s: np.ndarray):

View File

@ -2,12 +2,14 @@
Usage: Usage:
python3.14 train_sac.py python3.14 train_sac.py
python3.14 train_sac.py --load /tmp/sac_nucon_knn # hot-start from previous run
Requirements: Requirements:
- NuCon game running (for parameter metadata) - NuCon game running (for parameter metadata)
- /tmp/reactor_knn.pkl (kNN-GP model) - /tmp/reactor_knn.pkl (kNN-GP model)
- /tmp/nucon_dataset.pkl (500-sample dataset for init_states) - /tmp/nucon_dataset.pkl (500-sample dataset for init_states)
""" """
import argparse
import pickle import pickle
from gymnasium.wrappers import TimeLimit from gymnasium.wrappers import TimeLimit
from stable_baselines3 import SAC from stable_baselines3 import SAC
@ -16,6 +18,12 @@ from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
from nucon.sim import NuconSimulator from nucon.sim import NuconSimulator
from nucon.rl import NuconGoalEnv, Parameterized_Objectives, Parameterized_Terminators from nucon.rl import NuconGoalEnv, Parameterized_Objectives, Parameterized_Terminators
parser = argparse.ArgumentParser()
parser.add_argument('--load', default=None, help='Path to existing model to hot-start from')
parser.add_argument('--steps', type=int, default=50_000, help='Total timesteps (default: 50000)')
parser.add_argument('--out', default='/tmp/sac_nucon_knn', help='Output path for saved model')
args = parser.parse_args()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Load model and dataset # Load model and dataset
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -59,25 +67,33 @@ env = TimeLimit(env, max_episode_steps=MAX_EPISODE_STEPS)
# before the first gradient step. As the policy learns to stay in-dist, episodes # before the first gradient step. As the policy learns to stay in-dist, episodes
# will get longer and HER has more transitions to relabel. # will get longer and HER has more transitions to relabel.
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
model = SAC( if args.load:
'MultiInputPolicy', print(f"Hot-starting from {args.load}")
env, model = SAC.load(args.load, env=env, device='auto',
replay_buffer_class=HerReplayBuffer, custom_objects={'learning_rate': 3e-4, 'batch_size': BATCH_SIZE,
replay_buffer_kwargs={ 'tau': 0.005, 'gamma': 0.98,
'n_sampled_goal': 4, 'train_freq': 64, 'gradient_steps': 8,
'goal_selection_strategy': 'future', 'learning_starts': 0})
}, else:
verbose=1, model = SAC(
learning_rate=3e-4, 'MultiInputPolicy',
batch_size=BATCH_SIZE, env,
tau=0.005, replay_buffer_class=HerReplayBuffer,
gamma=0.98, replay_buffer_kwargs={
train_freq=64, 'n_sampled_goal': 4,
gradient_steps=8, 'goal_selection_strategy': 'future',
learning_starts=BATCH_SIZE, },
device='auto', verbose=1,
) learning_rate=3e-4,
batch_size=BATCH_SIZE,
tau=0.005,
gamma=0.98,
train_freq=64,
gradient_steps=8,
learning_starts=BATCH_SIZE,
device='auto',
)
model.learn(total_timesteps=50_000) model.learn(total_timesteps=args.steps)
model.save('/tmp/sac_nucon_knn') model.save(args.out)
print("Saved to /tmp/sac_nucon_knn.zip") print(f"Saved to {args.out}.zip")