From b681129a46f4bbc01b05b9cb07de55a870f5435f Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 23 Oct 2023 12:26:26 +0200 Subject: [PATCH] Classical Controll envs: Follow new spec for render_mode --- .../envs/classic_control/base_reacher/base_reacher.py | 6 ++++-- .../classic_control/base_reacher/base_reacher_direct.py | 4 ++-- .../classic_control/base_reacher/base_reacher_torque.py | 4 ++-- .../envs/classic_control/hole_reacher/hole_reacher.py | 8 +++++--- .../envs/classic_control/simple_reacher/simple_reacher.py | 8 +++++--- .../classic_control/viapoint_reacher/viapoint_reacher.py | 8 +++++--- 6 files changed, 23 insertions(+), 15 deletions(-) diff --git a/fancy_gym/envs/classic_control/base_reacher/base_reacher.py b/fancy_gym/envs/classic_control/base_reacher/base_reacher.py index 18305fd..a03115d 100644 --- a/fancy_gym/envs/classic_control/base_reacher/base_reacher.py +++ b/fancy_gym/envs/classic_control/base_reacher/base_reacher.py @@ -14,12 +14,14 @@ class BaseReacherEnv(gym.Env): Base class for all reaching environments. """ - def __init__(self, n_links: int, random_start: bool = True, allow_self_collision: bool = False): - super().__init__() + def __init__(self, n_links: int, random_start: bool = True, allow_self_collision: bool = False, render_mode: str = None, **kwargs): + super().__init__(render_mode=render_mode, **kwargs) self.link_lengths = np.ones(n_links) self.n_links = n_links self._dt = 0.01 + self.render_mode = render_mode + self.random_start = random_start self.allow_self_collision = allow_self_collision diff --git a/fancy_gym/envs/classic_control/base_reacher/base_reacher_direct.py b/fancy_gym/envs/classic_control/base_reacher/base_reacher_direct.py index 6878922..8e47df8 100644 --- a/fancy_gym/envs/classic_control/base_reacher/base_reacher_direct.py +++ b/fancy_gym/envs/classic_control/base_reacher/base_reacher_direct.py @@ -10,8 +10,8 @@ class BaseReacherDirectEnv(BaseReacherEnv): """ def __init__(self, n_links: int, random_start: bool = True, - allow_self_collision: bool = False): - super().__init__(n_links, random_start, allow_self_collision) + allow_self_collision: bool = False, **kwargs): + super().__init__(n_links, random_start, allow_self_collision, **kwargs) self.max_vel = 2 * np.pi action_bound = np.ones((self.n_links,)) * self.max_vel diff --git a/fancy_gym/envs/classic_control/base_reacher/base_reacher_torque.py b/fancy_gym/envs/classic_control/base_reacher/base_reacher_torque.py index c9a7d4f..1dcd9ba 100644 --- a/fancy_gym/envs/classic_control/base_reacher/base_reacher_torque.py +++ b/fancy_gym/envs/classic_control/base_reacher/base_reacher_torque.py @@ -10,8 +10,8 @@ class BaseReacherTorqueEnv(BaseReacherEnv): """ def __init__(self, n_links: int, random_start: bool = True, - allow_self_collision: bool = False): - super().__init__(n_links, random_start, allow_self_collision) + allow_self_collision: bool = False, **kwargs): + super().__init__(n_links, random_start, allow_self_collision, **kwargs) self.max_torque = 1000 action_bound = np.ones((self.n_links,)) * self.max_torque diff --git a/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py b/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py index c9e0a61..4e5caaf 100644 --- a/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py +++ b/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py @@ -17,9 +17,9 @@ class HoleReacherEnv(BaseReacherDirectEnv): def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None, hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False, - allow_wall_collision: bool = False, collision_penalty: float = 1000, rew_fct: str = "simple"): + allow_wall_collision: bool = False, collision_penalty: float = 1000, rew_fct: str = "simple", **kwargs): - super().__init__(n_links, random_start, allow_self_collision) + super().__init__(n_links, random_start, allow_self_collision, **kwargs) # provided initial parameters self.initial_x = hole_x # x-position of center of hole @@ -178,7 +178,9 @@ class HoleReacherEnv(BaseReacherDirectEnv): return False - def render(self, mode='human'): + def render(self, mode=None): + if mode==None: + mode = self.render_mode if self.fig is None: # Create base figure once on the beginning. Afterwards only update plt.ion() diff --git a/fancy_gym/envs/classic_control/simple_reacher/simple_reacher.py b/fancy_gym/envs/classic_control/simple_reacher/simple_reacher.py index 3afd021..40a8153 100644 --- a/fancy_gym/envs/classic_control/simple_reacher/simple_reacher.py +++ b/fancy_gym/envs/classic_control/simple_reacher/simple_reacher.py @@ -17,8 +17,8 @@ class SimpleReacherEnv(BaseReacherTorqueEnv): """ def __init__(self, n_links: int, target: Union[None, Iterable] = None, random_start: bool = True, - allow_self_collision: bool = False, ): - super().__init__(n_links, random_start, allow_self_collision) + allow_self_collision: bool = False, **kwargs): + super().__init__(n_links, random_start, allow_self_collision, **kwargs) # provided initial parameters self.inital_target = target @@ -98,7 +98,9 @@ class SimpleReacherEnv(BaseReacherTorqueEnv): def _check_collisions(self) -> bool: return self._check_self_collision() - def render(self, mode='human'): # pragma: no cover + def render(self, mode=None): # pragma: no cover + if mode==None: + mode = self.render_mode if self.fig is None: # Create base figure once on the beginning. Afterwards only update plt.ion() diff --git a/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py b/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py index e4d9091..932f50a 100644 --- a/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py +++ b/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py @@ -13,9 +13,9 @@ from . import MPWrapper class ViaPointReacherEnv(BaseReacherDirectEnv): def __init__(self, n_links, random_start: bool = False, via_target: Union[None, Iterable] = None, - target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000): + target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000, **kwargs): - super().__init__(n_links, random_start, allow_self_collision) + super().__init__(n_links, random_start, allow_self_collision, **kwargs) # provided initial parameters self.intitial_target = target # provided target value @@ -123,7 +123,9 @@ class ViaPointReacherEnv(BaseReacherDirectEnv): def _check_collisions(self) -> bool: return self._check_self_collision() - def render(self, mode='human'): + def render(self, mode=None): + if mode==None: + mode = self.render_mode goal_pos = self._goal.T via_pos = self._via_point.T