updated for new set_duration function.
This commit is contained in:
		
							parent
							
								
									623adaacd8
								
							
						
					
					
						commit
						1ca14f1c93
					
				| @ -50,7 +50,7 @@ class BlackBoxWrapper(gym.ObservationWrapper): | |||||||
|         self.tracking_controller = tracking_controller |         self.tracking_controller = tracking_controller | ||||||
|         # self.time_steps = np.linspace(0, self.duration, self.traj_steps) |         # self.time_steps = np.linspace(0, self.duration, self.traj_steps) | ||||||
|         # self.traj_gen.set_mp_times(self.time_steps) |         # self.traj_gen.set_mp_times(self.time_steps) | ||||||
|         self.traj_gen.set_duration(self.duration - self.dt, self.dt) |         self.traj_gen.set_duration(self.duration, self.dt) | ||||||
| 
 | 
 | ||||||
|         # reward computation |         # reward computation | ||||||
|         self.reward_aggregation = reward_aggregation |         self.reward_aggregation = reward_aggregation | ||||||
| @ -78,11 +78,13 @@ class BlackBoxWrapper(gym.ObservationWrapper): | |||||||
|         bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt) |         bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt) | ||||||
|         self.traj_gen.set_boundary_conditions(bc_time, self.current_pos, self.current_vel) |         self.traj_gen.set_boundary_conditions(bc_time, self.current_pos, self.current_vel) | ||||||
|         # TODO: remove the - self.dt after Bruces fix. |         # TODO: remove the - self.dt after Bruces fix. | ||||||
|         self.traj_gen.set_duration(None if self.learn_sub_trajectories else self.duration - self.dt, self.dt) |         self.traj_gen.set_duration(None if self.learn_sub_trajectories else self.duration, self.dt) | ||||||
|         # traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True) |         # traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True) | ||||||
|         trajectory = get_numpy(self.traj_gen.get_traj_pos()) |         trajectory = get_numpy(self.traj_gen.get_traj_pos()) | ||||||
|         velocity = get_numpy(self.traj_gen.get_traj_vel()) |         velocity = get_numpy(self.traj_gen.get_traj_vel()) | ||||||
| 
 | 
 | ||||||
|  |         print(len(trajectory)) | ||||||
|  | 
 | ||||||
|         if self.do_replanning: |         if self.do_replanning: | ||||||
|             # Remove first part of trajectory as this is already over |             # Remove first part of trajectory as this is already over | ||||||
|             trajectory = trajectory[self.current_traj_steps:] |             trajectory = trajectory[self.current_traj_steps:] | ||||||
|  | |||||||
| @ -106,31 +106,31 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True): | |||||||
| 
 | 
 | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     base_env_id = "HoleReacher-v0" |     base_env_id = "Reacher5d-v0" | ||||||
| 
 | 
 | ||||||
|     # Replace this wrapper with the custom wrapper for your environment by inheriting from the RawInterfaceWrapper. |     # Replace this wrapper with the custom wrapper for your environment by inheriting from the RawInterfaceWrapper. | ||||||
|     # You can also add other gym.Wrappers in case they are needed. |     # You can also add other gym.Wrappers in case they are needed. | ||||||
|     wrappers = [fancy_gym.envs.classic_control.hole_reacher.MPWrapper] |     wrappers = [fancy_gym.envs.mujoco.reacher.MPWrapper] | ||||||
| 
 | 
 | ||||||
|     # # For a ProMP |     # For a ProMP | ||||||
|     # trajectory_generator_kwargs = {'trajectory_generator_type': 'promp', |     trajectory_generator_kwargs = {'trajectory_generator_type': 'promp', | ||||||
|     #                                'weight_scale': 2} |                                    'weight_scale': 2} | ||||||
|     # phase_generator_kwargs = {'phase_generator_type': 'linear'} |     phase_generator_kwargs = {'phase_generator_type': 'linear'} | ||||||
|     # controller_kwargs = {'controller_type': 'velocity'} |  | ||||||
|     # basis_generator_kwargs = {'basis_generator_type': 'zero_rbf', |  | ||||||
|     #                           'num_basis': 5, |  | ||||||
|     #                           'num_basis_zero_start': 1 |  | ||||||
|     #                           } |  | ||||||
| 
 |  | ||||||
|     # For a DMP |  | ||||||
|     trajectory_generator_kwargs = {'trajectory_generator_type': 'dmp', |  | ||||||
|                                    'weight_scale': 500} |  | ||||||
|     phase_generator_kwargs = {'phase_generator_type': 'exp', |  | ||||||
|                               'alpha_phase': 2.5} |  | ||||||
|     controller_kwargs = {'controller_type': 'velocity'} |     controller_kwargs = {'controller_type': 'velocity'} | ||||||
|     basis_generator_kwargs = {'basis_generator_type': 'rbf', |     basis_generator_kwargs = {'basis_generator_type': 'zero_rbf', | ||||||
|                               'num_basis': 5 |                               'num_basis': 5, | ||||||
|  |                               'num_basis_zero_start': 1 | ||||||
|                               } |                               } | ||||||
|  | 
 | ||||||
|  |     # # For a DMP | ||||||
|  |     # trajectory_generator_kwargs = {'trajectory_generator_type': 'dmp', | ||||||
|  |     #                                'weight_scale': 500} | ||||||
|  |     # phase_generator_kwargs = {'phase_generator_type': 'exp', | ||||||
|  |     #                           'alpha_phase': 2.5} | ||||||
|  |     # controller_kwargs = {'controller_type': 'velocity'} | ||||||
|  |     # basis_generator_kwargs = {'basis_generator_type': 'rbf', | ||||||
|  |     #                           'num_basis': 5 | ||||||
|  |     #                           } | ||||||
|     env = fancy_gym.make_bb(env_id=base_env_id, wrappers=wrappers, black_box_kwargs={}, |     env = fancy_gym.make_bb(env_id=base_env_id, wrappers=wrappers, black_box_kwargs={}, | ||||||
|                             traj_gen_kwargs=trajectory_generator_kwargs, controller_kwargs=controller_kwargs, |                             traj_gen_kwargs=trajectory_generator_kwargs, controller_kwargs=controller_kwargs, | ||||||
|                             phase_kwargs=phase_generator_kwargs, basis_kwargs=basis_generator_kwargs, |                             phase_kwargs=phase_generator_kwargs, basis_kwargs=basis_generator_kwargs, | ||||||
| @ -155,15 +155,15 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     render = True |     render = False | ||||||
|     # DMP |     # DMP | ||||||
|     example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render) |     # example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render) | ||||||
|     # |     # | ||||||
|     # # ProMP |     # # ProMP | ||||||
|     example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render) |     # example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render) | ||||||
| 
 | 
 | ||||||
|     # Altered basis functions |     # Altered basis functions | ||||||
|     obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=5, render=render) |     # obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render) | ||||||
| 
 | 
 | ||||||
|     # Custom MP |     # Custom MP | ||||||
|     example_fully_custom_mp(seed=10, iterations=1, render=render) |     example_fully_custom_mp(seed=10, iterations=1, render=render) | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user