From 14c60766c24ce5170a9c267253059e5896850346 Mon Sep 17 00:00:00 2001 From: ottofabian Date: Mon, 17 May 2021 17:58:33 +0200 Subject: [PATCH] fixed open issues --- alr_envs/__init__.py | 96 +++++++------------- alr_envs/classic_control/simple_reacher.py | 6 -- alr_envs/classic_control/viapoint_reacher.py | 25 ++++- alr_envs/utils/mp_env_async_sampler.py | 2 +- 4 files changed, 56 insertions(+), 73 deletions(-) diff --git a/alr_envs/__init__.py b/alr_envs/__init__.py index 17d5541..0007982 100644 --- a/alr_envs/__init__.py +++ b/alr_envs/__init__.py @@ -214,16 +214,17 @@ register( ) # MP environments -reacher_envs = ["SimpleReacher-v0", "SimpleReacher-v1", "LongSimpleReacher-v0", "LongSimpleReacher-v1"] -for env in reacher_envs: - name = env.split("-") +## Simple Reacher +versions = ["SimpleReacher-v0", "SimpleReacher-v1", "LongSimpleReacher-v0", "LongSimpleReacher-v1"] +for v in versions: + name = v.split("-") register( id=f'{name[0]}DMP-{name[1]}', entry_point='alr_envs.utils.make_env_helpers:make_dmp_env', # max_episode_steps=1, kwargs={ - "name": f"alr_envs:{env}", - "num_dof": 2 if "long" not in env.lower() else 5 , + "name": f"alr_envs:{v}", + "num_dof": 2 if "long" not in v.lower() else 5, "num_basis": 5, "duration": 2, "alpha_phase": 2, @@ -249,59 +250,33 @@ register( } ) -register( - id='HoleReacherDMP-v0', - entry_point='alr_envs.utils.make_env_helpers:make_dmp_env', - # max_episode_steps=1, - kwargs={ - "name": "alr_envs:HoleReacher-v0", - "num_dof": 5, - "num_basis": 5, - "duration": 2, - "learn_goal": True, - "alpha_phase": 2, - "bandwidth_factor": 2, - "policy_type": "velocity", - "weights_scale": 50, - "goal_scale": 0.1 - } -) +## Hole Reacher +versions = ["v0", "v1", "v2"] +for v in versions: + register( + id=f'HoleReacherDMP-{v}', + entry_point='alr_envs.utils.make_env_helpers:make_dmp_env', + # max_episode_steps=1, + kwargs={ + "name": f"alr_envs:HoleReacher-{v}", + "num_dof": 5, + "num_basis": 5, + "duration": 2, + "learn_goal": True, + "alpha_phase": 2, + "bandwidth_factor": 2, + "policy_type": "velocity", + "weights_scale": 50, + "goal_scale": 0.1 + } + ) -register( - id='HoleReacherDMP-v1', - entry_point='alr_envs.utils.make_env_helpers:make_dmp_env', - # max_episode_steps=1, - kwargs={ - "name": "alr_envs:HoleReacher-v1", - "num_dof": 5, - "num_basis": 5, - "duration": 2, - "learn_goal": True, - "alpha_phase": 2, - "bandwidth_factor": 2, - "policy_type": "velocity", - "weights_scale": 50, - "goal_scale": 0.1 - } -) - -register( - id='HoleReacherDMP-v2', - entry_point='alr_envs.utils.make_env_helpers:make_dmp_env', - # max_episode_steps=1, - kwargs={ - "name": "alr_envs:HoleReacher-v2", - "num_dof": 5, - "num_basis": 5, - "duration": 2, - "learn_goal": True, - "alpha_phase": 2, - "bandwidth_factor": 2, - "policy_type": "velocity", - "weights_scale": 50, - "goal_scale": 0.1 - } -) +# register( +# id='HoleReacherDetPMP-v0', +# entry_point='alr_envs.classic_control.hole_reacher:holereacher_detpmp', +# # max_episode_steps=1, +# # TODO: add mp kwargs +# ) # TODO: properly add final_pos register( @@ -321,12 +296,7 @@ register( } ) -# register( -# id='HoleReacherDetPMP-v0', -# entry_point='alr_envs.classic_control.hole_reacher:holereacher_detpmp', -# # max_episode_steps=1, -# # TODO: add mp kwargs -# ) +## Ball in Cup register( id='ALRBallInACupSimpleDMP-v0', diff --git a/alr_envs/classic_control/simple_reacher.py b/alr_envs/classic_control/simple_reacher.py index 4e99ff1..425134d 100644 --- a/alr_envs/classic_control/simple_reacher.py +++ b/alr_envs/classic_control/simple_reacher.py @@ -127,12 +127,6 @@ class SimpleReacherEnv(MPEnv): def _generate_goal(self): if self._target is None: - # center = self._joints[0] - # # Sample uniformly in circle with radius R around center of reacher. - # R = np.sum(self.link_lengths) - # r = R * np.sqrt(self.np_random.uniform()) - # theta = self.np_random.uniform() * 2 * np.pi - # goal = center + r * np.stack([np.cos(theta), np.sin(theta)]) total_length = np.sum(self.link_lengths) goal = np.array([total_length, total_length]) diff --git a/alr_envs/classic_control/viapoint_reacher.py b/alr_envs/classic_control/viapoint_reacher.py index ac36360..2897f31 100644 --- a/alr_envs/classic_control/viapoint_reacher.py +++ b/alr_envs/classic_control/viapoint_reacher.py @@ -99,9 +99,28 @@ class ViaPointReacher(MPEnv): return self._get_obs().copy() def _generate_goal(self): - self._via_point = self.np_random.uniform(0.5, 3.5, 2) if self._via_target is None else np.copy(self._via_target) - self._goal = self.np_random.uniform(0.5, 0.1, 2) if self._target is None else np.copy(self._target) - # raise NotImplementedError("How to properly sample points??") + # TODO: Maybe improve this later, this can yield quite a lot of invalid settings + + total_length = np.sum(self.link_lengths) + + # rejection sampled point in inner circle with 0.5*Radius + if self._via_target is None: + via_target = np.array([total_length, total_length]) + while np.linalg.norm(via_target) >= 0.5 * total_length: + via_target = self.np_random.uniform(low=-0.5 * total_length, high=0.5 * total_length, size=2) + else: + via_target = np.copy(self._via_target) + + # rejection sampled point in outer circle + if self._target is None: + goal = np.array([total_length, total_length]) + while np.linalg.norm(goal) >= total_length or np.linalg.norm(goal) <= 0.5 * total_length: + goal = self.np_random.uniform(low=-total_length, high=total_length, size=2) + else: + goal = np.copy(self._target) + + self._via_target = via_target + self._goal = goal def _update_joints(self): """ diff --git a/alr_envs/utils/mp_env_async_sampler.py b/alr_envs/utils/mp_env_async_sampler.py index 2fb3645..e935ba6 100644 --- a/alr_envs/utils/mp_env_async_sampler.py +++ b/alr_envs/utils/mp_env_async_sampler.py @@ -70,7 +70,7 @@ class AlrMpEnvSampler: class AlrContextualMpEnvSampler: """ - An asynchronous sampler for non contextual MPWrapper environments. A sampler object can be called with a set of + An asynchronous sampler for contextual MPWrapper environments. A sampler object can be called with a set of parameters and returns the corresponding final obs, rewards, dones and info dicts. """ def __init__(self, env_id, num_envs, seed=0, **env_kwargs):