updated hole reacher example to new structure
This commit is contained in:
		
							parent
							
								
									c5109ec2e7
								
							
						
					
					
						commit
						fa7dfdc081
					
				| @ -8,10 +8,10 @@ from matplotlib import patches | ||||
| 
 | ||||
| from alr_envs.classic_control.utils import check_self_collision | ||||
| from mp_env_api.envs.mp_env import MpEnv | ||||
| from mp_env_api.envs.mp_env_wrapper import MpEnvWrapper | ||||
| from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper | ||||
| 
 | ||||
| 
 | ||||
| class HoleReacherEnv(MpEnv): | ||||
| class HoleReacherEnv(gym.Env): | ||||
| 
 | ||||
|     def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None, | ||||
|                  hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False, | ||||
| @ -23,9 +23,9 @@ class HoleReacherEnv(MpEnv): | ||||
|         self.random_start = random_start | ||||
| 
 | ||||
|         # provided initial parameters | ||||
|         self._hole_x = hole_x  # x-position of center of hole | ||||
|         self._hole_width = hole_width  # width of hole | ||||
|         self._hole_depth = hole_depth  # depth of hole | ||||
|         self.hole_x = hole_x  # x-position of center of hole | ||||
|         self.hole_width = hole_width  # width of hole | ||||
|         self.hole_depth = hole_depth  # depth of hole | ||||
| 
 | ||||
|         # temp container for current env state | ||||
|         self._tmp_hole_x = None | ||||
| @ -112,12 +112,12 @@ class HoleReacherEnv(MpEnv): | ||||
|         return self._get_obs().copy() | ||||
| 
 | ||||
|     def _generate_hole(self): | ||||
|         self._tmp_hole_x = self.np_random.uniform(1, 3.5, 1) if self._hole_x is None else np.copy(self._hole_x) | ||||
|         self._tmp_hole_width = self.np_random.uniform(0.15, 0.5, 1) if self._hole_width is None else np.copy( | ||||
|             self._hole_width) | ||||
|         self._tmp_hole_x = self.np_random.uniform(1, 3.5, 1) if self.hole_x is None else np.copy(self.hole_x) | ||||
|         self._tmp_hole_width = self.np_random.uniform(0.15, 0.5, 1) if self.hole_width is None else np.copy( | ||||
|             self.hole_width) | ||||
|         # TODO we do not want this right now. | ||||
|         self._tmp_hole_depth = self.np_random.uniform(1, 1, 1) if self._hole_depth is None else np.copy( | ||||
|             self._hole_depth) | ||||
|         self._tmp_hole_depth = self.np_random.uniform(1, 1, 1) if self.hole_depth is None else np.copy( | ||||
|             self.hole_depth) | ||||
|         self._goal = np.hstack([self._tmp_hole_x, -self._tmp_hole_depth]) | ||||
| 
 | ||||
|     def _update_joints(self): | ||||
| @ -291,15 +291,15 @@ class HoleReacherEnv(MpEnv): | ||||
|             plt.close(self.fig) | ||||
| 
 | ||||
| 
 | ||||
| class HoleReacherMPWrapper(MpEnvWrapper): | ||||
| class HoleReacherMPWrapper(MPEnvWrapper): | ||||
|     @property | ||||
|     def active_obs(self): | ||||
|         return np.hstack([ | ||||
|             [self.random_start] * self.n_links,  # cos | ||||
|             [self.random_start] * self.n_links,  # sin | ||||
|             [self.random_start] * self.n_links,  # velocity | ||||
|             [self._hole_width is None],  # hole width | ||||
|             # [self._hole_depth is None],  # hole depth | ||||
|             [self.env.random_start] * self.env.n_links,  # cos | ||||
|             [self.env.random_start] * self.env.n_links,  # sin | ||||
|             [self.env.random_start] * self.env.n_links,  # velocity | ||||
|             [self.env.hole_width is None],  # hole width | ||||
|             # [self.env.hole_depth is None],  # hole depth | ||||
|             [True] * 2,  # x-y coordinates of target distance | ||||
|             [False]  # env steps | ||||
|         ]) | ||||
| @ -315,26 +315,3 @@ class HoleReacherMPWrapper(MpEnvWrapper): | ||||
|     @property | ||||
|     def dt(self) -> Union[float, int]: | ||||
|         return self.env.dt | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     nl = 5 | ||||
|     render_mode = "human"  # "human" or "partial" or "final" | ||||
|     env = HoleReacherEnv(n_links=nl, allow_self_collision=False, allow_wall_collision=False, hole_width=None, | ||||
|                          hole_depth=1, hole_x=None) | ||||
|     obs = env.reset() | ||||
| 
 | ||||
|     for i in range(2000): | ||||
|         # objective.load_result("/tmp/cma") | ||||
|         # test with random actions | ||||
|         ac = 2 * env.action_space.sample() | ||||
|         obs, rew, d, info = env.step(ac) | ||||
|         if i % 10 == 0: | ||||
|             env.render(mode=render_mode) | ||||
| 
 | ||||
|         print(rew) | ||||
| 
 | ||||
|         if d: | ||||
|             env.reset() | ||||
| 
 | ||||
|     env.close() | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user