From 448ebcde950acfffb9bf87fbaf3c3eb59cf25d8e Mon Sep 17 00:00:00 2001 From: Maximilian Huettenrauch Date: Sat, 10 Apr 2021 13:37:48 +0200 Subject: [PATCH] holereach025 add success flag --- alr_envs/classic_control/hole_reacher.py | 10 +++++++--- dmp_env_wrapper_example.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/alr_envs/classic_control/hole_reacher.py b/alr_envs/classic_control/hole_reacher.py index a5a153c..2b39a2c 100644 --- a/alr_envs/classic_control/hole_reacher.py +++ b/alr_envs/classic_control/hole_reacher.py @@ -94,19 +94,23 @@ class HoleReacher(gym.Env): # compute reward directly in step function + success = False reward = 0 if not self._is_collided: if self._steps == 199: - reward = - np.linalg.norm(self.end_effector - self.bottom_center_of_hole) ** 2 + dist = np.linalg.norm(self.end_effector - self.bottom_center_of_hole) + reward = - dist ** 2 + success = dist < 0.005 else: + dist = np.linalg.norm(self.end_effector - self.bottom_center_of_hole) # if self.collision_penalty != 0: # reward = -self.collision_penalty # else: - reward = - np.linalg.norm(self.end_effector - self.bottom_center_of_hole) ** 2 - self.collision_penalty + reward = - dist ** 2 - self.collision_penalty reward -= 5e-8 * np.sum(acc ** 2) - info = {"is_collided": self._is_collided} + info = {"is_collided": self._is_collided, "is_success": success} self._steps += 1 diff --git a/dmp_env_wrapper_example.py b/dmp_env_wrapper_example.py index bb886c6..91309e8 100644 --- a/dmp_env_wrapper_example.py +++ b/dmp_env_wrapper_example.py @@ -28,7 +28,7 @@ if __name__ == "__main__": # params = np.hstack([50 * np.random.randn(n_samples, 25), np.tile(np.array([np.pi/2, -np.pi/4, -np.pi/4, -np.pi/4, -np.pi/4]), [n_samples, 1])]) - rew, info = test_env.rollout(params, render=True) + rew, info = test_env.rollout(params, render=False) print(rew) # out = env(params)