check time validity before pos validity
This commit is contained in:
		
							parent
							
								
									d384e6e764
								
							
						
					
					
						commit
						b9c2348855
					
				| @ -164,6 +164,9 @@ class BlackBoxWrapper(gym.ObservationWrapper): | |||||||
| 
 | 
 | ||||||
|     def step(self, action: np.ndarray): |     def step(self, action: np.ndarray): | ||||||
|         """ This function generates a trajectory based on a MP and then does the usual loop over reset and step""" |         """ This function generates a trajectory based on a MP and then does the usual loop over reset and step""" | ||||||
|  |         time_valid = self.env.check_time_validity(action) | ||||||
|  | 
 | ||||||
|  |         if time_valid: | ||||||
| 
 | 
 | ||||||
|             if self.plan_counts == 0: |             if self.plan_counts == 0: | ||||||
|                 self.tau_first_prediction = action[0] |                 self.tau_first_prediction = action[0] | ||||||
| @ -253,6 +256,9 @@ class BlackBoxWrapper(gym.ObservationWrapper): | |||||||
|             else: |             else: | ||||||
|                 obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity) |                 obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity) | ||||||
|                 return self.observation(obs), trajectory_return, done, infos |                 return self.observation(obs), trajectory_return, done, infos | ||||||
|  |         else: | ||||||
|  |                 obs, trajectory_return, done, infos = self.env.time_invalid_traj_callback(action) | ||||||
|  |                 return self.observation(obs), trajectory_return, done, infos | ||||||
|     def render(self, **kwargs): |     def render(self, **kwargs): | ||||||
|         """Only set render options here, such that they can be used during the rollout. |         """Only set render options here, such that they can be used during the rollout. | ||||||
|         This only needs to be called once""" |         This only needs to be called once""" | ||||||
|  | |||||||
| @ -560,7 +560,7 @@ for _v in _versions: | |||||||
|     _name = _v.split("-") |     _name = _v.split("-") | ||||||
|     _env_id = f'{_name[0]}ProDMP-{_name[1]}' |     _env_id = f'{_name[0]}ProDMP-{_name[1]}' | ||||||
|     kwargs_dict_tt_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP) |     kwargs_dict_tt_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP) | ||||||
|     kwargs_dict_tt_prodmp['wrappers'].append(mujoco.box_pushing.MPWrapper) |     kwargs_dict_tt_prodmp['wrappers'].append(mujoco.table_tennis.MPWrapper) | ||||||
|     kwargs_dict_tt_prodmp['name'] = _v |     kwargs_dict_tt_prodmp['name'] = _v | ||||||
|     kwargs_dict_tt_prodmp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0]) |     kwargs_dict_tt_prodmp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0]) | ||||||
|     kwargs_dict_tt_prodmp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1]) |     kwargs_dict_tt_prodmp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1]) | ||||||
|  | |||||||
| @ -28,12 +28,31 @@ class MPWrapper(RawInterfaceWrapper): | |||||||
|     def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: |     def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: | ||||||
|         return self.data.qvel[:7].copy() |         return self.data.qvel[:7].copy() | ||||||
| 
 | 
 | ||||||
|  |     def check_time_validity(self, action): | ||||||
|  |         return action[0] <= tau_bound[1] and action[0] >= tau_bound[0] \ | ||||||
|  |                and action[1] <= delay_bound[1] and action[1] >= delay_bound[0] | ||||||
|  | 
 | ||||||
|  |     def time_invalid_traj_callback(self, action) \ | ||||||
|  |         -> Tuple[np.ndarray, float, bool, dict]: | ||||||
|  |         tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]])) | ||||||
|  |         delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]])) | ||||||
|  |         invalid_penalty = tau_invalid_penalty + delay_invalid_penalty | ||||||
|  |         obs = np.concatenate([self.get_obs(), np.array([0])]) | ||||||
|  |         return obs, -invalid_penalty, True, { | ||||||
|  |         "hit_ball": [False], | ||||||
|  |         "ball_returned_success": [False], | ||||||
|  |         "land_dist_error": [10.], | ||||||
|  |         "is_success": [False], | ||||||
|  |         'trajectory_length': 1, | ||||||
|  |         "num_steps": [1] | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|     def episode_callback(self, action, pos_traj, vel_traj): |     def episode_callback(self, action, pos_traj, vel_traj): | ||||||
|         time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \ |         time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \ | ||||||
|                      or action[1] > delay_bound[1] or action[1] < delay_bound[0] |                      or action[1] > delay_bound[1] or action[1] < delay_bound[0] | ||||||
|         if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low): |         if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low): | ||||||
|             return False |  | ||||||
|             return True |             return True | ||||||
|  |         return False | ||||||
| 
 | 
 | ||||||
|     def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \ |     def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \ | ||||||
|             -> Tuple[np.ndarray, float, bool, dict]: |             -> Tuple[np.ndarray, float, bool, dict]: | ||||||
| @ -43,7 +62,8 @@ class MPWrapper(RawInterfaceWrapper): | |||||||
|         violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0)) |         violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0)) | ||||||
|         invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \ |         invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \ | ||||||
|                           violate_high_bound_error + violate_low_bound_error |                           violate_high_bound_error + violate_low_bound_error | ||||||
|         return self.get_obs(), -invalid_penalty, True, { |         obs = np.concatenate([self.get_obs(), np.array([0])]) | ||||||
|  |         return obs, -invalid_penalty, True, { | ||||||
|         "hit_ball": [False], |         "hit_ball": [False], | ||||||
|         "ball_returned_success": [False], |         "ball_returned_success": [False], | ||||||
|         "land_dist_error": [10.], |         "land_dist_error": [10.], | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user