updated hole reacher example to new structure
This commit is contained in:
parent
c5109ec2e7
commit
fa7dfdc081
@ -8,10 +8,10 @@ from matplotlib import patches
|
||||
|
||||
from alr_envs.classic_control.utils import check_self_collision
|
||||
from mp_env_api.envs.mp_env import MpEnv
|
||||
from mp_env_api.envs.mp_env_wrapper import MpEnvWrapper
|
||||
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
|
||||
|
||||
|
||||
class HoleReacherEnv(MpEnv):
|
||||
class HoleReacherEnv(gym.Env):
|
||||
|
||||
def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None,
|
||||
hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False,
|
||||
@ -23,9 +23,9 @@ class HoleReacherEnv(MpEnv):
|
||||
self.random_start = random_start
|
||||
|
||||
# provided initial parameters
|
||||
self._hole_x = hole_x # x-position of center of hole
|
||||
self._hole_width = hole_width # width of hole
|
||||
self._hole_depth = hole_depth # depth of hole
|
||||
self.hole_x = hole_x # x-position of center of hole
|
||||
self.hole_width = hole_width # width of hole
|
||||
self.hole_depth = hole_depth # depth of hole
|
||||
|
||||
# temp container for current env state
|
||||
self._tmp_hole_x = None
|
||||
@ -112,12 +112,12 @@ class HoleReacherEnv(MpEnv):
|
||||
return self._get_obs().copy()
|
||||
|
||||
def _generate_hole(self):
|
||||
self._tmp_hole_x = self.np_random.uniform(1, 3.5, 1) if self._hole_x is None else np.copy(self._hole_x)
|
||||
self._tmp_hole_width = self.np_random.uniform(0.15, 0.5, 1) if self._hole_width is None else np.copy(
|
||||
self._hole_width)
|
||||
self._tmp_hole_x = self.np_random.uniform(1, 3.5, 1) if self.hole_x is None else np.copy(self.hole_x)
|
||||
self._tmp_hole_width = self.np_random.uniform(0.15, 0.5, 1) if self.hole_width is None else np.copy(
|
||||
self.hole_width)
|
||||
# TODO we do not want this right now.
|
||||
self._tmp_hole_depth = self.np_random.uniform(1, 1, 1) if self._hole_depth is None else np.copy(
|
||||
self._hole_depth)
|
||||
self._tmp_hole_depth = self.np_random.uniform(1, 1, 1) if self.hole_depth is None else np.copy(
|
||||
self.hole_depth)
|
||||
self._goal = np.hstack([self._tmp_hole_x, -self._tmp_hole_depth])
|
||||
|
||||
def _update_joints(self):
|
||||
@ -291,15 +291,15 @@ class HoleReacherEnv(MpEnv):
|
||||
plt.close(self.fig)
|
||||
|
||||
|
||||
class HoleReacherMPWrapper(MpEnvWrapper):
|
||||
class HoleReacherMPWrapper(MPEnvWrapper):
|
||||
@property
|
||||
def active_obs(self):
|
||||
return np.hstack([
|
||||
[self.random_start] * self.n_links, # cos
|
||||
[self.random_start] * self.n_links, # sin
|
||||
[self.random_start] * self.n_links, # velocity
|
||||
[self._hole_width is None], # hole width
|
||||
# [self._hole_depth is None], # hole depth
|
||||
[self.env.random_start] * self.env.n_links, # cos
|
||||
[self.env.random_start] * self.env.n_links, # sin
|
||||
[self.env.random_start] * self.env.n_links, # velocity
|
||||
[self.env.hole_width is None], # hole width
|
||||
# [self.env.hole_depth is None], # hole depth
|
||||
[True] * 2, # x-y coordinates of target distance
|
||||
[False] # env steps
|
||||
])
|
||||
@ -315,26 +315,3 @@ class HoleReacherMPWrapper(MpEnvWrapper):
|
||||
@property
|
||||
def dt(self) -> Union[float, int]:
|
||||
return self.env.dt
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
nl = 5
|
||||
render_mode = "human" # "human" or "partial" or "final"
|
||||
env = HoleReacherEnv(n_links=nl, allow_self_collision=False, allow_wall_collision=False, hole_width=None,
|
||||
hole_depth=1, hole_x=None)
|
||||
obs = env.reset()
|
||||
|
||||
for i in range(2000):
|
||||
# objective.load_result("/tmp/cma")
|
||||
# test with random actions
|
||||
ac = 2 * env.action_space.sample()
|
||||
obs, rew, d, info = env.step(ac)
|
||||
if i % 10 == 0:
|
||||
env.render(mode=render_mode)
|
||||
|
||||
print(rew)
|
||||
|
||||
if d:
|
||||
env.reset()
|
||||
|
||||
env.close()
|
||||
|
Loading…
Reference in New Issue
Block a user