From bd7e811a645eadc1b21611887e942e120cdd7033 Mon Sep 17 00:00:00 2001 From: "hongyi.zhou" Date: Mon, 3 Jul 2023 17:19:41 +0200 Subject: [PATCH] fix tau bound and init bound bug --- fancy_gym/black_box/black_box_wrapper.py | 14 ++++++++++++-- fancy_gym/black_box/raw_interface_wrapper.py | 10 ++++++++-- fancy_gym/envs/__init__.py | 4 ++-- fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py | 9 +++++---- .../envs/mujoco/table_tennis/table_tennis_env.py | 10 +++++----- 5 files changed, 32 insertions(+), 15 deletions(-) diff --git a/fancy_gym/black_box/black_box_wrapper.py b/fancy_gym/black_box/black_box_wrapper.py index 0c108d7..7c33428 100644 --- a/fancy_gym/black_box/black_box_wrapper.py +++ b/fancy_gym/black_box/black_box_wrapper.py @@ -55,6 +55,14 @@ class BlackBoxWrapper(gym.ObservationWrapper): # self.traj_gen.set_mp_times(self.time_steps) self.traj_gen.set_duration(self.duration, self.dt) + # check + self.tau_bound = [-np.inf, np.inf] + self.delay_bound = [-np.inf, np.inf] + if hasattr(self.traj_gen.phase_gn, "tau_bound"): + self.tau_bound = self.traj_gen.phase_gn.tau_bound + if hasattr(self.traj_gen.phase_gn, "delay_bound"): + self.delay_bound = self.traj_gen.phase_gn.delay_bound + # reward computation self.reward_aggregation = reward_aggregation @@ -139,7 +147,8 @@ class BlackBoxWrapper(gym.ObservationWrapper): position, velocity = self.get_trajectory(action) position, velocity = self.env.set_episode_arguments(action, position, velocity) - traj_is_valid, position, velocity = self.env.preprocessing_and_validity_callback(action, position, velocity) + traj_is_valid, position, velocity = self.env.preprocessing_and_validity_callback(action, position, velocity, + self.tau_bound, self.delay_bound) trajectory_length = len(position) rewards = np.zeros(shape=(trajectory_length,)) @@ -153,7 +162,8 @@ class BlackBoxWrapper(gym.ObservationWrapper): if not traj_is_valid: obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity, - self.return_context_observation) + self.return_context_observation, + self.tau_bound, self.delay_bound) return self.observation(obs), trajectory_return, done, infos self.plan_steps += 1 diff --git a/fancy_gym/black_box/raw_interface_wrapper.py b/fancy_gym/black_box/raw_interface_wrapper.py index 7647924..c8f7273 100644 --- a/fancy_gym/black_box/raw_interface_wrapper.py +++ b/fancy_gym/black_box/raw_interface_wrapper.py @@ -52,7 +52,8 @@ class RawInterfaceWrapper(gym.Wrapper): """ return self.env.dt - def preprocessing_and_validity_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray) \ + def preprocessing_and_validity_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray, + tau_bound: list = None, delay_bound: list = None ) \ -> Tuple[bool, np.ndarray, np.ndarray]: """ Used to preprocess the action and check if the desired trajectory is valid. @@ -61,6 +62,8 @@ class RawInterfaceWrapper(gym.Wrapper): specified, else only traj_gen parameters pos_traj: a vector instance of the raw position trajectory vel_traj: a vector instance of the raw velocity trajectory + tau_bound: a list of two elements, the lower and upper bound of the trajectory length scaling factor + delay_bound: a list of two elements, the lower and upper bound of the time to wait before execute Returns: validity flag: bool, True if the raw trajectory is valid, False if not pos_traj: a vector instance of the preprocessed position trajectory @@ -97,7 +100,8 @@ class RawInterfaceWrapper(gym.Wrapper): """ return True - def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]: + def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray, + tau_bound: list, delay_bound: list) -> Tuple[np.ndarray, float, bool, dict]: """ Used to return a artificial return from the env if the desired trajectory is invalid. Args: @@ -105,6 +109,8 @@ class RawInterfaceWrapper(gym.Wrapper): specified, else only traj_gen parameters pos_traj: a vector instance of the raw position trajectory vel_traj: a vector instance of the raw velocity trajectory + tau_bound: a list of two elements, the lower and upper bound of the trajectory length scaling factor + delay_bound: a list of two elements, the lower and upper bound of the time to wait before execute Returns: obs: artificial observation if the trajectory is invalid, by default a zero vector reward: artificial reward if the trajectory is invalid, by default 0 diff --git a/fancy_gym/envs/__init__.py b/fancy_gym/envs/__init__.py index c23e879..b5bc154 100644 --- a/fancy_gym/envs/__init__.py +++ b/fancy_gym/envs/__init__.py @@ -557,8 +557,8 @@ for _v in _versions: kwargs_dict_tt_promp['name'] = _v kwargs_dict_tt_promp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0]) kwargs_dict_tt_promp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1]) - kwargs_dict_tt_promp['phase_generator_kwargs']['learn_tau'] = False - kwargs_dict_tt_promp['phase_generator_kwargs']['learn_delay'] = False + kwargs_dict_tt_promp['phase_generator_kwargs']['learn_tau'] = True + kwargs_dict_tt_promp['phase_generator_kwargs']['learn_delay'] = True kwargs_dict_tt_promp['phase_generator_kwargs']['tau_bound'] = [0.8, 1.5] kwargs_dict_tt_promp['phase_generator_kwargs']['delay_bound'] = [0.05, 0.15] kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis'] = 3 diff --git a/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py b/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py index e33ed6c..3370047 100644 --- a/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py +++ b/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py @@ -29,15 +29,16 @@ class TT_MPWrapper(RawInterfaceWrapper): def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: return self.data.qvel[:7].copy() - def preprocessing_and_validity_callback(self, action, pos_traj, vel_traj): - return self.check_traj_validity(action, pos_traj, vel_traj) + def preprocessing_and_validity_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray, + tau_bound: list, delay_bound:list): + return self.check_traj_validity(action, pos_traj, vel_traj, tau_bound, delay_bound) def set_episode_arguments(self, action, pos_traj, vel_traj): return pos_traj, vel_traj def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray, - return_contextual_obs: bool) -> Tuple[np.ndarray, float, bool, dict]: - return self.get_invalid_traj_step_return(action, pos_traj, return_contextual_obs) + return_contextual_obs: bool, tau_bound:list, delay_bound:list) -> Tuple[np.ndarray, float, bool, dict]: + return self.get_invalid_traj_step_return(action, pos_traj, return_contextual_obs, tau_bound, delay_bound) class TTVelObs_MPWrapper(TT_MPWrapper): diff --git a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py index 7fb5e9f..3f86463 100644 --- a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py +++ b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py @@ -5,7 +5,7 @@ from gym import utils, spaces from gym.envs.mujoco import MujocoEnv from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import is_init_state_valid, magnus_force -from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound +from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high import mujoco @@ -225,7 +225,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel) return init_ball_state - def _get_traj_invalid_penalty(self, action, pos_traj): + def _get_traj_invalid_penalty(self, action, pos_traj, tau_bound, delay_bound): tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]])) delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]])) violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0)) @@ -234,9 +234,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): violate_high_bound_error + violate_low_bound_error return -invalid_penalty - def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs): + def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs, tau_bound, delay_bound): obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj - penalty = self._get_traj_invalid_penalty(action, pos_traj) + penalty = self._get_traj_invalid_penalty(action, pos_traj, tau_bound, delay_bound) return obs, penalty, True, { "hit_ball": [False], "ball_returned_success": [False], @@ -247,7 +247,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): } @staticmethod - def check_traj_validity(action, pos_traj, vel_traj): + def check_traj_validity(action, pos_traj, vel_traj, tau_bound, delay_bound): time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \ or action[1] > delay_bound[1] or action[1] < delay_bound[0] if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low):