incorporated human render_mode gym spec + optimized keyword arguments
This commit is contained in:
		
							parent
							
								
									deaca46d87
								
							
						
					
					
						commit
						0b4e729a49
					
				| @ -10,7 +10,7 @@ from mushroom_rl.core import Environment | ||||
| class AirHockeyEnv(Environment): | ||||
|     metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 50} | ||||
| 
 | ||||
|     def __init__(self, env_mode=None, interpolation_order=3, render_mode=None, **kwargs): | ||||
|     def __init__(self, env_mode=None, interpolation_order=3, render_mode=None, width=1920, height=1080, **kwargs): | ||||
|         """ | ||||
|         Environment Constructor | ||||
| 
 | ||||
| @ -40,11 +40,10 @@ class AirHockeyEnv(Environment): | ||||
|             interpolation_order = (interpolation_order, interpolation_order) | ||||
| 
 | ||||
|         self.render_mode = render_mode | ||||
|         self.render_human_active = False | ||||
| 
 | ||||
|         # Determine headless mode based on render_mode | ||||
|         headless = self.render_mode == 'rgb_array' | ||||
|         width = kwargs.pop('width', 1920) | ||||
|         height = kwargs.pop('height', 1080) | ||||
|          | ||||
|         # Prepare viewer_params | ||||
|         viewer_params = kwargs.get('viewer_params', {}) | ||||
| @ -121,13 +120,16 @@ class AirHockeyEnv(Environment): | ||||
|         if self.env_info['env_name'] == "tournament": | ||||
|             obs = np.array(np.split(obs, 2)) | ||||
| 
 | ||||
|         if self.render_human_active: | ||||
|             self.base_env.render() | ||||
| 
 | ||||
|         return obs, reward, done, False, info | ||||
| 
 | ||||
|     def render(self): | ||||
|         if self.render_mode == 'rgb_array': | ||||
|             return self.base_env.render(record = True) | ||||
|         elif self.render_mode == 'human': | ||||
|             self.base_env.render() | ||||
|             self.render_human_active = True | ||||
|         else: | ||||
|             raise ValueError(f"Unsupported render mode: '{self.render_mode}'") | ||||
|              | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user