check time validity before pos validity
This commit is contained in:
parent
d384e6e764
commit
b9c2348855
@ -164,6 +164,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
|
|
||||||
def step(self, action: np.ndarray):
|
def step(self, action: np.ndarray):
|
||||||
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
|
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
|
||||||
|
time_valid = self.env.check_time_validity(action)
|
||||||
|
|
||||||
|
if time_valid:
|
||||||
|
|
||||||
if self.plan_counts == 0:
|
if self.plan_counts == 0:
|
||||||
self.tau_first_prediction = action[0]
|
self.tau_first_prediction = action[0]
|
||||||
@ -253,6 +256,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
else:
|
else:
|
||||||
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity)
|
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity)
|
||||||
return self.observation(obs), trajectory_return, done, infos
|
return self.observation(obs), trajectory_return, done, infos
|
||||||
|
else:
|
||||||
|
obs, trajectory_return, done, infos = self.env.time_invalid_traj_callback(action)
|
||||||
|
return self.observation(obs), trajectory_return, done, infos
|
||||||
def render(self, **kwargs):
|
def render(self, **kwargs):
|
||||||
"""Only set render options here, such that they can be used during the rollout.
|
"""Only set render options here, such that they can be used during the rollout.
|
||||||
This only needs to be called once"""
|
This only needs to be called once"""
|
||||||
|
@ -560,7 +560,7 @@ for _v in _versions:
|
|||||||
_name = _v.split("-")
|
_name = _v.split("-")
|
||||||
_env_id = f'{_name[0]}ProDMP-{_name[1]}'
|
_env_id = f'{_name[0]}ProDMP-{_name[1]}'
|
||||||
kwargs_dict_tt_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
kwargs_dict_tt_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
||||||
kwargs_dict_tt_prodmp['wrappers'].append(mujoco.box_pushing.MPWrapper)
|
kwargs_dict_tt_prodmp['wrappers'].append(mujoco.table_tennis.MPWrapper)
|
||||||
kwargs_dict_tt_prodmp['name'] = _v
|
kwargs_dict_tt_prodmp['name'] = _v
|
||||||
kwargs_dict_tt_prodmp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0])
|
kwargs_dict_tt_prodmp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0])
|
||||||
kwargs_dict_tt_prodmp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])
|
kwargs_dict_tt_prodmp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])
|
||||||
|
@ -28,12 +28,31 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
return self.data.qvel[:7].copy()
|
return self.data.qvel[:7].copy()
|
||||||
|
|
||||||
|
def check_time_validity(self, action):
|
||||||
|
return action[0] <= tau_bound[1] and action[0] >= tau_bound[0] \
|
||||||
|
and action[1] <= delay_bound[1] and action[1] >= delay_bound[0]
|
||||||
|
|
||||||
|
def time_invalid_traj_callback(self, action) \
|
||||||
|
-> Tuple[np.ndarray, float, bool, dict]:
|
||||||
|
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
|
||||||
|
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
|
||||||
|
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty
|
||||||
|
obs = np.concatenate([self.get_obs(), np.array([0])])
|
||||||
|
return obs, -invalid_penalty, True, {
|
||||||
|
"hit_ball": [False],
|
||||||
|
"ball_returned_success": [False],
|
||||||
|
"land_dist_error": [10.],
|
||||||
|
"is_success": [False],
|
||||||
|
'trajectory_length': 1,
|
||||||
|
"num_steps": [1]
|
||||||
|
}
|
||||||
|
|
||||||
def episode_callback(self, action, pos_traj, vel_traj):
|
def episode_callback(self, action, pos_traj, vel_traj):
|
||||||
time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
|
time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
|
||||||
or action[1] > delay_bound[1] or action[1] < delay_bound[0]
|
or action[1] > delay_bound[1] or action[1] < delay_bound[0]
|
||||||
if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low):
|
if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low):
|
||||||
return False
|
|
||||||
return True
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \
|
def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \
|
||||||
-> Tuple[np.ndarray, float, bool, dict]:
|
-> Tuple[np.ndarray, float, bool, dict]:
|
||||||
@ -43,7 +62,8 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0))
|
violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0))
|
||||||
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \
|
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \
|
||||||
violate_high_bound_error + violate_low_bound_error
|
violate_high_bound_error + violate_low_bound_error
|
||||||
return self.get_obs(), -invalid_penalty, True, {
|
obs = np.concatenate([self.get_obs(), np.array([0])])
|
||||||
|
return obs, -invalid_penalty, True, {
|
||||||
"hit_ball": [False],
|
"hit_ball": [False],
|
||||||
"ball_returned_success": [False],
|
"ball_returned_success": [False],
|
||||||
"land_dist_error": [10.],
|
"land_dist_error": [10.],
|
||||||
|
Loading…
Reference in New Issue
Block a user