resample the context if target is too near to the start position
This commit is contained in:
		
							parent
							
								
									eec171e04a
								
							
						
					
					
						commit
						1fd4a1e848
					
				| @ -81,9 +81,10 @@ class BlackBoxWrapper(gym.ObservationWrapper): | |||||||
|         bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt) |         bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt) | ||||||
|         # TODO we could think about initializing with the previous desired value in order to have a smooth transition |         # TODO we could think about initializing with the previous desired value in order to have a smooth transition | ||||||
|         #  at least from the planning point of view. |         #  at least from the planning point of view. | ||||||
|         self.traj_gen.set_boundary_conditions(torch.as_tensor(bc_time), | 
 | ||||||
|                                               torch.as_tensor(self.current_pos), |         self.traj_gen.set_boundary_conditions(bc_time, | ||||||
|                                               torch.as_tensor(self.current_vel)) |                                                 self.current_pos, | ||||||
|  |                                                 self.current_vel) | ||||||
|         duration = None if self.learn_sub_trajectories else self.duration |         duration = None if self.learn_sub_trajectories else self.duration | ||||||
|         self.traj_gen.set_duration(duration, self.dt) |         self.traj_gen.set_duration(duration, self.dt) | ||||||
|         # traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True) |         # traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True) | ||||||
|  | |||||||
| @ -369,7 +369,6 @@ for _v in _versions: | |||||||
|         kwargs=kwargs_dict_reacher_promp |         kwargs=kwargs_dict_reacher_promp | ||||||
|     ) |     ) | ||||||
|     ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) |     ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) | ||||||
| 
 |  | ||||||
| ######################################################################################################################## | ######################################################################################################################## | ||||||
| ## Beerpong ProMP | ## Beerpong ProMP | ||||||
| _versions = ['BeerPong-v0'] | _versions = ['BeerPong-v0'] | ||||||
| @ -485,7 +484,7 @@ for _v in _versions: | |||||||
|     kwargs_dict_box_pushing_promp['name'] = _v |     kwargs_dict_box_pushing_promp['name'] = _v | ||||||
|     kwargs_dict_box_pushing_promp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.]) |     kwargs_dict_box_pushing_promp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.]) | ||||||
|     kwargs_dict_box_pushing_promp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.]) |     kwargs_dict_box_pushing_promp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.]) | ||||||
|     kwargs_dict_box_pushing_promp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3.5 # 3.5, 4 to try |     kwargs_dict_box_pushing_promp['basis_generator_kwargs']['basis_bandwidth_factor'] = 2 # 3.5, 4 to try | ||||||
| 
 | 
 | ||||||
|     register( |     register( | ||||||
|         id=_env_id, |         id=_env_id, | ||||||
| @ -502,7 +501,7 @@ for _v in _versions: | |||||||
|     kwargs_dict_box_pushing_prodmp['name'] = _v |     kwargs_dict_box_pushing_prodmp['name'] = _v | ||||||
|     kwargs_dict_box_pushing_prodmp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.]) |     kwargs_dict_box_pushing_prodmp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.]) | ||||||
|     kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.]) |     kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.]) | ||||||
| 
 |     kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t: t % 25 == 0 | ||||||
|     register( |     register( | ||||||
|         id=_env_id, |         id=_env_id, | ||||||
|         entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', |         entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', | ||||||
|  | |||||||
| @ -94,6 +94,8 @@ class BoxPushingEnv(MujocoEnv, utils.EzPickle): | |||||||
| 
 | 
 | ||||||
|         # set target position |         # set target position | ||||||
|         box_target_pos = self.sample_context() |         box_target_pos = self.sample_context() | ||||||
|  |         while np.linalg.norm(box_target_pos[:2] - box_init_pos[:2]) < 0.3: | ||||||
|  |             box_target_pos = self.sample_context() | ||||||
|         #box_target_pos[0] = 0.4 |         #box_target_pos[0] = 0.4 | ||||||
|         #box_target_pos[1] = -0.3 |         #box_target_pos[1] = -0.3 | ||||||
|         #box_target_pos[-4:] = np.array([0.0, 0.0, 0.0, 1.0]) |         #box_target_pos[-4:] = np.array([0.0, 0.0, 0.0, 1.0]) | ||||||
|  | |||||||
| @ -24,7 +24,7 @@ def example_mp(env_name="HoleReacherProMP-v0", seed=1, iterations=1, render=True | |||||||
|     # number of samples/full trajectories (multiple environment steps) |     # number of samples/full trajectories (multiple environment steps) | ||||||
|     for i in range(iterations): |     for i in range(iterations): | ||||||
| 
 | 
 | ||||||
|         if render and i % 2 == 0: |         if render and i % 1 == 0: | ||||||
|             # This renders the full MP trajectory |             # This renders the full MP trajectory | ||||||
|             # It is only required to call render() once in the beginning, which renders every consecutive trajectory. |             # It is only required to call render() once in the beginning, which renders every consecutive trajectory. | ||||||
|             # Resetting to no rendering, can be achieved by render(mode=None). |             # Resetting to no rendering, can be achieved by render(mode=None). | ||||||
| @ -161,10 +161,10 @@ if __name__ == '__main__': | |||||||
| 
 | 
 | ||||||
|     # ProMP |     # ProMP | ||||||
|     # example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render) |     # example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render) | ||||||
|     example_mp("BoxPushingDenseProMP-v0", seed=10, iterations=50, render=render) |     # example_mp("BoxPushingDenseProMP-v0", seed=10, iterations=50, render=render) | ||||||
| 
 | 
 | ||||||
|     # ProDMP |     # ProDMP | ||||||
|     # example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=50, render=render) |     example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=50, render=render) | ||||||
| 
 | 
 | ||||||
|     # Altered basis functions |     # Altered basis functions | ||||||
|     # obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render) |     # obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render) | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user