feat: abort trajectory on high kNN uncertainty in simulator

NuconSimulator now accepts uncertainty_threshold (default None = disabled).
When set and using a kNN model, _update_reactor_state() calls
forward_with_uncertainty() and raises HighUncertaintyError if the GP
posterior std exceeds the threshold.

NuconEnv and NuconGoalEnv catch HighUncertaintyError in step() and
return truncated=True, so SB3 bootstraps the value rather than treating
OOD regions as terminal states.

Usage:
    simulator = NuconSimulator(uncertainty_threshold=0.3)
    # episodes are cut short when the policy wanders OOD

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Dominik Moritz Roth 2026-03-12 18:29:54 +01:00
parent e2e8db1f04
commit 6cb93ad56d
2 changed files with 38 additions and 4 deletions

View File

@ -5,6 +5,10 @@ import time
from typing import Dict, Any
from enum import Enum
from nucon import Nucon, BreakerStatus, PumpStatus, PumpDryStatus, PumpOverloadStatus
try:
from nucon.sim import HighUncertaintyError
except ImportError:
HighUncertaintyError = None
Objectives = {
"null": lambda obs: 0,
@ -127,7 +131,13 @@ class NuconEnv(gym.Env):
self._total_steps += 1
if self.simulator:
try:
self.simulator.update(self.seconds_per_step)
except Exception as e:
if HighUncertaintyError and isinstance(e, HighUncertaintyError):
truncated = True
else:
raise
else:
# Sleep to let the game advance seconds_per_step game-seconds,
# accounting for the game's simulation speed multiplier.
@ -350,7 +360,13 @@ class NuconGoalEnv(gym.Env):
self._total_steps += 1
if self.simulator:
try:
self.simulator.update(self.seconds_per_step)
except Exception as e:
if HighUncertaintyError and isinstance(e, HighUncertaintyError):
truncated = True
else:
raise
else:
sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0
time.sleep(self.seconds_per_step / sim_speed)

View File

@ -8,6 +8,18 @@ import torch
from nucon.model import ReactorDynamicsModel, ReactorKNNModel
import pickle
class HighUncertaintyError(Exception):
"""Raised by NuconSimulator when the dynamics model uncertainty exceeds the threshold.
Caught by NuconEnv/NuconGoalEnv and returned as truncated=True so the RL
algorithm bootstraps the value rather than treating it as a terminal state.
"""
def __init__(self, uncertainty: float, threshold: float):
self.uncertainty = uncertainty
self.threshold = threshold
super().__init__(f"Model uncertainty {uncertainty:.3f} exceeded threshold {threshold:.3f}")
class OperatingState(Enum):
# Tuple indicates a range of values, while list indicates a set of possible values
OFFLINE = {
@ -163,13 +175,14 @@ class NuconSimulator:
for param_name in nucon.get_all_readable():
setattr(self, param_name, None)
def __init__(self, host: str = 'localhost', port: int = 8786):
def __init__(self, host: str = 'localhost', port: int = 8786, uncertainty_threshold: float = None):
self._nucon = Nucon()
self.parameters = self.Parameters(self._nucon)
self.host = host
self.port = port
self.time = 0.0
self.allow_all_writes = False
self.uncertainty_threshold = uncertainty_threshold
self.set_state(OperatingState.OFFLINE)
self.model = None
self.readable_params = list(self._nucon.get_all_readable().keys())
@ -267,6 +280,11 @@ class NuconSimulator:
if isinstance(self.model, ReactorDynamicsModel):
with torch.no_grad():
next_state = self.model.forward(state, time_step)
else:
if self.uncertainty_threshold is not None:
next_state, uncertainty = self.model.forward_with_uncertainty(state, time_step)
if uncertainty > self.uncertainty_threshold:
raise HighUncertaintyError(uncertainty, self.uncertainty_threshold)
else:
next_state = self.model.forward(state, time_step)