diff --git a/fancy_gym/envs/__init__.py b/fancy_gym/envs/__init__.py index 2fddbf3..ea9aec7 100644 --- a/fancy_gym/envs/__init__.py +++ b/fancy_gym/envs/__init__.py @@ -571,7 +571,7 @@ for _v in _versions: for _v in _versions: _name = _v.split("-") - _env_id = f'{_name[0]}ProDMP-{_name[1]}' + _env_id = f'{_name[0]}ReplanProDMP-{_name[1]}' kwargs_dict_tt_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP) if _v == 'TableTennisWind-v0': kwargs_dict_tt_prodmp['wrappers'].append(mujoco.table_tennis.TTVelObs_MPWrapper) @@ -580,8 +580,6 @@ for _v in _versions: kwargs_dict_tt_prodmp['name'] = _v kwargs_dict_tt_prodmp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0]) kwargs_dict_tt_prodmp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1]) - kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['weights_scale'] = 1.0 - kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_scale'] = 1.0 kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['auto_scale_basis'] = False kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_offset'] = 1.0 kwargs_dict_tt_prodmp['phase_generator_kwargs']['tau_bound'] = [0.8, 1.5] @@ -590,7 +588,7 @@ for _v in _versions: kwargs_dict_tt_prodmp['phase_generator_kwargs']['learn_delay'] = True kwargs_dict_tt_prodmp['basis_generator_kwargs']['num_basis'] = 2 kwargs_dict_tt_prodmp['basis_generator_kwargs']['alpha'] = 25. - kwargs_dict_tt_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try + kwargs_dict_tt_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 kwargs_dict_tt_prodmp['phase_generator_kwargs']['alpha_phase'] = 3 kwargs_dict_tt_prodmp['black_box_kwargs']['max_planning_times'] = 3 kwargs_dict_tt_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 50 == 0 diff --git a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py index 3f30256..dc717c2 100644 --- a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py +++ b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py @@ -191,9 +191,6 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): self.data.joint("tar_x").qpos.copy(), self.data.joint("tar_y").qpos.copy(), self.data.joint("tar_z").qpos.copy(), - # self.data.joint("tar_x").qvel.copy(), - # self.data.joint("tar_y").qvel.copy(), - # self.data.joint("tar_z").qvel.copy(), self._goal_pos.copy(), ]) return obs @@ -234,7 +231,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel) return init_ball_state - def _get_traj_invalid_reward(self, action, pos_traj): + def _get_traj_invalid_penalty(self, action, pos_traj): tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]])) delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]])) violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0)) @@ -245,7 +242,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs): obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj - penalty = self._get_traj_invalid_reward(action, pos_traj) + penalty = self._get_traj_invalid_penalty(action, pos_traj) return obs, penalty, True, { "hit_ball": [False], "ball_returned_success": [False], diff --git a/fancy_gym/examples/examples_movement_primitives.py b/fancy_gym/examples/examples_movement_primitives.py index 4aeeecc..b533e9c 100644 --- a/fancy_gym/examples/examples_movement_primitives.py +++ b/fancy_gym/examples/examples_movement_primitives.py @@ -155,25 +155,22 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True): if __name__ == '__main__': - render = True + render = False # DMP - # example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render) + example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render) # ProMP - # example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render) - # example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render) + example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render) + example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render) example_mp("TableTennis4DProMP-v0", seed=10, iterations=20, render=render) - example_mp("TableTennisWindProMP-v0", seed=10, iterations=20, render=render) - example_mp("TableTennisGoalSwitchingProMP-v0", seed=10, iterations=20, render=render) # ProDMP with Replanning - # example_mp("BoxPushingDenseReplanProDMP-v0", seed=10, iterations=4, render=render) - # example_mp("TableTennis4DProDMP-v0", seed=10, iterations=100, render=render) - # example_mp("TableTennisWindProDMP-v0", seed=10, iterations=100, render=render) - # example_mp("TableTennisGoalSwitchingProDMP-v0", seed=10, iterations=100, render=render) + example_mp("BoxPushingDenseReplanProDMP-v0", seed=10, iterations=4, render=render) + example_mp("TableTennis4DReplanProDMP-v0", seed=10, iterations=20, render=render) + example_mp("TableTennisWindReplanProDMP-v0", seed=10, iterations=20, render=render) # Altered basis functions - # obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render) + obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render) # Custom MP - # example_fully_custom_mp(seed=10, iterations=1, render=render) + example_fully_custom_mp(seed=10, iterations=1, render=render)