feat: uncertainty-aware training with penalty and abort
sim.py:
- simulator.update(return_uncertainty=True) calls forward_with_uncertainty
on kNN models and returns the GP std; returns None for NN or when not
requested (no extra cost if unused)
- No state stored on simulator; caller decides what to do with the value
rl.py (NuconEnv and NuconGoalEnv):
- uncertainty_penalty_start: above this GP std, subtract a linear penalty
from the reward (scaled by uncertainty_penalty_scale, default 1.0)
- uncertainty_abort: at or above this GP std, set truncated=True
- Only calls update(return_uncertainty=True) when either threshold is set
- Uncertainty only applies when using a simulator (kNN model); ignored otherwise
Example:
simulator = NuconSimulator()
simulator.load_model('reactor_knn.pkl')
env = NuconGoalEnv(..., simulator=simulator,
uncertainty_penalty_start=0.3,
uncertainty_abort=0.7,
uncertainty_penalty_scale=2.0)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
6cb93ad56d
commit
65190dffea
42
nucon/rl.py
42
nucon/rl.py
@ -5,10 +5,6 @@ import time
|
||||
from typing import Dict, Any
|
||||
from enum import Enum
|
||||
from nucon import Nucon, BreakerStatus, PumpStatus, PumpDryStatus, PumpOverloadStatus
|
||||
try:
|
||||
from nucon.sim import HighUncertaintyError
|
||||
except ImportError:
|
||||
HighUncertaintyError = None
|
||||
|
||||
Objectives = {
|
||||
"null": lambda obs: 0,
|
||||
@ -27,7 +23,8 @@ Parameterized_Objectives = {
|
||||
class NuconEnv(gym.Env):
|
||||
metadata = {'render_modes': ['human']}
|
||||
|
||||
def __init__(self, nucon=None, simulator=None, render_mode=None, seconds_per_step=5, objectives=['null'], terminators=['null'], objective_weights=None, terminate_above=0):
|
||||
def __init__(self, nucon=None, simulator=None, render_mode=None, seconds_per_step=5, objectives=['null'], terminators=['null'], objective_weights=None, terminate_above=0,
|
||||
uncertainty_penalty_start: float = None, uncertainty_abort: float = None, uncertainty_penalty_scale: float = 1.0):
|
||||
super().__init__()
|
||||
|
||||
self.render_mode = render_mode
|
||||
@ -37,6 +34,9 @@ class NuconEnv(gym.Env):
|
||||
self.objective_weights = objective_weights
|
||||
self.terminate_above = terminate_above
|
||||
self.simulator = simulator
|
||||
self.uncertainty_penalty_start = uncertainty_penalty_start
|
||||
self.uncertainty_abort = uncertainty_abort
|
||||
self.uncertainty_penalty_scale = uncertainty_penalty_scale
|
||||
|
||||
if nucon is None:
|
||||
if simulator:
|
||||
@ -131,16 +131,14 @@ class NuconEnv(gym.Env):
|
||||
|
||||
self._total_steps += 1
|
||||
if self.simulator:
|
||||
try:
|
||||
self.simulator.update(self.seconds_per_step)
|
||||
except Exception as e:
|
||||
if HighUncertaintyError and isinstance(e, HighUncertaintyError):
|
||||
needs_uncertainty = self.uncertainty_penalty_start is not None or self.uncertainty_abort is not None
|
||||
uncertainty = self.simulator.update(self.seconds_per_step, return_uncertainty=needs_uncertainty)
|
||||
if uncertainty is not None:
|
||||
if self.uncertainty_abort is not None and uncertainty >= self.uncertainty_abort:
|
||||
truncated = True
|
||||
if self.uncertainty_penalty_start is not None and uncertainty > self.uncertainty_penalty_start:
|
||||
reward -= self.uncertainty_penalty_scale * (uncertainty - self.uncertainty_penalty_start)
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
# Sleep to let the game advance seconds_per_step game-seconds,
|
||||
# accounting for the game's simulation speed multiplier.
|
||||
sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0
|
||||
time.sleep(self.seconds_per_step / sim_speed)
|
||||
return observation, reward, terminated, truncated, info
|
||||
@ -222,6 +220,9 @@ class NuconGoalEnv(gym.Env):
|
||||
seconds_per_step=5,
|
||||
terminators=None,
|
||||
terminate_above=0,
|
||||
uncertainty_penalty_start: float = None,
|
||||
uncertainty_abort: float = None,
|
||||
uncertainty_penalty_scale: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -286,6 +287,9 @@ class NuconGoalEnv(gym.Env):
|
||||
|
||||
# Terminators
|
||||
self._terminators = terminators or []
|
||||
self.uncertainty_penalty_start = uncertainty_penalty_start
|
||||
self.uncertainty_abort = uncertainty_abort
|
||||
self.uncertainty_penalty_scale = uncertainty_penalty_scale
|
||||
|
||||
self._desired_goal = np.zeros(n_goals, dtype=np.float32)
|
||||
self._total_steps = 0
|
||||
@ -360,13 +364,13 @@ class NuconGoalEnv(gym.Env):
|
||||
|
||||
self._total_steps += 1
|
||||
if self.simulator:
|
||||
try:
|
||||
self.simulator.update(self.seconds_per_step)
|
||||
except Exception as e:
|
||||
if HighUncertaintyError and isinstance(e, HighUncertaintyError):
|
||||
needs_uncertainty = self.uncertainty_penalty_start is not None or self.uncertainty_abort is not None
|
||||
uncertainty = self.simulator.update(self.seconds_per_step, return_uncertainty=needs_uncertainty)
|
||||
if uncertainty is not None:
|
||||
if self.uncertainty_abort is not None and uncertainty >= self.uncertainty_abort:
|
||||
truncated = True
|
||||
else:
|
||||
raise
|
||||
if self.uncertainty_penalty_start is not None and uncertainty > self.uncertainty_penalty_start:
|
||||
reward -= self.uncertainty_penalty_scale * (uncertainty - self.uncertainty_penalty_start)
|
||||
else:
|
||||
sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0
|
||||
time.sleep(self.seconds_per_step / sim_speed)
|
||||
|
||||
38
nucon/sim.py
38
nucon/sim.py
@ -8,18 +8,6 @@ import torch
|
||||
from nucon.model import ReactorDynamicsModel, ReactorKNNModel
|
||||
import pickle
|
||||
|
||||
class HighUncertaintyError(Exception):
|
||||
"""Raised by NuconSimulator when the dynamics model uncertainty exceeds the threshold.
|
||||
|
||||
Caught by NuconEnv/NuconGoalEnv and returned as truncated=True so the RL
|
||||
algorithm bootstraps the value rather than treating it as a terminal state.
|
||||
"""
|
||||
def __init__(self, uncertainty: float, threshold: float):
|
||||
self.uncertainty = uncertainty
|
||||
self.threshold = threshold
|
||||
super().__init__(f"Model uncertainty {uncertainty:.3f} exceeded threshold {threshold:.3f}")
|
||||
|
||||
|
||||
class OperatingState(Enum):
|
||||
# Tuple indicates a range of values, while list indicates a set of possible values
|
||||
OFFLINE = {
|
||||
@ -175,14 +163,13 @@ class NuconSimulator:
|
||||
for param_name in nucon.get_all_readable():
|
||||
setattr(self, param_name, None)
|
||||
|
||||
def __init__(self, host: str = 'localhost', port: int = 8786, uncertainty_threshold: float = None):
|
||||
def __init__(self, host: str = 'localhost', port: int = 8786):
|
||||
self._nucon = Nucon()
|
||||
self.parameters = self.Parameters(self._nucon)
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.time = 0.0
|
||||
self.allow_all_writes = False
|
||||
self.uncertainty_threshold = uncertainty_threshold
|
||||
self.set_state(OperatingState.OFFLINE)
|
||||
self.model = None
|
||||
self.readable_params = list(self._nucon.get_all_readable().keys())
|
||||
@ -228,9 +215,16 @@ class NuconSimulator:
|
||||
def set_allow_all_writes(self, allow: bool) -> None:
|
||||
self.allow_all_writes = allow
|
||||
|
||||
def update(self, time_step: float) -> None:
|
||||
self._update_reactor_state(time_step)
|
||||
def update(self, time_step: float, return_uncertainty: bool = False):
|
||||
"""Advance the simulator by time_step game-seconds.
|
||||
|
||||
If return_uncertainty=True and a kNN model is loaded, returns the GP
|
||||
posterior std for this step (0 = on known data, ~1 = OOD).
|
||||
Always returns None when using an NN model.
|
||||
"""
|
||||
uncertainty = self._update_reactor_state(time_step, return_uncertainty=return_uncertainty)
|
||||
self.time += time_step
|
||||
return uncertainty
|
||||
|
||||
def set_model(self, model) -> None:
|
||||
"""Set a pre-loaded ReactorDynamicsModel or ReactorKNNModel directly."""
|
||||
@ -262,7 +256,7 @@ class NuconSimulator:
|
||||
print(f"Error loading model: {str(e)}")
|
||||
self.model = None
|
||||
|
||||
def _update_reactor_state(self, time_step: float) -> None:
|
||||
def _update_reactor_state(self, time_step: float, return_uncertainty: bool = False):
|
||||
if not self.model:
|
||||
raise ValueError("Model not set. Please load a model using load_model() or set_model().")
|
||||
|
||||
@ -276,15 +270,13 @@ class NuconSimulator:
|
||||
value = 0.0 # fallback for params not initialised in sim state
|
||||
state[param_id] = value
|
||||
|
||||
# Forward pass — same interface for both NN and kNN
|
||||
# Forward pass
|
||||
uncertainty = None
|
||||
if isinstance(self.model, ReactorDynamicsModel):
|
||||
with torch.no_grad():
|
||||
next_state = self.model.forward(state, time_step)
|
||||
else:
|
||||
if self.uncertainty_threshold is not None:
|
||||
elif return_uncertainty:
|
||||
next_state, uncertainty = self.model.forward_with_uncertainty(state, time_step)
|
||||
if uncertainty > self.uncertainty_threshold:
|
||||
raise HighUncertaintyError(uncertainty, self.uncertainty_threshold)
|
||||
else:
|
||||
next_state = self.model.forward(state, time_step)
|
||||
|
||||
@ -295,6 +287,8 @@ class NuconSimulator:
|
||||
except (ValueError, KeyError):
|
||||
pass # ignore params that can't be set (type mismatch, unknown)
|
||||
|
||||
return uncertainty
|
||||
|
||||
def set_state(self, state: OperatingState) -> None:
|
||||
self._sample_parameters_from_state(state)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user