Classical Controll envs: Follow new spec for render_mode
This commit is contained in:
parent
5db73f90c4
commit
b681129a46
@ -14,12 +14,14 @@ class BaseReacherEnv(gym.Env):
|
||||
Base class for all reaching environments.
|
||||
"""
|
||||
|
||||
def __init__(self, n_links: int, random_start: bool = True, allow_self_collision: bool = False):
|
||||
super().__init__()
|
||||
def __init__(self, n_links: int, random_start: bool = True, allow_self_collision: bool = False, render_mode: str = None, **kwargs):
|
||||
super().__init__(render_mode=render_mode, **kwargs)
|
||||
self.link_lengths = np.ones(n_links)
|
||||
self.n_links = n_links
|
||||
self._dt = 0.01
|
||||
|
||||
self.render_mode = render_mode
|
||||
|
||||
self.random_start = random_start
|
||||
|
||||
self.allow_self_collision = allow_self_collision
|
||||
|
@ -10,8 +10,8 @@ class BaseReacherDirectEnv(BaseReacherEnv):
|
||||
"""
|
||||
|
||||
def __init__(self, n_links: int, random_start: bool = True,
|
||||
allow_self_collision: bool = False):
|
||||
super().__init__(n_links, random_start, allow_self_collision)
|
||||
allow_self_collision: bool = False, **kwargs):
|
||||
super().__init__(n_links, random_start, allow_self_collision, **kwargs)
|
||||
|
||||
self.max_vel = 2 * np.pi
|
||||
action_bound = np.ones((self.n_links,)) * self.max_vel
|
||||
|
@ -10,8 +10,8 @@ class BaseReacherTorqueEnv(BaseReacherEnv):
|
||||
"""
|
||||
|
||||
def __init__(self, n_links: int, random_start: bool = True,
|
||||
allow_self_collision: bool = False):
|
||||
super().__init__(n_links, random_start, allow_self_collision)
|
||||
allow_self_collision: bool = False, **kwargs):
|
||||
super().__init__(n_links, random_start, allow_self_collision, **kwargs)
|
||||
|
||||
self.max_torque = 1000
|
||||
action_bound = np.ones((self.n_links,)) * self.max_torque
|
||||
|
@ -17,9 +17,9 @@ class HoleReacherEnv(BaseReacherDirectEnv):
|
||||
|
||||
def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None,
|
||||
hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False,
|
||||
allow_wall_collision: bool = False, collision_penalty: float = 1000, rew_fct: str = "simple"):
|
||||
allow_wall_collision: bool = False, collision_penalty: float = 1000, rew_fct: str = "simple", **kwargs):
|
||||
|
||||
super().__init__(n_links, random_start, allow_self_collision)
|
||||
super().__init__(n_links, random_start, allow_self_collision, **kwargs)
|
||||
|
||||
# provided initial parameters
|
||||
self.initial_x = hole_x # x-position of center of hole
|
||||
@ -178,7 +178,9 @@ class HoleReacherEnv(BaseReacherDirectEnv):
|
||||
|
||||
return False
|
||||
|
||||
def render(self, mode='human'):
|
||||
def render(self, mode=None):
|
||||
if mode==None:
|
||||
mode = self.render_mode
|
||||
if self.fig is None:
|
||||
# Create base figure once on the beginning. Afterwards only update
|
||||
plt.ion()
|
||||
|
@ -17,8 +17,8 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
|
||||
"""
|
||||
|
||||
def __init__(self, n_links: int, target: Union[None, Iterable] = None, random_start: bool = True,
|
||||
allow_self_collision: bool = False, ):
|
||||
super().__init__(n_links, random_start, allow_self_collision)
|
||||
allow_self_collision: bool = False, **kwargs):
|
||||
super().__init__(n_links, random_start, allow_self_collision, **kwargs)
|
||||
|
||||
# provided initial parameters
|
||||
self.inital_target = target
|
||||
@ -98,7 +98,9 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
|
||||
def _check_collisions(self) -> bool:
|
||||
return self._check_self_collision()
|
||||
|
||||
def render(self, mode='human'): # pragma: no cover
|
||||
def render(self, mode=None): # pragma: no cover
|
||||
if mode==None:
|
||||
mode = self.render_mode
|
||||
if self.fig is None:
|
||||
# Create base figure once on the beginning. Afterwards only update
|
||||
plt.ion()
|
||||
|
@ -13,9 +13,9 @@ from . import MPWrapper
|
||||
class ViaPointReacherEnv(BaseReacherDirectEnv):
|
||||
|
||||
def __init__(self, n_links, random_start: bool = False, via_target: Union[None, Iterable] = None,
|
||||
target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000):
|
||||
target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000, **kwargs):
|
||||
|
||||
super().__init__(n_links, random_start, allow_self_collision)
|
||||
super().__init__(n_links, random_start, allow_self_collision, **kwargs)
|
||||
|
||||
# provided initial parameters
|
||||
self.intitial_target = target # provided target value
|
||||
@ -123,7 +123,9 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
|
||||
def _check_collisions(self) -> bool:
|
||||
return self._check_self_collision()
|
||||
|
||||
def render(self, mode='human'):
|
||||
def render(self, mode=None):
|
||||
if mode==None:
|
||||
mode = self.render_mode
|
||||
goal_pos = self._goal.T
|
||||
via_pos = self._via_point.T
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user