diff --git a/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py b/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py index 372f0af..70b9bfd 100644 --- a/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py +++ b/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py @@ -197,7 +197,7 @@ class HoleReacherEnv(BaseReacherDirectEnv): self.fig.gca().set_title( f"Iteration: {self._steps}, distance: {np.linalg.norm(self.end_effector - self._goal) ** 2}") - if mode == "human": + if self.render_mode == "human": # arm self.line.set_data(self._joints[:, 0], self._joints[:, 1]) @@ -205,7 +205,7 @@ class HoleReacherEnv(BaseReacherDirectEnv): self.fig.canvas.draw() self.fig.canvas.flush_events() - elif mode == "partial": + elif self.render_mode == "partial": if self._steps % 20 == 0 or self._steps in [1, 199] or self._is_collided: # Arm plt.plot(self._joints[:, 0], self._joints[:, 1], 'ro-', markerfacecolor='k', diff --git a/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py b/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py index 05455e7..4003319 100644 --- a/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py +++ b/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py @@ -146,7 +146,7 @@ class ViaPointReacherEnv(BaseReacherDirectEnv): self.fig.gca().set_title(f"Iteration: {self._steps}, distance: {self.end_effector - self._goal}") - if mode == "human": + if self.render_mode == "human": # goal if self._steps == 1: self.goal_point_plot.set_data(goal_pos[0], goal_pos[1]) @@ -158,7 +158,7 @@ class ViaPointReacherEnv(BaseReacherDirectEnv): self.fig.canvas.draw() self.fig.canvas.flush_events() - elif mode == "partial": + elif self.render_mode == "partial": if self._steps == 1: # fig, ax = plt.subplots() # Add the patch to the Axes @@ -178,7 +178,7 @@ class ViaPointReacherEnv(BaseReacherDirectEnv): plt.ylim([-1.1, lim]) plt.pause(0.01) - elif mode == "final": + elif self.render_mode == "final": if self._steps == 199 or self._is_collided: # fig, ax = plt.subplots()