Pass noise as parameter
This commit is contained in:
parent
fde5e33be6
commit
86abbe0d97
@ -5,20 +5,19 @@ from fancy_gym.envs.mujoco.air_hockey.seven_dof.env_single import AirHockeySingl
|
||||
from fancy_gym.envs.mujoco.air_hockey.utils import inverse_kinematics, forward_kinematics, jacobian
|
||||
|
||||
class AirhocKIT2023BaseEnv(AirHockeySingle):
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, noise=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
obs_low = np.hstack([[-np.inf] * 37])
|
||||
obs_high = np.hstack([[np.inf] * 37])
|
||||
self.wrapper_obs_space = spaces.Box(low=obs_low, high=obs_high, dtype=np.float64)
|
||||
self.wrapper_act_space = spaces.Box(low=np.repeat(-100., 6), high=np.repeat(100., 6))
|
||||
self.noise = False
|
||||
self.noise = noise
|
||||
|
||||
# We don't need puck yaw observations
|
||||
def filter_obs(self, obs):
|
||||
obs = np.hstack([obs[0:2], obs[3:5], obs[6:12], obs[13:19], obs[20:]])
|
||||
return obs
|
||||
|
||||
# These are roughly the noise levels for a noisy environment, turned-off by default, enable in the constructor
|
||||
def add_noise(self, obs):
|
||||
if not self.noise:
|
||||
return
|
||||
|
@ -46,8 +46,8 @@ class AirHockeyDefend(AirHockeySingle):
|
||||
return super().is_absorbing(state)
|
||||
|
||||
class AirHockeyDefendAirhocKIT2023(AirhocKIT2023BaseEnv):
|
||||
def __init__(self, gamma=0.99, horizon=200, viewer_params={}):
|
||||
super().__init__(gamma=gamma, horizon=horizon, viewer_params=viewer_params)
|
||||
def __init__(self, gamma=0.99, horizon=200, viewer_params={}, **kwargs):
|
||||
super().__init__(gamma=gamma, horizon=horizon, viewer_params=viewer_params, **kwargs)
|
||||
self.init_velocity_range = (1, 3)
|
||||
self.start_range = np.array([[0.4, 0.75], [-0.4, 0.4]]) # Table Frame
|
||||
self._setup_metrics()
|
||||
|
@ -57,8 +57,8 @@ class AirHockeyHit(AirHockeySingle):
|
||||
return super(AirHockeyHit, self).is_absorbing(obs)
|
||||
|
||||
class AirHockeyHitAirhocKIT2023(AirhocKIT2023BaseEnv):
|
||||
def __init__(self, gamma=0.99, horizon=500, moving_init=True, viewer_params={}):
|
||||
super().__init__(gamma=gamma, horizon=horizon, viewer_params=viewer_params)
|
||||
def __init__(self, gamma=0.99, horizon=500, moving_init=True, viewer_params={}, **kwargs):
|
||||
super().__init__(gamma=gamma, horizon=horizon, viewer_params=viewer_params, **kwargs)
|
||||
|
||||
self.moving_init = moving_init
|
||||
hit_width = self.env_info['table']['width'] / 2 - self.env_info['puck']['radius'] - \
|
||||
|
Loading…
Reference in New Issue
Block a user