updated hole reacher example to new structure
This commit is contained in:
		
							parent
							
								
									e7525f61aa
								
							
						
					
					
						commit
						c5109ec2e7
					
				@ -8,6 +8,7 @@ from matplotlib import patches
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from alr_envs.classic_control.utils import check_self_collision
 | 
					from alr_envs.classic_control.utils import check_self_collision
 | 
				
			||||||
from mp_env_api.envs.mp_env import MpEnv
 | 
					from mp_env_api.envs.mp_env import MpEnv
 | 
				
			||||||
 | 
					from mp_env_api.envs.mp_env_wrapper import MpEnvWrapper
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class HoleReacherEnv(MpEnv):
 | 
					class HoleReacherEnv(MpEnv):
 | 
				
			||||||
@ -92,7 +93,7 @@ class HoleReacherEnv(MpEnv):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def reset(self):
 | 
					    def reset(self):
 | 
				
			||||||
        if self.random_start:
 | 
					        if self.random_start:
 | 
				
			||||||
            # Maybe change more than dirst seed
 | 
					            # Maybe change more than first seed
 | 
				
			||||||
            first_joint = self.np_random.uniform(np.pi / 4, 3 * np.pi / 4)
 | 
					            first_joint = self.np_random.uniform(np.pi / 4, 3 * np.pi / 4)
 | 
				
			||||||
            self._joint_angles = np.hstack([[first_joint], np.zeros(self.n_links - 1)])
 | 
					            self._joint_angles = np.hstack([[first_joint], np.zeros(self.n_links - 1)])
 | 
				
			||||||
            self._start_pos = self._joint_angles.copy()
 | 
					            self._start_pos = self._joint_angles.copy()
 | 
				
			||||||
@ -276,6 +277,21 @@ class HoleReacherEnv(MpEnv):
 | 
				
			|||||||
            self.fig.gca().add_patch(right_block)
 | 
					            self.fig.gca().add_patch(right_block)
 | 
				
			||||||
            self.fig.gca().add_patch(hole_floor)
 | 
					            self.fig.gca().add_patch(hole_floor)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def seed(self, seed=None):
 | 
				
			||||||
 | 
					        self.np_random, seed = seeding.np_random(seed)
 | 
				
			||||||
 | 
					        return [seed]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def end_effector(self):
 | 
				
			||||||
 | 
					        return self._joints[self.n_links].T
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def close(self):
 | 
				
			||||||
 | 
					        super().close()
 | 
				
			||||||
 | 
					        if self.fig is not None:
 | 
				
			||||||
 | 
					            plt.close(self.fig)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class HoleReacherMPWrapper(MpEnvWrapper):
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def active_obs(self):
 | 
					    def active_obs(self):
 | 
				
			||||||
        return np.hstack([
 | 
					        return np.hstack([
 | 
				
			||||||
@ -296,17 +312,9 @@ class HoleReacherEnv(MpEnv):
 | 
				
			|||||||
    def goal_pos(self) -> Union[float, int, np.ndarray]:
 | 
					    def goal_pos(self) -> Union[float, int, np.ndarray]:
 | 
				
			||||||
        raise ValueError("Goal position is not available and has to be learnt based on the environment.")
 | 
					        raise ValueError("Goal position is not available and has to be learnt based on the environment.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def seed(self, seed=None):
 | 
					 | 
				
			||||||
        self.np_random, seed = seeding.np_random(seed)
 | 
					 | 
				
			||||||
        return [seed]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def end_effector(self):
 | 
					    def dt(self) -> Union[float, int]:
 | 
				
			||||||
        return self._joints[self.n_links].T
 | 
					        return self.env.dt
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def close(self):
 | 
					 | 
				
			||||||
        if self.fig is not None:
 | 
					 | 
				
			||||||
            plt.close(self.fig)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user