Pass noise as parameter
This commit is contained in:
		
							parent
							
								
									fde5e33be6
								
							
						
					
					
						commit
						86abbe0d97
					
				| @ -5,20 +5,19 @@ from fancy_gym.envs.mujoco.air_hockey.seven_dof.env_single import AirHockeySingl | ||||
| from fancy_gym.envs.mujoco.air_hockey.utils import inverse_kinematics, forward_kinematics, jacobian | ||||
| 
 | ||||
| class AirhocKIT2023BaseEnv(AirHockeySingle): | ||||
|     def __init__(self, **kwargs): | ||||
|     def __init__(self, noise=False, **kwargs): | ||||
|         super().__init__(**kwargs) | ||||
|         obs_low = np.hstack([[-np.inf] * 37]) | ||||
|         obs_high = np.hstack([[np.inf] * 37]) | ||||
|         self.wrapper_obs_space = spaces.Box(low=obs_low, high=obs_high, dtype=np.float64) | ||||
|         self.wrapper_act_space = spaces.Box(low=np.repeat(-100., 6), high=np.repeat(100., 6)) | ||||
|         self.noise = False | ||||
|         self.noise = noise | ||||
| 
 | ||||
|     # We don't need puck yaw observations | ||||
|     def filter_obs(self, obs): | ||||
|         obs = np.hstack([obs[0:2], obs[3:5], obs[6:12], obs[13:19], obs[20:]]) | ||||
|         return obs | ||||
| 
 | ||||
|     # These are roughly the noise levels for a noisy environment, turned-off by default, enable in the constructor | ||||
|     def add_noise(self, obs): | ||||
|         if not self.noise: | ||||
|             return | ||||
|  | ||||
| @ -46,8 +46,8 @@ class AirHockeyDefend(AirHockeySingle): | ||||
|         return super().is_absorbing(state) | ||||
| 
 | ||||
| class AirHockeyDefendAirhocKIT2023(AirhocKIT2023BaseEnv): | ||||
|     def __init__(self, gamma=0.99, horizon=200, viewer_params={}): | ||||
|         super().__init__(gamma=gamma, horizon=horizon, viewer_params=viewer_params) | ||||
|     def __init__(self, gamma=0.99, horizon=200, viewer_params={}, **kwargs): | ||||
|         super().__init__(gamma=gamma, horizon=horizon, viewer_params=viewer_params, **kwargs) | ||||
|         self.init_velocity_range = (1, 3) | ||||
|         self.start_range = np.array([[0.4, 0.75], [-0.4, 0.4]])  # Table Frame | ||||
|         self._setup_metrics() | ||||
|  | ||||
| @ -57,8 +57,8 @@ class AirHockeyHit(AirHockeySingle): | ||||
|         return super(AirHockeyHit, self).is_absorbing(obs) | ||||
| 
 | ||||
| class AirHockeyHitAirhocKIT2023(AirhocKIT2023BaseEnv): | ||||
|     def __init__(self, gamma=0.99, horizon=500, moving_init=True, viewer_params={}): | ||||
|         super().__init__(gamma=gamma, horizon=horizon, viewer_params=viewer_params) | ||||
|     def __init__(self, gamma=0.99, horizon=500, moving_init=True, viewer_params={}, **kwargs): | ||||
|         super().__init__(gamma=gamma, horizon=horizon, viewer_params=viewer_params, **kwargs) | ||||
| 
 | ||||
|         self.moving_init = moving_init | ||||
|         hit_width = self.env_info['table']['width'] / 2 - self.env_info['puck']['radius'] - \ | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user