updated hole reacher example to new structure
This commit is contained in:
parent
c5109ec2e7
commit
fa7dfdc081
@ -8,10 +8,10 @@ from matplotlib import patches
|
|||||||
|
|
||||||
from alr_envs.classic_control.utils import check_self_collision
|
from alr_envs.classic_control.utils import check_self_collision
|
||||||
from mp_env_api.envs.mp_env import MpEnv
|
from mp_env_api.envs.mp_env import MpEnv
|
||||||
from mp_env_api.envs.mp_env_wrapper import MpEnvWrapper
|
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
class HoleReacherEnv(MpEnv):
|
class HoleReacherEnv(gym.Env):
|
||||||
|
|
||||||
def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None,
|
def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None,
|
||||||
hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False,
|
hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False,
|
||||||
@ -23,9 +23,9 @@ class HoleReacherEnv(MpEnv):
|
|||||||
self.random_start = random_start
|
self.random_start = random_start
|
||||||
|
|
||||||
# provided initial parameters
|
# provided initial parameters
|
||||||
self._hole_x = hole_x # x-position of center of hole
|
self.hole_x = hole_x # x-position of center of hole
|
||||||
self._hole_width = hole_width # width of hole
|
self.hole_width = hole_width # width of hole
|
||||||
self._hole_depth = hole_depth # depth of hole
|
self.hole_depth = hole_depth # depth of hole
|
||||||
|
|
||||||
# temp container for current env state
|
# temp container for current env state
|
||||||
self._tmp_hole_x = None
|
self._tmp_hole_x = None
|
||||||
@ -112,12 +112,12 @@ class HoleReacherEnv(MpEnv):
|
|||||||
return self._get_obs().copy()
|
return self._get_obs().copy()
|
||||||
|
|
||||||
def _generate_hole(self):
|
def _generate_hole(self):
|
||||||
self._tmp_hole_x = self.np_random.uniform(1, 3.5, 1) if self._hole_x is None else np.copy(self._hole_x)
|
self._tmp_hole_x = self.np_random.uniform(1, 3.5, 1) if self.hole_x is None else np.copy(self.hole_x)
|
||||||
self._tmp_hole_width = self.np_random.uniform(0.15, 0.5, 1) if self._hole_width is None else np.copy(
|
self._tmp_hole_width = self.np_random.uniform(0.15, 0.5, 1) if self.hole_width is None else np.copy(
|
||||||
self._hole_width)
|
self.hole_width)
|
||||||
# TODO we do not want this right now.
|
# TODO we do not want this right now.
|
||||||
self._tmp_hole_depth = self.np_random.uniform(1, 1, 1) if self._hole_depth is None else np.copy(
|
self._tmp_hole_depth = self.np_random.uniform(1, 1, 1) if self.hole_depth is None else np.copy(
|
||||||
self._hole_depth)
|
self.hole_depth)
|
||||||
self._goal = np.hstack([self._tmp_hole_x, -self._tmp_hole_depth])
|
self._goal = np.hstack([self._tmp_hole_x, -self._tmp_hole_depth])
|
||||||
|
|
||||||
def _update_joints(self):
|
def _update_joints(self):
|
||||||
@ -291,15 +291,15 @@ class HoleReacherEnv(MpEnv):
|
|||||||
plt.close(self.fig)
|
plt.close(self.fig)
|
||||||
|
|
||||||
|
|
||||||
class HoleReacherMPWrapper(MpEnvWrapper):
|
class HoleReacherMPWrapper(MPEnvWrapper):
|
||||||
@property
|
@property
|
||||||
def active_obs(self):
|
def active_obs(self):
|
||||||
return np.hstack([
|
return np.hstack([
|
||||||
[self.random_start] * self.n_links, # cos
|
[self.env.random_start] * self.env.n_links, # cos
|
||||||
[self.random_start] * self.n_links, # sin
|
[self.env.random_start] * self.env.n_links, # sin
|
||||||
[self.random_start] * self.n_links, # velocity
|
[self.env.random_start] * self.env.n_links, # velocity
|
||||||
[self._hole_width is None], # hole width
|
[self.env.hole_width is None], # hole width
|
||||||
# [self._hole_depth is None], # hole depth
|
# [self.env.hole_depth is None], # hole depth
|
||||||
[True] * 2, # x-y coordinates of target distance
|
[True] * 2, # x-y coordinates of target distance
|
||||||
[False] # env steps
|
[False] # env steps
|
||||||
])
|
])
|
||||||
@ -315,26 +315,3 @@ class HoleReacherMPWrapper(MpEnvWrapper):
|
|||||||
@property
|
@property
|
||||||
def dt(self) -> Union[float, int]:
|
def dt(self) -> Union[float, int]:
|
||||||
return self.env.dt
|
return self.env.dt
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
nl = 5
|
|
||||||
render_mode = "human" # "human" or "partial" or "final"
|
|
||||||
env = HoleReacherEnv(n_links=nl, allow_self_collision=False, allow_wall_collision=False, hole_width=None,
|
|
||||||
hole_depth=1, hole_x=None)
|
|
||||||
obs = env.reset()
|
|
||||||
|
|
||||||
for i in range(2000):
|
|
||||||
# objective.load_result("/tmp/cma")
|
|
||||||
# test with random actions
|
|
||||||
ac = 2 * env.action_space.sample()
|
|
||||||
obs, rew, d, info = env.step(ac)
|
|
||||||
if i % 10 == 0:
|
|
||||||
env.render(mode=render_mode)
|
|
||||||
|
|
||||||
print(rew)
|
|
||||||
|
|
||||||
if d:
|
|
||||||
env.reset()
|
|
||||||
|
|
||||||
env.close()
|
|
||||||
|
Loading…
Reference in New Issue
Block a user