matched file structure of classic control with other tasks
This commit is contained in:
parent
a30bdb8ce5
commit
dffa3e3682
@ -1,3 +1,3 @@
|
|||||||
from alr_envs.classic_control.simple_reacher import SimpleReacherEnv
|
from alr_envs.classic_control.viapoint_reacher.viapoint_reacher import ViaPointReacher
|
||||||
from alr_envs.classic_control.viapoint_reacher import ViaPointReacher
|
from alr_envs.classic_control.simple_reacher.simple_reacher import SimpleReacherEnv
|
||||||
from alr_envs.classic_control.hole_reacher import HoleReacherEnv
|
from alr_envs.classic_control.hole_reacher.hole_reacher import HoleReacherEnv
|
||||||
|
0
alr_envs/classic_control/hole_reacher/__init__.py
Normal file
0
alr_envs/classic_control/hole_reacher/__init__.py
Normal file
@ -7,8 +7,6 @@ from gym.utils import seeding
|
|||||||
from matplotlib import patches
|
from matplotlib import patches
|
||||||
|
|
||||||
from alr_envs.classic_control.utils import check_self_collision
|
from alr_envs.classic_control.utils import check_self_collision
|
||||||
from mp_env_api.envs.mp_env import MpEnv
|
|
||||||
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
|
|
||||||
|
|
||||||
|
|
||||||
class HoleReacherEnv(gym.Env):
|
class HoleReacherEnv(gym.Env):
|
||||||
@ -289,29 +287,3 @@ class HoleReacherEnv(gym.Env):
|
|||||||
super().close()
|
super().close()
|
||||||
if self.fig is not None:
|
if self.fig is not None:
|
||||||
plt.close(self.fig)
|
plt.close(self.fig)
|
||||||
|
|
||||||
|
|
||||||
class HoleReacherMPWrapper(MPEnvWrapper):
|
|
||||||
@property
|
|
||||||
def active_obs(self):
|
|
||||||
return np.hstack([
|
|
||||||
[self.env.random_start] * self.env.n_links, # cos
|
|
||||||
[self.env.random_start] * self.env.n_links, # sin
|
|
||||||
[self.env.random_start] * self.env.n_links, # velocity
|
|
||||||
[self.env.hole_width is None], # hole width
|
|
||||||
# [self.env.hole_depth is None], # hole depth
|
|
||||||
[True] * 2, # x-y coordinates of target distance
|
|
||||||
[False] # env steps
|
|
||||||
])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def start_pos(self) -> Union[float, int, np.ndarray]:
|
|
||||||
return self._start_pos
|
|
||||||
|
|
||||||
@property
|
|
||||||
def goal_pos(self) -> Union[float, int, np.ndarray]:
|
|
||||||
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dt(self) -> Union[float, int]:
|
|
||||||
return self.env.dt
|
|
@ -0,0 +1,31 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class HoleReacherMPWrapper(MPEnvWrapper):
|
||||||
|
@property
|
||||||
|
def active_obs(self):
|
||||||
|
return np.hstack([
|
||||||
|
[self.env.random_start] * self.env.n_links, # cos
|
||||||
|
[self.env.random_start] * self.env.n_links, # sin
|
||||||
|
[self.env.random_start] * self.env.n_links, # velocity
|
||||||
|
[self.env.hole_width is None], # hole width
|
||||||
|
# [self.env.hole_depth is None], # hole depth
|
||||||
|
[True] * 2, # x-y coordinates of target distance
|
||||||
|
[False] # env steps
|
||||||
|
])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
|
return self._start_pos
|
||||||
|
|
||||||
|
@property
|
||||||
|
def goal_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
|
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dt(self) -> Union[float, int]:
|
||||||
|
return self.env.dt
|
0
alr_envs/classic_control/simple_reacher/__init__.py
Normal file
0
alr_envs/classic_control/simple_reacher/__init__.py
Normal file
@ -6,8 +6,6 @@ import numpy as np
|
|||||||
from gym import spaces
|
from gym import spaces
|
||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
|
|
||||||
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleReacherEnv(gym.Env):
|
class SimpleReacherEnv(gym.Env):
|
||||||
"""
|
"""
|
||||||
@ -187,27 +185,3 @@ class SimpleReacherEnv(gym.Env):
|
|||||||
@property
|
@property
|
||||||
def end_effector(self):
|
def end_effector(self):
|
||||||
return self._joints[self.n_links].T
|
return self._joints[self.n_links].T
|
||||||
|
|
||||||
|
|
||||||
class SimpleReacherMPWrapper(MPEnvWrapper):
|
|
||||||
@property
|
|
||||||
def active_obs(self):
|
|
||||||
return np.hstack([
|
|
||||||
[self.env.random_start] * self.env.n_links, # cos
|
|
||||||
[self.env.random_start] * self.env.n_links, # sin
|
|
||||||
[self.env.random_start] * self.env.n_links, # velocity
|
|
||||||
[True] * 2, # x-y coordinates of target distance
|
|
||||||
[False] # env steps
|
|
||||||
])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def start_pos(self):
|
|
||||||
return self._start_pos
|
|
||||||
|
|
||||||
@property
|
|
||||||
def goal_pos(self):
|
|
||||||
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dt(self) -> Union[float, int]:
|
|
||||||
return self.env.dt
|
|
@ -0,0 +1,29 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleReacherMPWrapper(MPEnvWrapper):
|
||||||
|
@property
|
||||||
|
def active_obs(self):
|
||||||
|
return np.hstack([
|
||||||
|
[self.env.random_start] * self.env.n_links, # cos
|
||||||
|
[self.env.random_start] * self.env.n_links, # sin
|
||||||
|
[self.env.random_start] * self.env.n_links, # velocity
|
||||||
|
[True] * 2, # x-y coordinates of target distance
|
||||||
|
[False] # env steps
|
||||||
|
])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_pos(self):
|
||||||
|
return self._start_pos
|
||||||
|
|
||||||
|
@property
|
||||||
|
def goal_pos(self):
|
||||||
|
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dt(self) -> Union[float, int]:
|
||||||
|
return self.env.dt
|
@ -10,7 +10,7 @@ def intersect(A, B, C, D):
|
|||||||
|
|
||||||
|
|
||||||
def check_self_collision(line_points):
|
def check_self_collision(line_points):
|
||||||
"Checks whether line segments and intersect"
|
"""Checks whether line segments and intersect"""
|
||||||
for i, line1 in enumerate(line_points):
|
for i, line1 in enumerate(line_points):
|
||||||
for line2 in line_points[i + 2:, :, :]:
|
for line2 in line_points[i + 2:, :, :]:
|
||||||
if intersect(line1[0], line1[-1], line2[0], line2[-1]):
|
if intersect(line1[0], line1[-1], line2[0], line2[-1]):
|
||||||
|
@ -6,8 +6,6 @@ import numpy as np
|
|||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
|
|
||||||
from alr_envs.classic_control.utils import check_self_collision
|
from alr_envs.classic_control.utils import check_self_collision
|
||||||
from mp_env_api.envs.mp_env import MpEnv
|
|
||||||
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
|
|
||||||
|
|
||||||
|
|
||||||
class ViaPointReacher(gym.Env):
|
class ViaPointReacher(gym.Env):
|
||||||
@ -282,27 +280,3 @@ class ViaPointReacher(gym.Env):
|
|||||||
def close(self):
|
def close(self):
|
||||||
if self.fig is not None:
|
if self.fig is not None:
|
||||||
plt.close(self.fig)
|
plt.close(self.fig)
|
||||||
|
|
||||||
|
|
||||||
class ViaPointReacherMPWrapper(MPEnvWrapper):
|
|
||||||
@property
|
|
||||||
def active_obs(self):
|
|
||||||
return np.hstack([
|
|
||||||
[self.env.random_start] * self.env.n_links, # cos
|
|
||||||
[self.env.random_start] * self.env.n_links, # sin
|
|
||||||
[self.env.random_start] * self.env.n_links, # velocity
|
|
||||||
[self.env.via_target is None] * 2, # x-y coordinates of via point distance
|
|
||||||
[True] * 2, # x-y coordinates of target distance
|
|
||||||
[False] # env steps
|
|
||||||
])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def start_pos(self) -> Union[float, int, np.ndarray]:
|
|
||||||
return self._start_pos
|
|
||||||
|
|
||||||
@property
|
|
||||||
def goal_pos(self) -> Union[float, int, np.ndarray]:
|
|
||||||
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
|
||||||
|
|
||||||
def dt(self) -> Union[float, int]:
|
|
||||||
return self.env.dt
|
|
@ -0,0 +1,29 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class ViaPointReacherMPWrapper(MPEnvWrapper):
|
||||||
|
@property
|
||||||
|
def active_obs(self):
|
||||||
|
return np.hstack([
|
||||||
|
[self.env.random_start] * self.env.n_links, # cos
|
||||||
|
[self.env.random_start] * self.env.n_links, # sin
|
||||||
|
[self.env.random_start] * self.env.n_links, # velocity
|
||||||
|
[self.env.via_target is None] * 2, # x-y coordinates of via point distance
|
||||||
|
[True] * 2, # x-y coordinates of target distance
|
||||||
|
[False] # env steps
|
||||||
|
])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
|
return self._start_pos
|
||||||
|
|
||||||
|
@property
|
||||||
|
def goal_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
|
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||||||
|
|
||||||
|
def dt(self) -> Union[float, int]:
|
||||||
|
return self.env.dt
|
@ -98,14 +98,14 @@ class AlrMujocoEnv(PositionalEnv, AlrEnv):
|
|||||||
@property
|
@property
|
||||||
def start_pos(self):
|
def start_pos(self):
|
||||||
"""
|
"""
|
||||||
Start position of the agent, for example joint angles of a Panda robot. Necessary for MP wrapped envs.
|
Start position of the agent, for example joint angles of a Panda robot. Necessary for MP wrapped simple_reacher.
|
||||||
"""
|
"""
|
||||||
return self._start_pos
|
return self._start_pos
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def start_vel(self):
|
def start_vel(self):
|
||||||
"""
|
"""
|
||||||
Start velocity of the agent. Necessary for MP wrapped envs.
|
Start velocity of the agent. Necessary for MP wrapped simple_reacher.
|
||||||
"""
|
"""
|
||||||
return self._start_vel
|
return self._start_vel
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ def config_save(dir_path, config):
|
|||||||
|
|
||||||
|
|
||||||
def change_kp_in_xml(kp_list,
|
def change_kp_in_xml(kp_list,
|
||||||
model_path="/home/zhou/slow/table_tennis_rl/simulation/gymTableTennis/gym_table_tennis/envs/robotics/assets/table_tennis/right_arm_actuator.xml"):
|
model_path="/home/zhou/slow/table_tennis_rl/simulation/gymTableTennis/gym_table_tennis/simple_reacher/robotics/assets/table_tennis/right_arm_actuator.xml"):
|
||||||
tree = ET.parse(model_path)
|
tree = ET.parse(model_path)
|
||||||
root = tree.getroot()
|
root = tree.getroot()
|
||||||
# for actuator in root.find("actuator"):
|
# for actuator in root.find("actuator"):
|
||||||
|
Loading…
Reference in New Issue
Block a user