feat: uncertainty-aware training with penalty and abort
sim.py:
- simulator.update(return_uncertainty=True) calls forward_with_uncertainty
on kNN models and returns the GP std; returns None for NN or when not
requested (no extra cost if unused)
- No state stored on simulator; caller decides what to do with the value
rl.py (NuconEnv and NuconGoalEnv):
- uncertainty_penalty_start: above this GP std, subtract a linear penalty
from the reward (scaled by uncertainty_penalty_scale, default 1.0)
- uncertainty_abort: at or above this GP std, set truncated=True
- Only calls update(return_uncertainty=True) when either threshold is set
- Uncertainty only applies when using a simulator (kNN model); ignored otherwise
Example:
simulator = NuconSimulator()
simulator.load_model('reactor_knn.pkl')
env = NuconGoalEnv(..., simulator=simulator,
uncertainty_penalty_start=0.3,
uncertainty_abort=0.7,
uncertainty_penalty_scale=2.0)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
6cb93ad56d
commit
65190dffea
42
nucon/rl.py
42
nucon/rl.py
@ -5,10 +5,6 @@ import time
|
|||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from nucon import Nucon, BreakerStatus, PumpStatus, PumpDryStatus, PumpOverloadStatus
|
from nucon import Nucon, BreakerStatus, PumpStatus, PumpDryStatus, PumpOverloadStatus
|
||||||
try:
|
|
||||||
from nucon.sim import HighUncertaintyError
|
|
||||||
except ImportError:
|
|
||||||
HighUncertaintyError = None
|
|
||||||
|
|
||||||
Objectives = {
|
Objectives = {
|
||||||
"null": lambda obs: 0,
|
"null": lambda obs: 0,
|
||||||
@ -27,7 +23,8 @@ Parameterized_Objectives = {
|
|||||||
class NuconEnv(gym.Env):
|
class NuconEnv(gym.Env):
|
||||||
metadata = {'render_modes': ['human']}
|
metadata = {'render_modes': ['human']}
|
||||||
|
|
||||||
def __init__(self, nucon=None, simulator=None, render_mode=None, seconds_per_step=5, objectives=['null'], terminators=['null'], objective_weights=None, terminate_above=0):
|
def __init__(self, nucon=None, simulator=None, render_mode=None, seconds_per_step=5, objectives=['null'], terminators=['null'], objective_weights=None, terminate_above=0,
|
||||||
|
uncertainty_penalty_start: float = None, uncertainty_abort: float = None, uncertainty_penalty_scale: float = 1.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.render_mode = render_mode
|
self.render_mode = render_mode
|
||||||
@ -37,6 +34,9 @@ class NuconEnv(gym.Env):
|
|||||||
self.objective_weights = objective_weights
|
self.objective_weights = objective_weights
|
||||||
self.terminate_above = terminate_above
|
self.terminate_above = terminate_above
|
||||||
self.simulator = simulator
|
self.simulator = simulator
|
||||||
|
self.uncertainty_penalty_start = uncertainty_penalty_start
|
||||||
|
self.uncertainty_abort = uncertainty_abort
|
||||||
|
self.uncertainty_penalty_scale = uncertainty_penalty_scale
|
||||||
|
|
||||||
if nucon is None:
|
if nucon is None:
|
||||||
if simulator:
|
if simulator:
|
||||||
@ -131,16 +131,14 @@ class NuconEnv(gym.Env):
|
|||||||
|
|
||||||
self._total_steps += 1
|
self._total_steps += 1
|
||||||
if self.simulator:
|
if self.simulator:
|
||||||
try:
|
needs_uncertainty = self.uncertainty_penalty_start is not None or self.uncertainty_abort is not None
|
||||||
self.simulator.update(self.seconds_per_step)
|
uncertainty = self.simulator.update(self.seconds_per_step, return_uncertainty=needs_uncertainty)
|
||||||
except Exception as e:
|
if uncertainty is not None:
|
||||||
if HighUncertaintyError and isinstance(e, HighUncertaintyError):
|
if self.uncertainty_abort is not None and uncertainty >= self.uncertainty_abort:
|
||||||
truncated = True
|
truncated = True
|
||||||
|
if self.uncertainty_penalty_start is not None and uncertainty > self.uncertainty_penalty_start:
|
||||||
|
reward -= self.uncertainty_penalty_scale * (uncertainty - self.uncertainty_penalty_start)
|
||||||
else:
|
else:
|
||||||
raise
|
|
||||||
else:
|
|
||||||
# Sleep to let the game advance seconds_per_step game-seconds,
|
|
||||||
# accounting for the game's simulation speed multiplier.
|
|
||||||
sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0
|
sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0
|
||||||
time.sleep(self.seconds_per_step / sim_speed)
|
time.sleep(self.seconds_per_step / sim_speed)
|
||||||
return observation, reward, terminated, truncated, info
|
return observation, reward, terminated, truncated, info
|
||||||
@ -222,6 +220,9 @@ class NuconGoalEnv(gym.Env):
|
|||||||
seconds_per_step=5,
|
seconds_per_step=5,
|
||||||
terminators=None,
|
terminators=None,
|
||||||
terminate_above=0,
|
terminate_above=0,
|
||||||
|
uncertainty_penalty_start: float = None,
|
||||||
|
uncertainty_abort: float = None,
|
||||||
|
uncertainty_penalty_scale: float = 1.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -286,6 +287,9 @@ class NuconGoalEnv(gym.Env):
|
|||||||
|
|
||||||
# Terminators
|
# Terminators
|
||||||
self._terminators = terminators or []
|
self._terminators = terminators or []
|
||||||
|
self.uncertainty_penalty_start = uncertainty_penalty_start
|
||||||
|
self.uncertainty_abort = uncertainty_abort
|
||||||
|
self.uncertainty_penalty_scale = uncertainty_penalty_scale
|
||||||
|
|
||||||
self._desired_goal = np.zeros(n_goals, dtype=np.float32)
|
self._desired_goal = np.zeros(n_goals, dtype=np.float32)
|
||||||
self._total_steps = 0
|
self._total_steps = 0
|
||||||
@ -360,13 +364,13 @@ class NuconGoalEnv(gym.Env):
|
|||||||
|
|
||||||
self._total_steps += 1
|
self._total_steps += 1
|
||||||
if self.simulator:
|
if self.simulator:
|
||||||
try:
|
needs_uncertainty = self.uncertainty_penalty_start is not None or self.uncertainty_abort is not None
|
||||||
self.simulator.update(self.seconds_per_step)
|
uncertainty = self.simulator.update(self.seconds_per_step, return_uncertainty=needs_uncertainty)
|
||||||
except Exception as e:
|
if uncertainty is not None:
|
||||||
if HighUncertaintyError and isinstance(e, HighUncertaintyError):
|
if self.uncertainty_abort is not None and uncertainty >= self.uncertainty_abort:
|
||||||
truncated = True
|
truncated = True
|
||||||
else:
|
if self.uncertainty_penalty_start is not None and uncertainty > self.uncertainty_penalty_start:
|
||||||
raise
|
reward -= self.uncertainty_penalty_scale * (uncertainty - self.uncertainty_penalty_start)
|
||||||
else:
|
else:
|
||||||
sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0
|
sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0
|
||||||
time.sleep(self.seconds_per_step / sim_speed)
|
time.sleep(self.seconds_per_step / sim_speed)
|
||||||
|
|||||||
38
nucon/sim.py
38
nucon/sim.py
@ -8,18 +8,6 @@ import torch
|
|||||||
from nucon.model import ReactorDynamicsModel, ReactorKNNModel
|
from nucon.model import ReactorDynamicsModel, ReactorKNNModel
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
class HighUncertaintyError(Exception):
|
|
||||||
"""Raised by NuconSimulator when the dynamics model uncertainty exceeds the threshold.
|
|
||||||
|
|
||||||
Caught by NuconEnv/NuconGoalEnv and returned as truncated=True so the RL
|
|
||||||
algorithm bootstraps the value rather than treating it as a terminal state.
|
|
||||||
"""
|
|
||||||
def __init__(self, uncertainty: float, threshold: float):
|
|
||||||
self.uncertainty = uncertainty
|
|
||||||
self.threshold = threshold
|
|
||||||
super().__init__(f"Model uncertainty {uncertainty:.3f} exceeded threshold {threshold:.3f}")
|
|
||||||
|
|
||||||
|
|
||||||
class OperatingState(Enum):
|
class OperatingState(Enum):
|
||||||
# Tuple indicates a range of values, while list indicates a set of possible values
|
# Tuple indicates a range of values, while list indicates a set of possible values
|
||||||
OFFLINE = {
|
OFFLINE = {
|
||||||
@ -175,14 +163,13 @@ class NuconSimulator:
|
|||||||
for param_name in nucon.get_all_readable():
|
for param_name in nucon.get_all_readable():
|
||||||
setattr(self, param_name, None)
|
setattr(self, param_name, None)
|
||||||
|
|
||||||
def __init__(self, host: str = 'localhost', port: int = 8786, uncertainty_threshold: float = None):
|
def __init__(self, host: str = 'localhost', port: int = 8786):
|
||||||
self._nucon = Nucon()
|
self._nucon = Nucon()
|
||||||
self.parameters = self.Parameters(self._nucon)
|
self.parameters = self.Parameters(self._nucon)
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
self.time = 0.0
|
self.time = 0.0
|
||||||
self.allow_all_writes = False
|
self.allow_all_writes = False
|
||||||
self.uncertainty_threshold = uncertainty_threshold
|
|
||||||
self.set_state(OperatingState.OFFLINE)
|
self.set_state(OperatingState.OFFLINE)
|
||||||
self.model = None
|
self.model = None
|
||||||
self.readable_params = list(self._nucon.get_all_readable().keys())
|
self.readable_params = list(self._nucon.get_all_readable().keys())
|
||||||
@ -228,9 +215,16 @@ class NuconSimulator:
|
|||||||
def set_allow_all_writes(self, allow: bool) -> None:
|
def set_allow_all_writes(self, allow: bool) -> None:
|
||||||
self.allow_all_writes = allow
|
self.allow_all_writes = allow
|
||||||
|
|
||||||
def update(self, time_step: float) -> None:
|
def update(self, time_step: float, return_uncertainty: bool = False):
|
||||||
self._update_reactor_state(time_step)
|
"""Advance the simulator by time_step game-seconds.
|
||||||
|
|
||||||
|
If return_uncertainty=True and a kNN model is loaded, returns the GP
|
||||||
|
posterior std for this step (0 = on known data, ~1 = OOD).
|
||||||
|
Always returns None when using an NN model.
|
||||||
|
"""
|
||||||
|
uncertainty = self._update_reactor_state(time_step, return_uncertainty=return_uncertainty)
|
||||||
self.time += time_step
|
self.time += time_step
|
||||||
|
return uncertainty
|
||||||
|
|
||||||
def set_model(self, model) -> None:
|
def set_model(self, model) -> None:
|
||||||
"""Set a pre-loaded ReactorDynamicsModel or ReactorKNNModel directly."""
|
"""Set a pre-loaded ReactorDynamicsModel or ReactorKNNModel directly."""
|
||||||
@ -262,7 +256,7 @@ class NuconSimulator:
|
|||||||
print(f"Error loading model: {str(e)}")
|
print(f"Error loading model: {str(e)}")
|
||||||
self.model = None
|
self.model = None
|
||||||
|
|
||||||
def _update_reactor_state(self, time_step: float) -> None:
|
def _update_reactor_state(self, time_step: float, return_uncertainty: bool = False):
|
||||||
if not self.model:
|
if not self.model:
|
||||||
raise ValueError("Model not set. Please load a model using load_model() or set_model().")
|
raise ValueError("Model not set. Please load a model using load_model() or set_model().")
|
||||||
|
|
||||||
@ -276,15 +270,13 @@ class NuconSimulator:
|
|||||||
value = 0.0 # fallback for params not initialised in sim state
|
value = 0.0 # fallback for params not initialised in sim state
|
||||||
state[param_id] = value
|
state[param_id] = value
|
||||||
|
|
||||||
# Forward pass — same interface for both NN and kNN
|
# Forward pass
|
||||||
|
uncertainty = None
|
||||||
if isinstance(self.model, ReactorDynamicsModel):
|
if isinstance(self.model, ReactorDynamicsModel):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
next_state = self.model.forward(state, time_step)
|
next_state = self.model.forward(state, time_step)
|
||||||
else:
|
elif return_uncertainty:
|
||||||
if self.uncertainty_threshold is not None:
|
|
||||||
next_state, uncertainty = self.model.forward_with_uncertainty(state, time_step)
|
next_state, uncertainty = self.model.forward_with_uncertainty(state, time_step)
|
||||||
if uncertainty > self.uncertainty_threshold:
|
|
||||||
raise HighUncertaintyError(uncertainty, self.uncertainty_threshold)
|
|
||||||
else:
|
else:
|
||||||
next_state = self.model.forward(state, time_step)
|
next_state = self.model.forward(state, time_step)
|
||||||
|
|
||||||
@ -295,6 +287,8 @@ class NuconSimulator:
|
|||||||
except (ValueError, KeyError):
|
except (ValueError, KeyError):
|
||||||
pass # ignore params that can't be set (type mismatch, unknown)
|
pass # ignore params that can't be set (type mismatch, unknown)
|
||||||
|
|
||||||
|
return uncertainty
|
||||||
|
|
||||||
def set_state(self, state: OperatingState) -> None:
|
def set_state(self, state: OperatingState) -> None:
|
||||||
self._sample_parameters_from_state(state)
|
self._sample_parameters_from_state(state)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user