Merge pull request #87 from kayendns/airhockey-off-screen-rendering

Implemented render_mode 'rgb_array' for AirHockeyEnv
This commit is contained in:
Onur 2023-12-22 10:09:28 +01:00 committed by GitHub
commit aa652e3610
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -8,9 +8,9 @@ from fancy_gym.envs.mujoco.air_hockey.utils import robot_to_world
from mushroom_rl.core import Environment from mushroom_rl.core import Environment
class AirHockeyEnv(Environment): class AirHockeyEnv(Environment):
metadata = {"render_modes": ["human"], "render_fps": 50} metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 50}
def __init__(self, env_mode=None, interpolation_order=3, render_mode=None, **kwargs): def __init__(self, env_mode=None, interpolation_order=3, render_mode=None, width=1920, height=1080, **kwargs):
""" """
Environment Constructor Environment Constructor
@ -42,6 +42,17 @@ class AirHockeyEnv(Environment):
if env_mode == "tournament" and type(interpolation_order) != tuple: if env_mode == "tournament" and type(interpolation_order) != tuple:
interpolation_order = (interpolation_order, interpolation_order) interpolation_order = (interpolation_order, interpolation_order)
self.render_mode = render_mode
self.render_human_active = False
# Determine headless mode based on render_mode
headless = self.render_mode == 'rgb_array'
# Prepare viewer_params
viewer_params = kwargs.get('viewer_params', {})
viewer_params.update({'headless': headless, 'width': width, 'height': height})
kwargs['viewer_params'] = viewer_params
self.base_env = env_dict[env_mode](interpolation_order=interpolation_order, **kwargs) self.base_env = env_dict[env_mode](interpolation_order=interpolation_order, **kwargs)
self.env_name = env_mode self.env_name = env_mode
self.env_info = self.base_env.env_info self.env_info = self.base_env.env_info
@ -89,9 +100,6 @@ class AirHockeyEnv(Environment):
self.env_info['constraints'] = constraint_list self.env_info['constraints'] = constraint_list
self.env_info['env_name'] = self.env_name self.env_info['env_name'] = self.env_name
self.render_mode = render_mode
self.render_human_active = False
super().__init__(self.base_env.info) super().__init__(self.base_env.info)
def step(self, action): def step(self, action):
@ -119,15 +127,21 @@ class AirHockeyEnv(Environment):
if self.env_info['env_name'] == "tournament": if self.env_info['env_name'] == "tournament":
obs = np.array(np.split(obs, 2)) obs = np.array(np.split(obs, 2))
if self.render_human_active: if self.render_human_active:
self.base_env.render() self.base_env.render()
return obs, reward, done, False, info return obs, reward, done, False, info
def render(self): def render(self):
self.render_human_active = True if self.render_mode == 'rgb_array':
return self.base_env.render(record = True)
elif self.render_mode == 'human':
self.render_human_active = True
self.base_env.render()
else:
raise ValueError(f"Unsupported render mode: '{self.render_mode}'")
def reset(self, seed=None, options={}): def reset(self, seed=None, options={}):
self.base_env.seed(seed) self.base_env.seed(seed)
obs = self.base_env.reset() obs = self.base_env.reset()
@ -185,4 +199,4 @@ if __name__ == "__main__":
J = 0. J = 0.
gamma = 1. gamma = 1.
steps = 0 steps = 0
env.reset() env.reset()