feat: abort trajectory on high kNN uncertainty in simulator
NuconSimulator now accepts uncertainty_threshold (default None = disabled).
When set and using a kNN model, _update_reactor_state() calls
forward_with_uncertainty() and raises HighUncertaintyError if the GP
posterior std exceeds the threshold.
NuconEnv and NuconGoalEnv catch HighUncertaintyError in step() and
return truncated=True, so SB3 bootstraps the value rather than treating
OOD regions as terminal states.
Usage:
simulator = NuconSimulator(uncertainty_threshold=0.3)
# episodes are cut short when the policy wanders OOD
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
e2e8db1f04
commit
6cb93ad56d
16
nucon/rl.py
16
nucon/rl.py
@ -5,6 +5,10 @@ import time
|
||||
from typing import Dict, Any
|
||||
from enum import Enum
|
||||
from nucon import Nucon, BreakerStatus, PumpStatus, PumpDryStatus, PumpOverloadStatus
|
||||
try:
|
||||
from nucon.sim import HighUncertaintyError
|
||||
except ImportError:
|
||||
HighUncertaintyError = None
|
||||
|
||||
Objectives = {
|
||||
"null": lambda obs: 0,
|
||||
@ -127,7 +131,13 @@ class NuconEnv(gym.Env):
|
||||
|
||||
self._total_steps += 1
|
||||
if self.simulator:
|
||||
try:
|
||||
self.simulator.update(self.seconds_per_step)
|
||||
except Exception as e:
|
||||
if HighUncertaintyError and isinstance(e, HighUncertaintyError):
|
||||
truncated = True
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
# Sleep to let the game advance seconds_per_step game-seconds,
|
||||
# accounting for the game's simulation speed multiplier.
|
||||
@ -350,7 +360,13 @@ class NuconGoalEnv(gym.Env):
|
||||
|
||||
self._total_steps += 1
|
||||
if self.simulator:
|
||||
try:
|
||||
self.simulator.update(self.seconds_per_step)
|
||||
except Exception as e:
|
||||
if HighUncertaintyError and isinstance(e, HighUncertaintyError):
|
||||
truncated = True
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0
|
||||
time.sleep(self.seconds_per_step / sim_speed)
|
||||
|
||||
20
nucon/sim.py
20
nucon/sim.py
@ -8,6 +8,18 @@ import torch
|
||||
from nucon.model import ReactorDynamicsModel, ReactorKNNModel
|
||||
import pickle
|
||||
|
||||
class HighUncertaintyError(Exception):
|
||||
"""Raised by NuconSimulator when the dynamics model uncertainty exceeds the threshold.
|
||||
|
||||
Caught by NuconEnv/NuconGoalEnv and returned as truncated=True so the RL
|
||||
algorithm bootstraps the value rather than treating it as a terminal state.
|
||||
"""
|
||||
def __init__(self, uncertainty: float, threshold: float):
|
||||
self.uncertainty = uncertainty
|
||||
self.threshold = threshold
|
||||
super().__init__(f"Model uncertainty {uncertainty:.3f} exceeded threshold {threshold:.3f}")
|
||||
|
||||
|
||||
class OperatingState(Enum):
|
||||
# Tuple indicates a range of values, while list indicates a set of possible values
|
||||
OFFLINE = {
|
||||
@ -163,13 +175,14 @@ class NuconSimulator:
|
||||
for param_name in nucon.get_all_readable():
|
||||
setattr(self, param_name, None)
|
||||
|
||||
def __init__(self, host: str = 'localhost', port: int = 8786):
|
||||
def __init__(self, host: str = 'localhost', port: int = 8786, uncertainty_threshold: float = None):
|
||||
self._nucon = Nucon()
|
||||
self.parameters = self.Parameters(self._nucon)
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.time = 0.0
|
||||
self.allow_all_writes = False
|
||||
self.uncertainty_threshold = uncertainty_threshold
|
||||
self.set_state(OperatingState.OFFLINE)
|
||||
self.model = None
|
||||
self.readable_params = list(self._nucon.get_all_readable().keys())
|
||||
@ -267,6 +280,11 @@ class NuconSimulator:
|
||||
if isinstance(self.model, ReactorDynamicsModel):
|
||||
with torch.no_grad():
|
||||
next_state = self.model.forward(state, time_step)
|
||||
else:
|
||||
if self.uncertainty_threshold is not None:
|
||||
next_state, uncertainty = self.model.forward_with_uncertainty(state, time_step)
|
||||
if uncertainty > self.uncertainty_threshold:
|
||||
raise HighUncertaintyError(uncertainty, self.uncertainty_threshold)
|
||||
else:
|
||||
next_state = self.model.forward(state, time_step)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user