holereach025 add success flag

This commit is contained in:
Maximilian Huettenrauch 2021-04-10 13:37:48 +02:00
parent f6cef69225
commit 448ebcde95
2 changed files with 8 additions and 4 deletions

View File

@ -94,19 +94,23 @@ class HoleReacher(gym.Env):
# compute reward directly in step function # compute reward directly in step function
success = False
reward = 0 reward = 0
if not self._is_collided: if not self._is_collided:
if self._steps == 199: if self._steps == 199:
reward = - np.linalg.norm(self.end_effector - self.bottom_center_of_hole) ** 2 dist = np.linalg.norm(self.end_effector - self.bottom_center_of_hole)
reward = - dist ** 2
success = dist < 0.005
else: else:
dist = np.linalg.norm(self.end_effector - self.bottom_center_of_hole)
# if self.collision_penalty != 0: # if self.collision_penalty != 0:
# reward = -self.collision_penalty # reward = -self.collision_penalty
# else: # else:
reward = - np.linalg.norm(self.end_effector - self.bottom_center_of_hole) ** 2 - self.collision_penalty reward = - dist ** 2 - self.collision_penalty
reward -= 5e-8 * np.sum(acc ** 2) reward -= 5e-8 * np.sum(acc ** 2)
info = {"is_collided": self._is_collided} info = {"is_collided": self._is_collided, "is_success": success}
self._steps += 1 self._steps += 1

View File

@ -28,7 +28,7 @@ if __name__ == "__main__":
# params = np.hstack([50 * np.random.randn(n_samples, 25), np.tile(np.array([np.pi/2, -np.pi/4, -np.pi/4, -np.pi/4, -np.pi/4]), [n_samples, 1])]) # params = np.hstack([50 * np.random.randn(n_samples, 25), np.tile(np.array([np.pi/2, -np.pi/4, -np.pi/4, -np.pi/4, -np.pi/4]), [n_samples, 1])])
rew, info = test_env.rollout(params, render=True) rew, info = test_env.rollout(params, render=False)
print(rew) print(rew)
# out = env(params) # out = env(params)