updated hole reacher example to new structure

This commit is contained in:
ottofabian 2021-06-24 15:06:25 +02:00
parent e7525f61aa
commit c5109ec2e7

View File

@ -8,6 +8,7 @@ from matplotlib import patches
from alr_envs.classic_control.utils import check_self_collision
from mp_env_api.envs.mp_env import MpEnv
from mp_env_api.envs.mp_env_wrapper import MpEnvWrapper
class HoleReacherEnv(MpEnv):
@ -92,7 +93,7 @@ class HoleReacherEnv(MpEnv):
def reset(self):
if self.random_start:
# Maybe change more than dirst seed
# Maybe change more than first seed
first_joint = self.np_random.uniform(np.pi / 4, 3 * np.pi / 4)
self._joint_angles = np.hstack([[first_joint], np.zeros(self.n_links - 1)])
self._start_pos = self._joint_angles.copy()
@ -276,6 +277,21 @@ class HoleReacherEnv(MpEnv):
self.fig.gca().add_patch(right_block)
self.fig.gca().add_patch(hole_floor)
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
@property
def end_effector(self):
return self._joints[self.n_links].T
def close(self):
super().close()
if self.fig is not None:
plt.close(self.fig)
class HoleReacherMPWrapper(MpEnvWrapper):
@property
def active_obs(self):
return np.hstack([
@ -296,17 +312,9 @@ class HoleReacherEnv(MpEnv):
def goal_pos(self) -> Union[float, int, np.ndarray]:
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
@property
def end_effector(self):
return self._joints[self.n_links].T
def close(self):
if self.fig is not None:
plt.close(self.fig)
def dt(self) -> Union[float, int]:
return self.env.dt
if __name__ == '__main__':