feat: SAC+HER training on kNN-GP sim with direct bypass and scripts/
- nucon/rl.py: delta_action_scale action space, bool handling (>=0.5), direct sim read/write bypassing HTTP for ~2000fps env throughput; remove uncertainty_abort from training (use penalty-only), larger default batch sizes; fix _read_obs and step for in-process sim - nucon/model.py: optimise _lookup with einsum squared-L2, vectorised rbf kernel; forward_with_uncertainty uses pre-built normalised arrays - nucon/sim.py: _update_reactor_state writes outputs via setattr directly - scripts/train_sac.py: moved from root; full SAC+HER example with kNN-GP sim, delta actions, uncertainty penalty, init_states - scripts/collect_dataset.py: CLI tool to collect dynamics dataset from live game session (--steps, --delta, --out, --merge) - README.md: add Scripts section, reference both scripts in training loop Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
3dfe1aa673
commit
0932bb353a
23
README.md
23
README.md
@ -404,11 +404,11 @@ The recommended end-to-end workflow for training an RL operator is an iterative
|
|||||||
└─────────────────────┘
|
└─────────────────────┘
|
||||||
```
|
```
|
||||||
|
|
||||||
**Step 1 — Human dataset collection**: Run `NuconModelLearner.collect_data()` during your play session. Cover a wide range of states: startup from cold, ramping power, individual rod bank adjustments. Diversity in the dataset directly determines simulator accuracy. See [Model Learning](#model-learning) for collection details.
|
**Step 1 — Human dataset collection**: Run `scripts/collect_dataset.py` during your play session (see [Scripts](#scripts)). Cover a wide range of states: startup from cold, ramping power, individual rod bank adjustments. Diversity in the dataset directly determines simulator accuracy. See [Model Learning](#model-learning) for collection details.
|
||||||
|
|
||||||
**Step 2 — Initial model fitting**: Fit a kNN-GP model (instant) or NN (better extrapolation with larger datasets) using `fit_knn()` or `train_model()`. Prune near-duplicate samples with `drop_redundant()` before fitting. See [Model Learning](#model-learning).
|
**Step 2 — Initial model fitting**: Fit a kNN-GP model (instant) or NN (better extrapolation with larger datasets) using `fit_knn()` or `train_model()`. Prune near-duplicate samples with `drop_redundant()` before fitting. See [Model Learning](#model-learning).
|
||||||
|
|
||||||
**Step 3 — Train RL in simulator**: Load the fitted model into `NuconSimulator`, then train a `NuconGoalEnv` policy with SAC + HER. The simulator runs far faster than the real game, allowing many trajectories in reasonable time. Pass `Parameterized_Objectives['uncertainty_penalty']` and `Parameterized_Terminators['uncertainty_abort']` as additional objectives/terminators to discourage the policy from wandering into regions the model hasn't seen; `SIM_UNCERTAINTY` is automatically injected into the obs dict when a simulator is active. See [NuconGoalEnv + HER Usage](#nucongoalenv--her-usage).
|
**Step 3 — Train RL in simulator**: Load the fitted model into `NuconSimulator`, then train a `NuconGoalEnv` policy with SAC + HER. The simulator runs far faster than the real game, allowing many trajectories in reasonable time. Pass `Parameterized_Objectives['uncertainty_penalty']` and `Parameterized_Terminators['uncertainty_abort']` as additional objectives/terminators to discourage the policy from wandering into regions the model hasn't seen; `SIM_UNCERTAINTY` is automatically injected into the obs dict when a simulator is active. See [NuconGoalEnv + HER Usage](#nucongoalenv--her-usage) and `scripts/train_sac.py` for a complete example.
|
||||||
|
|
||||||
**Step 4 — Eval in game + collect new data**: Run the trained policy against the real game. This validates simulator accuracy and simultaneously collects new data from states the policy visits, which may be regions the original dataset missed. Run a second `NuconModelLearner` in a background thread to collect concurrently.
|
**Step 4 — Eval in game + collect new data**: Run the trained policy against the real game. This validates simulator accuracy and simultaneously collects new data from states the policy visits, which may be regions the original dataset missed. Run a second `NuconModelLearner` in a background thread to collect concurrently.
|
||||||
|
|
||||||
@ -416,6 +416,25 @@ The recommended end-to-end workflow for training an RL operator is an iterative
|
|||||||
|
|
||||||
Stop when the policy performs well in the real game and kNN-GP uncertainty stays low throughout an episode, indicating the policy stays within the known data distribution.
|
Stop when the policy performs well in the real game and kNN-GP uncertainty stays low throughout an episode, indicating the policy stays within the known data distribution.
|
||||||
|
|
||||||
|
## Scripts
|
||||||
|
|
||||||
|
Ready-to-run scripts in the `scripts/` directory covering the most common workflows.
|
||||||
|
|
||||||
|
**`scripts/collect_dataset.py`** — collect a dynamics dataset while playing the game:
|
||||||
|
```bash
|
||||||
|
python scripts/collect_dataset.py --steps 1000 --delta 10 --out reactor_dataset.pkl
|
||||||
|
# Ctrl-C to stop early; data is saved on exit
|
||||||
|
# Merge a previous session: --merge previous.pkl
|
||||||
|
```
|
||||||
|
|
||||||
|
**`scripts/train_sac.py`** — train a SAC + HER goal-conditioned policy on the kNN-GP simulator:
|
||||||
|
```bash
|
||||||
|
python scripts/train_sac.py
|
||||||
|
# Expects /tmp/reactor_knn.pkl and /tmp/nucon_dataset.pkl
|
||||||
|
# Saves trained policy to /tmp/sac_nucon_knn.zip
|
||||||
|
```
|
||||||
|
This script is the most elaborate end-to-end example: it loads a pre-fitted kNN-GP model, seeds episode resets from dataset states, uses delta actions and an uncertainty penalty, and configures SAC + HER for fast sim training.
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
NuCon includes a test suite to verify its functionality and compatibility with the Nucleares game.
|
NuCon includes a test suite to verify its functionality and compatibility with the Nucleares game.
|
||||||
|
|||||||
@ -98,11 +98,11 @@ class ReactorKNNModel:
|
|||||||
self._std = self._raw_states.std(axis=0) + 1e-8
|
self._std = self._raw_states.std(axis=0) + 1e-8
|
||||||
self._states = (self._raw_states - self._mean) / self._std
|
self._states = (self._raw_states - self._mean) / self._std
|
||||||
|
|
||||||
def _lookup(self, state_dict: Dict):
|
def _lookup(self, s: np.ndarray):
|
||||||
"""Return (s_norm, idx, k) for the k nearest neighbours."""
|
"""Return (s_norm, idx, k) for the k nearest neighbours. s is a raw (d_in,) array."""
|
||||||
s = np.array([state_dict[p] for p in self.input_params], dtype=np.float32)
|
|
||||||
s_norm = (s - self._mean) / self._std
|
s_norm = (s - self._mean) / self._std
|
||||||
dists = np.linalg.norm(self._states - s_norm, axis=1)
|
diff = self._states - s_norm # (n, d_in) broadcast
|
||||||
|
dists = np.einsum('ij,ij->i', diff, diff) # squared L2, faster than linalg.norm
|
||||||
k = min(self.k, len(dists))
|
k = min(self.k, len(dists))
|
||||||
idx = np.argpartition(dists, k - 1)[:k]
|
idx = np.argpartition(dists, k - 1)[:k]
|
||||||
return s_norm, idx, k
|
return s_norm, idx, k
|
||||||
@ -122,22 +122,22 @@ class ReactorKNNModel:
|
|||||||
if self._states is None:
|
if self._states is None:
|
||||||
raise ValueError("Model not fitted. Call fit(dataset) first.")
|
raise ValueError("Model not fitted. Call fit(dataset) first.")
|
||||||
|
|
||||||
s_norm, idx, k = self._lookup(state_dict)
|
s = np.array([state_dict[p] for p in self.input_params], dtype=np.float32)
|
||||||
|
s_norm, idx, k = self._lookup(s)
|
||||||
X = self._states[idx] # (k, d_in)
|
X = self._states[idx] # (k, d_in)
|
||||||
Y = self._rates[idx] # (k, d_out)
|
Y = self._rates[idx] # (k, d_out)
|
||||||
|
|
||||||
# RBF kernel (vectorised): k(a,b) = exp(-0.5 ||a-b||^2)
|
# RBF kernel: k(a,b) = exp(-0.5 ||a-b||^2)
|
||||||
def rbf_matrix(A, B):
|
def rbf(A, B):
|
||||||
diff = A[:, None, :] - B[None, :, :] # (|A|, |B|, d)
|
diff = A[:, None, :] - B[None, :, :]
|
||||||
return np.exp(-0.5 * (diff ** 2).sum(axis=-1)) # (|A|, |B|)
|
return np.exp(-0.5 * np.einsum('ijk,ijk->ij', diff, diff))
|
||||||
|
|
||||||
K = rbf_matrix(X, X) + 1e-4 * np.eye(k) # (k, k)
|
K = rbf(X, X) + 1e-4 * np.eye(k)
|
||||||
k_star = rbf_matrix(s_norm[None, :], X)[0] # (k,)
|
k_star = rbf(s_norm[None, :], X)[0]
|
||||||
|
|
||||||
K_inv = np.linalg.inv(K)
|
K_inv = np.linalg.inv(K)
|
||||||
mean_rates = k_star @ K_inv @ Y # (d_out,)
|
mean_rates = k_star @ K_inv @ Y
|
||||||
|
|
||||||
# Posterior variance (scalar, shared across all output dims)
|
|
||||||
var = max(0.0, 1.0 - float(k_star @ K_inv @ k_star))
|
var = max(0.0, 1.0 - float(k_star @ K_inv @ k_star))
|
||||||
std = float(np.sqrt(var))
|
std = float(np.sqrt(var))
|
||||||
|
|
||||||
|
|||||||
110
nucon/rl.py
110
nucon/rl.py
@ -49,13 +49,15 @@ Parameterized_Terminators = {
|
|||||||
# Internal helpers
|
# Internal helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def _build_flat_action_space(nucon, obs_param_set=None):
|
def _build_flat_action_space(nucon, obs_param_set=None, delta_action_scale=None):
|
||||||
"""Return (Box, ordered_param_ids) for all writable, readable, non-cheat params.
|
"""Return (Box, ordered_param_ids, param_ranges).
|
||||||
|
|
||||||
If obs_param_set is provided, only include params in that set.
|
If delta_action_scale is set, the action space is [-1, 1]^n and actions are
|
||||||
|
treated as normalised deltas: actual_delta = action * delta_action_scale * (max - min).
|
||||||
|
Otherwise the action space spans [min_val, max_val] per param (absolute values).
|
||||||
"""
|
"""
|
||||||
params = []
|
params = []
|
||||||
lows, highs = [], []
|
lows, highs, ranges = [], [], []
|
||||||
for param_id, param in nucon.get_all_writable().items():
|
for param_id, param in nucon.get_all_writable().items():
|
||||||
if not param.is_readable or param.is_cheat:
|
if not param.is_readable or param.is_cheat:
|
||||||
continue
|
continue
|
||||||
@ -69,9 +71,15 @@ def _build_flat_action_space(nucon, obs_param_set=None):
|
|||||||
params.append(param_id)
|
params.append(param_id)
|
||||||
lows.append(sp.low[0])
|
lows.append(sp.low[0])
|
||||||
highs.append(sp.high[0])
|
highs.append(sp.high[0])
|
||||||
|
ranges.append(sp.high[0] - sp.low[0])
|
||||||
|
if delta_action_scale is not None:
|
||||||
|
n = len(params)
|
||||||
|
box = spaces.Box(low=-np.ones(n, dtype=np.float32),
|
||||||
|
high=np.ones(n, dtype=np.float32), dtype=np.float32)
|
||||||
|
else:
|
||||||
box = spaces.Box(low=np.array(lows, dtype=np.float32),
|
box = spaces.Box(low=np.array(lows, dtype=np.float32),
|
||||||
high=np.array(highs, dtype=np.float32), dtype=np.float32)
|
high=np.array(highs, dtype=np.float32), dtype=np.float32)
|
||||||
return box, params
|
return box, params, np.array(lows, dtype=np.float32), np.array(ranges, dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
def _unflatten_action(flat_action, param_ids):
|
def _unflatten_action(flat_action, param_ids):
|
||||||
@ -96,12 +104,15 @@ def _build_param_space(param):
|
|||||||
def _apply_action(nucon, action):
|
def _apply_action(nucon, action):
|
||||||
for param_id, value in action.items():
|
for param_id, value in action.items():
|
||||||
param = nucon._parameters[param_id]
|
param = nucon._parameters[param_id]
|
||||||
if issubclass(param.param_type, Enum):
|
v = float(np.asarray(value).flat[0])
|
||||||
value = param.param_type(int(np.asarray(value).flat[0]))
|
if param.param_type == bool:
|
||||||
|
value = v >= 0.5 # [0,1] space: above midpoint → True
|
||||||
|
elif issubclass(param.param_type, Enum):
|
||||||
|
value = param.param_type(int(v))
|
||||||
else:
|
else:
|
||||||
value = param.param_type(np.asarray(value).flat[0])
|
value = param.param_type(v)
|
||||||
if param.min_val is not None and param.max_val is not None:
|
if param.min_val is not None and param.max_val is not None:
|
||||||
value = np.clip(value, param.min_val, param.max_val)
|
value = param.param_type(np.clip(value, param.min_val, param.max_val))
|
||||||
nucon.set(param, value)
|
nucon.set(param, value)
|
||||||
|
|
||||||
|
|
||||||
@ -138,7 +149,8 @@ class NuconEnv(gym.Env):
|
|||||||
obs_spaces[param_id] = sp
|
obs_spaces[param_id] = sp
|
||||||
self.observation_space = spaces.Dict(obs_spaces)
|
self.observation_space = spaces.Dict(obs_spaces)
|
||||||
|
|
||||||
self.action_space, self._action_params = _build_flat_action_space(self.nucon)
|
self.action_space, self._action_params, self._action_lows, self._action_ranges = \
|
||||||
|
_build_flat_action_space(self.nucon)
|
||||||
|
|
||||||
self.objectives = []
|
self.objectives = []
|
||||||
self.terminators = []
|
self.terminators = []
|
||||||
@ -272,11 +284,14 @@ class NuconGoalEnv(gym.Env):
|
|||||||
additional_objectives=None,
|
additional_objectives=None,
|
||||||
additional_objective_weights=None,
|
additional_objective_weights=None,
|
||||||
obs_params=None,
|
obs_params=None,
|
||||||
|
init_states=None,
|
||||||
|
delta_action_scale=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.render_mode = render_mode
|
self.render_mode = render_mode
|
||||||
self.seconds_per_step = seconds_per_step
|
self.seconds_per_step = seconds_per_step
|
||||||
|
self._delta_action_scale = delta_action_scale
|
||||||
self.terminate_above = terminate_above
|
self.terminate_above = terminate_above
|
||||||
self.simulator = simulator
|
self.simulator = simulator
|
||||||
self.goal_params = list(goal_params)
|
self.goal_params = list(goal_params)
|
||||||
@ -339,12 +354,14 @@ class NuconGoalEnv(gym.Env):
|
|||||||
})
|
})
|
||||||
|
|
||||||
# Action space: writable params within the obs param set (flat Box for SB3 compatibility).
|
# Action space: writable params within the obs param set (flat Box for SB3 compatibility).
|
||||||
self.action_space, self._action_params = _build_flat_action_space(self.nucon, set(base_params))
|
self.action_space, self._action_params, self._action_lows, self._action_ranges = \
|
||||||
|
_build_flat_action_space(self.nucon, set(base_params), delta_action_scale)
|
||||||
|
|
||||||
self._terminators = terminators or []
|
self._terminators = terminators or []
|
||||||
_objs = additional_objectives or []
|
_objs = additional_objectives or []
|
||||||
self._objectives = [Objectives[o] if isinstance(o, str) else o for o in _objs]
|
self._objectives = [Objectives[o] if isinstance(o, str) else o for o in _objs]
|
||||||
self._objective_weights = additional_objective_weights or [1.0] * len(self._objectives)
|
self._objective_weights = additional_objective_weights or [1.0] * len(self._objectives)
|
||||||
|
self._init_states = init_states # list of state dicts to sample on reset
|
||||||
self._desired_goal = np.zeros(n_goals, dtype=np.float32)
|
self._desired_goal = np.zeros(n_goals, dtype=np.float32)
|
||||||
self._total_steps = 0
|
self._total_steps = 0
|
||||||
|
|
||||||
@ -367,20 +384,37 @@ class NuconGoalEnv(gym.Env):
|
|||||||
def _read_obs(self, sim_uncertainty=None):
|
def _read_obs(self, sim_uncertainty=None):
|
||||||
"""Return (gym_obs_dict, reward_obs_dict).
|
"""Return (gym_obs_dict, reward_obs_dict).
|
||||||
|
|
||||||
gym_obs_dict — flat Box observation for the policy (no SIM_UNCERTAINTY).
|
When a simulator is attached, reads directly from sim.parameters (no HTTP).
|
||||||
reward_obs_dict — same values plus SIM_UNCERTAINTY for objectives/terminators/reward_fn.
|
Otherwise falls back to a single batch HTTP request.
|
||||||
"""
|
"""
|
||||||
|
def _to_float(v):
|
||||||
|
if v is None:
|
||||||
|
return 0.0
|
||||||
|
return float(v.value if isinstance(v, Enum) else v)
|
||||||
|
|
||||||
|
if self.simulator is not None:
|
||||||
|
# Direct in-process read — no HTTP overhead
|
||||||
|
def _get(pid):
|
||||||
|
return _to_float(self.simulator.get(pid))
|
||||||
|
else:
|
||||||
|
raw = self.nucon._batch_query(self._obs_params + self.goal_params)
|
||||||
|
all_params = self.nucon.get_all_readable()
|
||||||
|
def _get(pid):
|
||||||
|
try:
|
||||||
|
v = self.nucon._parse_value(all_params[pid], raw.get(pid, '0'))
|
||||||
|
return _to_float(v)
|
||||||
|
except Exception:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
reward_obs = {}
|
reward_obs = {}
|
||||||
if self._obs_with_uncertainty:
|
if self._obs_with_uncertainty:
|
||||||
reward_obs['SIM_UNCERTAINTY'] = float(sim_uncertainty) if sim_uncertainty is not None else 0.0
|
reward_obs['SIM_UNCERTAINTY'] = float(sim_uncertainty) if sim_uncertainty is not None else 0.0
|
||||||
for param_id in self._obs_params:
|
for pid in self._obs_params:
|
||||||
value = self.nucon.get(param_id)
|
reward_obs[pid] = _get(pid)
|
||||||
if isinstance(value, Enum):
|
|
||||||
value = value.value
|
|
||||||
reward_obs[param_id] = float(value) if value is not None else 0.0
|
|
||||||
|
|
||||||
obs_vec = np.array([reward_obs[p] for p in self._obs_params], dtype=np.float32)
|
obs_vec = np.array([reward_obs[p] for p in self._obs_params], dtype=np.float32)
|
||||||
achieved = self._read_goal_values()
|
goal_raw = np.array([_get(p) for p in self.goal_params], dtype=np.float32)
|
||||||
|
achieved = np.clip((goal_raw - self._goal_low) / self._goal_range, 0.0, 1.0)
|
||||||
gym_obs = {'observation': obs_vec, 'achieved_goal': achieved,
|
gym_obs = {'observation': obs_vec, 'achieved_goal': achieved,
|
||||||
'desired_goal': self._desired_goal.copy()}
|
'desired_goal': self._desired_goal.copy()}
|
||||||
return gym_obs, reward_obs
|
return gym_obs, reward_obs
|
||||||
@ -390,11 +424,47 @@ class NuconGoalEnv(gym.Env):
|
|||||||
self._total_steps = 0
|
self._total_steps = 0
|
||||||
rng = np.random.default_rng(seed)
|
rng = np.random.default_rng(seed)
|
||||||
self._desired_goal = rng.uniform(0.0, 1.0, size=len(self.goal_params)).astype(np.float32)
|
self._desired_goal = rng.uniform(0.0, 1.0, size=len(self.goal_params)).astype(np.float32)
|
||||||
|
if self._init_states is not None and self.simulator is not None:
|
||||||
|
state = self._init_states[rng.integers(len(self._init_states))]
|
||||||
|
for k, v in state.items():
|
||||||
|
try:
|
||||||
|
self.simulator.set(k, v, force=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
gym_obs, _ = self._read_obs()
|
gym_obs, _ = self._read_obs()
|
||||||
return gym_obs, {}
|
return gym_obs, {}
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
_apply_action(self.nucon, _unflatten_action(action, self._action_params))
|
flat = np.asarray(action, dtype=np.float32)
|
||||||
|
if self._delta_action_scale is not None:
|
||||||
|
# Compute absolute values from deltas, reading current state directly if possible
|
||||||
|
absolute = {}
|
||||||
|
for i, pid in enumerate(self._action_params):
|
||||||
|
param = self.nucon._parameters[pid]
|
||||||
|
if param.param_type == bool:
|
||||||
|
absolute[pid] = 1.0 if flat[i] > 0 else 0.0
|
||||||
|
else:
|
||||||
|
if self.simulator is not None:
|
||||||
|
v = self.simulator.get(pid)
|
||||||
|
current = float(v.value if isinstance(v, Enum) else v) if v is not None else 0.0
|
||||||
|
else:
|
||||||
|
current = 0.0 # fallback; batch read not worth it for actions alone
|
||||||
|
delta = float(flat[i]) * self._delta_action_scale * self._action_ranges[i]
|
||||||
|
absolute[pid] = float(np.clip(current + delta,
|
||||||
|
self._action_lows[i],
|
||||||
|
self._action_lows[i] + self._action_ranges[i]))
|
||||||
|
else:
|
||||||
|
absolute = _unflatten_action(flat, self._action_params)
|
||||||
|
|
||||||
|
if self.simulator is not None:
|
||||||
|
# Write directly to sim — skip HTTP entirely
|
||||||
|
for pid, val in absolute.items():
|
||||||
|
try:
|
||||||
|
self.simulator.set(pid, val, force=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
_apply_action(self.nucon, absolute)
|
||||||
|
|
||||||
if self.simulator:
|
if self.simulator:
|
||||||
uncertainty = self.simulator.update(self.seconds_per_step, return_uncertainty=True)
|
uncertainty = self.simulator.update(self.seconds_per_step, return_uncertainty=True)
|
||||||
|
|||||||
14
nucon/sim.py
14
nucon/sim.py
@ -261,14 +261,13 @@ class NuconSimulator:
|
|||||||
raise ValueError("Model not set. Please load a model using load_model() or set_model().")
|
raise ValueError("Model not set. Please load a model using load_model() or set_model().")
|
||||||
|
|
||||||
# Build state dict using only the params the model knows about
|
# Build state dict using only the params the model knows about
|
||||||
|
params = self.parameters
|
||||||
state = {}
|
state = {}
|
||||||
for param_id in self.model.input_params:
|
for param_id in self.model.input_params:
|
||||||
value = getattr(self.parameters, param_id, None)
|
value = getattr(params, param_id, None)
|
||||||
if isinstance(value, Enum):
|
if isinstance(value, Enum):
|
||||||
value = value.value
|
value = value.value
|
||||||
if value is None:
|
state[param_id] = 0.0 if value is None else value
|
||||||
value = 0.0 # fallback for params not initialised in sim state
|
|
||||||
state[param_id] = value
|
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
uncertainty = None
|
uncertainty = None
|
||||||
@ -280,12 +279,9 @@ class NuconSimulator:
|
|||||||
else:
|
else:
|
||||||
next_state = self.model.forward(state, time_step)
|
next_state = self.model.forward(state, time_step)
|
||||||
|
|
||||||
# Update only the output params the model predicts
|
# Write outputs directly — bypass sim.set() type-checking overhead
|
||||||
for param_id, value in next_state.items():
|
for param_id, value in next_state.items():
|
||||||
try:
|
setattr(params, param_id, value)
|
||||||
self.set(param_id, value, force=True)
|
|
||||||
except (ValueError, KeyError):
|
|
||||||
pass # ignore params that can't be set (type mismatch, unknown)
|
|
||||||
|
|
||||||
return uncertainty
|
return uncertainty
|
||||||
|
|
||||||
|
|||||||
55
scripts/collect_dataset.py
Normal file
55
scripts/collect_dataset.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
"""Collect a dynamics dataset from the running Nucleares game.
|
||||||
|
|
||||||
|
Play the game normally while this script runs in the background.
|
||||||
|
It records state transitions every `time_delta` game-seconds and
|
||||||
|
saves them incrementally so nothing is lost if you quit early.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python scripts/collect_dataset.py # default settings
|
||||||
|
python scripts/collect_dataset.py --steps 2000 --delta 5 # faster sampling
|
||||||
|
python scripts/collect_dataset.py --out my_dataset.pkl
|
||||||
|
|
||||||
|
The saved dataset is a list of (state_before, action_dict, state_after, time_delta)
|
||||||
|
tuples compatible with NuconModelLearner.fit_knn() and train_model().
|
||||||
|
|
||||||
|
Tips for good data:
|
||||||
|
- Cover a range of operating states: startup, ramp, steady-state, shutdown.
|
||||||
|
- Vary individual rod bank positions, pump speeds, and MSCV setpoints.
|
||||||
|
- Collect at least 500 samples for kNN-GP; 5000+ for the NN backend.
|
||||||
|
- Merge multiple sessions with NuconModelLearner.merge_datasets().
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import pickle
|
||||||
|
from nucon.model import NuconModelLearner
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--steps', type=int, default=1000,
|
||||||
|
help='Number of samples to collect (default: 1000)')
|
||||||
|
parser.add_argument('--delta', type=float, default=10.0,
|
||||||
|
help='Game-seconds between samples (default: 10.0)')
|
||||||
|
parser.add_argument('--out', default='reactor_dataset.pkl',
|
||||||
|
help='Output path for dataset (default: reactor_dataset.pkl)')
|
||||||
|
parser.add_argument('--merge', default=None,
|
||||||
|
help='Existing dataset to merge into before saving')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
learner = NuconModelLearner(
|
||||||
|
time_delta=args.delta,
|
||||||
|
dataset_path=args.out,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.merge:
|
||||||
|
learner.merge_datasets(args.merge)
|
||||||
|
print(f"Merged existing dataset from {args.merge} ({len(learner.dataset)} samples)")
|
||||||
|
|
||||||
|
print(f"Collecting {args.steps} samples (Δt={args.delta}s each) → {args.out}")
|
||||||
|
print("Play the game — vary rod positions, pump speeds, and operating states.")
|
||||||
|
print("Press Ctrl-C to stop early; data collected so far will be saved.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
learner.collect_data(num_steps=args.steps)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nInterrupted — saving collected data...")
|
||||||
|
|
||||||
|
learner.save_dataset(args.out)
|
||||||
|
print(f"Saved {len(learner.dataset)} samples to {args.out}")
|
||||||
83
scripts/train_sac.py
Normal file
83
scripts/train_sac.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
"""SAC + HER training on kNN-GP simulator.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3.14 train_sac.py
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
- NuCon game running (for parameter metadata)
|
||||||
|
- /tmp/reactor_knn.pkl (kNN-GP model)
|
||||||
|
- /tmp/nucon_dataset.pkl (500-sample dataset for init_states)
|
||||||
|
"""
|
||||||
|
import pickle
|
||||||
|
from gymnasium.wrappers import TimeLimit
|
||||||
|
from stable_baselines3 import SAC
|
||||||
|
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
|
||||||
|
|
||||||
|
from nucon.sim import NuconSimulator
|
||||||
|
from nucon.rl import NuconGoalEnv, Parameterized_Objectives, Parameterized_Terminators
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Load model and dataset
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
with open('/tmp/reactor_knn.pkl', 'rb') as f:
|
||||||
|
knn_model = pickle.load(f)
|
||||||
|
|
||||||
|
with open('/tmp/nucon_dataset.pkl', 'rb') as f:
|
||||||
|
dataset = pickle.load(f)
|
||||||
|
|
||||||
|
# Seed resets to in-distribution states from dataset
|
||||||
|
init_states = [s for _, _, s, _ in dataset]
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Build sim + env
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
sim = NuconSimulator(port=8786)
|
||||||
|
sim.set_model(knn_model)
|
||||||
|
|
||||||
|
BATCH_SIZE = 2048
|
||||||
|
MAX_EPISODE_STEPS = 200
|
||||||
|
|
||||||
|
env = NuconGoalEnv(
|
||||||
|
goal_params=['CORE_TEMP'],
|
||||||
|
goal_range={'CORE_TEMP': (55.0, 550.0)},
|
||||||
|
tolerance=0.05,
|
||||||
|
seconds_per_step=10,
|
||||||
|
simulator=sim,
|
||||||
|
additional_objectives=[
|
||||||
|
Parameterized_Objectives['uncertainty_penalty'](start=0.3),
|
||||||
|
],
|
||||||
|
additional_objective_weights=[1.0],
|
||||||
|
init_states=init_states,
|
||||||
|
delta_action_scale=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
env = TimeLimit(env, max_episode_steps=MAX_EPISODE_STEPS)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SAC + HER
|
||||||
|
# learning_starts = batch_size: wait for batch_size complete (short) episodes
|
||||||
|
# before the first gradient step. As the policy learns to stay in-dist, episodes
|
||||||
|
# will get longer and HER has more transitions to relabel.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
model = SAC(
|
||||||
|
'MultiInputPolicy',
|
||||||
|
env,
|
||||||
|
replay_buffer_class=HerReplayBuffer,
|
||||||
|
replay_buffer_kwargs={
|
||||||
|
'n_sampled_goal': 4,
|
||||||
|
'goal_selection_strategy': 'future',
|
||||||
|
},
|
||||||
|
verbose=1,
|
||||||
|
learning_rate=3e-4,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
tau=0.005,
|
||||||
|
gamma=0.98,
|
||||||
|
train_freq=64,
|
||||||
|
gradient_steps=8,
|
||||||
|
learning_starts=BATCH_SIZE,
|
||||||
|
device='auto',
|
||||||
|
)
|
||||||
|
|
||||||
|
model.learn(total_timesteps=50_000)
|
||||||
|
model.save('/tmp/sac_nucon_knn')
|
||||||
|
print("Saved to /tmp/sac_nucon_knn.zip")
|
||||||
Loading…
Reference in New Issue
Block a user