feat: uncertainty-aware training with penalty and abort

sim.py:
- simulator.update(return_uncertainty=True) calls forward_with_uncertainty
  on kNN models and returns the GP std; returns None for NN or when not
  requested (no extra cost if unused)
- No state stored on simulator; caller decides what to do with the value

rl.py (NuconEnv and NuconGoalEnv):
- uncertainty_penalty_start: above this GP std, subtract a linear penalty
  from the reward (scaled by uncertainty_penalty_scale, default 1.0)
- uncertainty_abort: at or above this GP std, set truncated=True
- Only calls update(return_uncertainty=True) when either threshold is set
- Uncertainty only applies when using a simulator (kNN model); ignored otherwise

Example:
    simulator = NuconSimulator()
    simulator.load_model('reactor_knn.pkl')
    env = NuconGoalEnv(..., simulator=simulator,
                       uncertainty_penalty_start=0.3,
                       uncertainty_abort=0.7,
                       uncertainty_penalty_scale=2.0)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Dominik Moritz Roth 2026-03-12 18:37:09 +01:00
parent 6cb93ad56d
commit 65190dffea
2 changed files with 41 additions and 43 deletions

View File

@ -5,10 +5,6 @@ import time
from typing import Dict, Any from typing import Dict, Any
from enum import Enum from enum import Enum
from nucon import Nucon, BreakerStatus, PumpStatus, PumpDryStatus, PumpOverloadStatus from nucon import Nucon, BreakerStatus, PumpStatus, PumpDryStatus, PumpOverloadStatus
try:
from nucon.sim import HighUncertaintyError
except ImportError:
HighUncertaintyError = None
Objectives = { Objectives = {
"null": lambda obs: 0, "null": lambda obs: 0,
@ -27,7 +23,8 @@ Parameterized_Objectives = {
class NuconEnv(gym.Env): class NuconEnv(gym.Env):
metadata = {'render_modes': ['human']} metadata = {'render_modes': ['human']}
def __init__(self, nucon=None, simulator=None, render_mode=None, seconds_per_step=5, objectives=['null'], terminators=['null'], objective_weights=None, terminate_above=0): def __init__(self, nucon=None, simulator=None, render_mode=None, seconds_per_step=5, objectives=['null'], terminators=['null'], objective_weights=None, terminate_above=0,
uncertainty_penalty_start: float = None, uncertainty_abort: float = None, uncertainty_penalty_scale: float = 1.0):
super().__init__() super().__init__()
self.render_mode = render_mode self.render_mode = render_mode
@ -37,6 +34,9 @@ class NuconEnv(gym.Env):
self.objective_weights = objective_weights self.objective_weights = objective_weights
self.terminate_above = terminate_above self.terminate_above = terminate_above
self.simulator = simulator self.simulator = simulator
self.uncertainty_penalty_start = uncertainty_penalty_start
self.uncertainty_abort = uncertainty_abort
self.uncertainty_penalty_scale = uncertainty_penalty_scale
if nucon is None: if nucon is None:
if simulator: if simulator:
@ -131,16 +131,14 @@ class NuconEnv(gym.Env):
self._total_steps += 1 self._total_steps += 1
if self.simulator: if self.simulator:
try: needs_uncertainty = self.uncertainty_penalty_start is not None or self.uncertainty_abort is not None
self.simulator.update(self.seconds_per_step) uncertainty = self.simulator.update(self.seconds_per_step, return_uncertainty=needs_uncertainty)
except Exception as e: if uncertainty is not None:
if HighUncertaintyError and isinstance(e, HighUncertaintyError): if self.uncertainty_abort is not None and uncertainty >= self.uncertainty_abort:
truncated = True truncated = True
else: if self.uncertainty_penalty_start is not None and uncertainty > self.uncertainty_penalty_start:
raise reward -= self.uncertainty_penalty_scale * (uncertainty - self.uncertainty_penalty_start)
else: else:
# Sleep to let the game advance seconds_per_step game-seconds,
# accounting for the game's simulation speed multiplier.
sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0 sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0
time.sleep(self.seconds_per_step / sim_speed) time.sleep(self.seconds_per_step / sim_speed)
return observation, reward, terminated, truncated, info return observation, reward, terminated, truncated, info
@ -222,6 +220,9 @@ class NuconGoalEnv(gym.Env):
seconds_per_step=5, seconds_per_step=5,
terminators=None, terminators=None,
terminate_above=0, terminate_above=0,
uncertainty_penalty_start: float = None,
uncertainty_abort: float = None,
uncertainty_penalty_scale: float = 1.0,
): ):
super().__init__() super().__init__()
@ -286,6 +287,9 @@ class NuconGoalEnv(gym.Env):
# Terminators # Terminators
self._terminators = terminators or [] self._terminators = terminators or []
self.uncertainty_penalty_start = uncertainty_penalty_start
self.uncertainty_abort = uncertainty_abort
self.uncertainty_penalty_scale = uncertainty_penalty_scale
self._desired_goal = np.zeros(n_goals, dtype=np.float32) self._desired_goal = np.zeros(n_goals, dtype=np.float32)
self._total_steps = 0 self._total_steps = 0
@ -360,13 +364,13 @@ class NuconGoalEnv(gym.Env):
self._total_steps += 1 self._total_steps += 1
if self.simulator: if self.simulator:
try: needs_uncertainty = self.uncertainty_penalty_start is not None or self.uncertainty_abort is not None
self.simulator.update(self.seconds_per_step) uncertainty = self.simulator.update(self.seconds_per_step, return_uncertainty=needs_uncertainty)
except Exception as e: if uncertainty is not None:
if HighUncertaintyError and isinstance(e, HighUncertaintyError): if self.uncertainty_abort is not None and uncertainty >= self.uncertainty_abort:
truncated = True truncated = True
else: if self.uncertainty_penalty_start is not None and uncertainty > self.uncertainty_penalty_start:
raise reward -= self.uncertainty_penalty_scale * (uncertainty - self.uncertainty_penalty_start)
else: else:
sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0 sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0
time.sleep(self.seconds_per_step / sim_speed) time.sleep(self.seconds_per_step / sim_speed)

View File

@ -8,18 +8,6 @@ import torch
from nucon.model import ReactorDynamicsModel, ReactorKNNModel from nucon.model import ReactorDynamicsModel, ReactorKNNModel
import pickle import pickle
class HighUncertaintyError(Exception):
"""Raised by NuconSimulator when the dynamics model uncertainty exceeds the threshold.
Caught by NuconEnv/NuconGoalEnv and returned as truncated=True so the RL
algorithm bootstraps the value rather than treating it as a terminal state.
"""
def __init__(self, uncertainty: float, threshold: float):
self.uncertainty = uncertainty
self.threshold = threshold
super().__init__(f"Model uncertainty {uncertainty:.3f} exceeded threshold {threshold:.3f}")
class OperatingState(Enum): class OperatingState(Enum):
# Tuple indicates a range of values, while list indicates a set of possible values # Tuple indicates a range of values, while list indicates a set of possible values
OFFLINE = { OFFLINE = {
@ -175,14 +163,13 @@ class NuconSimulator:
for param_name in nucon.get_all_readable(): for param_name in nucon.get_all_readable():
setattr(self, param_name, None) setattr(self, param_name, None)
def __init__(self, host: str = 'localhost', port: int = 8786, uncertainty_threshold: float = None): def __init__(self, host: str = 'localhost', port: int = 8786):
self._nucon = Nucon() self._nucon = Nucon()
self.parameters = self.Parameters(self._nucon) self.parameters = self.Parameters(self._nucon)
self.host = host self.host = host
self.port = port self.port = port
self.time = 0.0 self.time = 0.0
self.allow_all_writes = False self.allow_all_writes = False
self.uncertainty_threshold = uncertainty_threshold
self.set_state(OperatingState.OFFLINE) self.set_state(OperatingState.OFFLINE)
self.model = None self.model = None
self.readable_params = list(self._nucon.get_all_readable().keys()) self.readable_params = list(self._nucon.get_all_readable().keys())
@ -228,9 +215,16 @@ class NuconSimulator:
def set_allow_all_writes(self, allow: bool) -> None: def set_allow_all_writes(self, allow: bool) -> None:
self.allow_all_writes = allow self.allow_all_writes = allow
def update(self, time_step: float) -> None: def update(self, time_step: float, return_uncertainty: bool = False):
self._update_reactor_state(time_step) """Advance the simulator by time_step game-seconds.
If return_uncertainty=True and a kNN model is loaded, returns the GP
posterior std for this step (0 = on known data, ~1 = OOD).
Always returns None when using an NN model.
"""
uncertainty = self._update_reactor_state(time_step, return_uncertainty=return_uncertainty)
self.time += time_step self.time += time_step
return uncertainty
def set_model(self, model) -> None: def set_model(self, model) -> None:
"""Set a pre-loaded ReactorDynamicsModel or ReactorKNNModel directly.""" """Set a pre-loaded ReactorDynamicsModel or ReactorKNNModel directly."""
@ -262,7 +256,7 @@ class NuconSimulator:
print(f"Error loading model: {str(e)}") print(f"Error loading model: {str(e)}")
self.model = None self.model = None
def _update_reactor_state(self, time_step: float) -> None: def _update_reactor_state(self, time_step: float, return_uncertainty: bool = False):
if not self.model: if not self.model:
raise ValueError("Model not set. Please load a model using load_model() or set_model().") raise ValueError("Model not set. Please load a model using load_model() or set_model().")
@ -276,17 +270,15 @@ class NuconSimulator:
value = 0.0 # fallback for params not initialised in sim state value = 0.0 # fallback for params not initialised in sim state
state[param_id] = value state[param_id] = value
# Forward pass — same interface for both NN and kNN # Forward pass
uncertainty = None
if isinstance(self.model, ReactorDynamicsModel): if isinstance(self.model, ReactorDynamicsModel):
with torch.no_grad(): with torch.no_grad():
next_state = self.model.forward(state, time_step) next_state = self.model.forward(state, time_step)
elif return_uncertainty:
next_state, uncertainty = self.model.forward_with_uncertainty(state, time_step)
else: else:
if self.uncertainty_threshold is not None: next_state = self.model.forward(state, time_step)
next_state, uncertainty = self.model.forward_with_uncertainty(state, time_step)
if uncertainty > self.uncertainty_threshold:
raise HighUncertaintyError(uncertainty, self.uncertainty_threshold)
else:
next_state = self.model.forward(state, time_step)
# Update only the output params the model predicts # Update only the output params the model predicts
for param_id, value in next_state.items(): for param_id, value in next_state.items():
@ -295,6 +287,8 @@ class NuconSimulator:
except (ValueError, KeyError): except (ValueError, KeyError):
pass # ignore params that can't be set (type mismatch, unknown) pass # ignore params that can't be set (type mismatch, unknown)
return uncertainty
def set_state(self, state: OperatingState) -> None: def set_state(self, state: OperatingState) -> None:
self._sample_parameters_from_state(state) self._sample_parameters_from_state(state)