NuCon/nucon/rl.py
2024-10-02 18:45:06 +02:00

143 lines
6.0 KiB
Python

import gymnasium as gym
from gymnasium import spaces
import numpy as np
import time
from typing import Dict, Any
from .core import Nucon, BreakerStatus, PumpStatus, PumpDryStatus, PumpOverloadStatus
Objectives = {
"null": lambda obs: 0,
"coeff": lambda obj, coeff: lambda obs: obj(obs) * coeff,
"max_power": lambda obs: obs["GENERATOR_0_KW"] + obs["GENERATOR_1_KW"] + obs["GENERATOR_2_KW"],
"target_temperature": lambda goal_temp: lambda obs: (obs["CORE_TEMP"] - goal_temp) ** 2,
"episode_time": lambda obs: obs["EPISODE_TIME"],
}
class NuconEnv(gym.Env):
metadata = {'render_modes': ['human']}
def __init__(self, render_mode=None, seconds_per_step=5, objectives=['null'], terminators=['null'], terminate_above=0):
super().__init__()
self.render_mode = render_mode
self.seconds_per_step = seconds_per_step
self.terminate_at = terminate_at
# Define observation space
obs_spaces = {'EPISODE_TIME': spaces.Box(low=0, high=np.inf, shape=(1,), dtype=np.float32)}
for param in Nucon.get_all_readable():
if param.param_type == float:
obs_spaces[param.id] = spaces.Box(low=param.min_val or -np.inf, high=param.max_val or np.inf, shape=(1,), dtype=np.float32)
elif param.param_type == int:
if param.min_val is not None and param.max_val is not None:
obs_spaces[param.id] = spaces.Box(low=param.min_val, high=param.max_val, shape=(1,), dtype=np.float32)
else:
obs_spaces[param.id] = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32)
elif param.param_type == bool:
obs_spaces[param.id] = spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32)
elif issubclass(param.param_type, Enum):
obs_spaces[param.id] = spaces.Box(low=0, high=1, shape=(len(param.param_type),), dtype=np.float32)
else:
raise ValueError(f"Unsupported observation parameter type: {param.param_type}")
self.observation_space = spaces.Dict(obs_spaces)
# Define action space
action_spaces = {}
for param in Nucon.get_all_writable():
if param.param_type == float:
action_spaces[param.id] = spaces.Box(low=param.min_val or -np.inf, high=param.max_val or np.inf, shape=(1,), dtype=np.float32)
elif param.param_type == int:
if param.min_val is not None and param.max_val is not None:
action_spaces[param.id] = spaces.Box(low=param.min_val, high=param.max_val, shape=(1,), dtype=np.float32)
else:
action_spaces[param.id] = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32)
elif param.param_type == bool:
action_spaces[param.id] = spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32)
elif issubclass(param.param_type, Enum):
action_spaces[param.id] = spaces.Box(low=0, high=1, shape=(len(param.param_type),), dtype=np.float32)
else:
raise ValueError(f"Unsupported action parameter type: {param.param_type}")
self.action_space = spaces.Dict(action_spaces)
for objective in objectives:
if objective in Objectives:
self.objectives.append(Objectives[objective])
elif callable(objective):
self.objectives.append(objective)
else:
raise ValueError(f"Unsupported objective: {objective}")
for terminator in terminators:
if terminator in Objectives:
self.terminators.append(Objectives[terminator])
elif callable(terminator):
self.terminators.append(terminator)
else:
raise ValueError(f"Unsupported terminator: {terminator}")
def _get_obs(self):
obs = {}
for param in Nucon.get_all_readable():
value = Nucon.get(param)
if isinstance(value, Enum):
value = value.value
obs[param.id] = value
obs["EPISODE_TIME"] = self._total_steps * self.seconds_per_step
return obs
def _get_info(self):
info = {'objectives': {}}
for objective in self.objectives:
info['objectives'][objective.__name__] = objective(self._get_obs())
return info
def reset(self, seed=None, options=None):
super().reset(seed=seed)
self._total_steps = 0
observation = self._get_obs()
info = self._get_info()
return observation, info
def step(self, action):
# Apply the action to the Nucon system
for param_id, value in action.items():
param = next(p for p in Nucon if p.id == param_id)
if issubclass(param.param_type, Enum):
value = param.param_type(value)
if param.min_val is not None and param.max_val is not None:
value = np.clip(value, param.min_val, param.max_val)
Nucon.set(param, value)
observation = self._get_obs()
terminated = np.sum([terminator(observation) for terminator in self.terminators]) > self.terminate_above
truncated = False
info = self._get_info()
reward = sum(obj for obj in info['objectives'].values())
self._total_steps += 1
time.sleep(self.seconds_per_step)
return observation, reward, terminated, truncated, info
def render(self):
if self.render_mode == "human":
pass
def close(self):
pass
def _flatten_action(self, action):
return np.concatenate([v.flatten() for v in action.values()])
def _unflatten_action(self, flat_action):
return {k: v.reshape(1, -1) for k, v in self.action_space.items()}
def _flatten_observation(self, observation):
return np.concatenate([v.flatten() for v in observation.values()])
def _unflatten_observation(self, flat_observation):
return {k: v.reshape(1, -1) for k, v in self.observation_space.items()}