Classical Controll envs: Follow new spec for render_mode

This commit is contained in:
Dominik Moritz Roth 2023-10-23 12:26:26 +02:00
parent 5db73f90c4
commit b681129a46
6 changed files with 23 additions and 15 deletions

View File

@ -14,12 +14,14 @@ class BaseReacherEnv(gym.Env):
Base class for all reaching environments.
"""
def __init__(self, n_links: int, random_start: bool = True, allow_self_collision: bool = False):
super().__init__()
def __init__(self, n_links: int, random_start: bool = True, allow_self_collision: bool = False, render_mode: str = None, **kwargs):
super().__init__(render_mode=render_mode, **kwargs)
self.link_lengths = np.ones(n_links)
self.n_links = n_links
self._dt = 0.01
self.render_mode = render_mode
self.random_start = random_start
self.allow_self_collision = allow_self_collision

View File

@ -10,8 +10,8 @@ class BaseReacherDirectEnv(BaseReacherEnv):
"""
def __init__(self, n_links: int, random_start: bool = True,
allow_self_collision: bool = False):
super().__init__(n_links, random_start, allow_self_collision)
allow_self_collision: bool = False, **kwargs):
super().__init__(n_links, random_start, allow_self_collision, **kwargs)
self.max_vel = 2 * np.pi
action_bound = np.ones((self.n_links,)) * self.max_vel

View File

@ -10,8 +10,8 @@ class BaseReacherTorqueEnv(BaseReacherEnv):
"""
def __init__(self, n_links: int, random_start: bool = True,
allow_self_collision: bool = False):
super().__init__(n_links, random_start, allow_self_collision)
allow_self_collision: bool = False, **kwargs):
super().__init__(n_links, random_start, allow_self_collision, **kwargs)
self.max_torque = 1000
action_bound = np.ones((self.n_links,)) * self.max_torque

View File

@ -17,9 +17,9 @@ class HoleReacherEnv(BaseReacherDirectEnv):
def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None,
hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False,
allow_wall_collision: bool = False, collision_penalty: float = 1000, rew_fct: str = "simple"):
allow_wall_collision: bool = False, collision_penalty: float = 1000, rew_fct: str = "simple", **kwargs):
super().__init__(n_links, random_start, allow_self_collision)
super().__init__(n_links, random_start, allow_self_collision, **kwargs)
# provided initial parameters
self.initial_x = hole_x # x-position of center of hole
@ -178,7 +178,9 @@ class HoleReacherEnv(BaseReacherDirectEnv):
return False
def render(self, mode='human'):
def render(self, mode=None):
if mode==None:
mode = self.render_mode
if self.fig is None:
# Create base figure once on the beginning. Afterwards only update
plt.ion()

View File

@ -17,8 +17,8 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
"""
def __init__(self, n_links: int, target: Union[None, Iterable] = None, random_start: bool = True,
allow_self_collision: bool = False, ):
super().__init__(n_links, random_start, allow_self_collision)
allow_self_collision: bool = False, **kwargs):
super().__init__(n_links, random_start, allow_self_collision, **kwargs)
# provided initial parameters
self.inital_target = target
@ -98,7 +98,9 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
def _check_collisions(self) -> bool:
return self._check_self_collision()
def render(self, mode='human'): # pragma: no cover
def render(self, mode=None): # pragma: no cover
if mode==None:
mode = self.render_mode
if self.fig is None:
# Create base figure once on the beginning. Afterwards only update
plt.ion()

View File

@ -13,9 +13,9 @@ from . import MPWrapper
class ViaPointReacherEnv(BaseReacherDirectEnv):
def __init__(self, n_links, random_start: bool = False, via_target: Union[None, Iterable] = None,
target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000):
target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000, **kwargs):
super().__init__(n_links, random_start, allow_self_collision)
super().__init__(n_links, random_start, allow_self_collision, **kwargs)
# provided initial parameters
self.intitial_target = target # provided target value
@ -123,7 +123,9 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
def _check_collisions(self) -> bool:
return self._check_self_collision()
def render(self, mode='human'):
def render(self, mode=None):
if mode==None:
mode = self.render_mode
goal_pos = self._goal.T
via_pos = self._via_point.T