From 60e1673ee164771b106e4f4fe2b8d12f998e039d Mon Sep 17 00:00:00 2001 From: Maximilian Huettenrauch Date: Fri, 19 Feb 2021 16:17:55 +0100 Subject: [PATCH] viapoint reacher reward bug fix --- alr_envs/classic_control/utils.py | 5 +++-- alr_envs/classic_control/viapoint_reacher.py | 2 +- dmp_env_wrapper_example.py | 13 +++++++++++-- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/alr_envs/classic_control/utils.py b/alr_envs/classic_control/utils.py index 9da138f..f2ead72 100644 --- a/alr_envs/classic_control/utils.py +++ b/alr_envs/classic_control/utils.py @@ -23,11 +23,12 @@ def make_viapointreacher_env(rank, seed=0): num_dof=5, num_basis=5, duration=2, - alpha_phase=2, + alpha_phase=2.5, dt=_env.dt, start_pos=_env.start_pos, learn_goal=False, - policy_type="velocity") + policy_type="velocity", + weights_scale=50) _env.seed(seed + rank) return _env diff --git a/alr_envs/classic_control/viapoint_reacher.py b/alr_envs/classic_control/viapoint_reacher.py index 1cad10e..f6a3474 100644 --- a/alr_envs/classic_control/viapoint_reacher.py +++ b/alr_envs/classic_control/viapoint_reacher.py @@ -83,7 +83,7 @@ class ViaPointReacher(gym.Env): if not self._is_collided: if self._steps == 100: dist_reward = np.linalg.norm(self.end_effector - self.via_point) - if self._steps == 200: + if self._steps == 199: dist_reward = np.linalg.norm(self.end_effector - self.goal_point) reward = - dist_reward ** 2 diff --git a/dmp_env_wrapper_example.py b/dmp_env_wrapper_example.py index b971574..6fd0a57 100644 --- a/dmp_env_wrapper_example.py +++ b/dmp_env_wrapper_example.py @@ -14,10 +14,19 @@ if __name__ == "__main__": test_env = make_viapointreacher_env(0)() - params = np.random.randn(n_samples, dim) + # params = np.random.randn(n_samples, dim) + params = np.array([ 217.54494933, -1.85169983, 24.08414447, 42.23816868, + 23.32071702, 7.60780651, -31.74777741, 265.50634253, + 463.43822562, 245.93948374, -272.64003621, -45.24999553, + 503.21185823, 809.17742517, 393.12387021, -196.54196471, + 6.79327307, 374.82429078, 552.4119579 , 197.3963343 , + 243.87357056, -39.56041541, -616.93957463, -710.0772516 , + -414.21769789]) + # params = np.hstack([50 * np.random.randn(n_samples, 25), np.tile(np.array([np.pi/2, -np.pi/4, -np.pi/4, -np.pi/4, -np.pi/4]), [n_samples, 1])]) - test_env.rollout(params, render=True) + rew, info = test_env.rollout(params, render=True) + print(rew) # out = env(params) # print(out)