"""SAC + HER training on kNN-GP simulator. Usage: python3.14 train_sac.py python3.14 train_sac.py --load /tmp/sac_nucon_knn # hot-start from previous run Requirements: - NuCon game running (for parameter metadata) - /tmp/reactor_knn.pkl (kNN-GP model) - /tmp/nucon_dataset.pkl (500-sample dataset for init_states) """ import argparse import pickle import torch from gymnasium.wrappers import TimeLimit from stable_baselines3 import SAC from stable_baselines3.her.her_replay_buffer import HerReplayBuffer from stable_baselines3.common.callbacks import CheckpointCallback from nucon.sim import NuconSimulator from nucon.model import ReactorDynamicsModel, MixtureModel from nucon.rl import NuconGoalEnv, Parameterized_Objectives, Parameterized_Terminators parser = argparse.ArgumentParser() parser.add_argument('--load', default=None, help='Path to existing model to hot-start from') parser.add_argument('--steps', type=int, default=50_000, help='Total timesteps (default: 50000)') parser.add_argument('--out', default='/tmp/sac_nucon_knn', help='Output path for saved model') parser.add_argument('--model', default='/tmp/reactor_knn.pkl', help='Dynamics model (.pkl for kNN, .pt for NN)') parser.add_argument('--model2', default=None, help='Second dynamics model for mixture (optional)') parser.add_argument('--dataset', default='/tmp/nucon_dataset.pkl', help='Dataset for init states') args = parser.parse_args() # --------------------------------------------------------------------------- # Load dynamics model(s) and dataset # --------------------------------------------------------------------------- def _load_model(path): if path.endswith('.pt'): ckpt = torch.load(path, weights_only=False) m = ReactorDynamicsModel(ckpt['input_params'], ckpt['output_params']) m.load_state_dict(ckpt['state_dict']) m.eval() return m with open(path, 'rb') as f: return pickle.load(f) dynamics_model = _load_model(args.model) if args.model2: dynamics_model = MixtureModel(dynamics_model, _load_model(args.model2)) with open(args.dataset, 'rb') as f: dataset = pickle.load(f) # Seed resets to in-distribution states from dataset init_states = [s for _, _, s, _ in dataset] # --------------------------------------------------------------------------- # Build sim + env # --------------------------------------------------------------------------- sim = NuconSimulator(port=8786) sim.set_model(dynamics_model) BATCH_SIZE = 2048 MAX_EPISODE_STEPS = 200 GENERATORS = ['GENERATOR_0_KW', 'GENERATOR_1_KW', 'GENERATOR_2_KW'] POWER_RANGE = {g: (0.0, 100_000.0) for g in GENERATORS} # per-generator kW; ~100 MW upper bound # Curated obs: physically relevant features for power control (~25 dims vs ~260 full) OBS_PARAMS = [ 'CORE_TEMP', 'CORE_PRESSURE', 'CORE_STATE_CRITICALITY', 'CORE_WEAR', 'CORE_INTEGRITY', 'ROD_BANK_POS_0_ACTUAL', 'ROD_BANK_POS_0_ORDERED', 'COOLANT_CORE_FLOW_SPEED', 'COOLANT_CORE_VESSEL_TEMPERATURE', 'COOLANT_CORE_PRESSURE', 'COOLANT_CORE_QUANTITY_IN_VESSEL', 'STEAM_TURBINE_0_RPM', 'STEAM_TURBINE_0_TEMPERATURE', 'STEAM_TURBINE_0_PRESSURE', 'STEAM_TURBINE_1_RPM', 'STEAM_TURBINE_1_TEMPERATURE', 'STEAM_TURBINE_1_PRESSURE', 'STEAM_TURBINE_2_RPM', 'STEAM_TURBINE_2_TEMPERATURE', 'STEAM_TURBINE_2_PRESSURE', 'GENERATOR_0_V', 'GENERATOR_1_V', 'GENERATOR_2_V', ] env = NuconGoalEnv( goal_params=GENERATORS, goal_range=POWER_RANGE, seconds_per_step=10, simulator=sim, obs_params=OBS_PARAMS, additional_objectives=[ Parameterized_Objectives['uncertainty_penalty'](start=0.3), Parameterized_Objectives['temp_below_linear'](max_temp=420), ], additional_objective_weights=[1.0, 0.01], init_states=init_states, delta_action_scale=0.05, goal_sampling_std=0.15, # Gaussian delta in normalised space (~180 kW typical) ) env = TimeLimit(env, max_episode_steps=MAX_EPISODE_STEPS) # --------------------------------------------------------------------------- # SAC + HER # learning_starts = batch_size: wait for batch_size complete (short) episodes # before the first gradient step. As the policy learns to stay in-dist, episodes # will get longer and HER has more transitions to relabel. # --------------------------------------------------------------------------- if args.load: print(f"Hot-starting from {args.load}") model = SAC.load(args.load, env=env, device='auto', custom_objects={'learning_rate': 3e-4, 'batch_size': BATCH_SIZE, 'tau': 0.005, 'gamma': 0.98, 'train_freq': 64, 'gradient_steps': 8, 'learning_starts': MAX_EPISODE_STEPS, 'ent_coef': 0.1}) else: model = SAC( 'MultiInputPolicy', env, replay_buffer_class=HerReplayBuffer, replay_buffer_kwargs={ 'n_sampled_goal': 4, 'goal_selection_strategy': 'future', }, verbose=1, learning_rate=3e-4, batch_size=BATCH_SIZE, tau=0.005, gamma=0.98, train_freq=64, gradient_steps=8, learning_starts=BATCH_SIZE, ent_coef=0.1, # fixed; auto-tuning diverges on this many action dims device='auto', ) checkpoint_cb = CheckpointCallback( save_freq=10_000, save_path=args.out + '_checkpoints/', name_prefix='sac', ) import json, os config = {'obs_params': OBS_PARAMS} for save_dir in [args.out + '_checkpoints/', os.path.dirname(args.out) or '.']: os.makedirs(save_dir, exist_ok=True) with open(os.path.join(save_dir, 'config.json'), 'w') as f: json.dump(config, f) model.learn(total_timesteps=args.steps, callback=checkpoint_cb) model.save(args.out) with open(args.out + '.json', 'w') as f: json.dump(config, f) print(f"Saved to {args.out}.zip")