feat: abort trajectory on high kNN uncertainty in simulator

NuconSimulator now accepts uncertainty_threshold (default None = disabled).
When set and using a kNN model, _update_reactor_state() calls
forward_with_uncertainty() and raises HighUncertaintyError if the GP
posterior std exceeds the threshold.

NuconEnv and NuconGoalEnv catch HighUncertaintyError in step() and
return truncated=True, so SB3 bootstraps the value rather than treating
OOD regions as terminal states.

Usage:
    simulator = NuconSimulator(uncertainty_threshold=0.3)
    # episodes are cut short when the policy wanders OOD

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Dominik Moritz Roth 2026-03-12 18:29:54 +01:00
parent e2e8db1f04
commit 6cb93ad56d
2 changed files with 38 additions and 4 deletions

View File

@ -5,6 +5,10 @@ import time
from typing import Dict, Any from typing import Dict, Any
from enum import Enum from enum import Enum
from nucon import Nucon, BreakerStatus, PumpStatus, PumpDryStatus, PumpOverloadStatus from nucon import Nucon, BreakerStatus, PumpStatus, PumpDryStatus, PumpOverloadStatus
try:
from nucon.sim import HighUncertaintyError
except ImportError:
HighUncertaintyError = None
Objectives = { Objectives = {
"null": lambda obs: 0, "null": lambda obs: 0,
@ -127,7 +131,13 @@ class NuconEnv(gym.Env):
self._total_steps += 1 self._total_steps += 1
if self.simulator: if self.simulator:
try:
self.simulator.update(self.seconds_per_step) self.simulator.update(self.seconds_per_step)
except Exception as e:
if HighUncertaintyError and isinstance(e, HighUncertaintyError):
truncated = True
else:
raise
else: else:
# Sleep to let the game advance seconds_per_step game-seconds, # Sleep to let the game advance seconds_per_step game-seconds,
# accounting for the game's simulation speed multiplier. # accounting for the game's simulation speed multiplier.
@ -350,7 +360,13 @@ class NuconGoalEnv(gym.Env):
self._total_steps += 1 self._total_steps += 1
if self.simulator: if self.simulator:
try:
self.simulator.update(self.seconds_per_step) self.simulator.update(self.seconds_per_step)
except Exception as e:
if HighUncertaintyError and isinstance(e, HighUncertaintyError):
truncated = True
else:
raise
else: else:
sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0 sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0
time.sleep(self.seconds_per_step / sim_speed) time.sleep(self.seconds_per_step / sim_speed)

View File

@ -8,6 +8,18 @@ import torch
from nucon.model import ReactorDynamicsModel, ReactorKNNModel from nucon.model import ReactorDynamicsModel, ReactorKNNModel
import pickle import pickle
class HighUncertaintyError(Exception):
"""Raised by NuconSimulator when the dynamics model uncertainty exceeds the threshold.
Caught by NuconEnv/NuconGoalEnv and returned as truncated=True so the RL
algorithm bootstraps the value rather than treating it as a terminal state.
"""
def __init__(self, uncertainty: float, threshold: float):
self.uncertainty = uncertainty
self.threshold = threshold
super().__init__(f"Model uncertainty {uncertainty:.3f} exceeded threshold {threshold:.3f}")
class OperatingState(Enum): class OperatingState(Enum):
# Tuple indicates a range of values, while list indicates a set of possible values # Tuple indicates a range of values, while list indicates a set of possible values
OFFLINE = { OFFLINE = {
@ -163,13 +175,14 @@ class NuconSimulator:
for param_name in nucon.get_all_readable(): for param_name in nucon.get_all_readable():
setattr(self, param_name, None) setattr(self, param_name, None)
def __init__(self, host: str = 'localhost', port: int = 8786): def __init__(self, host: str = 'localhost', port: int = 8786, uncertainty_threshold: float = None):
self._nucon = Nucon() self._nucon = Nucon()
self.parameters = self.Parameters(self._nucon) self.parameters = self.Parameters(self._nucon)
self.host = host self.host = host
self.port = port self.port = port
self.time = 0.0 self.time = 0.0
self.allow_all_writes = False self.allow_all_writes = False
self.uncertainty_threshold = uncertainty_threshold
self.set_state(OperatingState.OFFLINE) self.set_state(OperatingState.OFFLINE)
self.model = None self.model = None
self.readable_params = list(self._nucon.get_all_readable().keys()) self.readable_params = list(self._nucon.get_all_readable().keys())
@ -267,6 +280,11 @@ class NuconSimulator:
if isinstance(self.model, ReactorDynamicsModel): if isinstance(self.model, ReactorDynamicsModel):
with torch.no_grad(): with torch.no_grad():
next_state = self.model.forward(state, time_step) next_state = self.model.forward(state, time_step)
else:
if self.uncertainty_threshold is not None:
next_state, uncertainty = self.model.forward_with_uncertainty(state, time_step)
if uncertainty > self.uncertainty_threshold:
raise HighUncertaintyError(uncertainty, self.uncertainty_threshold)
else: else:
next_state = self.model.forward(state, time_step) next_state = self.model.forward(state, time_step)