updated hole reacher example to new structure

This commit is contained in:
ottofabian 2021-06-24 15:06:25 +02:00
parent e7525f61aa
commit c5109ec2e7

View File

@ -8,6 +8,7 @@ from matplotlib import patches
from alr_envs.classic_control.utils import check_self_collision from alr_envs.classic_control.utils import check_self_collision
from mp_env_api.envs.mp_env import MpEnv from mp_env_api.envs.mp_env import MpEnv
from mp_env_api.envs.mp_env_wrapper import MpEnvWrapper
class HoleReacherEnv(MpEnv): class HoleReacherEnv(MpEnv):
@ -92,7 +93,7 @@ class HoleReacherEnv(MpEnv):
def reset(self): def reset(self):
if self.random_start: if self.random_start:
# Maybe change more than dirst seed # Maybe change more than first seed
first_joint = self.np_random.uniform(np.pi / 4, 3 * np.pi / 4) first_joint = self.np_random.uniform(np.pi / 4, 3 * np.pi / 4)
self._joint_angles = np.hstack([[first_joint], np.zeros(self.n_links - 1)]) self._joint_angles = np.hstack([[first_joint], np.zeros(self.n_links - 1)])
self._start_pos = self._joint_angles.copy() self._start_pos = self._joint_angles.copy()
@ -276,6 +277,21 @@ class HoleReacherEnv(MpEnv):
self.fig.gca().add_patch(right_block) self.fig.gca().add_patch(right_block)
self.fig.gca().add_patch(hole_floor) self.fig.gca().add_patch(hole_floor)
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
@property
def end_effector(self):
return self._joints[self.n_links].T
def close(self):
super().close()
if self.fig is not None:
plt.close(self.fig)
class HoleReacherMPWrapper(MpEnvWrapper):
@property @property
def active_obs(self): def active_obs(self):
return np.hstack([ return np.hstack([
@ -296,17 +312,9 @@ class HoleReacherEnv(MpEnv):
def goal_pos(self) -> Union[float, int, np.ndarray]: def goal_pos(self) -> Union[float, int, np.ndarray]:
raise ValueError("Goal position is not available and has to be learnt based on the environment.") raise ValueError("Goal position is not available and has to be learnt based on the environment.")
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
@property @property
def end_effector(self): def dt(self) -> Union[float, int]:
return self._joints[self.n_links].T return self.env.dt
def close(self):
if self.fig is not None:
plt.close(self.fig)
if __name__ == '__main__': if __name__ == '__main__':