Merge pull request #87 from kayendns/airhockey-off-screen-rendering
Implemented render_mode 'rgb_array' for AirHockeyEnv
This commit is contained in:
		
						commit
						aa652e3610
					
				| @ -8,9 +8,9 @@ from fancy_gym.envs.mujoco.air_hockey.utils import robot_to_world | ||||
| from mushroom_rl.core import Environment | ||||
| 
 | ||||
| class AirHockeyEnv(Environment): | ||||
|     metadata = {"render_modes": ["human"], "render_fps": 50} | ||||
|     metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 50} | ||||
| 
 | ||||
|     def __init__(self, env_mode=None, interpolation_order=3, render_mode=None, **kwargs): | ||||
|     def __init__(self, env_mode=None, interpolation_order=3, render_mode=None, width=1920, height=1080, **kwargs): | ||||
|         """ | ||||
|         Environment Constructor | ||||
| 
 | ||||
| @ -42,6 +42,17 @@ class AirHockeyEnv(Environment): | ||||
|         if env_mode == "tournament" and type(interpolation_order) != tuple: | ||||
|             interpolation_order = (interpolation_order, interpolation_order) | ||||
| 
 | ||||
|         self.render_mode = render_mode | ||||
|         self.render_human_active = False | ||||
| 
 | ||||
|         # Determine headless mode based on render_mode | ||||
|         headless = self.render_mode == 'rgb_array' | ||||
|          | ||||
|         # Prepare viewer_params | ||||
|         viewer_params = kwargs.get('viewer_params', {}) | ||||
|         viewer_params.update({'headless': headless, 'width': width, 'height': height}) | ||||
|         kwargs['viewer_params'] = viewer_params | ||||
| 
 | ||||
|         self.base_env = env_dict[env_mode](interpolation_order=interpolation_order, **kwargs) | ||||
|         self.env_name = env_mode | ||||
|         self.env_info = self.base_env.env_info | ||||
| @ -89,9 +100,6 @@ class AirHockeyEnv(Environment): | ||||
|         self.env_info['constraints'] = constraint_list | ||||
|         self.env_info['env_name'] = self.env_name | ||||
| 
 | ||||
|         self.render_mode = render_mode | ||||
|         self.render_human_active = False | ||||
| 
 | ||||
|         super().__init__(self.base_env.info) | ||||
| 
 | ||||
|     def step(self, action): | ||||
| @ -119,15 +127,21 @@ class AirHockeyEnv(Environment): | ||||
| 
 | ||||
|         if self.env_info['env_name'] == "tournament": | ||||
|             obs = np.array(np.split(obs, 2)) | ||||
|          | ||||
| 
 | ||||
|         if self.render_human_active: | ||||
|             self.base_env.render() | ||||
| 
 | ||||
|         return obs, reward, done, False, info | ||||
| 
 | ||||
|     def render(self): | ||||
|         self.render_human_active = True | ||||
| 
 | ||||
|         if self.render_mode == 'rgb_array': | ||||
|             return self.base_env.render(record = True) | ||||
|         elif self.render_mode == 'human': | ||||
|             self.render_human_active = True | ||||
|             self.base_env.render() | ||||
|         else: | ||||
|             raise ValueError(f"Unsupported render mode: '{self.render_mode}'") | ||||
|              | ||||
|     def reset(self, seed=None, options={}): | ||||
|         self.base_env.seed(seed) | ||||
|         obs = self.base_env.reset() | ||||
| @ -185,4 +199,4 @@ if __name__ == "__main__": | ||||
|             J = 0. | ||||
|             gamma = 1. | ||||
|             steps = 0 | ||||
|             env.reset() | ||||
|             env.reset() | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user