incorporated human render_mode gym spec + optimized keyword arguments
This commit is contained in:
parent
deaca46d87
commit
0b4e729a49
@ -10,7 +10,7 @@ from mushroom_rl.core import Environment
|
||||
class AirHockeyEnv(Environment):
|
||||
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 50}
|
||||
|
||||
def __init__(self, env_mode=None, interpolation_order=3, render_mode=None, **kwargs):
|
||||
def __init__(self, env_mode=None, interpolation_order=3, render_mode=None, width=1920, height=1080, **kwargs):
|
||||
"""
|
||||
Environment Constructor
|
||||
|
||||
@ -40,11 +40,10 @@ class AirHockeyEnv(Environment):
|
||||
interpolation_order = (interpolation_order, interpolation_order)
|
||||
|
||||
self.render_mode = render_mode
|
||||
self.render_human_active = False
|
||||
|
||||
# Determine headless mode based on render_mode
|
||||
headless = self.render_mode == 'rgb_array'
|
||||
width = kwargs.pop('width', 1920)
|
||||
height = kwargs.pop('height', 1080)
|
||||
|
||||
# Prepare viewer_params
|
||||
viewer_params = kwargs.get('viewer_params', {})
|
||||
@ -121,13 +120,16 @@ class AirHockeyEnv(Environment):
|
||||
if self.env_info['env_name'] == "tournament":
|
||||
obs = np.array(np.split(obs, 2))
|
||||
|
||||
if self.render_human_active:
|
||||
self.base_env.render()
|
||||
|
||||
return obs, reward, done, False, info
|
||||
|
||||
def render(self):
|
||||
if self.render_mode == 'rgb_array':
|
||||
return self.base_env.render(record = True)
|
||||
elif self.render_mode == 'human':
|
||||
self.base_env.render()
|
||||
self.render_human_active = True
|
||||
else:
|
||||
raise ValueError(f"Unsupported render mode: '{self.render_mode}'")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user