resample the context if target is too near to the start position

This commit is contained in:
Hongyi Zhou 2022-10-13 14:58:28 +02:00
parent eec171e04a
commit 1fd4a1e848
4 changed files with 11 additions and 9 deletions

View File

@ -81,9 +81,10 @@ class BlackBoxWrapper(gym.ObservationWrapper):
bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt) bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
# TODO we could think about initializing with the previous desired value in order to have a smooth transition # TODO we could think about initializing with the previous desired value in order to have a smooth transition
# at least from the planning point of view. # at least from the planning point of view.
self.traj_gen.set_boundary_conditions(torch.as_tensor(bc_time),
torch.as_tensor(self.current_pos), self.traj_gen.set_boundary_conditions(bc_time,
torch.as_tensor(self.current_vel)) self.current_pos,
self.current_vel)
duration = None if self.learn_sub_trajectories else self.duration duration = None if self.learn_sub_trajectories else self.duration
self.traj_gen.set_duration(duration, self.dt) self.traj_gen.set_duration(duration, self.dt)
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True) # traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)

View File

@ -369,7 +369,6 @@ for _v in _versions:
kwargs=kwargs_dict_reacher_promp kwargs=kwargs_dict_reacher_promp
) )
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
######################################################################################################################## ########################################################################################################################
## Beerpong ProMP ## Beerpong ProMP
_versions = ['BeerPong-v0'] _versions = ['BeerPong-v0']
@ -485,7 +484,7 @@ for _v in _versions:
kwargs_dict_box_pushing_promp['name'] = _v kwargs_dict_box_pushing_promp['name'] = _v
kwargs_dict_box_pushing_promp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.]) kwargs_dict_box_pushing_promp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.])
kwargs_dict_box_pushing_promp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.]) kwargs_dict_box_pushing_promp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
kwargs_dict_box_pushing_promp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3.5 # 3.5, 4 to try kwargs_dict_box_pushing_promp['basis_generator_kwargs']['basis_bandwidth_factor'] = 2 # 3.5, 4 to try
register( register(
id=_env_id, id=_env_id,
@ -502,7 +501,7 @@ for _v in _versions:
kwargs_dict_box_pushing_prodmp['name'] = _v kwargs_dict_box_pushing_prodmp['name'] = _v
kwargs_dict_box_pushing_prodmp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.]) kwargs_dict_box_pushing_prodmp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.])
kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.]) kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t: t % 25 == 0
register( register(
id=_env_id, id=_env_id,
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',

View File

@ -94,6 +94,8 @@ class BoxPushingEnv(MujocoEnv, utils.EzPickle):
# set target position # set target position
box_target_pos = self.sample_context() box_target_pos = self.sample_context()
while np.linalg.norm(box_target_pos[:2] - box_init_pos[:2]) < 0.3:
box_target_pos = self.sample_context()
#box_target_pos[0] = 0.4 #box_target_pos[0] = 0.4
#box_target_pos[1] = -0.3 #box_target_pos[1] = -0.3
#box_target_pos[-4:] = np.array([0.0, 0.0, 0.0, 1.0]) #box_target_pos[-4:] = np.array([0.0, 0.0, 0.0, 1.0])

View File

@ -24,7 +24,7 @@ def example_mp(env_name="HoleReacherProMP-v0", seed=1, iterations=1, render=True
# number of samples/full trajectories (multiple environment steps) # number of samples/full trajectories (multiple environment steps)
for i in range(iterations): for i in range(iterations):
if render and i % 2 == 0: if render and i % 1 == 0:
# This renders the full MP trajectory # This renders the full MP trajectory
# It is only required to call render() once in the beginning, which renders every consecutive trajectory. # It is only required to call render() once in the beginning, which renders every consecutive trajectory.
# Resetting to no rendering, can be achieved by render(mode=None). # Resetting to no rendering, can be achieved by render(mode=None).
@ -161,10 +161,10 @@ if __name__ == '__main__':
# ProMP # ProMP
# example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render) # example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
example_mp("BoxPushingDenseProMP-v0", seed=10, iterations=50, render=render) # example_mp("BoxPushingDenseProMP-v0", seed=10, iterations=50, render=render)
# ProDMP # ProDMP
# example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=50, render=render) example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=50, render=render)
# Altered basis functions # Altered basis functions
# obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render) # obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)