RL oh yeah

This commit is contained in:
Dominik Moritz Roth 2024-10-02 18:45:06 +02:00
parent 1f3a71ba96
commit 08828e2dec

143
nucon/rl.py Normal file
View File

@ -0,0 +1,143 @@
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import time
from typing import Dict, Any
from .core import Nucon, BreakerStatus, PumpStatus, PumpDryStatus, PumpOverloadStatus
Objectives = {
"null": lambda obs: 0,
"coeff": lambda obj, coeff: lambda obs: obj(obs) * coeff,
"max_power": lambda obs: obs["GENERATOR_0_KW"] + obs["GENERATOR_1_KW"] + obs["GENERATOR_2_KW"],
"target_temperature": lambda goal_temp: lambda obs: (obs["CORE_TEMP"] - goal_temp) ** 2,
"episode_time": lambda obs: obs["EPISODE_TIME"],
}
class NuconEnv(gym.Env):
metadata = {'render_modes': ['human']}
def __init__(self, render_mode=None, seconds_per_step=5, objectives=['null'], terminators=['null'], terminate_above=0):
super().__init__()
self.render_mode = render_mode
self.seconds_per_step = seconds_per_step
self.terminate_at = terminate_at
# Define observation space
obs_spaces = {'EPISODE_TIME': spaces.Box(low=0, high=np.inf, shape=(1,), dtype=np.float32)}
for param in Nucon.get_all_readable():
if param.param_type == float:
obs_spaces[param.id] = spaces.Box(low=param.min_val or -np.inf, high=param.max_val or np.inf, shape=(1,), dtype=np.float32)
elif param.param_type == int:
if param.min_val is not None and param.max_val is not None:
obs_spaces[param.id] = spaces.Box(low=param.min_val, high=param.max_val, shape=(1,), dtype=np.float32)
else:
obs_spaces[param.id] = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32)
elif param.param_type == bool:
obs_spaces[param.id] = spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32)
elif issubclass(param.param_type, Enum):
obs_spaces[param.id] = spaces.Box(low=0, high=1, shape=(len(param.param_type),), dtype=np.float32)
else:
raise ValueError(f"Unsupported observation parameter type: {param.param_type}")
self.observation_space = spaces.Dict(obs_spaces)
# Define action space
action_spaces = {}
for param in Nucon.get_all_writable():
if param.param_type == float:
action_spaces[param.id] = spaces.Box(low=param.min_val or -np.inf, high=param.max_val or np.inf, shape=(1,), dtype=np.float32)
elif param.param_type == int:
if param.min_val is not None and param.max_val is not None:
action_spaces[param.id] = spaces.Box(low=param.min_val, high=param.max_val, shape=(1,), dtype=np.float32)
else:
action_spaces[param.id] = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32)
elif param.param_type == bool:
action_spaces[param.id] = spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32)
elif issubclass(param.param_type, Enum):
action_spaces[param.id] = spaces.Box(low=0, high=1, shape=(len(param.param_type),), dtype=np.float32)
else:
raise ValueError(f"Unsupported action parameter type: {param.param_type}")
self.action_space = spaces.Dict(action_spaces)
for objective in objectives:
if objective in Objectives:
self.objectives.append(Objectives[objective])
elif callable(objective):
self.objectives.append(objective)
else:
raise ValueError(f"Unsupported objective: {objective}")
for terminator in terminators:
if terminator in Objectives:
self.terminators.append(Objectives[terminator])
elif callable(terminator):
self.terminators.append(terminator)
else:
raise ValueError(f"Unsupported terminator: {terminator}")
def _get_obs(self):
obs = {}
for param in Nucon.get_all_readable():
value = Nucon.get(param)
if isinstance(value, Enum):
value = value.value
obs[param.id] = value
obs["EPISODE_TIME"] = self._total_steps * self.seconds_per_step
return obs
def _get_info(self):
info = {'objectives': {}}
for objective in self.objectives:
info['objectives'][objective.__name__] = objective(self._get_obs())
return info
def reset(self, seed=None, options=None):
super().reset(seed=seed)
self._total_steps = 0
observation = self._get_obs()
info = self._get_info()
return observation, info
def step(self, action):
# Apply the action to the Nucon system
for param_id, value in action.items():
param = next(p for p in Nucon if p.id == param_id)
if issubclass(param.param_type, Enum):
value = param.param_type(value)
if param.min_val is not None and param.max_val is not None:
value = np.clip(value, param.min_val, param.max_val)
Nucon.set(param, value)
observation = self._get_obs()
terminated = np.sum([terminator(observation) for terminator in self.terminators]) > self.terminate_above
truncated = False
info = self._get_info()
reward = sum(obj for obj in info['objectives'].values())
self._total_steps += 1
time.sleep(self.seconds_per_step)
return observation, reward, terminated, truncated, info
def render(self):
if self.render_mode == "human":
pass
def close(self):
pass
def _flatten_action(self, action):
return np.concatenate([v.flatten() for v in action.values()])
def _unflatten_action(self, flat_action):
return {k: v.reshape(1, -1) for k, v in self.action_space.items()}
def _flatten_observation(self, observation):
return np.concatenate([v.flatten() for v in observation.values()])
def _unflatten_observation(self, flat_observation):
return {k: v.reshape(1, -1) for k, v in self.observation_space.items()}