import gymnasium as gym from gymnasium import spaces import numpy as np import time from typing import Dict, Any from .core import Nucon, BreakerStatus, PumpStatus, PumpDryStatus, PumpOverloadStatus Objectives = { "null": lambda obs: 0, "coeff": lambda obj, coeff: lambda obs: obj(obs) * coeff, "max_power": lambda obs: obs["GENERATOR_0_KW"] + obs["GENERATOR_1_KW"] + obs["GENERATOR_2_KW"], "target_temperature": lambda goal_temp: lambda obs: (obs["CORE_TEMP"] - goal_temp) ** 2, "episode_time": lambda obs: obs["EPISODE_TIME"], } class NuconEnv(gym.Env): metadata = {'render_modes': ['human']} def __init__(self, render_mode=None, seconds_per_step=5, objectives=['null'], terminators=['null'], terminate_above=0): super().__init__() self.render_mode = render_mode self.seconds_per_step = seconds_per_step self.terminate_at = terminate_at # Define observation space obs_spaces = {'EPISODE_TIME': spaces.Box(low=0, high=np.inf, shape=(1,), dtype=np.float32)} for param in Nucon.get_all_readable(): if param.param_type == float: obs_spaces[param.id] = spaces.Box(low=param.min_val or -np.inf, high=param.max_val or np.inf, shape=(1,), dtype=np.float32) elif param.param_type == int: if param.min_val is not None and param.max_val is not None: obs_spaces[param.id] = spaces.Box(low=param.min_val, high=param.max_val, shape=(1,), dtype=np.float32) else: obs_spaces[param.id] = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32) elif param.param_type == bool: obs_spaces[param.id] = spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32) elif issubclass(param.param_type, Enum): obs_spaces[param.id] = spaces.Box(low=0, high=1, shape=(len(param.param_type),), dtype=np.float32) else: raise ValueError(f"Unsupported observation parameter type: {param.param_type}") self.observation_space = spaces.Dict(obs_spaces) # Define action space action_spaces = {} for param in Nucon.get_all_writable(): if param.param_type == float: action_spaces[param.id] = spaces.Box(low=param.min_val or -np.inf, high=param.max_val or np.inf, shape=(1,), dtype=np.float32) elif param.param_type == int: if param.min_val is not None and param.max_val is not None: action_spaces[param.id] = spaces.Box(low=param.min_val, high=param.max_val, shape=(1,), dtype=np.float32) else: action_spaces[param.id] = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32) elif param.param_type == bool: action_spaces[param.id] = spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32) elif issubclass(param.param_type, Enum): action_spaces[param.id] = spaces.Box(low=0, high=1, shape=(len(param.param_type),), dtype=np.float32) else: raise ValueError(f"Unsupported action parameter type: {param.param_type}") self.action_space = spaces.Dict(action_spaces) for objective in objectives: if objective in Objectives: self.objectives.append(Objectives[objective]) elif callable(objective): self.objectives.append(objective) else: raise ValueError(f"Unsupported objective: {objective}") for terminator in terminators: if terminator in Objectives: self.terminators.append(Objectives[terminator]) elif callable(terminator): self.terminators.append(terminator) else: raise ValueError(f"Unsupported terminator: {terminator}") def _get_obs(self): obs = {} for param in Nucon.get_all_readable(): value = Nucon.get(param) if isinstance(value, Enum): value = value.value obs[param.id] = value obs["EPISODE_TIME"] = self._total_steps * self.seconds_per_step return obs def _get_info(self): info = {'objectives': {}} for objective in self.objectives: info['objectives'][objective.__name__] = objective(self._get_obs()) return info def reset(self, seed=None, options=None): super().reset(seed=seed) self._total_steps = 0 observation = self._get_obs() info = self._get_info() return observation, info def step(self, action): # Apply the action to the Nucon system for param_id, value in action.items(): param = next(p for p in Nucon if p.id == param_id) if issubclass(param.param_type, Enum): value = param.param_type(value) if param.min_val is not None and param.max_val is not None: value = np.clip(value, param.min_val, param.max_val) Nucon.set(param, value) observation = self._get_obs() terminated = np.sum([terminator(observation) for terminator in self.terminators]) > self.terminate_above truncated = False info = self._get_info() reward = sum(obj for obj in info['objectives'].values()) self._total_steps += 1 time.sleep(self.seconds_per_step) return observation, reward, terminated, truncated, info def render(self): if self.render_mode == "human": pass def close(self): pass def _flatten_action(self, action): return np.concatenate([v.flatten() for v in action.values()]) def _unflatten_action(self, flat_action): return {k: v.reshape(1, -1) for k, v in self.action_space.items()} def _flatten_observation(self, observation): return np.concatenate([v.flatten() for v in observation.values()]) def _unflatten_observation(self, flat_observation): return {k: v.reshape(1, -1) for k, v in self.observation_space.items()}