updated hole reacher example to new structure
This commit is contained in:
parent
e7525f61aa
commit
c5109ec2e7
@ -8,6 +8,7 @@ from matplotlib import patches
|
|||||||
|
|
||||||
from alr_envs.classic_control.utils import check_self_collision
|
from alr_envs.classic_control.utils import check_self_collision
|
||||||
from mp_env_api.envs.mp_env import MpEnv
|
from mp_env_api.envs.mp_env import MpEnv
|
||||||
|
from mp_env_api.envs.mp_env_wrapper import MpEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
class HoleReacherEnv(MpEnv):
|
class HoleReacherEnv(MpEnv):
|
||||||
@ -92,7 +93,7 @@ class HoleReacherEnv(MpEnv):
|
|||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
if self.random_start:
|
if self.random_start:
|
||||||
# Maybe change more than dirst seed
|
# Maybe change more than first seed
|
||||||
first_joint = self.np_random.uniform(np.pi / 4, 3 * np.pi / 4)
|
first_joint = self.np_random.uniform(np.pi / 4, 3 * np.pi / 4)
|
||||||
self._joint_angles = np.hstack([[first_joint], np.zeros(self.n_links - 1)])
|
self._joint_angles = np.hstack([[first_joint], np.zeros(self.n_links - 1)])
|
||||||
self._start_pos = self._joint_angles.copy()
|
self._start_pos = self._joint_angles.copy()
|
||||||
@ -276,6 +277,21 @@ class HoleReacherEnv(MpEnv):
|
|||||||
self.fig.gca().add_patch(right_block)
|
self.fig.gca().add_patch(right_block)
|
||||||
self.fig.gca().add_patch(hole_floor)
|
self.fig.gca().add_patch(hole_floor)
|
||||||
|
|
||||||
|
def seed(self, seed=None):
|
||||||
|
self.np_random, seed = seeding.np_random(seed)
|
||||||
|
return [seed]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end_effector(self):
|
||||||
|
return self._joints[self.n_links].T
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
super().close()
|
||||||
|
if self.fig is not None:
|
||||||
|
plt.close(self.fig)
|
||||||
|
|
||||||
|
|
||||||
|
class HoleReacherMPWrapper(MpEnvWrapper):
|
||||||
@property
|
@property
|
||||||
def active_obs(self):
|
def active_obs(self):
|
||||||
return np.hstack([
|
return np.hstack([
|
||||||
@ -296,17 +312,9 @@ class HoleReacherEnv(MpEnv):
|
|||||||
def goal_pos(self) -> Union[float, int, np.ndarray]:
|
def goal_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
|
||||||
return [seed]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def end_effector(self):
|
def dt(self) -> Union[float, int]:
|
||||||
return self._joints[self.n_links].T
|
return self.env.dt
|
||||||
|
|
||||||
def close(self):
|
|
||||||
if self.fig is not None:
|
|
||||||
plt.close(self.fig)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
Reference in New Issue
Block a user