From 5db73f90c43acedaafb73bbee54ecda41001d76c Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 23 Oct 2023 12:25:53 +0200 Subject: [PATCH 1/8] Update BB wrapper to follow new spec for render_kwargs --- fancy_gym/black_box/black_box_wrapper.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/fancy_gym/black_box/black_box_wrapper.py b/fancy_gym/black_box/black_box_wrapper.py index 6da24c7..c963f0c 100644 --- a/fancy_gym/black_box/black_box_wrapper.py +++ b/fancy_gym/black_box/black_box_wrapper.py @@ -75,7 +75,7 @@ class BlackBoxWrapper(gym.ObservationWrapper): self.observation_space = self._get_observation_space() # rendering - self.render_kwargs = {} + self.do_render = False self.verbose = verbose # condition value @@ -164,7 +164,7 @@ class BlackBoxWrapper(gym.ObservationWrapper): dtype=self.env.observation_space.dtype) infos = dict() - done = False + terminated, truncated = False, False if not traj_is_valid: obs, trajectory_return, terminated, truncated, infos = self.env.invalid_traj_callback(action, position, velocity, @@ -190,8 +190,9 @@ class BlackBoxWrapper(gym.ObservationWrapper): elems[t] = v infos[k] = elems - if self.render_kwargs: - self.env.render(**self.render_kwargs) + if self.do_render: + self.env.render() + if terminated or truncated or (self.replanning_schedule(self.env.get_wrapper_attr('current_pos'), self.env.get_wrapper_attr('current_vel'), obs, c_action, t + 1 + self.current_traj_steps) and self.plan_steps < self.max_planning_times): @@ -215,10 +216,8 @@ class BlackBoxWrapper(gym.ObservationWrapper): trajectory_return = self.reward_aggregation(rewards[:t + 1]) return self.observation(obs), trajectory_return, terminated, truncated, infos - def render(self, **kwargs): - """Only set render options here, such that they can be used during the rollout. - This only needs to be called once""" - self.render_kwargs = kwargs + def render(self): + self.do_render = True def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \ -> Tuple[ObsType, Dict[str, Any]]: From b681129a46f4bbc01b05b9cb07de55a870f5435f Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 23 Oct 2023 12:26:26 +0200 Subject: [PATCH 2/8] Classical Controll envs: Follow new spec for render_mode --- .../envs/classic_control/base_reacher/base_reacher.py | 6 ++++-- .../classic_control/base_reacher/base_reacher_direct.py | 4 ++-- .../classic_control/base_reacher/base_reacher_torque.py | 4 ++-- .../envs/classic_control/hole_reacher/hole_reacher.py | 8 +++++--- .../envs/classic_control/simple_reacher/simple_reacher.py | 8 +++++--- .../classic_control/viapoint_reacher/viapoint_reacher.py | 8 +++++--- 6 files changed, 23 insertions(+), 15 deletions(-) diff --git a/fancy_gym/envs/classic_control/base_reacher/base_reacher.py b/fancy_gym/envs/classic_control/base_reacher/base_reacher.py index 18305fd..a03115d 100644 --- a/fancy_gym/envs/classic_control/base_reacher/base_reacher.py +++ b/fancy_gym/envs/classic_control/base_reacher/base_reacher.py @@ -14,12 +14,14 @@ class BaseReacherEnv(gym.Env): Base class for all reaching environments. """ - def __init__(self, n_links: int, random_start: bool = True, allow_self_collision: bool = False): - super().__init__() + def __init__(self, n_links: int, random_start: bool = True, allow_self_collision: bool = False, render_mode: str = None, **kwargs): + super().__init__(render_mode=render_mode, **kwargs) self.link_lengths = np.ones(n_links) self.n_links = n_links self._dt = 0.01 + self.render_mode = render_mode + self.random_start = random_start self.allow_self_collision = allow_self_collision diff --git a/fancy_gym/envs/classic_control/base_reacher/base_reacher_direct.py b/fancy_gym/envs/classic_control/base_reacher/base_reacher_direct.py index 6878922..8e47df8 100644 --- a/fancy_gym/envs/classic_control/base_reacher/base_reacher_direct.py +++ b/fancy_gym/envs/classic_control/base_reacher/base_reacher_direct.py @@ -10,8 +10,8 @@ class BaseReacherDirectEnv(BaseReacherEnv): """ def __init__(self, n_links: int, random_start: bool = True, - allow_self_collision: bool = False): - super().__init__(n_links, random_start, allow_self_collision) + allow_self_collision: bool = False, **kwargs): + super().__init__(n_links, random_start, allow_self_collision, **kwargs) self.max_vel = 2 * np.pi action_bound = np.ones((self.n_links,)) * self.max_vel diff --git a/fancy_gym/envs/classic_control/base_reacher/base_reacher_torque.py b/fancy_gym/envs/classic_control/base_reacher/base_reacher_torque.py index c9a7d4f..1dcd9ba 100644 --- a/fancy_gym/envs/classic_control/base_reacher/base_reacher_torque.py +++ b/fancy_gym/envs/classic_control/base_reacher/base_reacher_torque.py @@ -10,8 +10,8 @@ class BaseReacherTorqueEnv(BaseReacherEnv): """ def __init__(self, n_links: int, random_start: bool = True, - allow_self_collision: bool = False): - super().__init__(n_links, random_start, allow_self_collision) + allow_self_collision: bool = False, **kwargs): + super().__init__(n_links, random_start, allow_self_collision, **kwargs) self.max_torque = 1000 action_bound = np.ones((self.n_links,)) * self.max_torque diff --git a/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py b/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py index c9e0a61..4e5caaf 100644 --- a/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py +++ b/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py @@ -17,9 +17,9 @@ class HoleReacherEnv(BaseReacherDirectEnv): def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None, hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False, - allow_wall_collision: bool = False, collision_penalty: float = 1000, rew_fct: str = "simple"): + allow_wall_collision: bool = False, collision_penalty: float = 1000, rew_fct: str = "simple", **kwargs): - super().__init__(n_links, random_start, allow_self_collision) + super().__init__(n_links, random_start, allow_self_collision, **kwargs) # provided initial parameters self.initial_x = hole_x # x-position of center of hole @@ -178,7 +178,9 @@ class HoleReacherEnv(BaseReacherDirectEnv): return False - def render(self, mode='human'): + def render(self, mode=None): + if mode==None: + mode = self.render_mode if self.fig is None: # Create base figure once on the beginning. Afterwards only update plt.ion() diff --git a/fancy_gym/envs/classic_control/simple_reacher/simple_reacher.py b/fancy_gym/envs/classic_control/simple_reacher/simple_reacher.py index 3afd021..40a8153 100644 --- a/fancy_gym/envs/classic_control/simple_reacher/simple_reacher.py +++ b/fancy_gym/envs/classic_control/simple_reacher/simple_reacher.py @@ -17,8 +17,8 @@ class SimpleReacherEnv(BaseReacherTorqueEnv): """ def __init__(self, n_links: int, target: Union[None, Iterable] = None, random_start: bool = True, - allow_self_collision: bool = False, ): - super().__init__(n_links, random_start, allow_self_collision) + allow_self_collision: bool = False, **kwargs): + super().__init__(n_links, random_start, allow_self_collision, **kwargs) # provided initial parameters self.inital_target = target @@ -98,7 +98,9 @@ class SimpleReacherEnv(BaseReacherTorqueEnv): def _check_collisions(self) -> bool: return self._check_self_collision() - def render(self, mode='human'): # pragma: no cover + def render(self, mode=None): # pragma: no cover + if mode==None: + mode = self.render_mode if self.fig is None: # Create base figure once on the beginning. Afterwards only update plt.ion() diff --git a/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py b/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py index e4d9091..932f50a 100644 --- a/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py +++ b/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py @@ -13,9 +13,9 @@ from . import MPWrapper class ViaPointReacherEnv(BaseReacherDirectEnv): def __init__(self, n_links, random_start: bool = False, via_target: Union[None, Iterable] = None, - target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000): + target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000, **kwargs): - super().__init__(n_links, random_start, allow_self_collision) + super().__init__(n_links, random_start, allow_self_collision, **kwargs) # provided initial parameters self.intitial_target = target # provided target value @@ -123,7 +123,9 @@ class ViaPointReacherEnv(BaseReacherDirectEnv): def _check_collisions(self) -> bool: return self._check_self_collision() - def render(self, mode='human'): + def render(self, mode=None): + if mode==None: + mode = self.render_mode goal_pos = self._goal.T via_pos = self._via_point.T From f7a493d8e5a6259e8b3f3e0e4ba8e792d60b0c5d Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 23 Oct 2023 12:27:13 +0200 Subject: [PATCH 3/8] Mujoco envs: Follow new spec for render_mode --- fancy_gym/envs/mujoco/ant_jump/ant_jump.py | 3 ++- .../mujoco/box_pushing/box_pushing_env.py | 24 +++++++++---------- .../half_cheetah_jump/half_cheetah_jump.py | 6 +++-- .../envs/mujoco/hopper_jump/hopper_jump.py | 4 +++- .../mujoco/hopper_jump/hopper_jump_on_box.py | 6 +++-- .../envs/mujoco/hopper_throw/hopper_throw.py | 6 +++-- .../hopper_throw/hopper_throw_in_basket.py | 6 +++-- .../mujoco/table_tennis/table_tennis_env.py | 5 ++-- .../mujoco/walker_2d_jump/walker_2d_jump.py | 6 +++-- 9 files changed, 40 insertions(+), 26 deletions(-) diff --git a/fancy_gym/envs/mujoco/ant_jump/ant_jump.py b/fancy_gym/envs/mujoco/ant_jump/ant_jump.py index ed6bea5..97cde0e 100644 --- a/fancy_gym/envs/mujoco/ant_jump/ant_jump.py +++ b/fancy_gym/envs/mujoco/ant_jump/ant_jump.py @@ -101,6 +101,7 @@ class AntJumpEnv(AntEnvCustomXML): contact_force_range=(-1.0, 1.0), reset_noise_scale=0.1, exclude_current_positions_from_observation=True, + **kwargs ): self.current_step = 0 self.max_height = 0 @@ -113,7 +114,7 @@ class AntJumpEnv(AntEnvCustomXML): healthy_z_range=healthy_z_range, contact_force_range=contact_force_range, reset_noise_scale=reset_noise_scale, - exclude_current_positions_from_observation=exclude_current_positions_from_observation) + exclude_current_positions_from_observation=exclude_current_positions_from_observation, **kwargs) def step(self, action): self.current_step += 1 diff --git a/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py b/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py index 932e3df..9b512a8 100644 --- a/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py +++ b/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py @@ -36,7 +36,7 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle): "render_fps": 50 } - def __init__(self, frame_skip: int = 10, random_init: bool = False): + def __init__(self, frame_skip: int = 10, random_init: bool = False, **kwargs): utils.EzPickle.__init__(**locals()) self._steps = 0 self.init_qpos_box_pushing = np.array([0., 0., 0., -1.5, 0., 1.5, 0., 0., 0., 0.6, 0.45, 0.0, 1., 0., 0., 0.]) @@ -58,7 +58,7 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle): MujocoEnv.__init__(self, model_path=os.path.join(os.path.dirname(__file__), "assets", "box_pushing.xml"), frame_skip=self.frame_skip, - observation_space=self.observation_space) + observation_space=self.observation_space, **kwargs) self.action_space = spaces.Box(low=-1, high=1, shape=(7,)) def step(self, action): @@ -305,8 +305,8 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle): class BoxPushingDense(BoxPushingEnvBase): - def __init__(self, frame_skip: int = 10, random_init: bool = False): - super(BoxPushingDense, self).__init__(frame_skip=frame_skip, random_init=random_init) + def __init__(self, **kwargs): + super(BoxPushingDense, self).__init__(**kwargs) def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat, rod_tip_pos, rod_quat, qpos, qvel, action): joint_penalty = self._joint_limit_violate_penalty(qpos, @@ -329,8 +329,8 @@ class BoxPushingDense(BoxPushingEnvBase): class BoxPushingTemporalSparse(BoxPushingEnvBase): - def __init__(self, frame_skip: int = 10, random_init: bool = False): - super(BoxPushingTemporalSparse, self).__init__(frame_skip=frame_skip, random_init=random_init) + def __init__(self, **kwargs): + super(BoxPushingTemporalSparse, self).__init__(**kwargs) def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat, rod_tip_pos, rod_quat, qpos, qvel, action): @@ -361,8 +361,8 @@ class BoxPushingTemporalSparse(BoxPushingEnvBase): class BoxPushingTemporalSpatialSparse(BoxPushingEnvBase): - def __init__(self, frame_skip: int = 10, random_init: bool = False): - super(BoxPushingTemporalSpatialSparse, self).__init__(frame_skip=frame_skip, random_init=random_init) + def __init__(self, **kwargs): + super(BoxPushingTemporalSpatialSparse, self).__init__(**kwargs) def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat, rod_tip_pos, rod_quat, qpos, qvel, action): @@ -392,8 +392,8 @@ class BoxPushingTemporalSpatialSparse(BoxPushingEnvBase): class BoxPushingTemporalSpatialSparse2(BoxPushingEnvBase): - def __init__(self, frame_skip: int = 10, random_init: bool = False): - super(BoxPushingTemporalSpatialSparse2, self).__init__(frame_skip=frame_skip, random_init=random_init) + def __init__(self, **kwargs): + super(BoxPushingTemporalSpatialSparse2, self).__init__(**kwargs) def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat, rod_tip_pos, rod_quat, qpos, qvel, action): @@ -428,8 +428,8 @@ class BoxPushingTemporalSpatialSparse2(BoxPushingEnvBase): class BoxPushingNoConstraintSparse(BoxPushingEnvBase): - def __init__(self, frame_skip: int = 10, random_init: bool = False): - super(BoxPushingNoConstraintSparse, self).__init__(frame_skip=frame_skip, random_init=random_init) + def __init__(self, **kwargs): + super(BoxPushingNoConstraintSparse, self).__init__(**kwargs) def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat, rod_tip_pos, rod_quat, qpos, qvel, action): diff --git a/fancy_gym/envs/mujoco/half_cheetah_jump/half_cheetah_jump.py b/fancy_gym/envs/mujoco/half_cheetah_jump/half_cheetah_jump.py index f15a9f4..088f959 100644 --- a/fancy_gym/envs/mujoco/half_cheetah_jump/half_cheetah_jump.py +++ b/fancy_gym/envs/mujoco/half_cheetah_jump/half_cheetah_jump.py @@ -74,7 +74,8 @@ class HalfCheetahJumpEnv(HalfCheetahEnvCustomXML): reset_noise_scale=0.1, context=True, exclude_current_positions_from_observation=True, - max_episode_steps=100): + max_episode_steps=100, + **kwargs): self.current_step = 0 self.max_height = 0 # self.max_episode_steps = max_episode_steps @@ -85,7 +86,8 @@ class HalfCheetahJumpEnv(HalfCheetahEnvCustomXML): forward_reward_weight=forward_reward_weight, ctrl_cost_weight=ctrl_cost_weight, reset_noise_scale=reset_noise_scale, - exclude_current_positions_from_observation=exclude_current_positions_from_observation) + exclude_current_positions_from_observation=exclude_current_positions_from_observation, + **kwargs) def step(self, action): diff --git a/fancy_gym/envs/mujoco/hopper_jump/hopper_jump.py b/fancy_gym/envs/mujoco/hopper_jump/hopper_jump.py index b77cab1..ae431ab 100644 --- a/fancy_gym/envs/mujoco/hopper_jump/hopper_jump.py +++ b/fancy_gym/envs/mujoco/hopper_jump/hopper_jump.py @@ -115,6 +115,7 @@ class HopperJumpEnv(HopperEnvCustomXML): reset_noise_scale=5e-3, exclude_current_positions_from_observation=False, sparse=False, + **kwargs ): self.sparse = sparse @@ -141,7 +142,8 @@ class HopperJumpEnv(HopperEnvCustomXML): healthy_z_range=healthy_z_range, healthy_angle_range=healthy_angle_range, reset_noise_scale=reset_noise_scale, - exclude_current_positions_from_observation=exclude_current_positions_from_observation) + exclude_current_positions_from_observation=exclude_current_positions_from_observation, + **kwargs) # increase initial height self.init_qpos[1] = 1.5 diff --git a/fancy_gym/envs/mujoco/hopper_jump/hopper_jump_on_box.py b/fancy_gym/envs/mujoco/hopper_jump/hopper_jump_on_box.py index 506344b..c0c57c2 100644 --- a/fancy_gym/envs/mujoco/hopper_jump/hopper_jump_on_box.py +++ b/fancy_gym/envs/mujoco/hopper_jump/hopper_jump_on_box.py @@ -29,7 +29,8 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML): reset_noise_scale=5e-3, context=True, exclude_current_positions_from_observation=True, - max_episode_steps=250): + max_episode_steps=250, + **kwargs): self.current_step = 0 self.max_height = 0 self.max_episode_steps = max_episode_steps @@ -50,7 +51,8 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML): xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file) super().__init__(xml_file, forward_reward_weight, ctrl_cost_weight, healthy_reward, terminate_when_unhealthy, healthy_state_range, healthy_z_range, healthy_angle_range, reset_noise_scale, - exclude_current_positions_from_observation) + exclude_current_positions_from_observation, + **kwargs) def step(self, action): diff --git a/fancy_gym/envs/mujoco/hopper_throw/hopper_throw.py b/fancy_gym/envs/mujoco/hopper_throw/hopper_throw.py index b5afc8b..7a39cd8 100644 --- a/fancy_gym/envs/mujoco/hopper_throw/hopper_throw.py +++ b/fancy_gym/envs/mujoco/hopper_throw/hopper_throw.py @@ -32,7 +32,8 @@ class HopperThrowEnv(HopperEnvCustomXML): reset_noise_scale=5e-3, context=True, exclude_current_positions_from_observation=True, - max_episode_steps=250): + max_episode_steps=250, + **kwargs): xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file) self.current_step = 0 self.max_episode_steps = max_episode_steps @@ -57,7 +58,8 @@ class HopperThrowEnv(HopperEnvCustomXML): healthy_z_range=healthy_z_range, healthy_state_range=healthy_angle_range, reset_noise_scale=reset_noise_scale, - exclude_current_positions_from_observation=exclude_current_positions_from_observation) + exclude_current_positions_from_observation=exclude_current_positions_from_observation, + **kwargs) def step(self, action): self.current_step += 1 diff --git a/fancy_gym/envs/mujoco/hopper_throw/hopper_throw_in_basket.py b/fancy_gym/envs/mujoco/hopper_throw/hopper_throw_in_basket.py index 00d1bdb..24ad402 100644 --- a/fancy_gym/envs/mujoco/hopper_throw/hopper_throw_in_basket.py +++ b/fancy_gym/envs/mujoco/hopper_throw/hopper_throw_in_basket.py @@ -36,7 +36,8 @@ class HopperThrowInBasketEnv(HopperEnvCustomXML): context=True, penalty=0.0, exclude_current_positions_from_observation=True, - max_episode_steps=250): + max_episode_steps=250, + **kwargs): self.hit_basket_reward = hit_basket_reward self.current_step = 0 self.max_episode_steps = max_episode_steps @@ -65,7 +66,8 @@ class HopperThrowInBasketEnv(HopperEnvCustomXML): healthy_z_range=healthy_z_range, healthy_angle_range=healthy_angle_range, reset_noise_scale=reset_noise_scale, - exclude_current_positions_from_observation=exclude_current_positions_from_observation) + exclude_current_positions_from_observation=exclude_current_positions_from_observation, + **kwargs) def step(self, action): diff --git a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py index 5395de7..6a08ed1 100644 --- a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py +++ b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py @@ -34,7 +34,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4, goal_switching_step: int = None, - enable_artificial_wind: bool = False): + enable_artificial_wind: bool = False, **kwargs): utils.EzPickle.__init__(**locals()) self._steps = 0 @@ -68,7 +68,8 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): MujocoEnv.__init__(self, model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"), frame_skip=frame_skip, - observation_space=self.observation_space) + observation_space=self.observation_space, + **kwargs) if ctxt_dim == 2: self.context_bounds = CONTEXT_BOUNDS_2DIMS diff --git a/fancy_gym/envs/mujoco/walker_2d_jump/walker_2d_jump.py b/fancy_gym/envs/mujoco/walker_2d_jump/walker_2d_jump.py index 6ad2be0..d9085ee 100644 --- a/fancy_gym/envs/mujoco/walker_2d_jump/walker_2d_jump.py +++ b/fancy_gym/envs/mujoco/walker_2d_jump/walker_2d_jump.py @@ -97,7 +97,8 @@ class Walker2dJumpEnv(Walker2dEnvCustomXML): reset_noise_scale=5e-3, penalty=0, exclude_current_positions_from_observation=True, - max_episode_steps=300): + max_episode_steps=300, + **kwargs): self.current_step = 0 self.max_episode_steps = max_episode_steps self.max_height = 0 @@ -112,7 +113,8 @@ class Walker2dJumpEnv(Walker2dEnvCustomXML): healthy_z_range=healthy_z_range, healthy_angle_range=healthy_angle_range, reset_noise_scale=reset_noise_scale, - exclude_current_positions_from_observation=exclude_current_positions_from_observation) + exclude_current_positions_from_observation=exclude_current_positions_from_observation, + **kwargs) def step(self, action): self.current_step += 1 From d7ea1f80a00ecef1f60a9db884dc0ab4045da63f Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 23 Oct 2023 13:12:34 +0200 Subject: [PATCH 4/8] gym.Env actually does not want to known about render_mode --- fancy_gym/envs/classic_control/base_reacher/base_reacher.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fancy_gym/envs/classic_control/base_reacher/base_reacher.py b/fancy_gym/envs/classic_control/base_reacher/base_reacher.py index a03115d..2463134 100644 --- a/fancy_gym/envs/classic_control/base_reacher/base_reacher.py +++ b/fancy_gym/envs/classic_control/base_reacher/base_reacher.py @@ -14,8 +14,8 @@ class BaseReacherEnv(gym.Env): Base class for all reaching environments. """ - def __init__(self, n_links: int, random_start: bool = True, allow_self_collision: bool = False, render_mode: str = None, **kwargs): - super().__init__(render_mode=render_mode, **kwargs) + def __init__(self, n_links: int, random_start: bool = True, allow_self_collision: bool = False, render_mode: str = None): + super().__init__() self.link_lengths = np.ones(n_links) self.n_links = n_links self._dt = 0.01 From dc773c6c1072ffbb1f7f5a46b4722b4600fc4f55 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 29 Oct 2023 12:53:06 +0100 Subject: [PATCH 5/8] Fixed TableTennisWind and TableTennisGoalSwitching not correctly passing args to parent class --- fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py index 6a08ed1..216ca1f 100644 --- a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py +++ b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py @@ -276,11 +276,11 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): class TableTennisWind(TableTennisEnv): - def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4): + def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4, **kwargs): self.observation_space = spaces.Box( low=-np.inf, high=np.inf, shape=(22,), dtype=np.float64 ) - super().__init__(ctxt_dim=ctxt_dim, frame_skip=frame_skip, enable_artificial_wind=True) + super().__init__(ctxt_dim=ctxt_dim, frame_skip=frame_skip, enable_artificial_wind=True, **kwargs) def _get_obs(self): obs = np.concatenate([ @@ -298,5 +298,5 @@ class TableTennisWind(TableTennisEnv): class TableTennisGoalSwitching(TableTennisEnv): - def __init__(self, frame_skip: int = 4, goal_switching_step: int = 99): - super().__init__(frame_skip=frame_skip, goal_switching_step=goal_switching_step) + def __init__(self, frame_skip: int = 4, goal_switching_step: int = 99, **kwargs): + super().__init__(frame_skip=frame_skip, goal_switching_step=goal_switching_step, **kwargs) From 01d2bf44ba77fd304b6e8be722fb60521e2626e4 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Tue, 21 Nov 2023 20:17:19 +0100 Subject: [PATCH 6/8] remove 'mode' from render() for all classic_control envs --- fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py | 4 +--- .../envs/classic_control/simple_reacher/simple_reacher.py | 4 +--- .../envs/classic_control/viapoint_reacher/viapoint_reacher.py | 4 +--- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py b/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py index 4e5caaf..372f0af 100644 --- a/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py +++ b/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py @@ -178,9 +178,7 @@ class HoleReacherEnv(BaseReacherDirectEnv): return False - def render(self, mode=None): - if mode==None: - mode = self.render_mode + def render(self): if self.fig is None: # Create base figure once on the beginning. Afterwards only update plt.ion() diff --git a/fancy_gym/envs/classic_control/simple_reacher/simple_reacher.py b/fancy_gym/envs/classic_control/simple_reacher/simple_reacher.py index 40a8153..9264b39 100644 --- a/fancy_gym/envs/classic_control/simple_reacher/simple_reacher.py +++ b/fancy_gym/envs/classic_control/simple_reacher/simple_reacher.py @@ -98,9 +98,7 @@ class SimpleReacherEnv(BaseReacherTorqueEnv): def _check_collisions(self) -> bool: return self._check_self_collision() - def render(self, mode=None): # pragma: no cover - if mode==None: - mode = self.render_mode + def render(self): # pragma: no cover if self.fig is None: # Create base figure once on the beginning. Afterwards only update plt.ion() diff --git a/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py b/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py index 932f50a..05455e7 100644 --- a/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py +++ b/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py @@ -123,9 +123,7 @@ class ViaPointReacherEnv(BaseReacherDirectEnv): def _check_collisions(self) -> bool: return self._check_self_collision() - def render(self, mode=None): - if mode==None: - mode = self.render_mode + def render(self): goal_pos = self._goal.T via_pos = self._via_point.T From cee8d5910708cfc191b99d63f55a88266997f604 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Tue, 21 Nov 2023 20:19:47 +0100 Subject: [PATCH 7/8] remove 'mode' from render() for all dummy envs (test suite) --- test/test_black_box.py | 2 +- test/test_fancy_registry.py | 2 +- test/test_replanning_sequencing.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_black_box.py b/test/test_black_box.py index 8cdc543..c1a760a 100644 --- a/test/test_black_box.py +++ b/test/test_black_box.py @@ -41,7 +41,7 @@ class ToyEnv(gym.Env): obs, reward, terminated, truncated, info = np.array([-1]), 1, False, False, {} return obs, reward, terminated, truncated, info - def render(self, mode="human"): + def render(self): pass diff --git a/test/test_fancy_registry.py b/test/test_fancy_registry.py index aad076b..849b36f 100644 --- a/test/test_fancy_registry.py +++ b/test/test_fancy_registry.py @@ -33,7 +33,7 @@ class ToyEnv(gym.Env): obs, reward, terminated, truncated, info = np.array([-1]), 1, False, False, {} return obs, reward, terminated, truncated, info - def render(self, mode="human"): + def render(self): pass diff --git a/test/test_replanning_sequencing.py b/test/test_replanning_sequencing.py index c2edf42..db463c8 100644 --- a/test/test_replanning_sequencing.py +++ b/test/test_replanning_sequencing.py @@ -37,7 +37,7 @@ class ToyEnv(gym.Env): obs, reward, terminated, truncated, info = np.array([-1]), 1, False, False, {} return obs, reward, terminated, truncated, info - def render(self, mode="human"): + def render(self): pass From 052abcbf1ccf7c16ca474513f9e59c3e1a598613 Mon Sep 17 00:00:00 2001 From: Hongyi Zhou Date: Wed, 22 Nov 2023 11:24:23 +0100 Subject: [PATCH 8/8] fix the render_mode in via_point_reacher and hole reacher --- fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py | 4 ++-- .../classic_control/viapoint_reacher/viapoint_reacher.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py b/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py index 372f0af..70b9bfd 100644 --- a/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py +++ b/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py @@ -197,7 +197,7 @@ class HoleReacherEnv(BaseReacherDirectEnv): self.fig.gca().set_title( f"Iteration: {self._steps}, distance: {np.linalg.norm(self.end_effector - self._goal) ** 2}") - if mode == "human": + if self.render_mode == "human": # arm self.line.set_data(self._joints[:, 0], self._joints[:, 1]) @@ -205,7 +205,7 @@ class HoleReacherEnv(BaseReacherDirectEnv): self.fig.canvas.draw() self.fig.canvas.flush_events() - elif mode == "partial": + elif self.render_mode == "partial": if self._steps % 20 == 0 or self._steps in [1, 199] or self._is_collided: # Arm plt.plot(self._joints[:, 0], self._joints[:, 1], 'ro-', markerfacecolor='k', diff --git a/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py b/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py index 05455e7..4003319 100644 --- a/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py +++ b/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py @@ -146,7 +146,7 @@ class ViaPointReacherEnv(BaseReacherDirectEnv): self.fig.gca().set_title(f"Iteration: {self._steps}, distance: {self.end_effector - self._goal}") - if mode == "human": + if self.render_mode == "human": # goal if self._steps == 1: self.goal_point_plot.set_data(goal_pos[0], goal_pos[1]) @@ -158,7 +158,7 @@ class ViaPointReacherEnv(BaseReacherDirectEnv): self.fig.canvas.draw() self.fig.canvas.flush_events() - elif mode == "partial": + elif self.render_mode == "partial": if self._steps == 1: # fig, ax = plt.subplots() # Add the patch to the Axes @@ -178,7 +178,7 @@ class ViaPointReacherEnv(BaseReacherDirectEnv): plt.ylim([-1.1, lim]) plt.pause(0.01) - elif mode == "final": + elif self.render_mode == "final": if self._steps == 199 or self._is_collided: # fig, ax = plt.subplots()