fix: kNN zero-variance dims get inf std; hot-start SAC from saved model
- nucon/model.py: constant input dimensions (zero variance in training data) now get std=inf so they contribute 0 to normalised kNN distance instead of causing catastrophic OOD from tiny float epsilon - scripts/train_sac.py: add --load, --steps, --out CLI args; --load hot-starts actor/critic weights from a previous run (learning_starts=0) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
f582e72151
commit
55d6e8708e
@ -95,7 +95,10 @@ class ReactorKNNModel:
|
|||||||
self._raw_states = np.array(raw)
|
self._raw_states = np.array(raw)
|
||||||
self._rates = np.array(rates)
|
self._rates = np.array(rates)
|
||||||
self._mean = self._raw_states.mean(axis=0)
|
self._mean = self._raw_states.mean(axis=0)
|
||||||
self._std = self._raw_states.std(axis=0) + 1e-8
|
raw_std = self._raw_states.std(axis=0)
|
||||||
|
# Dimensions with zero variance in the training data carry no distance information.
|
||||||
|
# Use inf so they contribute 0 to normalised L2 (i.e., are ignored in kNN lookup).
|
||||||
|
self._std = np.where(raw_std < 1e-6, np.inf, raw_std)
|
||||||
self._states = (self._raw_states - self._mean) / self._std
|
self._states = (self._raw_states - self._mean) / self._std
|
||||||
|
|
||||||
def _lookup(self, s: np.ndarray):
|
def _lookup(self, s: np.ndarray):
|
||||||
|
|||||||
@ -2,12 +2,14 @@
|
|||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python3.14 train_sac.py
|
python3.14 train_sac.py
|
||||||
|
python3.14 train_sac.py --load /tmp/sac_nucon_knn # hot-start from previous run
|
||||||
|
|
||||||
Requirements:
|
Requirements:
|
||||||
- NuCon game running (for parameter metadata)
|
- NuCon game running (for parameter metadata)
|
||||||
- /tmp/reactor_knn.pkl (kNN-GP model)
|
- /tmp/reactor_knn.pkl (kNN-GP model)
|
||||||
- /tmp/nucon_dataset.pkl (500-sample dataset for init_states)
|
- /tmp/nucon_dataset.pkl (500-sample dataset for init_states)
|
||||||
"""
|
"""
|
||||||
|
import argparse
|
||||||
import pickle
|
import pickle
|
||||||
from gymnasium.wrappers import TimeLimit
|
from gymnasium.wrappers import TimeLimit
|
||||||
from stable_baselines3 import SAC
|
from stable_baselines3 import SAC
|
||||||
@ -16,6 +18,12 @@ from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
|
|||||||
from nucon.sim import NuconSimulator
|
from nucon.sim import NuconSimulator
|
||||||
from nucon.rl import NuconGoalEnv, Parameterized_Objectives, Parameterized_Terminators
|
from nucon.rl import NuconGoalEnv, Parameterized_Objectives, Parameterized_Terminators
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--load', default=None, help='Path to existing model to hot-start from')
|
||||||
|
parser.add_argument('--steps', type=int, default=50_000, help='Total timesteps (default: 50000)')
|
||||||
|
parser.add_argument('--out', default='/tmp/sac_nucon_knn', help='Output path for saved model')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Load model and dataset
|
# Load model and dataset
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@ -59,6 +67,14 @@ env = TimeLimit(env, max_episode_steps=MAX_EPISODE_STEPS)
|
|||||||
# before the first gradient step. As the policy learns to stay in-dist, episodes
|
# before the first gradient step. As the policy learns to stay in-dist, episodes
|
||||||
# will get longer and HER has more transitions to relabel.
|
# will get longer and HER has more transitions to relabel.
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
if args.load:
|
||||||
|
print(f"Hot-starting from {args.load}")
|
||||||
|
model = SAC.load(args.load, env=env, device='auto',
|
||||||
|
custom_objects={'learning_rate': 3e-4, 'batch_size': BATCH_SIZE,
|
||||||
|
'tau': 0.005, 'gamma': 0.98,
|
||||||
|
'train_freq': 64, 'gradient_steps': 8,
|
||||||
|
'learning_starts': 0})
|
||||||
|
else:
|
||||||
model = SAC(
|
model = SAC(
|
||||||
'MultiInputPolicy',
|
'MultiInputPolicy',
|
||||||
env,
|
env,
|
||||||
@ -78,6 +94,6 @@ model = SAC(
|
|||||||
device='auto',
|
device='auto',
|
||||||
)
|
)
|
||||||
|
|
||||||
model.learn(total_timesteps=50_000)
|
model.learn(total_timesteps=args.steps)
|
||||||
model.save('/tmp/sac_nucon_knn')
|
model.save(args.out)
|
||||||
print("Saved to /tmp/sac_nucon_knn.zip")
|
print(f"Saved to {args.out}.zip")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user