From 6233c859044ee26dff051acab0e076ef347a68e1 Mon Sep 17 00:00:00 2001 From: Maximilian Huettenrauch Date: Mon, 22 Mar 2021 15:28:50 +0100 Subject: [PATCH] update on holereacher reward --- alr_envs/classic_control/hole_reacher.py | 16 ++++------------ dmp_pd_control_example.py | 2 +- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/alr_envs/classic_control/hole_reacher.py b/alr_envs/classic_control/hole_reacher.py index d3faf13..681183b 100644 --- a/alr_envs/classic_control/hole_reacher.py +++ b/alr_envs/classic_control/hole_reacher.py @@ -94,23 +94,15 @@ class HoleReacher(gym.Env): # compute reward directly in step function - dist_reward = 0 + reward = 0 if not self._is_collided: if self._steps == 199: - dist_reward = np.linalg.norm(self.end_effector - self.bottom_center_of_hole) + reward = - np.linalg.norm(self.end_effector - self.bottom_center_of_hole) ** 2 else: - dist_reward = np.linalg.norm(self.end_effector - self.bottom_center_of_hole) - - reward = - dist_reward ** 2 - - reward -= 5e-8 * np.sum(acc**2) - - # if self._steps == 180: - # reward -= 0.1 * np.sum(vel**2) ** 2 - - if self._is_collided: reward = -self.collision_penalty + reward -= 5e-8 * np.sum(acc ** 2) + info = {"is_collided": self._is_collided} self._steps += 1 diff --git a/dmp_pd_control_example.py b/dmp_pd_control_example.py index 303f979..5abf8fa 100644 --- a/dmp_pd_control_example.py +++ b/dmp_pd_control_example.py @@ -8,7 +8,7 @@ if __name__ == "__main__": dim = 15 n_cpus = 4 - n_samples = 1 + n_samples = 10 vec_env = DmpAsyncVectorEnv([make_simple_env(i) for i in range(n_cpus)], n_samples=n_samples)