reacher envs pass unittests

This commit is contained in:
Maximilian Huettenrauch 2021-11-30 12:05:19 +01:00
parent cfa49a04ba
commit ca90f257d4
7 changed files with 24 additions and 10 deletions

View File

@ -237,7 +237,7 @@ for _v in _versions:
"mp_kwargs": { "mp_kwargs": {
"num_dof": 2 if "long" not in _v.lower() else 5, "num_dof": 2 if "long" not in _v.lower() else 5,
"num_basis": 5, "num_basis": 5,
"duration": 20, "duration": 2,
"alpha_phase": 2, "alpha_phase": 2,
"learn_goal": True, "learn_goal": True,
"policy_type": "velocity", "policy_type": "velocity",
@ -277,7 +277,7 @@ for _v in _versions:
"mp_kwargs": { "mp_kwargs": {
"num_dof": 2 if "long" not in _v.lower() else 5, "num_dof": 2 if "long" not in _v.lower() else 5,
"num_basis": 5, "num_basis": 5,
"duration": 20, "duration": 2,
"width": 0.025, "width": 0.025,
"policy_type": "velocity", "policy_type": "velocity",
"weights_scale": 0.2, "weights_scale": 0.2,

View File

@ -22,6 +22,8 @@ class BaseReacherEnv(gym.Env, ABC):
self.random_start = random_start self.random_start = random_start
self.allow_self_collision = allow_self_collision
# state # state
self._joints = None self._joints = None
self._joint_angles = None self._joint_angles = None
@ -103,6 +105,9 @@ class BaseReacherEnv(gym.Env, ABC):
def _check_self_collision(self): def _check_self_collision(self):
"""Checks whether line segments intersect""" """Checks whether line segments intersect"""
if self.allow_self_collision:
return False
if np.any(self._joint_angles > self.j_max) or np.any(self._joint_angles < self.j_min): if np.any(self._joint_angles > self.j_max) or np.any(self._joint_angles < self.j_min):
return True return True

View File

@ -127,6 +127,9 @@ class HoleReacherEnv(BaseReacherDirectEnv):
return np.squeeze(end_effector + self._joints[0, :]) return np.squeeze(end_effector + self._joints[0, :])
def _check_collisions(self) -> bool:
return self._check_self_collision() or self.check_wall_collision()
def check_wall_collision(self): def check_wall_collision(self):
line_points = self._get_line_points(num_points_per_link=100) line_points = self._get_line_points(num_points_per_link=100)
@ -223,6 +226,6 @@ if __name__ == "__main__":
for i in range(10000): for i in range(10000):
ac = env.action_space.sample() ac = env.action_space.sample()
obs, rew, done, info = env.step(ac) obs, rew, done, info = env.step(ac)
# env.render() env.render()
if done: if done:
env.reset() env.reset()

View File

@ -88,6 +88,9 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
self._goal = goal self._goal = goal
def _check_collisions(self) -> bool:
return self._check_self_collision()
def render(self, mode='human'): # pragma: no cover def render(self, mode='human'): # pragma: no cover
if self.fig is None: if self.fig is None:
# Create base figure once on the beginning. Afterwards only update # Create base figure once on the beginning. Afterwards only update

View File

@ -25,7 +25,6 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
self._goal = np.array((n_links, 0)) self._goal = np.array((n_links, 0))
# collision # collision
self.allow_self_collision = allow_self_collision
self.collision_penalty = collision_penalty self.collision_penalty = collision_penalty
state_bound = np.hstack([ state_bound = np.hstack([
@ -96,7 +95,7 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
reward -= 5e-8 * np.sum(acc ** 2) reward -= 5e-8 * np.sum(acc ** 2)
info = {"is_success": success, info = {"is_success": success,
"is_collided": self._is_collided, "is_collided": self._is_collided,
"end_effector": np.copy(env.end_effector)} "end_effector": np.copy(self.end_effector)}
return reward, info return reward, info
@ -114,6 +113,9 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
self._steps self._steps
]) ])
def _check_collisions(self) -> bool:
return self._check_self_collision()
def render(self, mode='human'): def render(self, mode='human'):
goal_pos = self._goal.T goal_pos = self._goal.T
via_pos = self._via_point.T via_pos = self._via_point.T

View File

@ -87,6 +87,7 @@ class ALRReacherEnv(MujocoEnv, utils.EzPickle):
[self._steps], [self._steps],
]) ])
if __name__ == '__main__': if __name__ == '__main__':
nl = 5 nl = 5
render_mode = "human" # "human" or "partial" or "final" render_mode = "human" # "human" or "partial" or "final"

View File

@ -112,7 +112,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
# Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper. # Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper.
# You can also add other gym.Wrappers in case they are needed. # You can also add other gym.Wrappers in case they are needed.
wrappers = [alr_envs.classic_control.hole_reacher.MPWrapper] wrappers = [alr_envs.alr.classic_control.hole_reacher.MPWrapper]
mp_kwargs = { mp_kwargs = {
"num_dof": 5, "num_dof": 5,
"num_basis": 5, "num_basis": 5,
@ -147,15 +147,15 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
if __name__ == '__main__': if __name__ == '__main__':
render = False render = True
# DMP # DMP
example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=10, render=render) example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)
# ProMP # ProMP
example_mp("alr_envs:HoleReacherProMP-v1", seed=10, iterations=100, render=render) example_mp("alr_envs:HoleReacherProMP-v1", seed=10, iterations=1, render=render)
# DetProMP # DetProMP
example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=100, render=render) example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=1, render=render)
# Altered basis functions # Altered basis functions
example_custom_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render) example_custom_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)