feat: abort trajectory on high kNN uncertainty in simulator
NuconSimulator now accepts uncertainty_threshold (default None = disabled).
When set and using a kNN model, _update_reactor_state() calls
forward_with_uncertainty() and raises HighUncertaintyError if the GP
posterior std exceeds the threshold.
NuconEnv and NuconGoalEnv catch HighUncertaintyError in step() and
return truncated=True, so SB3 bootstraps the value rather than treating
OOD regions as terminal states.
Usage:
simulator = NuconSimulator(uncertainty_threshold=0.3)
# episodes are cut short when the policy wanders OOD
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
e2e8db1f04
commit
6cb93ad56d
16
nucon/rl.py
16
nucon/rl.py
@ -5,6 +5,10 @@ import time
|
|||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from nucon import Nucon, BreakerStatus, PumpStatus, PumpDryStatus, PumpOverloadStatus
|
from nucon import Nucon, BreakerStatus, PumpStatus, PumpDryStatus, PumpOverloadStatus
|
||||||
|
try:
|
||||||
|
from nucon.sim import HighUncertaintyError
|
||||||
|
except ImportError:
|
||||||
|
HighUncertaintyError = None
|
||||||
|
|
||||||
Objectives = {
|
Objectives = {
|
||||||
"null": lambda obs: 0,
|
"null": lambda obs: 0,
|
||||||
@ -127,7 +131,13 @@ class NuconEnv(gym.Env):
|
|||||||
|
|
||||||
self._total_steps += 1
|
self._total_steps += 1
|
||||||
if self.simulator:
|
if self.simulator:
|
||||||
|
try:
|
||||||
self.simulator.update(self.seconds_per_step)
|
self.simulator.update(self.seconds_per_step)
|
||||||
|
except Exception as e:
|
||||||
|
if HighUncertaintyError and isinstance(e, HighUncertaintyError):
|
||||||
|
truncated = True
|
||||||
|
else:
|
||||||
|
raise
|
||||||
else:
|
else:
|
||||||
# Sleep to let the game advance seconds_per_step game-seconds,
|
# Sleep to let the game advance seconds_per_step game-seconds,
|
||||||
# accounting for the game's simulation speed multiplier.
|
# accounting for the game's simulation speed multiplier.
|
||||||
@ -350,7 +360,13 @@ class NuconGoalEnv(gym.Env):
|
|||||||
|
|
||||||
self._total_steps += 1
|
self._total_steps += 1
|
||||||
if self.simulator:
|
if self.simulator:
|
||||||
|
try:
|
||||||
self.simulator.update(self.seconds_per_step)
|
self.simulator.update(self.seconds_per_step)
|
||||||
|
except Exception as e:
|
||||||
|
if HighUncertaintyError and isinstance(e, HighUncertaintyError):
|
||||||
|
truncated = True
|
||||||
|
else:
|
||||||
|
raise
|
||||||
else:
|
else:
|
||||||
sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0
|
sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0
|
||||||
time.sleep(self.seconds_per_step / sim_speed)
|
time.sleep(self.seconds_per_step / sim_speed)
|
||||||
|
|||||||
20
nucon/sim.py
20
nucon/sim.py
@ -8,6 +8,18 @@ import torch
|
|||||||
from nucon.model import ReactorDynamicsModel, ReactorKNNModel
|
from nucon.model import ReactorDynamicsModel, ReactorKNNModel
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
|
class HighUncertaintyError(Exception):
|
||||||
|
"""Raised by NuconSimulator when the dynamics model uncertainty exceeds the threshold.
|
||||||
|
|
||||||
|
Caught by NuconEnv/NuconGoalEnv and returned as truncated=True so the RL
|
||||||
|
algorithm bootstraps the value rather than treating it as a terminal state.
|
||||||
|
"""
|
||||||
|
def __init__(self, uncertainty: float, threshold: float):
|
||||||
|
self.uncertainty = uncertainty
|
||||||
|
self.threshold = threshold
|
||||||
|
super().__init__(f"Model uncertainty {uncertainty:.3f} exceeded threshold {threshold:.3f}")
|
||||||
|
|
||||||
|
|
||||||
class OperatingState(Enum):
|
class OperatingState(Enum):
|
||||||
# Tuple indicates a range of values, while list indicates a set of possible values
|
# Tuple indicates a range of values, while list indicates a set of possible values
|
||||||
OFFLINE = {
|
OFFLINE = {
|
||||||
@ -163,13 +175,14 @@ class NuconSimulator:
|
|||||||
for param_name in nucon.get_all_readable():
|
for param_name in nucon.get_all_readable():
|
||||||
setattr(self, param_name, None)
|
setattr(self, param_name, None)
|
||||||
|
|
||||||
def __init__(self, host: str = 'localhost', port: int = 8786):
|
def __init__(self, host: str = 'localhost', port: int = 8786, uncertainty_threshold: float = None):
|
||||||
self._nucon = Nucon()
|
self._nucon = Nucon()
|
||||||
self.parameters = self.Parameters(self._nucon)
|
self.parameters = self.Parameters(self._nucon)
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
self.time = 0.0
|
self.time = 0.0
|
||||||
self.allow_all_writes = False
|
self.allow_all_writes = False
|
||||||
|
self.uncertainty_threshold = uncertainty_threshold
|
||||||
self.set_state(OperatingState.OFFLINE)
|
self.set_state(OperatingState.OFFLINE)
|
||||||
self.model = None
|
self.model = None
|
||||||
self.readable_params = list(self._nucon.get_all_readable().keys())
|
self.readable_params = list(self._nucon.get_all_readable().keys())
|
||||||
@ -267,6 +280,11 @@ class NuconSimulator:
|
|||||||
if isinstance(self.model, ReactorDynamicsModel):
|
if isinstance(self.model, ReactorDynamicsModel):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
next_state = self.model.forward(state, time_step)
|
next_state = self.model.forward(state, time_step)
|
||||||
|
else:
|
||||||
|
if self.uncertainty_threshold is not None:
|
||||||
|
next_state, uncertainty = self.model.forward_with_uncertainty(state, time_step)
|
||||||
|
if uncertainty > self.uncertainty_threshold:
|
||||||
|
raise HighUncertaintyError(uncertainty, self.uncertainty_threshold)
|
||||||
else:
|
else:
|
||||||
next_state = self.model.forward(state, time_step)
|
next_state = self.model.forward(state, time_step)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user