From 655d52aa35d8fe1e5b148f68555eb256fd8be580 Mon Sep 17 00:00:00 2001 From: Maximilian Huettenrauch Date: Tue, 30 Nov 2021 12:05:19 +0100 Subject: [PATCH] reacher envs pass unittests --- alr_envs/alr/__init__.py | 4 ++-- .../alr/classic_control/base_reacher/base_reacher.py | 5 +++++ .../alr/classic_control/hole_reacher/hole_reacher.py | 5 ++++- .../classic_control/simple_reacher/simple_reacher.py | 3 +++ .../viapoint_reacher/viapoint_reacher.py | 6 ++++-- alr_envs/alr/mujoco/reacher/alr_reacher.py | 1 + alr_envs/examples/examples_motion_primitives.py | 10 +++++----- 7 files changed, 24 insertions(+), 10 deletions(-) diff --git a/alr_envs/alr/__init__.py b/alr_envs/alr/__init__.py index 2f7a81d..03a986f 100644 --- a/alr_envs/alr/__init__.py +++ b/alr_envs/alr/__init__.py @@ -210,7 +210,7 @@ for _v in _versions: "mp_kwargs": { "num_dof": 2 if "long" not in _v.lower() else 5, "num_basis": 5, - "duration": 20, + "duration": 2, "alpha_phase": 2, "learn_goal": True, "policy_type": "velocity", @@ -250,7 +250,7 @@ for _v in _versions: "mp_kwargs": { "num_dof": 2 if "long" not in _v.lower() else 5, "num_basis": 5, - "duration": 20, + "duration": 2, "width": 0.025, "policy_type": "velocity", "weights_scale": 0.2, diff --git a/alr_envs/alr/classic_control/base_reacher/base_reacher.py b/alr_envs/alr/classic_control/base_reacher/base_reacher.py index dd8c64b..1b1ad19 100644 --- a/alr_envs/alr/classic_control/base_reacher/base_reacher.py +++ b/alr_envs/alr/classic_control/base_reacher/base_reacher.py @@ -22,6 +22,8 @@ class BaseReacherEnv(gym.Env, ABC): self.random_start = random_start + self.allow_self_collision = allow_self_collision + # state self._joints = None self._joint_angles = None @@ -103,6 +105,9 @@ class BaseReacherEnv(gym.Env, ABC): def _check_self_collision(self): """Checks whether line segments intersect""" + if self.allow_self_collision: + return False + if np.any(self._joint_angles > self.j_max) or np.any(self._joint_angles < self.j_min): return True diff --git a/alr_envs/alr/classic_control/hole_reacher/hole_reacher.py b/alr_envs/alr/classic_control/hole_reacher/hole_reacher.py index 03ceee2..dd7321a 100644 --- a/alr_envs/alr/classic_control/hole_reacher/hole_reacher.py +++ b/alr_envs/alr/classic_control/hole_reacher/hole_reacher.py @@ -127,6 +127,9 @@ class HoleReacherEnv(BaseReacherDirectEnv): return np.squeeze(end_effector + self._joints[0, :]) + def _check_collisions(self) -> bool: + return self._check_self_collision() or self.check_wall_collision() + def check_wall_collision(self): line_points = self._get_line_points(num_points_per_link=100) @@ -223,6 +226,6 @@ if __name__ == "__main__": for i in range(10000): ac = env.action_space.sample() obs, rew, done, info = env.step(ac) - # env.render() + env.render() if done: env.reset() diff --git a/alr_envs/alr/classic_control/simple_reacher/simple_reacher.py b/alr_envs/alr/classic_control/simple_reacher/simple_reacher.py index 40a4d95..758f824 100644 --- a/alr_envs/alr/classic_control/simple_reacher/simple_reacher.py +++ b/alr_envs/alr/classic_control/simple_reacher/simple_reacher.py @@ -88,6 +88,9 @@ class SimpleReacherEnv(BaseReacherTorqueEnv): self._goal = goal + def _check_collisions(self) -> bool: + return self._check_self_collision() + def render(self, mode='human'): # pragma: no cover if self.fig is None: # Create base figure once on the beginning. Afterwards only update diff --git a/alr_envs/alr/classic_control/viapoint_reacher/viapoint_reacher.py b/alr_envs/alr/classic_control/viapoint_reacher/viapoint_reacher.py index 3b47969..748eb99 100644 --- a/alr_envs/alr/classic_control/viapoint_reacher/viapoint_reacher.py +++ b/alr_envs/alr/classic_control/viapoint_reacher/viapoint_reacher.py @@ -25,7 +25,6 @@ class ViaPointReacherEnv(BaseReacherDirectEnv): self._goal = np.array((n_links, 0)) # collision - self.allow_self_collision = allow_self_collision self.collision_penalty = collision_penalty state_bound = np.hstack([ @@ -96,7 +95,7 @@ class ViaPointReacherEnv(BaseReacherDirectEnv): reward -= 5e-8 * np.sum(acc ** 2) info = {"is_success": success, "is_collided": self._is_collided, - "end_effector": np.copy(env.end_effector)} + "end_effector": np.copy(self.end_effector)} return reward, info @@ -114,6 +113,9 @@ class ViaPointReacherEnv(BaseReacherDirectEnv): self._steps ]) + def _check_collisions(self) -> bool: + return self._check_self_collision() + def render(self, mode='human'): goal_pos = self._goal.T via_pos = self._via_point.T diff --git a/alr_envs/alr/mujoco/reacher/alr_reacher.py b/alr_envs/alr/mujoco/reacher/alr_reacher.py index 6ca66d1..2d122d2 100644 --- a/alr_envs/alr/mujoco/reacher/alr_reacher.py +++ b/alr_envs/alr/mujoco/reacher/alr_reacher.py @@ -87,6 +87,7 @@ class ALRReacherEnv(MujocoEnv, utils.EzPickle): [self._steps], ]) + if __name__ == '__main__': nl = 5 render_mode = "human" # "human" or "partial" or "final" diff --git a/alr_envs/examples/examples_motion_primitives.py b/alr_envs/examples/examples_motion_primitives.py index e84da2f..0df05c1 100644 --- a/alr_envs/examples/examples_motion_primitives.py +++ b/alr_envs/examples/examples_motion_primitives.py @@ -112,7 +112,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True): # Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper. # You can also add other gym.Wrappers in case they are needed. - wrappers = [alr_envs.classic_control.hole_reacher.MPWrapper] + wrappers = [alr_envs.alr.classic_control.hole_reacher.MPWrapper] mp_kwargs = { "num_dof": 5, "num_basis": 5, @@ -147,15 +147,15 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True): if __name__ == '__main__': - render = False + render = True # DMP - example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=10, render=render) + example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render) # ProMP - example_mp("alr_envs:HoleReacherProMP-v1", seed=10, iterations=100, render=render) + example_mp("alr_envs:HoleReacherProMP-v1", seed=10, iterations=1, render=render) # DetProMP - example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=100, render=render) + example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=1, render=render) # Altered basis functions example_custom_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)