From 1fd4a1e848500cd3dfce92654555d249eaff6dcf Mon Sep 17 00:00:00 2001 From: Hongyi Zhou Date: Thu, 13 Oct 2022 14:58:28 +0200 Subject: [PATCH] resample the context if target is too near to the start position --- fancy_gym/black_box/black_box_wrapper.py | 7 ++++--- fancy_gym/envs/__init__.py | 5 ++--- fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py | 2 ++ fancy_gym/examples/examples_movement_primitives.py | 6 +++--- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/fancy_gym/black_box/black_box_wrapper.py b/fancy_gym/black_box/black_box_wrapper.py index 2d2b037..fbcf58d 100644 --- a/fancy_gym/black_box/black_box_wrapper.py +++ b/fancy_gym/black_box/black_box_wrapper.py @@ -81,9 +81,10 @@ class BlackBoxWrapper(gym.ObservationWrapper): bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt) # TODO we could think about initializing with the previous desired value in order to have a smooth transition # at least from the planning point of view. - self.traj_gen.set_boundary_conditions(torch.as_tensor(bc_time), - torch.as_tensor(self.current_pos), - torch.as_tensor(self.current_vel)) + + self.traj_gen.set_boundary_conditions(bc_time, + self.current_pos, + self.current_vel) duration = None if self.learn_sub_trajectories else self.duration self.traj_gen.set_duration(duration, self.dt) # traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True) diff --git a/fancy_gym/envs/__init__.py b/fancy_gym/envs/__init__.py index 23bba25..9fc778e 100644 --- a/fancy_gym/envs/__init__.py +++ b/fancy_gym/envs/__init__.py @@ -369,7 +369,6 @@ for _v in _versions: kwargs=kwargs_dict_reacher_promp ) ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) - ######################################################################################################################## ## Beerpong ProMP _versions = ['BeerPong-v0'] @@ -485,7 +484,7 @@ for _v in _versions: kwargs_dict_box_pushing_promp['name'] = _v kwargs_dict_box_pushing_promp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.]) kwargs_dict_box_pushing_promp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.]) - kwargs_dict_box_pushing_promp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3.5 # 3.5, 4 to try + kwargs_dict_box_pushing_promp['basis_generator_kwargs']['basis_bandwidth_factor'] = 2 # 3.5, 4 to try register( id=_env_id, @@ -502,7 +501,7 @@ for _v in _versions: kwargs_dict_box_pushing_prodmp['name'] = _v kwargs_dict_box_pushing_prodmp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.]) kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.]) - + kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t: t % 25 == 0 register( id=_env_id, entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', diff --git a/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py b/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py index e70ecec..aff2a16 100644 --- a/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py +++ b/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py @@ -94,6 +94,8 @@ class BoxPushingEnv(MujocoEnv, utils.EzPickle): # set target position box_target_pos = self.sample_context() + while np.linalg.norm(box_target_pos[:2] - box_init_pos[:2]) < 0.3: + box_target_pos = self.sample_context() #box_target_pos[0] = 0.4 #box_target_pos[1] = -0.3 #box_target_pos[-4:] = np.array([0.0, 0.0, 0.0, 1.0]) diff --git a/fancy_gym/examples/examples_movement_primitives.py b/fancy_gym/examples/examples_movement_primitives.py index cda1064..56f6de0 100644 --- a/fancy_gym/examples/examples_movement_primitives.py +++ b/fancy_gym/examples/examples_movement_primitives.py @@ -24,7 +24,7 @@ def example_mp(env_name="HoleReacherProMP-v0", seed=1, iterations=1, render=True # number of samples/full trajectories (multiple environment steps) for i in range(iterations): - if render and i % 2 == 0: + if render and i % 1 == 0: # This renders the full MP trajectory # It is only required to call render() once in the beginning, which renders every consecutive trajectory. # Resetting to no rendering, can be achieved by render(mode=None). @@ -161,10 +161,10 @@ if __name__ == '__main__': # ProMP # example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render) - example_mp("BoxPushingDenseProMP-v0", seed=10, iterations=50, render=render) + # example_mp("BoxPushingDenseProMP-v0", seed=10, iterations=50, render=render) # ProDMP - # example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=50, render=render) + example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=50, render=render) # Altered basis functions # obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)