check time validity before pos validity

This commit is contained in:
Hongyi Zhou 2022-11-08 13:57:32 +01:00
parent d384e6e764
commit b9c2348855
3 changed files with 109 additions and 83 deletions

View File

@ -164,6 +164,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
def step(self, action: np.ndarray): def step(self, action: np.ndarray):
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step""" """ This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
time_valid = self.env.check_time_validity(action)
if time_valid:
if self.plan_counts == 0: if self.plan_counts == 0:
self.tau_first_prediction = action[0] self.tau_first_prediction = action[0]
@ -253,6 +256,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
else: else:
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity) obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity)
return self.observation(obs), trajectory_return, done, infos return self.observation(obs), trajectory_return, done, infos
else:
obs, trajectory_return, done, infos = self.env.time_invalid_traj_callback(action)
return self.observation(obs), trajectory_return, done, infos
def render(self, **kwargs): def render(self, **kwargs):
"""Only set render options here, such that they can be used during the rollout. """Only set render options here, such that they can be used during the rollout.
This only needs to be called once""" This only needs to be called once"""

View File

@ -560,7 +560,7 @@ for _v in _versions:
_name = _v.split("-") _name = _v.split("-")
_env_id = f'{_name[0]}ProDMP-{_name[1]}' _env_id = f'{_name[0]}ProDMP-{_name[1]}'
kwargs_dict_tt_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP) kwargs_dict_tt_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
kwargs_dict_tt_prodmp['wrappers'].append(mujoco.box_pushing.MPWrapper) kwargs_dict_tt_prodmp['wrappers'].append(mujoco.table_tennis.MPWrapper)
kwargs_dict_tt_prodmp['name'] = _v kwargs_dict_tt_prodmp['name'] = _v
kwargs_dict_tt_prodmp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0]) kwargs_dict_tt_prodmp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0])
kwargs_dict_tt_prodmp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1]) kwargs_dict_tt_prodmp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])

View File

@ -28,12 +28,31 @@ class MPWrapper(RawInterfaceWrapper):
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
return self.data.qvel[:7].copy() return self.data.qvel[:7].copy()
def check_time_validity(self, action):
return action[0] <= tau_bound[1] and action[0] >= tau_bound[0] \
and action[1] <= delay_bound[1] and action[1] >= delay_bound[0]
def time_invalid_traj_callback(self, action) \
-> Tuple[np.ndarray, float, bool, dict]:
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty
obs = np.concatenate([self.get_obs(), np.array([0])])
return obs, -invalid_penalty, True, {
"hit_ball": [False],
"ball_returned_success": [False],
"land_dist_error": [10.],
"is_success": [False],
'trajectory_length': 1,
"num_steps": [1]
}
def episode_callback(self, action, pos_traj, vel_traj): def episode_callback(self, action, pos_traj, vel_traj):
time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \ time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
or action[1] > delay_bound[1] or action[1] < delay_bound[0] or action[1] > delay_bound[1] or action[1] < delay_bound[0]
if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low): if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low):
return False
return True return True
return False
def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \ def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \
-> Tuple[np.ndarray, float, bool, dict]: -> Tuple[np.ndarray, float, bool, dict]:
@ -43,7 +62,8 @@ class MPWrapper(RawInterfaceWrapper):
violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0)) violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0))
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \ invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \
violate_high_bound_error + violate_low_bound_error violate_high_bound_error + violate_low_bound_error
return self.get_obs(), -invalid_penalty, True, { obs = np.concatenate([self.get_obs(), np.array([0])])
return obs, -invalid_penalty, True, {
"hit_ball": [False], "hit_ball": [False],
"ball_returned_success": [False], "ball_returned_success": [False],
"land_dist_error": [10.], "land_dist_error": [10.],