refactor: move UncertaintyPenalty/Abort into Parameterized_Objectives/Terminators dicts
- uncertainty_penalty -> Parameterized_Objectives['uncertainty_penalty'] - uncertainty_abort -> Parameterized_Terminators['uncertainty_abort'] - Add Parameterized_Terminators dict (same pattern as Parameterized_Objectives) - Keep UncertaintyPenalty / UncertaintyAbort as convenience aliases Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
041e0ec1bd
commit
2c1bbc1a31
65
nucon/rl.py
65
nucon/rl.py
@ -18,35 +18,7 @@ Objectives = {
|
||||
"episode_time": lambda obs: obs["EPISODE_TIME"],
|
||||
}
|
||||
|
||||
Parameterized_Objectives = {
|
||||
"target_temperature": lambda goal_temp: lambda obs: -((obs["CORE_TEMP"] - goal_temp) ** 2),
|
||||
"target_gap": lambda goal_gap: lambda obs: -((obs["CORE_TEMP"] - obs["CORE_TEMP_MIN"] - goal_gap) ** 2),
|
||||
"temp_below": lambda max_temp: lambda obs: -(np.clip(obs["CORE_TEMP"] - max_temp, 0, np.inf) ** 2),
|
||||
"temp_above": lambda min_temp: lambda obs: -(np.clip(min_temp - obs["CORE_TEMP"], 0, np.inf) ** 2),
|
||||
"constant": lambda constant: lambda obs: constant,
|
||||
}
|
||||
|
||||
|
||||
def UncertaintyPenalty(start: float = 0.3, scale: float = 1.0, mode: str = 'l2') -> Callable:
|
||||
"""Objective that penalises high simulator uncertainty.
|
||||
|
||||
Returns a callable ``(obs) -> float`` suitable for use as an objective or
|
||||
terminator in NuconEnv / NuconGoalEnv. Works because ``SIM_UNCERTAINTY``
|
||||
is injected into the obs dict whenever a simulator is active.
|
||||
|
||||
Args:
|
||||
start: uncertainty level at which the penalty starts (default 0.3).
|
||||
scale: penalty coefficient.
|
||||
mode: ``'l2'`` (quadratic, default) or ``'linear'``.
|
||||
|
||||
Example::
|
||||
|
||||
env = NuconEnv(
|
||||
objectives=['max_power', UncertaintyPenalty(start=0.3, scale=2.0)],
|
||||
objective_weights=[1.0, 1.0],
|
||||
simulator=simulator,
|
||||
)
|
||||
"""
|
||||
def _uncertainty_penalty(start=0.3, scale=1.0, mode='l2'):
|
||||
excess = lambda obs: max(0.0, obs.get('SIM_UNCERTAINTY', 0.0) - start)
|
||||
if mode == 'l2':
|
||||
return lambda obs: -scale * excess(obs) ** 2
|
||||
@ -55,25 +27,26 @@ def UncertaintyPenalty(start: float = 0.3, scale: float = 1.0, mode: str = 'l2')
|
||||
else:
|
||||
raise ValueError(f"Unknown mode '{mode}'. Use 'l2' or 'linear'.")
|
||||
|
||||
|
||||
def UncertaintyAbort(threshold: float = 0.7) -> Callable:
|
||||
"""Terminator that aborts the episode when simulator uncertainty is too high.
|
||||
|
||||
Returns a callable ``(obs) -> float`` for use as a *terminator*. When
|
||||
the GP posterior std exceeds ``threshold`` the episode is truncated
|
||||
(``terminated=True``).
|
||||
|
||||
Example::
|
||||
|
||||
env = NuconEnv(
|
||||
objectives=['max_power'],
|
||||
terminators=[UncertaintyAbort(threshold=0.7)],
|
||||
terminate_above=0,
|
||||
simulator=simulator,
|
||||
)
|
||||
"""
|
||||
def _uncertainty_abort(threshold=0.7):
|
||||
return lambda obs: 1.0 if obs.get('SIM_UNCERTAINTY', 0.0) >= threshold else 0.0
|
||||
|
||||
Parameterized_Objectives = {
|
||||
"target_temperature": lambda goal_temp: lambda obs: -((obs["CORE_TEMP"] - goal_temp) ** 2),
|
||||
"target_gap": lambda goal_gap: lambda obs: -((obs["CORE_TEMP"] - obs["CORE_TEMP_MIN"] - goal_gap) ** 2),
|
||||
"temp_below": lambda max_temp: lambda obs: -(np.clip(obs["CORE_TEMP"] - max_temp, 0, np.inf) ** 2),
|
||||
"temp_above": lambda min_temp: lambda obs: -(np.clip(min_temp - obs["CORE_TEMP"], 0, np.inf) ** 2),
|
||||
"constant": lambda constant: lambda obs: constant,
|
||||
"uncertainty_penalty": _uncertainty_penalty, # (start, scale, mode) -> (obs) -> float
|
||||
}
|
||||
|
||||
Parameterized_Terminators = {
|
||||
"uncertainty_abort": _uncertainty_abort, # (threshold,) -> (obs) -> float
|
||||
}
|
||||
|
||||
# Convenience aliases
|
||||
UncertaintyPenalty = _uncertainty_penalty
|
||||
UncertaintyAbort = _uncertainty_abort
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
|
||||
Loading…
Reference in New Issue
Block a user