added rgb_array render mode for off-screen rendering
This commit is contained in:
parent
dbd3caebb3
commit
deaca46d87
@ -8,7 +8,7 @@ from fancy_gym.envs.mujoco.air_hockey.utils import robot_to_world
|
|||||||
from mushroom_rl.core import Environment
|
from mushroom_rl.core import Environment
|
||||||
|
|
||||||
class AirHockeyEnv(Environment):
|
class AirHockeyEnv(Environment):
|
||||||
metadata = {"render_modes": ["human"], "render_fps": 50}
|
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 50}
|
||||||
|
|
||||||
def __init__(self, env_mode=None, interpolation_order=3, render_mode=None, **kwargs):
|
def __init__(self, env_mode=None, interpolation_order=3, render_mode=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
@ -39,6 +39,18 @@ class AirHockeyEnv(Environment):
|
|||||||
if env_mode == "tournament" and type(interpolation_order) != tuple:
|
if env_mode == "tournament" and type(interpolation_order) != tuple:
|
||||||
interpolation_order = (interpolation_order, interpolation_order)
|
interpolation_order = (interpolation_order, interpolation_order)
|
||||||
|
|
||||||
|
self.render_mode = render_mode
|
||||||
|
|
||||||
|
# Determine headless mode based on render_mode
|
||||||
|
headless = self.render_mode == 'rgb_array'
|
||||||
|
width = kwargs.pop('width', 1920)
|
||||||
|
height = kwargs.pop('height', 1080)
|
||||||
|
|
||||||
|
# Prepare viewer_params
|
||||||
|
viewer_params = kwargs.get('viewer_params', {})
|
||||||
|
viewer_params.update({'headless': headless, 'width': width, 'height': height})
|
||||||
|
kwargs['viewer_params'] = viewer_params
|
||||||
|
|
||||||
self.base_env = env_dict[env_mode](interpolation_order=interpolation_order, **kwargs)
|
self.base_env = env_dict[env_mode](interpolation_order=interpolation_order, **kwargs)
|
||||||
self.env_name = env_mode
|
self.env_name = env_mode
|
||||||
self.env_info = self.base_env.env_info
|
self.env_info = self.base_env.env_info
|
||||||
@ -81,9 +93,6 @@ class AirHockeyEnv(Environment):
|
|||||||
self.env_info['constraints'] = constraint_list
|
self.env_info['constraints'] = constraint_list
|
||||||
self.env_info['env_name'] = self.env_name
|
self.env_info['env_name'] = self.env_name
|
||||||
|
|
||||||
self.render_mode = render_mode
|
|
||||||
self.render_human_active = False
|
|
||||||
|
|
||||||
super().__init__(self.base_env.info)
|
super().__init__(self.base_env.info)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
@ -112,13 +121,15 @@ class AirHockeyEnv(Environment):
|
|||||||
if self.env_info['env_name'] == "tournament":
|
if self.env_info['env_name'] == "tournament":
|
||||||
obs = np.array(np.split(obs, 2))
|
obs = np.array(np.split(obs, 2))
|
||||||
|
|
||||||
if self.render_human_active:
|
|
||||||
self.base_env.render()
|
|
||||||
|
|
||||||
return obs, reward, done, False, info
|
return obs, reward, done, False, info
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
self.render_human_active = True
|
if self.render_mode == 'rgb_array':
|
||||||
|
return self.base_env.render(record = True)
|
||||||
|
elif self.render_mode == 'human':
|
||||||
|
self.base_env.render()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported render mode: '{self.render_mode}'")
|
||||||
|
|
||||||
def reset(self, seed=None, options={}):
|
def reset(self, seed=None, options={}):
|
||||||
self.base_env.seed(seed)
|
self.base_env.seed(seed)
|
||||||
|
Loading…
Reference in New Issue
Block a user