diff --git a/nucon/rl.py b/nucon/rl.py index a44f4f4..3485e89 100644 --- a/nucon/rl.py +++ b/nucon/rl.py @@ -5,6 +5,10 @@ import time from typing import Dict, Any from enum import Enum from nucon import Nucon, BreakerStatus, PumpStatus, PumpDryStatus, PumpOverloadStatus +try: + from nucon.sim import HighUncertaintyError +except ImportError: + HighUncertaintyError = None Objectives = { "null": lambda obs: 0, @@ -127,7 +131,13 @@ class NuconEnv(gym.Env): self._total_steps += 1 if self.simulator: - self.simulator.update(self.seconds_per_step) + try: + self.simulator.update(self.seconds_per_step) + except Exception as e: + if HighUncertaintyError and isinstance(e, HighUncertaintyError): + truncated = True + else: + raise else: # Sleep to let the game advance seconds_per_step game-seconds, # accounting for the game's simulation speed multiplier. @@ -350,7 +360,13 @@ class NuconGoalEnv(gym.Env): self._total_steps += 1 if self.simulator: - self.simulator.update(self.seconds_per_step) + try: + self.simulator.update(self.seconds_per_step) + except Exception as e: + if HighUncertaintyError and isinstance(e, HighUncertaintyError): + truncated = True + else: + raise else: sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0 time.sleep(self.seconds_per_step / sim_speed) diff --git a/nucon/sim.py b/nucon/sim.py index 784b94b..1fcc8cf 100644 --- a/nucon/sim.py +++ b/nucon/sim.py @@ -8,6 +8,18 @@ import torch from nucon.model import ReactorDynamicsModel, ReactorKNNModel import pickle +class HighUncertaintyError(Exception): + """Raised by NuconSimulator when the dynamics model uncertainty exceeds the threshold. + + Caught by NuconEnv/NuconGoalEnv and returned as truncated=True so the RL + algorithm bootstraps the value rather than treating it as a terminal state. + """ + def __init__(self, uncertainty: float, threshold: float): + self.uncertainty = uncertainty + self.threshold = threshold + super().__init__(f"Model uncertainty {uncertainty:.3f} exceeded threshold {threshold:.3f}") + + class OperatingState(Enum): # Tuple indicates a range of values, while list indicates a set of possible values OFFLINE = { @@ -163,13 +175,14 @@ class NuconSimulator: for param_name in nucon.get_all_readable(): setattr(self, param_name, None) - def __init__(self, host: str = 'localhost', port: int = 8786): + def __init__(self, host: str = 'localhost', port: int = 8786, uncertainty_threshold: float = None): self._nucon = Nucon() self.parameters = self.Parameters(self._nucon) self.host = host self.port = port self.time = 0.0 self.allow_all_writes = False + self.uncertainty_threshold = uncertainty_threshold self.set_state(OperatingState.OFFLINE) self.model = None self.readable_params = list(self._nucon.get_all_readable().keys()) @@ -268,7 +281,12 @@ class NuconSimulator: with torch.no_grad(): next_state = self.model.forward(state, time_step) else: - next_state = self.model.forward(state, time_step) + if self.uncertainty_threshold is not None: + next_state, uncertainty = self.model.forward_with_uncertainty(state, time_step) + if uncertainty > self.uncertainty_threshold: + raise HighUncertaintyError(uncertainty, self.uncertainty_threshold) + else: + next_state = self.model.forward(state, time_step) # Update only the output params the model predicts for param_id, value in next_state.items():