"""SAC + HER training on kNN-GP simulator. Usage: python3.14 train_sac.py python3.14 train_sac.py --load /tmp/sac_nucon_knn # hot-start from previous run Requirements: - NuCon game running (for parameter metadata) - /tmp/reactor_knn.pkl (kNN-GP model) - /tmp/nucon_dataset.pkl (500-sample dataset for init_states) """ import argparse import pickle from gymnasium.wrappers import TimeLimit from stable_baselines3 import SAC from stable_baselines3.her.her_replay_buffer import HerReplayBuffer from nucon.sim import NuconSimulator from nucon.rl import NuconGoalEnv, Parameterized_Objectives, Parameterized_Terminators parser = argparse.ArgumentParser() parser.add_argument('--load', default=None, help='Path to existing model to hot-start from') parser.add_argument('--steps', type=int, default=50_000, help='Total timesteps (default: 50000)') parser.add_argument('--out', default='/tmp/sac_nucon_knn', help='Output path for saved model') args = parser.parse_args() # --------------------------------------------------------------------------- # Load model and dataset # --------------------------------------------------------------------------- with open('/tmp/reactor_knn.pkl', 'rb') as f: knn_model = pickle.load(f) with open('/tmp/nucon_dataset.pkl', 'rb') as f: dataset = pickle.load(f) # Seed resets to in-distribution states from dataset init_states = [s for _, _, s, _ in dataset] # --------------------------------------------------------------------------- # Build sim + env # --------------------------------------------------------------------------- sim = NuconSimulator(port=8786) sim.set_model(knn_model) BATCH_SIZE = 2048 MAX_EPISODE_STEPS = 200 env = NuconGoalEnv( goal_params=['CORE_TEMP'], goal_range={'CORE_TEMP': (55.0, 550.0)}, tolerance=0.05, seconds_per_step=10, simulator=sim, additional_objectives=[ Parameterized_Objectives['uncertainty_penalty'](start=0.3), ], additional_objective_weights=[1.0], init_states=init_states, delta_action_scale=0.05, ) env = TimeLimit(env, max_episode_steps=MAX_EPISODE_STEPS) # --------------------------------------------------------------------------- # SAC + HER # learning_starts = batch_size: wait for batch_size complete (short) episodes # before the first gradient step. As the policy learns to stay in-dist, episodes # will get longer and HER has more transitions to relabel. # --------------------------------------------------------------------------- if args.load: print(f"Hot-starting from {args.load}") model = SAC.load(args.load, env=env, device='auto', custom_objects={'learning_rate': 3e-4, 'batch_size': BATCH_SIZE, 'tau': 0.005, 'gamma': 0.98, 'train_freq': 64, 'gradient_steps': 8, 'learning_starts': 0}) else: model = SAC( 'MultiInputPolicy', env, replay_buffer_class=HerReplayBuffer, replay_buffer_kwargs={ 'n_sampled_goal': 4, 'goal_selection_strategy': 'future', }, verbose=1, learning_rate=3e-4, batch_size=BATCH_SIZE, tau=0.005, gamma=0.98, train_freq=64, gradient_steps=8, learning_starts=BATCH_SIZE, device='auto', ) model.learn(total_timesteps=args.steps) model.save(args.out) print(f"Saved to {args.out}.zip")