From dffa3e3682bf48dd3059c4cb4a03a5c67244a15a Mon Sep 17 00:00:00 2001 From: ottofabian Date: Thu, 24 Jun 2021 15:52:21 +0200 Subject: [PATCH] matched file structure of classic control with other tasks --- alr_envs/classic_control/__init__.py | 6 ++-- .../classic_control/hole_reacher/__init__.py | 0 .../{ => hole_reacher}/hole_reacher.py | 28 ----------------- .../hole_reacher/hole_reacher_mp_wrapper.py | 31 +++++++++++++++++++ .../simple_reacher/__init__.py | 0 .../{ => simple_reacher}/simple_reacher.py | 26 ---------------- .../simple_reacher_mp_wrapper.py | 29 +++++++++++++++++ alr_envs/classic_control/utils.py | 2 +- .../viapoint_reacher/__init__.py | 0 .../viapoint_reacher.py | 26 ---------------- .../viapoint_reacher_mp_wrapper.py | 29 +++++++++++++++++ alr_envs/mujoco/alr_mujoco_env.py | 4 +-- .../mujoco/gym_table_tennis/utils/util.py | 2 +- 13 files changed, 96 insertions(+), 87 deletions(-) create mode 100644 alr_envs/classic_control/hole_reacher/__init__.py rename alr_envs/classic_control/{ => hole_reacher}/hole_reacher.py (91%) create mode 100644 alr_envs/classic_control/hole_reacher/hole_reacher_mp_wrapper.py create mode 100644 alr_envs/classic_control/simple_reacher/__init__.py rename alr_envs/classic_control/{ => simple_reacher}/simple_reacher.py (89%) create mode 100644 alr_envs/classic_control/simple_reacher/simple_reacher_mp_wrapper.py create mode 100644 alr_envs/classic_control/viapoint_reacher/__init__.py rename alr_envs/classic_control/{ => viapoint_reacher}/viapoint_reacher.py (91%) create mode 100644 alr_envs/classic_control/viapoint_reacher/viapoint_reacher_mp_wrapper.py diff --git a/alr_envs/classic_control/__init__.py b/alr_envs/classic_control/__init__.py index 4a26eaa..1397e41 100644 --- a/alr_envs/classic_control/__init__.py +++ b/alr_envs/classic_control/__init__.py @@ -1,3 +1,3 @@ -from alr_envs.classic_control.simple_reacher import SimpleReacherEnv -from alr_envs.classic_control.viapoint_reacher import ViaPointReacher -from alr_envs.classic_control.hole_reacher import HoleReacherEnv +from alr_envs.classic_control.viapoint_reacher.viapoint_reacher import ViaPointReacher +from alr_envs.classic_control.simple_reacher.simple_reacher import SimpleReacherEnv +from alr_envs.classic_control.hole_reacher.hole_reacher import HoleReacherEnv diff --git a/alr_envs/classic_control/hole_reacher/__init__.py b/alr_envs/classic_control/hole_reacher/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/alr_envs/classic_control/hole_reacher.py b/alr_envs/classic_control/hole_reacher/hole_reacher.py similarity index 91% rename from alr_envs/classic_control/hole_reacher.py rename to alr_envs/classic_control/hole_reacher/hole_reacher.py index 9100686..04065f2 100644 --- a/alr_envs/classic_control/hole_reacher.py +++ b/alr_envs/classic_control/hole_reacher/hole_reacher.py @@ -7,8 +7,6 @@ from gym.utils import seeding from matplotlib import patches from alr_envs.classic_control.utils import check_self_collision -from mp_env_api.envs.mp_env import MpEnv -from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper class HoleReacherEnv(gym.Env): @@ -289,29 +287,3 @@ class HoleReacherEnv(gym.Env): super().close() if self.fig is not None: plt.close(self.fig) - - -class HoleReacherMPWrapper(MPEnvWrapper): - @property - def active_obs(self): - return np.hstack([ - [self.env.random_start] * self.env.n_links, # cos - [self.env.random_start] * self.env.n_links, # sin - [self.env.random_start] * self.env.n_links, # velocity - [self.env.hole_width is None], # hole width - # [self.env.hole_depth is None], # hole depth - [True] * 2, # x-y coordinates of target distance - [False] # env steps - ]) - - @property - def start_pos(self) -> Union[float, int, np.ndarray]: - return self._start_pos - - @property - def goal_pos(self) -> Union[float, int, np.ndarray]: - raise ValueError("Goal position is not available and has to be learnt based on the environment.") - - @property - def dt(self) -> Union[float, int]: - return self.env.dt diff --git a/alr_envs/classic_control/hole_reacher/hole_reacher_mp_wrapper.py b/alr_envs/classic_control/hole_reacher/hole_reacher_mp_wrapper.py new file mode 100644 index 0000000..3d95b8c --- /dev/null +++ b/alr_envs/classic_control/hole_reacher/hole_reacher_mp_wrapper.py @@ -0,0 +1,31 @@ +from typing import Union + +import numpy as np + +from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper + + +class HoleReacherMPWrapper(MPEnvWrapper): + @property + def active_obs(self): + return np.hstack([ + [self.env.random_start] * self.env.n_links, # cos + [self.env.random_start] * self.env.n_links, # sin + [self.env.random_start] * self.env.n_links, # velocity + [self.env.hole_width is None], # hole width + # [self.env.hole_depth is None], # hole depth + [True] * 2, # x-y coordinates of target distance + [False] # env steps + ]) + + @property + def start_pos(self) -> Union[float, int, np.ndarray]: + return self._start_pos + + @property + def goal_pos(self) -> Union[float, int, np.ndarray]: + raise ValueError("Goal position is not available and has to be learnt based on the environment.") + + @property + def dt(self) -> Union[float, int]: + return self.env.dt \ No newline at end of file diff --git a/alr_envs/classic_control/simple_reacher/__init__.py b/alr_envs/classic_control/simple_reacher/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/alr_envs/classic_control/simple_reacher.py b/alr_envs/classic_control/simple_reacher/simple_reacher.py similarity index 89% rename from alr_envs/classic_control/simple_reacher.py rename to alr_envs/classic_control/simple_reacher/simple_reacher.py index b61266f..4b65daf 100644 --- a/alr_envs/classic_control/simple_reacher.py +++ b/alr_envs/classic_control/simple_reacher/simple_reacher.py @@ -6,8 +6,6 @@ import numpy as np from gym import spaces from gym.utils import seeding -from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper - class SimpleReacherEnv(gym.Env): """ @@ -187,27 +185,3 @@ class SimpleReacherEnv(gym.Env): @property def end_effector(self): return self._joints[self.n_links].T - - -class SimpleReacherMPWrapper(MPEnvWrapper): - @property - def active_obs(self): - return np.hstack([ - [self.env.random_start] * self.env.n_links, # cos - [self.env.random_start] * self.env.n_links, # sin - [self.env.random_start] * self.env.n_links, # velocity - [True] * 2, # x-y coordinates of target distance - [False] # env steps - ]) - - @property - def start_pos(self): - return self._start_pos - - @property - def goal_pos(self): - raise ValueError("Goal position is not available and has to be learnt based on the environment.") - - @property - def dt(self) -> Union[float, int]: - return self.env.dt diff --git a/alr_envs/classic_control/simple_reacher/simple_reacher_mp_wrapper.py b/alr_envs/classic_control/simple_reacher/simple_reacher_mp_wrapper.py new file mode 100644 index 0000000..613f30d --- /dev/null +++ b/alr_envs/classic_control/simple_reacher/simple_reacher_mp_wrapper.py @@ -0,0 +1,29 @@ +from typing import Union + +import numpy as np + +from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper + + +class SimpleReacherMPWrapper(MPEnvWrapper): + @property + def active_obs(self): + return np.hstack([ + [self.env.random_start] * self.env.n_links, # cos + [self.env.random_start] * self.env.n_links, # sin + [self.env.random_start] * self.env.n_links, # velocity + [True] * 2, # x-y coordinates of target distance + [False] # env steps + ]) + + @property + def start_pos(self): + return self._start_pos + + @property + def goal_pos(self): + raise ValueError("Goal position is not available and has to be learnt based on the environment.") + + @property + def dt(self) -> Union[float, int]: + return self.env.dt diff --git a/alr_envs/classic_control/utils.py b/alr_envs/classic_control/utils.py index dbaa88e..fa8176a 100644 --- a/alr_envs/classic_control/utils.py +++ b/alr_envs/classic_control/utils.py @@ -10,7 +10,7 @@ def intersect(A, B, C, D): def check_self_collision(line_points): - "Checks whether line segments and intersect" + """Checks whether line segments and intersect""" for i, line1 in enumerate(line_points): for line2 in line_points[i + 2:, :, :]: if intersect(line1[0], line1[-1], line2[0], line2[-1]): diff --git a/alr_envs/classic_control/viapoint_reacher/__init__.py b/alr_envs/classic_control/viapoint_reacher/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/alr_envs/classic_control/viapoint_reacher.py b/alr_envs/classic_control/viapoint_reacher/viapoint_reacher.py similarity index 91% rename from alr_envs/classic_control/viapoint_reacher.py rename to alr_envs/classic_control/viapoint_reacher/viapoint_reacher.py index 2965df4..fc3264d 100644 --- a/alr_envs/classic_control/viapoint_reacher.py +++ b/alr_envs/classic_control/viapoint_reacher/viapoint_reacher.py @@ -6,8 +6,6 @@ import numpy as np from gym.utils import seeding from alr_envs.classic_control.utils import check_self_collision -from mp_env_api.envs.mp_env import MpEnv -from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper class ViaPointReacher(gym.Env): @@ -282,27 +280,3 @@ class ViaPointReacher(gym.Env): def close(self): if self.fig is not None: plt.close(self.fig) - - -class ViaPointReacherMPWrapper(MPEnvWrapper): - @property - def active_obs(self): - return np.hstack([ - [self.env.random_start] * self.env.n_links, # cos - [self.env.random_start] * self.env.n_links, # sin - [self.env.random_start] * self.env.n_links, # velocity - [self.env.via_target is None] * 2, # x-y coordinates of via point distance - [True] * 2, # x-y coordinates of target distance - [False] # env steps - ]) - - @property - def start_pos(self) -> Union[float, int, np.ndarray]: - return self._start_pos - - @property - def goal_pos(self) -> Union[float, int, np.ndarray]: - raise ValueError("Goal position is not available and has to be learnt based on the environment.") - - def dt(self) -> Union[float, int]: - return self.env.dt diff --git a/alr_envs/classic_control/viapoint_reacher/viapoint_reacher_mp_wrapper.py b/alr_envs/classic_control/viapoint_reacher/viapoint_reacher_mp_wrapper.py new file mode 100644 index 0000000..57aff1b --- /dev/null +++ b/alr_envs/classic_control/viapoint_reacher/viapoint_reacher_mp_wrapper.py @@ -0,0 +1,29 @@ +from typing import Union + +import numpy as np + +from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper + + +class ViaPointReacherMPWrapper(MPEnvWrapper): + @property + def active_obs(self): + return np.hstack([ + [self.env.random_start] * self.env.n_links, # cos + [self.env.random_start] * self.env.n_links, # sin + [self.env.random_start] * self.env.n_links, # velocity + [self.env.via_target is None] * 2, # x-y coordinates of via point distance + [True] * 2, # x-y coordinates of target distance + [False] # env steps + ]) + + @property + def start_pos(self) -> Union[float, int, np.ndarray]: + return self._start_pos + + @property + def goal_pos(self) -> Union[float, int, np.ndarray]: + raise ValueError("Goal position is not available and has to be learnt based on the environment.") + + def dt(self) -> Union[float, int]: + return self.env.dt diff --git a/alr_envs/mujoco/alr_mujoco_env.py b/alr_envs/mujoco/alr_mujoco_env.py index a95165c..01384c2 100644 --- a/alr_envs/mujoco/alr_mujoco_env.py +++ b/alr_envs/mujoco/alr_mujoco_env.py @@ -98,14 +98,14 @@ class AlrMujocoEnv(PositionalEnv, AlrEnv): @property def start_pos(self): """ - Start position of the agent, for example joint angles of a Panda robot. Necessary for MP wrapped envs. + Start position of the agent, for example joint angles of a Panda robot. Necessary for MP wrapped simple_reacher. """ return self._start_pos @property def start_vel(self): """ - Start velocity of the agent. Necessary for MP wrapped envs. + Start velocity of the agent. Necessary for MP wrapped simple_reacher. """ return self._start_vel diff --git a/alr_envs/mujoco/gym_table_tennis/utils/util.py b/alr_envs/mujoco/gym_table_tennis/utils/util.py index 716b3c6..fa308e3 100644 --- a/alr_envs/mujoco/gym_table_tennis/utils/util.py +++ b/alr_envs/mujoco/gym_table_tennis/utils/util.py @@ -39,7 +39,7 @@ def config_save(dir_path, config): def change_kp_in_xml(kp_list, - model_path="/home/zhou/slow/table_tennis_rl/simulation/gymTableTennis/gym_table_tennis/envs/robotics/assets/table_tennis/right_arm_actuator.xml"): + model_path="/home/zhou/slow/table_tennis_rl/simulation/gymTableTennis/gym_table_tennis/simple_reacher/robotics/assets/table_tennis/right_arm_actuator.xml"): tree = ET.parse(model_path) root = tree.getroot() # for actuator in root.find("actuator"):