rl.py: - Add missing `from enum import Enum` - Skip str-typed params in obs/action space construction (was crashing) - Guard action space: exclude write-only (is_readable=False) and cheat params - Fix step() param lookup (no longer iterates Nucon, uses _parameters dict directly) - Correct sim-speed time dilation in real-game sleep - Extract _build_param_space() helper shared by NuconEnv and NuconGoalEnv - Add NuconGoalEnv: goal-conditioned env with normalised achieved/desired goal vectors, compatible with SB3 HerReplayBuffer; goals sampled per episode - Register Nucon-goal_power-v0 and Nucon-goal_temp-v0 presets - Enum obs/action space now scalar index (not one-hot) sim.py: - Store self.port and self.host on NuconSimulator - Add set_model() to accept a pre-loaded model directly - load_model() detects type by extension (.pkl → kNN, else → NN torch) and reads new checkpoint format with embedded input/output param lists - _update_reactor_state() uses model.input_params (not all readable params), calls .forward() directly for both NN and kNN, guards torch.no_grad per type - Import ReactorKNNModel and pickle model.py: - save_model() embeds input_params/output_params in NN checkpoint dict - load_model() handles new checkpoint format (state_dict key) with fallback README.md: - Update note: RODS_POS_ORDERED is no longer the only writable param; game v2.2.25.213 exposes rod banks, pumps, MSCVs, switches and more Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
404 lines
16 KiB
Python
404 lines
16 KiB
Python
import gymnasium as gym
|
|
from gymnasium import spaces
|
|
import numpy as np
|
|
import time
|
|
from typing import Dict, Any
|
|
from enum import Enum
|
|
from nucon import Nucon, BreakerStatus, PumpStatus, PumpDryStatus, PumpOverloadStatus
|
|
|
|
Objectives = {
|
|
"null": lambda obs: 0,
|
|
"max_power": lambda obs: obs["GENERATOR_0_KW"] + obs["GENERATOR_1_KW"] + obs["GENERATOR_2_KW"],
|
|
"episode_time": lambda obs: obs["EPISODE_TIME"],
|
|
}
|
|
|
|
Parameterized_Objectives = {
|
|
"target_temperature": lambda goal_temp: lambda obs: -((obs["CORE_TEMP"] - goal_temp) ** 2),
|
|
"target_gap": lambda goal_gap: lambda obs: -((obs["CORE_TEMP"] - obs["CORE_TEMP_MIN"] - goal_gap) ** 2),
|
|
"temp_below": lambda max_temp: lambda obs: -(np.clip(obs["CORE_TEMP"] - max_temp, 0, np.inf) ** 2),
|
|
"temp_above": lambda min_temp: lambda obs: -(np.clip(min_temp - obs["CORE_TEMP"], 0, np.inf) ** 2),
|
|
"constant": lambda constant: lambda obs: constant,
|
|
}
|
|
|
|
class NuconEnv(gym.Env):
|
|
metadata = {'render_modes': ['human']}
|
|
|
|
def __init__(self, nucon=None, simulator=None, render_mode=None, seconds_per_step=5, objectives=['null'], terminators=['null'], objective_weights=None, terminate_above=0):
|
|
super().__init__()
|
|
|
|
self.render_mode = render_mode
|
|
self.seconds_per_step = seconds_per_step
|
|
if objective_weights is None:
|
|
objective_weights = [1.0 for objective in objectives]
|
|
self.objective_weights = objective_weights
|
|
self.terminate_above = terminate_above
|
|
self.simulator = simulator
|
|
|
|
if nucon is None:
|
|
if simulator:
|
|
nucon = Nucon(port=simulator.port)
|
|
else:
|
|
nucon = Nucon()
|
|
self.nucon = nucon
|
|
|
|
# Define observation space
|
|
obs_spaces = {'EPISODE_TIME': spaces.Box(low=0, high=np.inf, shape=(1,), dtype=np.float32)}
|
|
for param_id, param in self.nucon.get_all_readable().items():
|
|
sp = _build_param_space(param)
|
|
if sp is not None:
|
|
obs_spaces[param_id] = sp
|
|
self.observation_space = spaces.Dict(obs_spaces)
|
|
|
|
# Define action space (only controllable, non-cheat, readable-back params)
|
|
action_spaces = {}
|
|
for param_id, param in self.nucon.get_all_writable().items():
|
|
if not param.is_readable or param.is_cheat:
|
|
continue # write-only (VALVE_OPEN/CLOSE, SCRAM, etc.) and cheat params excluded
|
|
sp = _build_param_space(param)
|
|
if sp is not None:
|
|
action_spaces[param_id] = sp
|
|
self.action_space = spaces.Dict(action_spaces)
|
|
|
|
self.objectives = []
|
|
self.terminators = []
|
|
|
|
for objective in objectives:
|
|
if objective in Objectives:
|
|
self.objectives.append(Objectives[objective])
|
|
elif callable(objective):
|
|
self.objectives.append(objective)
|
|
else:
|
|
raise ValueError(f"Unsupported objective: {objective}")
|
|
|
|
for terminator in terminators:
|
|
if terminator in Objectives:
|
|
self.terminators.append(Objectives[terminator])
|
|
elif callable(terminator):
|
|
self.terminators.append(terminator)
|
|
else:
|
|
raise ValueError(f"Unsupported terminator: {terminator}")
|
|
|
|
def _get_obs(self):
|
|
obs = {}
|
|
for param_id, param in self.nucon.get_all_readable().items():
|
|
if param.param_type == str or param_id not in self.observation_space.spaces:
|
|
continue
|
|
value = self.nucon.get(param_id)
|
|
if isinstance(value, Enum):
|
|
value = value.value
|
|
obs[param_id] = value
|
|
obs["EPISODE_TIME"] = self._total_steps * self.seconds_per_step
|
|
return obs
|
|
|
|
def _get_info(self):
|
|
info = {'objectives': {}, 'objectives_weighted': {}}
|
|
for objective, weight in zip(self.objectives, self.objective_weights):
|
|
obj = objective(self._get_obs())
|
|
info['objectives'][objective.__name__] = obj
|
|
info['objectives_weighted'][objective.__name__] = obj * weight
|
|
return info
|
|
|
|
def reset(self, seed=None, options=None):
|
|
super().reset(seed=seed)
|
|
|
|
self._total_steps = 0
|
|
observation = self._get_obs()
|
|
info = self._get_info()
|
|
|
|
return observation, info
|
|
|
|
def step(self, action):
|
|
# Apply the action to the Nucon system
|
|
for param_id, value in action.items():
|
|
param = self.nucon._parameters[param_id]
|
|
if issubclass(param.param_type, Enum):
|
|
value = param.param_type(int(np.asarray(value).flat[0]))
|
|
else:
|
|
value = param.param_type(np.asarray(value).flat[0])
|
|
if param.min_val is not None and param.max_val is not None:
|
|
value = np.clip(value, param.min_val, param.max_val)
|
|
self.nucon.set(param, value)
|
|
|
|
observation = self._get_obs()
|
|
terminated = np.sum([terminator(observation) for terminator in self.terminators]) > self.terminate_above
|
|
truncated = False
|
|
info = self._get_info()
|
|
reward = sum(obj for obj in info['objectives_weighted'].values())
|
|
|
|
self._total_steps += 1
|
|
if self.simulator:
|
|
self.simulator.update(self.seconds_per_step)
|
|
else:
|
|
# Sleep to let the game advance seconds_per_step game-seconds,
|
|
# accounting for the game's simulation speed multiplier.
|
|
sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0
|
|
time.sleep(self.seconds_per_step / sim_speed)
|
|
return observation, reward, terminated, truncated, info
|
|
|
|
def render(self):
|
|
if self.render_mode == "human":
|
|
pass
|
|
|
|
def close(self):
|
|
pass
|
|
|
|
def _flatten_action(self, action):
|
|
return np.concatenate([v.flatten() for v in action.values()])
|
|
|
|
def _unflatten_action(self, flat_action):
|
|
return {k: v.reshape(1, -1) for k, v in self.action_space.items()}
|
|
|
|
def _flatten_observation(self, observation):
|
|
return np.concatenate([v.flatten() for v in observation.values()])
|
|
|
|
def _unflatten_observation(self, flat_observation):
|
|
return {k: v.reshape(1, -1) for k, v in self.observation_space.items()}
|
|
|
|
|
|
def _build_param_space(param):
|
|
"""Return a gymnasium Box for a single NuconParameter, or None if unsupported."""
|
|
if param.param_type == float:
|
|
return spaces.Box(low=param.min_val or -np.inf, high=param.max_val or np.inf, shape=(1,), dtype=np.float32)
|
|
elif param.param_type == int:
|
|
lo = param.min_val if param.min_val is not None else -np.inf
|
|
hi = param.max_val if param.max_val is not None else np.inf
|
|
return spaces.Box(low=lo, high=hi, shape=(1,), dtype=np.float32)
|
|
elif param.param_type == bool:
|
|
return spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32)
|
|
elif param.param_type == str:
|
|
return None
|
|
elif issubclass(param.param_type, Enum):
|
|
return spaces.Box(low=0, high=len(param.param_type) - 1, shape=(1,), dtype=np.float32)
|
|
return None
|
|
|
|
|
|
class NuconGoalEnv(gym.Env):
|
|
"""
|
|
Goal-conditioned reactor environment compatible with SB3 HER (Hindsight Experience Replay).
|
|
|
|
The observation is a Dict with three keys as required by GoalEnv / HER:
|
|
- 'observation': all readable non-goal, non-str params (same encoding as NuconEnv)
|
|
- 'achieved_goal': current values of goal_params, normalised to [0, 1] within goal_range
|
|
- 'desired_goal': target values sampled each episode, normalised to [0, 1]
|
|
|
|
Reward defaults to negative L2 distance in the normalised goal space (dense).
|
|
Pass ``tolerance`` for a sparse {0, -1} reward (0 = within tolerance).
|
|
|
|
Usage with SB3 HER::
|
|
|
|
from stable_baselines3 import SAC
|
|
from stable_baselines3.common.buffers import HerReplayBuffer
|
|
|
|
env = NuconGoalEnv(
|
|
goal_params=['GENERATOR_0_KW', 'GENERATOR_1_KW', 'GENERATOR_2_KW'],
|
|
goal_range={'GENERATOR_0_KW': (0, 1200), 'GENERATOR_1_KW': (0, 1200), 'GENERATOR_2_KW': (0, 1200)},
|
|
simulator=simulator,
|
|
)
|
|
model = SAC('MultiInputPolicy', env, replay_buffer_class=HerReplayBuffer)
|
|
model.learn(total_timesteps=200_000)
|
|
"""
|
|
|
|
metadata = {'render_modes': ['human']}
|
|
|
|
def __init__(
|
|
self,
|
|
goal_params,
|
|
goal_range=None,
|
|
reward_fn=None,
|
|
tolerance=None,
|
|
nucon=None,
|
|
simulator=None,
|
|
render_mode=None,
|
|
seconds_per_step=5,
|
|
terminators=None,
|
|
terminate_above=0,
|
|
):
|
|
super().__init__()
|
|
|
|
self.render_mode = render_mode
|
|
self.seconds_per_step = seconds_per_step
|
|
self.terminate_above = terminate_above
|
|
self.simulator = simulator
|
|
self.goal_params = list(goal_params)
|
|
self.tolerance = tolerance
|
|
|
|
if nucon is None:
|
|
nucon = Nucon(port=simulator.port) if simulator else Nucon()
|
|
self.nucon = nucon
|
|
|
|
all_readable = self.nucon.get_all_readable()
|
|
|
|
# Validate goal params and build per-param range arrays
|
|
for pid in self.goal_params:
|
|
if pid not in all_readable:
|
|
raise ValueError(f"Goal param '{pid}' is not a readable parameter")
|
|
|
|
goal_range = goal_range or {}
|
|
self._goal_low = np.array([
|
|
goal_range.get(pid, (all_readable[pid].min_val or 0.0, all_readable[pid].max_val or 1.0))[0]
|
|
for pid in self.goal_params
|
|
], dtype=np.float32)
|
|
self._goal_high = np.array([
|
|
goal_range.get(pid, (all_readable[pid].min_val or 0.0, all_readable[pid].max_val or 1.0))[1]
|
|
for pid in self.goal_params
|
|
], dtype=np.float32)
|
|
self._goal_range = self._goal_high - self._goal_low
|
|
self._goal_range[self._goal_range == 0] = 1.0 # avoid div-by-zero
|
|
|
|
self._reward_fn = reward_fn # callable(achieved_norm, desired_norm) -> float, or None
|
|
|
|
# Observation subspace: all readable non-str non-goal params
|
|
goal_set = set(self.goal_params)
|
|
obs_spaces = {'EPISODE_TIME': spaces.Box(low=0, high=np.inf, shape=(1,), dtype=np.float32)}
|
|
for param_id, param in all_readable.items():
|
|
if param_id in goal_set:
|
|
continue
|
|
sp = _build_param_space(param)
|
|
if sp is not None:
|
|
obs_spaces[param_id] = sp
|
|
|
|
n_goals = len(self.goal_params)
|
|
self.observation_space = spaces.Dict({
|
|
'observation': spaces.Dict(obs_spaces),
|
|
'achieved_goal': spaces.Box(low=0.0, high=1.0, shape=(n_goals,), dtype=np.float32),
|
|
'desired_goal': spaces.Box(low=0.0, high=1.0, shape=(n_goals,), dtype=np.float32),
|
|
})
|
|
|
|
# Action space: readable-back, non-cheat writable params
|
|
action_spaces = {}
|
|
for param_id, param in self.nucon.get_all_writable().items():
|
|
if not param.is_readable or param.is_cheat:
|
|
continue
|
|
sp = _build_param_space(param)
|
|
if sp is not None:
|
|
action_spaces[param_id] = sp
|
|
self.action_space = spaces.Dict(action_spaces)
|
|
|
|
# Terminators
|
|
self._terminators = terminators or []
|
|
|
|
self._desired_goal = np.zeros(n_goals, dtype=np.float32)
|
|
self._total_steps = 0
|
|
|
|
# ------------------------------------------------------------------
|
|
# GoalEnv interface
|
|
# ------------------------------------------------------------------
|
|
|
|
def compute_reward(self, achieved_goal, desired_goal, info):
|
|
"""
|
|
Dense: negative L2 in normalised goal space (each dim in [0,1]).
|
|
Sparse when tolerance is set: 0 if within tolerance, -1 otherwise.
|
|
Custom reward_fn overrides both.
|
|
"""
|
|
if self._reward_fn is not None:
|
|
return self._reward_fn(achieved_goal, desired_goal)
|
|
dist = np.linalg.norm(achieved_goal - desired_goal, axis=-1)
|
|
if self.tolerance is not None:
|
|
return (dist <= self.tolerance).astype(np.float32) - 1.0
|
|
return -dist
|
|
|
|
def _read_goal_values(self):
|
|
raw = np.array([
|
|
self.nucon.get(pid) or 0.0 for pid in self.goal_params
|
|
], dtype=np.float32)
|
|
return np.clip((raw - self._goal_low) / self._goal_range, 0.0, 1.0)
|
|
|
|
def _get_obs_dict(self):
|
|
obs = {'EPISODE_TIME': float(self._total_steps * self.seconds_per_step)}
|
|
goal_set = set(self.goal_params)
|
|
for param_id, param in self.nucon.get_all_readable().items():
|
|
if param_id in goal_set or param_id not in self.observation_space['observation'].spaces:
|
|
continue
|
|
value = self.nucon.get(param_id)
|
|
if isinstance(value, Enum):
|
|
value = value.value
|
|
obs[param_id] = value
|
|
achieved = self._read_goal_values()
|
|
return {
|
|
'observation': obs,
|
|
'achieved_goal': achieved,
|
|
'desired_goal': self._desired_goal.copy(),
|
|
}
|
|
|
|
def reset(self, seed=None, options=None):
|
|
super().reset(seed=seed)
|
|
self._total_steps = 0
|
|
|
|
# Sample a new goal uniformly from the goal range
|
|
rng = np.random.default_rng(seed)
|
|
self._desired_goal = rng.uniform(0.0, 1.0, size=len(self.goal_params)).astype(np.float32)
|
|
|
|
obs = self._get_obs_dict()
|
|
return obs, {}
|
|
|
|
def step(self, action):
|
|
for param_id, value in action.items():
|
|
param = self.nucon._parameters[param_id]
|
|
if issubclass(param.param_type, Enum):
|
|
value = param.param_type(int(np.asarray(value).flat[0]))
|
|
else:
|
|
value = param.param_type(np.asarray(value).flat[0])
|
|
if param.min_val is not None and param.max_val is not None:
|
|
value = np.clip(value, param.min_val, param.max_val)
|
|
self.nucon.set(param, value)
|
|
|
|
obs = self._get_obs_dict()
|
|
reward = float(self.compute_reward(obs['achieved_goal'], obs['desired_goal'], {}))
|
|
terminated = any(t(obs['observation']) > self.terminate_above for t in self._terminators)
|
|
truncated = False
|
|
info = {'achieved_goal': obs['achieved_goal'], 'desired_goal': obs['desired_goal']}
|
|
|
|
self._total_steps += 1
|
|
if self.simulator:
|
|
self.simulator.update(self.seconds_per_step)
|
|
else:
|
|
sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0
|
|
time.sleep(self.seconds_per_step / sim_speed)
|
|
|
|
return obs, reward, terminated, truncated, info
|
|
|
|
def render(self):
|
|
pass
|
|
|
|
def close(self):
|
|
pass
|
|
|
|
|
|
def register_nucon_envs():
|
|
gym.register(
|
|
id='Nucon-max_power-v0',
|
|
entry_point='nucon.rl:NuconEnv',
|
|
kwargs={'seconds_per_step': 5, 'objectives': ['max_power']}
|
|
)
|
|
gym.register(
|
|
id='Nucon-target_temperature_350-v0',
|
|
entry_point='nucon.rl:NuconEnv',
|
|
kwargs={'seconds_per_step': 5, 'objectives': [Parameterized_Objectives['target_temperature'](goal_temp=350)]}
|
|
)
|
|
gym.register(
|
|
id='Nucon-safe_max_power-v0',
|
|
entry_point='nucon.rl:NuconEnv',
|
|
kwargs={'seconds_per_step': 5, 'objectives': [Parameterized_Objectives['temp_above'](min_temp=310), Parameterized_Objectives['temp_below'](max_temp=365), 'max_power'], 'objective_weights': [1, 10, 1/100_000]}
|
|
)
|
|
# Goal-conditioned: target total generator output (train with HER)
|
|
gym.register(
|
|
id='Nucon-goal_power-v0',
|
|
entry_point='nucon.rl:NuconGoalEnv',
|
|
kwargs={
|
|
'goal_params': ['GENERATOR_0_KW', 'GENERATOR_1_KW', 'GENERATOR_2_KW'],
|
|
'goal_range': {'GENERATOR_0_KW': (0.0, 1200.0), 'GENERATOR_1_KW': (0.0, 1200.0), 'GENERATOR_2_KW': (0.0, 1200.0)},
|
|
'seconds_per_step': 5,
|
|
}
|
|
)
|
|
# Goal-conditioned: target core temperature (train with HER)
|
|
gym.register(
|
|
id='Nucon-goal_temp-v0',
|
|
entry_point='nucon.rl:NuconGoalEnv',
|
|
kwargs={
|
|
'goal_params': ['CORE_TEMP'],
|
|
'goal_range': {'CORE_TEMP': (280.0, 380.0)},
|
|
'seconds_per_step': 5,
|
|
}
|
|
)
|
|
|
|
register_nucon_envs() |