updated hole reacher example to new structure
This commit is contained in:
		
							parent
							
								
									e7525f61aa
								
							
						
					
					
						commit
						c5109ec2e7
					
				| @ -8,6 +8,7 @@ from matplotlib import patches | |||||||
| 
 | 
 | ||||||
| from alr_envs.classic_control.utils import check_self_collision | from alr_envs.classic_control.utils import check_self_collision | ||||||
| from mp_env_api.envs.mp_env import MpEnv | from mp_env_api.envs.mp_env import MpEnv | ||||||
|  | from mp_env_api.envs.mp_env_wrapper import MpEnvWrapper | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class HoleReacherEnv(MpEnv): | class HoleReacherEnv(MpEnv): | ||||||
| @ -92,7 +93,7 @@ class HoleReacherEnv(MpEnv): | |||||||
| 
 | 
 | ||||||
|     def reset(self): |     def reset(self): | ||||||
|         if self.random_start: |         if self.random_start: | ||||||
|             # Maybe change more than dirst seed |             # Maybe change more than first seed | ||||||
|             first_joint = self.np_random.uniform(np.pi / 4, 3 * np.pi / 4) |             first_joint = self.np_random.uniform(np.pi / 4, 3 * np.pi / 4) | ||||||
|             self._joint_angles = np.hstack([[first_joint], np.zeros(self.n_links - 1)]) |             self._joint_angles = np.hstack([[first_joint], np.zeros(self.n_links - 1)]) | ||||||
|             self._start_pos = self._joint_angles.copy() |             self._start_pos = self._joint_angles.copy() | ||||||
| @ -276,6 +277,21 @@ class HoleReacherEnv(MpEnv): | |||||||
|             self.fig.gca().add_patch(right_block) |             self.fig.gca().add_patch(right_block) | ||||||
|             self.fig.gca().add_patch(hole_floor) |             self.fig.gca().add_patch(hole_floor) | ||||||
| 
 | 
 | ||||||
|  |     def seed(self, seed=None): | ||||||
|  |         self.np_random, seed = seeding.np_random(seed) | ||||||
|  |         return [seed] | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def end_effector(self): | ||||||
|  |         return self._joints[self.n_links].T | ||||||
|  | 
 | ||||||
|  |     def close(self): | ||||||
|  |         super().close() | ||||||
|  |         if self.fig is not None: | ||||||
|  |             plt.close(self.fig) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class HoleReacherMPWrapper(MpEnvWrapper): | ||||||
|     @property |     @property | ||||||
|     def active_obs(self): |     def active_obs(self): | ||||||
|         return np.hstack([ |         return np.hstack([ | ||||||
| @ -296,17 +312,9 @@ class HoleReacherEnv(MpEnv): | |||||||
|     def goal_pos(self) -> Union[float, int, np.ndarray]: |     def goal_pos(self) -> Union[float, int, np.ndarray]: | ||||||
|         raise ValueError("Goal position is not available and has to be learnt based on the environment.") |         raise ValueError("Goal position is not available and has to be learnt based on the environment.") | ||||||
| 
 | 
 | ||||||
|     def seed(self, seed=None): |  | ||||||
|         self.np_random, seed = seeding.np_random(seed) |  | ||||||
|         return [seed] |  | ||||||
| 
 |  | ||||||
|     @property |     @property | ||||||
|     def end_effector(self): |     def dt(self) -> Union[float, int]: | ||||||
|         return self._joints[self.n_links].T |         return self.env.dt | ||||||
| 
 |  | ||||||
|     def close(self): |  | ||||||
|         if self.fig is not None: |  | ||||||
|             plt.close(self.fig) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user