diff --git a/alr_envs/classic_control/simple_reacher.py b/alr_envs/classic_control/simple_reacher.py index 945bc3e..265c70b 100644 --- a/alr_envs/classic_control/simple_reacher.py +++ b/alr_envs/classic_control/simple_reacher.py @@ -1,4 +1,5 @@ import os +import time import gym import numpy as np @@ -32,13 +33,15 @@ class SimpleReacherEnv(gym.Env, utils.EzPickle): self._angle_velocity = None self.max_torque = 1 # 10 + self.steps_before_reward = 0 action_bound = np.ones((self.n_links,)) state_bound = np.hstack([ - [np.pi] * self.n_links, - [np.inf] * self.n_links, + [np.pi] * self.n_links, # cos + [np.pi] * self.n_links, # sin + [np.inf] * self.n_links, # velocity [np.inf] * 2, # x-y coordinates of target distance - [np.inf] # TODO: Maybe + [np.inf] # env steps, because reward start after n steps TODO: Maybe ]) self.action_space = spaces.Box(low=-action_bound, high=action_bound, shape=action_bound.shape) self.observation_space = spaces.Box(low=-state_bound, high=state_bound, shape=state_bound.shape) @@ -83,7 +86,14 @@ class SimpleReacherEnv(gym.Env, utils.EzPickle): return np.clip(action, lb, ub) def _get_obs(self): - return np.hstack([self._joint_angle, self._angle_velocity, self.end_effector - self._goal_pos, self._steps]) + theta = self._joint_angle + return np.hstack([ + np.cos(theta), + np.sin(theta), + self._angle_velocity, + self.end_effector - self._goal_pos, + self._steps + ]) def _update_joints(self): """ @@ -100,7 +110,7 @@ class SimpleReacherEnv(gym.Env, utils.EzPickle): distance = 0 # TODO: Is this the best option - if self._steps > 150: + if self._steps >= self.steps_before_reward: distance = np.exp(-0.1 * diff ** 2).mean() # distance -= (diff ** 2).mean() @@ -115,6 +125,7 @@ class SimpleReacherEnv(gym.Env, utils.EzPickle): self._angle_velocity = np.zeros(self.n_links) self._joints = np.zeros((self.n_links + 1, 2)) self._update_joints() + self._steps = 0 self._goal_pos = self._get_random_goal() return self._get_obs().copy() @@ -154,8 +165,9 @@ class SimpleReacherEnv(gym.Env, utils.EzPickle): lim = np.sum(self.link_lengths) + 0.5 plt.xlim([-lim, lim]) plt.ylim([-lim, lim]) - plt.draw() - plt.pause(0.0001) + # plt.draw() + # plt.pause(1e-4) pushed window to foreground, which is annoying. + self.fig.canvas.flush_events() def close(self): del self.fig