Morer objectives and fixes
This commit is contained in:
parent
33b5db2f57
commit
4c3ad983fc
68
nucon/rl.py
68
nucon/rl.py
@ -3,7 +3,7 @@ from gymnasium import spaces
|
||||
import numpy as np
|
||||
import time
|
||||
from typing import Dict, Any
|
||||
from .core import Nucon, BreakerStatus, PumpStatus, PumpDryStatus, PumpOverloadStatus
|
||||
from nucon import Nucon, BreakerStatus, PumpStatus, PumpDryStatus, PumpOverloadStatus
|
||||
|
||||
Objectives = {
|
||||
"null": lambda obs: 0,
|
||||
@ -13,12 +13,16 @@ Objectives = {
|
||||
|
||||
Parameterized_Objectives = {
|
||||
"target_temperature": lambda goal_temp: lambda obs: -((obs["CORE_TEMP"] - goal_temp) ** 2),
|
||||
"target_gap": lambda goal_gap: lambda obs: -((obs["CORE_TEMP"] - obs["CORE_TEMP_MIN"] - goal_gap) ** 2),
|
||||
"temp_below": lambda max_temp: lambda obs: -(np.clip(obs["CORE_TEMP"] - max_temp, 0, np.inf) ** 2),
|
||||
"temp_above": lambda min_temp: lambda obs: -(np.clip(min_temp - obs["CORE_TEMP"], 0, np.inf) ** 2),
|
||||
"constant": lambda constant: lambda obs: constant,
|
||||
}
|
||||
|
||||
class NuconEnv(gym.Env):
|
||||
metadata = {'render_modes': ['human']}
|
||||
|
||||
def __init__(self, render_mode=None, seconds_per_step=5, objectives=['null'], terminators=['null'], objective_weights=None, terminate_above=0):
|
||||
def __init__(self, nucon=None, simulator=None, render_mode=None, seconds_per_step=5, objectives=['null'], terminators=['null'], objective_weights=None, terminate_above=0):
|
||||
super().__init__()
|
||||
|
||||
self.render_mode = render_mode
|
||||
@ -26,22 +30,30 @@ class NuconEnv(gym.Env):
|
||||
if objective_weights is None:
|
||||
objective_weights = [1.0 for objective in objectives]
|
||||
self.objective_weights = objective_weights
|
||||
self.terminate_at = terminate_at
|
||||
self.terminate_above = terminate_above
|
||||
self.simulator = simulator
|
||||
|
||||
if nucon is None:
|
||||
if simulator:
|
||||
nucon = Nucon(port=simulator.port)
|
||||
else:
|
||||
nucon = Nucon()
|
||||
self.nucon = nucon
|
||||
|
||||
# Define observation space
|
||||
obs_spaces = {'EPISODE_TIME': spaces.Box(low=0, high=np.inf, shape=(1,), dtype=np.float32)}
|
||||
for param in Nucon.get_all_readable():
|
||||
for param_id, param in self.nucon.get_all_readable().items():
|
||||
if param.param_type == float:
|
||||
obs_spaces[param.id] = spaces.Box(low=param.min_val or -np.inf, high=param.max_val or np.inf, shape=(1,), dtype=np.float32)
|
||||
obs_spaces[param_id] = spaces.Box(low=param.min_val or -np.inf, high=param.max_val or np.inf, shape=(1,), dtype=np.float32)
|
||||
elif param.param_type == int:
|
||||
if param.min_val is not None and param.max_val is not None:
|
||||
obs_spaces[param.id] = spaces.Box(low=param.min_val, high=param.max_val, shape=(1,), dtype=np.float32)
|
||||
obs_spaces[param_id] = spaces.Box(low=param.min_val, high=param.max_val, shape=(1,), dtype=np.float32)
|
||||
else:
|
||||
obs_spaces[param.id] = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32)
|
||||
obs_spaces[param_id] = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32)
|
||||
elif param.param_type == bool:
|
||||
obs_spaces[param.id] = spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32)
|
||||
obs_spaces[param_id] = spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32)
|
||||
elif issubclass(param.param_type, Enum):
|
||||
obs_spaces[param.id] = spaces.Box(low=0, high=1, shape=(len(param.param_type),), dtype=np.float32)
|
||||
obs_spaces[param_id] = spaces.Box(low=0, high=1, shape=(len(param.param_type),), dtype=np.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported observation parameter type: {param.param_type}")
|
||||
|
||||
@ -49,23 +61,26 @@ class NuconEnv(gym.Env):
|
||||
|
||||
# Define action space
|
||||
action_spaces = {}
|
||||
for param in Nucon.get_all_writable():
|
||||
for param_id, param in self.nucon.get_all_writable().items():
|
||||
if param.param_type == float:
|
||||
action_spaces[param.id] = spaces.Box(low=param.min_val or -np.inf, high=param.max_val or np.inf, shape=(1,), dtype=np.float32)
|
||||
action_spaces[param_id] = spaces.Box(low=param.min_val or -np.inf, high=param.max_val or np.inf, shape=(1,), dtype=np.float32)
|
||||
elif param.param_type == int:
|
||||
if param.min_val is not None and param.max_val is not None:
|
||||
action_spaces[param.id] = spaces.Box(low=param.min_val, high=param.max_val, shape=(1,), dtype=np.float32)
|
||||
action_spaces[param_id] = spaces.Box(low=param.min_val, high=param.max_val, shape=(1,), dtype=np.float32)
|
||||
else:
|
||||
action_spaces[param.id] = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32)
|
||||
action_spaces[param_id] = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32)
|
||||
elif param.param_type == bool:
|
||||
action_spaces[param.id] = spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32)
|
||||
action_spaces[param_id] = spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32)
|
||||
elif issubclass(param.param_type, Enum):
|
||||
action_spaces[param.id] = spaces.Box(low=0, high=1, shape=(len(param.param_type),), dtype=np.float32)
|
||||
action_spaces[param_id] = spaces.Box(low=0, high=1, shape=(len(param.param_type),), dtype=np.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported action parameter type: {param.param_type}")
|
||||
|
||||
self.action_space = spaces.Dict(action_spaces)
|
||||
|
||||
self.objectives = []
|
||||
self.terminators = []
|
||||
|
||||
for objective in objectives:
|
||||
if objective in Objectives:
|
||||
self.objectives.append(Objectives[objective])
|
||||
@ -84,11 +99,11 @@ class NuconEnv(gym.Env):
|
||||
|
||||
def _get_obs(self):
|
||||
obs = {}
|
||||
for param in Nucon.get_all_readable():
|
||||
value = Nucon.get(param)
|
||||
for param_id, param in self.nucon.get_all_readable().items():
|
||||
value = self.nucon.get(param_id)
|
||||
if isinstance(value, Enum):
|
||||
value = value.value
|
||||
obs[param.id] = value
|
||||
obs[param_id] = value
|
||||
obs["EPISODE_TIME"] = self._total_steps * self.seconds_per_step
|
||||
return obs
|
||||
|
||||
@ -112,12 +127,12 @@ class NuconEnv(gym.Env):
|
||||
def step(self, action):
|
||||
# Apply the action to the Nucon system
|
||||
for param_id, value in action.items():
|
||||
param = next(p for p in Nucon if p.id == param_id)
|
||||
param = next(p for p in self.nucon if p.id == param_id)
|
||||
if issubclass(param.param_type, Enum):
|
||||
value = param.param_type(value)
|
||||
if param.min_val is not None and param.max_val is not None:
|
||||
value = np.clip(value, param.min_val, param.max_val)
|
||||
Nucon.set(param, value)
|
||||
self.nucon.set(param, value)
|
||||
|
||||
observation = self._get_obs()
|
||||
terminated = np.sum([terminator(observation) for terminator in self.terminators]) > self.terminate_above
|
||||
@ -126,6 +141,9 @@ class NuconEnv(gym.Env):
|
||||
reward = sum(obj for obj in info['objectives_weighted'].values())
|
||||
|
||||
self._total_steps += 1
|
||||
if self.simulator:
|
||||
self.simulator.update(self.seconds_per_step)
|
||||
else:
|
||||
time.sleep(self.seconds_per_step)
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
@ -148,6 +166,7 @@ class NuconEnv(gym.Env):
|
||||
def _unflatten_observation(self, flat_observation):
|
||||
return {k: v.reshape(1, -1) for k, v in self.observation_space.items()}
|
||||
|
||||
|
||||
def register_nucon_envs():
|
||||
gym.register(
|
||||
id='Nucon-max_power-v0',
|
||||
@ -155,9 +174,14 @@ def register_nucon_envs():
|
||||
kwargs={'seconds_per_step': 5, 'objectives': ['max_power']}
|
||||
)
|
||||
gym.register(
|
||||
id='Nucon-target_temperature_600-v0',
|
||||
id='Nucon-target_temperature_350-v0',
|
||||
entry_point='nucon.rl:NuconEnv',
|
||||
kwargs={'seconds_per_step': 5, 'objectives': [Parameterized_Objectives['target_temperature'](goal_temp=600)]}
|
||||
kwargs={'seconds_per_step': 5, 'objectives': [Parameterized_Objectives['target_temperature'](goal_temp=350)]}
|
||||
)
|
||||
gym.register(
|
||||
id='Nucon-safe_max_power-v0',
|
||||
entry_point='nucon.rl:NuconEnv',
|
||||
kwargs={'seconds_per_step': 5, 'objectives': [Parameterized_Objectives['temp_above'](min_temp=310), Parameterized_Objectives['temp_below'](max_temp=365), 'max_power'], 'objective_weights': [1, 10, 1/100_000]}
|
||||
)
|
||||
|
||||
register_nucon_envs()
|
Loading…
Reference in New Issue
Block a user