From 2d9e7fb3eb4fa9cf894efe3dc3603b170a643151 Mon Sep 17 00:00:00 2001 From: Maximilian Huettenrauch Date: Fri, 15 Jan 2021 17:16:52 +0100 Subject: [PATCH] fixes in holereacher --- alr_envs/classic_control/hole_reacher.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/alr_envs/classic_control/hole_reacher.py b/alr_envs/classic_control/hole_reacher.py index f19915b..0db772a 100644 --- a/alr_envs/classic_control/hole_reacher.py +++ b/alr_envs/classic_control/hole_reacher.py @@ -38,6 +38,7 @@ class HoleReacher(gym.Env): self.weight_matrix_scale = 50 # for the holereacher, the dmp weights become quite large compared to the values of the goal attractor. this scaling is to ensure they are on similar scale for the optimizer self._dt = 0.01 + self.time_limit = 2 action_bound = np.pi * np.ones((self.num_links,)) state_bound = np.hstack([ @@ -103,9 +104,7 @@ class HoleReacher(gym.Env): reward -= 1e-6 * np.sum(acc**2) if self._steps == 180: - reward -= (0.1 * np.sum(vel**2) ** 2 - + 1e-3 * np.sum(action**2) - ) + reward -= 0.1 * np.sum(vel**2) ** 2 if self._is_collided: reward -= self.collision_penalty @@ -114,7 +113,9 @@ class HoleReacher(gym.Env): self._steps += 1 - return self._get_obs().copy(), reward, self._is_collided, info + done = self._steps * self._dt > self.time_limit or self._is_collided + + return self._get_obs().copy(), reward, done, info def _update_joints(self): """ @@ -301,12 +302,12 @@ if __name__ == '__main__': # test with random actions ac = 2 * env.action_space.sample() # ac[0] += np.pi/2 - obs, rew, done, info = env.step(ac) + obs, rew, d, info = env.step(ac) env.render(mode=render_mode) print(rew) - if done: + if d: break env.close()