updated hole reacher example to new structure

This commit is contained in:
ottofabian 2021-06-24 15:19:05 +02:00
parent c5109ec2e7
commit fa7dfdc081

View File

@ -8,10 +8,10 @@ from matplotlib import patches
from alr_envs.classic_control.utils import check_self_collision
from mp_env_api.envs.mp_env import MpEnv
from mp_env_api.envs.mp_env_wrapper import MpEnvWrapper
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
class HoleReacherEnv(MpEnv):
class HoleReacherEnv(gym.Env):
def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None,
hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False,
@ -23,9 +23,9 @@ class HoleReacherEnv(MpEnv):
self.random_start = random_start
# provided initial parameters
self._hole_x = hole_x # x-position of center of hole
self._hole_width = hole_width # width of hole
self._hole_depth = hole_depth # depth of hole
self.hole_x = hole_x # x-position of center of hole
self.hole_width = hole_width # width of hole
self.hole_depth = hole_depth # depth of hole
# temp container for current env state
self._tmp_hole_x = None
@ -112,12 +112,12 @@ class HoleReacherEnv(MpEnv):
return self._get_obs().copy()
def _generate_hole(self):
self._tmp_hole_x = self.np_random.uniform(1, 3.5, 1) if self._hole_x is None else np.copy(self._hole_x)
self._tmp_hole_width = self.np_random.uniform(0.15, 0.5, 1) if self._hole_width is None else np.copy(
self._hole_width)
self._tmp_hole_x = self.np_random.uniform(1, 3.5, 1) if self.hole_x is None else np.copy(self.hole_x)
self._tmp_hole_width = self.np_random.uniform(0.15, 0.5, 1) if self.hole_width is None else np.copy(
self.hole_width)
# TODO we do not want this right now.
self._tmp_hole_depth = self.np_random.uniform(1, 1, 1) if self._hole_depth is None else np.copy(
self._hole_depth)
self._tmp_hole_depth = self.np_random.uniform(1, 1, 1) if self.hole_depth is None else np.copy(
self.hole_depth)
self._goal = np.hstack([self._tmp_hole_x, -self._tmp_hole_depth])
def _update_joints(self):
@ -291,15 +291,15 @@ class HoleReacherEnv(MpEnv):
plt.close(self.fig)
class HoleReacherMPWrapper(MpEnvWrapper):
class HoleReacherMPWrapper(MPEnvWrapper):
@property
def active_obs(self):
return np.hstack([
[self.random_start] * self.n_links, # cos
[self.random_start] * self.n_links, # sin
[self.random_start] * self.n_links, # velocity
[self._hole_width is None], # hole width
# [self._hole_depth is None], # hole depth
[self.env.random_start] * self.env.n_links, # cos
[self.env.random_start] * self.env.n_links, # sin
[self.env.random_start] * self.env.n_links, # velocity
[self.env.hole_width is None], # hole width
# [self.env.hole_depth is None], # hole depth
[True] * 2, # x-y coordinates of target distance
[False] # env steps
])
@ -315,26 +315,3 @@ class HoleReacherMPWrapper(MpEnvWrapper):
@property
def dt(self) -> Union[float, int]:
return self.env.dt
if __name__ == '__main__':
nl = 5
render_mode = "human" # "human" or "partial" or "final"
env = HoleReacherEnv(n_links=nl, allow_self_collision=False, allow_wall_collision=False, hole_width=None,
hole_depth=1, hole_x=None)
obs = env.reset()
for i in range(2000):
# objective.load_result("/tmp/cma")
# test with random actions
ac = 2 * env.action_space.sample()
obs, rew, d, info = env.step(ac)
if i % 10 == 0:
env.render(mode=render_mode)
print(rew)
if d:
env.reset()
env.close()