matched file structure of classic control with other tasks
This commit is contained in:
		
							parent
							
								
									a30bdb8ce5
								
							
						
					
					
						commit
						dffa3e3682
					
				@ -1,3 +1,3 @@
 | 
				
			|||||||
from alr_envs.classic_control.simple_reacher import SimpleReacherEnv
 | 
					from alr_envs.classic_control.viapoint_reacher.viapoint_reacher import ViaPointReacher
 | 
				
			||||||
from alr_envs.classic_control.viapoint_reacher import ViaPointReacher
 | 
					from alr_envs.classic_control.simple_reacher.simple_reacher import SimpleReacherEnv
 | 
				
			||||||
from alr_envs.classic_control.hole_reacher import HoleReacherEnv
 | 
					from alr_envs.classic_control.hole_reacher.hole_reacher import HoleReacherEnv
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										0
									
								
								alr_envs/classic_control/hole_reacher/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								alr_envs/classic_control/hole_reacher/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -7,8 +7,6 @@ from gym.utils import seeding
 | 
				
			|||||||
from matplotlib import patches
 | 
					from matplotlib import patches
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from alr_envs.classic_control.utils import check_self_collision
 | 
					from alr_envs.classic_control.utils import check_self_collision
 | 
				
			||||||
from mp_env_api.envs.mp_env import MpEnv
 | 
					 | 
				
			||||||
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class HoleReacherEnv(gym.Env):
 | 
					class HoleReacherEnv(gym.Env):
 | 
				
			||||||
@ -289,29 +287,3 @@ class HoleReacherEnv(gym.Env):
 | 
				
			|||||||
        super().close()
 | 
					        super().close()
 | 
				
			||||||
        if self.fig is not None:
 | 
					        if self.fig is not None:
 | 
				
			||||||
            plt.close(self.fig)
 | 
					            plt.close(self.fig)
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class HoleReacherMPWrapper(MPEnvWrapper):
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def active_obs(self):
 | 
					 | 
				
			||||||
        return np.hstack([
 | 
					 | 
				
			||||||
            [self.env.random_start] * self.env.n_links,  # cos
 | 
					 | 
				
			||||||
            [self.env.random_start] * self.env.n_links,  # sin
 | 
					 | 
				
			||||||
            [self.env.random_start] * self.env.n_links,  # velocity
 | 
					 | 
				
			||||||
            [self.env.hole_width is None],  # hole width
 | 
					 | 
				
			||||||
            # [self.env.hole_depth is None],  # hole depth
 | 
					 | 
				
			||||||
            [True] * 2,  # x-y coordinates of target distance
 | 
					 | 
				
			||||||
            [False]  # env steps
 | 
					 | 
				
			||||||
        ])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def start_pos(self) -> Union[float, int, np.ndarray]:
 | 
					 | 
				
			||||||
        return self._start_pos
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def goal_pos(self) -> Union[float, int, np.ndarray]:
 | 
					 | 
				
			||||||
        raise ValueError("Goal position is not available and has to be learnt based on the environment.")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def dt(self) -> Union[float, int]:
 | 
					 | 
				
			||||||
        return self.env.dt
 | 
					 | 
				
			||||||
@ -0,0 +1,31 @@
 | 
				
			|||||||
 | 
					from typing import Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class HoleReacherMPWrapper(MPEnvWrapper):
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def active_obs(self):
 | 
				
			||||||
 | 
					        return np.hstack([
 | 
				
			||||||
 | 
					            [self.env.random_start] * self.env.n_links,  # cos
 | 
				
			||||||
 | 
					            [self.env.random_start] * self.env.n_links,  # sin
 | 
				
			||||||
 | 
					            [self.env.random_start] * self.env.n_links,  # velocity
 | 
				
			||||||
 | 
					            [self.env.hole_width is None],  # hole width
 | 
				
			||||||
 | 
					            # [self.env.hole_depth is None],  # hole depth
 | 
				
			||||||
 | 
					            [True] * 2,  # x-y coordinates of target distance
 | 
				
			||||||
 | 
					            [False]  # env steps
 | 
				
			||||||
 | 
					        ])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def start_pos(self) -> Union[float, int, np.ndarray]:
 | 
				
			||||||
 | 
					        return self._start_pos
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def goal_pos(self) -> Union[float, int, np.ndarray]:
 | 
				
			||||||
 | 
					        raise ValueError("Goal position is not available and has to be learnt based on the environment.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def dt(self) -> Union[float, int]:
 | 
				
			||||||
 | 
					        return self.env.dt
 | 
				
			||||||
							
								
								
									
										0
									
								
								alr_envs/classic_control/simple_reacher/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								alr_envs/classic_control/simple_reacher/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -6,8 +6,6 @@ import numpy as np
 | 
				
			|||||||
from gym import spaces
 | 
					from gym import spaces
 | 
				
			||||||
from gym.utils import seeding
 | 
					from gym.utils import seeding
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SimpleReacherEnv(gym.Env):
 | 
					class SimpleReacherEnv(gym.Env):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
@ -187,27 +185,3 @@ class SimpleReacherEnv(gym.Env):
 | 
				
			|||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def end_effector(self):
 | 
					    def end_effector(self):
 | 
				
			||||||
        return self._joints[self.n_links].T
 | 
					        return self._joints[self.n_links].T
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class SimpleReacherMPWrapper(MPEnvWrapper):
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def active_obs(self):
 | 
					 | 
				
			||||||
        return np.hstack([
 | 
					 | 
				
			||||||
            [self.env.random_start] * self.env.n_links,  # cos
 | 
					 | 
				
			||||||
            [self.env.random_start] * self.env.n_links,  # sin
 | 
					 | 
				
			||||||
            [self.env.random_start] * self.env.n_links,  # velocity
 | 
					 | 
				
			||||||
            [True] * 2,  # x-y coordinates of target distance
 | 
					 | 
				
			||||||
            [False]  # env steps
 | 
					 | 
				
			||||||
        ])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def start_pos(self):
 | 
					 | 
				
			||||||
        return self._start_pos
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def goal_pos(self):
 | 
					 | 
				
			||||||
        raise ValueError("Goal position is not available and has to be learnt based on the environment.")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def dt(self) -> Union[float, int]:
 | 
					 | 
				
			||||||
        return self.env.dt
 | 
					 | 
				
			||||||
@ -0,0 +1,29 @@
 | 
				
			|||||||
 | 
					from typing import Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SimpleReacherMPWrapper(MPEnvWrapper):
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def active_obs(self):
 | 
				
			||||||
 | 
					        return np.hstack([
 | 
				
			||||||
 | 
					            [self.env.random_start] * self.env.n_links,  # cos
 | 
				
			||||||
 | 
					            [self.env.random_start] * self.env.n_links,  # sin
 | 
				
			||||||
 | 
					            [self.env.random_start] * self.env.n_links,  # velocity
 | 
				
			||||||
 | 
					            [True] * 2,  # x-y coordinates of target distance
 | 
				
			||||||
 | 
					            [False]  # env steps
 | 
				
			||||||
 | 
					        ])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def start_pos(self):
 | 
				
			||||||
 | 
					        return self._start_pos
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def goal_pos(self):
 | 
				
			||||||
 | 
					        raise ValueError("Goal position is not available and has to be learnt based on the environment.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def dt(self) -> Union[float, int]:
 | 
				
			||||||
 | 
					        return self.env.dt
 | 
				
			||||||
@ -10,7 +10,7 @@ def intersect(A, B, C, D):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def check_self_collision(line_points):
 | 
					def check_self_collision(line_points):
 | 
				
			||||||
    "Checks whether line segments and intersect"
 | 
					    """Checks whether line segments and intersect"""
 | 
				
			||||||
    for i, line1 in enumerate(line_points):
 | 
					    for i, line1 in enumerate(line_points):
 | 
				
			||||||
        for line2 in line_points[i + 2:, :, :]:
 | 
					        for line2 in line_points[i + 2:, :, :]:
 | 
				
			||||||
            if intersect(line1[0], line1[-1], line2[0], line2[-1]):
 | 
					            if intersect(line1[0], line1[-1], line2[0], line2[-1]):
 | 
				
			||||||
 | 
				
			|||||||
@ -6,8 +6,6 @@ import numpy as np
 | 
				
			|||||||
from gym.utils import seeding
 | 
					from gym.utils import seeding
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from alr_envs.classic_control.utils import check_self_collision
 | 
					from alr_envs.classic_control.utils import check_self_collision
 | 
				
			||||||
from mp_env_api.envs.mp_env import MpEnv
 | 
					 | 
				
			||||||
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ViaPointReacher(gym.Env):
 | 
					class ViaPointReacher(gym.Env):
 | 
				
			||||||
@ -282,27 +280,3 @@ class ViaPointReacher(gym.Env):
 | 
				
			|||||||
    def close(self):
 | 
					    def close(self):
 | 
				
			||||||
        if self.fig is not None:
 | 
					        if self.fig is not None:
 | 
				
			||||||
            plt.close(self.fig)
 | 
					            plt.close(self.fig)
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ViaPointReacherMPWrapper(MPEnvWrapper):
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def active_obs(self):
 | 
					 | 
				
			||||||
        return np.hstack([
 | 
					 | 
				
			||||||
            [self.env.random_start] * self.env.n_links,  # cos
 | 
					 | 
				
			||||||
            [self.env.random_start] * self.env.n_links,  # sin
 | 
					 | 
				
			||||||
            [self.env.random_start] * self.env.n_links,  # velocity
 | 
					 | 
				
			||||||
            [self.env.via_target is None] * 2,  # x-y coordinates of via point distance
 | 
					 | 
				
			||||||
            [True] * 2,  # x-y coordinates of target distance
 | 
					 | 
				
			||||||
            [False]  # env steps
 | 
					 | 
				
			||||||
        ])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def start_pos(self) -> Union[float, int, np.ndarray]:
 | 
					 | 
				
			||||||
        return self._start_pos
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def goal_pos(self) -> Union[float, int, np.ndarray]:
 | 
					 | 
				
			||||||
        raise ValueError("Goal position is not available and has to be learnt based on the environment.")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def dt(self) -> Union[float, int]:
 | 
					 | 
				
			||||||
        return self.env.dt
 | 
					 | 
				
			||||||
@ -0,0 +1,29 @@
 | 
				
			|||||||
 | 
					from typing import Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ViaPointReacherMPWrapper(MPEnvWrapper):
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def active_obs(self):
 | 
				
			||||||
 | 
					        return np.hstack([
 | 
				
			||||||
 | 
					            [self.env.random_start] * self.env.n_links,  # cos
 | 
				
			||||||
 | 
					            [self.env.random_start] * self.env.n_links,  # sin
 | 
				
			||||||
 | 
					            [self.env.random_start] * self.env.n_links,  # velocity
 | 
				
			||||||
 | 
					            [self.env.via_target is None] * 2,  # x-y coordinates of via point distance
 | 
				
			||||||
 | 
					            [True] * 2,  # x-y coordinates of target distance
 | 
				
			||||||
 | 
					            [False]  # env steps
 | 
				
			||||||
 | 
					        ])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def start_pos(self) -> Union[float, int, np.ndarray]:
 | 
				
			||||||
 | 
					        return self._start_pos
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def goal_pos(self) -> Union[float, int, np.ndarray]:
 | 
				
			||||||
 | 
					        raise ValueError("Goal position is not available and has to be learnt based on the environment.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def dt(self) -> Union[float, int]:
 | 
				
			||||||
 | 
					        return self.env.dt
 | 
				
			||||||
@ -98,14 +98,14 @@ class AlrMujocoEnv(PositionalEnv, AlrEnv):
 | 
				
			|||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def start_pos(self):
 | 
					    def start_pos(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Start position of the agent, for example joint angles of a Panda robot. Necessary for MP wrapped envs.
 | 
					        Start position of the agent, for example joint angles of a Panda robot. Necessary for MP wrapped simple_reacher.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        return self._start_pos
 | 
					        return self._start_pos
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def start_vel(self):
 | 
					    def start_vel(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Start velocity of the agent. Necessary for MP wrapped envs.
 | 
					        Start velocity of the agent. Necessary for MP wrapped simple_reacher.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        return self._start_vel
 | 
					        return self._start_vel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -39,7 +39,7 @@ def config_save(dir_path, config):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def change_kp_in_xml(kp_list,
 | 
					def change_kp_in_xml(kp_list,
 | 
				
			||||||
                     model_path="/home/zhou/slow/table_tennis_rl/simulation/gymTableTennis/gym_table_tennis/envs/robotics/assets/table_tennis/right_arm_actuator.xml"):
 | 
					                     model_path="/home/zhou/slow/table_tennis_rl/simulation/gymTableTennis/gym_table_tennis/simple_reacher/robotics/assets/table_tennis/right_arm_actuator.xml"):
 | 
				
			||||||
    tree = ET.parse(model_path)
 | 
					    tree = ET.parse(model_path)
 | 
				
			||||||
    root = tree.getroot()
 | 
					    root = tree.getroot()
 | 
				
			||||||
    # for actuator in root.find("actuator"):
 | 
					    # for actuator in root.find("actuator"):
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user