matched file structure of classic control with other tasks

This commit is contained in:
ottofabian 2021-06-24 15:52:21 +02:00
parent a30bdb8ce5
commit dffa3e3682
13 changed files with 96 additions and 87 deletions

View File

@ -1,3 +1,3 @@
from alr_envs.classic_control.simple_reacher import SimpleReacherEnv from alr_envs.classic_control.viapoint_reacher.viapoint_reacher import ViaPointReacher
from alr_envs.classic_control.viapoint_reacher import ViaPointReacher from alr_envs.classic_control.simple_reacher.simple_reacher import SimpleReacherEnv
from alr_envs.classic_control.hole_reacher import HoleReacherEnv from alr_envs.classic_control.hole_reacher.hole_reacher import HoleReacherEnv

View File

@ -7,8 +7,6 @@ from gym.utils import seeding
from matplotlib import patches from matplotlib import patches
from alr_envs.classic_control.utils import check_self_collision from alr_envs.classic_control.utils import check_self_collision
from mp_env_api.envs.mp_env import MpEnv
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
class HoleReacherEnv(gym.Env): class HoleReacherEnv(gym.Env):
@ -289,29 +287,3 @@ class HoleReacherEnv(gym.Env):
super().close() super().close()
if self.fig is not None: if self.fig is not None:
plt.close(self.fig) plt.close(self.fig)
class HoleReacherMPWrapper(MPEnvWrapper):
@property
def active_obs(self):
return np.hstack([
[self.env.random_start] * self.env.n_links, # cos
[self.env.random_start] * self.env.n_links, # sin
[self.env.random_start] * self.env.n_links, # velocity
[self.env.hole_width is None], # hole width
# [self.env.hole_depth is None], # hole depth
[True] * 2, # x-y coordinates of target distance
[False] # env steps
])
@property
def start_pos(self) -> Union[float, int, np.ndarray]:
return self._start_pos
@property
def goal_pos(self) -> Union[float, int, np.ndarray]:
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
@property
def dt(self) -> Union[float, int]:
return self.env.dt

View File

@ -0,0 +1,31 @@
from typing import Union
import numpy as np
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
class HoleReacherMPWrapper(MPEnvWrapper):
@property
def active_obs(self):
return np.hstack([
[self.env.random_start] * self.env.n_links, # cos
[self.env.random_start] * self.env.n_links, # sin
[self.env.random_start] * self.env.n_links, # velocity
[self.env.hole_width is None], # hole width
# [self.env.hole_depth is None], # hole depth
[True] * 2, # x-y coordinates of target distance
[False] # env steps
])
@property
def start_pos(self) -> Union[float, int, np.ndarray]:
return self._start_pos
@property
def goal_pos(self) -> Union[float, int, np.ndarray]:
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
@property
def dt(self) -> Union[float, int]:
return self.env.dt

View File

@ -6,8 +6,6 @@ import numpy as np
from gym import spaces from gym import spaces
from gym.utils import seeding from gym.utils import seeding
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
class SimpleReacherEnv(gym.Env): class SimpleReacherEnv(gym.Env):
""" """
@ -187,27 +185,3 @@ class SimpleReacherEnv(gym.Env):
@property @property
def end_effector(self): def end_effector(self):
return self._joints[self.n_links].T return self._joints[self.n_links].T
class SimpleReacherMPWrapper(MPEnvWrapper):
@property
def active_obs(self):
return np.hstack([
[self.env.random_start] * self.env.n_links, # cos
[self.env.random_start] * self.env.n_links, # sin
[self.env.random_start] * self.env.n_links, # velocity
[True] * 2, # x-y coordinates of target distance
[False] # env steps
])
@property
def start_pos(self):
return self._start_pos
@property
def goal_pos(self):
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
@property
def dt(self) -> Union[float, int]:
return self.env.dt

View File

@ -0,0 +1,29 @@
from typing import Union
import numpy as np
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
class SimpleReacherMPWrapper(MPEnvWrapper):
@property
def active_obs(self):
return np.hstack([
[self.env.random_start] * self.env.n_links, # cos
[self.env.random_start] * self.env.n_links, # sin
[self.env.random_start] * self.env.n_links, # velocity
[True] * 2, # x-y coordinates of target distance
[False] # env steps
])
@property
def start_pos(self):
return self._start_pos
@property
def goal_pos(self):
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
@property
def dt(self) -> Union[float, int]:
return self.env.dt

View File

@ -10,7 +10,7 @@ def intersect(A, B, C, D):
def check_self_collision(line_points): def check_self_collision(line_points):
"Checks whether line segments and intersect" """Checks whether line segments and intersect"""
for i, line1 in enumerate(line_points): for i, line1 in enumerate(line_points):
for line2 in line_points[i + 2:, :, :]: for line2 in line_points[i + 2:, :, :]:
if intersect(line1[0], line1[-1], line2[0], line2[-1]): if intersect(line1[0], line1[-1], line2[0], line2[-1]):

View File

@ -6,8 +6,6 @@ import numpy as np
from gym.utils import seeding from gym.utils import seeding
from alr_envs.classic_control.utils import check_self_collision from alr_envs.classic_control.utils import check_self_collision
from mp_env_api.envs.mp_env import MpEnv
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
class ViaPointReacher(gym.Env): class ViaPointReacher(gym.Env):
@ -282,27 +280,3 @@ class ViaPointReacher(gym.Env):
def close(self): def close(self):
if self.fig is not None: if self.fig is not None:
plt.close(self.fig) plt.close(self.fig)
class ViaPointReacherMPWrapper(MPEnvWrapper):
@property
def active_obs(self):
return np.hstack([
[self.env.random_start] * self.env.n_links, # cos
[self.env.random_start] * self.env.n_links, # sin
[self.env.random_start] * self.env.n_links, # velocity
[self.env.via_target is None] * 2, # x-y coordinates of via point distance
[True] * 2, # x-y coordinates of target distance
[False] # env steps
])
@property
def start_pos(self) -> Union[float, int, np.ndarray]:
return self._start_pos
@property
def goal_pos(self) -> Union[float, int, np.ndarray]:
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
def dt(self) -> Union[float, int]:
return self.env.dt

View File

@ -0,0 +1,29 @@
from typing import Union
import numpy as np
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
class ViaPointReacherMPWrapper(MPEnvWrapper):
@property
def active_obs(self):
return np.hstack([
[self.env.random_start] * self.env.n_links, # cos
[self.env.random_start] * self.env.n_links, # sin
[self.env.random_start] * self.env.n_links, # velocity
[self.env.via_target is None] * 2, # x-y coordinates of via point distance
[True] * 2, # x-y coordinates of target distance
[False] # env steps
])
@property
def start_pos(self) -> Union[float, int, np.ndarray]:
return self._start_pos
@property
def goal_pos(self) -> Union[float, int, np.ndarray]:
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
def dt(self) -> Union[float, int]:
return self.env.dt

View File

@ -98,14 +98,14 @@ class AlrMujocoEnv(PositionalEnv, AlrEnv):
@property @property
def start_pos(self): def start_pos(self):
""" """
Start position of the agent, for example joint angles of a Panda robot. Necessary for MP wrapped envs. Start position of the agent, for example joint angles of a Panda robot. Necessary for MP wrapped simple_reacher.
""" """
return self._start_pos return self._start_pos
@property @property
def start_vel(self): def start_vel(self):
""" """
Start velocity of the agent. Necessary for MP wrapped envs. Start velocity of the agent. Necessary for MP wrapped simple_reacher.
""" """
return self._start_vel return self._start_vel

View File

@ -39,7 +39,7 @@ def config_save(dir_path, config):
def change_kp_in_xml(kp_list, def change_kp_in_xml(kp_list,
model_path="/home/zhou/slow/table_tennis_rl/simulation/gymTableTennis/gym_table_tennis/envs/robotics/assets/table_tennis/right_arm_actuator.xml"): model_path="/home/zhou/slow/table_tennis_rl/simulation/gymTableTennis/gym_table_tennis/simple_reacher/robotics/assets/table_tennis/right_arm_actuator.xml"):
tree = ET.parse(model_path) tree = ET.parse(model_path)
root = tree.getroot() root = tree.getroot()
# for actuator in root.find("actuator"): # for actuator in root.find("actuator"):