diff --git a/nucon/rl.py b/nucon/rl.py index 0da1028..fa9158c 100644 --- a/nucon/rl.py +++ b/nucon/rl.py @@ -10,10 +10,13 @@ Objectives = { "coeff": lambda obj, coeff: lambda obs: obj(obs) * coeff, "max_power": lambda obs: obs["GENERATOR_0_KW"] + obs["GENERATOR_1_KW"] + obs["GENERATOR_2_KW"], - "target_temperature": lambda goal_temp: lambda obs: (obs["CORE_TEMP"] - goal_temp) ** 2, "episode_time": lambda obs: obs["EPISODE_TIME"], } +Parameterized_Objectives = { + "target_temperature": lambda goal_temp: lambda obs: -((obs["CORE_TEMP"] - goal_temp) ** 2), +} + class NuconEnv(gym.Env): metadata = {'render_modes': ['human']} @@ -140,4 +143,18 @@ class NuconEnv(gym.Env): return np.concatenate([v.flatten() for v in observation.values()]) def _unflatten_observation(self, flat_observation): - return {k: v.reshape(1, -1) for k, v in self.observation_space.items()} \ No newline at end of file + return {k: v.reshape(1, -1) for k, v in self.observation_space.items()} + +def register_nucon_envs(): + gym.register( + id='Nucon-max_power-v0', + entry_point='nucon.rl:NuconEnv', + kwargs={'seconds_per_step': 5, 'objectives': ['max_power']} + ) + gym.register( + id='Nucon-target_temperature_600-v0', + entry_point='nucon.rl:NuconEnv', + kwargs={'seconds_per_step': 5, 'objectives': [Parameterized_Objectives['target_temperature'](goal_temp=600)]} + ) + +register_nucon_envs() \ No newline at end of file