resample the context if target is too near to the start position
This commit is contained in:
parent
eec171e04a
commit
1fd4a1e848
@ -81,9 +81,10 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
|
bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
|
||||||
# TODO we could think about initializing with the previous desired value in order to have a smooth transition
|
# TODO we could think about initializing with the previous desired value in order to have a smooth transition
|
||||||
# at least from the planning point of view.
|
# at least from the planning point of view.
|
||||||
self.traj_gen.set_boundary_conditions(torch.as_tensor(bc_time),
|
|
||||||
torch.as_tensor(self.current_pos),
|
self.traj_gen.set_boundary_conditions(bc_time,
|
||||||
torch.as_tensor(self.current_vel))
|
self.current_pos,
|
||||||
|
self.current_vel)
|
||||||
duration = None if self.learn_sub_trajectories else self.duration
|
duration = None if self.learn_sub_trajectories else self.duration
|
||||||
self.traj_gen.set_duration(duration, self.dt)
|
self.traj_gen.set_duration(duration, self.dt)
|
||||||
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
|
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
|
||||||
|
@ -369,7 +369,6 @@ for _v in _versions:
|
|||||||
kwargs=kwargs_dict_reacher_promp
|
kwargs=kwargs_dict_reacher_promp
|
||||||
)
|
)
|
||||||
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
########################################################################################################################
|
########################################################################################################################
|
||||||
## Beerpong ProMP
|
## Beerpong ProMP
|
||||||
_versions = ['BeerPong-v0']
|
_versions = ['BeerPong-v0']
|
||||||
@ -485,7 +484,7 @@ for _v in _versions:
|
|||||||
kwargs_dict_box_pushing_promp['name'] = _v
|
kwargs_dict_box_pushing_promp['name'] = _v
|
||||||
kwargs_dict_box_pushing_promp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.])
|
kwargs_dict_box_pushing_promp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.])
|
||||||
kwargs_dict_box_pushing_promp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
|
kwargs_dict_box_pushing_promp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
|
||||||
kwargs_dict_box_pushing_promp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3.5 # 3.5, 4 to try
|
kwargs_dict_box_pushing_promp['basis_generator_kwargs']['basis_bandwidth_factor'] = 2 # 3.5, 4 to try
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
@ -502,7 +501,7 @@ for _v in _versions:
|
|||||||
kwargs_dict_box_pushing_prodmp['name'] = _v
|
kwargs_dict_box_pushing_prodmp['name'] = _v
|
||||||
kwargs_dict_box_pushing_prodmp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.])
|
kwargs_dict_box_pushing_prodmp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.])
|
||||||
kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
|
kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
|
||||||
|
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t: t % 25 == 0
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||||
|
@ -94,6 +94,8 @@ class BoxPushingEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
# set target position
|
# set target position
|
||||||
box_target_pos = self.sample_context()
|
box_target_pos = self.sample_context()
|
||||||
|
while np.linalg.norm(box_target_pos[:2] - box_init_pos[:2]) < 0.3:
|
||||||
|
box_target_pos = self.sample_context()
|
||||||
#box_target_pos[0] = 0.4
|
#box_target_pos[0] = 0.4
|
||||||
#box_target_pos[1] = -0.3
|
#box_target_pos[1] = -0.3
|
||||||
#box_target_pos[-4:] = np.array([0.0, 0.0, 0.0, 1.0])
|
#box_target_pos[-4:] = np.array([0.0, 0.0, 0.0, 1.0])
|
||||||
|
@ -24,7 +24,7 @@ def example_mp(env_name="HoleReacherProMP-v0", seed=1, iterations=1, render=True
|
|||||||
# number of samples/full trajectories (multiple environment steps)
|
# number of samples/full trajectories (multiple environment steps)
|
||||||
for i in range(iterations):
|
for i in range(iterations):
|
||||||
|
|
||||||
if render and i % 2 == 0:
|
if render and i % 1 == 0:
|
||||||
# This renders the full MP trajectory
|
# This renders the full MP trajectory
|
||||||
# It is only required to call render() once in the beginning, which renders every consecutive trajectory.
|
# It is only required to call render() once in the beginning, which renders every consecutive trajectory.
|
||||||
# Resetting to no rendering, can be achieved by render(mode=None).
|
# Resetting to no rendering, can be achieved by render(mode=None).
|
||||||
@ -161,10 +161,10 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# ProMP
|
# ProMP
|
||||||
# example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
|
# example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
|
||||||
example_mp("BoxPushingDenseProMP-v0", seed=10, iterations=50, render=render)
|
# example_mp("BoxPushingDenseProMP-v0", seed=10, iterations=50, render=render)
|
||||||
|
|
||||||
# ProDMP
|
# ProDMP
|
||||||
# example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=50, render=render)
|
example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=50, render=render)
|
||||||
|
|
||||||
# Altered basis functions
|
# Altered basis functions
|
||||||
# obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
|
# obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
|
||||||
|
Loading…
Reference in New Issue
Block a user