diff --git a/nucon/rl.py b/nucon/rl.py new file mode 100644 index 0000000..0da1028 --- /dev/null +++ b/nucon/rl.py @@ -0,0 +1,143 @@ +import gymnasium as gym +from gymnasium import spaces +import numpy as np +import time +from typing import Dict, Any +from .core import Nucon, BreakerStatus, PumpStatus, PumpDryStatus, PumpOverloadStatus + +Objectives = { + "null": lambda obs: 0, + "coeff": lambda obj, coeff: lambda obs: obj(obs) * coeff, + + "max_power": lambda obs: obs["GENERATOR_0_KW"] + obs["GENERATOR_1_KW"] + obs["GENERATOR_2_KW"], + "target_temperature": lambda goal_temp: lambda obs: (obs["CORE_TEMP"] - goal_temp) ** 2, + "episode_time": lambda obs: obs["EPISODE_TIME"], +} + +class NuconEnv(gym.Env): + metadata = {'render_modes': ['human']} + + def __init__(self, render_mode=None, seconds_per_step=5, objectives=['null'], terminators=['null'], terminate_above=0): + super().__init__() + + self.render_mode = render_mode + self.seconds_per_step = seconds_per_step + self.terminate_at = terminate_at + + # Define observation space + obs_spaces = {'EPISODE_TIME': spaces.Box(low=0, high=np.inf, shape=(1,), dtype=np.float32)} + for param in Nucon.get_all_readable(): + if param.param_type == float: + obs_spaces[param.id] = spaces.Box(low=param.min_val or -np.inf, high=param.max_val or np.inf, shape=(1,), dtype=np.float32) + elif param.param_type == int: + if param.min_val is not None and param.max_val is not None: + obs_spaces[param.id] = spaces.Box(low=param.min_val, high=param.max_val, shape=(1,), dtype=np.float32) + else: + obs_spaces[param.id] = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32) + elif param.param_type == bool: + obs_spaces[param.id] = spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32) + elif issubclass(param.param_type, Enum): + obs_spaces[param.id] = spaces.Box(low=0, high=1, shape=(len(param.param_type),), dtype=np.float32) + else: + raise ValueError(f"Unsupported observation parameter type: {param.param_type}") + + self.observation_space = spaces.Dict(obs_spaces) + + # Define action space + action_spaces = {} + for param in Nucon.get_all_writable(): + if param.param_type == float: + action_spaces[param.id] = spaces.Box(low=param.min_val or -np.inf, high=param.max_val or np.inf, shape=(1,), dtype=np.float32) + elif param.param_type == int: + if param.min_val is not None and param.max_val is not None: + action_spaces[param.id] = spaces.Box(low=param.min_val, high=param.max_val, shape=(1,), dtype=np.float32) + else: + action_spaces[param.id] = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32) + elif param.param_type == bool: + action_spaces[param.id] = spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32) + elif issubclass(param.param_type, Enum): + action_spaces[param.id] = spaces.Box(low=0, high=1, shape=(len(param.param_type),), dtype=np.float32) + else: + raise ValueError(f"Unsupported action parameter type: {param.param_type}") + + self.action_space = spaces.Dict(action_spaces) + + for objective in objectives: + if objective in Objectives: + self.objectives.append(Objectives[objective]) + elif callable(objective): + self.objectives.append(objective) + else: + raise ValueError(f"Unsupported objective: {objective}") + + for terminator in terminators: + if terminator in Objectives: + self.terminators.append(Objectives[terminator]) + elif callable(terminator): + self.terminators.append(terminator) + else: + raise ValueError(f"Unsupported terminator: {terminator}") + + def _get_obs(self): + obs = {} + for param in Nucon.get_all_readable(): + value = Nucon.get(param) + if isinstance(value, Enum): + value = value.value + obs[param.id] = value + obs["EPISODE_TIME"] = self._total_steps * self.seconds_per_step + return obs + + def _get_info(self): + info = {'objectives': {}} + for objective in self.objectives: + info['objectives'][objective.__name__] = objective(self._get_obs()) + return info + + def reset(self, seed=None, options=None): + super().reset(seed=seed) + + self._total_steps = 0 + observation = self._get_obs() + info = self._get_info() + + return observation, info + + def step(self, action): + # Apply the action to the Nucon system + for param_id, value in action.items(): + param = next(p for p in Nucon if p.id == param_id) + if issubclass(param.param_type, Enum): + value = param.param_type(value) + if param.min_val is not None and param.max_val is not None: + value = np.clip(value, param.min_val, param.max_val) + Nucon.set(param, value) + + observation = self._get_obs() + terminated = np.sum([terminator(observation) for terminator in self.terminators]) > self.terminate_above + truncated = False + info = self._get_info() + reward = sum(obj for obj in info['objectives'].values()) + + self._total_steps += 1 + time.sleep(self.seconds_per_step) + return observation, reward, terminated, truncated, info + + def render(self): + if self.render_mode == "human": + pass + + def close(self): + pass + + def _flatten_action(self, action): + return np.concatenate([v.flatten() for v in action.values()]) + + def _unflatten_action(self, flat_action): + return {k: v.reshape(1, -1) for k, v in self.action_space.items()} + + def _flatten_observation(self, observation): + return np.concatenate([v.flatten() for v in observation.values()]) + + def _unflatten_observation(self, flat_observation): + return {k: v.reshape(1, -1) for k, v in self.observation_space.items()} \ No newline at end of file