fixes in holereacher

This commit is contained in:
Maximilian Huettenrauch 2021-01-15 17:16:52 +01:00
parent b7400c477d
commit 2d9e7fb3eb

View File

@ -38,6 +38,7 @@ class HoleReacher(gym.Env):
self.weight_matrix_scale = 50 # for the holereacher, the dmp weights become quite large compared to the values of the goal attractor. this scaling is to ensure they are on similar scale for the optimizer self.weight_matrix_scale = 50 # for the holereacher, the dmp weights become quite large compared to the values of the goal attractor. this scaling is to ensure they are on similar scale for the optimizer
self._dt = 0.01 self._dt = 0.01
self.time_limit = 2
action_bound = np.pi * np.ones((self.num_links,)) action_bound = np.pi * np.ones((self.num_links,))
state_bound = np.hstack([ state_bound = np.hstack([
@ -103,9 +104,7 @@ class HoleReacher(gym.Env):
reward -= 1e-6 * np.sum(acc**2) reward -= 1e-6 * np.sum(acc**2)
if self._steps == 180: if self._steps == 180:
reward -= (0.1 * np.sum(vel**2) ** 2 reward -= 0.1 * np.sum(vel**2) ** 2
+ 1e-3 * np.sum(action**2)
)
if self._is_collided: if self._is_collided:
reward -= self.collision_penalty reward -= self.collision_penalty
@ -114,7 +113,9 @@ class HoleReacher(gym.Env):
self._steps += 1 self._steps += 1
return self._get_obs().copy(), reward, self._is_collided, info done = self._steps * self._dt > self.time_limit or self._is_collided
return self._get_obs().copy(), reward, done, info
def _update_joints(self): def _update_joints(self):
""" """
@ -301,12 +302,12 @@ if __name__ == '__main__':
# test with random actions # test with random actions
ac = 2 * env.action_space.sample() ac = 2 * env.action_space.sample()
# ac[0] += np.pi/2 # ac[0] += np.pi/2
obs, rew, done, info = env.step(ac) obs, rew, d, info = env.step(ac)
env.render(mode=render_mode) env.render(mode=render_mode)
print(rew) print(rew)
if done: if d:
break break
env.close() env.close()