Classical Controll envs: Follow new spec for render_mode

This commit is contained in:
Dominik Moritz Roth 2023-10-23 12:26:26 +02:00
parent 5db73f90c4
commit b681129a46
6 changed files with 23 additions and 15 deletions

View File

@ -14,12 +14,14 @@ class BaseReacherEnv(gym.Env):
Base class for all reaching environments. Base class for all reaching environments.
""" """
def __init__(self, n_links: int, random_start: bool = True, allow_self_collision: bool = False): def __init__(self, n_links: int, random_start: bool = True, allow_self_collision: bool = False, render_mode: str = None, **kwargs):
super().__init__() super().__init__(render_mode=render_mode, **kwargs)
self.link_lengths = np.ones(n_links) self.link_lengths = np.ones(n_links)
self.n_links = n_links self.n_links = n_links
self._dt = 0.01 self._dt = 0.01
self.render_mode = render_mode
self.random_start = random_start self.random_start = random_start
self.allow_self_collision = allow_self_collision self.allow_self_collision = allow_self_collision

View File

@ -10,8 +10,8 @@ class BaseReacherDirectEnv(BaseReacherEnv):
""" """
def __init__(self, n_links: int, random_start: bool = True, def __init__(self, n_links: int, random_start: bool = True,
allow_self_collision: bool = False): allow_self_collision: bool = False, **kwargs):
super().__init__(n_links, random_start, allow_self_collision) super().__init__(n_links, random_start, allow_self_collision, **kwargs)
self.max_vel = 2 * np.pi self.max_vel = 2 * np.pi
action_bound = np.ones((self.n_links,)) * self.max_vel action_bound = np.ones((self.n_links,)) * self.max_vel

View File

@ -10,8 +10,8 @@ class BaseReacherTorqueEnv(BaseReacherEnv):
""" """
def __init__(self, n_links: int, random_start: bool = True, def __init__(self, n_links: int, random_start: bool = True,
allow_self_collision: bool = False): allow_self_collision: bool = False, **kwargs):
super().__init__(n_links, random_start, allow_self_collision) super().__init__(n_links, random_start, allow_self_collision, **kwargs)
self.max_torque = 1000 self.max_torque = 1000
action_bound = np.ones((self.n_links,)) * self.max_torque action_bound = np.ones((self.n_links,)) * self.max_torque

View File

@ -17,9 +17,9 @@ class HoleReacherEnv(BaseReacherDirectEnv):
def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None, def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None,
hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False, hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False,
allow_wall_collision: bool = False, collision_penalty: float = 1000, rew_fct: str = "simple"): allow_wall_collision: bool = False, collision_penalty: float = 1000, rew_fct: str = "simple", **kwargs):
super().__init__(n_links, random_start, allow_self_collision) super().__init__(n_links, random_start, allow_self_collision, **kwargs)
# provided initial parameters # provided initial parameters
self.initial_x = hole_x # x-position of center of hole self.initial_x = hole_x # x-position of center of hole
@ -178,7 +178,9 @@ class HoleReacherEnv(BaseReacherDirectEnv):
return False return False
def render(self, mode='human'): def render(self, mode=None):
if mode==None:
mode = self.render_mode
if self.fig is None: if self.fig is None:
# Create base figure once on the beginning. Afterwards only update # Create base figure once on the beginning. Afterwards only update
plt.ion() plt.ion()

View File

@ -17,8 +17,8 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
""" """
def __init__(self, n_links: int, target: Union[None, Iterable] = None, random_start: bool = True, def __init__(self, n_links: int, target: Union[None, Iterable] = None, random_start: bool = True,
allow_self_collision: bool = False, ): allow_self_collision: bool = False, **kwargs):
super().__init__(n_links, random_start, allow_self_collision) super().__init__(n_links, random_start, allow_self_collision, **kwargs)
# provided initial parameters # provided initial parameters
self.inital_target = target self.inital_target = target
@ -98,7 +98,9 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
def _check_collisions(self) -> bool: def _check_collisions(self) -> bool:
return self._check_self_collision() return self._check_self_collision()
def render(self, mode='human'): # pragma: no cover def render(self, mode=None): # pragma: no cover
if mode==None:
mode = self.render_mode
if self.fig is None: if self.fig is None:
# Create base figure once on the beginning. Afterwards only update # Create base figure once on the beginning. Afterwards only update
plt.ion() plt.ion()

View File

@ -13,9 +13,9 @@ from . import MPWrapper
class ViaPointReacherEnv(BaseReacherDirectEnv): class ViaPointReacherEnv(BaseReacherDirectEnv):
def __init__(self, n_links, random_start: bool = False, via_target: Union[None, Iterable] = None, def __init__(self, n_links, random_start: bool = False, via_target: Union[None, Iterable] = None,
target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000): target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000, **kwargs):
super().__init__(n_links, random_start, allow_self_collision) super().__init__(n_links, random_start, allow_self_collision, **kwargs)
# provided initial parameters # provided initial parameters
self.intitial_target = target # provided target value self.intitial_target = target # provided target value
@ -123,7 +123,9 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
def _check_collisions(self) -> bool: def _check_collisions(self) -> bool:
return self._check_self_collision() return self._check_self_collision()
def render(self, mode='human'): def render(self, mode=None):
if mode==None:
mode = self.render_mode
goal_pos = self._goal.T goal_pos = self._goal.T
via_pos = self._via_point.T via_pos = self._via_point.T