feat: uncertainty-aware training with penalty and abort

sim.py:
- simulator.update(return_uncertainty=True) calls forward_with_uncertainty
  on kNN models and returns the GP std; returns None for NN or when not
  requested (no extra cost if unused)
- No state stored on simulator; caller decides what to do with the value

rl.py (NuconEnv and NuconGoalEnv):
- uncertainty_penalty_start: above this GP std, subtract a linear penalty
  from the reward (scaled by uncertainty_penalty_scale, default 1.0)
- uncertainty_abort: at or above this GP std, set truncated=True
- Only calls update(return_uncertainty=True) when either threshold is set
- Uncertainty only applies when using a simulator (kNN model); ignored otherwise

Example:
    simulator = NuconSimulator()
    simulator.load_model('reactor_knn.pkl')
    env = NuconGoalEnv(..., simulator=simulator,
                       uncertainty_penalty_start=0.3,
                       uncertainty_abort=0.7,
                       uncertainty_penalty_scale=2.0)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Dominik Moritz Roth 2026-03-12 18:37:09 +01:00
parent 6cb93ad56d
commit 65190dffea
2 changed files with 41 additions and 43 deletions

View File

@ -5,10 +5,6 @@ import time
from typing import Dict, Any
from enum import Enum
from nucon import Nucon, BreakerStatus, PumpStatus, PumpDryStatus, PumpOverloadStatus
try:
from nucon.sim import HighUncertaintyError
except ImportError:
HighUncertaintyError = None
Objectives = {
"null": lambda obs: 0,
@ -27,7 +23,8 @@ Parameterized_Objectives = {
class NuconEnv(gym.Env):
metadata = {'render_modes': ['human']}
def __init__(self, nucon=None, simulator=None, render_mode=None, seconds_per_step=5, objectives=['null'], terminators=['null'], objective_weights=None, terminate_above=0):
def __init__(self, nucon=None, simulator=None, render_mode=None, seconds_per_step=5, objectives=['null'], terminators=['null'], objective_weights=None, terminate_above=0,
uncertainty_penalty_start: float = None, uncertainty_abort: float = None, uncertainty_penalty_scale: float = 1.0):
super().__init__()
self.render_mode = render_mode
@ -37,6 +34,9 @@ class NuconEnv(gym.Env):
self.objective_weights = objective_weights
self.terminate_above = terminate_above
self.simulator = simulator
self.uncertainty_penalty_start = uncertainty_penalty_start
self.uncertainty_abort = uncertainty_abort
self.uncertainty_penalty_scale = uncertainty_penalty_scale
if nucon is None:
if simulator:
@ -131,16 +131,14 @@ class NuconEnv(gym.Env):
self._total_steps += 1
if self.simulator:
try:
self.simulator.update(self.seconds_per_step)
except Exception as e:
if HighUncertaintyError and isinstance(e, HighUncertaintyError):
needs_uncertainty = self.uncertainty_penalty_start is not None or self.uncertainty_abort is not None
uncertainty = self.simulator.update(self.seconds_per_step, return_uncertainty=needs_uncertainty)
if uncertainty is not None:
if self.uncertainty_abort is not None and uncertainty >= self.uncertainty_abort:
truncated = True
if self.uncertainty_penalty_start is not None and uncertainty > self.uncertainty_penalty_start:
reward -= self.uncertainty_penalty_scale * (uncertainty - self.uncertainty_penalty_start)
else:
raise
else:
# Sleep to let the game advance seconds_per_step game-seconds,
# accounting for the game's simulation speed multiplier.
sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0
time.sleep(self.seconds_per_step / sim_speed)
return observation, reward, terminated, truncated, info
@ -222,6 +220,9 @@ class NuconGoalEnv(gym.Env):
seconds_per_step=5,
terminators=None,
terminate_above=0,
uncertainty_penalty_start: float = None,
uncertainty_abort: float = None,
uncertainty_penalty_scale: float = 1.0,
):
super().__init__()
@ -286,6 +287,9 @@ class NuconGoalEnv(gym.Env):
# Terminators
self._terminators = terminators or []
self.uncertainty_penalty_start = uncertainty_penalty_start
self.uncertainty_abort = uncertainty_abort
self.uncertainty_penalty_scale = uncertainty_penalty_scale
self._desired_goal = np.zeros(n_goals, dtype=np.float32)
self._total_steps = 0
@ -360,13 +364,13 @@ class NuconGoalEnv(gym.Env):
self._total_steps += 1
if self.simulator:
try:
self.simulator.update(self.seconds_per_step)
except Exception as e:
if HighUncertaintyError and isinstance(e, HighUncertaintyError):
needs_uncertainty = self.uncertainty_penalty_start is not None or self.uncertainty_abort is not None
uncertainty = self.simulator.update(self.seconds_per_step, return_uncertainty=needs_uncertainty)
if uncertainty is not None:
if self.uncertainty_abort is not None and uncertainty >= self.uncertainty_abort:
truncated = True
else:
raise
if self.uncertainty_penalty_start is not None and uncertainty > self.uncertainty_penalty_start:
reward -= self.uncertainty_penalty_scale * (uncertainty - self.uncertainty_penalty_start)
else:
sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0
time.sleep(self.seconds_per_step / sim_speed)

View File

@ -8,18 +8,6 @@ import torch
from nucon.model import ReactorDynamicsModel, ReactorKNNModel
import pickle
class HighUncertaintyError(Exception):
"""Raised by NuconSimulator when the dynamics model uncertainty exceeds the threshold.
Caught by NuconEnv/NuconGoalEnv and returned as truncated=True so the RL
algorithm bootstraps the value rather than treating it as a terminal state.
"""
def __init__(self, uncertainty: float, threshold: float):
self.uncertainty = uncertainty
self.threshold = threshold
super().__init__(f"Model uncertainty {uncertainty:.3f} exceeded threshold {threshold:.3f}")
class OperatingState(Enum):
# Tuple indicates a range of values, while list indicates a set of possible values
OFFLINE = {
@ -175,14 +163,13 @@ class NuconSimulator:
for param_name in nucon.get_all_readable():
setattr(self, param_name, None)
def __init__(self, host: str = 'localhost', port: int = 8786, uncertainty_threshold: float = None):
def __init__(self, host: str = 'localhost', port: int = 8786):
self._nucon = Nucon()
self.parameters = self.Parameters(self._nucon)
self.host = host
self.port = port
self.time = 0.0
self.allow_all_writes = False
self.uncertainty_threshold = uncertainty_threshold
self.set_state(OperatingState.OFFLINE)
self.model = None
self.readable_params = list(self._nucon.get_all_readable().keys())
@ -228,9 +215,16 @@ class NuconSimulator:
def set_allow_all_writes(self, allow: bool) -> None:
self.allow_all_writes = allow
def update(self, time_step: float) -> None:
self._update_reactor_state(time_step)
def update(self, time_step: float, return_uncertainty: bool = False):
"""Advance the simulator by time_step game-seconds.
If return_uncertainty=True and a kNN model is loaded, returns the GP
posterior std for this step (0 = on known data, ~1 = OOD).
Always returns None when using an NN model.
"""
uncertainty = self._update_reactor_state(time_step, return_uncertainty=return_uncertainty)
self.time += time_step
return uncertainty
def set_model(self, model) -> None:
"""Set a pre-loaded ReactorDynamicsModel or ReactorKNNModel directly."""
@ -262,7 +256,7 @@ class NuconSimulator:
print(f"Error loading model: {str(e)}")
self.model = None
def _update_reactor_state(self, time_step: float) -> None:
def _update_reactor_state(self, time_step: float, return_uncertainty: bool = False):
if not self.model:
raise ValueError("Model not set. Please load a model using load_model() or set_model().")
@ -276,15 +270,13 @@ class NuconSimulator:
value = 0.0 # fallback for params not initialised in sim state
state[param_id] = value
# Forward pass — same interface for both NN and kNN
# Forward pass
uncertainty = None
if isinstance(self.model, ReactorDynamicsModel):
with torch.no_grad():
next_state = self.model.forward(state, time_step)
else:
if self.uncertainty_threshold is not None:
elif return_uncertainty:
next_state, uncertainty = self.model.forward_with_uncertainty(state, time_step)
if uncertainty > self.uncertainty_threshold:
raise HighUncertaintyError(uncertainty, self.uncertainty_threshold)
else:
next_state = self.model.forward(state, time_step)
@ -295,6 +287,8 @@ class NuconSimulator:
except (ValueError, KeyError):
pass # ignore params that can't be set (type mismatch, unknown)
return uncertainty
def set_state(self, state: OperatingState) -> None:
self._sample_parameters_from_state(state)