fixed open issues

This commit is contained in:
ottofabian 2021-05-17 17:58:33 +02:00
parent b39104a449
commit 14c60766c2
4 changed files with 56 additions and 73 deletions

View File

@ -214,16 +214,17 @@ register(
) )
# MP environments # MP environments
reacher_envs = ["SimpleReacher-v0", "SimpleReacher-v1", "LongSimpleReacher-v0", "LongSimpleReacher-v1"] ## Simple Reacher
for env in reacher_envs: versions = ["SimpleReacher-v0", "SimpleReacher-v1", "LongSimpleReacher-v0", "LongSimpleReacher-v1"]
name = env.split("-") for v in versions:
name = v.split("-")
register( register(
id=f'{name[0]}DMP-{name[1]}', id=f'{name[0]}DMP-{name[1]}',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env', entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
# max_episode_steps=1, # max_episode_steps=1,
kwargs={ kwargs={
"name": f"alr_envs:{env}", "name": f"alr_envs:{v}",
"num_dof": 2 if "long" not in env.lower() else 5 , "num_dof": 2 if "long" not in v.lower() else 5,
"num_basis": 5, "num_basis": 5,
"duration": 2, "duration": 2,
"alpha_phase": 2, "alpha_phase": 2,
@ -249,12 +250,15 @@ register(
} }
) )
## Hole Reacher
versions = ["v0", "v1", "v2"]
for v in versions:
register( register(
id='HoleReacherDMP-v0', id=f'HoleReacherDMP-{v}',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env', entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
# max_episode_steps=1, # max_episode_steps=1,
kwargs={ kwargs={
"name": "alr_envs:HoleReacher-v0", "name": f"alr_envs:HoleReacher-{v}",
"num_dof": 5, "num_dof": 5,
"num_basis": 5, "num_basis": 5,
"duration": 2, "duration": 2,
@ -267,41 +271,12 @@ register(
} }
) )
register( # register(
id='HoleReacherDMP-v1', # id='HoleReacherDetPMP-v0',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env', # entry_point='alr_envs.classic_control.hole_reacher:holereacher_detpmp',
# max_episode_steps=1, # # max_episode_steps=1,
kwargs={ # # TODO: add mp kwargs
"name": "alr_envs:HoleReacher-v1", # )
"num_dof": 5,
"num_basis": 5,
"duration": 2,
"learn_goal": True,
"alpha_phase": 2,
"bandwidth_factor": 2,
"policy_type": "velocity",
"weights_scale": 50,
"goal_scale": 0.1
}
)
register(
id='HoleReacherDMP-v2',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
# max_episode_steps=1,
kwargs={
"name": "alr_envs:HoleReacher-v2",
"num_dof": 5,
"num_basis": 5,
"duration": 2,
"learn_goal": True,
"alpha_phase": 2,
"bandwidth_factor": 2,
"policy_type": "velocity",
"weights_scale": 50,
"goal_scale": 0.1
}
)
# TODO: properly add final_pos # TODO: properly add final_pos
register( register(
@ -321,12 +296,7 @@ register(
} }
) )
# register( ## Ball in Cup
# id='HoleReacherDetPMP-v0',
# entry_point='alr_envs.classic_control.hole_reacher:holereacher_detpmp',
# # max_episode_steps=1,
# # TODO: add mp kwargs
# )
register( register(
id='ALRBallInACupSimpleDMP-v0', id='ALRBallInACupSimpleDMP-v0',

View File

@ -127,12 +127,6 @@ class SimpleReacherEnv(MPEnv):
def _generate_goal(self): def _generate_goal(self):
if self._target is None: if self._target is None:
# center = self._joints[0]
# # Sample uniformly in circle with radius R around center of reacher.
# R = np.sum(self.link_lengths)
# r = R * np.sqrt(self.np_random.uniform())
# theta = self.np_random.uniform() * 2 * np.pi
# goal = center + r * np.stack([np.cos(theta), np.sin(theta)])
total_length = np.sum(self.link_lengths) total_length = np.sum(self.link_lengths)
goal = np.array([total_length, total_length]) goal = np.array([total_length, total_length])

View File

@ -99,9 +99,28 @@ class ViaPointReacher(MPEnv):
return self._get_obs().copy() return self._get_obs().copy()
def _generate_goal(self): def _generate_goal(self):
self._via_point = self.np_random.uniform(0.5, 3.5, 2) if self._via_target is None else np.copy(self._via_target) # TODO: Maybe improve this later, this can yield quite a lot of invalid settings
self._goal = self.np_random.uniform(0.5, 0.1, 2) if self._target is None else np.copy(self._target)
# raise NotImplementedError("How to properly sample points??") total_length = np.sum(self.link_lengths)
# rejection sampled point in inner circle with 0.5*Radius
if self._via_target is None:
via_target = np.array([total_length, total_length])
while np.linalg.norm(via_target) >= 0.5 * total_length:
via_target = self.np_random.uniform(low=-0.5 * total_length, high=0.5 * total_length, size=2)
else:
via_target = np.copy(self._via_target)
# rejection sampled point in outer circle
if self._target is None:
goal = np.array([total_length, total_length])
while np.linalg.norm(goal) >= total_length or np.linalg.norm(goal) <= 0.5 * total_length:
goal = self.np_random.uniform(low=-total_length, high=total_length, size=2)
else:
goal = np.copy(self._target)
self._via_target = via_target
self._goal = goal
def _update_joints(self): def _update_joints(self):
""" """

View File

@ -70,7 +70,7 @@ class AlrMpEnvSampler:
class AlrContextualMpEnvSampler: class AlrContextualMpEnvSampler:
""" """
An asynchronous sampler for non contextual MPWrapper environments. A sampler object can be called with a set of An asynchronous sampler for contextual MPWrapper environments. A sampler object can be called with a set of
parameters and returns the corresponding final obs, rewards, dones and info dicts. parameters and returns the corresponding final obs, rewards, dones and info dicts.
""" """
def __init__(self, env_id, num_envs, seed=0, **env_kwargs): def __init__(self, env_id, num_envs, seed=0, **env_kwargs):