fix: kNN zero-variance dims get inf std; hot-start SAC from saved model
- nucon/model.py: constant input dimensions (zero variance in training data) now get std=inf so they contribute 0 to normalised kNN distance instead of causing catastrophic OOD from tiny float epsilon - scripts/train_sac.py: add --load, --steps, --out CLI args; --load hot-starts actor/critic weights from a previous run (learning_starts=0) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
f582e72151
commit
55d6e8708e
@ -95,7 +95,10 @@ class ReactorKNNModel:
|
||||
self._raw_states = np.array(raw)
|
||||
self._rates = np.array(rates)
|
||||
self._mean = self._raw_states.mean(axis=0)
|
||||
self._std = self._raw_states.std(axis=0) + 1e-8
|
||||
raw_std = self._raw_states.std(axis=0)
|
||||
# Dimensions with zero variance in the training data carry no distance information.
|
||||
# Use inf so they contribute 0 to normalised L2 (i.e., are ignored in kNN lookup).
|
||||
self._std = np.where(raw_std < 1e-6, np.inf, raw_std)
|
||||
self._states = (self._raw_states - self._mean) / self._std
|
||||
|
||||
def _lookup(self, s: np.ndarray):
|
||||
|
||||
@ -2,12 +2,14 @@
|
||||
|
||||
Usage:
|
||||
python3.14 train_sac.py
|
||||
python3.14 train_sac.py --load /tmp/sac_nucon_knn # hot-start from previous run
|
||||
|
||||
Requirements:
|
||||
- NuCon game running (for parameter metadata)
|
||||
- /tmp/reactor_knn.pkl (kNN-GP model)
|
||||
- /tmp/nucon_dataset.pkl (500-sample dataset for init_states)
|
||||
"""
|
||||
import argparse
|
||||
import pickle
|
||||
from gymnasium.wrappers import TimeLimit
|
||||
from stable_baselines3 import SAC
|
||||
@ -16,6 +18,12 @@ from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
|
||||
from nucon.sim import NuconSimulator
|
||||
from nucon.rl import NuconGoalEnv, Parameterized_Objectives, Parameterized_Terminators
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--load', default=None, help='Path to existing model to hot-start from')
|
||||
parser.add_argument('--steps', type=int, default=50_000, help='Total timesteps (default: 50000)')
|
||||
parser.add_argument('--out', default='/tmp/sac_nucon_knn', help='Output path for saved model')
|
||||
args = parser.parse_args()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Load model and dataset
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -59,25 +67,33 @@ env = TimeLimit(env, max_episode_steps=MAX_EPISODE_STEPS)
|
||||
# before the first gradient step. As the policy learns to stay in-dist, episodes
|
||||
# will get longer and HER has more transitions to relabel.
|
||||
# ---------------------------------------------------------------------------
|
||||
model = SAC(
|
||||
'MultiInputPolicy',
|
||||
env,
|
||||
replay_buffer_class=HerReplayBuffer,
|
||||
replay_buffer_kwargs={
|
||||
'n_sampled_goal': 4,
|
||||
'goal_selection_strategy': 'future',
|
||||
},
|
||||
verbose=1,
|
||||
learning_rate=3e-4,
|
||||
batch_size=BATCH_SIZE,
|
||||
tau=0.005,
|
||||
gamma=0.98,
|
||||
train_freq=64,
|
||||
gradient_steps=8,
|
||||
learning_starts=BATCH_SIZE,
|
||||
device='auto',
|
||||
)
|
||||
if args.load:
|
||||
print(f"Hot-starting from {args.load}")
|
||||
model = SAC.load(args.load, env=env, device='auto',
|
||||
custom_objects={'learning_rate': 3e-4, 'batch_size': BATCH_SIZE,
|
||||
'tau': 0.005, 'gamma': 0.98,
|
||||
'train_freq': 64, 'gradient_steps': 8,
|
||||
'learning_starts': 0})
|
||||
else:
|
||||
model = SAC(
|
||||
'MultiInputPolicy',
|
||||
env,
|
||||
replay_buffer_class=HerReplayBuffer,
|
||||
replay_buffer_kwargs={
|
||||
'n_sampled_goal': 4,
|
||||
'goal_selection_strategy': 'future',
|
||||
},
|
||||
verbose=1,
|
||||
learning_rate=3e-4,
|
||||
batch_size=BATCH_SIZE,
|
||||
tau=0.005,
|
||||
gamma=0.98,
|
||||
train_freq=64,
|
||||
gradient_steps=8,
|
||||
learning_starts=BATCH_SIZE,
|
||||
device='auto',
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=50_000)
|
||||
model.save('/tmp/sac_nucon_knn')
|
||||
print("Saved to /tmp/sac_nucon_knn.zip")
|
||||
model.learn(total_timesteps=args.steps)
|
||||
model.save(args.out)
|
||||
print(f"Saved to {args.out}.zip")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user