From 0b4e729a49691230c443531ffa767aaebc373733 Mon Sep 17 00:00:00 2001 From: Kayen Date: Sun, 26 Nov 2023 21:49:52 +0100 Subject: [PATCH] incorporated human render_mode gym spec + optimized keyword arguments --- .../envs/mujoco/air_hockey/air_hockey_env_wrapper.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/fancy_gym/envs/mujoco/air_hockey/air_hockey_env_wrapper.py b/fancy_gym/envs/mujoco/air_hockey/air_hockey_env_wrapper.py index 5f7bb38..b5ef8c3 100644 --- a/fancy_gym/envs/mujoco/air_hockey/air_hockey_env_wrapper.py +++ b/fancy_gym/envs/mujoco/air_hockey/air_hockey_env_wrapper.py @@ -10,7 +10,7 @@ from mushroom_rl.core import Environment class AirHockeyEnv(Environment): metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 50} - def __init__(self, env_mode=None, interpolation_order=3, render_mode=None, **kwargs): + def __init__(self, env_mode=None, interpolation_order=3, render_mode=None, width=1920, height=1080, **kwargs): """ Environment Constructor @@ -40,11 +40,10 @@ class AirHockeyEnv(Environment): interpolation_order = (interpolation_order, interpolation_order) self.render_mode = render_mode + self.render_human_active = False # Determine headless mode based on render_mode headless = self.render_mode == 'rgb_array' - width = kwargs.pop('width', 1920) - height = kwargs.pop('height', 1080) # Prepare viewer_params viewer_params = kwargs.get('viewer_params', {}) @@ -121,13 +120,16 @@ class AirHockeyEnv(Environment): if self.env_info['env_name'] == "tournament": obs = np.array(np.split(obs, 2)) + if self.render_human_active: + self.base_env.render() + return obs, reward, done, False, info def render(self): if self.render_mode == 'rgb_array': return self.base_env.render(record = True) elif self.render_mode == 'human': - self.base_env.render() + self.render_human_active = True else: raise ValueError(f"Unsupported render mode: '{self.render_mode}'")