diff --git a/fancy_gym/envs/mujoco/air_hockey/air_hockey_env_wrapper.py b/fancy_gym/envs/mujoco/air_hockey/air_hockey_env_wrapper.py index e943c53..64972d8 100644 --- a/fancy_gym/envs/mujoco/air_hockey/air_hockey_env_wrapper.py +++ b/fancy_gym/envs/mujoco/air_hockey/air_hockey_env_wrapper.py @@ -8,9 +8,9 @@ from fancy_gym.envs.mujoco.air_hockey.utils import robot_to_world from mushroom_rl.core import Environment class AirHockeyEnv(Environment): - metadata = {"render_modes": ["human"], "render_fps": 50} + metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 50} - def __init__(self, env_mode=None, interpolation_order=3, render_mode=None, **kwargs): + def __init__(self, env_mode=None, interpolation_order=3, render_mode=None, width=1920, height=1080, **kwargs): """ Environment Constructor @@ -42,6 +42,17 @@ class AirHockeyEnv(Environment): if env_mode == "tournament" and type(interpolation_order) != tuple: interpolation_order = (interpolation_order, interpolation_order) + self.render_mode = render_mode + self.render_human_active = False + + # Determine headless mode based on render_mode + headless = self.render_mode == 'rgb_array' + + # Prepare viewer_params + viewer_params = kwargs.get('viewer_params', {}) + viewer_params.update({'headless': headless, 'width': width, 'height': height}) + kwargs['viewer_params'] = viewer_params + self.base_env = env_dict[env_mode](interpolation_order=interpolation_order, **kwargs) self.env_name = env_mode self.env_info = self.base_env.env_info @@ -89,9 +100,6 @@ class AirHockeyEnv(Environment): self.env_info['constraints'] = constraint_list self.env_info['env_name'] = self.env_name - self.render_mode = render_mode - self.render_human_active = False - super().__init__(self.base_env.info) def step(self, action): @@ -119,15 +127,21 @@ class AirHockeyEnv(Environment): if self.env_info['env_name'] == "tournament": obs = np.array(np.split(obs, 2)) - + if self.render_human_active: self.base_env.render() return obs, reward, done, False, info def render(self): - self.render_human_active = True - + if self.render_mode == 'rgb_array': + return self.base_env.render(record = True) + elif self.render_mode == 'human': + self.render_human_active = True + self.base_env.render() + else: + raise ValueError(f"Unsupported render mode: '{self.render_mode}'") + def reset(self, seed=None, options={}): self.base_env.seed(seed) obs = self.base_env.reset() @@ -185,4 +199,4 @@ if __name__ == "__main__": J = 0. gamma = 1. steps = 0 - env.reset() \ No newline at end of file + env.reset()