diff --git a/alr_envs/__init__.py b/alr_envs/__init__.py index 14359b6..2439822 100644 --- a/alr_envs/__init__.py +++ b/alr_envs/__init__.py @@ -11,7 +11,7 @@ register( ) register( - id='ALRReacherShort-v0', + id='ALRReacherShortSparse-v0', entry_point='alr_envs.mujoco:ALRReacherEnv', max_episode_steps=50, kwargs={ @@ -20,6 +20,16 @@ register( } ) +register( + id='ALRReacherShort-v0', + entry_point='alr_envs.mujoco:ALRReacherEnv', + max_episode_steps=50, + kwargs={ + "steps_before_reward": 40, + "n_links": 5, + } +) + register( id='ALRReacherSparse-v0', entry_point='alr_envs.mujoco:ALRReacherEnv', diff --git a/alr_envs/mujoco/alr_reacher.py b/alr_envs/mujoco/alr_reacher.py index 4fd5c66..a7e4e18 100644 --- a/alr_envs/mujoco/alr_reacher.py +++ b/alr_envs/mujoco/alr_reacher.py @@ -10,7 +10,11 @@ class ALRReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle): self.steps_before_reward = steps_before_reward self.n_links = n_links - self.reward_weight = 1 if self.steps_before_reward != 200 and self.steps_before_reward != 50 else 200 + self.reward_weight = 1 + if steps_before_reward == 200: + self.reward_weight = 200 + elif steps_before_reward == 50: + self.reward_weight = 50 if n_links == 5: file_name = 'reacher_5links.xml'