From dc59173fe71dbb84d5d006c2dc36c43646a9c3ec Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 2 Oct 2024 19:22:23 +0200 Subject: [PATCH] Better parameterized objectives and gym bindings --- nucon/rl.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/nucon/rl.py b/nucon/rl.py index 0da1028..fa9158c 100644 --- a/nucon/rl.py +++ b/nucon/rl.py @@ -10,10 +10,13 @@ Objectives = { "coeff": lambda obj, coeff: lambda obs: obj(obs) * coeff, "max_power": lambda obs: obs["GENERATOR_0_KW"] + obs["GENERATOR_1_KW"] + obs["GENERATOR_2_KW"], - "target_temperature": lambda goal_temp: lambda obs: (obs["CORE_TEMP"] - goal_temp) ** 2, "episode_time": lambda obs: obs["EPISODE_TIME"], } +Parameterized_Objectives = { + "target_temperature": lambda goal_temp: lambda obs: -((obs["CORE_TEMP"] - goal_temp) ** 2), +} + class NuconEnv(gym.Env): metadata = {'render_modes': ['human']} @@ -140,4 +143,18 @@ class NuconEnv(gym.Env): return np.concatenate([v.flatten() for v in observation.values()]) def _unflatten_observation(self, flat_observation): - return {k: v.reshape(1, -1) for k, v in self.observation_space.items()} \ No newline at end of file + return {k: v.reshape(1, -1) for k, v in self.observation_space.items()} + +def register_nucon_envs(): + gym.register( + id='Nucon-max_power-v0', + entry_point='nucon.rl:NuconEnv', + kwargs={'seconds_per_step': 5, 'objectives': ['max_power']} + ) + gym.register( + id='Nucon-target_temperature_600-v0', + entry_point='nucon.rl:NuconEnv', + kwargs={'seconds_per_step': 5, 'objectives': [Parameterized_Objectives['target_temperature'](goal_temp=600)]} + ) + +register_nucon_envs() \ No newline at end of file