Classical Controll envs: Follow new spec for render_mode
This commit is contained in:
parent
5db73f90c4
commit
b681129a46
@ -14,12 +14,14 @@ class BaseReacherEnv(gym.Env):
|
|||||||
Base class for all reaching environments.
|
Base class for all reaching environments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, n_links: int, random_start: bool = True, allow_self_collision: bool = False):
|
def __init__(self, n_links: int, random_start: bool = True, allow_self_collision: bool = False, render_mode: str = None, **kwargs):
|
||||||
super().__init__()
|
super().__init__(render_mode=render_mode, **kwargs)
|
||||||
self.link_lengths = np.ones(n_links)
|
self.link_lengths = np.ones(n_links)
|
||||||
self.n_links = n_links
|
self.n_links = n_links
|
||||||
self._dt = 0.01
|
self._dt = 0.01
|
||||||
|
|
||||||
|
self.render_mode = render_mode
|
||||||
|
|
||||||
self.random_start = random_start
|
self.random_start = random_start
|
||||||
|
|
||||||
self.allow_self_collision = allow_self_collision
|
self.allow_self_collision = allow_self_collision
|
||||||
|
@ -10,8 +10,8 @@ class BaseReacherDirectEnv(BaseReacherEnv):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, n_links: int, random_start: bool = True,
|
def __init__(self, n_links: int, random_start: bool = True,
|
||||||
allow_self_collision: bool = False):
|
allow_self_collision: bool = False, **kwargs):
|
||||||
super().__init__(n_links, random_start, allow_self_collision)
|
super().__init__(n_links, random_start, allow_self_collision, **kwargs)
|
||||||
|
|
||||||
self.max_vel = 2 * np.pi
|
self.max_vel = 2 * np.pi
|
||||||
action_bound = np.ones((self.n_links,)) * self.max_vel
|
action_bound = np.ones((self.n_links,)) * self.max_vel
|
||||||
|
@ -10,8 +10,8 @@ class BaseReacherTorqueEnv(BaseReacherEnv):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, n_links: int, random_start: bool = True,
|
def __init__(self, n_links: int, random_start: bool = True,
|
||||||
allow_self_collision: bool = False):
|
allow_self_collision: bool = False, **kwargs):
|
||||||
super().__init__(n_links, random_start, allow_self_collision)
|
super().__init__(n_links, random_start, allow_self_collision, **kwargs)
|
||||||
|
|
||||||
self.max_torque = 1000
|
self.max_torque = 1000
|
||||||
action_bound = np.ones((self.n_links,)) * self.max_torque
|
action_bound = np.ones((self.n_links,)) * self.max_torque
|
||||||
|
@ -17,9 +17,9 @@ class HoleReacherEnv(BaseReacherDirectEnv):
|
|||||||
|
|
||||||
def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None,
|
def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None,
|
||||||
hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False,
|
hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False,
|
||||||
allow_wall_collision: bool = False, collision_penalty: float = 1000, rew_fct: str = "simple"):
|
allow_wall_collision: bool = False, collision_penalty: float = 1000, rew_fct: str = "simple", **kwargs):
|
||||||
|
|
||||||
super().__init__(n_links, random_start, allow_self_collision)
|
super().__init__(n_links, random_start, allow_self_collision, **kwargs)
|
||||||
|
|
||||||
# provided initial parameters
|
# provided initial parameters
|
||||||
self.initial_x = hole_x # x-position of center of hole
|
self.initial_x = hole_x # x-position of center of hole
|
||||||
@ -178,7 +178,9 @@ class HoleReacherEnv(BaseReacherDirectEnv):
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def render(self, mode='human'):
|
def render(self, mode=None):
|
||||||
|
if mode==None:
|
||||||
|
mode = self.render_mode
|
||||||
if self.fig is None:
|
if self.fig is None:
|
||||||
# Create base figure once on the beginning. Afterwards only update
|
# Create base figure once on the beginning. Afterwards only update
|
||||||
plt.ion()
|
plt.ion()
|
||||||
|
@ -17,8 +17,8 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, n_links: int, target: Union[None, Iterable] = None, random_start: bool = True,
|
def __init__(self, n_links: int, target: Union[None, Iterable] = None, random_start: bool = True,
|
||||||
allow_self_collision: bool = False, ):
|
allow_self_collision: bool = False, **kwargs):
|
||||||
super().__init__(n_links, random_start, allow_self_collision)
|
super().__init__(n_links, random_start, allow_self_collision, **kwargs)
|
||||||
|
|
||||||
# provided initial parameters
|
# provided initial parameters
|
||||||
self.inital_target = target
|
self.inital_target = target
|
||||||
@ -98,7 +98,9 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
|
|||||||
def _check_collisions(self) -> bool:
|
def _check_collisions(self) -> bool:
|
||||||
return self._check_self_collision()
|
return self._check_self_collision()
|
||||||
|
|
||||||
def render(self, mode='human'): # pragma: no cover
|
def render(self, mode=None): # pragma: no cover
|
||||||
|
if mode==None:
|
||||||
|
mode = self.render_mode
|
||||||
if self.fig is None:
|
if self.fig is None:
|
||||||
# Create base figure once on the beginning. Afterwards only update
|
# Create base figure once on the beginning. Afterwards only update
|
||||||
plt.ion()
|
plt.ion()
|
||||||
|
@ -13,9 +13,9 @@ from . import MPWrapper
|
|||||||
class ViaPointReacherEnv(BaseReacherDirectEnv):
|
class ViaPointReacherEnv(BaseReacherDirectEnv):
|
||||||
|
|
||||||
def __init__(self, n_links, random_start: bool = False, via_target: Union[None, Iterable] = None,
|
def __init__(self, n_links, random_start: bool = False, via_target: Union[None, Iterable] = None,
|
||||||
target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000):
|
target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000, **kwargs):
|
||||||
|
|
||||||
super().__init__(n_links, random_start, allow_self_collision)
|
super().__init__(n_links, random_start, allow_self_collision, **kwargs)
|
||||||
|
|
||||||
# provided initial parameters
|
# provided initial parameters
|
||||||
self.intitial_target = target # provided target value
|
self.intitial_target = target # provided target value
|
||||||
@ -123,7 +123,9 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
|
|||||||
def _check_collisions(self) -> bool:
|
def _check_collisions(self) -> bool:
|
||||||
return self._check_self_collision()
|
return self._check_self_collision()
|
||||||
|
|
||||||
def render(self, mode='human'):
|
def render(self, mode=None):
|
||||||
|
if mode==None:
|
||||||
|
mode = self.render_mode
|
||||||
goal_pos = self._goal.T
|
goal_pos = self._goal.T
|
||||||
via_pos = self._via_point.T
|
via_pos = self._via_point.T
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user