check time validity before pos validity
This commit is contained in:
		
							parent
							
								
									d384e6e764
								
							
						
					
					
						commit
						b9c2348855
					
				| @ -164,6 +164,9 @@ class BlackBoxWrapper(gym.ObservationWrapper): | ||||
| 
 | ||||
|     def step(self, action: np.ndarray): | ||||
|         """ This function generates a trajectory based on a MP and then does the usual loop over reset and step""" | ||||
|         time_valid = self.env.check_time_validity(action) | ||||
| 
 | ||||
|         if time_valid: | ||||
| 
 | ||||
|             if self.plan_counts == 0: | ||||
|                 self.tau_first_prediction = action[0] | ||||
| @ -253,6 +256,9 @@ class BlackBoxWrapper(gym.ObservationWrapper): | ||||
|             else: | ||||
|                 obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity) | ||||
|                 return self.observation(obs), trajectory_return, done, infos | ||||
|         else: | ||||
|                 obs, trajectory_return, done, infos = self.env.time_invalid_traj_callback(action) | ||||
|                 return self.observation(obs), trajectory_return, done, infos | ||||
|     def render(self, **kwargs): | ||||
|         """Only set render options here, such that they can be used during the rollout. | ||||
|         This only needs to be called once""" | ||||
|  | ||||
| @ -560,7 +560,7 @@ for _v in _versions: | ||||
|     _name = _v.split("-") | ||||
|     _env_id = f'{_name[0]}ProDMP-{_name[1]}' | ||||
|     kwargs_dict_tt_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP) | ||||
|     kwargs_dict_tt_prodmp['wrappers'].append(mujoco.box_pushing.MPWrapper) | ||||
|     kwargs_dict_tt_prodmp['wrappers'].append(mujoco.table_tennis.MPWrapper) | ||||
|     kwargs_dict_tt_prodmp['name'] = _v | ||||
|     kwargs_dict_tt_prodmp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0]) | ||||
|     kwargs_dict_tt_prodmp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1]) | ||||
|  | ||||
| @ -28,12 +28,31 @@ class MPWrapper(RawInterfaceWrapper): | ||||
|     def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: | ||||
|         return self.data.qvel[:7].copy() | ||||
| 
 | ||||
|     def check_time_validity(self, action): | ||||
|         return action[0] <= tau_bound[1] and action[0] >= tau_bound[0] \ | ||||
|                and action[1] <= delay_bound[1] and action[1] >= delay_bound[0] | ||||
| 
 | ||||
|     def time_invalid_traj_callback(self, action) \ | ||||
|         -> Tuple[np.ndarray, float, bool, dict]: | ||||
|         tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]])) | ||||
|         delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]])) | ||||
|         invalid_penalty = tau_invalid_penalty + delay_invalid_penalty | ||||
|         obs = np.concatenate([self.get_obs(), np.array([0])]) | ||||
|         return obs, -invalid_penalty, True, { | ||||
|         "hit_ball": [False], | ||||
|         "ball_returned_success": [False], | ||||
|         "land_dist_error": [10.], | ||||
|         "is_success": [False], | ||||
|         'trajectory_length': 1, | ||||
|         "num_steps": [1] | ||||
|         } | ||||
| 
 | ||||
|     def episode_callback(self, action, pos_traj, vel_traj): | ||||
|         time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \ | ||||
|                      or action[1] > delay_bound[1] or action[1] < delay_bound[0] | ||||
|         if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low): | ||||
|             return False | ||||
|             return True | ||||
|         return False | ||||
| 
 | ||||
|     def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \ | ||||
|             -> Tuple[np.ndarray, float, bool, dict]: | ||||
| @ -43,7 +62,8 @@ class MPWrapper(RawInterfaceWrapper): | ||||
|         violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0)) | ||||
|         invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \ | ||||
|                           violate_high_bound_error + violate_low_bound_error | ||||
|         return self.get_obs(), -invalid_penalty, True, { | ||||
|         obs = np.concatenate([self.get_obs(), np.array([0])]) | ||||
|         return obs, -invalid_penalty, True, { | ||||
|         "hit_ball": [False], | ||||
|         "ball_returned_success": [False], | ||||
|         "land_dist_error": [10.], | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user