diff --git a/alr_envs/classic_control/hole_reacher.py b/alr_envs/classic_control/hole_reacher.py index f19915b..0db772a 100644 --- a/alr_envs/classic_control/hole_reacher.py +++ b/alr_envs/classic_control/hole_reacher.py @@ -38,6 +38,7 @@ class HoleReacher(gym.Env): self.weight_matrix_scale = 50 # for the holereacher, the dmp weights become quite large compared to the values of the goal attractor. this scaling is to ensure they are on similar scale for the optimizer self._dt = 0.01 + self.time_limit = 2 action_bound = np.pi * np.ones((self.num_links,)) state_bound = np.hstack([ @@ -103,9 +104,7 @@ class HoleReacher(gym.Env): reward -= 1e-6 * np.sum(acc**2) if self._steps == 180: - reward -= (0.1 * np.sum(vel**2) ** 2 - + 1e-3 * np.sum(action**2) - ) + reward -= 0.1 * np.sum(vel**2) ** 2 if self._is_collided: reward -= self.collision_penalty @@ -114,7 +113,9 @@ class HoleReacher(gym.Env): self._steps += 1 - return self._get_obs().copy(), reward, self._is_collided, info + done = self._steps * self._dt > self.time_limit or self._is_collided + + return self._get_obs().copy(), reward, done, info def _update_joints(self): """ @@ -301,12 +302,12 @@ if __name__ == '__main__': # test with random actions ac = 2 * env.action_space.sample() # ac[0] += np.pi/2 - obs, rew, done, info = env.step(ac) + obs, rew, d, info = env.step(ac) env.render(mode=render_mode) print(rew) - if done: + if d: break env.close()